Skip to main content

fast_down/reqwest/
mod.rs

1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use httpdate::parse_http_date;
4use reqwest::{
5    Client, RequestBuilder, Response,
6    header::{self, HeaderMap, HeaderName, InvalidHeaderName},
7};
8use std::time::{Duration, SystemTime};
9use url::Url;
10
11impl HttpClient for Client {
12    type RequestBuilder = RequestBuilder;
13    fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
14        let mut req = self.get(url);
15        if let Some(range) = range {
16            req = req.header(
17                header::RANGE,
18                format!("bytes={}-{}", range.start, range.end - 1),
19            );
20        }
21        req
22    }
23}
24
25impl HttpRequestBuilder for RequestBuilder {
26    type Response = Response;
27    type RequestError = ReqwestResponseError;
28    async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
29        let res = self
30            .send()
31            .await
32            .map_err(|e| (ReqwestResponseError::Reqwest(e), None))?;
33        let status = res.status();
34        if status.is_success() {
35            Ok(res)
36        } else {
37            let retry_after = res
38                .headers()
39                .get(header::RETRY_AFTER)
40                .and_then(|r| r.to_str().ok())
41                .and_then(|r| match r.parse() {
42                    Ok(r) => Some(Duration::from_secs(r)),
43                    Err(_) => match parse_http_date(r) {
44                        Ok(target_time) => target_time.duration_since(SystemTime::now()).ok(),
45                        Err(_) => None,
46                    },
47                });
48            Err((ReqwestResponseError::StatusCode(status), retry_after))
49        }
50    }
51}
52
53impl HttpResponse for Response {
54    type Headers = HeaderMap;
55    type ChunkError = reqwest::Error;
56    fn headers(&self) -> &Self::Headers {
57        self.headers()
58    }
59    fn url(&self) -> &Url {
60        self.url()
61    }
62    async fn chunk(&mut self) -> Result<Option<bytes::Bytes>, Self::ChunkError> {
63        self.chunk().await
64    }
65}
66
67impl HttpHeaders for HeaderMap {
68    type GetHeaderError = ReqwestGetHeaderError;
69    fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
70        let header_name: HeaderName = header
71            .parse()
72            .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
73        let header_value = self
74            .get(&header_name)
75            .ok_or(ReqwestGetHeaderError::NotFound)?;
76        header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
77    }
78}
79
80#[derive(thiserror::Error, Debug)]
81pub enum ReqwestGetHeaderError {
82    #[error("Invalid header name {0:?}")]
83    InvalidHeaderName(InvalidHeaderName),
84    #[error("Failed to convert header value to string {0:?}")]
85    ToStr(reqwest::header::ToStrError),
86    #[error("Header not found")]
87    NotFound,
88}
89
90#[derive(thiserror::Error, Debug)]
91pub enum ReqwestResponseError {
92    #[error("Reqwest error {0:?}")]
93    Reqwest(reqwest::Error),
94    #[error("Status code {0:?}")]
95    StatusCode(reqwest::StatusCode),
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::{
102        http::{HttpError, HttpPuller, Prefetch},
103        url_info::FileId,
104    };
105    use fast_pull::{
106        Event, Merge,
107        mem::MemPusher,
108        mock::build_mock_data,
109        multi::{self, download_multi},
110        single::{self, download_single},
111    };
112    use reqwest::{Client, StatusCode};
113    use std::time::Duration;
114
115    #[tokio::test]
116    async fn test_redirect_and_content_range() {
117        let mut server = mockito::Server::new_async().await;
118        let client = Client::builder().no_proxy().build().unwrap();
119
120        let mock_redirect = server
121            .mock("GET", "/redirect")
122            .with_status(301)
123            .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
124            .create_async()
125            .await;
126
127        let mock_file = server
128            .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
129            .with_status(200)
130            .with_header("Content-Length", "1024")
131            .with_header("Accept-Ranges", "bytes")
132            .with_body(vec![0; 1024])
133            .create_async()
134            .await;
135
136        let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
137        let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
138
139        assert_eq!(
140            url_info.final_url.as_str(),
141            format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
142        );
143        assert_eq!(url_info.size, 1024);
144        assert_eq!(url_info.raw_name, "你好.txt");
145        assert!(url_info.supports_range);
146
147        mock_redirect.assert_async().await;
148        mock_file.assert_async().await;
149    }
150
151    #[tokio::test]
152    async fn test_filename_sources() {
153        let mut server = mockito::Server::new_async().await;
154        let client = Client::builder().no_proxy().build().unwrap();
155
156        // Test with Content-Disposition header
157        let mock1 = server
158            .mock("GET", "/test1")
159            .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
160            .create_async()
161            .await;
162        let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
163        let (url_info, _) = client.prefetch(url).await.unwrap();
164        assert_eq!(url_info.raw_name, "test.txt");
165        mock1.assert_async().await;
166
167        // 测试仅包含 filename* (UTF-8 编码)
168        let mock_star = server
169            .mock("GET", "/test_star")
170            .with_header(
171                "Content-Disposition",
172                r#"attachment; filename*=UTF-8''%E6%B5%8B%E8%AF%95.txt"#,
173            ) // "测试.txt"
174            .create_async()
175            .await;
176        let url = Url::parse(&format!("{}/test_star", server.url())).unwrap();
177        let (url_info, _) = client.prefetch(url).await.unwrap();
178        assert_eq!(url_info.raw_name, "测试.txt");
179        mock_star.assert_async().await;
180
181        let mock_both = server
182            .mock("GET", "/test_both")
183            .with_header(
184                "Content-Disposition",
185                r#"attachment; filename="fallback.txt"; filename*=UTF-8''%E6%B5%8B%E8%AF%95.txt"#,
186            )
187            .create_async()
188            .await;
189        let url = Url::parse(&format!("{}/test_both", server.url())).unwrap();
190        let (url_info, _) = client.prefetch(url).await.unwrap();
191        assert_eq!(url_info.raw_name, "测试.txt");
192        mock_both.assert_async().await;
193
194        // Test URL path source
195        let mock2 = server
196            .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
197            .create_async()
198            .await;
199        let url = Url::parse(&format!(
200            "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
201            server.url()
202        ))
203        .unwrap();
204        let (url_info, _) = client.prefetch(url).await.unwrap();
205        assert_eq!(url_info.raw_name, "好好好.pdf");
206        mock2.assert_async().await;
207    }
208
209    #[tokio::test]
210    async fn test_error_handling() {
211        let mut server = mockito::Server::new_async().await;
212        let client = Client::builder().no_proxy().build().unwrap();
213        let mock1 = server
214            .mock("GET", "/404")
215            .with_status(404)
216            .create_async()
217            .await;
218
219        let url = Url::parse(&format!("{}/404", server.url())).unwrap();
220        match client.prefetch(url).await {
221            Ok(info) => unreachable!("404 status code should not success: {info:?}"),
222            Err((err, _)) => match err {
223                HttpError::Request(e) => match e {
224                    ReqwestResponseError::Reqwest(error) => unreachable!("{error:?}"),
225                    ReqwestResponseError::StatusCode(status_code) => {
226                        assert_eq!(status_code, StatusCode::NOT_FOUND)
227                    }
228                },
229                HttpError::Chunk(_) => unreachable!(),
230                HttpError::GetHeader(_) => unreachable!(),
231                HttpError::Irrecoverable => unreachable!(),
232                HttpError::MismatchedBody(file_id) => {
233                    unreachable!("404 status code should not return mismatched body: {file_id:?}")
234                }
235            },
236        }
237        mock1.assert_async().await;
238    }
239
240    #[tokio::test]
241    async fn test_concurrent_download() {
242        let mock_data = build_mock_data(300 * 1024 * 1024);
243        let mut server = mockito::Server::new_async().await;
244        let client = Client::builder().no_proxy().build().unwrap();
245        let mock_body_clone = mock_data.clone();
246        let _mock = server
247            .mock("GET", "/concurrent")
248            .with_status(206)
249            .with_header("Accept-Ranges", "bytes")
250            .with_body_from_request(move |request| {
251                if !request.has_header("Range") {
252                    return mock_body_clone.clone();
253                }
254                let range = request.header("Range")[0];
255                println!("range: {range:?}");
256                range
257                    .to_str()
258                    .unwrap()
259                    .rsplit('=')
260                    .next()
261                    .unwrap()
262                    .split(',')
263                    .map(|p| p.trim().splitn(2, '-'))
264                    .map(|mut p| {
265                        let start = p.next().unwrap().parse::<usize>().unwrap();
266                        let end = p.next().unwrap().parse::<usize>().unwrap();
267                        start..=end
268                    })
269                    .flat_map(|p| mock_body_clone[p].to_vec())
270                    .collect()
271            })
272            .create_async()
273            .await;
274        let puller = HttpPuller::new(
275            format!("{}/concurrent", server.url()).parse().unwrap(),
276            client,
277            None,
278            FileId::default(),
279        );
280        let pusher = MemPusher::with_capacity(mock_data.len());
281        #[allow(clippy::single_range_in_vec_init)]
282        let download_chunks = vec![0..mock_data.len() as u64];
283        let result = download_multi(
284            puller,
285            pusher.clone(),
286            multi::DownloadOptions {
287                concurrent: 32,
288                retry_gap: Duration::from_secs(1),
289                push_queue_cap: 1024,
290                download_chunks: download_chunks.iter(),
291                pull_timeout: Duration::from_secs(5),
292                min_chunk_size: 1,
293            },
294        );
295
296        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
297        let mut push_progress: Vec<ProgressEntry> = Vec::new();
298        while let Ok(e) = result.event_chain.recv().await {
299            match e {
300                Event::PullProgress(_, p) => {
301                    pull_progress.merge_progress(p);
302                }
303                Event::PushProgress(_, p) => {
304                    push_progress.merge_progress(p);
305                }
306                _ => {}
307            }
308        }
309        dbg!(&pull_progress);
310        dbg!(&push_progress);
311        assert_eq!(pull_progress, download_chunks);
312        assert_eq!(push_progress, download_chunks);
313
314        result.join().await.unwrap();
315        assert_eq!(&**pusher.receive.lock(), mock_data);
316    }
317
318    #[tokio::test]
319    async fn test_sequential_download() {
320        let mock_data = build_mock_data(300 * 1024 * 1024);
321        let mut server = mockito::Server::new_async().await;
322        let client = Client::builder().no_proxy().build().unwrap();
323        let _mock = server
324            .mock("GET", "/sequential")
325            .with_status(200)
326            .with_body(mock_data.clone())
327            .create_async()
328            .await;
329        let puller = HttpPuller::new(
330            format!("{}/sequential", server.url()).parse().unwrap(),
331            client,
332            None,
333            FileId::default(),
334        );
335        let pusher = MemPusher::with_capacity(mock_data.len());
336        #[allow(clippy::single_range_in_vec_init)]
337        let download_chunks = vec![0..mock_data.len() as u64];
338        let result = download_single(
339            puller,
340            pusher.clone(),
341            single::DownloadOptions {
342                retry_gap: Duration::from_secs(1),
343                push_queue_cap: 1024,
344            },
345        );
346
347        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
348        let mut push_progress: Vec<ProgressEntry> = Vec::new();
349        while let Ok(e) = result.event_chain.recv().await {
350            match e {
351                Event::PullProgress(_, p) => {
352                    pull_progress.merge_progress(p);
353                }
354                Event::PushProgress(_, p) => {
355                    push_progress.merge_progress(p);
356                }
357                _ => {}
358            }
359        }
360        dbg!(&pull_progress);
361        dbg!(&push_progress);
362        assert_eq!(pull_progress, download_chunks);
363        assert_eq!(push_progress, download_chunks);
364
365        result.join().await.unwrap();
366        assert_eq!(&**pusher.receive.lock(), mock_data);
367    }
368}