async_fetcher/
lib.rs

1// Copyright 2021-2022 System76 <info@system76.com>
2// SPDX-License-Identifier: MPL-2.0
3
4//! Asynchronously fetch files from HTTP servers
5//!
6//! - Concurrently fetch multiple files at the same time.
7//! - Define alternative mirrors for the source of a file.
8//! - Use multiple concurrent connections per file.
9//! - Use mirrors for concurrent connections.
10//! - Resume a download which has been interrupted.
11//! - Progress events for fetches
12//!
13//! ```
14//! let (events_tx, events_rx) = tokio::sync::mpsc::unbounded_channel();
15//!
16//! let shutdown = async_shutdown::ShutdownManager::new();
17//!
18//! let results_stream = Fetcher::default()
19//!     // Define a max number of ranged connections per file.
20//!     .connections_per_file(4)
21//!     // Max size of a connection's part, concatenated on completion.
22//!     .max_part_size(4 * 1024 * 1024)
23//!     // The channel for sending progress notifications.
24//!     .events(events_tx)
25//!     // Maximum number of retry attempts.
26//!     .retries(3)
27//!     // Cancels the fetching process when a shutdown is triggered.
28//!     .shutdown(shutdown)
29//!     // How long to wait before aborting a download that hasn't progressed.
30//!     .timeout(Duration::from_secs(15))
31//!     // Finalize the struct into an `Arc` for use with fetching.
32//!     .build()
33//!     // Take a stream of `Source` inputs and generate a stream of fetches.
34//!     // Spawns
35//!     .stream_from(input_stream, 4)
36//! ```
37
38#[macro_use]
39extern crate derive_new;
40#[macro_use]
41extern crate derive_setters;
42#[macro_use]
43extern crate log;
44#[macro_use]
45extern crate thiserror;
46
47pub mod iface;
48
49mod checksum;
50mod checksum_system;
51mod concatenator;
52mod get;
53mod get_many;
54mod range;
55mod source;
56mod time;
57mod utils;
58
59pub use self::checksum::*;
60pub use self::checksum_system::*;
61pub use self::concatenator::*;
62pub use self::source::*;
63
64use self::get::{get, FetchLocation};
65use self::get_many::get_many;
66use self::time::{date_as_timestamp, update_modified};
67use async_shutdown::ShutdownManager;
68use futures::{
69    prelude::*,
70    stream::{self, StreamExt},
71};
72
73use http::StatusCode;
74use httpdate::HttpDate;
75use numtoa::NumToA;
76use reqwest::redirect::Policy;
77use reqwest::{
78    Client as ReqwestClient, RequestBuilder as ReqwestBuilder, Response as ReqwestResponse,
79};
80
81use std::sync::atomic::Ordering;
82use std::{
83    fmt::Debug,
84    io,
85    path::Path,
86    pin::Pin,
87    sync::{atomic::AtomicU16, Arc},
88    time::{Duration, UNIX_EPOCH},
89};
90use tokio::fs;
91use tokio::sync::mpsc;
92
93/// The result of a fetched task from a stream of input sources.
94pub type AsyncFetchOutput<Data> = (Arc<Path>, Arc<Data>, Result<(), Error>);
95
96/// A channel for sending `FetchEvent`s to.
97pub type EventSender<Data> = mpsc::UnboundedSender<(Arc<Path>, Data, FetchEvent)>;
98
99/// An error from the asynchronous file fetcher.
100#[derive(Debug, Error)]
101pub enum Error {
102    #[error("task was canceled")]
103    Canceled,
104    #[error("http client error")]
105    ReqwestClient(#[source] reqwest::Error),
106    #[error("unable to concatenate fetched parts")]
107    Concatenate(#[source] io::Error),
108    #[error("unable to create file")]
109    FileCreate(#[source] io::Error),
110    #[error("unable to set timestamp on {:?}", _0)]
111    FileTime(Arc<Path>, #[source] io::Error),
112    #[error("content length is an invalid range")]
113    InvalidRange(#[source] io::Error),
114    #[error("unable to remove file with bad metadata")]
115    MetadataRemove(#[source] io::Error),
116    #[error("destination has no file name")]
117    Nameless,
118    #[error("network connection was interrupted while fetching")]
119    NetworkChanged,
120    #[error("unable to open fetched part")]
121    OpenPart(Arc<Path>, #[source] io::Error),
122    #[error("destination lacks parent")]
123    Parentless,
124    #[error("connection timed out")]
125    TimedOut,
126    #[error("error writing to file")]
127    Write(#[source] io::Error),
128    #[error("network input error")]
129    Read(#[source] io::Error),
130    #[error("failed to rename partial to destination")]
131    Rename(#[source] io::Error),
132    #[error("server responded with an error: {}", _0)]
133    Status(StatusCode),
134    #[error("internal tokio join handle error")]
135    TokioSpawn(#[source] tokio::task::JoinError),
136    #[error("the request builder did not match the client used")]
137    InvalidGetRequestBuilder,
138}
139
140impl From<reqwest::Error> for Error {
141    fn from(e: reqwest::Error) -> Self {
142        Self::ReqwestClient(e)
143    }
144}
145
146/// Events which are submitted by the fetcher.
147#[derive(Debug)]
148pub enum FetchEvent {
149    /// States that we know the length of the file being fetched.
150    ContentLength(u64),
151    /// Notifies that the file has been fetched.
152    Fetched,
153    /// Notifies that a file is being fetched.
154    Fetching,
155    /// Reports the amount of bytes that have been read for a file.
156    Progress(u64),
157    /// Notification that a fetch is being re-attempted.
158    Retrying,
159}
160
161/// An asynchronous file fetcher for clients fetching files.
162///
163/// The futures generated by the fetcher are compatible with single and multi-threaded
164/// runtimes, allowing you to choose between the runtime that works best for your
165/// application. A single-threaded runtime is generally recommended for fetching files,
166/// as your network connection is unlikely to be faster than a single CPU core.
167#[derive(new, Setters)]
168pub struct Fetcher<Data> {
169    /// Creates an instance of a client. The caller can decide if the instance
170    /// is shared or unique.
171    #[setters(skip)]
172    client: Client,
173
174    /// The number of concurrent connections to sustain per file being fetched.
175    /// # Note
176    /// Defaults to 1 connection
177    #[new(value = "1")]
178    connections_per_file: u16,
179
180    /// Configure the delay between file requests.
181    /// # Note
182    /// Defaults to no delay
183    #[new(value = "0")]
184    delay_between_requests: u64,
185
186    /// The number of attempts to make when a request fails.
187    /// # Note
188    /// Defaults to 3 retries.
189    #[new(value = "3")]
190    retries: u16,
191
192    /// The maximum size of a part file when downloading in parts.
193    /// # Note
194    /// Defaults to 2 MiB.
195    #[new(value = "2 * 1024 * 1024")]
196    max_part_size: u32,
197
198    /// Time in ms between progress messages
199    /// # Note
200    /// Defaults to 500.
201    #[new(value = "500")]
202    progress_interval: u64,
203
204    /// The time to wait between chunks before giving up.
205    #[new(default)]
206    #[setters(strip_option)]
207    timeout: Option<Duration>,
208
209    /// Holds a sender for submitting events to.
210    #[new(default)]
211    #[setters(into)]
212    #[setters(strip_option)]
213    events: Option<Arc<EventSender<Arc<Data>>>>,
214
215    /// Utilized to know when to shut down the fetching process.
216    #[new(value = "ShutdownManager::new()")]
217    shutdown: ShutdownManager<()>,
218}
219
220/// The underlying Client used for the Fetcher
221pub enum Client {
222    Reqwest(ReqwestClient),
223}
224
225pub(crate) enum RequestBuilder {
226    Reqwest(ReqwestBuilder),
227}
228
229impl<Data> Default for Fetcher<Data> {
230    fn default() -> Self {
231        let client = ReqwestClient::builder()
232            // Keep a TCP connection alive for up to 90s
233            .tcp_keepalive(Duration::from_secs(90))
234            // Follow up to 10 redirect links
235            .redirect(Policy::limited(10))
236            // Allow the server to be eager about sending packets
237            .tcp_nodelay(true)
238            // Cache DNS records for 24 hours
239            // .dns_cache(Duration::from_secs(60 * 60 * 24))
240            .build()
241            .expect("failed to create HTTP Client");
242
243        Self::new(Client::Reqwest(client))
244    }
245}
246
247impl<Data: Send + Sync + 'static> Fetcher<Data> {
248    /// Finalizes the fetcher to prepare it for fetch tasks.
249    pub fn build(self) -> Arc<Self> {
250        Arc::new(self)
251    }
252
253    /// Given an input stream of source fetches, returns an output stream of fetch results.
254    ///
255    /// Spawns up to `concurrent` + `1` number of concurrent async tasks on the runtime.
256    /// One task for managing the fetch tasks, and one task per fetch request.
257    pub fn stream_from(
258        self: Arc<Self>,
259        inputs: impl Stream<Item = (Source, Arc<Data>)> + Send + 'static,
260        concurrent: usize,
261    ) -> Pin<Box<dyn Stream<Item = AsyncFetchOutput<Data>> + Send + 'static>> {
262        let shutdown = self.shutdown.clone();
263        let cancel_trigger = shutdown.wait_shutdown_triggered();
264        // Takes input requests and converts them into a stream of fetch requests.
265        let stream = inputs
266            .map(move |(Source { dest, urls, part }, extra)| {
267                let fetcher = self.clone();
268                async move {
269                    if fetcher.delay_between_requests != 0 {
270                        let delay = Duration::from_millis(fetcher.delay_between_requests);
271                        tokio::time::sleep(delay).await;
272                    }
273
274                    tokio::spawn(async move {
275                        let _token = match fetcher.shutdown.delay_shutdown_token() {
276                            Ok(token) => token,
277                            Err(_) => return (dest, extra, Err(Error::Canceled)),
278                        };
279
280                        let task = async {
281                            match part {
282                                Some(part) => {
283                                    match fetcher.request(urls, part.clone(), extra.clone()).await {
284                                        Ok(()) => {
285                                            fs::rename(&*part, &*dest).await.map_err(Error::Rename)
286                                        }
287                                        Err(why) => Err(why),
288                                    }
289                                }
290                                None => fetcher.request(urls, dest.clone(), extra.clone()).await,
291                            }
292                        };
293
294                        let result = task.await;
295
296                        (dest, extra, result)
297                    })
298                    .await
299                    .unwrap()
300                }
301            })
302            .buffer_unordered(concurrent)
303            .take_until(cancel_trigger);
304
305        Box::pin(stream)
306    }
307
308    /// Request a file from one or more URIs.
309    ///
310    /// At least one URI must be provided as a source for the file. Each additional URI
311    /// serves as a mirror for failover and load-balancing purposes.
312    pub async fn request(
313        self: Arc<Self>,
314        uris: Arc<[Box<str>]>,
315        to: Arc<Path>,
316        extra: Arc<Data>,
317    ) -> Result<(), Error> {
318        self.send(|| (to.clone(), extra.clone(), FetchEvent::Fetching));
319
320        remove_parts(&to).await;
321
322        let attempts = Arc::new(AtomicU16::new(0));
323
324        let fetch = || async {
325            loop {
326                let task = self.clone().inner_request(
327                    &self.client,
328                    uris.clone(),
329                    to.clone(),
330                    extra.clone(),
331                    attempts.clone(),
332                );
333
334                let result = task.await;
335
336                if let Err(Error::NetworkChanged) | Err(Error::TimedOut) = result {
337                    let mut attempts = 5;
338                    while attempts != 0 {
339                        tokio::time::sleep(Duration::from_secs(3)).await;
340
341                        match &self.client {
342                            Client::Reqwest(client) => {
343                                let future = head_reqwest(client, &uris[0]);
344                                let net_check =
345                                    crate::utils::timed_interrupt(Duration::from_secs(3), future);
346
347                                if net_check.await.is_ok() {
348                                    tokio::time::sleep(Duration::from_secs(3)).await;
349                                    break;
350                                }
351                            }
352                        };
353
354                        attempts -= 1;
355                    }
356
357                    self.send(|| (to.clone(), extra.clone(), FetchEvent::Retrying));
358                    remove_parts(&to).await;
359                    tokio::time::sleep(Duration::from_secs(3)).await;
360
361                    continue;
362                }
363
364                return result;
365            }
366        };
367
368        let task = async {
369            let mut attempted = false;
370            loop {
371                if attempted {
372                    self.send(|| (to.clone(), extra.clone(), FetchEvent::Retrying));
373                }
374
375                attempted = true;
376                remove_parts(&to).await;
377
378                let error = match fetch().await {
379                    Ok(()) => return Ok(()),
380                    Err(error) => error,
381                };
382
383                if let Error::Canceled = error {
384                    return Err(error);
385                }
386
387                tokio::time::sleep(Duration::from_secs(3)).await;
388
389                // Uncondtionally retry connection errors.
390                if let Error::ReqwestClient(ref error) = error {
391                    use std::error::Error;
392                    if let Some(source) = error.source() {
393                        if let Some(error) = source.downcast_ref::<reqwest::Error>() {
394                            if error.is_connect() || error.is_request() {
395                                error!("retrying due to connection error: {}", error);
396                                continue;
397                            }
398                        }
399                    }
400                }
401
402                error!("retrying after error encountered: {}", error);
403
404                if attempts.fetch_add(1, Ordering::SeqCst) > self.retries {
405                    return Err(error);
406                }
407            }
408        };
409
410        let result = task.await;
411
412        remove_parts(&to).await;
413
414        match result {
415            Ok(()) => {
416                self.send(|| (to.clone(), extra.clone(), FetchEvent::Fetched));
417
418                Ok(())
419            }
420            Err(why) => Err(why),
421        }
422    }
423
424    async fn inner_request(
425        self: Arc<Self>,
426        client: &Client,
427        uris: Arc<[Box<str>]>,
428        to: Arc<Path>,
429        extra: Arc<Data>,
430        attempts: Arc<AtomicU16>,
431    ) -> Result<(), Error> {
432        let mut length = None;
433        let mut modified = None;
434        let mut resume = 0;
435
436        match client {
437            Client::Reqwest(client) => {
438                let head_response = head_reqwest(client, &*uris[0]).await?;
439
440                if let Some(response) = head_response.as_ref() {
441                    length = response
442                        .headers()
443                        .get(reqwest::header::CONTENT_LENGTH)
444                        .and_then(|value| value.to_str().ok())
445                        .and_then(|value| value.parse().ok());
446                    modified = response.last_modified();
447                }
448            }
449        }
450
451        // If the file already exists, validate that it is the same.
452        if to.exists() {
453            if let (Some(length), Some(last_modified)) = (length, modified) {
454                match fs::metadata(to.as_ref()).await {
455                    Ok(metadata) => {
456                        let modified = metadata.modified().map_err(Error::Write)?;
457                        let ts = modified
458                            .duration_since(UNIX_EPOCH)
459                            .expect("time went backwards");
460
461                        if metadata.len() == length {
462                            if ts.as_secs() == date_as_timestamp(last_modified) {
463                                info!("already fetched {}", to.display());
464                                return Ok(());
465                            } else {
466                                error!("removing file with outdated timestamp: {:?}", to);
467                                let _ = fs::remove_file(to.as_ref())
468                                    .await
469                                    .map_err(Error::MetadataRemove)?;
470                            }
471                        } else {
472                            resume = metadata.len();
473                        }
474                    }
475                    Err(why) => {
476                        error!("failed to fetch metadata of {:?}: {}", to, why);
477                        fs::remove_file(to.as_ref())
478                            .await
479                            .map_err(Error::MetadataRemove)?;
480                    }
481                }
482            }
483        }
484
485        // If set, this will use multiple connections to download a file in parts.
486        if self.connections_per_file > 1 {
487            if let Some(length) = length {
488                if supports_range(client, &*uris[0], resume, Some(length)).await? {
489                    self.send(|| (to.clone(), extra.clone(), FetchEvent::ContentLength(length)));
490
491                    if resume != 0 {
492                        self.send(|| (to.clone(), extra.clone(), FetchEvent::Progress(resume)));
493                    }
494
495                    let result = get_many(
496                        self.clone(),
497                        to.clone(),
498                        uris,
499                        resume,
500                        length,
501                        modified,
502                        extra,
503                        attempts.clone(),
504                    )
505                    .await;
506
507                    if let Err(why) = result {
508                        return Err(why);
509                    }
510
511                    if let Some(modified) = modified {
512                        update_modified(&to, modified)?;
513                    }
514
515                    return Ok(());
516                }
517            }
518        }
519
520        if let Some(length) = length {
521            self.send(|| (to.clone(), extra.clone(), FetchEvent::ContentLength(length)));
522
523            if resume > length {
524                resume = 0;
525            }
526        }
527
528        let mut request = match client {
529            Client::Reqwest(client) => RequestBuilder::Reqwest(client.get(&*uris[0])),
530        };
531
532        if resume != 0 {
533            if let Ok(true) = supports_range(client, &*uris[0], resume, length).await {
534                match request {
535                    RequestBuilder::Reqwest(inner) => {
536                        request = RequestBuilder::Reqwest(
537                            inner.header("Range", range::to_string(resume, length)),
538                        );
539                    }
540                }
541                self.send(|| (to.clone(), extra.clone(), FetchEvent::Progress(resume)));
542            } else {
543                resume = 0;
544            }
545        }
546
547        let path = match crate::get(
548            self.clone(),
549            request,
550            FetchLocation::create(to.clone(), resume != 0).await?,
551            to.clone(),
552            extra.clone(),
553            attempts.clone(),
554        )
555        .await
556        {
557            Ok((path, _)) => path,
558            Err(Error::Status(StatusCode::NOT_MODIFIED)) => to,
559
560            // Server does not support if-modified-since
561            Err(Error::Status(StatusCode::NOT_IMPLEMENTED)) => {
562                let request = match client {
563                    Client::Reqwest(client) => RequestBuilder::Reqwest(client.get(&*uris[0])),
564                };
565
566                let (path, _) = crate::get(
567                    self.clone(),
568                    request,
569                    FetchLocation::create(to.clone(), resume != 0).await?,
570                    to.clone(),
571                    extra.clone(),
572                    attempts,
573                )
574                .await?;
575
576                path
577            }
578
579            Err(why) => return Err(why),
580        };
581
582        if let Some(modified) = modified {
583            update_modified(&path, modified)?;
584        }
585
586        Ok(())
587    }
588
589    fn send(&self, event: impl FnOnce() -> (Arc<Path>, Arc<Data>, FetchEvent)) {
590        if let Some(sender) = self.events.as_ref() {
591            let _ = sender.send(event());
592        }
593    }
594}
595
596async fn head_reqwest(client: &ReqwestClient, uri: &str) -> Result<Option<ReqwestResponse>, Error> {
597    let request = client.head(uri).build().unwrap();
598
599    match validate_reqwest(client.execute(request).await?).map(Some) {
600        result @ Ok(_) => result,
601        Err(Error::Status(StatusCode::NOT_MODIFIED))
602        | Err(Error::Status(StatusCode::NOT_IMPLEMENTED)) => Ok(None),
603        Err(other) => Err(other),
604    }
605}
606
607async fn supports_range(
608    client: &Client,
609    uri: &str,
610    resume: u64,
611    length: Option<u64>,
612) -> Result<bool, Error> {
613    match client {
614        Client::Reqwest(client) => {
615            let request = client
616                .head(uri)
617                .header("Range", range::to_string(resume, length).as_str())
618                .build()
619                .unwrap();
620
621            let response = client.execute(request).await?;
622
623            if response.status() == StatusCode::PARTIAL_CONTENT {
624                if let Some(header) = response.headers().get("Content-Range") {
625                    if let Ok(header) = header.to_str() {
626                        if header.starts_with(&format!("bytes {}-", resume)) {
627                            return Ok(true);
628                        }
629                    }
630                }
631
632                Ok(false)
633            } else {
634                validate_reqwest(response).map(|_| false)
635            }
636        }
637    }
638}
639
640fn validate_reqwest(response: ReqwestResponse) -> Result<ReqwestResponse, Error> {
641    let status = response.status();
642
643    if status.is_informational() || status.is_success() {
644        Ok(response)
645    } else {
646        Err(Error::Status(status))
647    }
648}
649
650trait ResponseExt {
651    fn content_length(&self) -> Option<u64>;
652    fn last_modified(&self) -> Option<HttpDate>;
653}
654
655impl ResponseExt for ReqwestResponse {
656    fn content_length(&self) -> Option<u64> {
657        let header = self.headers().get("content-length")?;
658        header.to_str().ok()?.parse::<u64>().ok()
659    }
660
661    fn last_modified(&self) -> Option<HttpDate> {
662        let header = self.headers().get("last-modified")?;
663        httpdate::parse_http_date(header.to_str().ok()?)
664            .ok()
665            .map(HttpDate::from)
666    }
667}
668
669/// Cleans up after a process that may have been aborted.
670async fn remove_parts(to: &Path) {
671    let original_filename = match to.file_name().and_then(|x| x.to_str()) {
672        Some(name) => name,
673        None => return,
674    };
675
676    if let Some(parent) = to.parent() {
677        if let Ok(mut dir) = tokio::fs::read_dir(parent).await {
678            while let Ok(Some(entry)) = dir.next_entry().await {
679                if let Some(entry_name) = entry.file_name().to_str() {
680                    if let Some(potential_part) = entry_name.strip_prefix(original_filename) {
681                        if potential_part.starts_with(".part") {
682                            let path = entry.path();
683                            let _ = tokio::fs::remove_file(path).await;
684                        }
685                    }
686                }
687            }
688        }
689    }
690}