fast_down/reqwest/
mod.rs

1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use httpdate::parse_http_date;
4use reqwest::{
5    Client, RequestBuilder, Response,
6    header::{self, HeaderMap, HeaderName, InvalidHeaderName},
7};
8use std::time::{Duration, SystemTime};
9use url::Url;
10
11impl HttpClient for Client {
12    type RequestBuilder = RequestBuilder;
13    fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
14        let mut req = self.get(url);
15        if let Some(range) = range {
16            req = req.header(
17                header::RANGE,
18                format!("bytes={}-{}", range.start, range.end - 1),
19            );
20        }
21        req
22    }
23}
24
25impl HttpRequestBuilder for RequestBuilder {
26    type Response = Response;
27    type RequestError = ReqwestResponseError;
28    async fn send(self) -> Result<Self::Response, Self::RequestError> {
29        let res = self.send().await.map_err(ReqwestResponseError::Reqwest)?;
30        let retry_after = res.headers().get(header::RETRY_AFTER);
31        if let Some(retry_after) = retry_after
32            && let Ok(retry_after) = retry_after.to_str()
33        {
34            let retry_after = match retry_after.parse() {
35                Ok(retry_after) => Some(Duration::from_secs(retry_after)),
36                Err(_) => match parse_http_date(retry_after) {
37                    Ok(target_time) => target_time.duration_since(SystemTime::now()).ok(),
38                    Err(_) => None,
39                },
40            };
41            if let Some(retry_after) = retry_after
42                && retry_after > Duration::ZERO
43            {
44                tokio::time::sleep(retry_after).await;
45            }
46        }
47        let status = res.status();
48        if status.is_success() {
49            Ok(res)
50        } else {
51            Err(ReqwestResponseError::StatusCode(status))
52        }
53    }
54}
55
56impl HttpResponse for Response {
57    type Headers = HeaderMap;
58    type ChunkError = reqwest::Error;
59    fn headers(&self) -> &Self::Headers {
60        self.headers()
61    }
62    fn url(&self) -> &Url {
63        self.url()
64    }
65    fn chunk(
66        &mut self,
67    ) -> impl Future<Output = Result<Option<bytes::Bytes>, Self::ChunkError>> + Send {
68        Response::chunk(self)
69    }
70}
71
72impl HttpHeaders for HeaderMap {
73    type GetHeaderError = ReqwestGetHeaderError;
74    fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
75        let header_name: HeaderName = header
76            .parse()
77            .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
78        let header_value = self
79            .get(&header_name)
80            .ok_or(ReqwestGetHeaderError::NotFound)?;
81        header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
82    }
83}
84
85#[derive(thiserror::Error, Debug)]
86pub enum ReqwestGetHeaderError {
87    #[error("Invalid header name {0:?}")]
88    InvalidHeaderName(InvalidHeaderName),
89    #[error("Failed to convert header value to string {0:?}")]
90    ToStr(reqwest::header::ToStrError),
91    #[error("Header not found")]
92    NotFound,
93}
94
95#[derive(thiserror::Error, Debug)]
96pub enum ReqwestResponseError {
97    #[error("Reqwest error {0:?}")]
98    Reqwest(reqwest::Error),
99    #[error("Status code {0:?}")]
100    StatusCode(reqwest::StatusCode),
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::{
107        http::{HttpError, HttpPuller, Prefetch},
108        url_info::FileId,
109    };
110    use fast_pull::{
111        Event, MergeProgress,
112        mem::MemPusher,
113        mock::build_mock_data,
114        multi::{self, download_multi},
115        single::{self, download_single},
116    };
117    use reqwest::{Client, StatusCode};
118    use std::{num::NonZero, time::Duration};
119
120    #[tokio::test]
121    async fn test_redirect_and_content_range() {
122        let mut server = mockito::Server::new_async().await;
123
124        let mock_redirect = server
125            .mock("GET", "/redirect")
126            .with_status(301)
127            .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
128            .create_async()
129            .await;
130
131        let mock_file = server
132            .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
133            .with_status(200)
134            .with_header("Content-Length", "1024")
135            .with_header("Accept-Ranges", "bytes")
136            .with_body(vec![0; 1024])
137            .create_async()
138            .await;
139
140        let client = Client::new();
141        let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
142        let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
143
144        assert_eq!(
145            url_info.final_url.as_str(),
146            format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
147        );
148        assert_eq!(url_info.size, 1024);
149        assert_eq!(url_info.name, "你好.txt");
150        assert!(url_info.supports_range);
151
152        mock_redirect.assert_async().await;
153        mock_file.assert_async().await;
154    }
155
156    #[tokio::test]
157    async fn test_filename_sources() {
158        let mut server = mockito::Server::new_async().await;
159
160        // Test with Content-Disposition header
161        let mock1 = server
162            .mock("GET", "/test1")
163            .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
164            .create_async()
165            .await;
166        let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
167        let (url_info, _) = Client::new().prefetch(url).await.unwrap();
168        assert_eq!(url_info.name, "test.txt");
169        mock1.assert_async().await;
170
171        // Test URL path source
172        let mock2 = server
173            .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
174            .create_async()
175            .await;
176        let url = Url::parse(&format!(
177            "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
178            server.url()
179        ))
180        .unwrap();
181        let (url_info, _) = Client::new().prefetch(url).await.unwrap();
182        assert_eq!(url_info.name, "好好好.pdf");
183        mock2.assert_async().await;
184    }
185
186    #[tokio::test]
187    async fn test_error_handling() {
188        let mut server = mockito::Server::new_async().await;
189        let mock1 = server
190            .mock("GET", "/404")
191            .with_status(404)
192            .create_async()
193            .await;
194
195        let client = Client::new();
196        let url = Url::parse(&format!("{}/404", server.url())).unwrap();
197        match client.prefetch(url).await {
198            Ok(info) => unreachable!("404 status code should not success: {info:?}"),
199            Err(err) => match err {
200                HttpError::Request(e) => match e {
201                    ReqwestResponseError::Reqwest(error) => unreachable!("{error:?}"),
202                    ReqwestResponseError::StatusCode(status_code) => {
203                        assert_eq!(status_code, StatusCode::NOT_FOUND)
204                    }
205                },
206                HttpError::Chunk(_) => unreachable!(),
207                HttpError::GetHeader(_) => unreachable!(),
208                HttpError::Irrecoverable => unreachable!(),
209                HttpError::MismatchedBody(file_id) => {
210                    unreachable!("404 status code should not return mismatched body: {file_id:?}")
211                }
212            },
213        }
214        mock1.assert_async().await;
215    }
216
217    #[tokio::test]
218    async fn test_concurrent_download() {
219        let mock_data = build_mock_data(300 * 1024 * 1024);
220        let mut server = mockito::Server::new_async().await;
221        let mock_body_clone = mock_data.clone();
222        let _mock = server
223            .mock("GET", "/concurrent")
224            .with_status(206)
225            .with_header("Accept-Ranges", "bytes")
226            .with_body_from_request(move |request| {
227                if !request.has_header("Range") {
228                    return mock_body_clone.clone();
229                }
230                let range = request.header("Range")[0];
231                println!("range: {range:?}");
232                range
233                    .to_str()
234                    .unwrap()
235                    .rsplit('=')
236                    .next()
237                    .unwrap()
238                    .split(',')
239                    .map(|p| p.trim().splitn(2, '-'))
240                    .map(|mut p| {
241                        let start = p.next().unwrap().parse::<usize>().unwrap();
242                        let end = p.next().unwrap().parse::<usize>().unwrap();
243                        start..=end
244                    })
245                    .flat_map(|p| mock_body_clone[p].to_vec())
246                    .collect()
247            })
248            .create_async()
249            .await;
250        let puller = HttpPuller::new(
251            format!("{}/concurrent", server.url()).parse().unwrap(),
252            Client::new(),
253            None,
254            FileId::empty(),
255        );
256        let pusher = MemPusher::with_capacity(mock_data.len());
257        #[allow(clippy::single_range_in_vec_init)]
258        let download_chunks = vec![0..mock_data.len() as u64];
259        let result = download_multi(
260            puller,
261            pusher.clone(),
262            multi::DownloadOptions {
263                concurrent: NonZero::new(32).unwrap(),
264                retry_gap: Duration::from_secs(1),
265                push_queue_cap: 1024,
266                download_chunks: download_chunks.clone(),
267                min_chunk_size: NonZero::new(1).unwrap(),
268            },
269        )
270        .await;
271
272        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
273        let mut push_progress: Vec<ProgressEntry> = Vec::new();
274        while let Ok(e) = result.event_chain.recv().await {
275            match e {
276                Event::PullProgress(_, p) => {
277                    pull_progress.merge_progress(p);
278                }
279                Event::PushProgress(_, p) => {
280                    push_progress.merge_progress(p);
281                }
282                _ => {}
283            }
284        }
285        dbg!(&pull_progress);
286        dbg!(&push_progress);
287        assert_eq!(pull_progress, download_chunks);
288        assert_eq!(push_progress, download_chunks);
289
290        result.join().await.unwrap();
291        assert_eq!(&**pusher.receive.lock(), mock_data);
292    }
293
294    #[tokio::test]
295    async fn test_sequential_download() {
296        let mock_data = build_mock_data(300 * 1024 * 1024);
297        let mut server = mockito::Server::new_async().await;
298        let _mock = server
299            .mock("GET", "/sequential")
300            .with_status(200)
301            .with_body(mock_data.clone())
302            .create_async()
303            .await;
304        let puller = HttpPuller::new(
305            format!("{}/sequential", server.url()).parse().unwrap(),
306            Client::new(),
307            None,
308            FileId::empty(),
309        );
310        let pusher = MemPusher::with_capacity(mock_data.len());
311        #[allow(clippy::single_range_in_vec_init)]
312        let download_chunks = vec![0..mock_data.len() as u64];
313        let result = download_single(
314            puller,
315            pusher.clone(),
316            single::DownloadOptions {
317                retry_gap: Duration::from_secs(1),
318                push_queue_cap: 1024,
319            },
320        )
321        .await;
322
323        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
324        let mut push_progress: Vec<ProgressEntry> = Vec::new();
325        while let Ok(e) = result.event_chain.recv().await {
326            match e {
327                Event::PullProgress(_, p) => {
328                    pull_progress.merge_progress(p);
329                }
330                Event::PushProgress(_, p) => {
331                    push_progress.merge_progress(p);
332                }
333                _ => {}
334            }
335        }
336        dbg!(&pull_progress);
337        dbg!(&push_progress);
338        assert_eq!(pull_progress, download_chunks);
339        assert_eq!(push_progress, download_chunks);
340
341        result.join().await.unwrap();
342        assert_eq!(&**pusher.receive.lock(), mock_data);
343    }
344}