use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::net::ToSocketAddrs;
use std::time::Duration;
use thiserror::Error;
pub const DEFAULT_MAX_BYTES: u64 = 500 * 1024 * 1024;
pub const DEFAULT_TIMEOUT_SECS: u64 = 120;
pub const MAX_UPLOAD_BYTES: u64 = 1024 * 1024 * 1024;
#[derive(Error, Debug)]
pub enum FileManagerError {
#[error("Missing required argument: {0}")]
MissingArg(&'static str),
#[error("Invalid argument '{name}': {reason}")]
InvalidArg { name: &'static str, reason: String },
#[error("URL is not allowed (private/internal address): {0}")]
PrivateUrl(String),
#[error("Host '{host}' is not in the download allowlist")]
HostNotAllowed { host: String },
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("HTTP error fetching '{url}': {source}")]
Http {
url: String,
#[source]
source: reqwest::Error,
},
#[error("Upstream returned status {status} for '{url}': {body}")]
Upstream {
url: String,
status: u16,
body: String,
},
#[error("Response exceeds max-bytes ({limit} bytes)")]
SizeCap { limit: u64 },
#[error("Invalid extra header '{name}': {reason}")]
BadHeader { name: String, reason: String },
#[error("Failed to read file '{path}': {source}")]
Io {
path: String,
#[source]
source: std::io::Error,
},
#[error("Upload destinations not configured on the proxy — operator must declare `[provider.upload_destinations.<name>]` in `manifests/file_manager.toml`")]
UploadNotConfigured,
#[error("Unknown upload destination '{0}' — not in the operator's allowlist")]
UnknownDestination(String),
#[error("Upload failed: {0}")]
Upload(String),
#[error("Invalid base64 in upload payload: {0}")]
Base64(#[from] base64::DecodeError),
}
impl FileManagerError {
pub fn http_status(&self) -> u16 {
match self {
Self::MissingArg(_)
| Self::InvalidArg { .. }
| Self::BadHeader { .. }
| Self::Base64(_) => 400,
Self::PrivateUrl(_) | Self::HostNotAllowed { .. } | Self::UnknownDestination(_) => 403,
Self::SizeCap { .. } => 413,
Self::UploadNotConfigured => 503,
Self::Upstream { status, .. } => (*status).clamp(400, 599),
Self::Http { .. } | Self::InvalidUrl(_) | Self::Upload(_) => 502,
Self::Io { .. } => 500,
}
}
}
const DENIED_DOWNLOAD_HEADERS: &[&str] = &[
"host",
"content-length",
"transfer-encoding",
"connection",
"proxy-authorization",
];
fn validate_extra_headers(headers: &HashMap<String, String>) -> Result<(), FileManagerError> {
for name in headers.keys() {
let lower = name.to_lowercase();
if DENIED_DOWNLOAD_HEADERS.contains(&lower.as_str()) {
return Err(FileManagerError::BadHeader {
name: name.clone(),
reason: "header is not allowed".into(),
});
}
if !name.bytes().all(|b| b.is_ascii() && b > 32 && b != b':') {
return Err(FileManagerError::BadHeader {
name: name.clone(),
reason: "header name contains invalid characters".into(),
});
}
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct DownloadArgs {
pub url: String,
pub max_bytes: u64,
pub timeout: Duration,
pub follow_redirects: bool,
pub headers: HashMap<String, String>,
}
impl DownloadArgs {
pub fn from_value(args: &HashMap<String, Value>) -> Result<Self, FileManagerError> {
let url = args
.get("url")
.and_then(|v| v.as_str())
.ok_or(FileManagerError::MissingArg("url"))?
.trim()
.to_string();
if url.is_empty() {
return Err(FileManagerError::MissingArg("url"));
}
let max_bytes = parse_u64_arg(args, &["max_bytes", "max-bytes"], "max_bytes")?
.unwrap_or(DEFAULT_MAX_BYTES);
if max_bytes == 0 {
return Err(FileManagerError::InvalidArg {
name: "max_bytes",
reason: "must be > 0".into(),
});
}
let timeout_secs =
parse_u64_arg(args, &["timeout"], "timeout")?.unwrap_or(DEFAULT_TIMEOUT_SECS);
let follow_redirects = args
.get("follow_redirects")
.or_else(|| args.get("follow-redirects"))
.and_then(|v| v.as_bool())
.unwrap_or(true);
let headers = parse_headers(args.get("headers"))?;
validate_extra_headers(&headers)?;
Ok(DownloadArgs {
url,
max_bytes,
timeout: Duration::from_secs(timeout_secs),
follow_redirects,
headers,
})
}
}
fn parse_u64_arg(
args: &HashMap<String, Value>,
aliases: &[&str],
field: &'static str,
) -> Result<Option<u64>, FileManagerError> {
let raw = aliases.iter().find_map(|k| args.get(*k));
let Some(v) = raw else {
return Ok(None);
};
let err = || FileManagerError::InvalidArg {
name: field,
reason: "must be a positive integer".into(),
};
match v {
Value::Number(n) => n.as_u64().map(Some).ok_or_else(err),
Value::String(s) => s
.parse::<u64>()
.map(Some)
.map_err(|e| FileManagerError::InvalidArg {
name: field,
reason: e.to_string(),
}),
_ => Err(err()),
}
}
fn parse_headers(value: Option<&Value>) -> Result<HashMap<String, String>, FileManagerError> {
let value = match value {
Some(v) => v,
None => return Ok(HashMap::new()),
};
let map = match value {
Value::Object(map) => map.clone(),
Value::String(s) if s.trim().is_empty() => return Ok(HashMap::new()),
Value::String(s) => match serde_json::from_str::<Value>(s) {
Ok(Value::Object(map)) => map,
Ok(_) => {
return Err(FileManagerError::InvalidArg {
name: "headers",
reason: "must be a JSON object".into(),
});
}
Err(e) => {
return Err(FileManagerError::InvalidArg {
name: "headers",
reason: format!("invalid JSON: {e}"),
});
}
},
Value::Null => return Ok(HashMap::new()),
_ => {
return Err(FileManagerError::InvalidArg {
name: "headers",
reason: "must be a JSON object or JSON string".into(),
});
}
};
let mut out = HashMap::with_capacity(map.len());
for (k, v) in map {
let s = match v {
Value::String(s) => s,
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
_ => {
return Err(FileManagerError::InvalidArg {
name: "headers",
reason: format!("value for '{k}' must be a string, number, or bool"),
});
}
};
out.insert(k, s);
}
Ok(out)
}
#[derive(Debug)]
pub struct DownloadResult {
pub bytes: Vec<u8>,
pub content_type: Option<String>,
pub source_url: String,
}
fn allowlist_patterns() -> Option<Vec<String>> {
let raw = std::env::var("ATI_DOWNLOAD_ALLOWLIST").ok()?;
let patterns: Vec<String> = raw
.split(',')
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty())
.collect();
if patterns.is_empty() {
None
} else {
Some(patterns)
}
}
fn host_matches_pattern(host: &str, pattern: &str) -> bool {
let host = host.to_lowercase();
if pattern == "*" {
return true;
}
if let Some(suffix) = pattern.strip_prefix("*.") {
return host == suffix || host.ends_with(&format!(".{suffix}"));
}
host == pattern
}
pub fn enforce_download_allowlist(url: &str) -> Result<(), FileManagerError> {
let patterns = match allowlist_patterns() {
Some(p) => p,
None => return Ok(()),
};
let parsed = reqwest::Url::parse(url)
.map_err(|e| FileManagerError::InvalidUrl(format!("could not parse URL: {e}")))?;
let host = parsed
.host_str()
.ok_or_else(|| FileManagerError::InvalidUrl("URL has no host component".into()))?;
if patterns.iter().any(|p| host_matches_pattern(host, p)) {
Ok(())
} else {
Err(FileManagerError::HostNotAllowed {
host: host.to_string(),
})
}
}
pub async fn fetch_bytes(args: &DownloadArgs) -> Result<DownloadResult, FileManagerError> {
crate::core::http::validate_url_not_private(&args.url).map_err(|e| match e {
crate::core::http::HttpError::SsrfBlocked(url) => FileManagerError::PrivateUrl(url),
other => FileManagerError::InvalidUrl(other.to_string()),
})?;
enforce_download_allowlist(&args.url)?;
let redirect_policy = if args.follow_redirects {
reqwest::redirect::Policy::limited(10)
} else {
reqwest::redirect::Policy::none()
};
let client = reqwest::Client::builder()
.timeout(args.timeout)
.redirect(redirect_policy)
.build()
.map_err(|e| FileManagerError::Http {
url: args.url.clone(),
source: e,
})?;
let mut req = client.get(&args.url);
for (k, v) in &args.headers {
req = req.header(k.as_str(), v.as_str());
}
let response = req.send().await.map_err(|e| FileManagerError::Http {
url: args.url.clone(),
source: e,
})?;
let status = response.status();
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
let truncated = if body.len() > 512 {
&body[..512]
} else {
&body
};
return Err(FileManagerError::Upstream {
url: args.url.clone(),
status: status.as_u16(),
body: truncated.to_string(),
});
}
let content_length = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
if let Some(len) = content_length {
if len > args.max_bytes {
return Err(FileManagerError::SizeCap {
limit: args.max_bytes,
});
}
}
use futures::StreamExt;
let initial_cap = content_length
.map(|l| l.min(args.max_bytes) as usize)
.unwrap_or(64 * 1024);
let mut bytes = Vec::with_capacity(initial_cap);
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| FileManagerError::Http {
url: args.url.clone(),
source: e,
})?;
if (bytes.len() as u64).saturating_add(chunk.len() as u64) > args.max_bytes {
return Err(FileManagerError::SizeCap {
limit: args.max_bytes,
});
}
bytes.extend_from_slice(&chunk);
}
Ok(DownloadResult {
bytes,
content_type,
source_url: args.url.clone(),
})
}
pub fn build_download_response(result: &DownloadResult) -> Value {
json!({
"success": true,
"size_bytes": result.bytes.len(),
"content_type": result.content_type,
"source_url": result.source_url,
"content_base64": B64.encode(&result.bytes),
})
}
pub fn guess_content_type(path: &str) -> &'static str {
let lower = path.to_ascii_lowercase();
let ext = lower.rsplit('.').next().unwrap_or("");
match ext {
"png" => "image/png",
"jpg" | "jpeg" => "image/jpeg",
"gif" => "image/gif",
"webp" => "image/webp",
"svg" => "image/svg+xml",
"pdf" => "application/pdf",
"mp4" | "m4v" => "video/mp4",
"mov" => "video/quicktime",
"webm" => "video/webm",
"mp3" => "audio/mpeg",
"wav" => "audio/wav",
"ogg" | "oga" => "audio/ogg",
"flac" => "audio/flac",
"m4a" => "audio/mp4",
"csv" => "text/csv",
"json" => "application/json",
"xml" => "application/xml",
"zip" => "application/zip",
"html" | "htm" => "text/html",
"md" => "text/markdown",
"txt" | "log" => "text/plain",
_ => "application/octet-stream",
}
}
#[derive(Debug)]
pub struct UploadArgs {
pub filename: String,
pub content_type: Option<String>,
pub bytes: Vec<u8>,
pub destination: Option<String>,
}
impl UploadArgs {
pub fn from_wire(args: &HashMap<String, Value>) -> Result<Self, FileManagerError> {
let filename = args
.get("filename")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or(FileManagerError::MissingArg("filename"))?;
let content_type = args
.get("content_type")
.or_else(|| args.get("content-type"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let b64 = args
.get("content_base64")
.or_else(|| args.get("content-base64"))
.and_then(|v| v.as_str())
.ok_or(FileManagerError::MissingArg("content_base64"))?;
let bytes = B64.decode(b64.as_bytes())?;
if (bytes.len() as u64) > MAX_UPLOAD_BYTES {
return Err(FileManagerError::SizeCap {
limit: MAX_UPLOAD_BYTES,
});
}
let destination = args
.get("destination")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
Ok(UploadArgs {
filename: sanitize_filename(&filename),
content_type,
bytes,
destination,
})
}
}
fn sanitize_filename(input: &str) -> String {
let trimmed = input.trim_matches(|c: char| c == '/' || c.is_whitespace());
let last = trimmed.rsplit('/').next().unwrap_or(trimmed);
let cleaned: String = last.chars().filter(|c| !c.is_control()).collect::<String>();
if cleaned.is_empty() || cleaned == "." || cleaned == ".." {
format!("upload-{}", chrono::Utc::now().timestamp_millis())
} else {
cleaned
}
}
#[derive(Debug)]
pub struct UploadResult {
pub url: String,
pub size_bytes: u64,
pub content_type: String,
pub destination: String,
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum UploadDestination {
Gcs {
bucket: String,
#[serde(default = "default_gcs_prefix")]
prefix: String,
#[serde(default = "default_gcs_key_ref")]
key_ref: String,
},
FalStorage {
#[serde(default = "default_fal_key_ref")]
key_ref: String,
#[serde(default)]
endpoint: Option<String>,
},
}
fn default_gcs_prefix() -> String {
"ati-uploads".to_string()
}
fn default_gcs_key_ref() -> String {
"gcp_credentials".to_string()
}
fn default_fal_key_ref() -> String {
"fal_api_key".to_string()
}
pub fn resolve_destination<'a>(
destinations: &'a HashMap<String, UploadDestination>,
default: Option<&str>,
requested: Option<&str>,
) -> Result<(String, &'a UploadDestination), FileManagerError> {
if destinations.is_empty() {
return Err(FileManagerError::UploadNotConfigured);
}
let key = match requested {
Some(k) if !k.is_empty() => k.to_string(),
_ => default
.map(|s| s.to_string())
.ok_or(FileManagerError::UploadNotConfigured)?,
};
let sink = destinations
.get(&key)
.ok_or_else(|| FileManagerError::UnknownDestination(key.clone()))?;
Ok((key, sink))
}
pub fn build_upload_response(result: &UploadResult) -> Value {
json!({
"success": true,
"url": result.url,
"size_bytes": result.size_bytes,
"content_type": result.content_type,
"destination": result.destination,
})
}
pub async fn upload_to_destination(
args: UploadArgs,
destinations: &HashMap<String, UploadDestination>,
default: Option<&str>,
keyring: &crate::core::keyring::Keyring,
) -> Result<Value, FileManagerError> {
let (key, sink) = resolve_destination(destinations, default, args.destination.as_deref())?;
let result = match sink {
UploadDestination::Gcs {
bucket,
prefix,
key_ref,
} => upload_to_gcs(args, bucket, prefix, key_ref, keyring, &key).await?,
UploadDestination::FalStorage { key_ref, endpoint } => {
upload_to_fal(args, key_ref, endpoint.as_deref(), keyring, &key).await?
}
};
Ok(build_upload_response(&result))
}
async fn upload_to_gcs(
args: UploadArgs,
bucket: &str,
prefix: &str,
key_ref: &str,
keyring: &crate::core::keyring::Keyring,
destination_key: &str,
) -> Result<UploadResult, FileManagerError> {
let service_account_json = keyring
.get(key_ref)
.ok_or_else(|| {
FileManagerError::Upload(format!("keyring key '{key_ref}' missing for GCS upload"))
})?
.to_string();
let content_type = args
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
let size_bytes = args.bytes.len() as u64;
let date = chrono::Utc::now().format("%Y-%m-%d");
let uuid = uuid::Uuid::new_v4();
let object_name = format!("{prefix}/{date}/{uuid}-{}", args.filename);
let client =
crate::core::gcs::GcsClient::new_read_write(bucket.to_string(), &service_account_json)
.map_err(|e| FileManagerError::Upload(e.to_string()))?;
let url = client
.upload_object(&object_name, args.bytes, &content_type)
.await
.map_err(|e| FileManagerError::Upload(e.to_string()))?;
Ok(UploadResult {
url,
size_bytes,
content_type,
destination: destination_key.to_string(),
})
}
fn require_public_https_url(url: &str) -> Result<(), FileManagerError> {
let parsed = reqwest::Url::parse(url)
.map_err(|e| FileManagerError::Upload(format!("server returned malformed URL: {e}")))?;
if parsed.scheme() != "https" {
return Err(FileManagerError::Upload(format!(
"refusing non-HTTPS URL from server: {url}"
)));
}
let host = parsed
.host_str()
.ok_or_else(|| FileManagerError::Upload(format!("server URL has no host: {url}")))?;
let host_lower = host.to_lowercase();
if host_lower == "localhost"
|| host_lower == "metadata.google.internal"
|| host_lower.ends_with(".internal")
|| host_lower.ends_with(".local")
{
return Err(FileManagerError::Upload(format!(
"server URL targets a private hostname: {url}"
)));
}
let port = parsed.port_or_known_default().unwrap_or(443);
let ip_host = host.trim_matches(['[', ']']);
let is_private = if let Ok(ip) = ip_host.parse::<std::net::IpAddr>() {
is_private_ip_addr(ip)
} else if let Ok(addrs) = (ip_host, port).to_socket_addrs() {
addrs.into_iter().any(|addr| is_private_ip_addr(addr.ip()))
} else {
false
};
if is_private {
return Err(FileManagerError::Upload(format!(
"server URL resolves to a private address: {url}"
)));
}
Ok(())
}
fn is_private_ip_addr(ip: std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(ip) => is_private_ipv4(ip),
std::net::IpAddr::V6(ip) => {
if let Some(v4) = ip.to_ipv4_mapped() {
return is_private_ipv4(v4);
}
ip.is_loopback()
|| ip.is_unspecified()
|| ip.is_unique_local()
|| ip.is_unicast_link_local()
}
}
}
fn is_private_ipv4(ip: std::net::Ipv4Addr) -> bool {
ip.is_loopback()
|| ip.is_private()
|| ip.is_link_local()
|| ip.is_unspecified()
|| (ip.octets()[0] == 100 && ip.octets()[1] >= 64 && ip.octets()[1] <= 127)
}
async fn upload_to_fal(
args: UploadArgs,
key_ref: &str,
endpoint: Option<&str>,
keyring: &crate::core::keyring::Keyring,
destination_key: &str,
) -> Result<UploadResult, FileManagerError> {
use serde::Deserialize;
let api_key = keyring
.get(key_ref)
.ok_or_else(|| {
FileManagerError::Upload(format!("keyring key '{key_ref}' missing for fal upload"))
})?
.to_string();
let rest_base = endpoint.unwrap_or("https://rest.alpha.fal.ai");
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.map_err(|e| FileManagerError::Upload(format!("http client init: {e}")))?;
let token_url = format!("{rest_base}/storage/auth/token?storage_type=fal-cdn-v3");
let token_resp = http
.post(&token_url)
.header("Authorization", format!("Key {api_key}"))
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.body("{}")
.send()
.await
.map_err(|e| FileManagerError::Upload(format!("fal token request failed: {e}")))?;
if !token_resp.status().is_success() {
let status = token_resp.status().as_u16();
let body = token_resp.text().await.unwrap_or_default();
return Err(FileManagerError::Upload(format!(
"fal token mint returned {status}: {body}"
)));
}
#[derive(Deserialize)]
struct FalToken {
token: String,
token_type: String,
base_url: String,
}
let token: FalToken = token_resp
.json()
.await
.map_err(|e| FileManagerError::Upload(format!("fal token JSON parse failed: {e}")))?;
let content_type = args
.content_type
.unwrap_or_else(|| "application/octet-stream".to_string());
let size_bytes = args.bytes.len() as u64;
let upload_url = format!("{}/files/upload", token.base_url.trim_end_matches('/'));
require_public_https_url(&upload_url)?;
let upload_resp = http
.post(&upload_url)
.header(
"Authorization",
format!("{} {}", token.token_type, token.token),
)
.header("Content-Type", &content_type)
.header("X-Fal-File-Name", &args.filename)
.body(args.bytes)
.send()
.await
.map_err(|e| FileManagerError::Upload(format!("fal upload request failed: {e}")))?;
if !upload_resp.status().is_success() {
let status = upload_resp.status().as_u16();
let body = upload_resp.text().await.unwrap_or_default();
return Err(FileManagerError::Upload(format!(
"fal upload returned {status}: {body}"
)));
}
#[derive(Deserialize)]
struct FalUploadResponse {
access_url: String,
}
let body: FalUploadResponse = upload_resp
.json()
.await
.map_err(|e| FileManagerError::Upload(format!("fal upload JSON parse failed: {e}")))?;
Ok(UploadResult {
url: body.access_url,
size_bytes,
content_type,
destination: destination_key.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_headers_object() {
let v = serde_json::json!({"X-Test": "1", "X-Other": "abc"});
let map = parse_headers(Some(&v)).unwrap();
assert_eq!(map.len(), 2);
assert_eq!(map.get("X-Test").map(String::as_str), Some("1"));
}
#[test]
fn parse_headers_string_json() {
let v = Value::String(r#"{"Authorization":"Bearer abc"}"#.into());
let map = parse_headers(Some(&v)).unwrap();
assert_eq!(
map.get("Authorization").map(String::as_str),
Some("Bearer abc")
);
}
#[test]
fn parse_headers_empty_string() {
let v = Value::String("".into());
assert!(parse_headers(Some(&v)).unwrap().is_empty());
}
#[test]
fn parse_headers_invalid_type() {
let v = Value::Number(42.into());
assert!(parse_headers(Some(&v)).is_err());
}
#[test]
fn validate_denied_header() {
let mut map = HashMap::new();
map.insert("Host".to_string(), "evil.com".to_string());
assert!(validate_extra_headers(&map).is_err());
}
#[test]
fn download_args_defaults() {
let mut args = HashMap::new();
args.insert(
"url".to_string(),
Value::String("https://example.com".into()),
);
let parsed = DownloadArgs::from_value(&args).unwrap();
assert_eq!(parsed.max_bytes, DEFAULT_MAX_BYTES);
assert_eq!(parsed.timeout, Duration::from_secs(DEFAULT_TIMEOUT_SECS));
assert!(parsed.follow_redirects);
assert!(parsed.headers.is_empty());
}
#[test]
fn download_args_missing_url() {
let args = HashMap::new();
assert!(DownloadArgs::from_value(&args).is_err());
}
#[test]
fn download_args_zero_max_bytes_rejected() {
let mut args = HashMap::new();
args.insert(
"url".to_string(),
Value::String("https://example.com".into()),
);
args.insert("max_bytes".to_string(), Value::Number(0.into()));
assert!(DownloadArgs::from_value(&args).is_err());
}
#[test]
fn download_args_max_bytes_string() {
let mut args = HashMap::new();
args.insert(
"url".to_string(),
Value::String("https://example.com".into()),
);
args.insert("max_bytes".to_string(), Value::String("1024".into()));
let parsed = DownloadArgs::from_value(&args).unwrap();
assert_eq!(parsed.max_bytes, 1024);
}
#[test]
fn upload_args_round_trip() {
let bytes = b"hello world".to_vec();
let mut args = HashMap::new();
args.insert("filename".to_string(), Value::String("hello.txt".into()));
args.insert(
"content_type".to_string(),
Value::String("text/plain".into()),
);
args.insert(
"content_base64".to_string(),
Value::String(B64.encode(&bytes)),
);
let parsed = UploadArgs::from_wire(&args).unwrap();
assert_eq!(parsed.bytes, bytes);
assert_eq!(parsed.filename, "hello.txt");
assert_eq!(parsed.content_type.as_deref(), Some("text/plain"));
}
#[test]
fn upload_args_path_traversal_stripped() {
let mut args = HashMap::new();
args.insert(
"filename".to_string(),
Value::String("../../etc/passwd".into()),
);
args.insert(
"content_base64".to_string(),
Value::String(B64.encode(b"x")),
);
let parsed = UploadArgs::from_wire(&args).unwrap();
assert_eq!(parsed.filename, "passwd");
}
#[test]
fn upload_args_missing_filename() {
let mut args = HashMap::new();
args.insert(
"content_base64".to_string(),
Value::String(B64.encode(b"x")),
);
assert!(UploadArgs::from_wire(&args).is_err());
}
#[test]
fn upload_args_invalid_base64() {
let mut args = HashMap::new();
args.insert("filename".to_string(), Value::String("a".into()));
args.insert(
"content_base64".to_string(),
Value::String("!!! not base64 !!!".into()),
);
assert!(UploadArgs::from_wire(&args).is_err());
}
#[test]
fn build_download_response_includes_base64() {
let bytes = b"hello".to_vec();
let result = DownloadResult {
bytes,
content_type: Some("text/plain".into()),
source_url: "https://example.com/h".into(),
};
let v = build_download_response(&result);
assert_eq!(v["size_bytes"], 5);
assert_eq!(v["content_type"], "text/plain");
assert!(v["content_base64"].as_str().is_some());
}
#[test]
fn host_pattern_exact_match() {
assert!(host_matches_pattern("v3b.fal.media", "v3b.fal.media"));
assert!(!host_matches_pattern("evil.com", "v3b.fal.media"));
assert!(host_matches_pattern("V3B.FAL.MEDIA", "v3b.fal.media"));
}
#[test]
fn host_pattern_subdomain_wildcard() {
assert!(host_matches_pattern("v3b.fal.media", "*.fal.media"));
assert!(host_matches_pattern("cdn.fal.media", "*.fal.media"));
assert!(host_matches_pattern("fal.media", "*.fal.media"));
assert!(!host_matches_pattern("evil.com", "*.fal.media"));
assert!(!host_matches_pattern("evilfal.media", "*.fal.media"));
}
#[test]
fn host_pattern_bare_wildcard_matches_anything() {
assert!(host_matches_pattern("anywhere.com", "*"));
}
fn make_destinations() -> HashMap<String, UploadDestination> {
let mut m = HashMap::new();
m.insert(
"gcs".to_string(),
UploadDestination::Gcs {
bucket: "b".to_string(),
prefix: "p".to_string(),
key_ref: "gcp_credentials".to_string(),
},
);
m.insert(
"fal".to_string(),
UploadDestination::FalStorage {
key_ref: "fal_api_key".to_string(),
endpoint: None,
},
);
m
}
#[test]
fn resolve_destination_picks_explicit_key() {
let m = make_destinations();
let (k, sink) = resolve_destination(&m, Some("gcs"), Some("fal")).unwrap();
assert_eq!(k, "fal");
assert!(matches!(sink, UploadDestination::FalStorage { .. }));
}
#[test]
fn resolve_destination_falls_back_to_default() {
let m = make_destinations();
let (k, _) = resolve_destination(&m, Some("gcs"), None).unwrap();
assert_eq!(k, "gcs");
}
#[test]
fn resolve_destination_unknown_key_rejected() {
let m = make_destinations();
let err = resolve_destination(&m, Some("gcs"), Some("evil")).unwrap_err();
assert!(matches!(err, FileManagerError::UnknownDestination(ref s) if s == "evil"));
}
#[test]
fn resolve_destination_empty_map_not_configured() {
let m: HashMap<String, UploadDestination> = HashMap::new();
let err = resolve_destination(&m, None, None).unwrap_err();
assert!(matches!(err, FileManagerError::UploadNotConfigured));
}
#[test]
fn resolve_destination_no_default_no_request_not_configured() {
let m = make_destinations();
let err = resolve_destination(&m, None, None).unwrap_err();
assert!(matches!(err, FileManagerError::UploadNotConfigured));
}
#[test]
fn require_public_https_accepts_public_https() {
assert!(require_public_https_url("https://v3b.fal.media/files/upload").is_ok());
}
#[test]
fn require_public_https_rejects_http_scheme() {
let err = require_public_https_url("http://v3b.fal.media/files/upload").unwrap_err();
assert!(
matches!(&err, FileManagerError::Upload(m) if m.contains("non-HTTPS")),
"unexpected error: {err:?}"
);
}
#[test]
fn require_public_https_rejects_loopback_hostname() {
let err = require_public_https_url("https://localhost/files/upload").unwrap_err();
assert!(matches!(&err, FileManagerError::Upload(m) if m.contains("private")));
}
#[test]
fn require_public_https_rejects_metadata_ip() {
let err = require_public_https_url("https://169.254.169.254/").unwrap_err();
assert!(matches!(&err, FileManagerError::Upload(m) if m.contains("private")));
}
#[test]
fn require_public_https_rejects_rfc1918() {
assert!(require_public_https_url("https://10.0.0.1/x").is_err());
assert!(require_public_https_url("https://192.168.1.1/x").is_err());
assert!(require_public_https_url("https://172.16.0.1/x").is_err());
}
#[test]
fn require_public_https_rejects_link_local_ipv6() {
assert!(require_public_https_url("https://[fe80::1]/x").is_err());
}
#[test]
fn require_public_https_rejects_ipv4_mapped_metadata_address() {
assert!(require_public_https_url("https://[::ffff:169.254.169.254]/").is_err());
}
#[test]
fn require_public_https_rejects_ipv4_mapped_loopback() {
assert!(require_public_https_url("https://[::ffff:127.0.0.1]/x").is_err());
}
#[test]
fn require_public_https_rejects_ipv4_mapped_rfc1918() {
assert!(require_public_https_url("https://[::ffff:10.0.0.1]/x").is_err());
assert!(require_public_https_url("https://[::ffff:192.168.1.1]/x").is_err());
assert!(require_public_https_url("https://[::ffff:172.16.0.1]/x").is_err());
}
#[test]
fn require_public_https_rejects_ipv4_mapped_cgnat() {
assert!(require_public_https_url("https://[::ffff:100.64.0.1]/x").is_err());
}
#[test]
fn require_public_https_rejects_dotinternal_hostname() {
assert!(require_public_https_url("https://storage.internal/x").is_err());
assert!(require_public_https_url("https://api.local/x").is_err());
}
#[test]
fn require_public_https_rejects_malformed_url() {
assert!(require_public_https_url("not a url").is_err());
}
}