dynamo-llm 1.0.2

Dynamo LLM Library
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashSet;
use std::time::Duration;

use anyhow::Result;

use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;
use dynamo_memory::nixl::NixlAgent;

use super::common::EncodedMediaData;
use super::decoders::{Decoder, MediaDecoder};
use super::rdma::{RdmaMediaDataDescriptor, get_nixl_agent};

const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct MediaFetcher {
    pub user_agent: String,
    pub allow_direct_ip: bool,
    pub allow_direct_port: bool,
    pub allowed_media_domains: Option<HashSet<String>>,
    pub timeout: Option<Duration>,
}

impl Default for MediaFetcher {
    fn default() -> Self {
        Self {
            user_agent: DEFAULT_HTTP_USER_AGENT.to_string(),
            allow_direct_ip: false,
            allow_direct_port: false,
            allowed_media_domains: None,
            timeout: Some(DEFAULT_HTTP_TIMEOUT),
        }
    }
}

impl MediaFetcher {
    pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
        if !matches!(url.scheme(), "http" | "https" | "data") {
            anyhow::bail!("Only HTTP(S) and data URLs are allowed");
        }

        if url.scheme() == "data" {
            return Ok(());
        }

        if !self.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_))) {
            anyhow::bail!("Direct IP access is not allowed");
        }
        if !self.allow_direct_port && url.port().is_some() {
            anyhow::bail!("Direct port access is not allowed");
        }
        if let Some(allowed_domains) = &self.allowed_media_domains
            && let Some(host) = url.host_str()
            && !allowed_domains.contains(host)
        {
            anyhow::bail!("Domain '{host}' is not in allowed list");
        }

        Ok(())
    }
}

pub struct MediaLoader {
    #[allow(dead_code)]
    media_decoder: MediaDecoder,
    #[allow(dead_code)]
    http_client: reqwest::Client,
    #[allow(dead_code)]
    media_fetcher: MediaFetcher,
    nixl_agent: NixlAgent,
}

impl MediaLoader {
    pub fn new(media_decoder: MediaDecoder, media_fetcher: Option<MediaFetcher>) -> Result<Self> {
        let media_fetcher = media_fetcher.unwrap_or_default();
        let mut http_client_builder: reqwest::ClientBuilder =
            reqwest::Client::builder().user_agent(&media_fetcher.user_agent);

        if let Some(timeout) = media_fetcher.timeout {
            http_client_builder = http_client_builder.timeout(timeout);
        }

        let http_client = http_client_builder.build()?;

        let nixl_agent = get_nixl_agent()?;

        Ok(Self {
            media_decoder,
            http_client,
            media_fetcher,
            nixl_agent,
        })
    }

    pub async fn fetch_and_decode_media_part(
        &self,
        oai_content_part: &ChatCompletionRequestUserMessageContentPart,
        media_io_kwargs: Option<&MediaDecoder>,
    ) -> Result<RdmaMediaDataDescriptor> {
        // fetch the media, decode and NIXL-register
        let decoded = match oai_content_part {
            ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
                let mdc_decoder = self
                    .media_decoder
                    .image
                    .as_ref()
                    .ok_or_else(|| anyhow::anyhow!("Model does not support image inputs"))?;

                let url = &image_part.image_url.url;
                self.media_fetcher.check_if_url_allowed(url)?;
                let data = EncodedMediaData::from_url(url, &self.http_client).await?;

                // Use runtime decoder if provided, with MDC limits enforced
                let decoder =
                    mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref()));
                decoder.decode_async(data).await?
            }
            #[allow(unused_variables)]
            ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
                #[cfg(not(feature = "media-ffmpeg"))]
                anyhow::bail!("Video decoding requires the 'media-ffmpeg' feature to be enabled");

                #[cfg(feature = "media-ffmpeg")]
                {
                    let mdc_decoder =
                        self.media_decoder.video.as_ref().ok_or_else(|| {
                            anyhow::anyhow!("Model does not support video inputs")
                        })?;

                    let url = &video_part.video_url.url;
                    self.media_fetcher.check_if_url_allowed(url)?;
                    let data = EncodedMediaData::from_url(url, &self.http_client).await?;

                    // Use runtime decoder if provided, with MDC limits enforced
                    let decoder =
                        mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.video.as_ref()));
                    decoder.decode_async(data).await?
                }
            }
            ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
                anyhow::bail!("Audio decoding is not supported yet");
            }
            _ => anyhow::bail!("Unsupported media type"),
        };

        let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
        Ok(rdma_descriptor)
    }
}

