1#![forbid(unsafe_code)]
10
11use reqwest::blocking::Client;
12use reqwest::header::RANGE;
13use reqwest::StatusCode;
14use std::fs;
15use std::path::{Path, PathBuf};
16use std::time::Duration;
17use vanta_core::{Area, VtaError, VtaResult};
18
19pub struct Downloader {
21 client: Client,
22 retries: u32,
23}
24
25impl Downloader {
26 pub fn new() -> VtaResult<Downloader> {
28 let client = Client::builder()
29 .user_agent(concat!("vanta/", env!("CARGO_PKG_VERSION")))
30 .connect_timeout(Duration::from_secs(30))
31 .build()
32 .map_err(|e| VtaError::new(Area::Net, 4, format!("building HTTP client: {e}")))?;
33 Ok(Downloader { client, retries: 3 })
34 }
35
36 pub fn with_retries(mut self, retries: u32) -> Self {
38 self.retries = retries;
39 self
40 }
41
42 pub fn download(&self, url: &str, dest: &Path) -> VtaResult<()> {
45 let mut last: Option<VtaError> = None;
46 for attempt in 0..=self.retries {
47 match self.fetch_one(url, dest) {
48 Ok(()) => return Ok(()),
49 Err(e) => {
50 last = Some(e);
51 if attempt < self.retries {
52 std::thread::sleep(backoff(attempt));
53 }
54 }
55 }
56 }
57 Err(last.unwrap_or_else(|| VtaError::new(Area::Net, 1, format!("download failed: {url}"))))
58 }
59
60 pub fn download_any(&self, urls: &[String], dest: &Path) -> VtaResult<()> {
64 let mut last: Option<VtaError> = None;
65 for url in urls {
66 match self.download(url, dest) {
67 Ok(()) => return Ok(()),
68 Err(e) => last = Some(e),
69 }
70 }
71 Err(last.unwrap_or_else(|| {
72 VtaError::new(Area::Net, 1, "no URLs supplied to download_any".to_string())
73 }))
74 }
75
76 fn fetch_one(&self, url: &str, dest: &Path) -> VtaResult<()> {
77 let part = part_path(dest);
78 let have = fs::metadata(&part).map(|m| m.len()).unwrap_or(0);
79
80 let mut req = self.client.get(url);
81 if have > 0 {
82 req = req.header(RANGE, format!("bytes={have}-"));
83 }
84 let mut resp = req
85 .send()
86 .map_err(|e| VtaError::new(Area::Net, 1, format!("requesting {url}: {e}")))?;
87
88 let status = resp.status();
89 let resuming = have > 0 && status == StatusCode::PARTIAL_CONTENT;
90 if !(status.is_success() || resuming) {
91 return Err(VtaError::new(
92 Area::Net,
93 1,
94 format!("HTTP {status} for {url}"),
95 ));
96 }
97
98 if let Some(parent) = part.parent() {
99 fs::create_dir_all(parent).map_err(|e| io(parent, e))?;
100 }
101 let mut file = if resuming {
102 fs::OpenOptions::new()
103 .append(true)
104 .open(&part)
105 .map_err(|e| io(&part, e))?
106 } else {
107 let _ = fs::remove_file(&part);
108 fs::File::create(&part).map_err(|e| io(&part, e))?
109 };
110
111 std::io::copy(&mut resp, &mut file)
112 .map_err(|e| VtaError::new(Area::Net, 1, format!("writing {}: {e}", part.display())))?;
113 file.sync_all().ok();
114 fs::rename(&part, dest).map_err(|e| io(dest, e))?;
115 Ok(())
116 }
117}
118
119fn part_path(dest: &Path) -> PathBuf {
120 let mut s = dest.as_os_str().to_os_string();
121 s.push(".part");
122 PathBuf::from(s)
123}
124
125fn backoff(attempt: u32) -> Duration {
126 let secs = (1u64 << attempt.min(4)) as f64 * 0.5;
128 Duration::from_secs_f64(secs)
129}
130
131fn io(path: &Path, e: std::io::Error) -> VtaError {
132 VtaError::new(Area::Net, 1, format!("{}: {e}", path.display()))
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn client_builds() {
141 assert!(Downloader::new().is_ok());
142 }
143
144 #[test]
145 fn part_path_appends_suffix() {
146 assert_eq!(
147 part_path(Path::new("/tmp/a.bin")),
148 PathBuf::from("/tmp/a.bin.part")
149 );
150 }
151
152 #[test]
153 fn download_any_empty_errors() {
154 let d = Downloader::new().unwrap();
155 assert!(d.download_any(&[], Path::new("/tmp/none")).is_err());
156 }
157}