fast-down 4.0.4

Download everything fast
Documentation
#![cfg(not(target_family = "wasm"))]

use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
use fast_pull::ProgressEntry;
use httpdate::parse_http_date;
use reqwest::{
    Client, RequestBuilder, Response,
    header::{self, HeaderMap, InvalidHeaderName},
};
use std::{
    borrow::Cow,
    time::{Duration, SystemTime},
};
use url::Url;

impl HttpClient for Client {
    type RequestBuilder = RequestBuilder;
    fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
        let mut req = self.get(url);
        if let Some(range) = range {
            req = req.header(
                header::RANGE,
                format!("bytes={}-{}", range.start, range.end - 1),
            );
        }
        req
    }
}

impl HttpRequestBuilder for RequestBuilder {
    type Response = Response;
    type RequestError = ReqwestResponseError;
    async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
        let res = self
            .send()
            .await
            .map_err(|e| (ReqwestResponseError::Reqwest(e), None))?;
        let status = res.status();
        if status.is_success() {
            Ok(res)
        } else {
            let retry_after = res
                .headers()
                .get(header::RETRY_AFTER)
                .and_then(|r| r.to_str().ok())
                .and_then(|r| {
                    r.parse().map_or_else(
                        |_| {
                            parse_http_date(r).ok().and_then(|target_time| {
                                target_time.duration_since(SystemTime::now()).ok()
                            })
                        },
                        |r| Some(Duration::from_secs(r)),
                    )
                });
            Err((ReqwestResponseError::StatusCode(status), retry_after))
        }
    }
}

impl HttpResponse for Response {
    type Headers = HeaderMap;
    type ChunkError = reqwest::Error;
    fn headers(&self) -> &Self::Headers {
        self.headers()
    }
    fn url(&self) -> &Url {
        self.url()
    }
    async fn chunk(&mut self) -> Result<Option<bytes::Bytes>, Self::ChunkError> {
        self.chunk().await
    }
}

impl HttpHeaders for HeaderMap {
    type GetHeaderError = ReqwestGetHeaderError;
    fn get(&self, header: &str) -> Result<Cow<'_, str>, Self::GetHeaderError> {
        let header_value = self.get(header).ok_or(ReqwestGetHeaderError::NotFound)?;
        Ok(String::from_utf8_lossy(header_value.as_bytes()))
    }
}

#[derive(thiserror::Error, Debug)]
pub enum ReqwestGetHeaderError {
    #[error("Invalid header name {0:?}")]
    InvalidHeaderName(InvalidHeaderName),
    #[error("Header not found")]
    NotFound,
}

#[derive(thiserror::Error, Debug)]
pub enum ReqwestResponseError {
    #[error("Reqwest error {0:?}")]
    Reqwest(reqwest::Error),
    #[error("Status code {0:?}")]
    StatusCode(reqwest::StatusCode),
}

