async_fetcher/
lib.rs

1// Copyright 2021-2022 System76 <info@system76.com>
2// SPDX-License-Identifier: MPL-2.0
3
4//! Asynchronously fetch files from HTTP servers
5//!
6//! - Concurrently fetch multiple files at the same time.
7//! - Define alternative mirrors for the source of a file.
8//! - Use multiple concurrent connections per file.
9//! - Use mirrors for concurrent connections.
10//! - Resume a download which has been interrupted.
11//! - Progress events for fetches
12//!
13//! ```no_test
14//! use async_fetcher::Fetcher;
15//! use std::time::Duration;
16//!
17//! let (events_tx, events_rx) = tokio::sync::mpsc::unbounded_channel();
18//!
19//! let shutdown = async_shutdown::ShutdownManager::new();
20//!
21//! let results_stream = Fetcher::default()
22//!     // Define a max number of ranged connections per file.
23//!     .connections_per_file(4)
24//!     // Max size of a connection's part, concatenated on completion.
25//!     .max_part_size(4 * 1024 * 1024)
26//!     // The channel for sending progress notifications.
27//!     .events(events_tx)
28//!     // Maximum number of retry attempts.
29//!     .retries(3)
30//!     // Cancels the fetching process when a shutdown is triggered.
31//!     .shutdown(shutdown)
32//!     // How long to wait before aborting a download that hasn't progressed.
33//!     .timeout(Duration::from_secs(15))
34//!     // Finalize the struct into an `Arc` for use with fetching.
35//!     .build()
36//!     // Take a stream of `Source` inputs and generate a stream of fetches.
37//!     // Spawns
38//!     .stream_from(input_stream, 4);
39//! ```
40
41#[macro_use]
42extern crate derive_new;
43#[macro_use]
44extern crate derive_setters;
45#[macro_use]
46extern crate log;
47#[macro_use]
48extern crate thiserror;
49
50pub mod iface;
51
52mod checksum;
53mod checksum_system;
54mod concatenator;
55mod get;
56mod get_many;
57mod range;
58mod source;
59mod time;
60mod utils;
61
62pub use self::checksum::*;
63pub use self::checksum_system::*;
64pub use self::concatenator::*;
65pub use self::source::*;
66
67use self::get::{get, FetchLocation};
68use self::get_many::get_many;
69use self::time::{date_as_timestamp, update_modified};
70use async_shutdown::ShutdownManager;
71use futures::{
72    prelude::*,
73    stream::{self, StreamExt},
74};
75
76use http::StatusCode;
77use httpdate::HttpDate;
78use numtoa::NumToA;
79use reqwest::redirect::Policy;
80use reqwest::{
81    Client as ReqwestClient, RequestBuilder as ReqwestBuilder, Response as ReqwestResponse,
82};
83
84use std::sync::atomic::Ordering;
85use std::{
86    fmt::Debug,
87    io,
88    path::Path,
89    pin::Pin,
90    sync::{atomic::AtomicU16, Arc},
91    time::{Duration, UNIX_EPOCH},
92};
93use tokio::fs;
94use tokio::sync::mpsc;
95
96/// The result of a fetched task from a stream of input sources.
97pub type AsyncFetchOutput<Data> = (Arc<Path>, Arc<Data>, Result<(), Error>);
98
99/// A channel for sending `FetchEvent`s to.
100pub type EventSender<Data> = mpsc::UnboundedSender<(Arc<Path>, Data, FetchEvent)>;
101
102/// An error from the asynchronous file fetcher.
103#[derive(Debug, Error)]
104pub enum Error {
105    #[error("task was canceled")]
106    Canceled,
107    #[error("http client error")]
108    ReqwestClient(#[source] reqwest::Error),
109    #[error("unable to concatenate fetched parts")]
110    Concatenate(#[source] io::Error),
111    #[error("unable to create file")]
112    FileCreate(#[source] io::Error),
113    #[error("unable to set timestamp on {:?}", _0)]
114    FileTime(Arc<Path>, #[source] io::Error),
115    #[error("content length is an invalid range")]
116    InvalidRange(#[source] io::Error),
117    #[error("unable to remove file with bad metadata")]
118    MetadataRemove(#[source] io::Error),
119    #[error("destination has no file name")]
120    Nameless,
121    #[error("network connection was interrupted while fetching")]
122    NetworkChanged,
123    #[error("unable to open fetched part")]
124    OpenPart(Arc<Path>, #[source] io::Error),
125    #[error("destination lacks parent")]
126    Parentless,
127    #[error("connection timed out")]
128    TimedOut,
129    #[error("error writing to file")]
130    Write(#[source] io::Error),
131    #[error("network input error")]
132    Read(#[source] io::Error),
133    #[error("failed to rename partial to destination")]
134    Rename(#[source] io::Error),
135    #[error("server responded with an error: {}", _0)]
136    Status(StatusCode),
137    #[error("internal tokio join handle error")]
138    TokioSpawn(#[source] tokio::task::JoinError),
139    #[error("the request builder did not match the client used")]
140    InvalidGetRequestBuilder,
141}
142
143impl From<reqwest::Error> for Error {
144    fn from(e: reqwest::Error) -> Self {
145        Self::ReqwestClient(e)
146    }
147}
148
149/// Events which are submitted by the fetcher.
150#[derive(Debug)]
151pub enum FetchEvent {
152    /// States that we know the length of the file being fetched.
153    ContentLength(u64),
154    /// Notifies that the file has been fetched.
155    Fetched,
156    /// Notifies that a file is being fetched.
157    Fetching,
158    /// Reports the amount of bytes that have been read for a file.
159    Progress(u64),
160    /// Notification that a fetch is being re-attempted.
161    Retrying,
162}
163
164/// An asynchronous file fetcher for clients fetching files.
165///
166/// The futures generated by the fetcher are compatible with single and multi-threaded
167/// runtimes, allowing you to choose between the runtime that works best for your
168/// application. A single-threaded runtime is generally recommended for fetching files,
169/// as your network connection is unlikely to be faster than a single CPU core.
170#[derive(new, Setters)]
171pub struct Fetcher<Data> {
172    /// Creates an instance of a client. The caller can decide if the instance
173    /// is shared or unique.
174    #[setters(skip)]
175    client: Client,
176
177    /// The number of concurrent connections to sustain per file being fetched.
178    /// # Note
179    /// Defaults to 1 connection
180    #[new(value = "1")]
181    connections_per_file: u16,
182
183    /// Configure the delay between file requests.
184    /// # Note
185    /// Defaults to no delay
186    #[new(value = "0")]
187    delay_between_requests: u64,
188
189    /// The number of attempts to make when a request fails.
190    /// # Note
191    /// Defaults to 3 retries.
192    #[new(value = "3")]
193    retries: u16,
194
195    /// The maximum size of a part file when downloading in parts.
196    /// # Note
197    /// Defaults to 2 MiB.
198    #[new(value = "2 * 1024 * 1024")]
199    max_part_size: u32,
200
201    /// Time in ms between progress messages
202    /// # Note
203    /// Defaults to 500.
204    #[new(value = "500")]
205    progress_interval: u64,
206
207    /// The time to wait between chunks before giving up.
208    #[new(default)]
209    #[setters(strip_option)]
210    timeout: Option<Duration>,
211
212    /// Holds a sender for submitting events to.
213    #[new(default)]
214    #[setters(into)]
215    #[setters(strip_option)]
216    events: Option<Arc<EventSender<Arc<Data>>>>,
217
218    /// Utilized to know when to shut down the fetching process.
219    #[new(value = "ShutdownManager::new()")]
220    shutdown: ShutdownManager<()>,
221}
222
223/// The underlying Client used for the Fetcher
224pub enum Client {
225    Reqwest(ReqwestClient),
226}
227
228pub(crate) enum RequestBuilder {
229    Reqwest(ReqwestBuilder),
230}
231
232impl<Data> Default for Fetcher<Data> {
233    fn default() -> Self {
234        let client = ReqwestClient::builder()
235            // Keep a TCP connection alive for up to 90s
236            .tcp_keepalive(Duration::from_secs(90))
237            // Follow up to 10 redirect links
238            .redirect(Policy::limited(10))
239            // Allow the server to be eager about sending packets
240            .tcp_nodelay(true)
241            // Cache DNS records for 24 hours
242            // .dns_cache(Duration::from_secs(60 * 60 * 24))
243            .build()
244            .expect("failed to create HTTP Client");
245
246        Self::new(Client::Reqwest(client))
247    }
248}
249
250impl<Data: Send + Sync + 'static> Fetcher<Data> {
251    /// Finalizes the fetcher to prepare it for fetch tasks.
252    pub fn build(self) -> Arc<Self> {
253        Arc::new(self)
254    }
255
256    /// Given an input stream of source fetches, returns an output stream of fetch results.
257    ///
258    /// Spawns up to `concurrent` + `1` number of concurrent async tasks on the runtime.
259    /// One task for managing the fetch tasks, and one task per fetch request.
260    pub fn stream_from(
261        self: Arc<Self>,
262        inputs: impl Stream<Item = (Source, Arc<Data>)> + Send + 'static,
263        concurrent: usize,
264    ) -> Pin<Box<dyn Stream<Item = AsyncFetchOutput<Data>> + Send + 'static>> {
265        let shutdown = self.shutdown.clone();
266        let cancel_trigger = shutdown.wait_shutdown_triggered();
267        // Takes input requests and converts them into a stream of fetch requests.
268        let stream = inputs
269            .map(move |(Source { dest, urls, part }, extra)| {
270                let fetcher = self.clone();
271                async move {
272                    if fetcher.delay_between_requests != 0 {
273                        let delay = Duration::from_millis(fetcher.delay_between_requests);
274                        tokio::time::sleep(delay).await;
275                    }
276
277                    tokio::spawn(async move {
278                        let _token = match fetcher.shutdown.delay_shutdown_token() {
279                            Ok(token) => token,
280                            Err(_) => return (dest, extra, Err(Error::Canceled)),
281                        };
282
283                        let task = async {
284                            match part {
285                                Some(part) => {
286                                    match fetcher.request(urls, part.clone(), extra.clone()).await {
287                                        Ok(()) => {
288                                            fs::rename(&*part, &*dest).await.map_err(Error::Rename)
289                                        }
290                                        Err(why) => Err(why),
291                                    }
292                                }
293                                None => fetcher.request(urls, dest.clone(), extra.clone()).await,
294                            }
295                        };
296
297                        let result = task.await;
298
299                        (dest, extra, result)
300                    })
301                    .await
302                    .unwrap()
303                }
304            })
305            .buffer_unordered(concurrent)
306            .take_until(cancel_trigger);
307
308        Box::pin(stream)
309    }
310
311    /// Request a file from one or more URIs.
312    ///
313    /// At least one URI must be provided as a source for the file. Each additional URI
314    /// serves as a mirror for failover and load-balancing purposes.
315    pub async fn request(
316        self: Arc<Self>,
317        uris: Arc<[Box<str>]>,
318        to: Arc<Path>,
319        extra: Arc<Data>,
320    ) -> Result<(), Error> {
321        self.send(|| (to.clone(), extra.clone(), FetchEvent::Fetching));
322
323        remove_parts(&to).await;
324
325        let attempts = Arc::new(AtomicU16::new(0));
326
327        let fetch = || async {
328            loop {
329                let task = self.clone().inner_request(
330                    &self.client,
331                    uris.clone(),
332                    to.clone(),
333                    extra.clone(),
334                    attempts.clone(),
335                );
336
337                let result = task.await;
338
339                if let Err(Error::NetworkChanged) | Err(Error::TimedOut) = result {
340                    let mut attempts = 5;
341                    while attempts != 0 {
342                        tokio::time::sleep(Duration::from_secs(3)).await;
343
344                        match &self.client {
345                            Client::Reqwest(client) => {
346                                let future = head_reqwest(client, &uris[0]);
347                                let net_check =
348                                    crate::utils::timed_interrupt(Duration::from_secs(3), future);
349
350                                if net_check.await.is_ok() {
351                                    tokio::time::sleep(Duration::from_secs(3)).await;
352                                    break;
353                                }
354                            }
355                        };
356
357                        attempts -= 1;
358                    }
359
360                    self.send(|| (to.clone(), extra.clone(), FetchEvent::Retrying));
361                    remove_parts(&to).await;
362                    tokio::time::sleep(Duration::from_secs(3)).await;
363
364                    continue;
365                }
366
367                return result;
368            }
369        };
370
371        let task = async {
372            let mut attempted = false;
373            loop {
374                if attempted {
375                    self.send(|| (to.clone(), extra.clone(), FetchEvent::Retrying));
376                }
377
378                attempted = true;
379                remove_parts(&to).await;
380
381                let error = match fetch().await {
382                    Ok(()) => return Ok(()),
383                    Err(error) => error,
384                };
385
386                if let Error::Canceled = error {
387                    return Err(error);
388                }
389
390                tokio::time::sleep(Duration::from_secs(3)).await;
391
392                // Uncondtionally retry connection errors.
393                if let Error::ReqwestClient(ref error) = error {
394                    use std::error::Error;
395                    if let Some(source) = error.source() {
396                        if let Some(error) = source.downcast_ref::<reqwest::Error>() {
397                            if error.is_connect() || error.is_request() {
398                                error!("retrying due to connection error: {}", error);
399                                continue;
400                            }
401                        }
402                    }
403                }
404
405                error!("retrying after error encountered: {}", error);
406
407                if attempts.fetch_add(1, Ordering::SeqCst) > self.retries {
408                    return Err(error);
409                }
410            }
411        };
412
413        let result = task.await;
414
415        remove_parts(&to).await;
416
417        match result {
418            Ok(()) => {
419                self.send(|| (to.clone(), extra.clone(), FetchEvent::Fetched));
420
421                Ok(())
422            }
423            Err(why) => Err(why),
424        }
425    }
426
427    async fn inner_request(
428        self: Arc<Self>,
429        client: &Client,
430        uris: Arc<[Box<str>]>,
431        to: Arc<Path>,
432        extra: Arc<Data>,
433        attempts: Arc<AtomicU16>,
434    ) -> Result<(), Error> {
435        let mut length = None;
436        let mut modified = None;
437        let mut resume = 0;
438
439        match client {
440            Client::Reqwest(client) => {
441                let head_response = head_reqwest(client, &*uris[0]).await?;
442
443                if let Some(response) = head_response.as_ref() {
444                    length = response
445                        .headers()
446                        .get(reqwest::header::CONTENT_LENGTH)
447                        .and_then(|value| value.to_str().ok())
448                        .and_then(|value| value.parse().ok());
449                    modified = response.last_modified();
450                }
451            }
452        }
453
454        // If the file already exists, validate that it is the same.
455        if to.exists() {
456            if let (Some(length), Some(last_modified)) = (length, modified) {
457                match fs::metadata(to.as_ref()).await {
458                    Ok(metadata) => {
459                        let modified = metadata.modified().map_err(Error::Write)?;
460                        let ts = modified
461                            .duration_since(UNIX_EPOCH)
462                            .expect("time went backwards");
463
464                        if metadata.len() == length {
465                            if ts.as_secs() == date_as_timestamp(last_modified) {
466                                info!("already fetched {}", to.display());
467                                return Ok(());
468                            } else {
469                                error!("removing file with outdated timestamp: {:?}", to);
470                                let _ = fs::remove_file(to.as_ref())
471                                    .await
472                                    .map_err(Error::MetadataRemove)?;
473                            }
474                        } else {
475                            resume = metadata.len();
476                        }
477                    }
478                    Err(why) => {
479                        error!("failed to fetch metadata of {:?}: {}", to, why);
480                        fs::remove_file(to.as_ref())
481                            .await
482                            .map_err(Error::MetadataRemove)?;
483                    }
484                }
485            }
486        }
487
488        // If set, this will use multiple connections to download a file in parts.
489        if self.connections_per_file > 1 {
490            if let Some(length) = length {
491                if supports_range(client, &*uris[0], resume, Some(length)).await? {
492                    self.send(|| (to.clone(), extra.clone(), FetchEvent::ContentLength(length)));
493
494                    if resume != 0 {
495                        self.send(|| (to.clone(), extra.clone(), FetchEvent::Progress(resume)));
496                    }
497
498                    let result = get_many(
499                        self.clone(),
500                        to.clone(),
501                        uris,
502                        resume,
503                        length,
504                        modified,
505                        extra,
506                        attempts.clone(),
507                    )
508                    .await;
509
510                    if let Err(why) = result {
511                        return Err(why);
512                    }
513
514                    if let Some(modified) = modified {
515                        update_modified(&to, modified)?;
516                    }
517
518                    return Ok(());
519                }
520            }
521        }
522
523        if let Some(length) = length {
524            self.send(|| (to.clone(), extra.clone(), FetchEvent::ContentLength(length)));
525
526            if resume > length {
527                resume = 0;
528            }
529        }
530
531        let mut request = match client {
532            Client::Reqwest(client) => RequestBuilder::Reqwest(client.get(&*uris[0])),
533        };
534
535        if resume != 0 {
536            if let Ok(true) = supports_range(client, &*uris[0], resume, length).await {
537                match request {
538                    RequestBuilder::Reqwest(inner) => {
539                        request = RequestBuilder::Reqwest(
540                            inner.header("Range", range::to_string(resume, length)),
541                        );
542                    }
543                }
544                self.send(|| (to.clone(), extra.clone(), FetchEvent::Progress(resume)));
545            } else {
546                resume = 0;
547            }
548        }
549
550        let path = match crate::get(
551            self.clone(),
552            request,
553            FetchLocation::create(to.clone(), resume != 0).await?,
554            to.clone(),
555            extra.clone(),
556            attempts.clone(),
557        )
558        .await
559        {
560            Ok((path, _)) => path,
561            Err(Error::Status(StatusCode::NOT_MODIFIED)) => to,
562
563            // Server does not support if-modified-since
564            Err(Error::Status(StatusCode::NOT_IMPLEMENTED)) => {
565                let request = match client {
566                    Client::Reqwest(client) => RequestBuilder::Reqwest(client.get(&*uris[0])),
567                };
568
569                let (path, _) = crate::get(
570                    self.clone(),
571                    request,
572                    FetchLocation::create(to.clone(), resume != 0).await?,
573                    to.clone(),
574                    extra.clone(),
575                    attempts,
576                )
577                .await?;
578
579                path
580            }
581
582            Err(why) => return Err(why),
583        };
584
585        if let Some(modified) = modified {
586            update_modified(&path, modified)?;
587        }
588
589        Ok(())
590    }
591
592    fn send(&self, event: impl FnOnce() -> (Arc<Path>, Arc<Data>, FetchEvent)) {
593        if let Some(sender) = self.events.as_ref() {
594            let _ = sender.send(event());
595        }
596    }
597}
598
599async fn head_reqwest(client: &ReqwestClient, uri: &str) -> Result<Option<ReqwestResponse>, Error> {
600    let request = client.head(uri).build().unwrap();
601
602    match validate_reqwest(client.execute(request).await?).map(Some) {
603        result @ Ok(_) => result,
604        Err(Error::Status(StatusCode::NOT_MODIFIED))
605        | Err(Error::Status(StatusCode::NOT_IMPLEMENTED)) => Ok(None),
606        Err(other) => Err(other),
607    }
608}
609
610async fn supports_range(
611    client: &Client,
612    uri: &str,
613    resume: u64,
614    length: Option<u64>,
615) -> Result<bool, Error> {
616    match client {
617        Client::Reqwest(client) => {
618            let request = client
619                .head(uri)
620                .header("Range", range::to_string(resume, length).as_str())
621                .build()
622                .unwrap();
623
624            let response = client.execute(request).await?;
625
626            if response.status() == StatusCode::PARTIAL_CONTENT {
627                if let Some(header) = response.headers().get("Content-Range") {
628                    if let Ok(header) = header.to_str() {
629                        if header.starts_with(&format!("bytes {}-", resume)) {
630                            return Ok(true);
631                        }
632                    }
633                }
634
635                Ok(false)
636            } else {
637                validate_reqwest(response).map(|_| false)
638            }
639        }
640    }
641}
642
643fn validate_reqwest(response: ReqwestResponse) -> Result<ReqwestResponse, Error> {
644    let status = response.status();
645
646    if status.is_informational() || status.is_success() {
647        Ok(response)
648    } else {
649        Err(Error::Status(status))
650    }
651}
652
653trait ResponseExt {
654    fn content_length(&self) -> Option<u64>;
655    fn last_modified(&self) -> Option<HttpDate>;
656}
657
658impl ResponseExt for ReqwestResponse {
659    fn content_length(&self) -> Option<u64> {
660        let header = self.headers().get("content-length")?;
661        header.to_str().ok()?.parse::<u64>().ok()
662    }
663
664    fn last_modified(&self) -> Option<HttpDate> {
665        let header = self.headers().get("last-modified")?;
666        httpdate::parse_http_date(header.to_str().ok()?)
667            .ok()
668            .map(HttpDate::from)
669    }
670}
671
672/// Cleans up after a process that may have been aborted.
673async fn remove_parts(to: &Path) {
674    let original_filename = match to.file_name().and_then(|x| x.to_str()) {
675        Some(name) => name,
676        None => return,
677    };
678
679    if let Some(parent) = to.parent() {
680        if let Ok(mut dir) = tokio::fs::read_dir(parent).await {
681            while let Ok(Some(entry)) = dir.next_entry().await {
682                if let Some(entry_name) = entry.file_name().to_str() {
683                    if let Some(potential_part) = entry_name.strip_prefix(original_filename) {
684                        if potential_part.starts_with(".part") {
685                            let path = entry.path();
686                            let _ = tokio::fs::remove_file(path).await;
687                        }
688                    }
689                }
690            }
691        }
692    }
693}