Skip to main content

aft/
url_fetch.rs

1use std::error::Error;
2use std::fmt;
3use std::fs;
4use std::io::{self, Read, Write};
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
6use std::path::{Path, PathBuf};
7use std::sync::{mpsc, Arc};
8use std::thread;
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10
11use reqwest::blocking::{Client, Response as HttpResponse};
12use reqwest::header::{ACCEPT, CONTENT_TYPE, LOCATION, USER_AGENT};
13use reqwest::redirect::Policy;
14use serde::{Deserialize, Serialize};
15use sha2::{Digest, Sha256};
16use url::Url;
17
18const MAX_RESPONSE_BYTES: u64 = 10 * 1024 * 1024;
19const CACHE_TTL_MS: u64 = 24 * 60 * 60 * 1000;
20const CONNECT_TIMEOUT: Duration = Duration::from_millis(30_000);
21const BODY_CHUNK_TIMEOUT: Duration = Duration::from_millis(15_000);
22const MAX_REDIRECTS: usize = 5;
23
24/// Retry budget for transient connect/transport failures only. Agents
25/// shouldn't have to retry manually for a single TCP/TLS hiccup. We cap
26/// at 2 retries (= 3 total attempts) with short jittered backoff so a
27/// genuinely-broken host fails fast instead of dragging the foreground
28/// fetch out to many seconds.
29///
30/// We deliberately do NOT retry on:
31///   - HTTP error status (4xx/5xx) — the server actually answered
32///   - Redirect errors / SSRF rejections — those are deterministic
33///   - Body read stalls — already handled by BODY_CHUNK_TIMEOUT
34const TRANSIENT_RETRY_ATTEMPTS: usize = 2;
35const TRANSIENT_RETRY_BACKOFFS_MS: [u64; TRANSIENT_RETRY_ATTEMPTS] = [200, 600];
36const ACCEPT_HEADER: &str = "application/vnd.github.raw, text/markdown, text/x-markdown, text/html;q=0.9, application/json;q=0.8, text/plain;q=0.5";
37const USER_AGENT_VALUE: &str = "aft-opencode-plugin";
38
39#[derive(Clone, Default)]
40pub struct UrlFetchOptions {
41    pub allow_private: bool,
42    /// Test hook: treat a hostname as resolving to these IPs during SSRF validation.
43    /// Production callers leave this empty and use `std::net::ToSocketAddrs`.
44    #[doc(hidden)]
45    pub public_host_overrides: Vec<(String, Vec<IpAddr>)>,
46    /// Test hook: force reqwest to connect a hostname to a local mock server while
47    /// SSRF validation still sees `public_host_overrides` above.
48    #[doc(hidden)]
49    pub connect_overrides: Vec<(String, SocketAddr)>,
50    /// Test hook: observes the temp path immediately before the atomic rename.
51    #[doc(hidden)]
52    pub atomic_write_observer: Option<Arc<dyn Fn(&Path, &Path) + Send + Sync>>,
53}
54
55#[derive(Debug, Clone)]
56pub struct UrlFetchError {
57    message: String,
58}
59
60impl UrlFetchError {
61    fn new(message: impl Into<String>) -> Self {
62        Self {
63            message: message.into(),
64        }
65    }
66}
67
68impl fmt::Display for UrlFetchError {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.write_str(&self.message)
71    }
72}
73
74impl std::error::Error for UrlFetchError {}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct CacheMeta {
78    url: String,
79    #[serde(rename = "contentType")]
80    content_type: String,
81    extension: String,
82    #[serde(rename = "fetchedAt")]
83    fetched_at: u64,
84}
85
86pub fn is_http_url(value: &str) -> bool {
87    value.starts_with("http://") || value.starts_with("https://")
88}
89
90pub fn fetch_url_to_cache(
91    url: &str,
92    storage_dir: &Path,
93    options: UrlFetchOptions,
94) -> Result<PathBuf, UrlFetchError> {
95    let parsed = Url::parse(url).map_err(|_| UrlFetchError::new(format!("Invalid URL: {url}")))?;
96    validate_public_url(&parsed, &options)?;
97
98    let dir = cache_dir(storage_dir);
99    fs::create_dir_all(&dir).map_err(|error| {
100        UrlFetchError::new(format!(
101            "Failed to create URL cache directory {}: {error}",
102            dir.display()
103        ))
104    })?;
105
106    let hash = hash_url(url);
107    let meta_file = meta_path(storage_dir, &hash);
108    if let Some(cached) = fresh_cached_path(storage_dir, &hash, &meta_file)? {
109        return Ok(cached);
110    }
111
112    let response = fetch_with_redirects(&parsed, url, &options)?;
113    if !response.status().is_success() {
114        return Err(UrlFetchError::new(format!(
115            "HTTP {} {} fetching {url}",
116            response.status().as_u16(),
117            response.status().canonical_reason().unwrap_or("")
118        )));
119    }
120
121    let content_type = response
122        .headers()
123        .get(CONTENT_TYPE)
124        .and_then(|value| value.to_str().ok())
125        .unwrap_or("text/plain")
126        .to_string();
127    let extension = resolve_extension(&content_type).ok_or_else(|| {
128        UrlFetchError::new(format!(
129            "Unsupported content type '{content_type}' for {url}. Supported: text/html, text/markdown, application/json, text/plain"
130        ))
131    })?;
132
133    if let Some(length) = response.content_length() {
134        if length > MAX_RESPONSE_BYTES {
135            return Err(UrlFetchError::new(format!(
136                "Response too large: {length} bytes (max {MAX_RESPONSE_BYTES})"
137            )));
138        }
139    }
140
141    let body = read_response_body(response, url)?;
142    let content_file = content_path(storage_dir, &hash, extension);
143    atomic_write(&content_file, &body, &options)?;
144
145    let meta = CacheMeta {
146        url: url.to_string(),
147        content_type,
148        extension: extension.to_string(),
149        fetched_at: now_ms(),
150    };
151    let meta_bytes = serde_json::to_vec(&meta).map_err(|error| {
152        UrlFetchError::new(format!("Failed to encode URL cache metadata: {error}"))
153    })?;
154    atomic_write(&meta_file, &meta_bytes, &options)?;
155
156    Ok(content_file)
157}
158
159pub fn cleanup_url_cache(storage_dir: &Path) -> Result<usize, UrlFetchError> {
160    let dir = cache_dir(storage_dir);
161    if !dir.exists() {
162        return Ok(0);
163    }
164
165    let entries = fs::read_dir(&dir).map_err(|error| {
166        UrlFetchError::new(format!(
167            "URL cache cleanup failed reading {}: {error}",
168            dir.display()
169        ))
170    })?;
171    let mut removed = 0usize;
172    let now = now_ms();
173
174    for entry in entries.flatten() {
175        let path = entry.path();
176        let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
177            continue;
178        };
179        if !name.ends_with(".meta.json") {
180            continue;
181        }
182
183        let meta = fs::read_to_string(&path)
184            .ok()
185            .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok());
186        let Some(meta) = meta else {
187            if fs::remove_file(&path).is_ok() {
188                removed += 1;
189            }
190            continue;
191        };
192
193        if now.saturating_sub(meta.fetched_at) <= CACHE_TTL_MS {
194            continue;
195        }
196
197        let hash = name.trim_end_matches(".meta.json");
198        let content = content_path(storage_dir, hash, &meta.extension);
199        let _ = fs::remove_file(content);
200        if fs::remove_file(&path).is_ok() {
201            removed += 1;
202        }
203    }
204
205    Ok(removed)
206}
207
208#[doc(hidden)]
209pub fn cache_content_path_for_url(storage_dir: &Path, url: &str, extension: &str) -> PathBuf {
210    content_path(storage_dir, &hash_url(url), extension)
211}
212
213#[doc(hidden)]
214pub fn cache_meta_path_for_url(storage_dir: &Path, url: &str) -> PathBuf {
215    meta_path(storage_dir, &hash_url(url))
216}
217
218#[doc(hidden)]
219pub fn is_private_ip_for_test(ip: IpAddr) -> bool {
220    is_private_ip(ip)
221}
222
223fn cache_dir(storage_dir: &Path) -> PathBuf {
224    storage_dir.join("url_cache")
225}
226
227fn hash_url(url: &str) -> String {
228    let digest = Sha256::digest(url.as_bytes());
229    format!("{digest:x}").chars().take(16).collect()
230}
231
232fn meta_path(storage_dir: &Path, hash: &str) -> PathBuf {
233    cache_dir(storage_dir).join(format!("{hash}.meta.json"))
234}
235
236fn content_path(storage_dir: &Path, hash: &str, extension: &str) -> PathBuf {
237    cache_dir(storage_dir).join(format!("{hash}{extension}"))
238}
239
240fn fresh_cached_path(
241    storage_dir: &Path,
242    hash: &str,
243    meta_file: &Path,
244) -> Result<Option<PathBuf>, UrlFetchError> {
245    if !meta_file.exists() {
246        return Ok(None);
247    }
248
249    let meta = match fs::read_to_string(meta_file)
250        .ok()
251        .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok())
252    {
253        Some(meta) => meta,
254        None => return Ok(None),
255    };
256    let age = now_ms().saturating_sub(meta.fetched_at);
257    let cached = content_path(storage_dir, hash, &meta.extension);
258    if age < CACHE_TTL_MS && cached.exists() {
259        return Ok(Some(cached));
260    }
261    Ok(None)
262}
263
264fn fetch_with_redirects(
265    start_url: &Url,
266    original_url: &str,
267    options: &UrlFetchOptions,
268) -> Result<HttpResponse, UrlFetchError> {
269    let client = build_client(options)?;
270    let mut current_url = start_url.clone();
271
272    for redirect_count in 0..=MAX_REDIRECTS {
273        validate_public_url(&current_url, options)?;
274        let response = send_with_transient_retries(&client, &current_url)?;
275
276        if !response.status().is_redirection() {
277            return Ok(response);
278        }
279        if redirect_count == MAX_REDIRECTS {
280            return Err(UrlFetchError::new(format!(
281                "Too many redirects fetching {original_url}"
282            )));
283        }
284
285        let location = response
286            .headers()
287            .get(LOCATION)
288            .and_then(|value| value.to_str().ok())
289            .ok_or_else(|| {
290                UrlFetchError::new(format!(
291                    "Redirect from {} missing Location header",
292                    current_url.as_str()
293                ))
294            })?;
295        current_url = current_url.join(location).map_err(|error| {
296            UrlFetchError::new(format!(
297                "Invalid redirect Location '{location}' from {}: {error}",
298                current_url.as_str()
299            ))
300        })?;
301    }
302
303    Err(UrlFetchError::new(format!(
304        "Too many redirects fetching {original_url}"
305    )))
306}
307
308/// Issue a single GET with the configured User-Agent + Accept headers and
309/// transparently retry only on transient connect/transport failures.
310///
311/// Returns the response (including 4xx/5xx — caller decides how to treat
312/// those). On a non-transient reqwest error (e.g. an HTTP-shaped reply that
313/// reqwest still surfaces as Err, or a TLS handshake fault that doesn't read
314/// as `is_connect`), the original error is returned immediately so the user
315/// sees the real failure without an artificial 800ms-plus delay.
316fn send_with_transient_retries(
317    client: &Client,
318    target: &Url,
319) -> Result<HttpResponse, UrlFetchError> {
320    let mut last_error: Option<reqwest::Error> = None;
321    for attempt in 0..=TRANSIENT_RETRY_ATTEMPTS {
322        let result = client
323            .get(target.clone())
324            .header(USER_AGENT, USER_AGENT_VALUE)
325            .header(ACCEPT, ACCEPT_HEADER)
326            .send();
327        match result {
328            Ok(response) => return Ok(response),
329            Err(error) => {
330                if attempt < TRANSIENT_RETRY_ATTEMPTS && is_transient_reqwest_error(&error) {
331                    thread::sleep(Duration::from_millis(TRANSIENT_RETRY_BACKOFFS_MS[attempt]));
332                    last_error = Some(error);
333                    continue;
334                }
335                return Err(UrlFetchError::new(format!(
336                    "Failed to fetch {}: {}",
337                    target.as_str(),
338                    reqwest_error_detail(&error)
339                )));
340            }
341        }
342    }
343    // Loop fell through after the last allowed retry exhausted — surface the
344    // most recent transient error rather than swallowing it.
345    Err(UrlFetchError::new(format!(
346        "Failed to fetch {} after {} retries: {}",
347        target.as_str(),
348        TRANSIENT_RETRY_ATTEMPTS,
349        last_error
350            .as_ref()
351            .map(reqwest_error_detail)
352            .unwrap_or_else(|| "unknown transient error".to_string())
353    )))
354}
355
356/// Classify a reqwest error as transient (worth a quick retry) vs terminal.
357///
358/// Transient: TCP connect failures, request-build/send TCP-level failures
359/// that don't carry status, and timeouts. These typically clear on a single
360/// retry — agents shouldn't have to ask twice for a momentary blip.
361///
362/// Terminal: anything where reqwest got far enough to decode an HTTP-shaped
363/// reply (`is_status()`, `is_body()`, `is_decode()`). Retrying those would
364/// just hammer a server that already answered.
365fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
366    error.is_connect() || error.is_timeout() || error.is_request()
367}
368
369fn build_client(options: &UrlFetchOptions) -> Result<Client, UrlFetchError> {
370    let mut builder = Client::builder()
371        .redirect(Policy::none())
372        .connect_timeout(CONNECT_TIMEOUT);
373
374    for (host, address) in &options.connect_overrides {
375        builder = builder.resolve(host, *address);
376    }
377
378    builder
379        .build()
380        .map_err(|error| UrlFetchError::new(format!("Failed to build URL fetch client: {error}")))
381}
382
383fn validate_public_url(url: &Url, options: &UrlFetchOptions) -> Result<(), UrlFetchError> {
384    if url.scheme() != "http" && url.scheme() != "https" {
385        return Err(UrlFetchError::new(format!(
386            "Only http:// and https:// URLs are supported, got: {}:",
387            url.scheme()
388        )));
389    }
390    if options.allow_private {
391        return Ok(());
392    }
393
394    let host = url
395        .host_str()
396        .ok_or_else(|| UrlFetchError::new(format!("URL missing host: {url}")))?;
397    let host_for_parse = host
398        .trim_matches(['[', ']'])
399        .split('%')
400        .next()
401        .unwrap_or(host);
402
403    if let Ok(ip) = host_for_parse.parse::<IpAddr>() {
404        reject_private_ip(host, ip)?;
405        return Ok(());
406    }
407    if host_for_parse.contains(':') {
408        return Err(UrlFetchError::new(format!(
409            "Blocked private URL host {host} ({host_for_parse})"
410        )));
411    }
412
413    let addresses = resolve_host_ips(host_for_parse, url.port_or_known_default(), options)?;
414    if addresses.is_empty() {
415        return Err(UrlFetchError::new(format!(
416            "Failed to resolve URL host {host}"
417        )));
418    }
419    for ip in addresses {
420        reject_private_ip(host, ip)?;
421    }
422
423    // We validate all resolved addresses before issuing the request. Reqwest's
424    // default resolver runs again during TCP connect, leaving the same small
425    // DNS-rebinding window the old Bun fallback accepted. A custom per-request
426    // resolver hook would close that window but adds complexity for marginal
427    // value in this opt-in agent-tooling surface.
428    Ok(())
429}
430
431fn resolve_host_ips(
432    host: &str,
433    port: Option<u16>,
434    options: &UrlFetchOptions,
435) -> Result<Vec<IpAddr>, UrlFetchError> {
436    if let Some((_, ips)) = options
437        .public_host_overrides
438        .iter()
439        .find(|(override_host, _)| override_host == host)
440    {
441        return Ok(ips.clone());
442    }
443
444    let port = port.unwrap_or(80);
445    let addrs = (host, port).to_socket_addrs().map_err(|error| {
446        UrlFetchError::new(format!("Failed to resolve URL host {host}: {error}"))
447    })?;
448    Ok(addrs.map(|addr| addr.ip()).collect())
449}
450
451fn reject_private_ip(host: &str, ip: IpAddr) -> Result<(), UrlFetchError> {
452    if is_private_ip(ip) {
453        return Err(UrlFetchError::new(format!(
454            "Blocked private URL host {host} ({ip})"
455        )));
456    }
457    Ok(())
458}
459
460fn is_private_ip(ip: IpAddr) -> bool {
461    match ip {
462        IpAddr::V4(ipv4) => is_private_ipv4(ipv4),
463        IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
464    }
465}
466
467fn is_private_ipv4(ip: Ipv4Addr) -> bool {
468    let [a, b, _, _] = ip.octets();
469    a == 0
470        || a == 10
471        || a == 127
472        || (a == 172 && (16..=31).contains(&b))
473        || (a == 192 && b == 168)
474        || (a == 169 && b == 254)
475        || a >= 224
476}
477
478fn is_private_ipv6(ip: Ipv6Addr) -> bool {
479    let segments = ip.segments();
480    let top_six_zero = segments[..6].iter().all(|segment| *segment == 0);
481    let is_mapped = segments[..5].iter().all(|segment| *segment == 0) && segments[5] == 0xffff;
482    if is_mapped || top_six_zero {
483        let embedded = Ipv4Addr::new(
484            (segments[6] >> 8) as u8,
485            (segments[6] & 0xff) as u8,
486            (segments[7] >> 8) as u8,
487            (segments[7] & 0xff) as u8,
488        );
489        return is_private_ipv4(embedded);
490    }
491
492    let first = segments[0];
493    (0xfe80..=0xfebf).contains(&first) || (0xfc00..=0xfdff).contains(&first) || first >= 0xff00
494}
495
496fn resolve_extension(content_type: &str) -> Option<&'static str> {
497    let lower = content_type.to_ascii_lowercase();
498    let media_type = lower
499        .split(';')
500        .next()
501        .unwrap_or("")
502        .split(',')
503        .next()
504        .unwrap_or("")
505        .trim();
506
507    match media_type {
508        "text/html"
509        | "application/xhtml+xml"
510        | "application/vnd.github.html"
511        | "application/vnd.github+html" => Some(".html"),
512        "text/markdown"
513        | "text/x-markdown"
514        | "application/markdown"
515        | "application/vnd.github.raw"
516        | "application/vnd.github+raw"
517        | "application/vnd.github.v3.raw"
518        | "text/plain" => Some(".md"),
519        "application/json" | "application/ld+json" => Some(".json"),
520        other if other.ends_with("+json") => Some(".json"),
521        _ => None,
522    }
523}
524
525enum BodyReadEvent {
526    Chunk(Vec<u8>),
527    Done,
528    Error(io::ErrorKind, String),
529}
530
531fn read_response_body(mut response: HttpResponse, url: &str) -> Result<Vec<u8>, UrlFetchError> {
532    let (tx, rx) = mpsc::channel();
533    thread::spawn(move || {
534        let mut buffer = [0u8; 16 * 1024];
535        loop {
536            match response.read(&mut buffer) {
537                Ok(0) => {
538                    let _ = tx.send(BodyReadEvent::Done);
539                    break;
540                }
541                Ok(n) => {
542                    if tx.send(BodyReadEvent::Chunk(buffer[..n].to_vec())).is_err() {
543                        break;
544                    }
545                }
546                Err(error) => {
547                    let kind = error.kind();
548                    let message = error.to_string();
549                    let _ = tx.send(BodyReadEvent::Error(kind, message));
550                    break;
551                }
552            }
553        }
554    });
555
556    let mut chunks = Vec::new();
557    let mut total = 0u64;
558    loop {
559        match rx.recv_timeout(BODY_CHUNK_TIMEOUT) {
560            Ok(BodyReadEvent::Chunk(chunk)) => {
561                total += chunk.len() as u64;
562                if total > MAX_RESPONSE_BYTES {
563                    return Err(UrlFetchError::new(format!(
564                        "Response exceeded {MAX_RESPONSE_BYTES} bytes, aborted"
565                    )));
566                }
567                chunks.extend_from_slice(&chunk);
568            }
569            Ok(BodyReadEvent::Done) => return Ok(chunks),
570            Ok(BodyReadEvent::Error(kind, _message)) if is_body_stall_kind(kind) => {
571                return Err(body_stall_error(url));
572            }
573            Ok(BodyReadEvent::Error(_, message)) => {
574                return Err(UrlFetchError::new(format!(
575                    "Failed to read response body for {url}: {message}"
576                )));
577            }
578            Err(mpsc::RecvTimeoutError::Timeout) => return Err(body_stall_error(url)),
579            Err(mpsc::RecvTimeoutError::Disconnected) => {
580                return Err(UrlFetchError::new(format!(
581                    "Failed to read response body for {url}: body reader stopped unexpectedly"
582                )));
583            }
584        }
585    }
586}
587
588fn body_stall_error(url: &str) -> UrlFetchError {
589    UrlFetchError::new(format!(
590        "Body read stalled (no data for {}ms) fetching {url}",
591        BODY_CHUNK_TIMEOUT.as_millis()
592    ))
593}
594
595fn is_body_stall_kind(kind: io::ErrorKind) -> bool {
596    matches!(kind, io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock)
597}
598
599fn atomic_write(
600    final_path: &Path,
601    bytes: &[u8],
602    options: &UrlFetchOptions,
603) -> Result<(), UrlFetchError> {
604    let parent = final_path.parent().unwrap_or_else(|| Path::new("."));
605    fs::create_dir_all(parent).map_err(|error| {
606        UrlFetchError::new(format!(
607            "Failed to create URL cache parent {}: {error}",
608            parent.display()
609        ))
610    })?;
611
612    let file_name = final_path
613        .file_name()
614        .and_then(|name| name.to_str())
615        .ok_or_else(|| {
616            UrlFetchError::new(format!("Invalid cache path: {}", final_path.display()))
617        })?;
618    let tmp_path = final_path.with_file_name(format!(
619        "{file_name}.tmp-{}-{}",
620        std::process::id(),
621        random_nonce()
622    ));
623
624    let write_result = (|| -> io::Result<()> {
625        let mut file = fs::File::create(&tmp_path)?;
626        file.write_all(bytes)?;
627        file.flush()?;
628        Ok(())
629    })();
630    if let Err(error) = write_result {
631        let _ = fs::remove_file(&tmp_path);
632        return Err(UrlFetchError::new(format!(
633            "Failed to write URL cache temp file {}: {error}",
634            tmp_path.display()
635        )));
636    }
637
638    if let Some(observer) = &options.atomic_write_observer {
639        observer(&tmp_path, final_path);
640    }
641
642    fs::rename(&tmp_path, final_path).map_err(|error| {
643        let _ = fs::remove_file(&tmp_path);
644        UrlFetchError::new(format!(
645            "Failed to finalize URL cache file {}: {error}",
646            final_path.display()
647        ))
648    })
649}
650
651fn random_nonce() -> String {
652    let mut bytes = [0u8; 8];
653    if getrandom::fill(&mut bytes).is_err() {
654        let fallback = now_ms() ^ u64::from(std::process::id());
655        bytes = fallback.to_le_bytes();
656    }
657    let mut out = String::with_capacity(bytes.len() * 2);
658    for byte in bytes {
659        use std::fmt::Write as _;
660        let _ = write!(out, "{byte:02x}");
661    }
662    out
663}
664
665fn now_ms() -> u64 {
666    SystemTime::now()
667        .duration_since(UNIX_EPOCH)
668        .unwrap_or_default()
669        .as_millis()
670        .try_into()
671        .unwrap_or(u64::MAX)
672}
673
674fn reqwest_error_detail(error: &reqwest::Error) -> String {
675    if error.is_timeout() {
676        return format!("timeout: {error}");
677    }
678    if let Some(source) = error.source() {
679        return format!("{source}");
680    }
681    error.to_string()
682}