1#![doc = include_str!("../README.md")]
2
3use std::fmt;
4use std::fmt::Write as _;
5use std::future::Future;
6use std::path::PathBuf;
7use std::pin::pin;
8use std::time::{Duration, Instant};
9
10use futures_core::Stream;
11use futures_util::TryStreamExt;
12use md5::Md5;
13use sha1::{Digest, Sha1};
14use url::Url;
15
16mod error;
17mod http;
18pub mod local;
19mod range;
20#[cfg(feature = "s3")]
21pub mod s3;
22
23pub use error::Error;
24use http::is_retryable;
25pub use http::{Downloader, DownloaderBuilder};
26
27const PROGRESS_INTERVAL: Duration = Duration::from_millis(500);
28
29#[derive(Debug, Clone)]
33pub enum Checksum {
34 Sha1(String),
35 Md5(String),
36}
37
38impl Checksum {
39 pub fn hex(&self) -> &str {
41 match self {
42 Checksum::Sha1(h) | Checksum::Md5(h) => h,
43 }
44 }
45
46 pub fn algorithm(&self) -> &'static str {
48 match self {
49 Checksum::Sha1(_) => "sha1",
50 Checksum::Md5(_) => "md5",
51 }
52 }
53
54 pub fn with_value(&self, value: String) -> Checksum {
57 match self {
58 Checksum::Sha1(_) => Checksum::Sha1(value),
59 Checksum::Md5(_) => Checksum::Md5(value),
60 }
61 }
62}
63
64pub enum Hasher {
68 None,
69 Sha1(Sha1),
70 Md5(Md5),
71}
72
73impl Hasher {
74 pub fn for_checksum(checksum: Option<&Checksum>) -> Self {
75 match checksum {
76 Some(Checksum::Sha1(_)) => Hasher::Sha1(Sha1::new()),
77 Some(Checksum::Md5(_)) => Hasher::Md5(Md5::new()),
78 None => Hasher::None,
79 }
80 }
81
82 pub fn update(&mut self, bytes: &[u8]) {
83 match self {
84 Hasher::Sha1(h) => h.update(bytes),
85 Hasher::Md5(h) => h.update(bytes),
86 Hasher::None => {}
87 }
88 }
89
90 pub fn finalize_hex(self) -> Option<String> {
91 match self {
92 Hasher::Sha1(h) => Some(to_hex(&h.finalize())),
93 Hasher::Md5(h) => Some(to_hex(&h.finalize())),
94 Hasher::None => None,
95 }
96 }
97}
98
99fn to_hex(bytes: &[u8]) -> String {
100 let mut hex = String::with_capacity(bytes.len() * 2);
101 for b in bytes {
102 write!(&mut hex, "{b:02x}").expect("writing to String cannot fail");
103 }
104 hex
105}
106
107pub trait Source: Clone + Send + 'static {
112 fn name(&self) -> &str;
113 fn size(&self) -> u64;
114 fn checksum(&self) -> Option<Checksum> {
115 None
116 }
117}
118
119#[derive(Debug, Clone)]
122pub struct Download {
123 pub url: Url,
124 pub size: u64,
125 pub checksum: Option<Checksum>,
126 pub name: String,
127}
128
129impl Source for Download {
130 fn name(&self) -> &str {
131 &self.name
132 }
133
134 fn size(&self) -> u64 {
135 self.size
136 }
137
138 fn checksum(&self) -> Option<Checksum> {
139 self.checksum.clone()
140 }
141}
142
143#[derive(Clone, Copy)]
147pub struct Target<'a> {
148 pub name: &'a str,
149 pub size: u64,
150 pub checksum: Option<&'a Checksum>,
151}
152
153pub trait Sink: Send {
154 type Location: Send + 'static;
155
156 fn prepare(
157 &mut self,
158 target: Target<'_>,
159 ) -> impl Future<Output = Result<Prepared<Self::Location>, Error>> + Send;
160
161 fn write_chunk(&mut self, chunk: &[u8]) -> impl Future<Output = Result<(), Error>> + Send;
162
163 fn restart(&mut self) -> impl Future<Output = Result<(), Error>> + Send;
164
165 fn finalize(self) -> impl Future<Output = Result<Self::Location, Error>> + Send;
166}
167
168pub trait SinkFactory: Send {
171 type Sink: Sink<Location = Self::Location> + 'static;
172 type Location: DownloadLocation;
173
174 fn make(
175 &mut self,
176 target: Target<'_>,
177 ) -> impl Future<Output = Result<Self::Sink, Error>> + Send;
178}
179
180pub trait DownloadLocation: Send + 'static {
185 fn fmt_location(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result;
186}
187
188impl DownloadLocation for PathBuf {
189 fn fmt_location(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 write!(f, "{}", self.display())
191 }
192}
193
194impl DownloadLocation for String {
195 fn fmt_location(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 f.write_str(self)
197 }
198}
199
200#[derive(Debug)]
206pub enum Outcome<M, L = PathBuf> {
207 Downloaded {
208 file: M,
209 location: L,
210 verified: bool,
211 },
212 Failed {
213 file: M,
214 error: Error,
215 },
216 Progress {
217 file: M,
218 received: u64,
219 total: u64,
220 },
221 Skipped {
222 file: M,
223 location: L,
224 },
225 StreamFailed {
226 error: Error,
227 },
228}
229
230impl<M: Source, L: DownloadLocation> fmt::Display for Outcome<M, L> {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 match self {
233 Self::Progress {
234 file,
235 received,
236 total,
237 } => {
238 let pct = if *total == 0 {
239 100.0
240 } else {
241 (*received as f64 / *total as f64) * 100.0
242 };
243 write!(f, "{}: {pct:.1}% ({received} / {total} bytes)", file.name())
244 }
245 Self::Downloaded {
246 location, verified, ..
247 } => {
248 write!(f, "downloaded ")?;
249 location.fmt_location(f)?;
250 write!(
251 f,
252 " ({})",
253 if *verified { "verified" } else { "unverified" }
254 )
255 }
256 Self::Failed { file, error } => write!(f, "failed {}: {error}", file.name()),
257 Self::Skipped { location, .. } => {
258 write!(f, "skipped ")?;
259 location.fmt_location(f)?;
260 write!(f, " (already present)")
261 }
262 Self::StreamFailed { error } => write!(f, "stream failed: {error}"),
263 }
264 }
265}
266
267pub enum Prepared<L> {
268 Skip { location: L },
270 Resume { received: u64, partial: Hasher },
274}
275
276pub fn drive<'a, M, F, R>(
282 http: &'a Downloader,
283 items: impl Stream<Item = Result<M, Error>> + Send + 'a,
284 mut resolve: R,
285 factory: F,
286) -> impl Stream<Item = Outcome<M, F::Location>> + Send + 'a
287where
288 M: Source,
289 F: SinkFactory + 'a,
290 R: FnMut(&M) -> Result<Url, Error> + Send + 'a,
291{
292 async_stream::stream! {
293 let mut factory = factory;
294 let mut items = pin!(items);
295 loop {
296 let item = match items.try_next().await {
297 Ok(Some(item)) => item,
298 Ok(None) => break,
299 Err(error) => {
300 yield Outcome::StreamFailed { error };
301 return;
302 }
303 };
304 let url = match resolve(&item) {
305 Ok(u) => u,
306 Err(error) => {
307 yield Outcome::Failed { file: item, error };
308 continue;
309 }
310 };
311 let checksum = item.checksum();
312 let target = Target {
313 name: item.name(),
314 size: item.size(),
315 checksum: checksum.as_ref(),
316 };
317 let sink = match factory.make(target).await {
318 Ok(s) => s,
319 Err(error) => {
320 yield Outcome::Failed { file: item, error };
321 continue;
322 }
323 };
324 let item_for_err = item.clone();
325 let mut events = pin!(run_download(http, url, item, sink));
326 loop {
327 match events.try_next().await {
328 Ok(Some(outcome)) => yield outcome,
329 Ok(None) => break,
330 Err(error) => {
331 yield Outcome::Failed {
332 file: item_for_err,
333 error,
334 };
335 break;
336 }
337 }
338 }
339 }
340 }
341}
342
343pub fn drive_downloads<'a, F>(
346 http: &'a Downloader,
347 items: impl Stream<Item = Result<Download, Error>> + Send + 'a,
348 factory: F,
349) -> impl Stream<Item = Outcome<Download, F::Location>> + Send + 'a
350where
351 F: SinkFactory + 'a,
352{
353 drive(http, items, |download| Ok(download.url.clone()), factory)
354}
355
356pub fn run_download<M, S>(
361 http: &Downloader,
362 url: Url,
363 item: M,
364 sink: S,
365) -> impl Stream<Item = Result<Outcome<M, S::Location>, Error>> + Send + '_
366where
367 M: Source,
368 S: Sink + Send + 'static,
369{
370 async_stream::try_stream! {
371 let mut sink = sink;
372 let size = item.size();
373 let checksum = item.checksum();
374
375 let (mut received, mut hasher) = match sink
376 .prepare(Target {
377 name: item.name(),
378 size,
379 checksum: checksum.as_ref(),
380 })
381 .await?
382 {
383 Prepared::Skip { location } => {
384 yield Outcome::Skipped { file: item, location };
385 return;
386 }
387 Prepared::Resume { received, partial } => (received, partial),
388 };
389
390 let mut last_progress: Option<Instant> = None;
391 let mut attempts_left = http.max_attempts();
392 let mut delay = http.backoff();
393
394 if received > 0 && received < size {
395 yield Outcome::Progress {
396 file: item.clone(),
397 received,
398 total: size,
399 };
400 last_progress = Some(Instant::now());
401 }
402
403 'download: while received < size {
404 let mut response = http.get_response_range(url.clone(), received).await?;
405
406 if received > 0 {
407 match response.status() {
408 reqwest::StatusCode::OK => {
409 sink.restart().await?;
410 received = 0;
411 hasher = Hasher::for_checksum(checksum.as_ref());
412 attempts_left = http.max_attempts();
413 delay = http.backoff();
414 }
415 reqwest::StatusCode::PARTIAL_CONTENT => {
416 range::validate_content_range(&response, received, size, &url)?;
417 }
418 status => {
419 Err(Error::InvalidRangeResponse {
420 url: url.to_string(),
421 details: format!("expected 200 or 206 for resume, got {status}"),
422 })?;
423 }
424 }
425 }
426
427 loop {
428 let chunk = match response.chunk().await {
429 Ok(Some(chunk)) => chunk,
430 Ok(None) => break 'download,
431 Err(e) => {
432 let err = Error::from(e);
433 if attempts_left > 1 && is_retryable(&err) {
434 attempts_left -= 1;
435 tokio::time::sleep(delay).await;
436 delay = delay.saturating_mul(2);
437 continue 'download;
438 }
439 Err(err)?;
440 unreachable!();
441 }
442 };
443
444 sink.write_chunk(&chunk).await?;
445 hasher.update(&chunk);
446 received += chunk.len() as u64;
447 attempts_left = http.max_attempts();
448 delay = http.backoff();
449
450 let emit = match last_progress {
451 None => true,
452 Some(t) => t.elapsed() >= PROGRESS_INTERVAL,
453 };
454 if emit {
455 yield Outcome::Progress {
456 file: item.clone(),
457 received,
458 total: size,
459 };
460 last_progress = Some(Instant::now());
461 }
462 }
463 }
464
465 if received != size {
466 Err(Error::SizeMismatch {
467 url: url.to_string(),
468 expected: size,
469 actual: received,
470 })?;
471 }
472
473 let verified = match (checksum.as_ref(), hasher.finalize_hex()) {
474 (Some(expected), Some(actual)) => {
475 if actual != expected.hex() {
476 Err(Error::ChecksumMismatch {
477 algorithm: expected.algorithm(),
478 url: url.to_string(),
479 expected: expected.hex().to_owned(),
480 actual,
481 })?;
482 }
483 true
484 }
485 _ => false,
486 };
487
488 let location = sink.finalize().await?;
489 yield Outcome::Downloaded { file: item, location, verified };
490 }
491}