Skip to main content

ati/core/
file_manager.rs

1//! File manager — `file_manager:download` / `file_manager:upload` virtual
2//! tools. Registered automatically with no TOML manifest so sandboxed agents
3//! can move binary bytes through the proxy (network egress is otherwise
4//! confined to the proxy host).
5//!
6//! In proxy mode the proxy performs the fetch/upload; bytes travel over the
7//! `/call` JSON wire as base64. The sandbox-side CLI materializes them to
8//! disk (`--out`) or ships them (`--path`). Local mode does the work inline.
9
10use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
11use serde_json::{json, Value};
12use std::collections::HashMap;
13use std::net::ToSocketAddrs;
14use std::time::Duration;
15use thiserror::Error;
16
17/// Default ceiling on download/upload size (500 MB).
18pub const DEFAULT_MAX_BYTES: u64 = 500 * 1024 * 1024;
19/// Default timeout for the upstream HTTP fetch.
20pub const DEFAULT_TIMEOUT_SECS: u64 = 120;
21/// Hard ceiling on upload payload accepted by the proxy (1 GB).
22pub const MAX_UPLOAD_BYTES: u64 = 1024 * 1024 * 1024;
23
24#[derive(Error, Debug)]
25pub enum FileManagerError {
26    #[error("Missing required argument: {0}")]
27    MissingArg(&'static str),
28    #[error("Invalid argument '{name}': {reason}")]
29    InvalidArg { name: &'static str, reason: String },
30    #[error("URL is not allowed (private/internal address): {0}")]
31    PrivateUrl(String),
32    #[error("Host '{host}' is not in the download allowlist")]
33    HostNotAllowed { host: String },
34    #[error("Invalid URL: {0}")]
35    InvalidUrl(String),
36    #[error("HTTP error fetching '{url}': {source}")]
37    Http {
38        url: String,
39        #[source]
40        source: reqwest::Error,
41    },
42    #[error("Upstream returned status {status} for '{url}': {body}")]
43    Upstream {
44        url: String,
45        status: u16,
46        body: String,
47    },
48    #[error("Response exceeds max-bytes ({limit} bytes)")]
49    SizeCap { limit: u64 },
50    #[error("Invalid extra header '{name}': {reason}")]
51    BadHeader { name: String, reason: String },
52    #[error("Failed to read file '{path}': {source}")]
53    Io {
54        path: String,
55        #[source]
56        source: std::io::Error,
57    },
58    #[error("Upload destinations not configured on the proxy — operator must declare `[provider.upload_destinations.<name>]` in `manifests/file_manager.toml`")]
59    UploadNotConfigured,
60    #[error("Unknown upload destination '{0}' — not in the operator's allowlist")]
61    UnknownDestination(String),
62    #[error("Upload failed: {0}")]
63    Upload(String),
64    #[error("Invalid base64 in upload payload: {0}")]
65    Base64(#[from] base64::DecodeError),
66}
67
68impl FileManagerError {
69    /// HTTP status this variant should map to when surfaced over the proxy
70    /// `POST /call` endpoint. Kept here (rather than in `proxy/server.rs`)
71    /// so adding a new error variant doesn't silently default to 500 in one
72    /// handler and 400 in another.
73    pub fn http_status(&self) -> u16 {
74        match self {
75            Self::MissingArg(_)
76            | Self::InvalidArg { .. }
77            | Self::BadHeader { .. }
78            | Self::Base64(_) => 400,
79            Self::PrivateUrl(_) | Self::HostNotAllowed { .. } | Self::UnknownDestination(_) => 403,
80            Self::SizeCap { .. } => 413,
81            Self::UploadNotConfigured => 503,
82            Self::Upstream { status, .. } => (*status).clamp(400, 599),
83            Self::Http { .. } | Self::InvalidUrl(_) | Self::Upload(_) => 502,
84            Self::Io { .. } => 500,
85        }
86    }
87}
88
89/// Headers an agent must not be able to set on outbound downloads.
90const DENIED_DOWNLOAD_HEADERS: &[&str] = &[
91    "host",
92    "content-length",
93    "transfer-encoding",
94    "connection",
95    "proxy-authorization",
96];
97
98/// Validate caller-supplied headers against the deny-list.
99fn validate_extra_headers(headers: &HashMap<String, String>) -> Result<(), FileManagerError> {
100    for name in headers.keys() {
101        let lower = name.to_lowercase();
102        if DENIED_DOWNLOAD_HEADERS.contains(&lower.as_str()) {
103            return Err(FileManagerError::BadHeader {
104                name: name.clone(),
105                reason: "header is not allowed".into(),
106            });
107        }
108        if !name.bytes().all(|b| b.is_ascii() && b > 32 && b != b':') {
109            return Err(FileManagerError::BadHeader {
110                name: name.clone(),
111                reason: "header name contains invalid characters".into(),
112            });
113        }
114    }
115    Ok(())
116}
117
118/// Parsed download arguments.
119#[derive(Debug, Clone)]
120pub struct DownloadArgs {
121    pub url: String,
122    pub max_bytes: u64,
123    pub timeout: Duration,
124    pub follow_redirects: bool,
125    pub headers: HashMap<String, String>,
126}
127
128impl DownloadArgs {
129    pub fn from_value(args: &HashMap<String, Value>) -> Result<Self, FileManagerError> {
130        let url = args
131            .get("url")
132            .and_then(|v| v.as_str())
133            .ok_or(FileManagerError::MissingArg("url"))?
134            .trim()
135            .to_string();
136        if url.is_empty() {
137            return Err(FileManagerError::MissingArg("url"));
138        }
139
140        let max_bytes = parse_u64_arg(args, &["max_bytes", "max-bytes"], "max_bytes")?
141            .unwrap_or(DEFAULT_MAX_BYTES);
142        if max_bytes == 0 {
143            return Err(FileManagerError::InvalidArg {
144                name: "max_bytes",
145                reason: "must be > 0".into(),
146            });
147        }
148
149        let timeout_secs =
150            parse_u64_arg(args, &["timeout"], "timeout")?.unwrap_or(DEFAULT_TIMEOUT_SECS);
151
152        let follow_redirects = args
153            .get("follow_redirects")
154            .or_else(|| args.get("follow-redirects"))
155            .and_then(|v| v.as_bool())
156            .unwrap_or(true);
157
158        let headers = parse_headers(args.get("headers"))?;
159        validate_extra_headers(&headers)?;
160
161        Ok(DownloadArgs {
162            url,
163            max_bytes,
164            timeout: Duration::from_secs(timeout_secs),
165            follow_redirects,
166            headers,
167        })
168    }
169}
170
171/// Look up an optional u64 arg under any of several aliases (to handle both
172/// `max_bytes` and `max-bytes` from CLI arg normalization). Accepts JSON
173/// numbers or numeric strings.
174fn parse_u64_arg(
175    args: &HashMap<String, Value>,
176    aliases: &[&str],
177    field: &'static str,
178) -> Result<Option<u64>, FileManagerError> {
179    let raw = aliases.iter().find_map(|k| args.get(*k));
180    let Some(v) = raw else {
181        return Ok(None);
182    };
183    let err = || FileManagerError::InvalidArg {
184        name: field,
185        reason: "must be a positive integer".into(),
186    };
187    match v {
188        Value::Number(n) => n.as_u64().map(Some).ok_or_else(err),
189        Value::String(s) => s
190            .parse::<u64>()
191            .map(Some)
192            .map_err(|e| FileManagerError::InvalidArg {
193                name: field,
194                reason: e.to_string(),
195            }),
196        _ => Err(err()),
197    }
198}
199
200/// Parse a `headers` argument that may be a JSON object or a JSON-encoded string.
201fn parse_headers(value: Option<&Value>) -> Result<HashMap<String, String>, FileManagerError> {
202    let value = match value {
203        Some(v) => v,
204        None => return Ok(HashMap::new()),
205    };
206    let map = match value {
207        Value::Object(map) => map.clone(),
208        Value::String(s) if s.trim().is_empty() => return Ok(HashMap::new()),
209        Value::String(s) => match serde_json::from_str::<Value>(s) {
210            Ok(Value::Object(map)) => map,
211            Ok(_) => {
212                return Err(FileManagerError::InvalidArg {
213                    name: "headers",
214                    reason: "must be a JSON object".into(),
215                });
216            }
217            Err(e) => {
218                return Err(FileManagerError::InvalidArg {
219                    name: "headers",
220                    reason: format!("invalid JSON: {e}"),
221                });
222            }
223        },
224        Value::Null => return Ok(HashMap::new()),
225        _ => {
226            return Err(FileManagerError::InvalidArg {
227                name: "headers",
228                reason: "must be a JSON object or JSON string".into(),
229            });
230        }
231    };
232    let mut out = HashMap::with_capacity(map.len());
233    for (k, v) in map {
234        let s = match v {
235            Value::String(s) => s,
236            Value::Number(n) => n.to_string(),
237            Value::Bool(b) => b.to_string(),
238            _ => {
239                return Err(FileManagerError::InvalidArg {
240                    name: "headers",
241                    reason: format!("value for '{k}' must be a string, number, or bool"),
242                });
243            }
244        };
245        out.insert(k, s);
246    }
247    Ok(out)
248}
249
250/// Result of a successful download — the bytes plus discovered metadata.
251/// Intentionally NOT `Clone` — `bytes` can be up to `DEFAULT_MAX_BYTES`.
252#[derive(Debug)]
253pub struct DownloadResult {
254    pub bytes: Vec<u8>,
255    pub content_type: Option<String>,
256    pub source_url: String,
257}
258
259/// Read the `ATI_DOWNLOAD_ALLOWLIST` env var. Returns `None` if unset or empty
260/// (meaning "no allowlist configured"); returns `Some(patterns)` otherwise.
261///
262/// Patterns are comma-separated and case-insensitive. Each pattern is one of:
263/// - exact host: `v3b.fal.media`
264/// - subdomain wildcard: `*.fal.media` matches `v3b.fal.media`, `cdn.fal.media`, etc.
265/// - bare wildcard: `*` matches anything (NOT recommended — defeats the purpose)
266fn allowlist_patterns() -> Option<Vec<String>> {
267    let raw = std::env::var("ATI_DOWNLOAD_ALLOWLIST").ok()?;
268    let patterns: Vec<String> = raw
269        .split(',')
270        .map(|s| s.trim().to_lowercase())
271        .filter(|s| !s.is_empty())
272        .collect();
273    if patterns.is_empty() {
274        None
275    } else {
276        Some(patterns)
277    }
278}
279
280/// Returns true if `host` matches any of the configured allowlist patterns.
281fn host_matches_pattern(host: &str, pattern: &str) -> bool {
282    let host = host.to_lowercase();
283    if pattern == "*" {
284        return true;
285    }
286    if let Some(suffix) = pattern.strip_prefix("*.") {
287        return host == suffix || host.ends_with(&format!(".{suffix}"));
288    }
289    host == pattern
290}
291
292/// Reject the URL if `ATI_DOWNLOAD_ALLOWLIST` is set and the host doesn't match.
293/// When the env var is unset or empty, downloads to any (non-private) host are
294/// allowed — local-mode operators who want a wide-open dev experience can leave
295/// the allowlist off; production proxies should always set it.
296pub fn enforce_download_allowlist(url: &str) -> Result<(), FileManagerError> {
297    let patterns = match allowlist_patterns() {
298        Some(p) => p,
299        None => return Ok(()),
300    };
301    let parsed = reqwest::Url::parse(url)
302        .map_err(|e| FileManagerError::InvalidUrl(format!("could not parse URL: {e}")))?;
303    let host = parsed
304        .host_str()
305        .ok_or_else(|| FileManagerError::InvalidUrl("URL has no host component".into()))?;
306
307    if patterns.iter().any(|p| host_matches_pattern(host, p)) {
308        Ok(())
309    } else {
310        Err(FileManagerError::HostNotAllowed {
311            host: host.to_string(),
312        })
313    }
314}
315
316/// Perform the actual HTTP fetch. Streams the body and aborts if it exceeds `max_bytes`.
317///
318/// Applies SSRF protection per `crate::core::http::validate_url_not_private`,
319/// then enforces the download host allowlist (env `ATI_DOWNLOAD_ALLOWLIST`).
320pub async fn fetch_bytes(args: &DownloadArgs) -> Result<DownloadResult, FileManagerError> {
321    crate::core::http::validate_url_not_private(&args.url).map_err(|e| match e {
322        crate::core::http::HttpError::SsrfBlocked(url) => FileManagerError::PrivateUrl(url),
323        other => FileManagerError::InvalidUrl(other.to_string()),
324    })?;
325
326    enforce_download_allowlist(&args.url)?;
327
328    let redirect_policy = if args.follow_redirects {
329        reqwest::redirect::Policy::limited(10)
330    } else {
331        reqwest::redirect::Policy::none()
332    };
333
334    let client = reqwest::Client::builder()
335        .timeout(args.timeout)
336        .redirect(redirect_policy)
337        .build()
338        .map_err(|e| FileManagerError::Http {
339            url: args.url.clone(),
340            source: e,
341        })?;
342
343    let mut req = client.get(&args.url);
344    for (k, v) in &args.headers {
345        req = req.header(k.as_str(), v.as_str());
346    }
347
348    let response = req.send().await.map_err(|e| FileManagerError::Http {
349        url: args.url.clone(),
350        source: e,
351    })?;
352    let status = response.status();
353    let content_type = response
354        .headers()
355        .get(reqwest::header::CONTENT_TYPE)
356        .and_then(|v| v.to_str().ok())
357        .map(|s| s.to_string());
358
359    if !status.is_success() {
360        let body = response.text().await.unwrap_or_default();
361        let truncated = if body.len() > 512 {
362            &body[..512]
363        } else {
364            &body
365        };
366        return Err(FileManagerError::Upstream {
367            url: args.url.clone(),
368            status: status.as_u16(),
369            body: truncated.to_string(),
370        });
371    }
372
373    // Pre-flight against Content-Length when present, and use it to seed the
374    // accumulator's capacity so we avoid ~log2(N) regrow memcpy cycles for
375    // large downloads.
376    let content_length = response
377        .headers()
378        .get(reqwest::header::CONTENT_LENGTH)
379        .and_then(|h| h.to_str().ok())
380        .and_then(|s| s.parse::<u64>().ok());
381    if let Some(len) = content_length {
382        if len > args.max_bytes {
383            return Err(FileManagerError::SizeCap {
384                limit: args.max_bytes,
385            });
386        }
387    }
388
389    // Stream the body so we can abort early on oversize. Cap the preallocation
390    // at `max_bytes` so a spoofed Content-Length can't force a huge allocation.
391    use futures::StreamExt;
392    let initial_cap = content_length
393        .map(|l| l.min(args.max_bytes) as usize)
394        .unwrap_or(64 * 1024);
395    let mut bytes = Vec::with_capacity(initial_cap);
396    let mut stream = response.bytes_stream();
397    while let Some(chunk) = stream.next().await {
398        let chunk = chunk.map_err(|e| FileManagerError::Http {
399            url: args.url.clone(),
400            source: e,
401        })?;
402        if (bytes.len() as u64).saturating_add(chunk.len() as u64) > args.max_bytes {
403            return Err(FileManagerError::SizeCap {
404                limit: args.max_bytes,
405            });
406        }
407        bytes.extend_from_slice(&chunk);
408    }
409
410    Ok(DownloadResult {
411        bytes,
412        content_type,
413        source_url: args.url.clone(),
414    })
415}
416
417/// Build the JSON response payload that the proxy / local-mode core returns
418/// to the CLI. Always carries `content_base64` so the CLI can write to `--out`
419/// or print inline depending on caller intent.
420pub fn build_download_response(result: &DownloadResult) -> Value {
421    json!({
422        "success": true,
423        "size_bytes": result.bytes.len(),
424        "content_type": result.content_type,
425        "source_url": result.source_url,
426        "content_base64": B64.encode(&result.bytes),
427    })
428}
429
430/// Best-effort MIME type from a path's extension. Shared across
431/// `file_manager:*` tools and CLI output capture. Falls back to octet-stream.
432pub fn guess_content_type(path: &str) -> &'static str {
433    let lower = path.to_ascii_lowercase();
434    let ext = lower.rsplit('.').next().unwrap_or("");
435    match ext {
436        "png" => "image/png",
437        "jpg" | "jpeg" => "image/jpeg",
438        "gif" => "image/gif",
439        "webp" => "image/webp",
440        "svg" => "image/svg+xml",
441        "pdf" => "application/pdf",
442        "mp4" | "m4v" => "video/mp4",
443        "mov" => "video/quicktime",
444        "webm" => "video/webm",
445        "mp3" => "audio/mpeg",
446        "wav" => "audio/wav",
447        "ogg" | "oga" => "audio/ogg",
448        "flac" => "audio/flac",
449        "m4a" => "audio/mp4",
450        "csv" => "text/csv",
451        "json" => "application/json",
452        "xml" => "application/xml",
453        "zip" => "application/zip",
454        "html" | "htm" => "text/html",
455        "md" => "text/markdown",
456        "txt" | "log" => "text/plain",
457        _ => "application/octet-stream",
458    }
459}
460
461// ---------------------------------------------------------------------------
462// Upload
463// ---------------------------------------------------------------------------
464
465/// Parsed upload arguments — what the caller needs to send to the proxy.
466/// Intentionally NOT `Clone` — `bytes` can be up to `MAX_UPLOAD_BYTES` and
467/// cloning it would be a costly footgun. Each sink consumes `args` by value.
468#[derive(Debug)]
469pub struct UploadArgs {
470    pub filename: String,
471    pub content_type: Option<String>,
472    pub bytes: Vec<u8>,
473    /// Destination key from the proxy's allowlist. `None` means "use the
474    /// operator-configured default."
475    pub destination: Option<String>,
476}
477
478impl UploadArgs {
479    /// Decode upload args sent over the wire (base64 + filename + content_type
480    /// + optional destination).
481    pub fn from_wire(args: &HashMap<String, Value>) -> Result<Self, FileManagerError> {
482        let filename = args
483            .get("filename")
484            .and_then(|v| v.as_str())
485            .map(|s| s.trim().to_string())
486            .filter(|s| !s.is_empty())
487            .ok_or(FileManagerError::MissingArg("filename"))?;
488        let content_type = args
489            .get("content_type")
490            .or_else(|| args.get("content-type"))
491            .and_then(|v| v.as_str())
492            .map(|s| s.to_string());
493        let b64 = args
494            .get("content_base64")
495            .or_else(|| args.get("content-base64"))
496            .and_then(|v| v.as_str())
497            .ok_or(FileManagerError::MissingArg("content_base64"))?;
498        let bytes = B64.decode(b64.as_bytes())?;
499        if (bytes.len() as u64) > MAX_UPLOAD_BYTES {
500            return Err(FileManagerError::SizeCap {
501                limit: MAX_UPLOAD_BYTES,
502            });
503        }
504        let destination = args
505            .get("destination")
506            .and_then(|v| v.as_str())
507            .map(|s| s.trim().to_string())
508            .filter(|s| !s.is_empty());
509        Ok(UploadArgs {
510            filename: sanitize_filename(&filename),
511            content_type,
512            bytes,
513            destination,
514        })
515    }
516}
517
518/// Strip directory components and disallow path traversal in the filename
519/// the agent gave us — we use it as the GCS object key.
520fn sanitize_filename(input: &str) -> String {
521    let trimmed = input.trim_matches(|c: char| c == '/' || c.is_whitespace());
522    let last = trimmed.rsplit('/').next().unwrap_or(trimmed);
523    let cleaned: String = last.chars().filter(|c| !c.is_control()).collect::<String>();
524    if cleaned.is_empty() || cleaned == "." || cleaned == ".." {
525        format!("upload-{}", chrono::Utc::now().timestamp_millis())
526    } else {
527        cleaned
528    }
529}
530
531/// Outcome of a successful upload — what the proxy returns to the CLI.
532#[derive(Debug)]
533pub struct UploadResult {
534    pub url: String,
535    pub size_bytes: u64,
536    pub content_type: String,
537    /// Which configured destination key was used.
538    pub destination: String,
539}
540
541// ---------------------------------------------------------------------------
542// Upload destination allowlist
543// ---------------------------------------------------------------------------
544
545/// One typed sink the operator's manifest declares as a permitted upload
546/// destination. The agent can pick from these keys via `--destination <key>`;
547/// anything else is refused with a typed error.
548#[derive(Debug, Clone, serde::Deserialize)]
549#[serde(tag = "kind", rename_all = "snake_case")]
550pub enum UploadDestination {
551    /// Google Cloud Storage bucket. Object goes to `<bucket>/<prefix>/<date>/<uuid>-<filename>`.
552    /// `key_ref` names a keyring key holding the GCP service account JSON.
553    Gcs {
554        bucket: String,
555        #[serde(default = "default_gcs_prefix")]
556        prefix: String,
557        #[serde(default = "default_gcs_key_ref")]
558        key_ref: String,
559    },
560    /// fal.ai CDN — uploads via fal's signed-token storage flow.
561    /// `key_ref` names a keyring key holding the fal API key.
562    /// `endpoint` overrides the REST base (default `https://rest.alpha.fal.ai`).
563    FalStorage {
564        #[serde(default = "default_fal_key_ref")]
565        key_ref: String,
566        #[serde(default)]
567        endpoint: Option<String>,
568    },
569}
570
571fn default_gcs_prefix() -> String {
572    "ati-uploads".to_string()
573}
574
575fn default_gcs_key_ref() -> String {
576    "gcp_credentials".to_string()
577}
578
579fn default_fal_key_ref() -> String {
580    "fal_api_key".to_string()
581}
582
583/// Resolve a caller-supplied (or omitted) destination key against the operator
584/// manifest's allowlist. Refuses any key not in the map with a typed error.
585pub fn resolve_destination<'a>(
586    destinations: &'a HashMap<String, UploadDestination>,
587    default: Option<&str>,
588    requested: Option<&str>,
589) -> Result<(String, &'a UploadDestination), FileManagerError> {
590    if destinations.is_empty() {
591        return Err(FileManagerError::UploadNotConfigured);
592    }
593    let key = match requested {
594        Some(k) if !k.is_empty() => k.to_string(),
595        _ => default
596            .map(|s| s.to_string())
597            .ok_or(FileManagerError::UploadNotConfigured)?,
598    };
599    let sink = destinations
600        .get(&key)
601        .ok_or_else(|| FileManagerError::UnknownDestination(key.clone()))?;
602    Ok((key, sink))
603}
604
605pub fn build_upload_response(result: &UploadResult) -> Value {
606    json!({
607        "success": true,
608        "url": result.url,
609        "size_bytes": result.size_bytes,
610        "content_type": result.content_type,
611        "destination": result.destination,
612    })
613}
614
615/// Dispatch an upload to one of the operator-allowlisted destinations.
616/// Resolves the requested key (or default) against the manifest's destinations
617/// map, then routes to the typed sink. Refuses any key not in the map.
618pub async fn upload_to_destination(
619    args: UploadArgs,
620    destinations: &HashMap<String, UploadDestination>,
621    default: Option<&str>,
622    keyring: &crate::core::keyring::Keyring,
623) -> Result<Value, FileManagerError> {
624    let (key, sink) = resolve_destination(destinations, default, args.destination.as_deref())?;
625    let result = match sink {
626        UploadDestination::Gcs {
627            bucket,
628            prefix,
629            key_ref,
630        } => upload_to_gcs(args, bucket, prefix, key_ref, keyring, &key).await?,
631        UploadDestination::FalStorage { key_ref, endpoint } => {
632            upload_to_fal(args, key_ref, endpoint.as_deref(), keyring, &key).await?
633        }
634    };
635    Ok(build_upload_response(&result))
636}
637
638async fn upload_to_gcs(
639    args: UploadArgs,
640    bucket: &str,
641    prefix: &str,
642    key_ref: &str,
643    keyring: &crate::core::keyring::Keyring,
644    destination_key: &str,
645) -> Result<UploadResult, FileManagerError> {
646    let service_account_json = keyring
647        .get(key_ref)
648        .ok_or_else(|| {
649            FileManagerError::Upload(format!("keyring key '{key_ref}' missing for GCS upload"))
650        })?
651        .to_string();
652
653    let content_type = args
654        .content_type
655        .unwrap_or_else(|| "application/octet-stream".to_string());
656    let size_bytes = args.bytes.len() as u64;
657    let date = chrono::Utc::now().format("%Y-%m-%d");
658    let uuid = uuid::Uuid::new_v4();
659    let object_name = format!("{prefix}/{date}/{uuid}-{}", args.filename);
660
661    let client =
662        crate::core::gcs::GcsClient::new_read_write(bucket.to_string(), &service_account_json)
663            .map_err(|e| FileManagerError::Upload(e.to_string()))?;
664    let url = client
665        .upload_object(&object_name, args.bytes, &content_type)
666        .await
667        .map_err(|e| FileManagerError::Upload(e.to_string()))?;
668
669    Ok(UploadResult {
670        url,
671        size_bytes,
672        content_type,
673        destination: destination_key.to_string(),
674    })
675}
676
677/// Always-on SSRF guard for URLs that came from a remote server's response.
678///
679/// Applies to URLs derived from a third-party response rather than from agent
680/// input or operator config. Refuses non-HTTPS URLs and any host that
681/// resolves to a private/internal address.
682///
683/// Ignores the `ATI_SSRF_PROTECTION` env knob — that's for the
684/// agent-controlled-URL path where the operator might want unrestricted dev
685/// mode. Here we have no reason to ever trust a server-supplied internal
686/// address.
687fn require_public_https_url(url: &str) -> Result<(), FileManagerError> {
688    let parsed = reqwest::Url::parse(url)
689        .map_err(|e| FileManagerError::Upload(format!("server returned malformed URL: {e}")))?;
690    if parsed.scheme() != "https" {
691        return Err(FileManagerError::Upload(format!(
692            "refusing non-HTTPS URL from server: {url}"
693        )));
694    }
695    let host = parsed
696        .host_str()
697        .ok_or_else(|| FileManagerError::Upload(format!("server URL has no host: {url}")))?;
698    let host_lower = host.to_lowercase();
699    if host_lower == "localhost"
700        || host_lower == "metadata.google.internal"
701        || host_lower.ends_with(".internal")
702        || host_lower.ends_with(".local")
703    {
704        return Err(FileManagerError::Upload(format!(
705            "server URL targets a private hostname: {url}"
706        )));
707    }
708    let port = parsed.port_or_known_default().unwrap_or(443);
709    let ip_host = host.trim_matches(['[', ']']);
710    let is_private = if let Ok(ip) = ip_host.parse::<std::net::IpAddr>() {
711        is_private_ip_addr(ip)
712    } else if let Ok(addrs) = (ip_host, port).to_socket_addrs() {
713        addrs.into_iter().any(|addr| is_private_ip_addr(addr.ip()))
714    } else {
715        false
716    };
717    if is_private {
718        return Err(FileManagerError::Upload(format!(
719            "server URL resolves to a private address: {url}"
720        )));
721    }
722    Ok(())
723}
724
725fn is_private_ip_addr(ip: std::net::IpAddr) -> bool {
726    match ip {
727        std::net::IpAddr::V4(ip) => is_private_ipv4(ip),
728        std::net::IpAddr::V6(ip) => {
729            // IPv4-mapped IPv6 (::ffff:a.b.c.d): a compromised server could
730            // return a URL like `https://[::ffff:169.254.169.254]/` and bypass
731            // the v4-only private checks. Unwrap the mapped form and recurse
732            // through the v4 rules.
733            if let Some(v4) = ip.to_ipv4_mapped() {
734                return is_private_ipv4(v4);
735            }
736            ip.is_loopback()
737                || ip.is_unspecified()
738                || ip.is_unique_local()
739                || ip.is_unicast_link_local()
740        }
741    }
742}
743
744fn is_private_ipv4(ip: std::net::Ipv4Addr) -> bool {
745    ip.is_loopback()
746        || ip.is_private()
747        || ip.is_link_local()
748        || ip.is_unspecified()
749        // Carrier-grade NAT (RFC 6598): 100.64.0.0/10
750        || (ip.octets()[0] == 100 && ip.octets()[1] >= 64 && ip.octets()[1] <= 127)
751}
752
753/// Upload to fal.ai's CDN via their two-step signed-token flow.
754///
755/// 1. POST `<rest>/storage/auth/token?storage_type=fal-cdn-v3` with
756///    `Authorization: Key <api_key>` → `{token, token_type, base_url, expires_at}`
757/// 2. POST `<base_url or v3.fal.media>/files/upload` with the signed token,
758///    `Content-Type: <mime>`, `X-Fal-File-Name: <filename>`, body = bytes
759///    → `{access_url: "..."}`
760async fn upload_to_fal(
761    args: UploadArgs,
762    key_ref: &str,
763    endpoint: Option<&str>,
764    keyring: &crate::core::keyring::Keyring,
765    destination_key: &str,
766) -> Result<UploadResult, FileManagerError> {
767    use serde::Deserialize;
768
769    let api_key = keyring
770        .get(key_ref)
771        .ok_or_else(|| {
772            FileManagerError::Upload(format!("keyring key '{key_ref}' missing for fal upload"))
773        })?
774        .to_string();
775    let rest_base = endpoint.unwrap_or("https://rest.alpha.fal.ai");
776
777    let http = reqwest::Client::builder()
778        .timeout(std::time::Duration::from_secs(60))
779        .build()
780        .map_err(|e| FileManagerError::Upload(format!("http client init: {e}")))?;
781
782    // Step 1: mint signed token
783    let token_url = format!("{rest_base}/storage/auth/token?storage_type=fal-cdn-v3");
784    let token_resp = http
785        .post(&token_url)
786        .header("Authorization", format!("Key {api_key}"))
787        .header("Accept", "application/json")
788        .header("Content-Type", "application/json")
789        .body("{}")
790        .send()
791        .await
792        .map_err(|e| FileManagerError::Upload(format!("fal token request failed: {e}")))?;
793    if !token_resp.status().is_success() {
794        let status = token_resp.status().as_u16();
795        let body = token_resp.text().await.unwrap_or_default();
796        return Err(FileManagerError::Upload(format!(
797            "fal token mint returned {status}: {body}"
798        )));
799    }
800    #[derive(Deserialize)]
801    struct FalToken {
802        token: String,
803        token_type: String,
804        base_url: String,
805    }
806    let token: FalToken = token_resp
807        .json()
808        .await
809        .map_err(|e| FileManagerError::Upload(format!("fal token JSON parse failed: {e}")))?;
810
811    // Step 2: PUT bytes to <base_url>/files/upload
812    let content_type = args
813        .content_type
814        .unwrap_or_else(|| "application/octet-stream".to_string());
815    let size_bytes = args.bytes.len() as u64;
816    let upload_url = format!("{}/files/upload", token.base_url.trim_end_matches('/'));
817
818    // SSRF guard: the `base_url` came from fal's token response. A compromised
819    // or DNS-hijacked fal endpoint returning e.g. `base_url =
820    // "http://169.254.169.254/"` would otherwise cause the proxy to POST the
821    // file payload + signed token to that internal address. Always enforce —
822    // the env-gated `ATI_SSRF_PROTECTION` is for agent-supplied URLs where the
823    // operator might want unrestricted dev access; this is a server-supplied
824    // URL we can't trust unconditionally.
825    require_public_https_url(&upload_url)?;
826
827    let upload_resp = http
828        .post(&upload_url)
829        .header(
830            "Authorization",
831            format!("{} {}", token.token_type, token.token),
832        )
833        .header("Content-Type", &content_type)
834        .header("X-Fal-File-Name", &args.filename)
835        .body(args.bytes)
836        .send()
837        .await
838        .map_err(|e| FileManagerError::Upload(format!("fal upload request failed: {e}")))?;
839    if !upload_resp.status().is_success() {
840        let status = upload_resp.status().as_u16();
841        let body = upload_resp.text().await.unwrap_or_default();
842        return Err(FileManagerError::Upload(format!(
843            "fal upload returned {status}: {body}"
844        )));
845    }
846    #[derive(Deserialize)]
847    struct FalUploadResponse {
848        access_url: String,
849    }
850    let body: FalUploadResponse = upload_resp
851        .json()
852        .await
853        .map_err(|e| FileManagerError::Upload(format!("fal upload JSON parse failed: {e}")))?;
854
855    Ok(UploadResult {
856        url: body.access_url,
857        size_bytes,
858        content_type,
859        destination: destination_key.to_string(),
860    })
861}
862
863#[cfg(test)]
864mod tests {
865    use super::*;
866
867    #[test]
868    fn parse_headers_object() {
869        let v = serde_json::json!({"X-Test": "1", "X-Other": "abc"});
870        let map = parse_headers(Some(&v)).unwrap();
871        assert_eq!(map.len(), 2);
872        assert_eq!(map.get("X-Test").map(String::as_str), Some("1"));
873    }
874
875    #[test]
876    fn parse_headers_string_json() {
877        let v = Value::String(r#"{"Authorization":"Bearer abc"}"#.into());
878        let map = parse_headers(Some(&v)).unwrap();
879        assert_eq!(
880            map.get("Authorization").map(String::as_str),
881            Some("Bearer abc")
882        );
883    }
884
885    #[test]
886    fn parse_headers_empty_string() {
887        let v = Value::String("".into());
888        assert!(parse_headers(Some(&v)).unwrap().is_empty());
889    }
890
891    #[test]
892    fn parse_headers_invalid_type() {
893        let v = Value::Number(42.into());
894        assert!(parse_headers(Some(&v)).is_err());
895    }
896
897    #[test]
898    fn validate_denied_header() {
899        let mut map = HashMap::new();
900        map.insert("Host".to_string(), "evil.com".to_string());
901        assert!(validate_extra_headers(&map).is_err());
902    }
903
904    #[test]
905    fn download_args_defaults() {
906        let mut args = HashMap::new();
907        args.insert(
908            "url".to_string(),
909            Value::String("https://example.com".into()),
910        );
911        let parsed = DownloadArgs::from_value(&args).unwrap();
912        assert_eq!(parsed.max_bytes, DEFAULT_MAX_BYTES);
913        assert_eq!(parsed.timeout, Duration::from_secs(DEFAULT_TIMEOUT_SECS));
914        assert!(parsed.follow_redirects);
915        assert!(parsed.headers.is_empty());
916    }
917
918    #[test]
919    fn download_args_missing_url() {
920        let args = HashMap::new();
921        assert!(DownloadArgs::from_value(&args).is_err());
922    }
923
924    #[test]
925    fn download_args_zero_max_bytes_rejected() {
926        let mut args = HashMap::new();
927        args.insert(
928            "url".to_string(),
929            Value::String("https://example.com".into()),
930        );
931        args.insert("max_bytes".to_string(), Value::Number(0.into()));
932        assert!(DownloadArgs::from_value(&args).is_err());
933    }
934
935    #[test]
936    fn download_args_max_bytes_string() {
937        let mut args = HashMap::new();
938        args.insert(
939            "url".to_string(),
940            Value::String("https://example.com".into()),
941        );
942        args.insert("max_bytes".to_string(), Value::String("1024".into()));
943        let parsed = DownloadArgs::from_value(&args).unwrap();
944        assert_eq!(parsed.max_bytes, 1024);
945    }
946
947    #[test]
948    fn upload_args_round_trip() {
949        let bytes = b"hello world".to_vec();
950        let mut args = HashMap::new();
951        args.insert("filename".to_string(), Value::String("hello.txt".into()));
952        args.insert(
953            "content_type".to_string(),
954            Value::String("text/plain".into()),
955        );
956        args.insert(
957            "content_base64".to_string(),
958            Value::String(B64.encode(&bytes)),
959        );
960        let parsed = UploadArgs::from_wire(&args).unwrap();
961        assert_eq!(parsed.bytes, bytes);
962        assert_eq!(parsed.filename, "hello.txt");
963        assert_eq!(parsed.content_type.as_deref(), Some("text/plain"));
964    }
965
966    #[test]
967    fn upload_args_path_traversal_stripped() {
968        let mut args = HashMap::new();
969        args.insert(
970            "filename".to_string(),
971            Value::String("../../etc/passwd".into()),
972        );
973        args.insert(
974            "content_base64".to_string(),
975            Value::String(B64.encode(b"x")),
976        );
977        let parsed = UploadArgs::from_wire(&args).unwrap();
978        assert_eq!(parsed.filename, "passwd");
979    }
980
981    #[test]
982    fn upload_args_missing_filename() {
983        let mut args = HashMap::new();
984        args.insert(
985            "content_base64".to_string(),
986            Value::String(B64.encode(b"x")),
987        );
988        assert!(UploadArgs::from_wire(&args).is_err());
989    }
990
991    #[test]
992    fn upload_args_invalid_base64() {
993        let mut args = HashMap::new();
994        args.insert("filename".to_string(), Value::String("a".into()));
995        args.insert(
996            "content_base64".to_string(),
997            Value::String("!!! not base64 !!!".into()),
998        );
999        assert!(UploadArgs::from_wire(&args).is_err());
1000    }
1001
1002    #[test]
1003    fn build_download_response_includes_base64() {
1004        let bytes = b"hello".to_vec();
1005        let result = DownloadResult {
1006            bytes,
1007            content_type: Some("text/plain".into()),
1008            source_url: "https://example.com/h".into(),
1009        };
1010        let v = build_download_response(&result);
1011        assert_eq!(v["size_bytes"], 5);
1012        assert_eq!(v["content_type"], "text/plain");
1013        assert!(v["content_base64"].as_str().is_some());
1014    }
1015
1016    #[test]
1017    fn host_pattern_exact_match() {
1018        assert!(host_matches_pattern("v3b.fal.media", "v3b.fal.media"));
1019        assert!(!host_matches_pattern("evil.com", "v3b.fal.media"));
1020        assert!(host_matches_pattern("V3B.FAL.MEDIA", "v3b.fal.media"));
1021    }
1022
1023    #[test]
1024    fn host_pattern_subdomain_wildcard() {
1025        assert!(host_matches_pattern("v3b.fal.media", "*.fal.media"));
1026        assert!(host_matches_pattern("cdn.fal.media", "*.fal.media"));
1027        assert!(host_matches_pattern("fal.media", "*.fal.media"));
1028        assert!(!host_matches_pattern("evil.com", "*.fal.media"));
1029        // Don't match suffix-collision tricks like "evilfal.media"
1030        assert!(!host_matches_pattern("evilfal.media", "*.fal.media"));
1031    }
1032
1033    #[test]
1034    fn host_pattern_bare_wildcard_matches_anything() {
1035        assert!(host_matches_pattern("anywhere.com", "*"));
1036    }
1037
1038    fn make_destinations() -> HashMap<String, UploadDestination> {
1039        let mut m = HashMap::new();
1040        m.insert(
1041            "gcs".to_string(),
1042            UploadDestination::Gcs {
1043                bucket: "b".to_string(),
1044                prefix: "p".to_string(),
1045                key_ref: "gcp_credentials".to_string(),
1046            },
1047        );
1048        m.insert(
1049            "fal".to_string(),
1050            UploadDestination::FalStorage {
1051                key_ref: "fal_api_key".to_string(),
1052                endpoint: None,
1053            },
1054        );
1055        m
1056    }
1057
1058    #[test]
1059    fn resolve_destination_picks_explicit_key() {
1060        let m = make_destinations();
1061        let (k, sink) = resolve_destination(&m, Some("gcs"), Some("fal")).unwrap();
1062        assert_eq!(k, "fal");
1063        assert!(matches!(sink, UploadDestination::FalStorage { .. }));
1064    }
1065
1066    #[test]
1067    fn resolve_destination_falls_back_to_default() {
1068        let m = make_destinations();
1069        let (k, _) = resolve_destination(&m, Some("gcs"), None).unwrap();
1070        assert_eq!(k, "gcs");
1071    }
1072
1073    #[test]
1074    fn resolve_destination_unknown_key_rejected() {
1075        let m = make_destinations();
1076        let err = resolve_destination(&m, Some("gcs"), Some("evil")).unwrap_err();
1077        assert!(matches!(err, FileManagerError::UnknownDestination(ref s) if s == "evil"));
1078    }
1079
1080    #[test]
1081    fn resolve_destination_empty_map_not_configured() {
1082        let m: HashMap<String, UploadDestination> = HashMap::new();
1083        let err = resolve_destination(&m, None, None).unwrap_err();
1084        assert!(matches!(err, FileManagerError::UploadNotConfigured));
1085    }
1086
1087    #[test]
1088    fn resolve_destination_no_default_no_request_not_configured() {
1089        let m = make_destinations();
1090        let err = resolve_destination(&m, None, None).unwrap_err();
1091        assert!(matches!(err, FileManagerError::UploadNotConfigured));
1092    }
1093
1094    // Always-on SSRF guard for server-supplied URLs (e.g. fal's base_url).
1095    #[test]
1096    fn require_public_https_accepts_public_https() {
1097        assert!(require_public_https_url("https://v3b.fal.media/files/upload").is_ok());
1098    }
1099
1100    #[test]
1101    fn require_public_https_rejects_http_scheme() {
1102        let err = require_public_https_url("http://v3b.fal.media/files/upload").unwrap_err();
1103        assert!(
1104            matches!(&err, FileManagerError::Upload(m) if m.contains("non-HTTPS")),
1105            "unexpected error: {err:?}"
1106        );
1107    }
1108
1109    #[test]
1110    fn require_public_https_rejects_loopback_hostname() {
1111        let err = require_public_https_url("https://localhost/files/upload").unwrap_err();
1112        assert!(matches!(&err, FileManagerError::Upload(m) if m.contains("private")));
1113    }
1114
1115    #[test]
1116    fn require_public_https_rejects_metadata_ip() {
1117        // GCP metadata service
1118        let err = require_public_https_url("https://169.254.169.254/").unwrap_err();
1119        assert!(matches!(&err, FileManagerError::Upload(m) if m.contains("private")));
1120    }
1121
1122    #[test]
1123    fn require_public_https_rejects_rfc1918() {
1124        assert!(require_public_https_url("https://10.0.0.1/x").is_err());
1125        assert!(require_public_https_url("https://192.168.1.1/x").is_err());
1126        assert!(require_public_https_url("https://172.16.0.1/x").is_err());
1127    }
1128
1129    #[test]
1130    fn require_public_https_rejects_link_local_ipv6() {
1131        assert!(require_public_https_url("https://[fe80::1]/x").is_err());
1132    }
1133
1134    /// Regression: v1 of `is_private_ip_addr` missed IPv4-mapped IPv6 addresses,
1135    /// letting a compromised server bypass the SSRF guard with
1136    /// `::ffff:169.254.169.254` et al.
1137    #[test]
1138    fn require_public_https_rejects_ipv4_mapped_metadata_address() {
1139        assert!(require_public_https_url("https://[::ffff:169.254.169.254]/").is_err());
1140    }
1141
1142    #[test]
1143    fn require_public_https_rejects_ipv4_mapped_loopback() {
1144        assert!(require_public_https_url("https://[::ffff:127.0.0.1]/x").is_err());
1145    }
1146
1147    #[test]
1148    fn require_public_https_rejects_ipv4_mapped_rfc1918() {
1149        assert!(require_public_https_url("https://[::ffff:10.0.0.1]/x").is_err());
1150        assert!(require_public_https_url("https://[::ffff:192.168.1.1]/x").is_err());
1151        assert!(require_public_https_url("https://[::ffff:172.16.0.1]/x").is_err());
1152    }
1153
1154    #[test]
1155    fn require_public_https_rejects_ipv4_mapped_cgnat() {
1156        // 100.64.0.0/10 — carrier-grade NAT
1157        assert!(require_public_https_url("https://[::ffff:100.64.0.1]/x").is_err());
1158    }
1159
1160    #[test]
1161    fn require_public_https_rejects_dotinternal_hostname() {
1162        assert!(require_public_https_url("https://storage.internal/x").is_err());
1163        assert!(require_public_https_url("https://api.local/x").is_err());
1164    }
1165
1166    #[test]
1167    fn require_public_https_rejects_malformed_url() {
1168        assert!(require_public_https_url("not a url").is_err());
1169    }
1170}