1use std::error::Error;
2use std::fmt;
3use std::fs;
4use std::io::{self, Read, Write};
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
6use std::path::{Path, PathBuf};
7use std::sync::{mpsc, Arc};
8use std::thread;
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10
11use htmd::{
12 element_handler::{HandlerResult, Handlers},
13 Element, HtmlToMarkdown,
14};
15use reqwest::blocking::{Client, Response as HttpResponse};
16use reqwest::header::{ACCEPT, CONTENT_TYPE, LOCATION, USER_AGENT};
17use reqwest::redirect::Policy;
18use serde::{Deserialize, Serialize};
19use sha2::{Digest, Sha256};
20use url::Url;
21
22use crate::parser::detect_language;
23
24const MAX_RESPONSE_BYTES: u64 = 10 * 1024 * 1024;
25const CACHE_TTL_MS: u64 = 24 * 60 * 60 * 1000;
26const CONNECT_TIMEOUT: Duration = Duration::from_millis(30_000);
27const BODY_CHUNK_TIMEOUT: Duration = Duration::from_millis(15_000);
28const MAX_REDIRECTS: usize = 5;
29
30const TRANSIENT_RETRY_ATTEMPTS: usize = 2;
41const TRANSIENT_RETRY_BACKOFFS_MS: [u64; TRANSIENT_RETRY_ATTEMPTS] = [200, 600];
42const ACCEPT_HEADER: &str = "application/vnd.github.raw, text/markdown, text/x-markdown, text/html;q=0.9, application/json;q=0.8, text/plain;q=0.5";
43const USER_AGENT_VALUE: &str = "aft-opencode-plugin";
44const CONVERTED_MARKDOWN_CONTENT_TYPE: &str = "text/markdown; charset=utf-8";
45
46#[derive(Clone, Default)]
47pub struct UrlFetchOptions {
48 pub allow_private: bool,
49 #[doc(hidden)]
52 pub public_host_overrides: Vec<(String, Vec<IpAddr>)>,
53 #[doc(hidden)]
56 pub connect_overrides: Vec<(String, SocketAddr)>,
57 #[doc(hidden)]
59 pub atomic_write_observer: Option<Arc<dyn Fn(&Path, &Path) + Send + Sync>>,
60}
61
62#[derive(Debug, Clone)]
63pub struct UrlFetchError {
64 message: String,
65}
66
67impl UrlFetchError {
68 fn new(message: impl Into<String>) -> Self {
69 Self {
70 message: message.into(),
71 }
72 }
73}
74
75impl fmt::Display for UrlFetchError {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.write_str(&self.message)
78 }
79}
80
81impl std::error::Error for UrlFetchError {}
82
83#[derive(Debug, Serialize, Deserialize)]
84struct CacheMeta {
85 url: String,
86 #[serde(rename = "contentType")]
87 content_type: String,
88 extension: String,
89 #[serde(rename = "fetchedAt")]
90 fetched_at: u64,
91}
92
93pub fn is_http_url(value: &str) -> bool {
94 value.starts_with("http://") || value.starts_with("https://")
95}
96
97pub fn fetch_url_to_cache(
98 url: &str,
99 storage_dir: &Path,
100 options: UrlFetchOptions,
101) -> Result<PathBuf, UrlFetchError> {
102 let parsed = Url::parse(url).map_err(|_| UrlFetchError::new(format!("Invalid URL: {url}")))?;
103 validate_public_url(&parsed, &options)?;
104
105 let dir = cache_dir(storage_dir);
106 fs::create_dir_all(&dir).map_err(|error| {
107 UrlFetchError::new(format!(
108 "Failed to create URL cache directory {}: {error}",
109 dir.display()
110 ))
111 })?;
112
113 let hash = hash_url(url);
114 let meta_file = meta_path(storage_dir, &hash);
115 if let Some(cached) = fresh_cached_path(storage_dir, &hash, &meta_file, &parsed)? {
116 return Ok(cached);
117 }
118
119 let response = fetch_with_redirects(&parsed, url, &options)?;
120 if !response.status().is_success() {
121 return Err(UrlFetchError::new(format!(
122 "HTTP {} {} fetching {url}",
123 response.status().as_u16(),
124 response.status().canonical_reason().unwrap_or("")
125 )));
126 }
127
128 let content_type = response
129 .headers()
130 .get(CONTENT_TYPE)
131 .and_then(|value| value.to_str().ok())
132 .unwrap_or("text/plain")
133 .to_string();
134 let (extension, from_source_path) =
135 resolve_fetch_extension(&parsed, &content_type).ok_or_else(|| {
136 UrlFetchError::new(format!(
137 "Unsupported content type '{content_type}' for {url}. Supported: text/html, text/markdown, application/json, text/plain; source files via URL path extension (e.g. .rs, .ts, .mjs)"
138 ))
139 })?;
140
141 if let Some(length) = response.content_length() {
142 if length > MAX_RESPONSE_BYTES {
143 return Err(UrlFetchError::new(format!(
144 "Response too large: {length} bytes (max {MAX_RESPONSE_BYTES})"
145 )));
146 }
147 }
148
149 let body = read_response_body(response, url)?;
150 if from_source_path && body_contains_nul_in_prefix(&body) {
151 return Err(UrlFetchError::new(format!(
152 "Binary content detected for source URL {url}"
153 )));
154 }
155 let (body, content_type, extension) = if extension == ".html" {
156 (
157 convert_html_body_to_markdown(&body, url)?,
158 CONVERTED_MARKDOWN_CONTENT_TYPE.to_string(),
159 ".md",
160 )
161 } else {
162 (body, content_type, extension)
163 };
164
165 let content_file = content_path(storage_dir, &hash, extension);
166 atomic_write(&content_file, &body, &options)?;
167
168 let meta = CacheMeta {
169 url: url.to_string(),
170 content_type,
171 extension: extension.to_string(),
172 fetched_at: now_ms(),
173 };
174 let meta_bytes = serde_json::to_vec(&meta).map_err(|error| {
175 UrlFetchError::new(format!("Failed to encode URL cache metadata: {error}"))
176 })?;
177 atomic_write(&meta_file, &meta_bytes, &options)?;
178
179 Ok(content_file)
180}
181
182pub fn cleanup_url_cache(storage_dir: &Path) -> Result<usize, UrlFetchError> {
183 let dir = cache_dir(storage_dir);
184 if !dir.exists() {
185 return Ok(0);
186 }
187
188 let entries = fs::read_dir(&dir).map_err(|error| {
189 UrlFetchError::new(format!(
190 "URL cache cleanup failed reading {}: {error}",
191 dir.display()
192 ))
193 })?;
194 let mut removed = 0usize;
195 let now = now_ms();
196
197 for entry in entries.flatten() {
198 let path = entry.path();
199 let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
200 continue;
201 };
202 if !name.ends_with(".meta.json") {
203 continue;
204 }
205
206 let meta = fs::read_to_string(&path)
207 .ok()
208 .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok());
209 let Some(meta) = meta else {
210 if fs::remove_file(&path).is_ok() {
211 removed += 1;
212 }
213 continue;
214 };
215
216 if now.saturating_sub(meta.fetched_at) <= CACHE_TTL_MS {
217 continue;
218 }
219
220 let hash = name.trim_end_matches(".meta.json");
221 let content = content_path(storage_dir, hash, &meta.extension);
222 let _ = fs::remove_file(content);
223 if fs::remove_file(&path).is_ok() {
224 removed += 1;
225 }
226 }
227
228 Ok(removed)
229}
230
231#[doc(hidden)]
232pub fn cache_content_path_for_url(storage_dir: &Path, url: &str, extension: &str) -> PathBuf {
233 content_path(storage_dir, &hash_url(url), extension)
234}
235
236#[doc(hidden)]
237pub fn cache_meta_path_for_url(storage_dir: &Path, url: &str) -> PathBuf {
238 meta_path(storage_dir, &hash_url(url))
239}
240
241#[doc(hidden)]
242pub fn is_private_ip_for_test(ip: IpAddr) -> bool {
243 is_private_ip(ip)
244}
245
246fn cache_dir(storage_dir: &Path) -> PathBuf {
247 storage_dir.join("url_cache")
248}
249
250fn hash_url(url: &str) -> String {
251 let digest = Sha256::digest(url.as_bytes());
252 format!("{digest:x}").chars().take(16).collect()
253}
254
255fn meta_path(storage_dir: &Path, hash: &str) -> PathBuf {
256 cache_dir(storage_dir).join(format!("{hash}.meta.json"))
257}
258
259fn content_path(storage_dir: &Path, hash: &str, extension: &str) -> PathBuf {
260 cache_dir(storage_dir).join(format!("{hash}{extension}"))
261}
262
263fn fresh_cached_path(
264 storage_dir: &Path,
265 hash: &str,
266 meta_file: &Path,
267 url: &Url,
268) -> Result<Option<PathBuf>, UrlFetchError> {
269 if !meta_file.exists() {
270 return Ok(None);
271 }
272
273 let meta = match fs::read_to_string(meta_file)
274 .ok()
275 .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok())
276 {
277 Some(meta) => meta,
278 None => return Ok(None),
279 };
280 let age = now_ms().saturating_sub(meta.fetched_at);
281 if meta.extension == ".html" {
282 return Ok(None);
283 }
284
285 let content_type = meta.content_type.as_str();
286 let current = resolve_fetch_extension(url, content_type);
287 let expected_ext = current.map(|(ext, _)| ext);
288 if expected_ext != Some(meta.extension.as_str()) {
289 return Ok(None);
290 }
291
292 let cached = content_path(storage_dir, hash, &meta.extension);
293 if age < CACHE_TTL_MS && cached.exists() {
294 return Ok(Some(cached));
295 }
296 Ok(None)
297}
298
299fn fetch_with_redirects(
300 start_url: &Url,
301 original_url: &str,
302 options: &UrlFetchOptions,
303) -> Result<HttpResponse, UrlFetchError> {
304 let client = build_client(options)?;
305 let mut current_url = start_url.clone();
306
307 for redirect_count in 0..=MAX_REDIRECTS {
308 validate_public_url(¤t_url, options)?;
309 let response = send_with_transient_retries(&client, ¤t_url)?;
310
311 if !response.status().is_redirection() {
312 return Ok(response);
313 }
314 if redirect_count == MAX_REDIRECTS {
315 return Err(UrlFetchError::new(format!(
316 "Too many redirects fetching {original_url}"
317 )));
318 }
319
320 let location = response
321 .headers()
322 .get(LOCATION)
323 .and_then(|value| value.to_str().ok())
324 .ok_or_else(|| {
325 UrlFetchError::new(format!(
326 "Redirect from {} missing Location header",
327 current_url.as_str()
328 ))
329 })?;
330 current_url = current_url.join(location).map_err(|error| {
331 UrlFetchError::new(format!(
332 "Invalid redirect Location '{location}' from {}: {error}",
333 current_url.as_str()
334 ))
335 })?;
336 }
337
338 Err(UrlFetchError::new(format!(
339 "Too many redirects fetching {original_url}"
340 )))
341}
342
343fn send_with_transient_retries(
352 client: &Client,
353 target: &Url,
354) -> Result<HttpResponse, UrlFetchError> {
355 let mut last_error: Option<reqwest::Error> = None;
356 for attempt in 0..=TRANSIENT_RETRY_ATTEMPTS {
357 let result = client
358 .get(target.clone())
359 .header(USER_AGENT, USER_AGENT_VALUE)
360 .header(ACCEPT, ACCEPT_HEADER)
361 .send();
362 match result {
363 Ok(response) => return Ok(response),
364 Err(error) => {
365 if attempt < TRANSIENT_RETRY_ATTEMPTS && is_transient_reqwest_error(&error) {
366 thread::sleep(Duration::from_millis(TRANSIENT_RETRY_BACKOFFS_MS[attempt]));
367 last_error = Some(error);
368 continue;
369 }
370 return Err(UrlFetchError::new(format!(
371 "Failed to fetch {}: {}",
372 target.as_str(),
373 reqwest_error_detail(&error)
374 )));
375 }
376 }
377 }
378 Err(UrlFetchError::new(format!(
381 "Failed to fetch {} after {} retries: {}",
382 target.as_str(),
383 TRANSIENT_RETRY_ATTEMPTS,
384 last_error
385 .as_ref()
386 .map(reqwest_error_detail)
387 .unwrap_or_else(|| "unknown transient error".to_string())
388 )))
389}
390
391fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
401 error.is_connect() || error.is_timeout() || error.is_request()
402}
403
404fn build_client(options: &UrlFetchOptions) -> Result<Client, UrlFetchError> {
405 let mut builder = Client::builder()
406 .redirect(Policy::none())
407 .connect_timeout(CONNECT_TIMEOUT);
408
409 for (host, address) in &options.connect_overrides {
410 builder = builder.resolve(host, *address);
411 }
412
413 builder
414 .build()
415 .map_err(|error| UrlFetchError::new(format!("Failed to build URL fetch client: {error}")))
416}
417
418fn validate_public_url(url: &Url, options: &UrlFetchOptions) -> Result<(), UrlFetchError> {
419 if url.scheme() != "http" && url.scheme() != "https" {
420 return Err(UrlFetchError::new(format!(
421 "Only http:// and https:// URLs are supported, got: {}:",
422 url.scheme()
423 )));
424 }
425 if options.allow_private {
426 return Ok(());
427 }
428
429 let host = url
430 .host_str()
431 .ok_or_else(|| UrlFetchError::new(format!("URL missing host: {url}")))?;
432 let host_for_parse = host
433 .trim_matches(['[', ']'])
434 .split('%')
435 .next()
436 .unwrap_or(host);
437
438 if let Ok(ip) = host_for_parse.parse::<IpAddr>() {
439 reject_private_ip(host, ip)?;
440 return Ok(());
441 }
442 if host_for_parse.contains(':') {
443 return Err(UrlFetchError::new(format!(
444 "Blocked private URL host {host} ({host_for_parse})"
445 )));
446 }
447
448 let addresses = resolve_host_ips(host_for_parse, url.port_or_known_default(), options)?;
449 if addresses.is_empty() {
450 return Err(UrlFetchError::new(format!(
451 "Failed to resolve URL host {host}"
452 )));
453 }
454 for ip in addresses {
455 reject_private_ip(host, ip)?;
456 }
457
458 Ok(())
464}
465
466fn resolve_host_ips(
467 host: &str,
468 port: Option<u16>,
469 options: &UrlFetchOptions,
470) -> Result<Vec<IpAddr>, UrlFetchError> {
471 if let Some((_, ips)) = options
472 .public_host_overrides
473 .iter()
474 .find(|(override_host, _)| override_host == host)
475 {
476 return Ok(ips.clone());
477 }
478
479 let port = port.unwrap_or(80);
480 let addrs = (host, port).to_socket_addrs().map_err(|error| {
481 UrlFetchError::new(format!("Failed to resolve URL host {host}: {error}"))
482 })?;
483 Ok(addrs.map(|addr| addr.ip()).collect())
484}
485
486fn reject_private_ip(host: &str, ip: IpAddr) -> Result<(), UrlFetchError> {
487 if is_private_ip(ip) {
488 return Err(UrlFetchError::new(format!(
489 "Blocked private URL host {host} ({ip})"
490 )));
491 }
492 Ok(())
493}
494
495pub fn is_private_or_reserved_ip(ip: IpAddr) -> bool {
501 is_private_ip(ip)
502}
503
504fn is_private_ip(ip: IpAddr) -> bool {
505 match ip {
506 IpAddr::V4(ipv4) => is_private_ipv4(ipv4),
507 IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
508 }
509}
510
511fn is_private_ipv4(ip: Ipv4Addr) -> bool {
512 let [a, b, _, _] = ip.octets();
513 a == 0
514 || a == 10
515 || a == 127
516 || (a == 172 && (16..=31).contains(&b))
517 || (a == 192 && b == 168)
518 || (a == 169 && b == 254)
519 || (a == 100 && (64..=127).contains(&b))
523 || (a == 198 && (18..=19).contains(&b))
525 || a >= 224
526}
527
528fn is_private_ipv6(ip: Ipv6Addr) -> bool {
529 let segments = ip.segments();
530 let top_six_zero = segments[..6].iter().all(|segment| *segment == 0);
531 let is_mapped = segments[..5].iter().all(|segment| *segment == 0) && segments[5] == 0xffff;
532 if is_mapped || top_six_zero {
533 let embedded = Ipv4Addr::new(
534 (segments[6] >> 8) as u8,
535 (segments[6] & 0xff) as u8,
536 (segments[7] >> 8) as u8,
537 (segments[7] & 0xff) as u8,
538 );
539 return is_private_ipv4(embedded);
540 }
541
542 let first = segments[0];
543 (0xfe80..=0xfebf).contains(&first) || (0xfc00..=0xfdff).contains(&first) || first >= 0xff00
544}
545
546const BINARY_SNIFF_PREFIX: usize = 8 * 1024;
547
548fn body_contains_nul_in_prefix(body: &[u8]) -> bool {
549 let end = body.len().min(BINARY_SNIFF_PREFIX);
550 body[..end].contains(&0)
551}
552
553fn resolve_fetch_extension(url: &Url, content_type: &str) -> Option<(&'static str, bool)> {
556 if let Some(ext) = extension_from_url_path(url) {
557 return Some((ext, true));
558 }
559 resolve_extension_from_content_type(content_type).map(|ext| (ext, false))
560}
561
562fn extension_from_url_path(url: &Url) -> Option<&'static str> {
563 let path = url.path();
564 if path.is_empty() || path == "/" {
565 return None;
566 }
567 let segment = path.rsplit('/').next().unwrap_or(path);
568 let file_name = percent_decode_path_segment(segment);
569 let dot = file_name.rfind('.')?;
570 let ext = &file_name[dot + 1..];
571 if ext.is_empty() {
572 return None;
573 }
574 let probe = Path::new("file").with_extension(ext);
575 if detect_language(&probe).is_some() {
576 static_extension_for_lang_ext(ext)
577 } else {
578 None
579 }
580}
581
582fn percent_decode_path_segment(segment: &str) -> String {
583 let mut out = String::with_capacity(segment.len());
584 let bytes = segment.as_bytes();
585 let mut i = 0;
586 while i < bytes.len() {
587 if bytes[i] == b'%' && i + 2 < bytes.len() {
588 if let (Some(h1), Some(h2)) = (from_hex(bytes[i + 1]), from_hex(bytes[i + 2])) {
589 out.push(char::from(h1 << 4 | h2));
590 i += 3;
591 continue;
592 }
593 }
594 out.push(bytes[i] as char);
595 i += 1;
596 }
597 out
598}
599
600fn from_hex(byte: u8) -> Option<u8> {
601 match byte {
602 b'0'..=b'9' => Some(byte - b'0'),
603 b'a'..=b'f' => Some(byte - b'a' + 10),
604 b'A'..=b'F' => Some(byte - b'A' + 10),
605 _ => None,
606 }
607}
608
609fn static_extension_for_lang_ext(ext: &str) -> Option<&'static str> {
611 match ext.to_ascii_lowercase().as_str() {
612 "ts" | "mts" | "cts" => Some(".ts"),
613 "tsx" => Some(".tsx"),
614 "js" => Some(".js"),
615 "jsx" => Some(".jsx"),
616 "mjs" => Some(".mjs"),
617 "cjs" => Some(".cjs"),
618 "py" | "pyi" => Some(".py"),
619 "rs" => Some(".rs"),
620 "go" => Some(".go"),
621 "c" | "h" => Some(".c"),
622 "cc" | "cpp" | "cxx" | "hpp" | "hh" => Some(".cpp"),
623 "zig" => Some(".zig"),
624 "cs" => Some(".cs"),
625 "sh" | "bash" | "zsh" => Some(".sh"),
626 "html" | "htm" => Some(".html"),
627 "md" | "markdown" | "mdx" => Some(".md"),
628 "sol" => Some(".sol"),
629 "scss" => Some(".scss"),
630 "vue" => Some(".vue"),
631 "json" | "jsonc" => Some(".json"),
632 "scala" | "sc" => Some(".scala"),
633 "java" => Some(".java"),
634 "rb" => Some(".rb"),
635 "kt" | "kts" => Some(".kt"),
636 "swift" => Some(".swift"),
637 "inc" | "php" => Some(".php"),
638 "lua" => Some(".lua"),
639 "pl" | "pm" | "t" => Some(".pl"),
640 "yaml" | "yml" => Some(".yaml"),
641 _ => None,
642 }
643}
644
645fn resolve_extension_from_content_type(content_type: &str) -> Option<&'static str> {
646 let lower = content_type.to_ascii_lowercase();
647 let media_type = lower
648 .split(';')
649 .next()
650 .unwrap_or("")
651 .split(',')
652 .next()
653 .unwrap_or("")
654 .trim();
655
656 match media_type {
657 "text/html"
658 | "application/xhtml+xml"
659 | "application/vnd.github.html"
660 | "application/vnd.github+html" => Some(".html"),
661 "text/markdown"
662 | "text/x-markdown"
663 | "application/markdown"
664 | "application/vnd.github.raw"
665 | "application/vnd.github+raw"
666 | "application/vnd.github.v3.raw"
667 | "text/plain" => Some(".md"),
668 "application/json" | "application/ld+json" => Some(".json"),
669 other if other.ends_with("+json") => Some(".json"),
670 "text/javascript" | "application/javascript" | "application/ecmascript" => Some(".js"),
671 "text/typescript" | "application/typescript" => Some(".ts"),
672 _ => None,
673 }
674}
675
676fn convert_html_body_to_markdown(body: &[u8], url: &str) -> Result<Vec<u8>, UrlFetchError> {
677 let html = String::from_utf8_lossy(body);
678 let mut markdown = html_to_markdown_converter()
679 .convert(&html)
680 .map_err(|error| {
681 UrlFetchError::new(format!(
682 "Failed to convert HTML from {url} to Markdown: {error}"
683 ))
684 })?;
685 if !markdown.ends_with('\n') {
686 markdown.push('\n');
687 }
688 Ok(markdown.into_bytes())
689}
690
691fn html_to_markdown_converter() -> HtmlToMarkdown {
692 HtmlToMarkdown::builder()
693 .skip_tags(vec![
694 "head", "script", "style", "nav", "footer", "aside", "noscript",
695 ])
696 .add_handler(
697 vec!["a"],
698 |handlers: &dyn Handlers, element: Element| -> Option<HandlerResult> {
699 if is_permalink_anchor(&element) {
700 None
701 } else {
702 handlers.fallback(element)
703 }
704 },
705 )
706 .add_handler(
707 vec!["header"],
708 |handlers: &dyn Handlers, element: Element| -> Option<HandlerResult> {
709 if should_skip_header(&element) {
710 None
711 } else {
712 handlers.fallback(element)
713 }
714 },
715 )
716 .add_handler(
717 vec!["span"],
718 |handlers: &dyn Handlers, element: Element| -> Option<HandlerResult> {
719 if element_has_class_token(&element, "token-line") {
720 let mut content = handlers.walk_children(element.node).content;
721 content.push('\n');
722 Some(content.into())
723 } else {
724 handlers.fallback(element)
725 }
726 },
727 )
728 .build()
729}
730
731fn is_permalink_anchor(element: &Element<'_>) -> bool {
732 element_has_class_token(element, "hash-link")
733 || element_attr_value(element, "aria-label")
734 .is_some_and(|value| value.to_ascii_lowercase().starts_with("direct link to"))
735}
736
737fn should_skip_header(element: &Element<'_>) -> bool {
738 element_has_class_token(element, "navbar")
739 || element_has_class_token(element, "site-header")
740 || element_has_class_token(element, "site-nav")
741 || element_has_class_token(element, "topbar")
742 || element_attr_value(element, "role")
743 .is_some_and(|value| value.eq_ignore_ascii_case("banner"))
744 || element_attr_value(element, "id").is_some_and(|value| {
745 let value = value.to_ascii_lowercase();
746 value.contains("navbar") || value.contains("site-header") || value.contains("site-nav")
747 })
748}
749
750fn element_has_class_token(element: &Element<'_>, token: &str) -> bool {
751 element_attr_value(element, "class")
752 .is_some_and(|value| value.split_ascii_whitespace().any(|class| class == token))
753}
754
755fn element_attr_value<'a>(element: &'a Element<'_>, name: &str) -> Option<&'a str> {
756 element
757 .attrs
758 .iter()
759 .find(|attr| attr.name.local.as_ref() == name)
760 .map(|attr| attr.value.as_ref())
761}
762
763enum BodyReadEvent {
764 Chunk(Vec<u8>),
765 Done,
766 Error(io::ErrorKind, String),
767}
768
769fn read_response_body(mut response: HttpResponse, url: &str) -> Result<Vec<u8>, UrlFetchError> {
770 let (tx, rx) = mpsc::channel();
771 thread::spawn(move || {
772 let mut buffer = [0u8; 16 * 1024];
773 loop {
774 match response.read(&mut buffer) {
775 Ok(0) => {
776 let _ = tx.send(BodyReadEvent::Done);
777 break;
778 }
779 Ok(n) => {
780 if tx.send(BodyReadEvent::Chunk(buffer[..n].to_vec())).is_err() {
781 break;
782 }
783 }
784 Err(error) => {
785 let kind = error.kind();
786 let message = error.to_string();
787 let _ = tx.send(BodyReadEvent::Error(kind, message));
788 break;
789 }
790 }
791 }
792 });
793
794 let mut chunks = Vec::new();
795 let mut total = 0u64;
796 loop {
797 match rx.recv_timeout(BODY_CHUNK_TIMEOUT) {
798 Ok(BodyReadEvent::Chunk(chunk)) => {
799 total += chunk.len() as u64;
800 if total > MAX_RESPONSE_BYTES {
801 return Err(UrlFetchError::new(format!(
802 "Response exceeded {MAX_RESPONSE_BYTES} bytes, aborted"
803 )));
804 }
805 chunks.extend_from_slice(&chunk);
806 }
807 Ok(BodyReadEvent::Done) => return Ok(chunks),
808 Ok(BodyReadEvent::Error(kind, _message)) if is_body_stall_kind(kind) => {
809 return Err(body_stall_error(url));
810 }
811 Ok(BodyReadEvent::Error(_, message)) => {
812 return Err(UrlFetchError::new(format!(
813 "Failed to read response body for {url}: {message}"
814 )));
815 }
816 Err(mpsc::RecvTimeoutError::Timeout) => return Err(body_stall_error(url)),
817 Err(mpsc::RecvTimeoutError::Disconnected) => {
818 return Err(UrlFetchError::new(format!(
819 "Failed to read response body for {url}: body reader stopped unexpectedly"
820 )));
821 }
822 }
823 }
824}
825
826fn body_stall_error(url: &str) -> UrlFetchError {
827 UrlFetchError::new(format!(
828 "Body read stalled (no data for {}ms) fetching {url}",
829 BODY_CHUNK_TIMEOUT.as_millis()
830 ))
831}
832
833fn is_body_stall_kind(kind: io::ErrorKind) -> bool {
834 matches!(kind, io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock)
835}
836
837fn atomic_write(
838 final_path: &Path,
839 bytes: &[u8],
840 options: &UrlFetchOptions,
841) -> Result<(), UrlFetchError> {
842 let parent = final_path.parent().unwrap_or_else(|| Path::new("."));
843 fs::create_dir_all(parent).map_err(|error| {
844 UrlFetchError::new(format!(
845 "Failed to create URL cache parent {}: {error}",
846 parent.display()
847 ))
848 })?;
849
850 let file_name = final_path
851 .file_name()
852 .and_then(|name| name.to_str())
853 .ok_or_else(|| {
854 UrlFetchError::new(format!("Invalid cache path: {}", final_path.display()))
855 })?;
856 let tmp_path = final_path.with_file_name(format!(
857 "{file_name}.tmp-{}-{}",
858 std::process::id(),
859 random_nonce()
860 ));
861
862 let write_result = (|| -> io::Result<()> {
863 let mut file = fs::File::create(&tmp_path)?;
864 file.write_all(bytes)?;
865 file.flush()?;
866 Ok(())
867 })();
868 if let Err(error) = write_result {
869 let _ = fs::remove_file(&tmp_path);
870 return Err(UrlFetchError::new(format!(
871 "Failed to write URL cache temp file {}: {error}",
872 tmp_path.display()
873 )));
874 }
875
876 if let Some(observer) = &options.atomic_write_observer {
877 observer(&tmp_path, final_path);
878 }
879
880 fs::rename(&tmp_path, final_path).map_err(|error| {
881 let _ = fs::remove_file(&tmp_path);
882 UrlFetchError::new(format!(
883 "Failed to finalize URL cache file {}: {error}",
884 final_path.display()
885 ))
886 })
887}
888
889fn random_nonce() -> String {
890 let mut bytes = [0u8; 8];
891 if getrandom::fill(&mut bytes).is_err() {
892 let fallback = now_ms() ^ u64::from(std::process::id());
893 bytes = fallback.to_le_bytes();
894 }
895 let mut out = String::with_capacity(bytes.len() * 2);
896 for byte in bytes {
897 use std::fmt::Write as _;
898 let _ = write!(out, "{byte:02x}");
899 }
900 out
901}
902
903fn now_ms() -> u64 {
904 SystemTime::now()
905 .duration_since(UNIX_EPOCH)
906 .unwrap_or_default()
907 .as_millis()
908 .try_into()
909 .unwrap_or(u64::MAX)
910}
911
912fn reqwest_error_detail(error: &reqwest::Error) -> String {
913 if error.is_timeout() {
914 return format!("timeout: {error}");
915 }
916 if let Some(source) = error.source() {
917 return format!("{source}");
918 }
919 error.to_string()
920}
921
922#[cfg(test)]
923mod tests {
924 use super::*;
925 use url::Url;
926
927 #[test]
928 fn extension_from_path_uses_parser_mapping() {
929 let url = Url::parse("https://example.com/pkg/index.mjs").unwrap();
930 let (ext, from_path) = resolve_fetch_extension(&url, "text/javascript").unwrap();
931 assert_eq!(ext, ".mjs");
932 assert!(from_path);
933 }
934
935 #[test]
936 fn text_plain_rs_url_ignores_content_type_gate() {
937 let url = Url::parse("https://raw.githubusercontent.com/o/r/main/lib.rs").unwrap();
938 let (ext, from_path) = resolve_fetch_extension(&url, "text/plain").unwrap();
939 assert_eq!(ext, ".rs");
940 assert!(from_path);
941 }
942
943 #[test]
944 fn extensionless_javascript_maps_to_js() {
945 let url = Url::parse("https://cdn.example/bundle").unwrap();
946 let (ext, from_path) = resolve_fetch_extension(&url, "text/javascript").unwrap();
947 assert_eq!(ext, ".js");
948 assert!(!from_path);
949 }
950
951 #[test]
952 fn extensionless_plain_stays_md() {
953 let url = Url::parse("https://example.com/readme").unwrap();
954 let (ext, _) = resolve_fetch_extension(&url, "text/plain").unwrap();
955 assert_eq!(ext, ".md");
956 }
957
958 #[test]
959 fn query_and_fragment_do_not_break_path_extension() {
960 let url = Url::parse("https://example.com/src/file.ts?v=2#L10").unwrap();
961 let (ext, from_path) = resolve_fetch_extension(&url, "text/plain").unwrap();
962 assert_eq!(ext, ".ts");
963 assert!(from_path);
964 }
965
966 #[test]
967 fn percent_encoded_path_segment() {
968 let url = Url::parse("https://example.com/foo%2Fbar.rs").unwrap();
969 let (ext, _) = resolve_fetch_extension(&url, "text/plain").unwrap();
971 assert_eq!(ext, ".rs");
972 }
973
974 #[test]
975 fn binary_sniff_detects_nul() {
976 let mut body = vec![b'f', b'n', 0, b' '];
977 assert!(body_contains_nul_in_prefix(&body));
978 body = vec![b'h'; 9000];
979 assert!(!body_contains_nul_in_prefix(&body));
980 }
981
982 #[test]
983 fn unsupported_pdf_still_errors_via_resolve() {
984 let url = Url::parse("https://example.com/doc.pdf").unwrap();
985 assert!(resolve_fetch_extension(&url, "application/pdf").is_none());
986 }
987}