#[cfg(test)]
#[cfg(feature = "mem")]
mod tests {
    #![allow(
        clippy::unwrap_used,
        clippy::expect_used,
        clippy::panic,
        clippy::significant_drop_tightening
    )]
    use super::*;
    use crate::{
        http::{HttpError, HttpPuller, Prefetch},
        url_info::FileId,
    };
    use fast_pull::{
        Event, Merge,
        mem::MemPusher,
        mock::build_mock_data,
        multi::{self, download_multi},
        single::{self, download_single},
    };
    use reqwest::{Client, StatusCode};
    use std::time::Duration;

    #[tokio::test]
    async fn test_redirect_and_content_range() {
        let mut server = mockito::Server::new_async().await;
        let client = Client::builder().no_proxy().build().unwrap();

        let _mock_redirect = server
            .mock("GET", "/redirect")
            .with_status(301)
            .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
            .create_async()
            .await;

        let _mock_file = server
            .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
            .with_status(206)
            .with_header("Content-Length", "1024")
            .with_header("Content-Range", "bytes 0-0/1024")
            .with_body(vec![0; 1024])
            .create_async()
            .await;

        let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
        let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");

        assert_eq!(
            url_info.final_url.as_str(),
            format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
        );
        assert_eq!(url_info.size, 1024);
        assert_eq!(url_info.raw_name, "你好.txt");
        assert!(url_info.supports_range);
    }

    #[tokio::test]
    async fn test_filename_sources() {
        let mut server = mockito::Server::new_async().await;
        let client = Client::builder().no_proxy().build().unwrap();

        // Test with Content-Disposition header
        let _mock1 = server
            .mock("GET", "/test1")
            .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
            .create_async()
            .await;
        let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
        let (url_info, _) = client.prefetch(url).await.unwrap();
        assert_eq!(url_info.raw_name, "test.txt");

        // 测试仅包含 filename* (UTF-8 编码)
        let _mock_star = server
            .mock("GET", "/test_star")
            .with_header(
                "Content-Disposition",
                "attachment; filename*=UTF-8''%E6%B5%8B%E8%AF%95.txt",
            ) // "测试.txt"
            .create_async()
            .await;
        let url = Url::parse(&format!("{}/test_star", server.url())).unwrap();
        let (url_info, _) = client.prefetch(url).await.unwrap();
        assert_eq!(url_info.raw_name, "测试.txt");

        let _mock_both = server
            .mock("GET", "/test_both")
            .with_header(
                "Content-Disposition",
                r#"attachment; filename="fallback.txt"; filename*=UTF-8''%E6%B5%8B%E8%AF%95.txt"#,
            )
            .create_async()
            .await;
        let url = Url::parse(&format!("{}/test_both", server.url())).unwrap();
        let (url_info, _) = client.prefetch(url).await.unwrap();
        assert_eq!(url_info.raw_name, "测试.txt");

        // Test URL path source
        let _mock2 = server
            .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
            .create_async()
            .await;
        let url = Url::parse(&format!(
            "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
            server.url()
        ))
        .unwrap();
        let (url_info, _) = client.prefetch(url).await.unwrap();
        assert_eq!(url_info.raw_name, "好好好.pdf");
    }

    #[tokio::test]
    async fn test_error_handling() {
        let mut server = mockito::Server::new_async().await;
        let client = Client::builder().no_proxy().build().unwrap();
        let _mock1 = server
            .mock("GET", "/404")
            .with_status(404)
            .create_async()
            .await;

        let url = Url::parse(&format!("{}/404", server.url())).unwrap();
        match client.prefetch(url).await {
            Ok(info) => unreachable!("404 status code should not success: {info:?}"),
            Err((err, _)) => match err {
                HttpError::Request(e) => match e {
                    ReqwestResponseError::Reqwest(error) => unreachable!("{error:?}"),
                    ReqwestResponseError::StatusCode(status_code) => {
                        assert_eq!(status_code, StatusCode::NOT_FOUND);
                    }
                },
                HttpError::Chunk(_) | HttpError::GetHeader(_) | HttpError::Irrecoverable => {
                    unreachable!()
                }
                HttpError::MismatchedBody(file_id) => {
                    unreachable!("404 status code should not return mismatched body: {file_id:?}")
                }
            },
        }
    }

    #[tokio::test]
    async fn test_concurrent_download() {
        let mock_data = build_mock_data(300 * 1024 * 1024);
        let mut server = mockito::Server::new_async().await;
        let client = Client::builder().no_proxy().build().unwrap();
        let mock_body_clone = mock_data.clone();
        let _mock = server
            .mock("GET", "/concurrent")
            .with_status(206)
            .with_header("Accept-Ranges", "bytes")
            .with_body_from_request(move |request| {
                if !request.has_header("Range") {
                    return mock_body_clone.clone();
                }
                let range = request.header("Range")[0];
                println!("range: {range:?}");
                range
                    .to_str()
                    .unwrap()
                    .rsplit('=')
                    .next()
                    .unwrap()
                    .split(',')
                    .map(|p| p.trim().splitn(2, '-'))
                    .map(|mut p| {
                        let start = p.next().unwrap().parse::<usize>().unwrap();
                        let end = p.next().unwrap().parse::<usize>().unwrap();
                        start..=end
                    })
                    .flat_map(|p| mock_body_clone[p].to_vec())
                    .collect()
            })
            .create_async()
            .await;
        let puller = HttpPuller::new(
            format!("{}/concurrent", server.url()).parse().unwrap(),
            client,
            None,
            FileId::default(),
        );
        let pusher = MemPusher::with_capacity(mock_data.len());
        #[allow(clippy::single_range_in_vec_init)]
        let download_chunks = vec![0..mock_data.len() as u64];
        let result = download_multi(
            puller,
            pusher.clone(),
            multi::DownloadOptions {
                concurrent: 32,
                retry_gap: Duration::from_secs(1),
                push_queue_cap: 1024,
                download_chunks: download_chunks.iter().cloned(),
                pull_timeout: Duration::from_secs(5),
                min_chunk_size: 1,
                max_speculative: 3,
            },
        );

        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
        let mut push_progress: Vec<ProgressEntry> = Vec::new();
        while let Ok(e) = result.event_chain.recv().await {
            match e {
                Event::PullProgress(_, p) => {
                    pull_progress.merge_progress(p);
                }
                Event::PushProgress(_, p) => {
                    push_progress.merge_progress(p);
                }
                _ => {}
            }
        }
        dbg!(&pull_progress);
        dbg!(&push_progress);
        assert_eq!(pull_progress, download_chunks);
        assert_eq!(push_progress, download_chunks);

        result.join().await.unwrap();
        assert_eq!(&**pusher.receive.lock(), mock_data);
    }

    #[tokio::test]
    async fn test_sequential_download() {
        let mock_data = build_mock_data(300 * 1024 * 1024);
        let mut server = mockito::Server::new_async().await;
        let client = Client::builder().no_proxy().build().unwrap();
        let _mock = server
            .mock("GET", "/sequential")
            .with_status(200)
            .with_body(mock_data.clone())
            .create_async()
            .await;
        let puller = HttpPuller::new(
            format!("{}/sequential", server.url()).parse().unwrap(),
            client,
            None,
            FileId::default(),
        );
        let pusher = MemPusher::with_capacity(mock_data.len());
        #[allow(clippy::single_range_in_vec_init)]
        let download_chunks = vec![0..mock_data.len() as u64];
        let result = download_single(
            puller,
            pusher.clone(),
            single::DownloadOptions {
                retry_gap: Duration::from_secs(1),
                push_queue_cap: 1024,
            },
        );

        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
        let mut push_progress: Vec<ProgressEntry> = Vec::new();
        while let Ok(e) = result.event_chain.recv().await {
            match e {
                Event::PullProgress(_, p) => {
                    pull_progress.merge_progress(p);
                }
                Event::PushProgress(_, p) => {
                    push_progress.merge_progress(p);
                }
                _ => {}
            }
        }
        dbg!(&pull_progress);
        dbg!(&push_progress);
        assert_eq!(pull_progress, download_chunks);
        assert_eq!(push_progress, download_chunks);

        result.join().await.unwrap();
        assert_eq!(&**pusher.receive.lock(), mock_data);
    }
}