rhtdl/lib.rs
1#![doc = include_str!("../README.md")]
2
3#![forbid(unsafe_code)]
4// TODO: integration test with traffic control (tc(8)) to simulate poor network connection
5
6use std::fmt;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9use std::ops::Deref;
10use std::net::SocketAddr;
11use reqwest::{Url, Client, Response};
12use headers::{ContentRange, Range, HeaderMap, HeaderMapExt};
13use tracing::{trace, error, warn, info, info_span, debug};
14
15mod integrity;
16mod error;
17mod output;
18
19use crate::integrity::Integrity;
20pub use crate::output::Output;
21pub use crate::error::Error;
22
23pub type Result<T> = std::result::Result<T, Error>;
24
25// this could be made into a wrapper struct with 0xFFFF reserved as a sentinel
26// value, but i don't think it's worth it just to save 3 bytes.
27/// Amount of times to retry an operation.
28///
29/// `None` = infinite retries.
30///
31/// note that this is the number of retries, not the number of *attempts*.
32pub type RetryCount = Option<u16>;
33
34/// Tracks how much of a file has been downloaded.
35///
36/// Allows resuming downloads and showing a progress bar.
37///
38/// This may be written to a file or database to allow gracefully resuming if a program unexpectedly terminates.
39#[derive(Default, Copy, Clone, Debug)]
40pub struct Progress {
41 pub bytes_read: u64,
42 pub bytes_total: Option<u64>,
43}
44
45impl fmt::Display for Progress {
46 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47 if let Some(total) = self.bytes_total {
48 write!(f, "{}/{}", self.bytes_read, total)?;
49 } else {
50 write!(f, "{}/??", self.bytes_read)?;
51 }
52 Ok(())
53 }
54}
55
56impl Progress {
57 fn complete(&self) -> bool {
58 if let Some(total) = self.bytes_total {
59 self.bytes_read >= total
60 } else {
61 false
62 }
63 }
64
65 /// Recover the progress of a download based on the length of its output file.
66 ///
67 /// `total_len` will be set to `None`.
68 pub fn from_file_len(file: &std::fs::File) -> std::io::Result<Progress> {
69 Ok(Progress{
70 bytes_read: file.metadata()?.len(),
71 bytes_total: None,
72 })
73 }
74}
75
76/// parameters for how a file should be downloaded.
77///
78/// the default was carefully selected to be suitable for most usecases over
79/// virtually any type of network.
80///
81/// unlike [`Options`], this may be reused for multiple downloads.
82#[derive(Debug, Clone)]
83// TODO: optional serde support
84pub struct Config {
85 /// if `Some`, abort the download after the specified duration has passed.
86 pub timeout: Option<Duration>,
87 /// if the specified duration passes with no progress being made,
88 /// abort the download.
89 pub stalled_timeout: Option<Duration>,
90 /// number of times to retry the download while no progress is being made.
91 ///
92 /// None = infinite retries
93 ///
94 /// Note that this is the number of *retries*, which is one less than
95 /// the number of attempts. With a retry count of 0, exactly
96 /// one attempt will be made, and it will fail if any error occurs.
97 ///
98 /// The larger this number, the longer a burst of spurious errors
99 /// will need to be in order to promoted to a fatal error.
100 pub stalled_retries: RetryCount,
101 /// total number of times to retry the download
102 ///
103 /// Using this is not recommended unless you know you will be
104 /// downloading over a stable network. instead try configuring
105 /// the stalled download timeout or overall deadline.
106 pub max_retries: RetryCount,
107 /// maximum number of times to restart the download from the beginning
108 ///
109 /// usually, if there is an error, `rhtdl` will use the http Range
110 /// header to resume the download from where it left off,
111 /// but if the server does not implement `Range` requests, it will
112 /// restart the download from the beginning.
113 ///
114 /// the download will also be restarted if `rhtdl` detects that the
115 /// resource has been modified (eg. due to an `ETag` change).
116 pub max_restarts: RetryCount,
117 /// default size of transfer chunks requested with the Range header.
118 ///
119 /// larger chunk sizes will usually result in faster downloads,
120 /// (especially if the remote server refuses to keep the connection open
121 /// after a request), but will use more memory.
122 pub chunk_size: u64,
123 /// base retry delay used for exponential backoff.
124 ///
125 /// if there is only one error, this is the `amount` of time that will be
126 /// waited before retrying.
127 pub backoff_base: Duration,
128 // TODO: arbitrary headers
129}
130
131impl Default for Config {
132 fn default() -> Config {
133 // very generous defaults sutible for downloading massive files over
134 // horrible connections.
135 Self {
136 timeout: None,
137 stalled_timeout: Some(Duration::from_secs(3*60)),
138 stalled_retries: Some(7),
139 max_retries: None,
140 max_restarts: Some(3),
141 chunk_size: 1_024_000,
142 backoff_base: Duration::from_millis(200),
143 }
144 }
145}
146
147impl Config {
148 /// set infinite timeout and retries.
149 ///
150 /// useful for background downloads that should
151 /// keep going unless the user cancels them,
152 /// such as a podcast player.
153 #[inline]
154 pub fn never_give_up(&mut self) -> &mut Self {
155 self.timeout = None;
156 self.max_retries = None;
157 self.stalled_retries = None;
158 self.max_restarts = None;
159 // TODO: timeouts
160 self
161 }
162}
163
164// TODO: bon::builder so the internal representation can be modified?
165// TODO: split the immutable data (configuration options) from the mutable data like `progress` so an Options struct can be easily reused for multiple downloads.
166/// structure that controls the behavior of a specific download.
167///
168/// this contains an `Arc<Config>`, but also other data, such as channels for
169/// reporting the progress of the download.
170#[derive(Default)]
171pub struct Options {
172 config: Arc<Config>,
173 /// place to report progress to
174 progress: Option<Arc<Mutex<Progress>>>,
175 deadline: Option<Instant>,
176 address_callback: Option<Box<dyn Fn(SocketAddr)>>,
177}
178
179impl fmt::Debug for Options {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 // FIXME: do this properly
182 write!(f, "{:?}", &self.config)
183 }
184}
185
186impl Deref for Options {
187 type Target = Config;
188
189 fn deref(&self) -> &Self::Target {
190 &self.config
191 }
192}
193
194impl From<Arc<Config>> for Options {
195 #[inline]
196 fn from(config: Arc<Config>) -> Self {
197 Options {
198 config,
199 deadline: None,
200 progress: None,
201 address_callback: None,
202 }
203 }
204}
205
206impl From<Config> for Options {
207 #[inline]
208 fn from(config: Config) -> Self {
209 Arc::new(config).into()
210 }
211}
212
213impl Options {
214 /// return a handle that can be used to track the progress of the download
215 ///
216 /// # Panics
217 /// this method panics if it is called multiple times on the same `Options`.
218 pub fn progress(&mut self) -> Arc<Mutex<Progress>> {
219 assert!(self.progress.is_none());
220 let handle = Arc::new(Mutex::new(Progress::default()));
221 self.progress = Some(handle.clone());
222 handle
223 }
224
225 /// registers a callback that is called as soon as an address is obtained.
226 ///
227 /// this mostly exists for the purpose of integration testing.
228 pub fn register_address_callback(&mut self, callback: Box<dyn Fn(SocketAddr)>) {
229 self.address_callback = Some(callback);
230 }
231
232 fn deadline_expired(&self) -> bool {
233 if let Some(deadline) = self.deadline {
234 Instant::now().saturating_duration_since(deadline) > Duration::from_secs(0)
235 } else {
236 false
237 }
238 }
239
240 fn deadline_start(&mut self) {
241 if self.deadline.is_none() && self.timeout.is_some() {
242 self.deadline = Some(Instant::now() + self.timeout.unwrap());
243 }
244 }
245}
246
247// TODO: use Content-Digest to verify the download, if available.
248// TODO: multiple download workers, each connecting to the host over a different IP, if available. two `RangeSet`s behind a mutex to coordinate parallel range downloads. one `RangeSet` stores the blocks that are currently being downloaded by a worker, and one stores the blocks that have already been downloaded (maybe it would be better to invert this slightly, and have map of ranges that have been completed, and a map of ranges yet to be delegated)
249// TODO: use If-Range to have the server restart the download if the etag changes.
250/// a download task that is in progress
251struct Download<O> {
252 options: Options,
253 url: Url,
254 /// internal progress status
255 progress: Progress,
256 resp: Option<Response>,
257 error_count: u32,
258 restart_count: u16,
259 /// fatal error encountered
260 fatal: bool,
261 client: Client,
262 /// time since a data byte was received
263 last_progress_made: Instant,
264 /// number of errors since a data byte was received
265 ///
266 /// used for `stalled_retries` and exponential backoff
267 // TODO: wrap all error counts in `Saturating`.
268 errors_since_last_progress: u16,
269 output: O,
270 integrity: Integrity,
271 last_error: Option<Error>,
272}
273
274// TODO: this should never panic, use that binary analysis thing to assert that.
275/// Used to resume a download across program restarts.
276///
277/// The `bytes_read` of `progress` must exactly match the amount of bytes previously written to `output`, but the `bytes_total` may be `None`.
278/// The easiest way to get this is via `Progress::from_file_len`.
279pub async fn resume(url: Url, options: Options, output: impl Output, progress: Progress) -> Result<()> {
280 use tracing::Instrument;
281 let span = info_span!("download", url = url.to_string());
282 let mut dl = Download {
283 options,
284 output,
285 url,
286 progress,
287 error_count: 0,
288 restart_count: 0,
289 last_progress_made: Instant::now(),
290 errors_since_last_progress: 0,
291 fatal: false,
292 resp: None,
293 client: Client::new(),
294 integrity: Integrity::default(),
295 last_error: None,
296 };
297 dl.resume().instrument(span).await
298}
299
300/// The central function of the crate.
301///
302/// Downloads `url` using `options`, saving the result to `output`.
303///
304/// `output` must be empty.
305pub async fn download(url: Url, options: impl Into<Options>, output: impl Output) -> Result<()> {
306 resume(url, options.into(), output, Progress::default()).await?;
307 Ok(())
308}
309
310impl<O: Output> Download<O> {
311 async fn try_step_inner(&mut self) -> Result<()> {
312 if let Some(resp) = &mut self.resp {
313 if resp.url() != &self.url {
314 // redirect happened, save the new url,
315 // since the older url might me a "latest"
316 // url that gets changed to point to a new url
317 // whenever a new version is published,
318 // but we don't want to splice two versions together.
319 self.url = resp.url().clone();
320 }
321 let Some(chunk) = resp.chunk().await? else {
322 self.resp = None;
323 return Ok(())
324 };
325 self.progress.bytes_read += chunk.len() as u64;
326 trace!("recieved {} bytes", chunk.len());
327 if let Some(prog) = &self.options.progress {
328 // if the progress is currently being read, don't block,
329 // just wait until next pass (maybe this should use buf2sync instead to make it lockless?)
330 if let Ok(mut prog) = prog.try_lock() {
331 *prog = self.progress;
332 }
333 }
334 self.last_progress_made = Instant::now();
335 self.errors_since_last_progress = 0;
336 self.output.write_all(&chunk)?;
337 } else {
338 let rstart = self.progress.bytes_read;
339 let rend = rstart + self.options.chunk_size;
340 let mut req = self.client.get(self.url.clone());
341 let mut headers = HeaderMap::new();
342 headers.typed_insert(Range::bytes(rstart..rend).unwrap());
343 // TODO: use if_range around here? will also need to slighly tweak retry logic to avoid wasting a request restarting a download that has already restarted.
344 req = req.headers(headers);
345 if let Some(deadline) = self.options.deadline {
346 // if deadline expires, reqwest will give an error.
347 req = req.timeout(deadline.saturating_duration_since(Instant::now()));
348 }
349 let resp = req.send().await?;
350 if let Some(ac) = &self.options.address_callback {
351 if let Some(ra) = resp.remote_addr() {
352 ac(ra);
353 }
354 }
355 let maybe_cr = resp.headers().typed_get::<ContentRange>();
356 let real_rstart = maybe_cr
357 .as_ref()
358 .and_then(headers::ContentRange::bytes_range)
359 .map_or(0, |x| x.0);
360 if real_rstart != rstart {
361 // if rstart != 0, and Content-Range is missing,
362 // we can't continue the download. retry from start.
363 warn!("expected to recive a response starting at {rstart}, instead recived a chunk starting at {real_rstart}");
364 return self.restart();
365 }
366 if let Err(e) = self.integrity.ingest_headers(resp.headers()) {
367 error!("mid-air collision detected, restarting download ({e})");
368 self.last_error = Some(e);
369 return self.restart();
370 }
371 if self.progress.bytes_total.is_none() {
372 if let Some(cr) = maybe_cr {
373 self.progress.bytes_total = cr.bytes_len();
374 } else {
375 self.progress.bytes_total = resp.content_length();
376 }
377 } else if let Some(cr) = maybe_cr {
378 let old_len = self.progress.bytes_total;
379 let new_len = cr.bytes_len();
380 if old_len != new_len {
381 warn!(
382 old = old_len,
383 new = new_len,
384 "total length of resource changed",
385 );
386 return self.restart();
387 }
388 }
389 self.resp = Some(resp);
390 }
391 Ok(())
392 }
393
394 async fn try_step(&mut self) -> Result<()> {
395 if let Some(deadline) = self.stall_deadline() {
396 tokio::time::timeout_at(deadline.into(), self.try_step_inner()).await?
397 } else {
398 self.try_step_inner().await
399 }
400 }
401
402 // only returns the last error
403 async fn resume(&mut self) -> Result<()> {
404 self.options.deadline_start();
405 debug!("downloading from offset {}", self.progress.bytes_read);
406 while !self.progress.complete() {
407 match self.try_step().await {
408 Ok(()) => {},
409 Err(err) => {
410 if err.is_fatal() {
411 error!("fatal error: {err}");
412 return Err(err);
413 }
414 if err.is_timeout() {
415 // since there is no way to construct
416 // reqwest::Error, we need to get reqwest to
417 // make one for us, by using its timeout
418 // TODO: now that we have our own error type, perhaps reconsider?
419 if self.options.deadline_expired() {
420 error!("deadline expired");
421 return Err(err);
422 }
423 }
424 // exponential backoff.
425 // is computed before incrementing error_count to
426 // avoid a (mostly insignificant) off-by-one error.
427 // TODO: "recovery time", if there are no errors for a long period, reduce the retry delay
428 let delay = self.options.backoff_base * 2_u16.saturating_pow(self.error_count).into();
429 self.error_count += 1;
430 self.errors_since_last_progress += 1;
431 warn!(error_count = self.error_count, "non-fatal error: {err:?}");
432 if let Some(max) = self.options.max_retries {
433 if self.error_count >= max.into() {
434 error!("max_retries exceeded");
435 return Err(err);
436 }
437 }
438 if let Some(max) = self.options.stalled_retries {
439 if self.error_count >= max.into() {
440 error!(stalled_retries = self.options.stalled_retries, "stalled_retries exceeded");
441 return Err(err);
442 }
443 }
444 self.last_error = Some(err);
445 info!("retrying in {} ms", delay.as_millis());
446 tokio::time::sleep(delay).await;
447 },
448 }
449 }
450 if self.fatal {
451 if let Some(e) = self.last_error.take() {
452 error!("previous error upgraded to fatal: {e}");
453 return Err(e);
454 }
455 }
456 Ok(())
457 }
458
459 /// rewind to the start and retry the download from there.
460 fn restart(&mut self) -> Result<()> {
461 warn!("restarting download");
462 if let Some(max) = self.options.max_restarts {
463 error!("maximum restart count exceeded");
464 if self.restart_count >= max {
465 self.fatal = true;
466 return Ok(());
467 }
468 }
469 self.restart_count = self.restart_count.saturating_add(1);
470 self.progress.bytes_read = 0;
471 self.resp = None;
472 self.output.restart()?;
473 self.integrity = Default::default();
474 Ok(())
475 }
476
477 fn stall_deadline(&self) -> Option<Instant> {
478 self.options.stalled_timeout.map(|x| self.last_progress_made + x)
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 /*#[tokio::test]
487 async fn failure() {
488 // for some reason, my ISP will redirect *any* unknown domain to
489 // the router login page, making this test impossible.
490 return;
491 // TODO: check the actual error
492 assert_eq!(
493 resume("http://bad-domain.example/".parse().unwrap(),
494 Options::default(), &std::io::sink(),
495 Progress::default()).await.unwrap_err().to_string(),
496 "error sending request for url (http://bad-domain.example/)");
497 }*/
498
499 #[tokio::test]
500 async fn success() {
501 let mut out = Vec::<u8>::new();
502 resume("https://raw.githubusercontent.com/unicode-org/cldr-json/44.0.1/cldr-json/cldr-annotations-modern/annotations/en-001/annotations.json".parse().unwrap(), Options::default(), &mut out, Progress::default()).await.unwrap();
503 }
504
505 #[tokio::test]
506 async fn short_timeout() {
507 // TODO: capture the "tracing" output to make sure it mentions a timeout/deadline
508 let mut cfg = Config::default();
509 cfg.timeout = Some(Duration::from_nanos(10));
510 let result = download(
511 "https://static.rust-lang.org/dist/2024-07-26/cargo-nightly-powerpc64-unknown-linux-gnu.tar.xz".parse().unwrap(),
512 cfg, &std::io::sink())
513 .await;
514 assert!(result.is_err());
515 assert!(result.unwrap_err().is_timeout());
516 }
517}