fast_down/reqwest/
mod.rs

1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use reqwest::{
4    Client, RequestBuilder, Response,
5    header::{self, HeaderMap, HeaderName, InvalidHeaderName},
6};
7use url::Url;
8
9impl HttpClient for Client {
10    type RequestBuilder = RequestBuilder;
11    fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
12        let mut req = self.get(url);
13        if let Some(range) = range {
14            req = req.header(
15                header::RANGE,
16                format!("bytes={}-{}", range.start, range.end - 1),
17            );
18        }
19        req
20    }
21}
22
23impl HttpRequestBuilder for RequestBuilder {
24    type Response = Response;
25    type RequestError = reqwest::Error;
26    async fn send(self) -> Result<Self::Response, Self::RequestError> {
27        let res = self.send().await?;
28        res.error_for_status()
29    }
30}
31
32impl HttpResponse for Response {
33    type Headers = HeaderMap;
34    type ChunkError = reqwest::Error;
35    fn headers(&self) -> &Self::Headers {
36        self.headers()
37    }
38    fn url(&self) -> &Url {
39        self.url()
40    }
41    fn chunk(
42        &mut self,
43    ) -> impl Future<Output = Result<Option<bytes::Bytes>, Self::ChunkError>> + Send {
44        Response::chunk(self)
45    }
46}
47
48impl HttpHeaders for HeaderMap {
49    type GetHeaderError = ReqwestGetHeaderError;
50    fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
51        let header_name: HeaderName = header
52            .parse()
53            .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
54        let header_value = self
55            .get(&header_name)
56            .ok_or(ReqwestGetHeaderError::NotFound)?;
57        header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
58    }
59}
60
61#[derive(thiserror::Error, Debug)]
62pub enum ReqwestGetHeaderError {
63    #[error("Invalid header name {0:?}")]
64    InvalidHeaderName(InvalidHeaderName),
65    #[error("Failed to convert header value to string {0:?}")]
66    ToStr(reqwest::header::ToStrError),
67    #[error("Header not found")]
68    NotFound,
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use crate::http::{HttpError, HttpPuller, Prefetch};
75    use fast_pull::{
76        Event, MergeProgress,
77        mem::MemPusher,
78        mock::build_mock_data,
79        multi::{self, download_multi},
80        single::{self, download_single},
81    };
82    use reqwest::{Client, StatusCode};
83    use std::{num::NonZero, time::Duration};
84
85    #[tokio::test]
86    async fn test_redirect_and_content_range() {
87        let mut server = mockito::Server::new_async().await;
88
89        let mock_redirect = server
90            .mock("GET", "/redirect")
91            .with_status(301)
92            .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
93            .create_async()
94            .await;
95
96        let mock_file = server
97            .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
98            .with_status(200)
99            .with_header("Content-Length", "1024")
100            .with_header("Accept-Ranges", "bytes")
101            .with_body(vec![0; 1024])
102            .create_async()
103            .await;
104
105        let client = Client::new();
106        let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
107        let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
108
109        assert_eq!(
110            url_info.final_url.as_str(),
111            format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
112        );
113        assert_eq!(url_info.size, 1024);
114        assert_eq!(url_info.name, "你好.txt");
115        assert!(url_info.supports_range);
116
117        mock_redirect.assert_async().await;
118        mock_file.assert_async().await;
119    }
120
121    #[tokio::test]
122    async fn test_filename_sources() {
123        let mut server = mockito::Server::new_async().await;
124
125        // Test with Content-Disposition header
126        let mock1 = server
127            .mock("GET", "/test1")
128            .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
129            .create_async()
130            .await;
131        let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
132        let (url_info, _) = Client::new().prefetch(url).await.unwrap();
133        assert_eq!(url_info.name, "test.txt");
134        mock1.assert_async().await;
135
136        // Test URL path source
137        let mock2 = server
138            .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
139            .create_async()
140            .await;
141        let url = Url::parse(&format!(
142            "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
143            server.url()
144        ))
145        .unwrap();
146        let (url_info, _) = Client::new().prefetch(url).await.unwrap();
147        assert_eq!(url_info.name, "好好好.pdf");
148        mock2.assert_async().await;
149    }
150
151    #[tokio::test]
152    async fn test_error_handling() {
153        let mut server = mockito::Server::new_async().await;
154        let mock1 = server
155            .mock("GET", "/404")
156            .with_status(404)
157            .create_async()
158            .await;
159
160        let client = Client::new();
161        let url = Url::parse(&format!("{}/404", server.url())).unwrap();
162        match client.prefetch(url).await {
163            Ok(info) => unreachable!("404 status code should not success: {info:?}"),
164            Err(err) => match err {
165                HttpError::Request(e) => assert_eq!(e.status(), Some(StatusCode::NOT_FOUND)),
166                HttpError::Chunk(_) => unreachable!(),
167                HttpError::GetHeader(_) => unreachable!(),
168                HttpError::Irrecoverable => unreachable!(),
169            },
170        }
171        mock1.assert_async().await;
172    }
173
174    #[tokio::test]
175    async fn test_concurrent_download() {
176        let mock_data = build_mock_data(300 * 1024 * 1024);
177        let mut server = mockito::Server::new_async().await;
178        let mock_body_clone = mock_data.clone();
179        let _mock = server
180            .mock("GET", "/concurrent")
181            .with_status(206)
182            .with_header("Accept-Ranges", "bytes")
183            .with_body_from_request(move |request| {
184                if !request.has_header("Range") {
185                    return mock_body_clone.clone();
186                }
187                let range = request.header("Range")[0];
188                println!("range: {range:?}");
189                range
190                    .to_str()
191                    .unwrap()
192                    .rsplit('=')
193                    .next()
194                    .unwrap()
195                    .split(',')
196                    .map(|p| p.trim().splitn(2, '-'))
197                    .map(|mut p| {
198                        let start = p.next().unwrap().parse::<usize>().unwrap();
199                        let end = p.next().unwrap().parse::<usize>().unwrap();
200                        start..=end
201                    })
202                    .flat_map(|p| mock_body_clone[p].to_vec())
203                    .collect()
204            })
205            .create_async()
206            .await;
207        let puller = HttpPuller::new(
208            format!("{}/concurrent", server.url()).parse().unwrap(),
209            Client::new(),
210        );
211        let pusher = MemPusher::with_capacity(mock_data.len());
212        #[allow(clippy::single_range_in_vec_init)]
213        let download_chunks = vec![0..mock_data.len() as u64];
214        let result = download_multi(
215            puller,
216            pusher.clone(),
217            multi::DownloadOptions {
218                concurrent: NonZero::new(32).unwrap(),
219                retry_gap: Duration::from_secs(1),
220                push_queue_cap: 1024,
221                download_chunks: download_chunks.clone(),
222                min_chunk_size: NonZero::new(1).unwrap(),
223            },
224        )
225        .await;
226
227        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
228        let mut push_progress: Vec<ProgressEntry> = Vec::new();
229        while let Ok(e) = result.event_chain.recv().await {
230            match e {
231                Event::PullProgress(_, p) => {
232                    pull_progress.merge_progress(p);
233                }
234                Event::PushProgress(_, p) => {
235                    push_progress.merge_progress(p);
236                }
237                _ => {}
238            }
239        }
240        dbg!(&pull_progress);
241        dbg!(&push_progress);
242        assert_eq!(pull_progress, download_chunks);
243        assert_eq!(push_progress, download_chunks);
244
245        result.join().await.unwrap();
246        assert_eq!(&**pusher.receive.lock(), mock_data);
247    }
248
249    #[tokio::test]
250    async fn test_sequential_download() {
251        let mock_data = build_mock_data(300 * 1024 * 1024);
252        let mut server = mockito::Server::new_async().await;
253        let _mock = server
254            .mock("GET", "/sequential")
255            .with_status(200)
256            .with_body(mock_data.clone())
257            .create_async()
258            .await;
259        let puller = HttpPuller::new(
260            format!("{}/sequential", server.url()).parse().unwrap(),
261            Client::new(),
262        );
263        let pusher = MemPusher::with_capacity(mock_data.len());
264        #[allow(clippy::single_range_in_vec_init)]
265        let download_chunks = vec![0..mock_data.len() as u64];
266        let result = download_single(
267            puller,
268            pusher.clone(),
269            single::DownloadOptions {
270                retry_gap: Duration::from_secs(1),
271                push_queue_cap: 1024,
272            },
273        )
274        .await;
275
276        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
277        let mut push_progress: Vec<ProgressEntry> = Vec::new();
278        while let Ok(e) = result.event_chain.recv().await {
279            match e {
280                Event::PullProgress(_, p) => {
281                    pull_progress.merge_progress(p);
282                }
283                Event::PushProgress(_, p) => {
284                    push_progress.merge_progress(p);
285                }
286                _ => {}
287            }
288        }
289        dbg!(&pull_progress);
290        dbg!(&push_progress);
291        assert_eq!(pull_progress, download_chunks);
292        assert_eq!(push_progress, download_chunks);
293
294        result.join().await.unwrap();
295        assert_eq!(&**pusher.receive.lock(), mock_data);
296    }
297}