#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
    use super::super::decoders::ImageDecoder;
    use super::super::rdma::DataType;
    use super::*;
    use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};

    #[tokio::test]
    async fn test_fetch_and_decode() {
        let test_image_bytes =
            include_bytes!("../../../tests/data/media/llm-optimize-deploy-graphic.png");

        let mut server = mockito::Server::new_async().await;
        let mock = server
            .mock("GET", "/llm-optimize-deploy-graphic.png")
            .with_status(200)
            .with_header("content-type", "image/png")
            .with_body(&test_image_bytes[..])
            .create_async()
            .await;

        let media_decoder = MediaDecoder {
            image: Some(ImageDecoder::default()),
            #[cfg(feature = "media-ffmpeg")]
            video: None,
        };
        let fetcher = MediaFetcher {
            allow_direct_ip: true,
            allow_direct_port: true,
            ..Default::default()
        };

        let loader: MediaLoader = match MediaLoader::new(media_decoder, Some(fetcher)) {
            Ok(l) => l,
            Err(e) => {
                println!(
                    "test test_fetch_and_decode ... ignored (NIXL/UCX not available: {})",
                    e
                );
                return;
            }
        };

        let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url()));
        let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl(
            ChatCompletionRequestMessageContentPartImage { image_url },
        );

        let result = loader
            .fetch_and_decode_media_part(&content_part, None)
            .await;

        let descriptor = match result {
            Ok(descriptor) => descriptor,
            Err(e) if e.to_string().contains("NIXL agent is not available") => {
                println!("test test_fetch_and_decode ... ignored (NIXL agent not available)");
                return;
            }
            Err(e) => panic!("Failed to fetch and decode image: {}", e),
        };
        mock.assert_async().await;
        assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8);

        // Verify image dimensions: 1,999px × 1,125px (width × height)
        // Shape format is [height, width, channels]
        assert_eq!(descriptor.tensor_info.shape.len(), 3);
        assert_eq!(
            descriptor.tensor_info.shape[0], 1125,
            "Height should be 1125"
        );
        assert_eq!(
            descriptor.tensor_info.shape[1], 1999,
            "Width should be 1999"
        );
        assert_eq!(
            descriptor.tensor_info.shape[2], 4,
            "RGBA channels should be 4"
        );

        assert!(
            descriptor.source_storage.is_some(),
            "Source storage should be present"
        );
        assert!(
            descriptor.source_storage.unwrap().is_registered(),
            "Source storage should be registered with NIXL"
        );
    }
}

#[cfg(test)]
mod tests_non_nixl {
    use super::*;

    #[test]
    fn test_direct_ip_blocked() {
        let fetcher = MediaFetcher {
            allow_direct_ip: false,
            ..Default::default()
        };

        let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
        let result = fetcher.check_if_url_allowed(&url);

        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("Direct IP access is not allowed")
        );
    }

    #[test]
    fn test_direct_port_blocked() {
        let fetcher = MediaFetcher {
            allow_direct_port: false,
            ..Default::default()
        };

        let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
        let result = fetcher.check_if_url_allowed(&url);

        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("Direct port access is not allowed")
        );
    }

    #[test]
    fn test_domain_allowlist() {
        let mut allowed_domains = HashSet::new();
        allowed_domains.insert("trusted.com".to_string());
        allowed_domains.insert("example.com".to_string());

        let fetcher = MediaFetcher {
            allowed_media_domains: Some(allowed_domains),
            ..Default::default()
        };

        // Allowed domain should pass
        let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
        assert!(fetcher.check_if_url_allowed(&url).is_ok());

        // Disallowed domain should fail
        let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap();
        let result = fetcher.check_if_url_allowed(&url);
        assert!(result.is_err());
        assert!(
            result
                .unwrap_err()
                .to_string()
                .contains("not in allowed list")
        );
    }
}