fast_down/reqwest/
mod.rs

1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use httpdate::parse_http_date;
4use reqwest::{
5    Client, RequestBuilder, Response,
6    header::{self, HeaderMap, HeaderName, InvalidHeaderName},
7};
8use std::time::{Duration, SystemTime};
9use url::Url;
10
11impl HttpClient for Client {
12    type RequestBuilder = RequestBuilder;
13    fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
14        let mut req = self.get(url);
15        if let Some(range) = range {
16            req = req.header(
17                header::RANGE,
18                format!("bytes={}-{}", range.start, range.end - 1),
19            );
20        }
21        req
22    }
23}
24
25impl HttpRequestBuilder for RequestBuilder {
26    type Response = Response;
27    type RequestError = ReqwestResponseError;
28    async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
29        let res = self
30            .send()
31            .await
32            .map_err(|e| (ReqwestResponseError::Reqwest(e), None))?;
33        let status = res.status();
34        if status.is_success() {
35            Ok(res)
36        } else {
37            let retry_after = res
38                .headers()
39                .get(header::RETRY_AFTER)
40                .and_then(|r| r.to_str().ok())
41                .and_then(|r| match r.parse() {
42                    Ok(r) => Some(Duration::from_secs(r)),
43                    Err(_) => match parse_http_date(r) {
44                        Ok(target_time) => target_time.duration_since(SystemTime::now()).ok(),
45                        Err(_) => None,
46                    },
47                });
48            Err((ReqwestResponseError::StatusCode(status), retry_after))
49        }
50    }
51}
52
53impl HttpResponse for Response {
54    type Headers = HeaderMap;
55    type ChunkError = reqwest::Error;
56    fn headers(&self) -> &Self::Headers {
57        self.headers()
58    }
59    fn url(&self) -> &Url {
60        self.url()
61    }
62    fn chunk(
63        &mut self,
64    ) -> impl Future<Output = Result<Option<bytes::Bytes>, Self::ChunkError>> + Send {
65        self.chunk()
66    }
67}
68
69impl HttpHeaders for HeaderMap {
70    type GetHeaderError = ReqwestGetHeaderError;
71    fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
72        let header_name: HeaderName = header
73            .parse()
74            .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
75        let header_value = self
76            .get(&header_name)
77            .ok_or(ReqwestGetHeaderError::NotFound)?;
78        header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
79    }
80}
81
82#[derive(thiserror::Error, Debug)]
83pub enum ReqwestGetHeaderError {
84    #[error("Invalid header name {0:?}")]
85    InvalidHeaderName(InvalidHeaderName),
86    #[error("Failed to convert header value to string {0:?}")]
87    ToStr(reqwest::header::ToStrError),
88    #[error("Header not found")]
89    NotFound,
90}
91
92#[derive(thiserror::Error, Debug)]
93pub enum ReqwestResponseError {
94    #[error("Reqwest error {0:?}")]
95    Reqwest(reqwest::Error),
96    #[error("Status code {0:?}")]
97    StatusCode(reqwest::StatusCode),
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::{
104        http::{HttpError, HttpPuller, Prefetch},
105        url_info::FileId,
106    };
107    use fast_pull::{
108        Event, MergeProgress,
109        mem::MemPusher,
110        mock::build_mock_data,
111        multi::{self, download_multi},
112        single::{self, download_single},
113    };
114    use reqwest::{Client, StatusCode};
115    use std::{num::NonZero, time::Duration};
116
117    #[tokio::test]
118    async fn test_redirect_and_content_range() {
119        let mut server = mockito::Server::new_async().await;
120
121        let mock_redirect = server
122            .mock("GET", "/redirect")
123            .with_status(301)
124            .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
125            .create_async()
126            .await;
127
128        let mock_file = server
129            .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
130            .with_status(200)
131            .with_header("Content-Length", "1024")
132            .with_header("Accept-Ranges", "bytes")
133            .with_body(vec![0; 1024])
134            .create_async()
135            .await;
136
137        let client = Client::new();
138        let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
139        let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
140
141        assert_eq!(
142            url_info.final_url.as_str(),
143            format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
144        );
145        assert_eq!(url_info.size, 1024);
146        assert_eq!(url_info.raw_name, "你好.txt");
147        assert!(url_info.supports_range);
148
149        mock_redirect.assert_async().await;
150        mock_file.assert_async().await;
151    }
152
153    #[tokio::test]
154    async fn test_filename_sources() {
155        let mut server = mockito::Server::new_async().await;
156
157        // Test with Content-Disposition header
158        let mock1 = server
159            .mock("GET", "/test1")
160            .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
161            .create_async()
162            .await;
163        let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
164        let (url_info, _) = Client::new().prefetch(url).await.unwrap();
165        assert_eq!(url_info.raw_name, "test.txt");
166        mock1.assert_async().await;
167
168        // Test URL path source
169        let mock2 = server
170            .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
171            .create_async()
172            .await;
173        let url = Url::parse(&format!(
174            "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
175            server.url()
176        ))
177        .unwrap();
178        let (url_info, _) = Client::new().prefetch(url).await.unwrap();
179        assert_eq!(url_info.raw_name, "好好好.pdf");
180        mock2.assert_async().await;
181    }
182
183    #[tokio::test]
184    async fn test_error_handling() {
185        let mut server = mockito::Server::new_async().await;
186        let mock1 = server
187            .mock("GET", "/404")
188            .with_status(404)
189            .create_async()
190            .await;
191
192        let client = Client::new();
193        let url = Url::parse(&format!("{}/404", server.url())).unwrap();
194        match client.prefetch(url).await {
195            Ok(info) => unreachable!("404 status code should not success: {info:?}"),
196            Err((err, _)) => match err {
197                HttpError::Request(e) => match e {
198                    ReqwestResponseError::Reqwest(error) => unreachable!("{error:?}"),
199                    ReqwestResponseError::StatusCode(status_code) => {
200                        assert_eq!(status_code, StatusCode::NOT_FOUND)
201                    }
202                },
203                HttpError::Chunk(_) => unreachable!(),
204                HttpError::GetHeader(_) => unreachable!(),
205                HttpError::Irrecoverable => unreachable!(),
206                HttpError::MismatchedBody(file_id) => {
207                    unreachable!("404 status code should not return mismatched body: {file_id:?}")
208                }
209            },
210        }
211        mock1.assert_async().await;
212    }
213
214    #[tokio::test]
215    async fn test_concurrent_download() {
216        let mock_data = build_mock_data(300 * 1024 * 1024);
217        let mut server = mockito::Server::new_async().await;
218        let mock_body_clone = mock_data.clone();
219        let _mock = server
220            .mock("GET", "/concurrent")
221            .with_status(206)
222            .with_header("Accept-Ranges", "bytes")
223            .with_body_from_request(move |request| {
224                if !request.has_header("Range") {
225                    return mock_body_clone.clone();
226                }
227                let range = request.header("Range")[0];
228                println!("range: {range:?}");
229                range
230                    .to_str()
231                    .unwrap()
232                    .rsplit('=')
233                    .next()
234                    .unwrap()
235                    .split(',')
236                    .map(|p| p.trim().splitn(2, '-'))
237                    .map(|mut p| {
238                        let start = p.next().unwrap().parse::<usize>().unwrap();
239                        let end = p.next().unwrap().parse::<usize>().unwrap();
240                        start..=end
241                    })
242                    .flat_map(|p| mock_body_clone[p].to_vec())
243                    .collect()
244            })
245            .create_async()
246            .await;
247        let puller = HttpPuller::new(
248            format!("{}/concurrent", server.url()).parse().unwrap(),
249            Client::new(),
250            None,
251            FileId::empty(),
252        );
253        let pusher = MemPusher::with_capacity(mock_data.len());
254        #[allow(clippy::single_range_in_vec_init)]
255        let download_chunks = vec![0..mock_data.len() as u64];
256        let result = download_multi(
257            puller,
258            pusher.clone(),
259            multi::DownloadOptions {
260                concurrent: NonZero::new(32).unwrap(),
261                retry_gap: Duration::from_secs(1),
262                push_queue_cap: 1024,
263                download_chunks: download_chunks.clone(),
264                min_chunk_size: NonZero::new(1).unwrap(),
265            },
266        )
267        .await;
268
269        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
270        let mut push_progress: Vec<ProgressEntry> = Vec::new();
271        while let Ok(e) = result.event_chain.recv().await {
272            match e {
273                Event::PullProgress(_, p) => {
274                    pull_progress.merge_progress(p);
275                }
276                Event::PushProgress(_, p) => {
277                    push_progress.merge_progress(p);
278                }
279                _ => {}
280            }
281        }
282        dbg!(&pull_progress);
283        dbg!(&push_progress);
284        assert_eq!(pull_progress, download_chunks);
285        assert_eq!(push_progress, download_chunks);
286
287        result.join().await.unwrap();
288        assert_eq!(&**pusher.receive.lock(), mock_data);
289    }
290
291    #[tokio::test]
292    async fn test_sequential_download() {
293        let mock_data = build_mock_data(300 * 1024 * 1024);
294        let mut server = mockito::Server::new_async().await;
295        let _mock = server
296            .mock("GET", "/sequential")
297            .with_status(200)
298            .with_body(mock_data.clone())
299            .create_async()
300            .await;
301        let puller = HttpPuller::new(
302            format!("{}/sequential", server.url()).parse().unwrap(),
303            Client::new(),
304            None,
305            FileId::empty(),
306        );
307        let pusher = MemPusher::with_capacity(mock_data.len());
308        #[allow(clippy::single_range_in_vec_init)]
309        let download_chunks = vec![0..mock_data.len() as u64];
310        let result = download_single(
311            puller,
312            pusher.clone(),
313            single::DownloadOptions {
314                retry_gap: Duration::from_secs(1),
315                push_queue_cap: 1024,
316            },
317        )
318        .await;
319
320        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
321        let mut push_progress: Vec<ProgressEntry> = Vec::new();
322        while let Ok(e) = result.event_chain.recv().await {
323            match e {
324                Event::PullProgress(_, p) => {
325                    pull_progress.merge_progress(p);
326                }
327                Event::PushProgress(_, p) => {
328                    push_progress.merge_progress(p);
329                }
330                _ => {}
331            }
332        }
333        dbg!(&pull_progress);
334        dbg!(&push_progress);
335        assert_eq!(pull_progress, download_chunks);
336        assert_eq!(push_progress, download_chunks);
337
338        result.join().await.unwrap();
339        assert_eq!(&**pusher.receive.lock(), mock_data);
340    }
341}