Skip to main content

aft/
url_fetch.rs

1use std::error::Error;
2use std::fmt;
3use std::fs;
4use std::io::{self, Read, Write};
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
6use std::path::{Path, PathBuf};
7use std::sync::{mpsc, Arc};
8use std::thread;
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10
11use htmd::{
12    element_handler::{HandlerResult, Handlers},
13    Element, HtmlToMarkdown,
14};
15use reqwest::blocking::{Client, Response as HttpResponse};
16use reqwest::header::{ACCEPT, CONTENT_TYPE, LOCATION, USER_AGENT};
17use reqwest::redirect::Policy;
18use serde::{Deserialize, Serialize};
19use sha2::{Digest, Sha256};
20use url::Url;
21
22use crate::parser::detect_language;
23
24const MAX_RESPONSE_BYTES: u64 = 10 * 1024 * 1024;
25const CACHE_TTL_MS: u64 = 24 * 60 * 60 * 1000;
26const CONNECT_TIMEOUT: Duration = Duration::from_millis(30_000);
27const BODY_CHUNK_TIMEOUT: Duration = Duration::from_millis(15_000);
28const MAX_REDIRECTS: usize = 5;
29
30/// Retry budget for transient connect/transport failures only. Agents
31/// shouldn't have to retry manually for a single TCP/TLS hiccup. We cap
32/// at 2 retries (= 3 total attempts) with short jittered backoff so a
33/// genuinely-broken host fails fast instead of dragging the foreground
34/// fetch out to many seconds.
35///
36/// We deliberately do NOT retry on:
37///   - HTTP error status (4xx/5xx) — the server actually answered
38///   - Redirect errors / SSRF rejections — those are deterministic
39///   - Body read stalls — already handled by BODY_CHUNK_TIMEOUT
40const TRANSIENT_RETRY_ATTEMPTS: usize = 2;
41const TRANSIENT_RETRY_BACKOFFS_MS: [u64; TRANSIENT_RETRY_ATTEMPTS] = [200, 600];
42const ACCEPT_HEADER: &str = "application/vnd.github.raw, text/markdown, text/x-markdown, text/html;q=0.9, application/json;q=0.8, text/plain;q=0.5";
43const USER_AGENT_VALUE: &str = "aft-opencode-plugin";
44const CONVERTED_MARKDOWN_CONTENT_TYPE: &str = "text/markdown; charset=utf-8";
45
46#[derive(Clone, Default)]
47pub struct UrlFetchOptions {
48    pub allow_private: bool,
49    /// Test hook: treat a hostname as resolving to these IPs during SSRF validation.
50    /// Production callers leave this empty and use `std::net::ToSocketAddrs`.
51    #[doc(hidden)]
52    pub public_host_overrides: Vec<(String, Vec<IpAddr>)>,
53    /// Test hook: force reqwest to connect a hostname to a local mock server while
54    /// SSRF validation still sees `public_host_overrides` above.
55    #[doc(hidden)]
56    pub connect_overrides: Vec<(String, SocketAddr)>,
57    /// Test hook: observes the temp path immediately before the atomic rename.
58    #[doc(hidden)]
59    pub atomic_write_observer: Option<Arc<dyn Fn(&Path, &Path) + Send + Sync>>,
60}
61
62#[derive(Debug, Clone)]
63pub struct UrlFetchError {
64    message: String,
65}
66
67impl UrlFetchError {
68    fn new(message: impl Into<String>) -> Self {
69        Self {
70            message: message.into(),
71        }
72    }
73}
74
75impl fmt::Display for UrlFetchError {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        f.write_str(&self.message)
78    }
79}
80
81impl std::error::Error for UrlFetchError {}
82
83#[derive(Debug, Serialize, Deserialize)]
84struct CacheMeta {
85    url: String,
86    #[serde(rename = "contentType")]
87    content_type: String,
88    extension: String,
89    #[serde(rename = "fetchedAt")]
90    fetched_at: u64,
91}
92
93pub fn is_http_url(value: &str) -> bool {
94    value.starts_with("http://") || value.starts_with("https://")
95}
96
97pub fn fetch_url_to_cache(
98    url: &str,
99    storage_dir: &Path,
100    options: UrlFetchOptions,
101) -> Result<PathBuf, UrlFetchError> {
102    let parsed = Url::parse(url).map_err(|_| UrlFetchError::new(format!("Invalid URL: {url}")))?;
103    validate_public_url(&parsed, &options)?;
104
105    let dir = cache_dir(storage_dir);
106    fs::create_dir_all(&dir).map_err(|error| {
107        UrlFetchError::new(format!(
108            "Failed to create URL cache directory {}: {error}",
109            dir.display()
110        ))
111    })?;
112
113    let hash = hash_url(url);
114    let meta_file = meta_path(storage_dir, &hash);
115    if let Some(cached) = fresh_cached_path(storage_dir, &hash, &meta_file, &parsed)? {
116        return Ok(cached);
117    }
118
119    let response = fetch_with_redirects(&parsed, url, &options)?;
120    if !response.status().is_success() {
121        return Err(UrlFetchError::new(format!(
122            "HTTP {} {} fetching {url}",
123            response.status().as_u16(),
124            response.status().canonical_reason().unwrap_or("")
125        )));
126    }
127
128    let content_type = response
129        .headers()
130        .get(CONTENT_TYPE)
131        .and_then(|value| value.to_str().ok())
132        .unwrap_or("text/plain")
133        .to_string();
134    let (extension, from_source_path) =
135        resolve_fetch_extension(&parsed, &content_type).ok_or_else(|| {
136            UrlFetchError::new(format!(
137                "Unsupported content type '{content_type}' for {url}. Supported: text/html, text/markdown, application/json, text/plain; source files via URL path extension (e.g. .rs, .ts, .mjs)"
138            ))
139        })?;
140
141    if let Some(length) = response.content_length() {
142        if length > MAX_RESPONSE_BYTES {
143            return Err(UrlFetchError::new(format!(
144                "Response too large: {length} bytes (max {MAX_RESPONSE_BYTES})"
145            )));
146        }
147    }
148
149    let body = read_response_body(response, url)?;
150    if from_source_path && body_contains_nul_in_prefix(&body) {
151        return Err(UrlFetchError::new(format!(
152            "Binary content detected for source URL {url}"
153        )));
154    }
155    let (body, content_type, extension) = if extension == ".html" {
156        (
157            convert_html_body_to_markdown(&body, url)?,
158            CONVERTED_MARKDOWN_CONTENT_TYPE.to_string(),
159            ".md",
160        )
161    } else {
162        (body, content_type, extension)
163    };
164
165    let content_file = content_path(storage_dir, &hash, extension);
166    atomic_write(&content_file, &body, &options)?;
167
168    let meta = CacheMeta {
169        url: url.to_string(),
170        content_type,
171        extension: extension.to_string(),
172        fetched_at: now_ms(),
173    };
174    let meta_bytes = serde_json::to_vec(&meta).map_err(|error| {
175        UrlFetchError::new(format!("Failed to encode URL cache metadata: {error}"))
176    })?;
177    atomic_write(&meta_file, &meta_bytes, &options)?;
178
179    Ok(content_file)
180}
181
182pub fn cleanup_url_cache(storage_dir: &Path) -> Result<usize, UrlFetchError> {
183    let dir = cache_dir(storage_dir);
184    if !dir.exists() {
185        return Ok(0);
186    }
187
188    let entries = fs::read_dir(&dir).map_err(|error| {
189        UrlFetchError::new(format!(
190            "URL cache cleanup failed reading {}: {error}",
191            dir.display()
192        ))
193    })?;
194    let mut removed = 0usize;
195    let now = now_ms();
196
197    for entry in entries.flatten() {
198        let path = entry.path();
199        let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
200            continue;
201        };
202        if !name.ends_with(".meta.json") {
203            continue;
204        }
205
206        let meta = fs::read_to_string(&path)
207            .ok()
208            .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok());
209        let Some(meta) = meta else {
210            if fs::remove_file(&path).is_ok() {
211                removed += 1;
212            }
213            continue;
214        };
215
216        if now.saturating_sub(meta.fetched_at) <= CACHE_TTL_MS {
217            continue;
218        }
219
220        let hash = name.trim_end_matches(".meta.json");
221        let content = content_path(storage_dir, hash, &meta.extension);
222        let _ = fs::remove_file(content);
223        if fs::remove_file(&path).is_ok() {
224            removed += 1;
225        }
226    }
227
228    Ok(removed)
229}
230
231#[doc(hidden)]
232pub fn cache_content_path_for_url(storage_dir: &Path, url: &str, extension: &str) -> PathBuf {
233    content_path(storage_dir, &hash_url(url), extension)
234}
235
236#[doc(hidden)]
237pub fn cache_meta_path_for_url(storage_dir: &Path, url: &str) -> PathBuf {
238    meta_path(storage_dir, &hash_url(url))
239}
240
241#[doc(hidden)]
242pub fn is_private_ip_for_test(ip: IpAddr) -> bool {
243    is_private_ip(ip)
244}
245
246fn cache_dir(storage_dir: &Path) -> PathBuf {
247    storage_dir.join("url_cache")
248}
249
250fn hash_url(url: &str) -> String {
251    let digest = Sha256::digest(url.as_bytes());
252    format!("{digest:x}").chars().take(16).collect()
253}
254
255fn meta_path(storage_dir: &Path, hash: &str) -> PathBuf {
256    cache_dir(storage_dir).join(format!("{hash}.meta.json"))
257}
258
259fn content_path(storage_dir: &Path, hash: &str, extension: &str) -> PathBuf {
260    cache_dir(storage_dir).join(format!("{hash}{extension}"))
261}
262
263fn fresh_cached_path(
264    storage_dir: &Path,
265    hash: &str,
266    meta_file: &Path,
267    url: &Url,
268) -> Result<Option<PathBuf>, UrlFetchError> {
269    if !meta_file.exists() {
270        return Ok(None);
271    }
272
273    let meta = match fs::read_to_string(meta_file)
274        .ok()
275        .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok())
276    {
277        Some(meta) => meta,
278        None => return Ok(None),
279    };
280    let age = now_ms().saturating_sub(meta.fetched_at);
281    if meta.extension == ".html" {
282        return Ok(None);
283    }
284
285    let content_type = meta.content_type.as_str();
286    let current = resolve_fetch_extension(url, content_type);
287    let expected_ext = current.map(|(ext, _)| ext);
288    if expected_ext != Some(meta.extension.as_str()) {
289        return Ok(None);
290    }
291
292    let cached = content_path(storage_dir, hash, &meta.extension);
293    if age < CACHE_TTL_MS && cached.exists() {
294        return Ok(Some(cached));
295    }
296    Ok(None)
297}
298
299fn fetch_with_redirects(
300    start_url: &Url,
301    original_url: &str,
302    options: &UrlFetchOptions,
303) -> Result<HttpResponse, UrlFetchError> {
304    let client = build_client(options)?;
305    let mut current_url = start_url.clone();
306
307    for redirect_count in 0..=MAX_REDIRECTS {
308        validate_public_url(&current_url, options)?;
309        let response = send_with_transient_retries(&client, &current_url)?;
310
311        if !response.status().is_redirection() {
312            return Ok(response);
313        }
314        if redirect_count == MAX_REDIRECTS {
315            return Err(UrlFetchError::new(format!(
316                "Too many redirects fetching {original_url}"
317            )));
318        }
319
320        let location = response
321            .headers()
322            .get(LOCATION)
323            .and_then(|value| value.to_str().ok())
324            .ok_or_else(|| {
325                UrlFetchError::new(format!(
326                    "Redirect from {} missing Location header",
327                    current_url.as_str()
328                ))
329            })?;
330        current_url = current_url.join(location).map_err(|error| {
331            UrlFetchError::new(format!(
332                "Invalid redirect Location '{location}' from {}: {error}",
333                current_url.as_str()
334            ))
335        })?;
336    }
337
338    Err(UrlFetchError::new(format!(
339        "Too many redirects fetching {original_url}"
340    )))
341}
342
343/// Issue a single GET with the configured User-Agent + Accept headers and
344/// transparently retry only on transient connect/transport failures.
345///
346/// Returns the response (including 4xx/5xx — caller decides how to treat
347/// those). On a non-transient reqwest error (e.g. an HTTP-shaped reply that
348/// reqwest still surfaces as Err, or a TLS handshake fault that doesn't read
349/// as `is_connect`), the original error is returned immediately so the user
350/// sees the real failure without an artificial 800ms-plus delay.
351fn send_with_transient_retries(
352    client: &Client,
353    target: &Url,
354) -> Result<HttpResponse, UrlFetchError> {
355    let mut last_error: Option<reqwest::Error> = None;
356    for attempt in 0..=TRANSIENT_RETRY_ATTEMPTS {
357        let result = client
358            .get(target.clone())
359            .header(USER_AGENT, USER_AGENT_VALUE)
360            .header(ACCEPT, ACCEPT_HEADER)
361            .send();
362        match result {
363            Ok(response) => return Ok(response),
364            Err(error) => {
365                if attempt < TRANSIENT_RETRY_ATTEMPTS && is_transient_reqwest_error(&error) {
366                    thread::sleep(Duration::from_millis(TRANSIENT_RETRY_BACKOFFS_MS[attempt]));
367                    last_error = Some(error);
368                    continue;
369                }
370                return Err(UrlFetchError::new(format!(
371                    "Failed to fetch {}: {}",
372                    target.as_str(),
373                    reqwest_error_detail(&error)
374                )));
375            }
376        }
377    }
378    // Loop fell through after the last allowed retry exhausted — surface the
379    // most recent transient error rather than swallowing it.
380    Err(UrlFetchError::new(format!(
381        "Failed to fetch {} after {} retries: {}",
382        target.as_str(),
383        TRANSIENT_RETRY_ATTEMPTS,
384        last_error
385            .as_ref()
386            .map(reqwest_error_detail)
387            .unwrap_or_else(|| "unknown transient error".to_string())
388    )))
389}
390
391/// Classify a reqwest error as transient (worth a quick retry) vs terminal.
392///
393/// Transient: TCP connect failures, request-build/send TCP-level failures
394/// that don't carry status, and timeouts. These typically clear on a single
395/// retry — agents shouldn't have to ask twice for a momentary blip.
396///
397/// Terminal: anything where reqwest got far enough to decode an HTTP-shaped
398/// reply (`is_status()`, `is_body()`, `is_decode()`). Retrying those would
399/// just hammer a server that already answered.
400fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
401    error.is_connect() || error.is_timeout() || error.is_request()
402}
403
404fn build_client(options: &UrlFetchOptions) -> Result<Client, UrlFetchError> {
405    let mut builder = Client::builder()
406        .redirect(Policy::none())
407        .connect_timeout(CONNECT_TIMEOUT);
408
409    for (host, address) in &options.connect_overrides {
410        builder = builder.resolve(host, *address);
411    }
412
413    builder
414        .build()
415        .map_err(|error| UrlFetchError::new(format!("Failed to build URL fetch client: {error}")))
416}
417
418fn validate_public_url(url: &Url, options: &UrlFetchOptions) -> Result<(), UrlFetchError> {
419    if url.scheme() != "http" && url.scheme() != "https" {
420        return Err(UrlFetchError::new(format!(
421            "Only http:// and https:// URLs are supported, got: {}:",
422            url.scheme()
423        )));
424    }
425    if options.allow_private {
426        return Ok(());
427    }
428
429    let host = url
430        .host_str()
431        .ok_or_else(|| UrlFetchError::new(format!("URL missing host: {url}")))?;
432    let host_for_parse = host
433        .trim_matches(['[', ']'])
434        .split('%')
435        .next()
436        .unwrap_or(host);
437
438    if let Ok(ip) = host_for_parse.parse::<IpAddr>() {
439        reject_private_ip(host, ip)?;
440        return Ok(());
441    }
442    if host_for_parse.contains(':') {
443        return Err(UrlFetchError::new(format!(
444            "Blocked private URL host {host} ({host_for_parse})"
445        )));
446    }
447
448    let addresses = resolve_host_ips(host_for_parse, url.port_or_known_default(), options)?;
449    if addresses.is_empty() {
450        return Err(UrlFetchError::new(format!(
451            "Failed to resolve URL host {host}"
452        )));
453    }
454    for ip in addresses {
455        reject_private_ip(host, ip)?;
456    }
457
458    // We validate all resolved addresses before issuing the request. Reqwest's
459    // default resolver runs again during TCP connect, leaving the same small
460    // DNS-rebinding window the old Bun fallback accepted. A custom per-request
461    // resolver hook would close that window but adds complexity for marginal
462    // value in this opt-in agent-tooling surface.
463    Ok(())
464}
465
466fn resolve_host_ips(
467    host: &str,
468    port: Option<u16>,
469    options: &UrlFetchOptions,
470) -> Result<Vec<IpAddr>, UrlFetchError> {
471    if let Some((_, ips)) = options
472        .public_host_overrides
473        .iter()
474        .find(|(override_host, _)| override_host == host)
475    {
476        return Ok(ips.clone());
477    }
478
479    let port = port.unwrap_or(80);
480    let addrs = (host, port).to_socket_addrs().map_err(|error| {
481        UrlFetchError::new(format!("Failed to resolve URL host {host}: {error}"))
482    })?;
483    Ok(addrs.map(|addr| addr.ip()).collect())
484}
485
486fn reject_private_ip(host: &str, ip: IpAddr) -> Result<(), UrlFetchError> {
487    if is_private_ip(ip) {
488        return Err(UrlFetchError::new(format!(
489            "Blocked private URL host {host} ({ip})"
490        )));
491    }
492    Ok(())
493}
494
495/// True for any private/link-local/CGNAT/benchmark/multicast/reserved/loopback
496/// address (the full set this module refuses to fetch). Exposed so the semantic
497/// embedding SSRF guard shares one authoritative range list instead of keeping a
498/// drifting copy. Note: this INCLUDES loopback — callers that intentionally
499/// allow loopback (e.g. a local Ollama endpoint) must exclude it themselves.
500pub fn is_private_or_reserved_ip(ip: IpAddr) -> bool {
501    is_private_ip(ip)
502}
503
504fn is_private_ip(ip: IpAddr) -> bool {
505    match ip {
506        IpAddr::V4(ipv4) => is_private_ipv4(ipv4),
507        IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
508    }
509}
510
511fn is_private_ipv4(ip: Ipv4Addr) -> bool {
512    let [a, b, _, _] = ip.octets();
513    a == 0
514        || a == 10
515        || a == 127
516        || (a == 172 && (16..=31).contains(&b))
517        || (a == 192 && b == 168)
518        || (a == 169 && b == 254)
519        // RFC 6598 Shared Address Space (CGNAT): 100.64.0.0/10. Not globally
520        // routable; used for provider/VPC-internal endpoints — must not be
521        // reachable via SSRF.
522        || (a == 100 && (64..=127).contains(&b))
523        // RFC 2544 benchmark subnet: 198.18.0.0/15. Reserved, non-routable.
524        || (a == 198 && (18..=19).contains(&b))
525        || a >= 224
526}
527
528fn is_private_ipv6(ip: Ipv6Addr) -> bool {
529    let segments = ip.segments();
530    let top_six_zero = segments[..6].iter().all(|segment| *segment == 0);
531    let is_mapped = segments[..5].iter().all(|segment| *segment == 0) && segments[5] == 0xffff;
532    if is_mapped || top_six_zero {
533        let embedded = Ipv4Addr::new(
534            (segments[6] >> 8) as u8,
535            (segments[6] & 0xff) as u8,
536            (segments[7] >> 8) as u8,
537            (segments[7] & 0xff) as u8,
538        );
539        return is_private_ipv4(embedded);
540    }
541
542    let first = segments[0];
543    (0xfe80..=0xfebf).contains(&first) || (0xfc00..=0xfdff).contains(&first) || first >= 0xff00
544}
545
546const BINARY_SNIFF_PREFIX: usize = 8 * 1024;
547
548fn body_contains_nul_in_prefix(body: &[u8]) -> bool {
549    let end = body.len().min(BINARY_SNIFF_PREFIX);
550    body[..end].contains(&0)
551}
552
553/// Resolve cache extension: URL path (parsed, percent-decoded) first for source
554/// languages AFT parses; otherwise content-type mapping.
555fn resolve_fetch_extension(url: &Url, content_type: &str) -> Option<(&'static str, bool)> {
556    if let Some(ext) = extension_from_url_path(url) {
557        return Some((ext, true));
558    }
559    resolve_extension_from_content_type(content_type).map(|ext| (ext, false))
560}
561
562fn extension_from_url_path(url: &Url) -> Option<&'static str> {
563    let path = url.path();
564    if path.is_empty() || path == "/" {
565        return None;
566    }
567    let segment = path.rsplit('/').next().unwrap_or(path);
568    let file_name = percent_decode_path_segment(segment);
569    let dot = file_name.rfind('.')?;
570    let ext = &file_name[dot + 1..];
571    if ext.is_empty() {
572        return None;
573    }
574    let probe = Path::new("file").with_extension(ext);
575    if detect_language(&probe).is_some() {
576        static_extension_for_lang_ext(ext)
577    } else {
578        None
579    }
580}
581
582fn percent_decode_path_segment(segment: &str) -> String {
583    let mut out = String::with_capacity(segment.len());
584    let bytes = segment.as_bytes();
585    let mut i = 0;
586    while i < bytes.len() {
587        if bytes[i] == b'%' && i + 2 < bytes.len() {
588            if let (Some(h1), Some(h2)) = (from_hex(bytes[i + 1]), from_hex(bytes[i + 2])) {
589                out.push(char::from(h1 << 4 | h2));
590                i += 3;
591                continue;
592            }
593        }
594        out.push(bytes[i] as char);
595        i += 1;
596    }
597    out
598}
599
600fn from_hex(byte: u8) -> Option<u8> {
601    match byte {
602        b'0'..=b'9' => Some(byte - b'0'),
603        b'a'..=b'f' => Some(byte - b'a' + 10),
604        b'A'..=b'F' => Some(byte - b'A' + 10),
605        _ => None,
606    }
607}
608
609/// Dotted cache extension for a path extension that `detect_language` accepts.
610fn static_extension_for_lang_ext(ext: &str) -> Option<&'static str> {
611    match ext.to_ascii_lowercase().as_str() {
612        "ts" | "mts" | "cts" => Some(".ts"),
613        "tsx" => Some(".tsx"),
614        "js" => Some(".js"),
615        "jsx" => Some(".jsx"),
616        "mjs" => Some(".mjs"),
617        "cjs" => Some(".cjs"),
618        "py" | "pyi" => Some(".py"),
619        "rs" => Some(".rs"),
620        "go" => Some(".go"),
621        "c" | "h" => Some(".c"),
622        "cc" | "cpp" | "cxx" | "hpp" | "hh" => Some(".cpp"),
623        "zig" => Some(".zig"),
624        "cs" => Some(".cs"),
625        "sh" | "bash" | "zsh" => Some(".sh"),
626        "html" | "htm" => Some(".html"),
627        "md" | "markdown" | "mdx" => Some(".md"),
628        "sol" => Some(".sol"),
629        "scss" => Some(".scss"),
630        "vue" => Some(".vue"),
631        "json" | "jsonc" => Some(".json"),
632        "scala" | "sc" => Some(".scala"),
633        "java" => Some(".java"),
634        "rb" => Some(".rb"),
635        "kt" | "kts" => Some(".kt"),
636        "swift" => Some(".swift"),
637        "inc" | "php" => Some(".php"),
638        "lua" => Some(".lua"),
639        "pl" | "pm" | "t" => Some(".pl"),
640        "yaml" | "yml" => Some(".yaml"),
641        _ => None,
642    }
643}
644
645fn resolve_extension_from_content_type(content_type: &str) -> Option<&'static str> {
646    let lower = content_type.to_ascii_lowercase();
647    let media_type = lower
648        .split(';')
649        .next()
650        .unwrap_or("")
651        .split(',')
652        .next()
653        .unwrap_or("")
654        .trim();
655
656    match media_type {
657        "text/html"
658        | "application/xhtml+xml"
659        | "application/vnd.github.html"
660        | "application/vnd.github+html" => Some(".html"),
661        "text/markdown"
662        | "text/x-markdown"
663        | "application/markdown"
664        | "application/vnd.github.raw"
665        | "application/vnd.github+raw"
666        | "application/vnd.github.v3.raw"
667        | "text/plain" => Some(".md"),
668        "application/json" | "application/ld+json" => Some(".json"),
669        other if other.ends_with("+json") => Some(".json"),
670        "text/javascript" | "application/javascript" | "application/ecmascript" => Some(".js"),
671        "text/typescript" | "application/typescript" => Some(".ts"),
672        _ => None,
673    }
674}
675
676fn convert_html_body_to_markdown(body: &[u8], url: &str) -> Result<Vec<u8>, UrlFetchError> {
677    let html = String::from_utf8_lossy(body);
678    let mut markdown = html_to_markdown_converter()
679        .convert(&html)
680        .map_err(|error| {
681            UrlFetchError::new(format!(
682                "Failed to convert HTML from {url} to Markdown: {error}"
683            ))
684        })?;
685    if !markdown.ends_with('\n') {
686        markdown.push('\n');
687    }
688    Ok(markdown.into_bytes())
689}
690
691fn html_to_markdown_converter() -> HtmlToMarkdown {
692    HtmlToMarkdown::builder()
693        .skip_tags(vec![
694            "head", "script", "style", "nav", "footer", "aside", "noscript",
695        ])
696        .add_handler(
697            vec!["a"],
698            |handlers: &dyn Handlers, element: Element| -> Option<HandlerResult> {
699                if is_permalink_anchor(&element) {
700                    None
701                } else {
702                    handlers.fallback(element)
703                }
704            },
705        )
706        .add_handler(
707            vec!["header"],
708            |handlers: &dyn Handlers, element: Element| -> Option<HandlerResult> {
709                if should_skip_header(&element) {
710                    None
711                } else {
712                    handlers.fallback(element)
713                }
714            },
715        )
716        .add_handler(
717            vec!["span"],
718            |handlers: &dyn Handlers, element: Element| -> Option<HandlerResult> {
719                if element_has_class_token(&element, "token-line") {
720                    let mut content = handlers.walk_children(element.node).content;
721                    content.push('\n');
722                    Some(content.into())
723                } else {
724                    handlers.fallback(element)
725                }
726            },
727        )
728        .build()
729}
730
731fn is_permalink_anchor(element: &Element<'_>) -> bool {
732    element_has_class_token(element, "hash-link")
733        || element_attr_value(element, "aria-label")
734            .is_some_and(|value| value.to_ascii_lowercase().starts_with("direct link to"))
735}
736
737fn should_skip_header(element: &Element<'_>) -> bool {
738    element_has_class_token(element, "navbar")
739        || element_has_class_token(element, "site-header")
740        || element_has_class_token(element, "site-nav")
741        || element_has_class_token(element, "topbar")
742        || element_attr_value(element, "role")
743            .is_some_and(|value| value.eq_ignore_ascii_case("banner"))
744        || element_attr_value(element, "id").is_some_and(|value| {
745            let value = value.to_ascii_lowercase();
746            value.contains("navbar") || value.contains("site-header") || value.contains("site-nav")
747        })
748}
749
750fn element_has_class_token(element: &Element<'_>, token: &str) -> bool {
751    element_attr_value(element, "class")
752        .is_some_and(|value| value.split_ascii_whitespace().any(|class| class == token))
753}
754
755fn element_attr_value<'a>(element: &'a Element<'_>, name: &str) -> Option<&'a str> {
756    element
757        .attrs
758        .iter()
759        .find(|attr| attr.name.local.as_ref() == name)
760        .map(|attr| attr.value.as_ref())
761}
762
763enum BodyReadEvent {
764    Chunk(Vec<u8>),
765    Done,
766    Error(io::ErrorKind, String),
767}
768
769fn read_response_body(mut response: HttpResponse, url: &str) -> Result<Vec<u8>, UrlFetchError> {
770    let (tx, rx) = mpsc::channel();
771    thread::spawn(move || {
772        let mut buffer = [0u8; 16 * 1024];
773        loop {
774            match response.read(&mut buffer) {
775                Ok(0) => {
776                    let _ = tx.send(BodyReadEvent::Done);
777                    break;
778                }
779                Ok(n) => {
780                    if tx.send(BodyReadEvent::Chunk(buffer[..n].to_vec())).is_err() {
781                        break;
782                    }
783                }
784                Err(error) => {
785                    let kind = error.kind();
786                    let message = error.to_string();
787                    let _ = tx.send(BodyReadEvent::Error(kind, message));
788                    break;
789                }
790            }
791        }
792    });
793
794    let mut chunks = Vec::new();
795    let mut total = 0u64;
796    loop {
797        match rx.recv_timeout(BODY_CHUNK_TIMEOUT) {
798            Ok(BodyReadEvent::Chunk(chunk)) => {
799                total += chunk.len() as u64;
800                if total > MAX_RESPONSE_BYTES {
801                    return Err(UrlFetchError::new(format!(
802                        "Response exceeded {MAX_RESPONSE_BYTES} bytes, aborted"
803                    )));
804                }
805                chunks.extend_from_slice(&chunk);
806            }
807            Ok(BodyReadEvent::Done) => return Ok(chunks),
808            Ok(BodyReadEvent::Error(kind, _message)) if is_body_stall_kind(kind) => {
809                return Err(body_stall_error(url));
810            }
811            Ok(BodyReadEvent::Error(_, message)) => {
812                return Err(UrlFetchError::new(format!(
813                    "Failed to read response body for {url}: {message}"
814                )));
815            }
816            Err(mpsc::RecvTimeoutError::Timeout) => return Err(body_stall_error(url)),
817            Err(mpsc::RecvTimeoutError::Disconnected) => {
818                return Err(UrlFetchError::new(format!(
819                    "Failed to read response body for {url}: body reader stopped unexpectedly"
820                )));
821            }
822        }
823    }
824}
825
826fn body_stall_error(url: &str) -> UrlFetchError {
827    UrlFetchError::new(format!(
828        "Body read stalled (no data for {}ms) fetching {url}",
829        BODY_CHUNK_TIMEOUT.as_millis()
830    ))
831}
832
833fn is_body_stall_kind(kind: io::ErrorKind) -> bool {
834    matches!(kind, io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock)
835}
836
837fn atomic_write(
838    final_path: &Path,
839    bytes: &[u8],
840    options: &UrlFetchOptions,
841) -> Result<(), UrlFetchError> {
842    let parent = final_path.parent().unwrap_or_else(|| Path::new("."));
843    fs::create_dir_all(parent).map_err(|error| {
844        UrlFetchError::new(format!(
845            "Failed to create URL cache parent {}: {error}",
846            parent.display()
847        ))
848    })?;
849
850    let file_name = final_path
851        .file_name()
852        .and_then(|name| name.to_str())
853        .ok_or_else(|| {
854            UrlFetchError::new(format!("Invalid cache path: {}", final_path.display()))
855        })?;
856    let tmp_path = final_path.with_file_name(format!(
857        "{file_name}.tmp-{}-{}",
858        std::process::id(),
859        random_nonce()
860    ));
861
862    let write_result = (|| -> io::Result<()> {
863        let mut file = fs::File::create(&tmp_path)?;
864        file.write_all(bytes)?;
865        file.flush()?;
866        Ok(())
867    })();
868    if let Err(error) = write_result {
869        let _ = fs::remove_file(&tmp_path);
870        return Err(UrlFetchError::new(format!(
871            "Failed to write URL cache temp file {}: {error}",
872            tmp_path.display()
873        )));
874    }
875
876    if let Some(observer) = &options.atomic_write_observer {
877        observer(&tmp_path, final_path);
878    }
879
880    fs::rename(&tmp_path, final_path).map_err(|error| {
881        let _ = fs::remove_file(&tmp_path);
882        UrlFetchError::new(format!(
883            "Failed to finalize URL cache file {}: {error}",
884            final_path.display()
885        ))
886    })
887}
888
889fn random_nonce() -> String {
890    let mut bytes = [0u8; 8];
891    if getrandom::fill(&mut bytes).is_err() {
892        let fallback = now_ms() ^ u64::from(std::process::id());
893        bytes = fallback.to_le_bytes();
894    }
895    let mut out = String::with_capacity(bytes.len() * 2);
896    for byte in bytes {
897        use std::fmt::Write as _;
898        let _ = write!(out, "{byte:02x}");
899    }
900    out
901}
902
903fn now_ms() -> u64 {
904    SystemTime::now()
905        .duration_since(UNIX_EPOCH)
906        .unwrap_or_default()
907        .as_millis()
908        .try_into()
909        .unwrap_or(u64::MAX)
910}
911
912fn reqwest_error_detail(error: &reqwest::Error) -> String {
913    if error.is_timeout() {
914        return format!("timeout: {error}");
915    }
916    if let Some(source) = error.source() {
917        return format!("{source}");
918    }
919    error.to_string()
920}
921
922#[cfg(test)]
923mod tests {
924    use super::*;
925    use url::Url;
926
927    #[test]
928    fn extension_from_path_uses_parser_mapping() {
929        let url = Url::parse("https://example.com/pkg/index.mjs").unwrap();
930        let (ext, from_path) = resolve_fetch_extension(&url, "text/javascript").unwrap();
931        assert_eq!(ext, ".mjs");
932        assert!(from_path);
933    }
934
935    #[test]
936    fn text_plain_rs_url_ignores_content_type_gate() {
937        let url = Url::parse("https://raw.githubusercontent.com/o/r/main/lib.rs").unwrap();
938        let (ext, from_path) = resolve_fetch_extension(&url, "text/plain").unwrap();
939        assert_eq!(ext, ".rs");
940        assert!(from_path);
941    }
942
943    #[test]
944    fn extensionless_javascript_maps_to_js() {
945        let url = Url::parse("https://cdn.example/bundle").unwrap();
946        let (ext, from_path) = resolve_fetch_extension(&url, "text/javascript").unwrap();
947        assert_eq!(ext, ".js");
948        assert!(!from_path);
949    }
950
951    #[test]
952    fn extensionless_plain_stays_md() {
953        let url = Url::parse("https://example.com/readme").unwrap();
954        let (ext, _) = resolve_fetch_extension(&url, "text/plain").unwrap();
955        assert_eq!(ext, ".md");
956    }
957
958    #[test]
959    fn query_and_fragment_do_not_break_path_extension() {
960        let url = Url::parse("https://example.com/src/file.ts?v=2#L10").unwrap();
961        let (ext, from_path) = resolve_fetch_extension(&url, "text/plain").unwrap();
962        assert_eq!(ext, ".ts");
963        assert!(from_path);
964    }
965
966    #[test]
967    fn percent_encoded_path_segment() {
968        let url = Url::parse("https://example.com/foo%2Fbar.rs").unwrap();
969        // path is /foo%2Fbar.rs - segment is foo%2Fbar.rs -> decode to foo/bar.rs -> ext rs
970        let (ext, _) = resolve_fetch_extension(&url, "text/plain").unwrap();
971        assert_eq!(ext, ".rs");
972    }
973
974    #[test]
975    fn binary_sniff_detects_nul() {
976        let mut body = vec![b'f', b'n', 0, b' '];
977        assert!(body_contains_nul_in_prefix(&body));
978        body = vec![b'h'; 9000];
979        assert!(!body_contains_nul_in_prefix(&body));
980    }
981
982    #[test]
983    fn unsupported_pdf_still_errors_via_resolve() {
984        let url = Url::parse("https://example.com/doc.pdf").unwrap();
985        assert!(resolve_fetch_extension(&url, "application/pdf").is_none());
986    }
987}