1use std::error::Error;
2use std::fmt;
3use std::fs;
4use std::io::{self, Read, Write};
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
6use std::path::{Path, PathBuf};
7use std::sync::{mpsc, Arc};
8use std::thread;
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10
11use reqwest::blocking::{Client, Response as HttpResponse};
12use reqwest::header::{ACCEPT, CONTENT_TYPE, LOCATION, USER_AGENT};
13use reqwest::redirect::Policy;
14use serde::{Deserialize, Serialize};
15use sha2::{Digest, Sha256};
16use url::Url;
17
18const MAX_RESPONSE_BYTES: u64 = 10 * 1024 * 1024;
19const CACHE_TTL_MS: u64 = 24 * 60 * 60 * 1000;
20const CONNECT_TIMEOUT: Duration = Duration::from_millis(30_000);
21const BODY_CHUNK_TIMEOUT: Duration = Duration::from_millis(15_000);
22const MAX_REDIRECTS: usize = 5;
23
24const TRANSIENT_RETRY_ATTEMPTS: usize = 2;
35const TRANSIENT_RETRY_BACKOFFS_MS: [u64; TRANSIENT_RETRY_ATTEMPTS] = [200, 600];
36const ACCEPT_HEADER: &str = "application/vnd.github.raw, text/markdown, text/x-markdown, text/html;q=0.9, application/json;q=0.8, text/plain;q=0.5";
37const USER_AGENT_VALUE: &str = "aft-opencode-plugin";
38
39#[derive(Clone, Default)]
40pub struct UrlFetchOptions {
41 pub allow_private: bool,
42 #[doc(hidden)]
45 pub public_host_overrides: Vec<(String, Vec<IpAddr>)>,
46 #[doc(hidden)]
49 pub connect_overrides: Vec<(String, SocketAddr)>,
50 #[doc(hidden)]
52 pub atomic_write_observer: Option<Arc<dyn Fn(&Path, &Path) + Send + Sync>>,
53}
54
55#[derive(Debug, Clone)]
56pub struct UrlFetchError {
57 message: String,
58}
59
60impl UrlFetchError {
61 fn new(message: impl Into<String>) -> Self {
62 Self {
63 message: message.into(),
64 }
65 }
66}
67
68impl fmt::Display for UrlFetchError {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 f.write_str(&self.message)
71 }
72}
73
74impl std::error::Error for UrlFetchError {}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct CacheMeta {
78 url: String,
79 #[serde(rename = "contentType")]
80 content_type: String,
81 extension: String,
82 #[serde(rename = "fetchedAt")]
83 fetched_at: u64,
84}
85
86pub fn is_http_url(value: &str) -> bool {
87 value.starts_with("http://") || value.starts_with("https://")
88}
89
90pub fn fetch_url_to_cache(
91 url: &str,
92 storage_dir: &Path,
93 options: UrlFetchOptions,
94) -> Result<PathBuf, UrlFetchError> {
95 let parsed = Url::parse(url).map_err(|_| UrlFetchError::new(format!("Invalid URL: {url}")))?;
96 validate_public_url(&parsed, &options)?;
97
98 let dir = cache_dir(storage_dir);
99 fs::create_dir_all(&dir).map_err(|error| {
100 UrlFetchError::new(format!(
101 "Failed to create URL cache directory {}: {error}",
102 dir.display()
103 ))
104 })?;
105
106 let hash = hash_url(url);
107 let meta_file = meta_path(storage_dir, &hash);
108 if let Some(cached) = fresh_cached_path(storage_dir, &hash, &meta_file)? {
109 return Ok(cached);
110 }
111
112 let response = fetch_with_redirects(&parsed, url, &options)?;
113 if !response.status().is_success() {
114 return Err(UrlFetchError::new(format!(
115 "HTTP {} {} fetching {url}",
116 response.status().as_u16(),
117 response.status().canonical_reason().unwrap_or("")
118 )));
119 }
120
121 let content_type = response
122 .headers()
123 .get(CONTENT_TYPE)
124 .and_then(|value| value.to_str().ok())
125 .unwrap_or("text/plain")
126 .to_string();
127 let extension = resolve_extension(&content_type).ok_or_else(|| {
128 UrlFetchError::new(format!(
129 "Unsupported content type '{content_type}' for {url}. Supported: text/html, text/markdown, application/json, text/plain"
130 ))
131 })?;
132
133 if let Some(length) = response.content_length() {
134 if length > MAX_RESPONSE_BYTES {
135 return Err(UrlFetchError::new(format!(
136 "Response too large: {length} bytes (max {MAX_RESPONSE_BYTES})"
137 )));
138 }
139 }
140
141 let body = read_response_body(response, url)?;
142 let content_file = content_path(storage_dir, &hash, extension);
143 atomic_write(&content_file, &body, &options)?;
144
145 let meta = CacheMeta {
146 url: url.to_string(),
147 content_type,
148 extension: extension.to_string(),
149 fetched_at: now_ms(),
150 };
151 let meta_bytes = serde_json::to_vec(&meta).map_err(|error| {
152 UrlFetchError::new(format!("Failed to encode URL cache metadata: {error}"))
153 })?;
154 atomic_write(&meta_file, &meta_bytes, &options)?;
155
156 Ok(content_file)
157}
158
159pub fn cleanup_url_cache(storage_dir: &Path) -> Result<usize, UrlFetchError> {
160 let dir = cache_dir(storage_dir);
161 if !dir.exists() {
162 return Ok(0);
163 }
164
165 let entries = fs::read_dir(&dir).map_err(|error| {
166 UrlFetchError::new(format!(
167 "URL cache cleanup failed reading {}: {error}",
168 dir.display()
169 ))
170 })?;
171 let mut removed = 0usize;
172 let now = now_ms();
173
174 for entry in entries.flatten() {
175 let path = entry.path();
176 let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
177 continue;
178 };
179 if !name.ends_with(".meta.json") {
180 continue;
181 }
182
183 let meta = fs::read_to_string(&path)
184 .ok()
185 .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok());
186 let Some(meta) = meta else {
187 if fs::remove_file(&path).is_ok() {
188 removed += 1;
189 }
190 continue;
191 };
192
193 if now.saturating_sub(meta.fetched_at) <= CACHE_TTL_MS {
194 continue;
195 }
196
197 let hash = name.trim_end_matches(".meta.json");
198 let content = content_path(storage_dir, hash, &meta.extension);
199 let _ = fs::remove_file(content);
200 if fs::remove_file(&path).is_ok() {
201 removed += 1;
202 }
203 }
204
205 Ok(removed)
206}
207
208#[doc(hidden)]
209pub fn cache_content_path_for_url(storage_dir: &Path, url: &str, extension: &str) -> PathBuf {
210 content_path(storage_dir, &hash_url(url), extension)
211}
212
213#[doc(hidden)]
214pub fn cache_meta_path_for_url(storage_dir: &Path, url: &str) -> PathBuf {
215 meta_path(storage_dir, &hash_url(url))
216}
217
218#[doc(hidden)]
219pub fn is_private_ip_for_test(ip: IpAddr) -> bool {
220 is_private_ip(ip)
221}
222
223fn cache_dir(storage_dir: &Path) -> PathBuf {
224 storage_dir.join("url_cache")
225}
226
227fn hash_url(url: &str) -> String {
228 let digest = Sha256::digest(url.as_bytes());
229 format!("{digest:x}").chars().take(16).collect()
230}
231
232fn meta_path(storage_dir: &Path, hash: &str) -> PathBuf {
233 cache_dir(storage_dir).join(format!("{hash}.meta.json"))
234}
235
236fn content_path(storage_dir: &Path, hash: &str, extension: &str) -> PathBuf {
237 cache_dir(storage_dir).join(format!("{hash}{extension}"))
238}
239
240fn fresh_cached_path(
241 storage_dir: &Path,
242 hash: &str,
243 meta_file: &Path,
244) -> Result<Option<PathBuf>, UrlFetchError> {
245 if !meta_file.exists() {
246 return Ok(None);
247 }
248
249 let meta = match fs::read_to_string(meta_file)
250 .ok()
251 .and_then(|content| serde_json::from_str::<CacheMeta>(&content).ok())
252 {
253 Some(meta) => meta,
254 None => return Ok(None),
255 };
256 let age = now_ms().saturating_sub(meta.fetched_at);
257 let cached = content_path(storage_dir, hash, &meta.extension);
258 if age < CACHE_TTL_MS && cached.exists() {
259 return Ok(Some(cached));
260 }
261 Ok(None)
262}
263
264fn fetch_with_redirects(
265 start_url: &Url,
266 original_url: &str,
267 options: &UrlFetchOptions,
268) -> Result<HttpResponse, UrlFetchError> {
269 let client = build_client(options)?;
270 let mut current_url = start_url.clone();
271
272 for redirect_count in 0..=MAX_REDIRECTS {
273 validate_public_url(¤t_url, options)?;
274 let response = send_with_transient_retries(&client, ¤t_url)?;
275
276 if !response.status().is_redirection() {
277 return Ok(response);
278 }
279 if redirect_count == MAX_REDIRECTS {
280 return Err(UrlFetchError::new(format!(
281 "Too many redirects fetching {original_url}"
282 )));
283 }
284
285 let location = response
286 .headers()
287 .get(LOCATION)
288 .and_then(|value| value.to_str().ok())
289 .ok_or_else(|| {
290 UrlFetchError::new(format!(
291 "Redirect from {} missing Location header",
292 current_url.as_str()
293 ))
294 })?;
295 current_url = current_url.join(location).map_err(|error| {
296 UrlFetchError::new(format!(
297 "Invalid redirect Location '{location}' from {}: {error}",
298 current_url.as_str()
299 ))
300 })?;
301 }
302
303 Err(UrlFetchError::new(format!(
304 "Too many redirects fetching {original_url}"
305 )))
306}
307
308fn send_with_transient_retries(
317 client: &Client,
318 target: &Url,
319) -> Result<HttpResponse, UrlFetchError> {
320 let mut last_error: Option<reqwest::Error> = None;
321 for attempt in 0..=TRANSIENT_RETRY_ATTEMPTS {
322 let result = client
323 .get(target.clone())
324 .header(USER_AGENT, USER_AGENT_VALUE)
325 .header(ACCEPT, ACCEPT_HEADER)
326 .send();
327 match result {
328 Ok(response) => return Ok(response),
329 Err(error) => {
330 if attempt < TRANSIENT_RETRY_ATTEMPTS && is_transient_reqwest_error(&error) {
331 thread::sleep(Duration::from_millis(TRANSIENT_RETRY_BACKOFFS_MS[attempt]));
332 last_error = Some(error);
333 continue;
334 }
335 return Err(UrlFetchError::new(format!(
336 "Failed to fetch {}: {}",
337 target.as_str(),
338 reqwest_error_detail(&error)
339 )));
340 }
341 }
342 }
343 Err(UrlFetchError::new(format!(
346 "Failed to fetch {} after {} retries: {}",
347 target.as_str(),
348 TRANSIENT_RETRY_ATTEMPTS,
349 last_error
350 .as_ref()
351 .map(reqwest_error_detail)
352 .unwrap_or_else(|| "unknown transient error".to_string())
353 )))
354}
355
356fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
366 error.is_connect() || error.is_timeout() || error.is_request()
367}
368
369fn build_client(options: &UrlFetchOptions) -> Result<Client, UrlFetchError> {
370 let mut builder = Client::builder()
371 .redirect(Policy::none())
372 .connect_timeout(CONNECT_TIMEOUT);
373
374 for (host, address) in &options.connect_overrides {
375 builder = builder.resolve(host, *address);
376 }
377
378 builder
379 .build()
380 .map_err(|error| UrlFetchError::new(format!("Failed to build URL fetch client: {error}")))
381}
382
383fn validate_public_url(url: &Url, options: &UrlFetchOptions) -> Result<(), UrlFetchError> {
384 if url.scheme() != "http" && url.scheme() != "https" {
385 return Err(UrlFetchError::new(format!(
386 "Only http:// and https:// URLs are supported, got: {}:",
387 url.scheme()
388 )));
389 }
390 if options.allow_private {
391 return Ok(());
392 }
393
394 let host = url
395 .host_str()
396 .ok_or_else(|| UrlFetchError::new(format!("URL missing host: {url}")))?;
397 let host_for_parse = host
398 .trim_matches(['[', ']'])
399 .split('%')
400 .next()
401 .unwrap_or(host);
402
403 if let Ok(ip) = host_for_parse.parse::<IpAddr>() {
404 reject_private_ip(host, ip)?;
405 return Ok(());
406 }
407 if host_for_parse.contains(':') {
408 return Err(UrlFetchError::new(format!(
409 "Blocked private URL host {host} ({host_for_parse})"
410 )));
411 }
412
413 let addresses = resolve_host_ips(host_for_parse, url.port_or_known_default(), options)?;
414 if addresses.is_empty() {
415 return Err(UrlFetchError::new(format!(
416 "Failed to resolve URL host {host}"
417 )));
418 }
419 for ip in addresses {
420 reject_private_ip(host, ip)?;
421 }
422
423 Ok(())
429}
430
431fn resolve_host_ips(
432 host: &str,
433 port: Option<u16>,
434 options: &UrlFetchOptions,
435) -> Result<Vec<IpAddr>, UrlFetchError> {
436 if let Some((_, ips)) = options
437 .public_host_overrides
438 .iter()
439 .find(|(override_host, _)| override_host == host)
440 {
441 return Ok(ips.clone());
442 }
443
444 let port = port.unwrap_or(80);
445 let addrs = (host, port).to_socket_addrs().map_err(|error| {
446 UrlFetchError::new(format!("Failed to resolve URL host {host}: {error}"))
447 })?;
448 Ok(addrs.map(|addr| addr.ip()).collect())
449}
450
451fn reject_private_ip(host: &str, ip: IpAddr) -> Result<(), UrlFetchError> {
452 if is_private_ip(ip) {
453 return Err(UrlFetchError::new(format!(
454 "Blocked private URL host {host} ({ip})"
455 )));
456 }
457 Ok(())
458}
459
460fn is_private_ip(ip: IpAddr) -> bool {
461 match ip {
462 IpAddr::V4(ipv4) => is_private_ipv4(ipv4),
463 IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
464 }
465}
466
467fn is_private_ipv4(ip: Ipv4Addr) -> bool {
468 let [a, b, _, _] = ip.octets();
469 a == 0
470 || a == 10
471 || a == 127
472 || (a == 172 && (16..=31).contains(&b))
473 || (a == 192 && b == 168)
474 || (a == 169 && b == 254)
475 || (a == 100 && (64..=127).contains(&b))
479 || (a == 198 && (18..=19).contains(&b))
481 || a >= 224
482}
483
484fn is_private_ipv6(ip: Ipv6Addr) -> bool {
485 let segments = ip.segments();
486 let top_six_zero = segments[..6].iter().all(|segment| *segment == 0);
487 let is_mapped = segments[..5].iter().all(|segment| *segment == 0) && segments[5] == 0xffff;
488 if is_mapped || top_six_zero {
489 let embedded = Ipv4Addr::new(
490 (segments[6] >> 8) as u8,
491 (segments[6] & 0xff) as u8,
492 (segments[7] >> 8) as u8,
493 (segments[7] & 0xff) as u8,
494 );
495 return is_private_ipv4(embedded);
496 }
497
498 let first = segments[0];
499 (0xfe80..=0xfebf).contains(&first) || (0xfc00..=0xfdff).contains(&first) || first >= 0xff00
500}
501
502fn resolve_extension(content_type: &str) -> Option<&'static str> {
503 let lower = content_type.to_ascii_lowercase();
504 let media_type = lower
505 .split(';')
506 .next()
507 .unwrap_or("")
508 .split(',')
509 .next()
510 .unwrap_or("")
511 .trim();
512
513 match media_type {
514 "text/html"
515 | "application/xhtml+xml"
516 | "application/vnd.github.html"
517 | "application/vnd.github+html" => Some(".html"),
518 "text/markdown"
519 | "text/x-markdown"
520 | "application/markdown"
521 | "application/vnd.github.raw"
522 | "application/vnd.github+raw"
523 | "application/vnd.github.v3.raw"
524 | "text/plain" => Some(".md"),
525 "application/json" | "application/ld+json" => Some(".json"),
526 other if other.ends_with("+json") => Some(".json"),
527 _ => None,
528 }
529}
530
531enum BodyReadEvent {
532 Chunk(Vec<u8>),
533 Done,
534 Error(io::ErrorKind, String),
535}
536
537fn read_response_body(mut response: HttpResponse, url: &str) -> Result<Vec<u8>, UrlFetchError> {
538 let (tx, rx) = mpsc::channel();
539 thread::spawn(move || {
540 let mut buffer = [0u8; 16 * 1024];
541 loop {
542 match response.read(&mut buffer) {
543 Ok(0) => {
544 let _ = tx.send(BodyReadEvent::Done);
545 break;
546 }
547 Ok(n) => {
548 if tx.send(BodyReadEvent::Chunk(buffer[..n].to_vec())).is_err() {
549 break;
550 }
551 }
552 Err(error) => {
553 let kind = error.kind();
554 let message = error.to_string();
555 let _ = tx.send(BodyReadEvent::Error(kind, message));
556 break;
557 }
558 }
559 }
560 });
561
562 let mut chunks = Vec::new();
563 let mut total = 0u64;
564 loop {
565 match rx.recv_timeout(BODY_CHUNK_TIMEOUT) {
566 Ok(BodyReadEvent::Chunk(chunk)) => {
567 total += chunk.len() as u64;
568 if total > MAX_RESPONSE_BYTES {
569 return Err(UrlFetchError::new(format!(
570 "Response exceeded {MAX_RESPONSE_BYTES} bytes, aborted"
571 )));
572 }
573 chunks.extend_from_slice(&chunk);
574 }
575 Ok(BodyReadEvent::Done) => return Ok(chunks),
576 Ok(BodyReadEvent::Error(kind, _message)) if is_body_stall_kind(kind) => {
577 return Err(body_stall_error(url));
578 }
579 Ok(BodyReadEvent::Error(_, message)) => {
580 return Err(UrlFetchError::new(format!(
581 "Failed to read response body for {url}: {message}"
582 )));
583 }
584 Err(mpsc::RecvTimeoutError::Timeout) => return Err(body_stall_error(url)),
585 Err(mpsc::RecvTimeoutError::Disconnected) => {
586 return Err(UrlFetchError::new(format!(
587 "Failed to read response body for {url}: body reader stopped unexpectedly"
588 )));
589 }
590 }
591 }
592}
593
594fn body_stall_error(url: &str) -> UrlFetchError {
595 UrlFetchError::new(format!(
596 "Body read stalled (no data for {}ms) fetching {url}",
597 BODY_CHUNK_TIMEOUT.as_millis()
598 ))
599}
600
601fn is_body_stall_kind(kind: io::ErrorKind) -> bool {
602 matches!(kind, io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock)
603}
604
605fn atomic_write(
606 final_path: &Path,
607 bytes: &[u8],
608 options: &UrlFetchOptions,
609) -> Result<(), UrlFetchError> {
610 let parent = final_path.parent().unwrap_or_else(|| Path::new("."));
611 fs::create_dir_all(parent).map_err(|error| {
612 UrlFetchError::new(format!(
613 "Failed to create URL cache parent {}: {error}",
614 parent.display()
615 ))
616 })?;
617
618 let file_name = final_path
619 .file_name()
620 .and_then(|name| name.to_str())
621 .ok_or_else(|| {
622 UrlFetchError::new(format!("Invalid cache path: {}", final_path.display()))
623 })?;
624 let tmp_path = final_path.with_file_name(format!(
625 "{file_name}.tmp-{}-{}",
626 std::process::id(),
627 random_nonce()
628 ));
629
630 let write_result = (|| -> io::Result<()> {
631 let mut file = fs::File::create(&tmp_path)?;
632 file.write_all(bytes)?;
633 file.flush()?;
634 Ok(())
635 })();
636 if let Err(error) = write_result {
637 let _ = fs::remove_file(&tmp_path);
638 return Err(UrlFetchError::new(format!(
639 "Failed to write URL cache temp file {}: {error}",
640 tmp_path.display()
641 )));
642 }
643
644 if let Some(observer) = &options.atomic_write_observer {
645 observer(&tmp_path, final_path);
646 }
647
648 fs::rename(&tmp_path, final_path).map_err(|error| {
649 let _ = fs::remove_file(&tmp_path);
650 UrlFetchError::new(format!(
651 "Failed to finalize URL cache file {}: {error}",
652 final_path.display()
653 ))
654 })
655}
656
657fn random_nonce() -> String {
658 let mut bytes = [0u8; 8];
659 if getrandom::fill(&mut bytes).is_err() {
660 let fallback = now_ms() ^ u64::from(std::process::id());
661 bytes = fallback.to_le_bytes();
662 }
663 let mut out = String::with_capacity(bytes.len() * 2);
664 for byte in bytes {
665 use std::fmt::Write as _;
666 let _ = write!(out, "{byte:02x}");
667 }
668 out
669}
670
671fn now_ms() -> u64 {
672 SystemTime::now()
673 .duration_since(UNIX_EPOCH)
674 .unwrap_or_default()
675 .as_millis()
676 .try_into()
677 .unwrap_or(u64::MAX)
678}
679
680fn reqwest_error_detail(error: &reqwest::Error) -> String {
681 if error.is_timeout() {
682 return format!("timeout: {error}");
683 }
684 if let Some(source) = error.source() {
685 return format!("{source}");
686 }
687 error.to_string()
688}