rhtdl/
lib.rs

1#![doc = include_str!("../README.md")]
2
3#![forbid(unsafe_code)]
4// TODO: integration test with traffic control (tc(8)) to simulate poor network connection
5
6use std::fmt;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9use std::ops::Deref;
10use std::net::SocketAddr;
11use reqwest::{Url, Client, Response};
12use headers::{ContentRange, Range, HeaderMap, HeaderMapExt};
13use tracing::{trace, error, warn, info, info_span, debug};
14
15mod integrity;
16mod error;
17mod output;
18
19use crate::integrity::Integrity;
20pub use crate::output::Output;
21pub use crate::error::Error;
22
23pub type Result<T> = std::result::Result<T, Error>;
24
25// this could be made into a wrapper struct with 0xFFFF reserved as a sentinel
26// value, but i don't think it's worth it just to save 3 bytes.
27/// Amount of times to retry an operation.
28///
29/// `None` = infinite retries.
30///
31/// note that this is the number of retries, not the number of *attempts*.
32pub type RetryCount = Option<u16>;
33
34/// Tracks how much of a file has been downloaded.
35///
36/// Allows resuming downloads and showing a progress bar.
37///
38/// This may be written to a file or database to allow gracefully resuming if a program unexpectedly terminates.
39#[derive(Default, Copy, Clone, Debug)]
40pub struct Progress {
41    pub bytes_read: u64,
42    pub bytes_total: Option<u64>,
43}
44
45impl fmt::Display for Progress {
46    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47        if let Some(total) = self.bytes_total {
48            write!(f, "{}/{}", self.bytes_read, total)?;
49        } else {
50            write!(f, "{}/??", self.bytes_read)?;
51        }
52        Ok(())
53    }
54}
55
56impl Progress {
57    fn complete(&self) -> bool {
58        if let Some(total) = self.bytes_total {
59            self.bytes_read >= total
60        } else {
61            false
62        }
63    }
64
65    /// Recover the progress of a download based on the length of its output file.
66    ///
67    /// `total_len` will be set to `None`.
68    pub fn from_file_len(file: &std::fs::File) -> std::io::Result<Progress> {
69        Ok(Progress{
70            bytes_read: file.metadata()?.len(),
71            bytes_total: None,
72        })
73    }
74}
75
76/// parameters for how a file should be downloaded.
77///
78/// the default was carefully selected to be suitable for most usecases over
79/// virtually any type of network.
80///
81/// unlike [`Options`], this may be reused for multiple downloads.
82#[derive(Debug, Clone)]
83// TODO: optional serde support
84pub struct Config {
85    /// if `Some`, abort the download after the specified duration has passed.
86    pub timeout: Option<Duration>,
87    /// if the specified duration passes with no progress being made,
88    /// abort the download.
89    pub stalled_timeout: Option<Duration>,
90    /// number of times to retry the download while no progress is being made.
91    ///
92    /// None = infinite retries
93    ///
94    /// Note that this is the number of *retries*, which is one less than
95    /// the number of attempts.  With a retry count of 0, exactly
96    /// one attempt will be made, and it will fail if any error occurs.
97    ///
98    /// The larger this number, the longer a burst of spurious errors
99    /// will need to be in order to promoted to a fatal error.
100    pub stalled_retries: RetryCount,
101    /// total number of times to retry the download
102    ///
103    /// Using this is not recommended unless you know you will be
104    /// downloading over a stable network.  instead try configuring
105    /// the stalled download timeout or overall deadline.
106    pub max_retries: RetryCount,
107    /// maximum number of times to restart the download from the beginning
108    ///
109    /// usually, if there is an error, `rhtdl` will use the http Range
110    /// header to resume the download from where it left off,
111    /// but if the server does not implement `Range` requests, it will
112    /// restart the download from the beginning.
113    ///
114    /// the download will also be restarted if `rhtdl` detects that the
115    /// resource has been modified (eg. due to an `ETag` change).
116    pub max_restarts: RetryCount,
117    /// default size of transfer chunks requested with the Range header.
118    ///
119    /// larger chunk sizes will usually result in faster downloads,
120    /// (especially if the remote server refuses to keep the connection open
121    /// after a request), but will use more memory.
122    pub chunk_size: u64,
123    /// base retry delay used for exponential backoff.
124    ///
125    /// if there is only one error, this is the `amount` of time that will be
126    /// waited before retrying.
127    pub backoff_base: Duration,
128    // TODO: arbitrary headers
129}
130
131impl Default for Config {
132    fn default() -> Config {
133        // very generous defaults sutible for downloading massive files over
134        // horrible connections.
135        Self {
136            timeout: None,
137            stalled_timeout: Some(Duration::from_secs(3*60)),
138            stalled_retries: Some(7),
139            max_retries: None,
140            max_restarts: Some(3),
141            chunk_size: 1_024_000,
142            backoff_base: Duration::from_millis(200),
143        }
144    }
145}
146
147impl Config {
148    /// set infinite timeout and retries.
149    ///
150    /// useful for background downloads that should
151    /// keep going unless the user cancels them,
152    /// such as a podcast player.
153    #[inline]
154    pub fn never_give_up(&mut self) -> &mut Self {
155        self.timeout = None;
156        self.max_retries = None;
157        self.stalled_retries = None;
158        self.max_restarts = None;
159        // TODO: timeouts
160        self
161    }
162}
163
164// TODO: bon::builder so the internal representation can be modified?
165// TODO: split the immutable data (configuration options) from the mutable data like `progress` so an Options struct can be easily reused for multiple downloads.
166/// structure that controls the behavior of a specific download.
167///
168/// this contains an `Arc<Config>`, but also other data, such as channels for
169/// reporting the progress of the download.
170#[derive(Default)]
171pub struct Options {
172    config: Arc<Config>,
173    /// place to report progress to
174    progress: Option<Arc<Mutex<Progress>>>,
175    deadline: Option<Instant>,
176    address_callback: Option<Box<dyn Fn(SocketAddr)>>,
177}
178
179impl fmt::Debug for Options {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        // FIXME: do this properly
182        write!(f, "{:?}", &self.config)
183    }
184}
185
186impl Deref for Options {
187    type Target = Config;
188
189    fn deref(&self) -> &Self::Target {
190        &self.config 
191    }
192}
193
194impl From<Arc<Config>> for Options {
195    #[inline]
196    fn from(config: Arc<Config>) -> Self {
197        Options {
198            config,
199            deadline: None,
200            progress: None,
201            address_callback: None,
202        }
203    }
204}
205
206impl From<Config> for Options {
207    #[inline]
208    fn from(config: Config) -> Self {
209        Arc::new(config).into()
210    }
211}
212
213impl Options {
214    /// return a handle that can be used to track the progress of the download
215    ///
216    /// # Panics
217    /// this method panics if it is called multiple times on the same `Options`.
218    pub fn progress(&mut self) -> Arc<Mutex<Progress>> {
219        assert!(self.progress.is_none());
220        let handle = Arc::new(Mutex::new(Progress::default()));
221        self.progress = Some(handle.clone());
222        handle
223    }
224
225    /// registers a callback that is called as soon as an address is obtained.
226    ///
227    /// this mostly exists for the purpose of integration testing.
228    pub fn register_address_callback(&mut self, callback: Box<dyn Fn(SocketAddr)>) {
229        self.address_callback = Some(callback);
230    }
231
232    fn deadline_expired(&self) -> bool {
233        if let Some(deadline) = self.deadline {
234            Instant::now().saturating_duration_since(deadline) > Duration::from_secs(0)
235        } else {
236            false
237        }
238    }
239
240    fn deadline_start(&mut self) {
241        if self.deadline.is_none() && self.timeout.is_some() {
242            self.deadline = Some(Instant::now() + self.timeout.unwrap());
243        }
244    }
245}
246
247// TODO: use Content-Digest to verify the download, if available.
248// TODO: multiple download workers, each connecting to the host over a different IP, if available. two `RangeSet`s behind a mutex to coordinate parallel range downloads.  one `RangeSet` stores the blocks that are currently being downloaded by a worker, and one stores the blocks that have already been downloaded (maybe it would be better to invert this slightly, and have map of ranges that have been completed, and a map of ranges yet to be delegated)
249// TODO: use If-Range to have the server restart the download if the etag changes.
250/// a download task that is in progress
251struct Download<O> {
252    options: Options,
253    url: Url,
254    /// internal progress status
255    progress: Progress,
256    resp: Option<Response>,
257    error_count: u32,
258    restart_count: u16,
259    /// fatal error encountered
260    fatal: bool,
261    client: Client,
262    /// time since a data byte was received
263    last_progress_made: Instant,
264    /// number of errors since a data byte was received
265    ///
266    /// used for `stalled_retries` and exponential backoff
267    // TODO: wrap all error counts in `Saturating`.
268    errors_since_last_progress: u16,
269    output: O,
270    integrity: Integrity,
271    last_error: Option<Error>,
272}
273
274// TODO: this should never panic, use that binary analysis thing to assert that.
275/// Used to resume a download across program restarts.
276///
277/// The `bytes_read` of `progress` must exactly match the amount of bytes previously written to `output`, but the `bytes_total` may be `None`.
278/// The easiest way to get this is via `Progress::from_file_len`.
279pub async fn resume(url: Url, options: Options, output: impl Output, progress: Progress) -> Result<()> {
280    use tracing::Instrument;
281    let span = info_span!("download", url = url.to_string());
282    let mut dl = Download {
283        options,
284        output,
285        url,
286        progress,
287        error_count: 0,
288        restart_count: 0,
289        last_progress_made: Instant::now(),
290        errors_since_last_progress: 0,
291        fatal: false,
292        resp: None,
293        client: Client::new(),
294        integrity: Integrity::default(),
295        last_error: None,
296    };
297    dl.resume().instrument(span).await
298}
299
300/// The central function of the crate.
301///
302/// Downloads `url` using `options`, saving the result to `output`.
303///
304/// `output` must be empty.
305pub async fn download(url: Url, options: impl Into<Options>, output: impl Output) -> Result<()> {
306    resume(url, options.into(), output, Progress::default()).await?;
307    Ok(())
308}
309
310impl<O: Output> Download<O> {
311    async fn try_step_inner(&mut self) -> Result<()> {
312        if let Some(resp) = &mut self.resp {
313            if resp.url() != &self.url {
314                // redirect happened, save the new url,
315                // since the older url might me a "latest"
316                // url that gets changed to point to a new url
317                // whenever a new version is published,
318                // but we don't want to splice two versions together.
319                self.url = resp.url().clone();
320            }
321            let Some(chunk) = resp.chunk().await? else {
322                self.resp = None;
323                return Ok(())
324            };
325            self.progress.bytes_read += chunk.len() as u64;
326            trace!("recieved {} bytes", chunk.len());
327            if let Some(prog) = &self.options.progress {
328                // if the progress is currently being read, don't block,
329                // just wait until next pass (maybe this should use buf2sync instead to make it lockless?)
330                if let Ok(mut prog) = prog.try_lock() {
331                    *prog = self.progress;
332                }
333            }
334            self.last_progress_made = Instant::now();
335            self.errors_since_last_progress = 0;
336            self.output.write_all(&chunk)?;
337        } else {
338            let rstart = self.progress.bytes_read;
339            let rend = rstart + self.options.chunk_size;
340            let mut req = self.client.get(self.url.clone());
341            let mut headers = HeaderMap::new();
342            headers.typed_insert(Range::bytes(rstart..rend).unwrap());
343            // TODO: use if_range around here? will also need to slighly tweak retry logic to avoid wasting a request restarting a download that has already restarted.
344            req = req.headers(headers);
345            if let Some(deadline) = self.options.deadline {
346                // if deadline expires, reqwest will give an error.
347                req = req.timeout(deadline.saturating_duration_since(Instant::now()));
348            }
349            let resp = req.send().await?;
350            if let Some(ac) = &self.options.address_callback {
351                if let Some(ra) = resp.remote_addr() {
352                    ac(ra);
353                }
354            }
355            let maybe_cr = resp.headers().typed_get::<ContentRange>();
356            let real_rstart = maybe_cr
357                .as_ref()
358                .and_then(headers::ContentRange::bytes_range)
359                .map_or(0, |x| x.0);
360            if real_rstart != rstart {
361                // if rstart != 0, and Content-Range is missing,
362                // we can't continue the download.  retry from start.
363                warn!("expected to recive a response starting at {rstart}, instead recived a chunk starting at {real_rstart}");
364                return self.restart();
365            }
366            if let Err(e) = self.integrity.ingest_headers(resp.headers()) {
367                error!("mid-air collision detected, restarting download ({e})");
368                self.last_error = Some(e);
369                return self.restart();
370            }
371            if self.progress.bytes_total.is_none() {
372                if let Some(cr) = maybe_cr {
373                    self.progress.bytes_total = cr.bytes_len();
374                } else {
375                    self.progress.bytes_total = resp.content_length();
376                }
377            } else if let Some(cr) = maybe_cr {
378                let old_len = self.progress.bytes_total;
379                let new_len = cr.bytes_len();
380                if old_len != new_len {
381                    warn!(
382                        old = old_len,
383                        new = new_len,
384                        "total length of resource changed",
385                    );
386                    return self.restart();
387                }
388            }
389            self.resp = Some(resp);
390        }
391        Ok(())
392    }
393
394    async fn try_step(&mut self) -> Result<()> {
395        if let Some(deadline) = self.stall_deadline() {
396            tokio::time::timeout_at(deadline.into(), self.try_step_inner()).await?
397        } else {
398            self.try_step_inner().await
399        }
400    }
401    
402    // only returns the last error
403    async fn resume(&mut self) -> Result<()> {
404        self.options.deadline_start();
405        debug!("downloading from offset {}", self.progress.bytes_read);
406        while !self.progress.complete() {
407            match self.try_step().await {
408                Ok(()) => {},
409                Err(err) => {
410                    if err.is_fatal() {
411                        error!("fatal error: {err}");
412                        return Err(err);
413                    }
414                    if err.is_timeout() {
415                        // since there is no way to construct
416                        // reqwest::Error, we need to get reqwest to
417                        // make one for us, by using its timeout
418                        // TODO: now that we have our own error type, perhaps reconsider?
419                        if self.options.deadline_expired() {
420                            error!("deadline expired");
421                            return Err(err);
422                        }
423                    }
424                    // exponential backoff.
425                    // is computed before incrementing error_count to
426                    // avoid a (mostly insignificant) off-by-one error.
427                    // TODO: "recovery time", if there are no errors for a long period, reduce the retry delay
428                    let delay = self.options.backoff_base * 2_u16.saturating_pow(self.error_count).into();
429                    self.error_count += 1;
430                    self.errors_since_last_progress += 1;
431                    warn!(error_count = self.error_count, "non-fatal error: {err:?}");
432                    if let Some(max) = self.options.max_retries {
433                        if self.error_count >= max.into() {
434                            error!("max_retries exceeded");
435                            return Err(err);
436                        }
437                    }
438                    if let Some(max) = self.options.stalled_retries {
439                        if self.error_count >= max.into() {
440                            error!(stalled_retries = self.options.stalled_retries, "stalled_retries exceeded");
441                            return Err(err);
442                        }
443                    }
444                    self.last_error = Some(err);
445                    info!("retrying in {} ms", delay.as_millis());
446                    tokio::time::sleep(delay).await;
447                },
448            }
449        }
450        if self.fatal {
451            if let Some(e) = self.last_error.take() {
452                error!("previous error upgraded to fatal: {e}");
453                return Err(e);
454            }
455        }
456        Ok(())
457    }
458
459    /// rewind to the start and retry the download from there.
460    fn restart(&mut self) -> Result<()> {
461        warn!("restarting download");
462        if let Some(max) = self.options.max_restarts {
463            error!("maximum restart count exceeded");
464            if self.restart_count >= max {
465                self.fatal = true;
466                return Ok(());
467            }
468        }
469        self.restart_count = self.restart_count.saturating_add(1);
470        self.progress.bytes_read = 0;
471        self.resp = None;
472        self.output.restart()?;
473        self.integrity = Default::default();
474        Ok(())
475    }
476
477    fn stall_deadline(&self) -> Option<Instant> {
478        self.options.stalled_timeout.map(|x| self.last_progress_made + x)
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    /*#[tokio::test]
487    async fn failure() {
488        // for some reason, my ISP will redirect *any* unknown domain to
489        // the router login page, making this test impossible.
490        return;
491        // TODO: check the actual error
492        assert_eq!(
493            resume("http://bad-domain.example/".parse().unwrap(),
494                   Options::default(), &std::io::sink(),
495                   Progress::default()).await.unwrap_err().to_string(),
496            "error sending request for url (http://bad-domain.example/)");
497    }*/
498
499    #[tokio::test]
500    async fn success() {
501        let mut out = Vec::<u8>::new();
502        resume("https://raw.githubusercontent.com/unicode-org/cldr-json/44.0.1/cldr-json/cldr-annotations-modern/annotations/en-001/annotations.json".parse().unwrap(), Options::default(), &mut out, Progress::default()).await.unwrap();
503    }
504
505    #[tokio::test]
506    async fn short_timeout() {
507        // TODO: capture the "tracing" output to make sure it mentions a timeout/deadline
508        let mut cfg = Config::default();
509        cfg.timeout = Some(Duration::from_nanos(10));
510        let result = download(
511            "https://static.rust-lang.org/dist/2024-07-26/cargo-nightly-powerpc64-unknown-linux-gnu.tar.xz".parse().unwrap(),
512            cfg, &std::io::sink())
513            .await;
514        assert!(result.is_err());
515        assert!(result.unwrap_err().is_timeout());
516    }
517}