1use std::io::{self, Write};
32
33use digest::Digest;
34use rand::seq::SliceRandom;
35
36#[derive(Debug, thiserror::Error)]
38pub enum Error {
39 #[error("unsupported algorithm: {0}")]
41 UnsupportedAlgorithm(String),
42
43 #[error("hash mismatch: expected {expected}, got {actual}")]
45 HashMismatch { expected: String, actual: String },
46}
47
48pub fn normalize_algo(name: &str) -> String {
53 name.chars()
54 .filter_map(|c| match c {
55 'A'..='Z' => Some(c.to_ascii_lowercase()),
56 'a'..='z' | '0'..='9' => Some(c),
57 _ => None,
58 })
59 .collect()
60}
61
62pub fn is_supported(algo: &str) -> bool {
64 matches!(normalize_algo(algo).as_str(), "sha1" | "sha256" | "sha512")
65}
66
67pub fn parse_fetchurl_server(value: &str) -> Vec<String> {
69 parse_sfv_string_list(value)
70}
71
72pub fn encode_source_urls(urls: &[impl AsRef<str>]) -> String {
74 let strs: Vec<&str> = urls.iter().map(|s| s.as_ref()).collect();
75 encode_sfv_string_list(&strs)
76}
77
78fn encode_sfv_string_list(strings: &[&str]) -> String {
81 strings
82 .iter()
83 .map(|s| {
84 let escaped = s.replace('\\', "\\\\").replace('"', "\\\"");
85 format!("\"{escaped}\"")
86 })
87 .collect::<Vec<_>>()
88 .join(", ")
89}
90
91fn parse_sfv_string_list(input: &str) -> Vec<String> {
92 let mut results = Vec::new();
93 let bytes = input.as_bytes();
94 let mut i = 0;
95
96 while i < bytes.len() {
97 while i < bytes.len() && matches!(bytes[i], b' ' | b'\t') {
99 i += 1;
100 }
101 if i >= bytes.len() {
102 break;
103 }
104
105 if bytes[i] != b'"' {
107 while i < bytes.len() && bytes[i] != b',' {
109 i += 1;
110 }
111 if i < bytes.len() {
112 i += 1;
113 }
114 continue;
115 }
116 i += 1;
117
118 let mut s = String::new();
120 while i < bytes.len() {
121 match bytes[i] {
122 b'\\' if i + 1 < bytes.len() => {
123 s.push(bytes[i + 1] as char);
124 i += 2;
125 }
126 b'"' => {
127 i += 1;
128 break;
129 }
130 c => {
131 s.push(c as char);
132 i += 1;
133 }
134 }
135 }
136 results.push(s);
137
138 while i < bytes.len() && bytes[i] != b',' {
140 i += 1;
141 }
142 if i < bytes.len() {
143 i += 1;
144 }
145 }
146
147 results
148}
149
150#[derive(Clone, Debug)]
154pub struct FetchAttempt {
155 url: String,
156 headers: Vec<(String, String)>,
157}
158
159impl FetchAttempt {
160 pub fn url(&self) -> &str {
162 &self.url
163 }
164
165 pub fn headers(&self) -> &[(String, String)] {
167 &self.headers
168 }
169}
170
171pub struct FetchSession {
182 attempts: Vec<FetchAttempt>,
183 current: usize,
184 algo: String,
185 hash: String,
186 done: bool,
187 success: bool,
188}
189
190impl FetchSession {
191 pub fn new(
197 algo: &str,
198 hash: &str,
199 source_urls: &[impl AsRef<str>],
200 ) -> Result<Self, Error> {
201 let algo = normalize_algo(algo);
202 if !is_supported(&algo) {
203 return Err(Error::UnsupportedAlgorithm(algo));
204 }
205
206 let servers_env = std::env::var("FETCHURL_SERVER").unwrap_or_default();
207 let servers = parse_fetchurl_server(&servers_env);
208
209 let source_header = if !source_urls.is_empty() {
210 Some(encode_source_urls(source_urls))
211 } else {
212 None
213 };
214
215 let mut attempts = Vec::new();
216
217 for server in servers {
219 let base = server.trim_end_matches('/');
220 let url = format!("{base}/api/fetchurl/{algo}/{hash}");
221 let mut headers = Vec::new();
222 if let Some(ref val) = source_header {
223 headers.push(("X-Source-Urls".to_string(), val.clone()));
224 }
225 attempts.push(FetchAttempt { url, headers });
226 }
227
228 let mut direct: Vec<String> = source_urls
230 .iter()
231 .map(|s| s.as_ref().to_string())
232 .collect();
233 direct.shuffle(&mut rand::thread_rng());
234 for url in direct {
235 attempts.push(FetchAttempt {
236 url,
237 headers: Vec::new(),
238 });
239 }
240
241 Ok(FetchSession {
242 attempts,
243 current: 0,
244 algo,
245 hash: hash.to_string(),
246 done: false,
247 success: false,
248 })
249 }
250
251 pub fn next_attempt(&mut self) -> Option<FetchAttempt> {
260 if self.done || self.current >= self.attempts.len() {
261 return None;
262 }
263 let attempt = self.attempts[self.current].clone();
264 self.current += 1;
265 Some(attempt)
266 }
267
268 pub fn report_success(&mut self) {
270 self.done = true;
271 self.success = true;
272 }
273
274 pub fn report_partial(&mut self) {
277 self.done = true;
278 }
279
280 pub fn succeeded(&self) -> bool {
282 self.success
283 }
284
285 pub fn verifier<W: Write>(&self, writer: W) -> HashVerifier<W> {
290 HashVerifier::new(&self.algo, &self.hash, writer)
291 }
292}
293
294enum HasherInner {
297 Sha1(sha1::Sha1),
298 Sha256(sha2::Sha256),
299 Sha512(sha2::Sha512),
300}
301
302impl HasherInner {
303 fn new(algo: &str) -> Option<Self> {
304 match algo {
305 "sha1" => Some(HasherInner::Sha1(sha1::Sha1::new())),
306 "sha256" => Some(HasherInner::Sha256(sha2::Sha256::new())),
307 "sha512" => Some(HasherInner::Sha512(sha2::Sha512::new())),
308 _ => None,
309 }
310 }
311
312 fn update(&mut self, data: &[u8]) {
313 match self {
314 HasherInner::Sha1(h) => h.update(data),
315 HasherInner::Sha256(h) => h.update(data),
316 HasherInner::Sha512(h) => h.update(data),
317 }
318 }
319
320 fn finalize(self) -> Vec<u8> {
321 match self {
322 HasherInner::Sha1(h) => h.finalize().to_vec(),
323 HasherInner::Sha256(h) => h.finalize().to_vec(),
324 HasherInner::Sha512(h) => h.finalize().to_vec(),
325 }
326 }
327}
328
329fn to_hex(bytes: &[u8]) -> String {
330 bytes.iter().map(|b| format!("{b:02x}")).collect()
331}
332
333pub struct HashVerifier<W: Write> {
338 inner: W,
339 hasher: HasherInner,
340 expected_hash: String,
341 bytes_written: u64,
342}
343
344impl<W: Write> HashVerifier<W> {
345 fn new(algo: &str, expected_hash: &str, inner: W) -> Self {
346 let hasher =
347 HasherInner::new(algo).expect("HashVerifier created with validated algo");
348 HashVerifier {
349 inner,
350 hasher,
351 expected_hash: expected_hash.to_string(),
352 bytes_written: 0,
353 }
354 }
355
356 pub fn bytes_written(&self) -> u64 {
358 self.bytes_written
359 }
360
361 pub fn finish(self) -> Result<W, Error> {
365 let actual = to_hex(&self.hasher.finalize());
366 if actual == self.expected_hash {
367 Ok(self.inner)
368 } else {
369 Err(Error::HashMismatch {
370 expected: self.expected_hash,
371 actual,
372 })
373 }
374 }
375}
376
377impl<W: Write> Write for HashVerifier<W> {
378 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
379 let n = self.inner.write(buf)?;
380 self.hasher.update(&buf[..n]);
381 self.bytes_written += n as u64;
382 Ok(n)
383 }
384
385 fn flush(&mut self) -> io::Result<()> {
386 self.inner.flush()
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 fn sha256_hex(data: &[u8]) -> String {
395 to_hex(&sha2::Sha256::digest(data))
396 }
397
398 #[test]
399 fn test_normalize_algo() {
400 assert_eq!(normalize_algo("SHA-256"), "sha256");
401 assert_eq!(normalize_algo("sha256"), "sha256");
402 assert_eq!(normalize_algo("SHA512"), "sha512");
403 assert_eq!(normalize_algo("md5"), "md5");
404 }
405
406 #[test]
407 fn test_is_supported() {
408 assert!(is_supported("sha256"));
409 assert!(is_supported("SHA-256"));
410 assert!(is_supported("sha1"));
411 assert!(is_supported("sha512"));
412 assert!(!is_supported("md5"));
413 }
414
415 #[test]
416 fn test_sfv_encode() {
417 assert_eq!(
418 encode_sfv_string_list(&["https://a.com", "https://b.com"]),
419 r#""https://a.com", "https://b.com""#
420 );
421 }
422
423 #[test]
424 fn test_sfv_parse() {
425 let parsed = parse_sfv_string_list(r#""https://a.com", "https://b.com""#);
426 assert_eq!(parsed, vec!["https://a.com", "https://b.com"]);
427 }
428
429 #[test]
430 fn test_sfv_roundtrip() {
431 let urls = &[
432 "https://cdn.example.com/file.tar.gz",
433 "https://mirror.org/archive.tgz",
434 ];
435 let encoded = encode_sfv_string_list(urls);
436 let decoded = parse_sfv_string_list(&encoded);
437 assert_eq!(decoded, urls);
438 }
439
440 #[test]
441 fn test_sfv_parse_with_parameters() {
442 let parsed = parse_sfv_string_list(r#""https://a.com";q=0.9, "https://b.com""#);
444 assert_eq!(parsed, vec!["https://a.com", "https://b.com"]);
445 }
446
447 #[test]
448 fn test_hash_verifier_success() {
449 let data = b"hello world";
450 let hash = sha256_hex(data);
451
452 let mut output = Vec::new();
453 {
454 let mut verifier = HashVerifier::new("sha256", &hash, &mut output);
455 verifier.write_all(data).unwrap();
456 assert_eq!(verifier.bytes_written(), data.len() as u64);
457 verifier.finish().unwrap();
458 }
459 assert_eq!(output, data);
460 }
461
462 #[test]
463 fn test_hash_verifier_mismatch() {
464 let data = b"hello world";
465 let wrong_hash = sha256_hex(b"wrong");
466
467 let mut output = Vec::new();
468 let mut verifier = HashVerifier::new("sha256", &wrong_hash, &mut output);
469 verifier.write_all(data).unwrap();
470 let err = verifier.finish().unwrap_err();
471 assert!(matches!(err, Error::HashMismatch { .. }));
472 }
473
474 #[test]
475 fn test_session_unsupported_algo() {
476 let err = FetchSession::new(
477 "md5",
478 "abc",
479 &["http://src"],
480 );
481 assert!(matches!(err, Err(Error::UnsupportedAlgorithm(_))));
482 }
483
484 #[test]
485 fn test_session_attempt_ordering() {
486 let hash = sha256_hex(b"test");
487 unsafe { std::env::set_var("FETCHURL_SERVER", "\"http://cache1\", \"http://cache2\""); }
488 let mut session = FetchSession::new(
489 "sha256",
490 &hash,
491 &["http://src1"],
492 )
493 .unwrap();
494
495 let a1 = session.next_attempt().unwrap();
497 assert!(a1.url().starts_with("http://cache1/api/fetchurl/sha256/"));
498 assert!(!a1.headers().is_empty());
499
500 let a2 = session.next_attempt().unwrap();
501 assert!(a2.url().starts_with("http://cache2/api/fetchurl/sha256/"));
502
503 let a3 = session.next_attempt().unwrap();
505 assert_eq!(a3.url(), "http://src1");
506 assert!(a3.headers().is_empty());
507
508 assert!(session.next_attempt().is_none());
510 assert!(!session.succeeded());
511 }
512
513 #[test]
514 fn test_session_success_stops() {
515 let hash = sha256_hex(b"test");
516 unsafe { std::env::set_var("FETCHURL_SERVER", "\"http://cache\""); }
517 let mut session = FetchSession::new(
518 "sha256",
519 &hash,
520 &["http://src"],
521 )
522 .unwrap();
523
524 let _ = session.next_attempt().unwrap();
525 session.report_success();
526 assert!(session.succeeded());
527 assert!(session.next_attempt().is_none());
528 }
529
530 #[test]
531 fn test_session_partial_stops() {
532 let hash = sha256_hex(b"test");
533 unsafe { std::env::set_var("FETCHURL_SERVER", "\"http://cache\""); }
534 let mut session = FetchSession::new(
535 "sha256",
536 &hash,
537 &["http://src"],
538 )
539 .unwrap();
540
541 let _ = session.next_attempt().unwrap();
542 session.report_partial();
543 assert!(!session.succeeded());
544 assert!(session.next_attempt().is_none());
545 }
546
547 #[test]
548 fn test_session_server_has_source_header() {
549 let hash = sha256_hex(b"test");
550 unsafe { std::env::set_var("FETCHURL_SERVER", "\"http://cache\""); }
551 let mut session = FetchSession::new(
552 "sha256",
553 &hash,
554 &["http://src1", "http://src2"],
555 )
556 .unwrap();
557
558 let attempt = session.next_attempt().unwrap();
559 let source_header = attempt
560 .headers()
561 .iter()
562 .find(|(k, _)| k == "X-Source-Urls")
563 .map(|(_, v)| v.clone())
564 .unwrap();
565
566 let parsed = parse_sfv_string_list(&source_header);
567 assert!(parsed.contains(&"http://src1".to_string()));
568 assert!(parsed.contains(&"http://src2".to_string()));
569 }
570}