use anyhow::{Context, Result};
use super::{decode_base64_image, detect_format, Encoding, ImageMetadata, MediaFormat};
use crate::mode::SafetyLimits;
pub fn inspect_bytes(data: &[u8]) -> Result<ImageMetadata> {
let format = detect_format(data);
match format {
MediaFormat::Svg => inspect_svg(data),
_ if format.is_image() => inspect_raster(data, format),
_ => anyhow::bail!("not a recognized image format"),
}
}
pub fn inspect_base64(input: &str) -> Result<ImageMetadata> {
let (bytes, _mime_hint) = decode_base64_image(input)?;
let mut meta = inspect_bytes(&bytes)?;
meta.encoding = Encoding::Base64;
meta.size_bytes = bytes.len(); Ok(meta)
}
pub fn inspect_url(url: &str) -> Result<ImageMetadata> {
inspect_url_with_limits(url, &SafetyLimits::default())
}
pub fn inspect_url_with_limits(url: &str, limits: &SafetyLimits) -> Result<ImageMetadata> {
validate_url(url)?;
let bytes = fetch_url_safe(url, limits)?;
let mut meta = inspect_bytes(&bytes)?;
meta.encoding = Encoding::Url(url.to_string());
meta.size_bytes = bytes.len();
Ok(meta)
}
fn validate_url(input: &str) -> Result<()> {
let parsed = url::Url::parse(input).context("invalid URL")?;
match parsed.scheme() {
"https" | "http" => {}
scheme => anyhow::bail!(
"unsupported URL scheme '{}': only http/https allowed",
scheme
),
}
let host = parsed.host_str().context("URL missing host")?;
if host == "localhost" || host == "metadata.google.internal" {
anyhow::bail!("URL host '{}' is not allowed", host);
}
if host.starts_with("0x") || host.starts_with("0X") {
anyhow::bail!("URL host appears to be a hex-encoded IP address");
}
match parsed.host() {
Some(url::Host::Ipv4(ip)) => {
if is_private_ip(&std::net::IpAddr::V4(ip)) {
anyhow::bail!("URL contains a private/loopback IP address");
}
}
Some(url::Host::Ipv6(ip)) => {
if is_private_ip(&std::net::IpAddr::V6(ip)) {
anyhow::bail!("URL contains a private/loopback IP address");
}
}
Some(url::Host::Domain(_)) => {
let port = parsed
.port()
.unwrap_or(if parsed.scheme() == "https" { 443 } else { 80 });
if let Ok(addrs) = std::net::ToSocketAddrs::to_socket_addrs(&(host, port)) {
for addr in addrs {
if is_private_ip(&addr.ip()) {
anyhow::bail!("URL hostname resolves to a private/loopback IP address");
}
}
}
}
None => {
anyhow::bail!("URL has no host");
}
}
Ok(())
}
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_broadcast()
|| v4.is_unspecified()
|| v4.octets()[0] == 0 }
std::net::IpAddr::V6(v6) => {
if let Some(mapped_v4) = v6.to_ipv4_mapped() {
return is_private_ip(&std::net::IpAddr::V4(mapped_v4));
}
v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xffc0) == 0xfe80
|| (v6.segments()[0] & 0xfe00) == 0xfc00
}
}
}
fn inspect_raster(data: &[u8], detected_format: MediaFormat) -> Result<ImageMetadata> {
let reader = image::ImageReader::new(std::io::Cursor::new(data))
.with_guessed_format()
.context("failed to guess image format")?;
let (width, height) = reader
.into_dimensions()
.context("failed to read image dimensions")?;
let limits = SafetyLimits::default();
let pixels = width as u64 * height as u64;
if pixels > limits.max_pixels {
anyhow::bail!(
"image too large: {}x{} ({:.1} megapixels) exceeds limit of {:.0} megapixels",
width,
height,
pixels as f64 / 1_000_000.0,
limits.max_pixels as f64 / 1_000_000.0
);
}
Ok(ImageMetadata::new(
detected_format,
width,
height,
data.len(),
Encoding::Raw,
))
}
fn inspect_svg(data: &[u8]) -> Result<ImageMetadata> {
let source = std::str::from_utf8(data).context("SVG is not valid UTF-8")?;
let (width, height) = parse_svg_dimensions(source);
let mut meta = ImageMetadata::new(MediaFormat::Svg, width, height, data.len(), Encoding::Raw);
meta.svg_source = Some(source.to_string());
Ok(meta)
}
fn parse_svg_dimensions(svg: &str) -> (u32, u32) {
let width = extract_svg_attr(svg, "width");
let height = extract_svg_attr(svg, "height");
if let (Some(w), Some(h)) = (width, height) {
if w > 0 && h > 0 {
return (w, h);
}
}
if let Some(vb) = extract_svg_viewbox(svg) {
return vb;
}
(300, 150)
}
fn extract_svg_attr(svg: &str, attr_name: &str) -> Option<u32> {
let svg_tag_start = svg.find("<svg")?;
let svg_tag_end = svg[svg_tag_start..].find('>')? + svg_tag_start;
let tag = &svg[svg_tag_start..=svg_tag_end];
let search_pattern = format!(" {}=", attr_name);
let attr_pos = tag.find(&search_pattern)?;
let after_eq = &tag[attr_pos + search_pattern.len()..];
let value = if let Some(stripped) = after_eq.strip_prefix('"') {
let end = stripped.find('"')?;
&stripped[..end]
} else if let Some(stripped) = after_eq.strip_prefix('\'') {
let end = stripped.find('\'')?;
&stripped[..end]
} else {
let end = after_eq
.find(|c: char| c.is_whitespace() || c == '>')
.unwrap_or(after_eq.len());
&after_eq[..end]
};
let lower = value.to_lowercase();
if lower.contains('%')
|| lower.contains("em")
|| lower.contains("rem")
|| lower.contains("vw")
|| lower.contains("vh")
|| lower.contains("vmin")
|| lower.contains("vmax")
{
return None;
}
let numeric: String = value
.chars()
.take_while(|c| c.is_ascii_digit() || *c == '.')
.collect();
numeric.parse::<f64>().ok().map(|v| v as u32)
}
fn extract_svg_viewbox(svg: &str) -> Option<(u32, u32)> {
let svg_tag_start = svg.find("<svg")?;
let svg_tag_end = svg[svg_tag_start..].find('>')? + svg_tag_start;
let tag = &svg[svg_tag_start..=svg_tag_end];
let vb_pos = tag.find("viewBox=")?;
let after_eq = &tag[vb_pos + 8..];
let value = if let Some(stripped) = after_eq.strip_prefix('"') {
let end = stripped.find('"')?;
&stripped[..end]
} else if let Some(stripped) = after_eq.strip_prefix('\'') {
let end = stripped.find('\'')?;
&stripped[..end]
} else {
return None;
};
let parts: Vec<f64> = value
.split_whitespace()
.flat_map(|s| s.split(','))
.filter(|s| !s.is_empty())
.filter_map(|s| s.parse::<f64>().ok())
.collect();
if parts.len() >= 4 && parts[2] > 0.0 && parts[3] > 0.0 {
Some((parts[2] as u32, parts[3] as u32))
} else {
None
}
}
pub fn fetch_url_safe(url: &str, limits: &SafetyLimits) -> Result<Vec<u8>> {
validate_url(url)?;
let agent = ureq::Agent::new_with_config(
ureq::config::Config::builder()
.redirect_auth_headers(ureq::config::RedirectAuthHeaders::Never)
.max_redirects(0)
.timeout_global(Some(std::time::Duration::from_secs(30)))
.build(),
);
let response = agent
.get(url)
.call()
.with_context(|| "failed to fetch image from URL".to_string())?;
let status = response.status().as_u16();
if (300..400).contains(&status) {
anyhow::bail!(
"image URL returned a redirect (HTTP {}); redirects are disabled for security",
status
);
}
if status != 200 {
anyhow::bail!("failed to fetch image: HTTP {}", status);
}
if let Some(cl) = response.headers().get("content-length") {
if let Ok(size) = cl.to_str().unwrap_or("").parse::<usize>() {
if size > limits.max_download_bytes {
anyhow::bail!(
"image URL Content-Length ({} bytes) exceeds limit of {} bytes",
size,
limits.max_download_bytes
);
}
}
}
use std::io::Read;
let max = limits.max_download_bytes;
let mut body = response.into_body();
let mut buf = Vec::new();
let bytes_read = body
.as_reader()
.take((max + 1) as u64)
.read_to_end(&mut buf)
.context("failed to read image response body")?;
if bytes_read > max {
anyhow::bail!(
"downloaded image too large: read at least {} bytes, exceeds limit of {} bytes",
bytes_read,
max
);
}
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_png(width: u32, height: u32) -> Vec<u8> {
let img = image::RgbaImage::new(width, height);
let mut buf = Vec::new();
let encoder = image::codecs::png::PngEncoder::new(&mut buf);
image::ImageEncoder::write_image(
encoder,
img.as_raw(),
width,
height,
image::ExtendedColorType::Rgba8,
)
.unwrap();
buf
}
fn make_jpeg(width: u32, height: u32) -> Vec<u8> {
let img = image::RgbImage::new(width, height);
let mut buf = Vec::new();
let mut encoder = image::codecs::jpeg::JpegEncoder::new_with_quality(&mut buf, 80);
encoder
.encode(img.as_raw(), width, height, image::ExtendedColorType::Rgb8)
.unwrap();
buf
}
#[test]
fn test_inspect_png() {
let data = make_png(640, 480);
let meta = inspect_bytes(&data).unwrap();
assert_eq!(meta.format, MediaFormat::Png);
assert_eq!(meta.width, 640);
assert_eq!(meta.height, 480);
assert_eq!(meta.max_dim(), 640);
}
#[test]
fn test_inspect_jpeg() {
let data = make_jpeg(1920, 1080);
let meta = inspect_bytes(&data).unwrap();
assert_eq!(meta.format, MediaFormat::Jpeg);
assert_eq!(meta.width, 1920);
assert_eq!(meta.height, 1080);
}
#[test]
fn test_inspect_svg_with_dimensions() {
let svg =
r#"<svg xmlns="http://www.w3.org/2000/svg" width="200" height="100"><rect/></svg>"#;
let meta = inspect_bytes(svg.as_bytes()).unwrap();
assert_eq!(meta.format, MediaFormat::Svg);
assert_eq!(meta.width, 200);
assert_eq!(meta.height, 100);
assert!(meta.svg_source.is_some());
}
#[test]
fn test_inspect_svg_with_viewbox() {
let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 800 600"><rect/></svg>"#;
let meta = inspect_bytes(svg.as_bytes()).unwrap();
assert_eq!(meta.format, MediaFormat::Svg);
assert_eq!(meta.width, 800);
assert_eq!(meta.height, 600);
}
#[test]
fn test_inspect_svg_with_xml_declaration() {
let svg = r#"<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" width="500" height="300">
<circle cx="250" cy="150" r="100"/>
</svg>"#;
let meta = inspect_bytes(svg.as_bytes()).unwrap();
assert_eq!(meta.format, MediaFormat::Svg);
assert_eq!(meta.width, 500);
assert_eq!(meta.height, 300);
}
#[test]
fn test_inspect_svg_viewbox_comma_separated() {
let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0,0,1024,768"><rect/></svg>"#;
let meta = inspect_bytes(svg.as_bytes()).unwrap();
assert_eq!(meta.width, 1024);
assert_eq!(meta.height, 768);
}
#[test]
fn test_inspect_svg_px_units() {
let svg =
r#"<svg xmlns="http://www.w3.org/2000/svg" width="200px" height="150px"><rect/></svg>"#;
let meta = inspect_bytes(svg.as_bytes()).unwrap();
assert_eq!(meta.width, 200);
assert_eq!(meta.height, 150);
}
#[test]
fn test_inspect_svg_percentage_falls_to_viewbox() {
let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 4000 3000"><rect/></svg>"#;
let meta = inspect_bytes(svg.as_bytes()).unwrap();
assert_eq!(meta.width, 4000);
assert_eq!(meta.height, 3000);
}
#[test]
fn test_inspect_svg_em_units_falls_to_viewbox() {
let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" width="10em" height="8em" viewBox="0 0 500 400"><rect/></svg>"#;
let meta = inspect_bytes(svg.as_bytes()).unwrap();
assert_eq!(meta.width, 500);
assert_eq!(meta.height, 400);
}
#[test]
fn test_inspect_svg_stroke_width_not_confused() {
let svg = r#"<svg xmlns="http://www.w3.org/2000/svg" stroke-width="3" width="800" height="600"><rect/></svg>"#;
let meta = inspect_bytes(svg.as_bytes()).unwrap();
assert_eq!(meta.width, 800);
assert_eq!(meta.height, 600);
}
#[test]
fn test_inspect_base64_png() {
use base64::Engine;
let png_data = make_png(100, 50);
let encoded = base64::engine::general_purpose::STANDARD.encode(&png_data);
let data_uri = format!("data:image/png;base64,{}", encoded);
let meta = inspect_base64(&data_uri).unwrap();
assert_eq!(meta.format, MediaFormat::Png);
assert_eq!(meta.width, 100);
assert_eq!(meta.height, 50);
assert_eq!(meta.encoding, Encoding::Base64);
}
#[test]
fn test_inspect_base64_raw() {
use base64::Engine;
let png_data = make_png(64, 64);
let encoded = base64::engine::general_purpose::STANDARD.encode(&png_data);
let meta = inspect_base64(&encoded).unwrap();
assert_eq!(meta.format, MediaFormat::Png);
assert_eq!(meta.width, 64);
assert_eq!(meta.height, 64);
}
#[test]
fn test_inspect_not_an_image() {
let result = inspect_bytes(b"this is just text, not an image");
assert!(result.is_err());
}
#[test]
fn test_megapixels() {
let data = make_png(4000, 3000);
let meta = inspect_bytes(&data).unwrap();
assert!((meta.megapixels - 12.0).abs() < 0.001);
}
#[test]
fn test_validate_url_rejects_private_ip() {
assert!(validate_url("http://127.0.0.1/image.png").is_err());
assert!(validate_url("http://10.0.0.1/image.png").is_err());
assert!(validate_url("http://172.16.0.1/image.png").is_err());
assert!(validate_url("http://192.168.1.1/image.png").is_err());
assert!(validate_url("http://169.254.169.254/latest/meta-data/").is_err());
}
#[test]
fn test_validate_url_rejects_localhost() {
assert!(validate_url("http://localhost/image.png").is_err());
assert!(validate_url("http://localhost:8080/secret").is_err());
}
#[test]
fn test_validate_url_rejects_ipv6_loopback() {
assert!(validate_url("http://[::1]/image.png").is_err());
}
#[test]
fn test_validate_url_rejects_file_scheme() {
assert!(validate_url("file:///etc/passwd").is_err());
}
#[test]
fn test_validate_url_rejects_hex_ip() {
assert!(validate_url("http://0x7f000001/image.png").is_err());
}
#[test]
fn test_validate_url_allows_public() {
assert!(validate_url("https://example.com/image.png").is_ok());
assert!(validate_url("https://cdn.openai.com/image.png").is_ok());
}
#[test]
fn test_validate_url_rejects_ipv4_mapped_ipv6() {
assert!(validate_url("http://[::ffff:127.0.0.1]/image.png").is_err());
assert!(validate_url("http://[::ffff:10.0.0.1]/image.png").is_err());
assert!(validate_url("http://[::ffff:169.254.169.254]/image.png").is_err());
assert!(validate_url("http://[::ffff:192.168.1.1]/image.png").is_err());
}
#[test]
fn test_is_private_ip_ipv4_mapped_ipv6() {
use std::net::{IpAddr, Ipv6Addr};
let mapped_loopback: Ipv6Addr = "::ffff:127.0.0.1".parse().unwrap();
assert!(is_private_ip(&IpAddr::V6(mapped_loopback)));
let mapped_private: Ipv6Addr = "::ffff:10.0.0.1".parse().unwrap();
assert!(is_private_ip(&IpAddr::V6(mapped_private)));
let mapped_public: Ipv6Addr = "::ffff:8.8.8.8".parse().unwrap();
assert!(!is_private_ip(&IpAddr::V6(mapped_public)));
}
#[test]
fn test_validate_url_resolves_hostname_localhost() {
assert!(validate_url("http://localhost/image.png").is_err());
}
#[test]
fn test_normal_image_passes_pixel_budget() {
let data = make_png(4000, 3000); let meta = inspect_bytes(&data).unwrap();
assert_eq!(meta.width, 4000);
}
#[test]
fn test_pixel_budget_rejects_oversized() {
use crate::mode::SafetyLimits;
assert_eq!(SafetyLimits::default().max_pixels, 100_000_000);
let data = make_png(10000, 10000);
let meta = inspect_bytes(&data).unwrap();
assert_eq!(meta.width, 10000);
}
#[test]
fn test_svg_viewbox_negative_dims_fallback() {
let svg =
r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 -100 -100"><rect/></svg>"#;
let meta = inspect_bytes(svg.as_bytes()).unwrap();
assert_eq!(meta.width, 300);
assert_eq!(meta.height, 150);
}
}