1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use httpdate::parse_http_date;
4use reqwest::{
5 Client, RequestBuilder, Response,
6 header::{self, HeaderMap, HeaderName, InvalidHeaderName},
7};
8use std::time::{Duration, SystemTime};
9use url::Url;
10
11impl HttpClient for Client {
12 type RequestBuilder = RequestBuilder;
13 fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
14 let mut req = self.get(url);
15 if let Some(range) = range {
16 req = req.header(
17 header::RANGE,
18 format!("bytes={}-{}", range.start, range.end - 1),
19 );
20 }
21 req
22 }
23}
24
25impl HttpRequestBuilder for RequestBuilder {
26 type Response = Response;
27 type RequestError = ReqwestResponseError;
28 async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
29 let res = self
30 .send()
31 .await
32 .map_err(|e| (ReqwestResponseError::Reqwest(e), None))?;
33 let status = res.status();
34 if status.is_success() {
35 Ok(res)
36 } else {
37 let retry_after = res
38 .headers()
39 .get(header::RETRY_AFTER)
40 .and_then(|r| r.to_str().ok())
41 .and_then(|r| match r.parse() {
42 Ok(r) => Some(Duration::from_secs(r)),
43 Err(_) => match parse_http_date(r) {
44 Ok(target_time) => target_time.duration_since(SystemTime::now()).ok(),
45 Err(_) => None,
46 },
47 });
48 Err((ReqwestResponseError::StatusCode(status), retry_after))
49 }
50 }
51}
52
53impl HttpResponse for Response {
54 type Headers = HeaderMap;
55 type ChunkError = reqwest::Error;
56 fn headers(&self) -> &Self::Headers {
57 self.headers()
58 }
59 fn url(&self) -> &Url {
60 self.url()
61 }
62 async fn chunk(&mut self) -> Result<Option<bytes::Bytes>, Self::ChunkError> {
63 self.chunk().await
64 }
65}
66
67impl HttpHeaders for HeaderMap {
68 type GetHeaderError = ReqwestGetHeaderError;
69 fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
70 let header_name: HeaderName = header
71 .parse()
72 .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
73 let header_value = self
74 .get(&header_name)
75 .ok_or(ReqwestGetHeaderError::NotFound)?;
76 header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
77 }
78}
79
80#[derive(thiserror::Error, Debug)]
81pub enum ReqwestGetHeaderError {
82 #[error("Invalid header name {0:?}")]
83 InvalidHeaderName(InvalidHeaderName),
84 #[error("Failed to convert header value to string {0:?}")]
85 ToStr(reqwest::header::ToStrError),
86 #[error("Header not found")]
87 NotFound,
88}
89
90#[derive(thiserror::Error, Debug)]
91pub enum ReqwestResponseError {
92 #[error("Reqwest error {0:?}")]
93 Reqwest(reqwest::Error),
94 #[error("Status code {0:?}")]
95 StatusCode(reqwest::StatusCode),
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use crate::{
102 http::{HttpError, HttpPuller, Prefetch},
103 url_info::FileId,
104 };
105 use fast_pull::{
106 Event, Merge,
107 mem::MemPusher,
108 mock::build_mock_data,
109 multi::{self, download_multi},
110 single::{self, download_single},
111 };
112 use reqwest::{Client, StatusCode};
113 use std::time::Duration;
114
115 #[tokio::test]
116 async fn test_redirect_and_content_range() {
117 let mut server = mockito::Server::new_async().await;
118 let client = Client::builder().no_proxy().build().unwrap();
119
120 let mock_redirect = server
121 .mock("GET", "/redirect")
122 .with_status(301)
123 .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
124 .create_async()
125 .await;
126
127 let mock_file = server
128 .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
129 .with_status(200)
130 .with_header("Content-Length", "1024")
131 .with_header("Accept-Ranges", "bytes")
132 .with_body(vec![0; 1024])
133 .create_async()
134 .await;
135
136 let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
137 let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
138
139 assert_eq!(
140 url_info.final_url.as_str(),
141 format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
142 );
143 assert_eq!(url_info.size, 1024);
144 assert_eq!(url_info.raw_name, "你好.txt");
145 assert!(url_info.supports_range);
146
147 mock_redirect.assert_async().await;
148 mock_file.assert_async().await;
149 }
150
151 #[tokio::test]
152 async fn test_filename_sources() {
153 let mut server = mockito::Server::new_async().await;
154 let client = Client::builder().no_proxy().build().unwrap();
155
156 let mock1 = server
158 .mock("GET", "/test1")
159 .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
160 .create_async()
161 .await;
162 let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
163 let (url_info, _) = client.prefetch(url).await.unwrap();
164 assert_eq!(url_info.raw_name, "test.txt");
165 mock1.assert_async().await;
166
167 let mock_star = server
169 .mock("GET", "/test_star")
170 .with_header(
171 "Content-Disposition",
172 r#"attachment; filename*=UTF-8''%E6%B5%8B%E8%AF%95.txt"#,
173 ) .create_async()
175 .await;
176 let url = Url::parse(&format!("{}/test_star", server.url())).unwrap();
177 let (url_info, _) = client.prefetch(url).await.unwrap();
178 assert_eq!(url_info.raw_name, "测试.txt");
179 mock_star.assert_async().await;
180
181 let mock_both = server
182 .mock("GET", "/test_both")
183 .with_header(
184 "Content-Disposition",
185 r#"attachment; filename="fallback.txt"; filename*=UTF-8''%E6%B5%8B%E8%AF%95.txt"#,
186 )
187 .create_async()
188 .await;
189 let url = Url::parse(&format!("{}/test_both", server.url())).unwrap();
190 let (url_info, _) = client.prefetch(url).await.unwrap();
191 assert_eq!(url_info.raw_name, "测试.txt");
192 mock_both.assert_async().await;
193
194 let mock2 = server
196 .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
197 .create_async()
198 .await;
199 let url = Url::parse(&format!(
200 "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
201 server.url()
202 ))
203 .unwrap();
204 let (url_info, _) = client.prefetch(url).await.unwrap();
205 assert_eq!(url_info.raw_name, "好好好.pdf");
206 mock2.assert_async().await;
207 }
208
209 #[tokio::test]
210 async fn test_error_handling() {
211 let mut server = mockito::Server::new_async().await;
212 let client = Client::builder().no_proxy().build().unwrap();
213 let mock1 = server
214 .mock("GET", "/404")
215 .with_status(404)
216 .create_async()
217 .await;
218
219 let url = Url::parse(&format!("{}/404", server.url())).unwrap();
220 match client.prefetch(url).await {
221 Ok(info) => unreachable!("404 status code should not success: {info:?}"),
222 Err((err, _)) => match err {
223 HttpError::Request(e) => match e {
224 ReqwestResponseError::Reqwest(error) => unreachable!("{error:?}"),
225 ReqwestResponseError::StatusCode(status_code) => {
226 assert_eq!(status_code, StatusCode::NOT_FOUND)
227 }
228 },
229 HttpError::Chunk(_) => unreachable!(),
230 HttpError::GetHeader(_) => unreachable!(),
231 HttpError::Irrecoverable => unreachable!(),
232 HttpError::MismatchedBody(file_id) => {
233 unreachable!("404 status code should not return mismatched body: {file_id:?}")
234 }
235 },
236 }
237 mock1.assert_async().await;
238 }
239
240 #[tokio::test]
241 async fn test_concurrent_download() {
242 let mock_data = build_mock_data(300 * 1024 * 1024);
243 let mut server = mockito::Server::new_async().await;
244 let client = Client::builder().no_proxy().build().unwrap();
245 let mock_body_clone = mock_data.clone();
246 let _mock = server
247 .mock("GET", "/concurrent")
248 .with_status(206)
249 .with_header("Accept-Ranges", "bytes")
250 .with_body_from_request(move |request| {
251 if !request.has_header("Range") {
252 return mock_body_clone.clone();
253 }
254 let range = request.header("Range")[0];
255 println!("range: {range:?}");
256 range
257 .to_str()
258 .unwrap()
259 .rsplit('=')
260 .next()
261 .unwrap()
262 .split(',')
263 .map(|p| p.trim().splitn(2, '-'))
264 .map(|mut p| {
265 let start = p.next().unwrap().parse::<usize>().unwrap();
266 let end = p.next().unwrap().parse::<usize>().unwrap();
267 start..=end
268 })
269 .flat_map(|p| mock_body_clone[p].to_vec())
270 .collect()
271 })
272 .create_async()
273 .await;
274 let puller = HttpPuller::new(
275 format!("{}/concurrent", server.url()).parse().unwrap(),
276 client,
277 None,
278 FileId::default(),
279 );
280 let pusher = MemPusher::with_capacity(mock_data.len());
281 #[allow(clippy::single_range_in_vec_init)]
282 let download_chunks = vec![0..mock_data.len() as u64];
283 let result = download_multi(
284 puller,
285 pusher.clone(),
286 multi::DownloadOptions {
287 concurrent: 32,
288 retry_gap: Duration::from_secs(1),
289 push_queue_cap: 1024,
290 download_chunks: download_chunks.iter(),
291 pull_timeout: Duration::from_secs(5),
292 min_chunk_size: 1,
293 max_speculative: 3,
294 },
295 );
296
297 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
298 let mut push_progress: Vec<ProgressEntry> = Vec::new();
299 while let Ok(e) = result.event_chain.recv().await {
300 match e {
301 Event::PullProgress(_, p) => {
302 pull_progress.merge_progress(p);
303 }
304 Event::PushProgress(_, p) => {
305 push_progress.merge_progress(p);
306 }
307 _ => {}
308 }
309 }
310 dbg!(&pull_progress);
311 dbg!(&push_progress);
312 assert_eq!(pull_progress, download_chunks);
313 assert_eq!(push_progress, download_chunks);
314
315 result.join().await.unwrap();
316 assert_eq!(&**pusher.receive.lock(), mock_data);
317 }
318
319 #[tokio::test]
320 async fn test_sequential_download() {
321 let mock_data = build_mock_data(300 * 1024 * 1024);
322 let mut server = mockito::Server::new_async().await;
323 let client = Client::builder().no_proxy().build().unwrap();
324 let _mock = server
325 .mock("GET", "/sequential")
326 .with_status(200)
327 .with_body(mock_data.clone())
328 .create_async()
329 .await;
330 let puller = HttpPuller::new(
331 format!("{}/sequential", server.url()).parse().unwrap(),
332 client,
333 None,
334 FileId::default(),
335 );
336 let pusher = MemPusher::with_capacity(mock_data.len());
337 #[allow(clippy::single_range_in_vec_init)]
338 let download_chunks = vec![0..mock_data.len() as u64];
339 let result = download_single(
340 puller,
341 pusher.clone(),
342 single::DownloadOptions {
343 retry_gap: Duration::from_secs(1),
344 push_queue_cap: 1024,
345 },
346 );
347
348 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
349 let mut push_progress: Vec<ProgressEntry> = Vec::new();
350 while let Ok(e) = result.event_chain.recv().await {
351 match e {
352 Event::PullProgress(_, p) => {
353 pull_progress.merge_progress(p);
354 }
355 Event::PushProgress(_, p) => {
356 push_progress.merge_progress(p);
357 }
358 _ => {}
359 }
360 }
361 dbg!(&pull_progress);
362 dbg!(&push_progress);
363 assert_eq!(pull_progress, download_chunks);
364 assert_eq!(push_progress, download_chunks);
365
366 result.join().await.unwrap();
367 assert_eq!(&**pusher.receive.lock(), mock_data);
368 }
369}