Skip to main content

fast_down/reqwest/
mod.rs

1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use httpdate::parse_http_date;
4use reqwest::{
5    Client, RequestBuilder, Response,
6    header::{self, HeaderMap, HeaderName, InvalidHeaderName},
7};
8use std::time::{Duration, SystemTime};
9use url::Url;
10
11impl HttpClient for Client {
12    type RequestBuilder = RequestBuilder;
13    fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
14        let mut req = self.get(url);
15        if let Some(range) = range {
16            req = req.header(
17                header::RANGE,
18                format!("bytes={}-{}", range.start, range.end - 1),
19            );
20        }
21        req
22    }
23}
24
25impl HttpRequestBuilder for RequestBuilder {
26    type Response = Response;
27    type RequestError = ReqwestResponseError;
28    async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
29        let res = self
30            .send()
31            .await
32            .map_err(|e| (ReqwestResponseError::Reqwest(e), None))?;
33        let status = res.status();
34        if status.is_success() {
35            Ok(res)
36        } else {
37            let retry_after = res
38                .headers()
39                .get(header::RETRY_AFTER)
40                .and_then(|r| r.to_str().ok())
41                .and_then(|r| match r.parse() {
42                    Ok(r) => Some(Duration::from_secs(r)),
43                    Err(_) => match parse_http_date(r) {
44                        Ok(target_time) => target_time.duration_since(SystemTime::now()).ok(),
45                        Err(_) => None,
46                    },
47                });
48            Err((ReqwestResponseError::StatusCode(status), retry_after))
49        }
50    }
51}
52
53impl HttpResponse for Response {
54    type Headers = HeaderMap;
55    type ChunkError = reqwest::Error;
56    fn headers(&self) -> &Self::Headers {
57        self.headers()
58    }
59    fn url(&self) -> &Url {
60        self.url()
61    }
62    async fn chunk(&mut self) -> Result<Option<bytes::Bytes>, Self::ChunkError> {
63        self.chunk().await
64    }
65}
66
67impl HttpHeaders for HeaderMap {
68    type GetHeaderError = ReqwestGetHeaderError;
69    fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
70        let header_name: HeaderName = header
71            .parse()
72            .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
73        let header_value = self
74            .get(&header_name)
75            .ok_or(ReqwestGetHeaderError::NotFound)?;
76        header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
77    }
78}
79
80#[derive(thiserror::Error, Debug)]
81pub enum ReqwestGetHeaderError {
82    #[error("Invalid header name {0:?}")]
83    InvalidHeaderName(InvalidHeaderName),
84    #[error("Failed to convert header value to string {0:?}")]
85    ToStr(reqwest::header::ToStrError),
86    #[error("Header not found")]
87    NotFound,
88}
89
90#[derive(thiserror::Error, Debug)]
91pub enum ReqwestResponseError {
92    #[error("Reqwest error {0:?}")]
93    Reqwest(reqwest::Error),
94    #[error("Status code {0:?}")]
95    StatusCode(reqwest::StatusCode),
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::{
102        http::{HttpError, HttpPuller, Prefetch},
103        url_info::FileId,
104    };
105    use fast_pull::{
106        Event, Merge,
107        mem::MemPusher,
108        mock::build_mock_data,
109        multi::{self, download_multi},
110        single::{self, download_single},
111    };
112    use reqwest::{Client, StatusCode};
113    use std::time::Duration;
114
115    #[tokio::test]
116    async fn test_redirect_and_content_range() {
117        let mut server = mockito::Server::new_async().await;
118        let client = Client::builder().no_proxy().build().unwrap();
119
120        let mock_redirect = server
121            .mock("GET", "/redirect")
122            .with_status(301)
123            .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
124            .create_async()
125            .await;
126
127        let mock_file = server
128            .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
129            .with_status(200)
130            .with_header("Content-Length", "1024")
131            .with_header("Accept-Ranges", "bytes")
132            .with_body(vec![0; 1024])
133            .create_async()
134            .await;
135
136        let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
137        let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
138
139        assert_eq!(
140            url_info.final_url.as_str(),
141            format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
142        );
143        assert_eq!(url_info.size, 1024);
144        assert_eq!(url_info.raw_name, "你好.txt");
145        assert!(url_info.supports_range);
146
147        mock_redirect.assert_async().await;
148        mock_file.assert_async().await;
149    }
150
151    #[tokio::test]
152    async fn test_filename_sources() {
153        let mut server = mockito::Server::new_async().await;
154        let client = Client::builder().no_proxy().build().unwrap();
155
156        // Test with Content-Disposition header
157        let mock1 = server
158            .mock("GET", "/test1")
159            .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
160            .create_async()
161            .await;
162        let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
163        let (url_info, _) = client.prefetch(url).await.unwrap();
164        assert_eq!(url_info.raw_name, "test.txt");
165        mock1.assert_async().await;
166
167        // Test URL path source
168        let mock2 = server
169            .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
170            .create_async()
171            .await;
172        let url = Url::parse(&format!(
173            "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
174            server.url()
175        ))
176        .unwrap();
177        let (url_info, _) = client.prefetch(url).await.unwrap();
178        assert_eq!(url_info.raw_name, "好好好.pdf");
179        mock2.assert_async().await;
180    }
181
182    #[tokio::test]
183    async fn test_error_handling() {
184        let mut server = mockito::Server::new_async().await;
185        let client = Client::builder().no_proxy().build().unwrap();
186        let mock1 = server
187            .mock("GET", "/404")
188            .with_status(404)
189            .create_async()
190            .await;
191
192        let url = Url::parse(&format!("{}/404", server.url())).unwrap();
193        match client.prefetch(url).await {
194            Ok(info) => unreachable!("404 status code should not success: {info:?}"),
195            Err((err, _)) => match err {
196                HttpError::Request(e) => match e {
197                    ReqwestResponseError::Reqwest(error) => unreachable!("{error:?}"),
198                    ReqwestResponseError::StatusCode(status_code) => {
199                        assert_eq!(status_code, StatusCode::NOT_FOUND)
200                    }
201                },
202                HttpError::Chunk(_) => unreachable!(),
203                HttpError::GetHeader(_) => unreachable!(),
204                HttpError::Irrecoverable => unreachable!(),
205                HttpError::MismatchedBody(file_id) => {
206                    unreachable!("404 status code should not return mismatched body: {file_id:?}")
207                }
208            },
209        }
210        mock1.assert_async().await;
211    }
212
213    #[tokio::test]
214    async fn test_concurrent_download() {
215        let mock_data = build_mock_data(300 * 1024 * 1024);
216        let mut server = mockito::Server::new_async().await;
217        let client = Client::builder().no_proxy().build().unwrap();
218        let mock_body_clone = mock_data.clone();
219        let _mock = server
220            .mock("GET", "/concurrent")
221            .with_status(206)
222            .with_header("Accept-Ranges", "bytes")
223            .with_body_from_request(move |request| {
224                if !request.has_header("Range") {
225                    return mock_body_clone.clone();
226                }
227                let range = request.header("Range")[0];
228                println!("range: {range:?}");
229                range
230                    .to_str()
231                    .unwrap()
232                    .rsplit('=')
233                    .next()
234                    .unwrap()
235                    .split(',')
236                    .map(|p| p.trim().splitn(2, '-'))
237                    .map(|mut p| {
238                        let start = p.next().unwrap().parse::<usize>().unwrap();
239                        let end = p.next().unwrap().parse::<usize>().unwrap();
240                        start..=end
241                    })
242                    .flat_map(|p| mock_body_clone[p].to_vec())
243                    .collect()
244            })
245            .create_async()
246            .await;
247        let puller = HttpPuller::new(
248            format!("{}/concurrent", server.url()).parse().unwrap(),
249            client,
250            None,
251            FileId::default(),
252        );
253        let pusher = MemPusher::with_capacity(mock_data.len());
254        #[allow(clippy::single_range_in_vec_init)]
255        let download_chunks = vec![0..mock_data.len() as u64];
256        let result = download_multi(
257            puller,
258            pusher.clone(),
259            multi::DownloadOptions {
260                concurrent: 32,
261                retry_gap: Duration::from_secs(1),
262                push_queue_cap: 1024,
263                download_chunks: download_chunks.iter(),
264                min_chunk_size: 1,
265            },
266        );
267
268        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
269        let mut push_progress: Vec<ProgressEntry> = Vec::new();
270        while let Ok(e) = result.event_chain.recv().await {
271            match e {
272                Event::PullProgress(_, p) => {
273                    pull_progress.merge_progress(p);
274                }
275                Event::PushProgress(_, p) => {
276                    push_progress.merge_progress(p);
277                }
278                _ => {}
279            }
280        }
281        dbg!(&pull_progress);
282        dbg!(&push_progress);
283        assert_eq!(pull_progress, download_chunks);
284        assert_eq!(push_progress, download_chunks);
285
286        result.join().await.unwrap();
287        assert_eq!(&**pusher.receive.lock(), mock_data);
288    }
289
290    #[tokio::test]
291    async fn test_sequential_download() {
292        let mock_data = build_mock_data(300 * 1024 * 1024);
293        let mut server = mockito::Server::new_async().await;
294        let client = Client::builder().no_proxy().build().unwrap();
295        let _mock = server
296            .mock("GET", "/sequential")
297            .with_status(200)
298            .with_body(mock_data.clone())
299            .create_async()
300            .await;
301        let puller = HttpPuller::new(
302            format!("{}/sequential", server.url()).parse().unwrap(),
303            client,
304            None,
305            FileId::default(),
306        );
307        let pusher = MemPusher::with_capacity(mock_data.len());
308        #[allow(clippy::single_range_in_vec_init)]
309        let download_chunks = vec![0..mock_data.len() as u64];
310        let result = download_single(
311            puller,
312            pusher.clone(),
313            single::DownloadOptions {
314                retry_gap: Duration::from_secs(1),
315                push_queue_cap: 1024,
316            },
317        );
318
319        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
320        let mut push_progress: Vec<ProgressEntry> = Vec::new();
321        while let Ok(e) = result.event_chain.recv().await {
322            match e {
323                Event::PullProgress(_, p) => {
324                    pull_progress.merge_progress(p);
325                }
326                Event::PushProgress(_, p) => {
327                    push_progress.merge_progress(p);
328                }
329                _ => {}
330            }
331        }
332        dbg!(&pull_progress);
333        dbg!(&push_progress);
334        assert_eq!(pull_progress, download_chunks);
335        assert_eq!(push_progress, download_chunks);
336
337        result.join().await.unwrap();
338        assert_eq!(&**pusher.receive.lock(), mock_data);
339    }
340}