Skip to main content

aft/
url_fetch.rs

1use std::error::Error;
2use std::fmt;
3use std::fs;
4use std::io::{self, Read, Write};
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
6use std::path::{Path, PathBuf};
7use std::sync::{mpsc, Arc};
8use std::thread;
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10
11use reqwest::blocking::{Client, Response as HttpResponse};
12use reqwest::header::{ACCEPT, CONTENT_TYPE, LOCATION, USER_AGENT};
13use reqwest::redirect::Policy;
14use serde::{Deserialize, Serialize};
15use sha2::{Digest, Sha256};
16use url::Url;
17
18const MAX_RESPONSE_BYTES: u64 = 10 * 1024 * 1024;
19const CACHE_TTL_MS: u64 = 24 * 60 * 60 * 1000;
20const CONNECT_TIMEOUT: Duration = Duration::from_millis(30_000);
21const BODY_CHUNK_TIMEOUT: Duration = Duration::from_millis(15_000);
22const MAX_REDIRECTS: usize = 5;
23
24/// Retry budget for transient connect/transport failures only. Agents
25/// shouldn't have to retry manually for a single TCP/TLS hiccup. We cap
26/// at 2 retries (= 3 total attempts) with short jittered backoff so a
27/// genuinely-broken host fails fast instead of dragging the foreground
28/// fetch out to many seconds.
29///
30/// We deliberately do NOT retry on:
31///   - HTTP error status (4xx/5xx) — the server actually answered
32///   - Redirect errors / SSRF rejections — those are deterministic
33///   - Body read stalls — already handled by BODY_CHUNK_TIMEOUT
34const TRANSIENT_RETRY_ATTEMPTS: usize = 2;
35const TRANSIENT_RETRY_BACKOFFS_MS: [u64; TRANSIENT_RETRY_ATTEMPTS] = [200, 600];
36const ACCEPT_HEADER: &str = "application/vnd.github.raw, text/markdown, text/x-markdown, text/html;q=0.9, application/json;q=0.8, text/plain;q=0.5";
37const USER_AGENT_VALUE: &str = "aft-opencode-plugin";
38
39#[derive(Clone, Default)]
40pub struct UrlFetchOptions {
41    pub allow_private: bool,
42    /// Test hook: treat a hostname as resolving to these IPs during SSRF validation.
43    /// Production callers leave this empty and use `std::net::ToSocketAddrs`.
44    #[doc(hidden)]
45    pub public_host_overrides: Vec<(String, Vec<IpAddr>)>,
46    /// Test hook: force reqwest to connect a hostname to a local mock server while
47    /// SSRF validation still sees `public_host_overrides` above.
48    #[doc(hidden)]
49    pub connect_overrides: Vec<(String, SocketAddr)>,
50    /// Test hook: observes the temp path immediately before the atomic rename.
51    #[doc(hidden)]
52    pub atomic_write_observer: Option<Arc<dyn Fn(&Path, &Path) + Send + Sync>>,
53}
54
55#[derive(Debug, Clone)]
56pub struct UrlFetchError {
57    message: String,
58}
59
60impl UrlFetchError {
61    fn new(message: impl Into<String>) -> Self {
62        Self {
63            message: message.into(),
64        }
65    }
66}
67
68impl fmt::Display for UrlFetchError {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.write_str(&self.message)
71    }
72}
73
74impl std::error::Error for UrlFetchError {}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct CacheMeta {
78    url: String,
79    #[serde(rename = "contentType")]
80    content_type: String,
81    extension: String,
82    #[serde(rename = "fetchedAt")]
83    fetched_at: u64,
84}
85
86pub fn is_http_url(value: &str) -> bool {
87    value.starts_with("http://") || value.starts_with("https://")
88}
89
90pub fn fetch_url_to_cache(
91    url: &str,
92    storage_dir: &Path,
93    options: UrlFetchOptions,
94) -> Result<PathBuf, UrlFetchError> {
95    let parsed = Url::parse(url).map_err(|_| UrlFetchError::new(format!("Invalid URL: {url}")))?;
96    validate_public_url(&parsed, &options)?;
97
98    let dir = cache_dir(storage_dir);
99    fs::create_dir_all(&dir).map_err(|error| {
100        UrlFetchError::new(format!(
101            "Failed to create URL cache directory {}: {error}",
102            dir.display()
103        ))
104    })?;
105
106    let hash = hash_url(url);
107    let meta_file = meta_path(storage_dir, &hash);
108    if let Some(cached) = fresh_cached_path(storage_dir, &hash, &meta_file)? {
109        return Ok(cached);
110    }
111
112    let response = fetch_with_redirects(&parsed, url, &options)?;
113    if !response.status().is_success() {
114        return Err(UrlFetchError::new(format!(
115            "HTTP {} {} fetching {url}",
116            response.status().as_u16(),
117            response.status().canonical_reason().unwrap_or("")
118        )));
119    }
120
121    let content_type = response
122        .headers()
123        .get(CONTENT_TYPE)
124        .and_then(|value| value.to_str().ok())
125        .unwrap_or("text/plain")
126        .to_string();
127    let extension = resolve_extension(&content_type).ok_or_else(|| {
128        UrlFetchError::new(format!(
129            "Unsupported content type '{content_type}' for {url}. Supported: text/html, text/markdown, application/json, text/plain"
130        ))
131    })?;
132
133    if let Some(length) = response.content_length() {
134        if length > MAX_RESPONSE_BYTES {
135            return Err(UrlFetchError::new(format!(
136                "Response too large: {length} bytes (max {MAX_RESPONSE_BYTES})"
137            )));
138        }
139    }
140
141    let body = read_response_body(response, url)?;
142    let content_file = content_path(storage_dir, &hash, extension);
143    atomic_write(&content_file, &body, &options)?;
144
145    let meta = CacheMeta {
146        url: url.to_string(),
147        content_type,
148        extension: extension.to_string(),
149        fetched_at: now_ms(),
150    };
151    let meta_bytes = serde_json::to_vec(&meta).map_err(|error| {
152        UrlFetchError::new(format!("Failed to encode URL cache metadata: {error}"))
153    })?;
154    atomic_write(&meta_file, &meta_bytes, &options)?;
155
156    Ok(content_file)
157}
158
159pub fn cleanup_url_cache(storage_dir: &Path) -> Result<usize, UrlFetchError> {
160    let dir = cache_dir(storage_dir);
161    if !dir.exists() {
162        return Ok(0);
163    }
164
165    let entries = fs::read_dir(&dir).map_err(|error| {
166        UrlFetchError::new(format!(
167            "URL cache cleanup failed reading {}: {error}",
168            dir.display()
169        ))
170    })?;
171    let mut removed = 0usize;
172    let now = now_ms();
173
174    for entry in entries.flatten() {
175        let path = entry.path();
176        let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
177            continue;
178        };
179        if !name.ends_with(".meta.json") {
180            continue;
181        }
182
183        let meta = fs::read_to_string(&path)
184            .ok()
185            .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok());
186        let Some(meta) = meta else {
187            if fs::remove_file(&path).is_ok() {
188                removed += 1;
189            }
190            continue;
191        };
192
193        if now.saturating_sub(meta.fetched_at) <= CACHE_TTL_MS {
194            continue;
195        }
196
197        let hash = name.trim_end_matches(".meta.json");
198        let content = content_path(storage_dir, hash, &meta.extension);
199        let _ = fs::remove_file(content);
200        if fs::remove_file(&path).is_ok() {
201            removed += 1;
202        }
203    }
204
205    Ok(removed)
206}
207
208#[doc(hidden)]
209pub fn cache_content_path_for_url(storage_dir: &Path, url: &str, extension: &str) -> PathBuf {
210    content_path(storage_dir, &hash_url(url), extension)
211}
212
213#[doc(hidden)]
214pub fn cache_meta_path_for_url(storage_dir: &Path, url: &str) -> PathBuf {
215    meta_path(storage_dir, &hash_url(url))
216}
217
218#[doc(hidden)]
219pub fn is_private_ip_for_test(ip: IpAddr) -> bool {
220    is_private_ip(ip)
221}
222
223fn cache_dir(storage_dir: &Path) -> PathBuf {
224    storage_dir.join("url_cache")
225}
226
227fn hash_url(url: &str) -> String {
228    let digest = Sha256::digest(url.as_bytes());
229    format!("{digest:x}").chars().take(16).collect()
230}
231
232fn meta_path(storage_dir: &Path, hash: &str) -> PathBuf {
233    cache_dir(storage_dir).join(format!("{hash}.meta.json"))
234}
235
236fn content_path(storage_dir: &Path, hash: &str, extension: &str) -> PathBuf {
237    cache_dir(storage_dir).join(format!("{hash}{extension}"))
238}
239
240fn fresh_cached_path(
241    storage_dir: &Path,
242    hash: &str,
243    meta_file: &Path,
244) -> Result<Option<PathBuf>, UrlFetchError> {
245    if !meta_file.exists() {
246        return Ok(None);
247    }
248
249    let meta = match fs::read_to_string(meta_file)
250        .ok()
251        .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok())
252    {
253        Some(meta) => meta,
254        None => return Ok(None),
255    };
256    let age = now_ms().saturating_sub(meta.fetched_at);
257    let cached = content_path(storage_dir, hash, &meta.extension);
258    if age < CACHE_TTL_MS && cached.exists() {
259        return Ok(Some(cached));
260    }
261    Ok(None)
262}
263
264fn fetch_with_redirects(
265    start_url: &Url,
266    original_url: &str,
267    options: &UrlFetchOptions,
268) -> Result<HttpResponse, UrlFetchError> {
269    let client = build_client(options)?;
270    let mut current_url = start_url.clone();
271
272    for redirect_count in 0..=MAX_REDIRECTS {
273        validate_public_url(&current_url, options)?;
274        let response = send_with_transient_retries(&client, &current_url)?;
275
276        if !response.status().is_redirection() {
277            return Ok(response);
278        }
279        if redirect_count == MAX_REDIRECTS {
280            return Err(UrlFetchError::new(format!(
281                "Too many redirects fetching {original_url}"
282            )));
283        }
284
285        let location = response
286            .headers()
287            .get(LOCATION)
288            .and_then(|value| value.to_str().ok())
289            .ok_or_else(|| {
290                UrlFetchError::new(format!(
291                    "Redirect from {} missing Location header",
292                    current_url.as_str()
293                ))
294            })?;
295        current_url = current_url.join(location).map_err(|error| {
296            UrlFetchError::new(format!(
297                "Invalid redirect Location '{location}' from {}: {error}",
298                current_url.as_str()
299            ))
300        })?;
301    }
302
303    Err(UrlFetchError::new(format!(
304        "Too many redirects fetching {original_url}"
305    )))
306}
307
308/// Issue a single GET with the configured User-Agent + Accept headers and
309/// transparently retry only on transient connect/transport failures.
310///
311/// Returns the response (including 4xx/5xx — caller decides how to treat
312/// those). On a non-transient reqwest error (e.g. an HTTP-shaped reply that
313/// reqwest still surfaces as Err, or a TLS handshake fault that doesn't read
314/// as `is_connect`), the original error is returned immediately so the user
315/// sees the real failure without an artificial 800ms-plus delay.
316fn send_with_transient_retries(
317    client: &Client,
318    target: &Url,
319) -> Result<HttpResponse, UrlFetchError> {
320    let mut last_error: Option<reqwest::Error> = None;
321    for attempt in 0..=TRANSIENT_RETRY_ATTEMPTS {
322        let result = client
323            .get(target.clone())
324            .header(USER_AGENT, USER_AGENT_VALUE)
325            .header(ACCEPT, ACCEPT_HEADER)
326            .send();
327        match result {
328            Ok(response) => return Ok(response),
329            Err(error) => {
330                if attempt < TRANSIENT_RETRY_ATTEMPTS && is_transient_reqwest_error(&error) {
331                    thread::sleep(Duration::from_millis(TRANSIENT_RETRY_BACKOFFS_MS[attempt]));
332                    last_error = Some(error);
333                    continue;
334                }
335                return Err(UrlFetchError::new(format!(
336                    "Failed to fetch {}: {}",
337                    target.as_str(),
338                    reqwest_error_detail(&error)
339                )));
340            }
341        }
342    }
343    // Loop fell through after the last allowed retry exhausted — surface the
344    // most recent transient error rather than swallowing it.
345    Err(UrlFetchError::new(format!(
346        "Failed to fetch {} after {} retries: {}",
347        target.as_str(),
348        TRANSIENT_RETRY_ATTEMPTS,
349        last_error
350            .as_ref()
351            .map(reqwest_error_detail)
352            .unwrap_or_else(|| "unknown transient error".to_string())
353    )))
354}
355
356/// Classify a reqwest error as transient (worth a quick retry) vs terminal.
357///
358/// Transient: TCP connect failures, request-build/send TCP-level failures
359/// that don't carry status, and timeouts. These typically clear on a single
360/// retry — agents shouldn't have to ask twice for a momentary blip.
361///
362/// Terminal: anything where reqwest got far enough to decode an HTTP-shaped
363/// reply (`is_status()`, `is_body()`, `is_decode()`). Retrying those would
364/// just hammer a server that already answered.
365fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
366    error.is_connect() || error.is_timeout() || error.is_request()
367}
368
369fn build_client(options: &UrlFetchOptions) -> Result<Client, UrlFetchError> {
370    let mut builder = Client::builder()
371        .redirect(Policy::none())
372        .connect_timeout(CONNECT_TIMEOUT);
373
374    for (host, address) in &options.connect_overrides {
375        builder = builder.resolve(host, *address);
376    }
377
378    builder
379        .build()
380        .map_err(|error| UrlFetchError::new(format!("Failed to build URL fetch client: {error}")))
381}
382
383fn validate_public_url(url: &Url, options: &UrlFetchOptions) -> Result<(), UrlFetchError> {
384    if url.scheme() != "http" && url.scheme() != "https" {
385        return Err(UrlFetchError::new(format!(
386            "Only http:// and https:// URLs are supported, got: {}:",
387            url.scheme()
388        )));
389    }
390    if options.allow_private {
391        return Ok(());
392    }
393
394    let host = url
395        .host_str()
396        .ok_or_else(|| UrlFetchError::new(format!("URL missing host: {url}")))?;
397    let host_for_parse = host
398        .trim_matches(['[', ']'])
399        .split('%')
400        .next()
401        .unwrap_or(host);
402
403    if let Ok(ip) = host_for_parse.parse::<IpAddr>() {
404        reject_private_ip(host, ip)?;
405        return Ok(());
406    }
407    if host_for_parse.contains(':') {
408        return Err(UrlFetchError::new(format!(
409            "Blocked private URL host {host} ({host_for_parse})"
410        )));
411    }
412
413    let addresses = resolve_host_ips(host_for_parse, url.port_or_known_default(), options)?;
414    if addresses.is_empty() {
415        return Err(UrlFetchError::new(format!(
416            "Failed to resolve URL host {host}"
417        )));
418    }
419    for ip in addresses {
420        reject_private_ip(host, ip)?;
421    }
422
423    // We validate all resolved addresses before issuing the request. Reqwest's
424    // default resolver runs again during TCP connect, leaving the same small
425    // DNS-rebinding window the old Bun fallback accepted. A custom per-request
426    // resolver hook would close that window but adds complexity for marginal
427    // value in this opt-in agent-tooling surface.
428    Ok(())
429}
430
431fn resolve_host_ips(
432    host: &str,
433    port: Option<u16>,
434    options: &UrlFetchOptions,
435) -> Result<Vec<IpAddr>, UrlFetchError> {
436    if let Some((_, ips)) = options
437        .public_host_overrides
438        .iter()
439        .find(|(override_host, _)| override_host == host)
440    {
441        return Ok(ips.clone());
442    }
443
444    let port = port.unwrap_or(80);
445    let addrs = (host, port).to_socket_addrs().map_err(|error| {
446        UrlFetchError::new(format!("Failed to resolve URL host {host}: {error}"))
447    })?;
448    Ok(addrs.map(|addr| addr.ip()).collect())
449}
450
451fn reject_private_ip(host: &str, ip: IpAddr) -> Result<(), UrlFetchError> {
452    if is_private_ip(ip) {
453        return Err(UrlFetchError::new(format!(
454            "Blocked private URL host {host} ({ip})"
455        )));
456    }
457    Ok(())
458}
459
460fn is_private_ip(ip: IpAddr) -> bool {
461    match ip {
462        IpAddr::V4(ipv4) => is_private_ipv4(ipv4),
463        IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
464    }
465}
466
467fn is_private_ipv4(ip: Ipv4Addr) -> bool {
468    let [a, b, _, _] = ip.octets();
469    a == 0
470        || a == 10
471        || a == 127
472        || (a == 172 && (16..=31).contains(&b))
473        || (a == 192 && b == 168)
474        || (a == 169 && b == 254)
475        // RFC 6598 Shared Address Space (CGNAT): 100.64.0.0/10. Not globally
476        // routable; used for provider/VPC-internal endpoints — must not be
477        // reachable via SSRF.
478        || (a == 100 && (64..=127).contains(&b))
479        // RFC 2544 benchmark subnet: 198.18.0.0/15. Reserved, non-routable.
480        || (a == 198 && (18..=19).contains(&b))
481        || a >= 224
482}
483
484fn is_private_ipv6(ip: Ipv6Addr) -> bool {
485    let segments = ip.segments();
486    let top_six_zero = segments[..6].iter().all(|segment| *segment == 0);
487    let is_mapped = segments[..5].iter().all(|segment| *segment == 0) && segments[5] == 0xffff;
488    if is_mapped || top_six_zero {
489        let embedded = Ipv4Addr::new(
490            (segments[6] >> 8) as u8,
491            (segments[6] & 0xff) as u8,
492            (segments[7] >> 8) as u8,
493            (segments[7] & 0xff) as u8,
494        );
495        return is_private_ipv4(embedded);
496    }
497
498    let first = segments[0];
499    (0xfe80..=0xfebf).contains(&first) || (0xfc00..=0xfdff).contains(&first) || first >= 0xff00
500}
501
502fn resolve_extension(content_type: &str) -> Option<&'static str> {
503    let lower = content_type.to_ascii_lowercase();
504    let media_type = lower
505        .split(';')
506        .next()
507        .unwrap_or("")
508        .split(',')
509        .next()
510        .unwrap_or("")
511        .trim();
512
513    match media_type {
514        "text/html"
515        | "application/xhtml+xml"
516        | "application/vnd.github.html"
517        | "application/vnd.github+html" => Some(".html"),
518        "text/markdown"
519        | "text/x-markdown"
520        | "application/markdown"
521        | "application/vnd.github.raw"
522        | "application/vnd.github+raw"
523        | "application/vnd.github.v3.raw"
524        | "text/plain" => Some(".md"),
525        "application/json" | "application/ld+json" => Some(".json"),
526        other if other.ends_with("+json") => Some(".json"),
527        _ => None,
528    }
529}
530
531enum BodyReadEvent {
532    Chunk(Vec<u8>),
533    Done,
534    Error(io::ErrorKind, String),
535}
536
537fn read_response_body(mut response: HttpResponse, url: &str) -> Result<Vec<u8>, UrlFetchError> {
538    let (tx, rx) = mpsc::channel();
539    thread::spawn(move || {
540        let mut buffer = [0u8; 16 * 1024];
541        loop {
542            match response.read(&mut buffer) {
543                Ok(0) => {
544                    let _ = tx.send(BodyReadEvent::Done);
545                    break;
546                }
547                Ok(n) => {
548                    if tx.send(BodyReadEvent::Chunk(buffer[..n].to_vec())).is_err() {
549                        break;
550                    }
551                }
552                Err(error) => {
553                    let kind = error.kind();
554                    let message = error.to_string();
555                    let _ = tx.send(BodyReadEvent::Error(kind, message));
556                    break;
557                }
558            }
559        }
560    });
561
562    let mut chunks = Vec::new();
563    let mut total = 0u64;
564    loop {
565        match rx.recv_timeout(BODY_CHUNK_TIMEOUT) {
566            Ok(BodyReadEvent::Chunk(chunk)) => {
567                total += chunk.len() as u64;
568                if total > MAX_RESPONSE_BYTES {
569                    return Err(UrlFetchError::new(format!(
570                        "Response exceeded {MAX_RESPONSE_BYTES} bytes, aborted"
571                    )));
572                }
573                chunks.extend_from_slice(&chunk);
574            }
575            Ok(BodyReadEvent::Done) => return Ok(chunks),
576            Ok(BodyReadEvent::Error(kind, _message)) if is_body_stall_kind(kind) => {
577                return Err(body_stall_error(url));
578            }
579            Ok(BodyReadEvent::Error(_, message)) => {
580                return Err(UrlFetchError::new(format!(
581                    "Failed to read response body for {url}: {message}"
582                )));
583            }
584            Err(mpsc::RecvTimeoutError::Timeout) => return Err(body_stall_error(url)),
585            Err(mpsc::RecvTimeoutError::Disconnected) => {
586                return Err(UrlFetchError::new(format!(
587                    "Failed to read response body for {url}: body reader stopped unexpectedly"
588                )));
589            }
590        }
591    }
592}
593
594fn body_stall_error(url: &str) -> UrlFetchError {
595    UrlFetchError::new(format!(
596        "Body read stalled (no data for {}ms) fetching {url}",
597        BODY_CHUNK_TIMEOUT.as_millis()
598    ))
599}
600
601fn is_body_stall_kind(kind: io::ErrorKind) -> bool {
602    matches!(kind, io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock)
603}
604
605fn atomic_write(
606    final_path: &Path,
607    bytes: &[u8],
608    options: &UrlFetchOptions,
609) -> Result<(), UrlFetchError> {
610    let parent = final_path.parent().unwrap_or_else(|| Path::new("."));
611    fs::create_dir_all(parent).map_err(|error| {
612        UrlFetchError::new(format!(
613            "Failed to create URL cache parent {}: {error}",
614            parent.display()
615        ))
616    })?;
617
618    let file_name = final_path
619        .file_name()
620        .and_then(|name| name.to_str())
621        .ok_or_else(|| {
622            UrlFetchError::new(format!("Invalid cache path: {}", final_path.display()))
623        })?;
624    let tmp_path = final_path.with_file_name(format!(
625        "{file_name}.tmp-{}-{}",
626        std::process::id(),
627        random_nonce()
628    ));
629
630    let write_result = (|| -> io::Result<()> {
631        let mut file = fs::File::create(&tmp_path)?;
632        file.write_all(bytes)?;
633        file.flush()?;
634        Ok(())
635    })();
636    if let Err(error) = write_result {
637        let _ = fs::remove_file(&tmp_path);
638        return Err(UrlFetchError::new(format!(
639            "Failed to write URL cache temp file {}: {error}",
640            tmp_path.display()
641        )));
642    }
643
644    if let Some(observer) = &options.atomic_write_observer {
645        observer(&tmp_path, final_path);
646    }
647
648    fs::rename(&tmp_path, final_path).map_err(|error| {
649        let _ = fs::remove_file(&tmp_path);
650        UrlFetchError::new(format!(
651            "Failed to finalize URL cache file {}: {error}",
652            final_path.display()
653        ))
654    })
655}
656
657fn random_nonce() -> String {
658    let mut bytes = [0u8; 8];
659    if getrandom::fill(&mut bytes).is_err() {
660        let fallback = now_ms() ^ u64::from(std::process::id());
661        bytes = fallback.to_le_bytes();
662    }
663    let mut out = String::with_capacity(bytes.len() * 2);
664    for byte in bytes {
665        use std::fmt::Write as _;
666        let _ = write!(out, "{byte:02x}");
667    }
668    out
669}
670
671fn now_ms() -> u64 {
672    SystemTime::now()
673        .duration_since(UNIX_EPOCH)
674        .unwrap_or_default()
675        .as_millis()
676        .try_into()
677        .unwrap_or(u64::MAX)
678}
679
680fn reqwest_error_detail(error: &reqwest::Error) -> String {
681    if error.is_timeout() {
682        return format!("timeout: {error}");
683    }
684    if let Some(source) = error.source() {
685        return format!("{source}");
686    }
687    error.to_string()
688}