use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use anyhow::Result;
use ipnet::IpNet;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use reqwest::redirect::Policy;
use dynamo_memory::nixl::NixlAgent;
use dynamo_protocols::types::ChatCompletionRequestUserMessageContentPart;
use super::common::EncodedMediaData;
use super::decoders::{Decoder, MediaDecoder};
use super::rdma::{DataType, RdmaMediaDataDescriptor, get_nixl_agent};
use lru::LruCache;
use parking_lot::Mutex;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
const MAX_REDIRECTS: usize = 3;
static BLOCKED_IP_NETWORKS: LazyLock<Vec<IpNet>> = LazyLock::new(|| {
[
"0.0.0.0/8",
"10.0.0.0/8",
"100.64.0.0/10",
"127.0.0.0/8",
"169.254.0.0/16",
"172.16.0.0/12",
"192.0.0.0/24",
"192.0.2.0/24",
"192.168.0.0/16",
"198.18.0.0/15",
"198.51.100.0/24",
"203.0.113.0/24",
"224.0.0.0/4",
"240.0.0.0/4",
"255.255.255.255/32",
"::/128",
"::1/128",
"::ffff:0:0/96",
"fc00::/7",
"fe80::/10",
"ff00::/8",
]
.iter()
.map(|s| s.parse().expect("invalid CIDR in BLOCKED_IP_NETWORKS"))
.collect()
});
static BLOCKED_HOSTS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
[
"localhost",
"localhost.localdomain",
"ip6-localhost",
"ip6-loopback",
"metadata",
"metadata.google.internal",
"metadata.goog",
"kubernetes.default",
"kubernetes.default.svc",
]
.iter()
.copied()
.collect()
});
pub fn is_blocked_ip(ip: &IpAddr) -> bool {
BLOCKED_IP_NETWORKS.iter().any(|net| net.contains(ip))
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct MediaFetcher {
pub user_agent: String,
pub allow_direct_ip: bool,
pub allow_direct_port: bool,
pub allow_private_ips: bool,
pub allowed_media_domains: Option<HashSet<String>>,
pub timeout: Option<Duration>,
}
impl Default for MediaFetcher {
fn default() -> Self {
Self {
user_agent: DEFAULT_HTTP_USER_AGENT.to_string(),
allow_direct_ip: false,
allow_direct_port: false,
allow_private_ips: false,
allowed_media_domains: None,
timeout: Some(DEFAULT_HTTP_TIMEOUT),
}
}
}
impl MediaFetcher {
pub fn from_env() -> Self {
let allow_internal = std::env::var("DYN_MM_ALLOW_INTERNAL").ok().as_deref() == Some("1");
Self {
allow_direct_ip: allow_internal,
allow_direct_port: allow_internal,
allow_private_ips: allow_internal,
..Self::default()
}
}
}
impl MediaFetcher {
pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
if !matches!(url.scheme(), "http" | "https" | "data") {
anyhow::bail!("Only HTTP(S) and data URLs are allowed");
}
if url.scheme() == "data" {
return Ok(());
}
let host = url
.host()
.ok_or_else(|| anyhow::anyhow!("URL has no host component"))?;
if !self.allow_direct_ip && !matches!(host, url::Host::Domain(_)) {
anyhow::bail!("Direct IP access is not allowed");
}
if !self.allow_direct_port && url.port().is_some() {
anyhow::bail!("Direct port access is not allowed");
}
if !self.allow_private_ips {
let ip_literal = match host {
url::Host::Domain(domain) => {
let lowered = domain.trim_end_matches('.').to_ascii_lowercase();
if BLOCKED_HOSTS.contains(lowered.as_str()) {
anyhow::bail!("Host '{domain}' is blocked (resolves to internal service)");
}
None
}
url::Host::Ipv4(ip) => Some(IpAddr::V4(ip)),
url::Host::Ipv6(ip) => Some(IpAddr::V6(ip)),
};
if let Some(ip) = ip_literal
&& is_blocked_ip(&ip)
{
anyhow::bail!("IP literal '{ip}' is in a blocked range");
}
}
if let Some(allowed_domains) = &self.allowed_media_domains
&& let Some(host_str) = url.host_str()
&& !allowed_domains.contains(host_str)
{
anyhow::bail!("Host '{host_str}' is not in the allowed_media_domains list");
}
Ok(())
}
pub async fn check_if_url_allowed_with_dns(&self, url: &url::Url) -> Result<()> {
self.check_if_url_allowed(url)?;
if self.allow_private_ips || url.scheme() == "data" {
return Ok(());
}
let Some(url::Host::Domain(host)) = url.host() else {
return Ok(());
};
let port = url.port_or_known_default().unwrap_or(0);
let iter = tokio::net::lookup_host((host, port))
.await
.map_err(|e| anyhow::anyhow!("Could not resolve host '{host}': {e}"))?;
for sock_addr in iter {
let ip = sock_addr.ip();
if is_blocked_ip(&ip) {
anyhow::bail!("Host '{host}' resolves to blocked IP '{ip}'");
}
}
Ok(())
}
pub fn build_http_client(&self) -> Result<reqwest::Client> {
let fetcher_for_redirects = self.clone();
let redirect_policy = Policy::custom(move |attempt| {
if attempt.previous().len() >= MAX_REDIRECTS {
return attempt.error(anyhow::anyhow!("too many redirects (max={MAX_REDIRECTS})"));
}
match fetcher_for_redirects.check_if_url_allowed(attempt.url()) {
Ok(()) => attempt.follow(),
Err(e) => attempt.error(e),
}
});
let mut builder = reqwest::Client::builder()
.user_agent(&self.user_agent)
.redirect(redirect_policy)
.dns_resolver(Arc::new(BlocklistResolver {
allow_private_ips: self.allow_private_ips,
}));
if let Some(timeout) = self.timeout {
builder = builder.timeout(timeout);
}
Ok(builder.build()?)
}
}
struct BlocklistResolver {
allow_private_ips: bool,
}
impl Resolve for BlocklistResolver {
fn resolve(&self, name: Name) -> Resolving {
let host = name.as_str().to_string();
let allow_private = self.allow_private_ips;
Box::pin(async move {
let iter = tokio::net::lookup_host((host.as_str(), 0_u16)).await?;
let addrs: Vec<SocketAddr> = if allow_private {
iter.collect()
} else {
iter.filter(|sa| !is_blocked_ip(&sa.ip())).collect()
};
if addrs.is_empty() {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
format!("no non-blocked addresses for host '{host}'"),
))
as Box<dyn std::error::Error + Send + Sync>);
}
Ok(Box::new(addrs.into_iter()) as Addrs)
})
}
}
struct LoaderCache {
lru: LruCache<u64, RdmaMediaDataDescriptor>,
bytes_used: u64,
budget_bytes: u64,
}
impl LoaderCache {
fn new(budget_bytes: u64) -> Self {
Self {
lru: LruCache::unbounded(),
bytes_used: 0,
budget_bytes,
}
}
fn get(&mut self, key: &u64) -> Option<RdmaMediaDataDescriptor> {
self.lru.get(key).cloned()
}
fn put(&mut self, key: u64, val: RdmaMediaDataDescriptor) {
let val_bytes = descriptor_bytes(&val);
if let Some(old) = self.lru.pop(&key) {
self.bytes_used = self.bytes_used.saturating_sub(descriptor_bytes(&old));
}
self.lru.put(key, val);
self.bytes_used = self.bytes_used.saturating_add(val_bytes);
while self.bytes_used > self.budget_bytes && !self.lru.is_empty() {
if let Some((_, old)) = self.lru.pop_lru() {
self.bytes_used = self.bytes_used.saturating_sub(descriptor_bytes(&old));
} else {
break;
}
}
}
fn len(&self) -> usize {
self.lru.len()
}
}
fn descriptor_bytes(d: &RdmaMediaDataDescriptor) -> u64 {
let elem = match d.tensor_info.dtype {
DataType::UINT8 => 1u64,
};
d.tensor_info
.shape
.iter()
.try_fold(1u64, |acc, &x| acc.checked_mul(x as u64))
.unwrap_or(u64::MAX)
.saturating_mul(elem)
}
pub struct MediaLoader {
#[allow(dead_code)]
media_decoder: MediaDecoder,
#[allow(dead_code)]
http_client: reqwest::Client,
#[allow(dead_code)]
media_fetcher: MediaFetcher,
nixl_agent: NixlAgent,
cache: Option<Arc<Mutex<LoaderCache>>>,
}
impl MediaLoader {
fn cache_budget_bytes_from_env() -> u64 {
let gb = std::env::var("DYN_MULTIMODAL_LOADER_CACHE_GB")
.ok()
.and_then(|s| s.parse::<f64>().ok())
.filter(|v| v.is_finite() && *v >= 0.0)
.unwrap_or(0.0);
(gb * (1024.0 * 1024.0 * 1024.0)) as u64
}
fn cache_key(url: &str) -> u64 {
let mut h = DefaultHasher::new();
url.hash(&mut h);
h.finish()
}
pub fn new(media_decoder: MediaDecoder, media_fetcher: Option<MediaFetcher>) -> Result<Self> {
let media_fetcher = media_fetcher.unwrap_or_else(MediaFetcher::from_env);
let http_client = media_fetcher.build_http_client()?;
let nixl_agent = get_nixl_agent()?;
let cache = match Self::cache_budget_bytes_from_env() {
0 => {
tracing::debug!(
"[mm-cache] frontend media cache disabled (DYN_MULTIMODAL_LOADER_CACHE_GB=0)"
);
None
}
budget => {
tracing::info!(
budget_bytes = budget,
"[mm-cache] frontend media cache enabled (DYN_MULTIMODAL_LOADER_CACHE_GB)"
);
Some(Arc::new(Mutex::new(LoaderCache::new(budget))))
}
};
Ok(Self {
media_decoder,
http_client,
media_fetcher,
nixl_agent,
cache,
})
}
#[cfg(test)]
pub fn with_cache_budget_bytes(
media_decoder: MediaDecoder,
media_fetcher: Option<MediaFetcher>,
budget_bytes: u64,
) -> Result<Self> {
let mut loader = Self::new(media_decoder, media_fetcher)?;
loader.cache = if budget_bytes == 0 {
None
} else {
Some(Arc::new(Mutex::new(LoaderCache::new(budget_bytes))))
};
Ok(loader)
}
pub fn cache_len(&self) -> usize {
self.cache.as_ref().map(|c| c.lock().len()).unwrap_or(0)
}
pub async fn fetch_and_decode_media_part(
&self,
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
media_io_kwargs: Option<&MediaDecoder>,
) -> Result<RdmaMediaDataDescriptor> {
if let (Some(cache), ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part)) =
(self.cache.as_ref(), oai_content_part)
{
if media_io_kwargs.is_none() {
let key = Self::cache_key(image_part.image_url.url.as_str());
if let Some(hit) = cache.lock().get(&key) {
tracing::debug!(url_hash = key, "[mm-cache] hit");
return Ok(hit);
}
}
}
let decoded = match oai_content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
let mdc_decoder = self
.media_decoder
.image
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Model does not support image inputs"))?;
let url = &image_part.image_url.url;
self.media_fetcher
.check_if_url_allowed_with_dns(url)
.await?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
let decoder =
mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref()));
decoder.decode_async(data).await?
}
#[allow(unused_variables)]
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
#[cfg(not(feature = "media-ffmpeg"))]
anyhow::bail!("Video decoding requires the 'media-ffmpeg' feature to be enabled");
#[cfg(feature = "media-ffmpeg")]
{
let mdc_decoder =
self.media_decoder.video.as_ref().ok_or_else(|| {
anyhow::anyhow!("Model does not support video inputs")
})?;
let url = &video_part.video_url.url;
self.media_fetcher
.check_if_url_allowed_with_dns(url)
.await?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
let decoder =
mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.video.as_ref()));
decoder.decode_async(data).await?
}
}
ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
anyhow::bail!("Audio decoding is not supported yet");
}
_ => anyhow::bail!("Unsupported media type"),
};
let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
if let (Some(cache), ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part)) =
(self.cache.as_ref(), oai_content_part)
&& media_io_kwargs.is_none()
{
let key = Self::cache_key(image_part.image_url.url.as_str());
let bytes = descriptor_bytes(&rdma_descriptor);
cache.lock().put(key, rdma_descriptor.clone());
tracing::debug!(url_hash = key, bytes, "[mm-cache] insert");
}
Ok(rdma_descriptor)
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::super::decoders::ImageDecoder;
use super::super::rdma::DataType;
use super::*;
use dynamo_protocols::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};
#[tokio::test]
async fn test_fetch_and_decode() {
let test_image_bytes =
include_bytes!("../../../tests/data/media/llm-optimize-deploy-graphic.png");
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/llm-optimize-deploy-graphic.png")
.with_status(200)
.with_header("content-type", "image/png")
.with_body(&test_image_bytes[..])
.create_async()
.await;
let media_decoder = MediaDecoder {
image: Some(ImageDecoder::default()),
#[cfg(feature = "media-ffmpeg")]
video: None,
};
let fetcher = MediaFetcher {
allow_direct_ip: true,
allow_direct_port: true,
allow_private_ips: true,
..Default::default()
};
let loader: MediaLoader = match MediaLoader::new(media_decoder, Some(fetcher)) {
Ok(l) => l,
Err(e) => {
println!(
"test test_fetch_and_decode ... ignored (NIXL/UCX not available: {})",
e
);
return;
}
};
let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url()));
let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl(
ChatCompletionRequestMessageContentPartImage { image_url },
);
let result = loader
.fetch_and_decode_media_part(&content_part, None)
.await;
let descriptor = match result {
Ok(descriptor) => descriptor,
Err(e) if e.to_string().contains("NIXL agent is not available") => {
println!("test test_fetch_and_decode ... ignored (NIXL agent not available)");
return;
}
Err(e) => panic!("Failed to fetch and decode image: {}", e),
};
mock.assert_async().await;
assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8);
assert_eq!(descriptor.tensor_info.shape.len(), 3);
assert_eq!(
descriptor.tensor_info.shape[0], 1125,
"Height should be 1125"
);
assert_eq!(
descriptor.tensor_info.shape[1], 1999,
"Width should be 1999"
);
assert_eq!(
descriptor.tensor_info.shape[2], 4,
"RGBA channels should be 4"
);
assert!(
descriptor.source_storage.is_some(),
"Source storage should be present"
);
assert!(
descriptor.source_storage.unwrap().is_registered(),
"Source storage should be registered with NIXL"
);
}
#[tokio::test]
async fn test_cache_hit_skips_second_fetch() {
let test_image_bytes =
include_bytes!("../../../tests/data/media/llm-optimize-deploy-graphic.png");
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/cache-image.png")
.with_status(200)
.with_header("content-type", "image/png")
.with_body(&test_image_bytes[..])
.expect(1)
.create_async()
.await;
let media_decoder = MediaDecoder {
image: Some(ImageDecoder::default()),
#[cfg(feature = "media-ffmpeg")]
video: None,
};
let fetcher = MediaFetcher {
allow_direct_ip: true,
allow_direct_port: true,
allow_private_ips: true,
..Default::default()
};
let loader = match MediaLoader::with_cache_budget_bytes(
media_decoder,
Some(fetcher),
100 * 1024 * 1024,
) {
Ok(l) => l,
Err(e) => {
println!("test_cache_hit_skips_second_fetch ... ignored ({})", e);
return;
}
};
assert_eq!(loader.cache_len(), 0, "cache should start empty");
let url_string = format!("{}/cache-image.png", server.url());
let image_url = ImageUrl::from(url_string);
let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl(
ChatCompletionRequestMessageContentPartImage { image_url },
);
let first = match loader
.fetch_and_decode_media_part(&content_part, None)
.await
{
Ok(d) => d,
Err(e) if e.to_string().contains("NIXL agent is not available") => {
println!("test_cache_hit_skips_second_fetch ... ignored (NIXL not available)");
return;
}
Err(e) => panic!("first fetch failed: {}", e),
};
assert_eq!(loader.cache_len(), 1, "cache should hold one entry");
let second = loader
.fetch_and_decode_media_part(&content_part, None)
.await
.expect("second fetch should hit cache");
match (&first.source_storage, &second.source_storage) {
(Some(a), Some(b)) => assert!(
Arc::ptr_eq(a, b),
"cache hit should return the same Arc'd source_storage"
),
_ => panic!("source_storage missing from one of the descriptors"),
}
mock.assert_async().await;
}
#[tokio::test]
async fn test_cache_budget_lru_eviction() {
let test_image_bytes =
include_bytes!("../../../tests/data/media/llm-optimize-deploy-graphic.png");
let mut server = mockito::Server::new_async().await;
let mock_a = server
.mock("GET", "/a.png")
.with_status(200)
.with_header("content-type", "image/png")
.with_body(&test_image_bytes[..])
.expect(2)
.create_async()
.await;
let mock_b = server
.mock("GET", "/b.png")
.with_status(200)
.with_header("content-type", "image/png")
.with_body(&test_image_bytes[..])
.expect(1)
.create_async()
.await;
let mock_c = server
.mock("GET", "/c.png")
.with_status(200)
.with_header("content-type", "image/png")
.with_body(&test_image_bytes[..])
.expect(1)
.create_async()
.await;
let media_decoder = MediaDecoder {
image: Some(ImageDecoder::default()),
#[cfg(feature = "media-ffmpeg")]
video: None,
};
let fetcher = MediaFetcher {
allow_direct_ip: true,
allow_direct_port: true,
allow_private_ips: true,
..Default::default()
};
let loader = match MediaLoader::with_cache_budget_bytes(
media_decoder,
Some(fetcher),
18 * 1024 * 1024,
) {
Ok(l) => l,
Err(e) => {
println!("test_cache_budget_lru_eviction ... ignored ({})", e);
return;
}
};
let make_part = |path: &str| {
let image_url = ImageUrl::from(format!("{}{}", server.url(), path));
ChatCompletionRequestUserMessageContentPart::ImageUrl(
ChatCompletionRequestMessageContentPartImage { image_url },
)
};
let part_a = make_part("/a.png");
let part_b = make_part("/b.png");
let part_c = make_part("/c.png");
for (label, part) in [("a", &part_a), ("b", &part_b)] {
match loader.fetch_and_decode_media_part(part, None).await {
Ok(_) => {}
Err(e) if e.to_string().contains("NIXL agent is not available") => {
println!(
"test_cache_budget_lru_eviction ... ignored (NIXL not available, fetch={})",
label
);
return;
}
Err(e) => panic!("fetch {} failed: {}", label, e),
}
}
assert_eq!(loader.cache_len(), 2);
loader
.fetch_and_decode_media_part(&part_b, None)
.await
.expect("b re-fetch should hit cache");
loader
.fetch_and_decode_media_part(&part_c, None)
.await
.expect("c cold fetch should succeed");
assert_eq!(loader.cache_len(), 2);
loader
.fetch_and_decode_media_part(&part_a, None)
.await
.expect("a re-fetch after eviction should succeed");
assert_eq!(loader.cache_len(), 2);
mock_a.assert_async().await;
mock_b.assert_async().await;
mock_c.assert_async().await;
}
}
#[cfg(test)]
mod tests_non_nixl {
use super::*;
#[test]
fn test_cache_key_is_stable_per_url() {
let u1 = "http://images.example.com/a.png";
let u2 = "http://images.example.com/b.png";
assert_eq!(MediaLoader::cache_key(u1), MediaLoader::cache_key(u1));
assert_ne!(MediaLoader::cache_key(u1), MediaLoader::cache_key(u2));
let datauri = "data:image/png;base64,iVBORw0KGgoAAAA...".to_string();
assert_eq!(
MediaLoader::cache_key(&datauri),
MediaLoader::cache_key(&datauri)
);
}
#[test]
fn test_cache_budget_from_env_default_zero() {
const VAR: &str = "DYN_MULTIMODAL_LOADER_CACHE_GB";
const GIB: u64 = 1024 * 1024 * 1024;
let prev = std::env::var(VAR).ok();
unsafe {
std::env::remove_var(VAR);
}
assert_eq!(MediaLoader::cache_budget_bytes_from_env(), 0);
unsafe {
std::env::set_var(VAR, "1");
}
assert_eq!(MediaLoader::cache_budget_bytes_from_env(), GIB);
unsafe {
std::env::set_var(VAR, "0.5");
}
assert_eq!(MediaLoader::cache_budget_bytes_from_env(), GIB / 2);
unsafe {
std::env::set_var(VAR, "not-a-number");
}
assert_eq!(
MediaLoader::cache_budget_bytes_from_env(),
0,
"non-numeric value should fall back to 0"
);
unsafe {
std::env::set_var(VAR, "-1");
}
assert_eq!(MediaLoader::cache_budget_bytes_from_env(), 0);
match prev {
Some(v) => unsafe { std::env::set_var(VAR, v) },
None => unsafe { std::env::remove_var(VAR) },
}
}
#[test]
fn test_direct_ip_blocked() {
let fetcher = MediaFetcher {
allow_direct_ip: false,
..Default::default()
};
let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Direct IP access is not allowed")
);
}
#[test]
fn test_direct_port_blocked() {
let fetcher = MediaFetcher {
allow_direct_port: false,
..Default::default()
};
let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Direct port access is not allowed")
);
}
#[test]
fn test_domain_allowlist() {
let mut allowed_domains = HashSet::new();
allowed_domains.insert("trusted.com".to_string());
allowed_domains.insert("example.com".to_string());
let fetcher = MediaFetcher {
allowed_media_domains: Some(allowed_domains),
..Default::default()
};
let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
assert!(fetcher.check_if_url_allowed(&url).is_ok());
let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("allowed_media_domains")
);
}
#[test]
fn test_is_blocked_ip_ranges() {
for ip in [
"127.0.0.1",
"10.0.0.1",
"172.16.5.5",
"192.168.1.1",
"169.254.169.254", "100.64.0.1", "::1",
"fe80::1",
"fc00::1",
] {
let addr: IpAddr = ip.parse().unwrap();
assert!(is_blocked_ip(&addr), "{ip} should be blocked");
}
for ip in ["8.8.8.8", "1.1.1.1", "2606:4700:4700::1111"] {
let addr: IpAddr = ip.parse().unwrap();
assert!(!is_blocked_ip(&addr), "{ip} should not be blocked");
}
}
#[test]
fn test_blocked_ip_literal_rejected_even_when_direct_ip_allowed() {
let fetcher = MediaFetcher {
allow_direct_ip: true,
..Default::default()
};
let url = url::Url::parse("http://169.254.169.254/latest/meta-data/").unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("is in a blocked range")
);
}
#[test]
fn test_blocked_hostname_rejected() {
let fetcher = MediaFetcher::default();
for host in [
"localhost",
"metadata.google.internal",
"kubernetes.default.svc",
] {
let url = url::Url::parse(&format!("https://{host}/x")).unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err(), "{host} should be blocked");
assert!(
result.unwrap_err().to_string().contains("blocked"),
"{host} error should mention 'blocked'"
);
}
}
#[test]
fn test_allow_private_ips_bypasses_blocklist() {
let fetcher = MediaFetcher {
allow_direct_ip: true,
allow_private_ips: true,
..Default::default()
};
assert!(
fetcher
.check_if_url_allowed(&url::Url::parse("http://10.0.0.5/x").unwrap())
.is_ok()
);
assert!(
fetcher
.check_if_url_allowed(&url::Url::parse("https://localhost/x").unwrap())
.is_ok()
);
}
#[test]
fn test_hostname_blocklist_case_insensitive() {
let fetcher = MediaFetcher::default();
let url = url::Url::parse("https://Metadata.Google.Internal/x").unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
}
#[test]
fn test_from_env_default() {
unsafe {
std::env::remove_var("DYN_MM_ALLOW_INTERNAL");
}
let f = MediaFetcher::from_env();
assert!(!f.allow_private_ips);
assert!(!f.allow_direct_ip);
assert!(!f.allow_direct_port);
}
#[test]
fn test_hostname_blocklist_strips_trailing_dot() {
let fetcher = MediaFetcher::default();
let url = url::Url::parse("https://metadata.google.internal./x").unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err(), "FQDN with trailing dot should be rejected");
}
#[tokio::test]
async fn test_check_with_dns_data_url_skips_resolution() {
let fetcher = MediaFetcher::default();
let url = url::Url::parse("data:image/png;base64,iVBORw0KGgoAAAA=").unwrap();
fetcher.check_if_url_allowed_with_dns(&url).await.unwrap();
}
#[tokio::test]
async fn test_check_with_dns_public_ip_literal_passes() {
let fetcher = MediaFetcher {
allow_direct_ip: true,
..Default::default()
};
let url = url::Url::parse("https://8.8.8.8/x").unwrap();
fetcher.check_if_url_allowed_with_dns(&url).await.unwrap();
}
#[tokio::test]
async fn test_check_with_dns_blocked_hostname_fails_before_resolution() {
let fetcher = MediaFetcher::default();
let url = url::Url::parse("https://localhost/x").unwrap();
let result = fetcher.check_if_url_allowed_with_dns(&url).await;
assert!(result.is_err());
}
}