1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use httpdate::parse_http_date;
4use reqwest::{
5 Client, RequestBuilder, Response,
6 header::{self, HeaderMap, HeaderName, InvalidHeaderName},
7};
8use std::time::{Duration, SystemTime};
9use url::Url;
10
11impl HttpClient for Client {
12 type RequestBuilder = RequestBuilder;
13 fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
14 let mut req = self.get(url);
15 if let Some(range) = range {
16 req = req.header(
17 header::RANGE,
18 format!("bytes={}-{}", range.start, range.end - 1),
19 );
20 }
21 req
22 }
23}
24
25impl HttpRequestBuilder for RequestBuilder {
26 type Response = Response;
27 type RequestError = ReqwestResponseError;
28 async fn send(self) -> Result<Self::Response, Self::RequestError> {
29 let res = self.send().await.map_err(ReqwestResponseError::Reqwest)?;
30 let retry_after = res.headers().get(header::RETRY_AFTER);
31 if let Some(retry_after) = retry_after
32 && let Ok(retry_after) = retry_after.to_str()
33 {
34 let retry_after = match retry_after.parse() {
35 Ok(retry_after) => Some(Duration::from_secs(retry_after)),
36 Err(_) => match parse_http_date(retry_after) {
37 Ok(target_time) => target_time.duration_since(SystemTime::now()).ok(),
38 Err(_) => None,
39 },
40 };
41 if let Some(retry_after) = retry_after
42 && retry_after > Duration::ZERO
43 {
44 tokio::time::sleep(retry_after).await;
45 }
46 }
47 let status = res.status();
48 if status.is_success() {
49 Ok(res)
50 } else {
51 Err(ReqwestResponseError::StatusCode(status))
52 }
53 }
54}
55
56impl HttpResponse for Response {
57 type Headers = HeaderMap;
58 type ChunkError = reqwest::Error;
59 fn headers(&self) -> &Self::Headers {
60 self.headers()
61 }
62 fn url(&self) -> &Url {
63 self.url()
64 }
65 fn chunk(
66 &mut self,
67 ) -> impl Future<Output = Result<Option<bytes::Bytes>, Self::ChunkError>> + Send {
68 Response::chunk(self)
69 }
70}
71
72impl HttpHeaders for HeaderMap {
73 type GetHeaderError = ReqwestGetHeaderError;
74 fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
75 let header_name: HeaderName = header
76 .parse()
77 .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
78 let header_value = self
79 .get(&header_name)
80 .ok_or(ReqwestGetHeaderError::NotFound)?;
81 header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
82 }
83}
84
85#[derive(thiserror::Error, Debug)]
86pub enum ReqwestGetHeaderError {
87 #[error("Invalid header name {0:?}")]
88 InvalidHeaderName(InvalidHeaderName),
89 #[error("Failed to convert header value to string {0:?}")]
90 ToStr(reqwest::header::ToStrError),
91 #[error("Header not found")]
92 NotFound,
93}
94
95#[derive(thiserror::Error, Debug)]
96pub enum ReqwestResponseError {
97 #[error("Reqwest error {0:?}")]
98 Reqwest(reqwest::Error),
99 #[error("Status code {0:?}")]
100 StatusCode(reqwest::StatusCode),
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::{
107 http::{HttpError, HttpPuller, Prefetch},
108 url_info::FileId,
109 };
110 use fast_pull::{
111 Event, MergeProgress,
112 mem::MemPusher,
113 mock::build_mock_data,
114 multi::{self, download_multi},
115 single::{self, download_single},
116 };
117 use reqwest::{Client, StatusCode};
118 use std::{num::NonZero, time::Duration};
119
120 #[tokio::test]
121 async fn test_redirect_and_content_range() {
122 let mut server = mockito::Server::new_async().await;
123
124 let mock_redirect = server
125 .mock("GET", "/redirect")
126 .with_status(301)
127 .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
128 .create_async()
129 .await;
130
131 let mock_file = server
132 .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
133 .with_status(200)
134 .with_header("Content-Length", "1024")
135 .with_header("Accept-Ranges", "bytes")
136 .with_body(vec![0; 1024])
137 .create_async()
138 .await;
139
140 let client = Client::new();
141 let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
142 let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
143
144 assert_eq!(
145 url_info.final_url.as_str(),
146 format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
147 );
148 assert_eq!(url_info.size, 1024);
149 assert_eq!(url_info.name, "你好.txt");
150 assert!(url_info.supports_range);
151
152 mock_redirect.assert_async().await;
153 mock_file.assert_async().await;
154 }
155
156 #[tokio::test]
157 async fn test_filename_sources() {
158 let mut server = mockito::Server::new_async().await;
159
160 let mock1 = server
162 .mock("GET", "/test1")
163 .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
164 .create_async()
165 .await;
166 let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
167 let (url_info, _) = Client::new().prefetch(url).await.unwrap();
168 assert_eq!(url_info.name, "test.txt");
169 mock1.assert_async().await;
170
171 let mock2 = server
173 .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
174 .create_async()
175 .await;
176 let url = Url::parse(&format!(
177 "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
178 server.url()
179 ))
180 .unwrap();
181 let (url_info, _) = Client::new().prefetch(url).await.unwrap();
182 assert_eq!(url_info.name, "好好好.pdf");
183 mock2.assert_async().await;
184 }
185
186 #[tokio::test]
187 async fn test_error_handling() {
188 let mut server = mockito::Server::new_async().await;
189 let mock1 = server
190 .mock("GET", "/404")
191 .with_status(404)
192 .create_async()
193 .await;
194
195 let client = Client::new();
196 let url = Url::parse(&format!("{}/404", server.url())).unwrap();
197 match client.prefetch(url).await {
198 Ok(info) => unreachable!("404 status code should not success: {info:?}"),
199 Err(err) => match err {
200 HttpError::Request(e) => match e {
201 ReqwestResponseError::Reqwest(error) => unreachable!("{error:?}"),
202 ReqwestResponseError::StatusCode(status_code) => {
203 assert_eq!(status_code, StatusCode::NOT_FOUND)
204 }
205 },
206 HttpError::Chunk(_) => unreachable!(),
207 HttpError::GetHeader(_) => unreachable!(),
208 HttpError::Irrecoverable => unreachable!(),
209 HttpError::MismatchedBody(file_id) => {
210 unreachable!("404 status code should not return mismatched body: {file_id:?}")
211 }
212 },
213 }
214 mock1.assert_async().await;
215 }
216
217 #[tokio::test]
218 async fn test_concurrent_download() {
219 let mock_data = build_mock_data(300 * 1024 * 1024);
220 let mut server = mockito::Server::new_async().await;
221 let mock_body_clone = mock_data.clone();
222 let _mock = server
223 .mock("GET", "/concurrent")
224 .with_status(206)
225 .with_header("Accept-Ranges", "bytes")
226 .with_body_from_request(move |request| {
227 if !request.has_header("Range") {
228 return mock_body_clone.clone();
229 }
230 let range = request.header("Range")[0];
231 println!("range: {range:?}");
232 range
233 .to_str()
234 .unwrap()
235 .rsplit('=')
236 .next()
237 .unwrap()
238 .split(',')
239 .map(|p| p.trim().splitn(2, '-'))
240 .map(|mut p| {
241 let start = p.next().unwrap().parse::<usize>().unwrap();
242 let end = p.next().unwrap().parse::<usize>().unwrap();
243 start..=end
244 })
245 .flat_map(|p| mock_body_clone[p].to_vec())
246 .collect()
247 })
248 .create_async()
249 .await;
250 let puller = HttpPuller::new(
251 format!("{}/concurrent", server.url()).parse().unwrap(),
252 Client::new(),
253 None,
254 FileId::empty(),
255 );
256 let pusher = MemPusher::with_capacity(mock_data.len());
257 #[allow(clippy::single_range_in_vec_init)]
258 let download_chunks = vec![0..mock_data.len() as u64];
259 let result = download_multi(
260 puller,
261 pusher.clone(),
262 multi::DownloadOptions {
263 concurrent: NonZero::new(32).unwrap(),
264 retry_gap: Duration::from_secs(1),
265 push_queue_cap: 1024,
266 download_chunks: download_chunks.clone(),
267 min_chunk_size: NonZero::new(1).unwrap(),
268 },
269 )
270 .await;
271
272 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
273 let mut push_progress: Vec<ProgressEntry> = Vec::new();
274 while let Ok(e) = result.event_chain.recv().await {
275 match e {
276 Event::PullProgress(_, p) => {
277 pull_progress.merge_progress(p);
278 }
279 Event::PushProgress(_, p) => {
280 push_progress.merge_progress(p);
281 }
282 _ => {}
283 }
284 }
285 dbg!(&pull_progress);
286 dbg!(&push_progress);
287 assert_eq!(pull_progress, download_chunks);
288 assert_eq!(push_progress, download_chunks);
289
290 result.join().await.unwrap();
291 assert_eq!(&**pusher.receive.lock(), mock_data);
292 }
293
294 #[tokio::test]
295 async fn test_sequential_download() {
296 let mock_data = build_mock_data(300 * 1024 * 1024);
297 let mut server = mockito::Server::new_async().await;
298 let _mock = server
299 .mock("GET", "/sequential")
300 .with_status(200)
301 .with_body(mock_data.clone())
302 .create_async()
303 .await;
304 let puller = HttpPuller::new(
305 format!("{}/sequential", server.url()).parse().unwrap(),
306 Client::new(),
307 None,
308 FileId::empty(),
309 );
310 let pusher = MemPusher::with_capacity(mock_data.len());
311 #[allow(clippy::single_range_in_vec_init)]
312 let download_chunks = vec![0..mock_data.len() as u64];
313 let result = download_single(
314 puller,
315 pusher.clone(),
316 single::DownloadOptions {
317 retry_gap: Duration::from_secs(1),
318 push_queue_cap: 1024,
319 },
320 )
321 .await;
322
323 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
324 let mut push_progress: Vec<ProgressEntry> = Vec::new();
325 while let Ok(e) = result.event_chain.recv().await {
326 match e {
327 Event::PullProgress(_, p) => {
328 pull_progress.merge_progress(p);
329 }
330 Event::PushProgress(_, p) => {
331 push_progress.merge_progress(p);
332 }
333 _ => {}
334 }
335 }
336 dbg!(&pull_progress);
337 dbg!(&push_progress);
338 assert_eq!(pull_progress, download_chunks);
339 assert_eq!(push_progress, download_chunks);
340
341 result.join().await.unwrap();
342 assert_eq!(&**pusher.receive.lock(), mock_data);
343 }
344}