1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use httpdate::parse_http_date;
4use reqwest::{
5 Client, RequestBuilder, Response,
6 header::{self, HeaderMap, HeaderName, InvalidHeaderName},
7};
8use std::time::{Duration, SystemTime};
9use url::Url;
10
11impl HttpClient for Client {
12 type RequestBuilder = RequestBuilder;
13 fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
14 let mut req = self.get(url);
15 if let Some(range) = range {
16 req = req.header(
17 header::RANGE,
18 format!("bytes={}-{}", range.start, range.end - 1),
19 );
20 }
21 req
22 }
23}
24
25impl HttpRequestBuilder for RequestBuilder {
26 type Response = Response;
27 type RequestError = ReqwestResponseError;
28 async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
29 let res = self
30 .send()
31 .await
32 .map_err(|e| (ReqwestResponseError::Reqwest(e), None))?;
33 let status = res.status();
34 if status.is_success() {
35 Ok(res)
36 } else {
37 let retry_after = res
38 .headers()
39 .get(header::RETRY_AFTER)
40 .and_then(|r| r.to_str().ok())
41 .and_then(|r| match r.parse() {
42 Ok(r) => Some(Duration::from_secs(r)),
43 Err(_) => match parse_http_date(r) {
44 Ok(target_time) => target_time.duration_since(SystemTime::now()).ok(),
45 Err(_) => None,
46 },
47 });
48 Err((ReqwestResponseError::StatusCode(status), retry_after))
49 }
50 }
51}
52
53impl HttpResponse for Response {
54 type Headers = HeaderMap;
55 type ChunkError = reqwest::Error;
56 fn headers(&self) -> &Self::Headers {
57 self.headers()
58 }
59 fn url(&self) -> &Url {
60 self.url()
61 }
62 async fn chunk(&mut self) -> Result<Option<bytes::Bytes>, Self::ChunkError> {
63 self.chunk().await
64 }
65}
66
67impl HttpHeaders for HeaderMap {
68 type GetHeaderError = ReqwestGetHeaderError;
69 fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
70 let header_name: HeaderName = header
71 .parse()
72 .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
73 let header_value = self
74 .get(&header_name)
75 .ok_or(ReqwestGetHeaderError::NotFound)?;
76 header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
77 }
78}
79
80#[derive(thiserror::Error, Debug)]
81pub enum ReqwestGetHeaderError {
82 #[error("Invalid header name {0:?}")]
83 InvalidHeaderName(InvalidHeaderName),
84 #[error("Failed to convert header value to string {0:?}")]
85 ToStr(reqwest::header::ToStrError),
86 #[error("Header not found")]
87 NotFound,
88}
89
90#[derive(thiserror::Error, Debug)]
91pub enum ReqwestResponseError {
92 #[error("Reqwest error {0:?}")]
93 Reqwest(reqwest::Error),
94 #[error("Status code {0:?}")]
95 StatusCode(reqwest::StatusCode),
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use crate::{
102 http::{HttpError, HttpPuller, Prefetch},
103 url_info::FileId,
104 };
105 use fast_pull::{
106 Event, Merge,
107 mem::MemPusher,
108 mock::build_mock_data,
109 multi::{self, download_multi},
110 single::{self, download_single},
111 };
112 use reqwest::{Client, StatusCode};
113 use std::time::Duration;
114
115 #[tokio::test]
116 async fn test_redirect_and_content_range() {
117 let mut server = mockito::Server::new_async().await;
118 let client = Client::builder().no_proxy().build().unwrap();
119
120 let mock_redirect = server
121 .mock("GET", "/redirect")
122 .with_status(301)
123 .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
124 .create_async()
125 .await;
126
127 let mock_file = server
128 .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
129 .with_status(200)
130 .with_header("Content-Length", "1024")
131 .with_header("Accept-Ranges", "bytes")
132 .with_body(vec![0; 1024])
133 .create_async()
134 .await;
135
136 let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
137 let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
138
139 assert_eq!(
140 url_info.final_url.as_str(),
141 format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
142 );
143 assert_eq!(url_info.size, 1024);
144 assert_eq!(url_info.raw_name, "你好.txt");
145 assert!(url_info.supports_range);
146
147 mock_redirect.assert_async().await;
148 mock_file.assert_async().await;
149 }
150
151 #[tokio::test]
152 async fn test_filename_sources() {
153 let mut server = mockito::Server::new_async().await;
154 let client = Client::builder().no_proxy().build().unwrap();
155
156 let mock1 = server
158 .mock("GET", "/test1")
159 .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
160 .create_async()
161 .await;
162 let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
163 let (url_info, _) = client.prefetch(url).await.unwrap();
164 assert_eq!(url_info.raw_name, "test.txt");
165 mock1.assert_async().await;
166
167 let mock2 = server
169 .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
170 .create_async()
171 .await;
172 let url = Url::parse(&format!(
173 "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
174 server.url()
175 ))
176 .unwrap();
177 let (url_info, _) = client.prefetch(url).await.unwrap();
178 assert_eq!(url_info.raw_name, "好好好.pdf");
179 mock2.assert_async().await;
180 }
181
182 #[tokio::test]
183 async fn test_error_handling() {
184 let mut server = mockito::Server::new_async().await;
185 let client = Client::builder().no_proxy().build().unwrap();
186 let mock1 = server
187 .mock("GET", "/404")
188 .with_status(404)
189 .create_async()
190 .await;
191
192 let url = Url::parse(&format!("{}/404", server.url())).unwrap();
193 match client.prefetch(url).await {
194 Ok(info) => unreachable!("404 status code should not success: {info:?}"),
195 Err((err, _)) => match err {
196 HttpError::Request(e) => match e {
197 ReqwestResponseError::Reqwest(error) => unreachable!("{error:?}"),
198 ReqwestResponseError::StatusCode(status_code) => {
199 assert_eq!(status_code, StatusCode::NOT_FOUND)
200 }
201 },
202 HttpError::Chunk(_) => unreachable!(),
203 HttpError::GetHeader(_) => unreachable!(),
204 HttpError::Irrecoverable => unreachable!(),
205 HttpError::MismatchedBody(file_id) => {
206 unreachable!("404 status code should not return mismatched body: {file_id:?}")
207 }
208 },
209 }
210 mock1.assert_async().await;
211 }
212
213 #[tokio::test]
214 async fn test_concurrent_download() {
215 let mock_data = build_mock_data(300 * 1024 * 1024);
216 let mut server = mockito::Server::new_async().await;
217 let client = Client::builder().no_proxy().build().unwrap();
218 let mock_body_clone = mock_data.clone();
219 let _mock = server
220 .mock("GET", "/concurrent")
221 .with_status(206)
222 .with_header("Accept-Ranges", "bytes")
223 .with_body_from_request(move |request| {
224 if !request.has_header("Range") {
225 return mock_body_clone.clone();
226 }
227 let range = request.header("Range")[0];
228 println!("range: {range:?}");
229 range
230 .to_str()
231 .unwrap()
232 .rsplit('=')
233 .next()
234 .unwrap()
235 .split(',')
236 .map(|p| p.trim().splitn(2, '-'))
237 .map(|mut p| {
238 let start = p.next().unwrap().parse::<usize>().unwrap();
239 let end = p.next().unwrap().parse::<usize>().unwrap();
240 start..=end
241 })
242 .flat_map(|p| mock_body_clone[p].to_vec())
243 .collect()
244 })
245 .create_async()
246 .await;
247 let puller = HttpPuller::new(
248 format!("{}/concurrent", server.url()).parse().unwrap(),
249 client,
250 None,
251 FileId::default(),
252 );
253 let pusher = MemPusher::with_capacity(mock_data.len());
254 #[allow(clippy::single_range_in_vec_init)]
255 let download_chunks = vec![0..mock_data.len() as u64];
256 let result = download_multi(
257 puller,
258 pusher.clone(),
259 multi::DownloadOptions {
260 concurrent: 32,
261 retry_gap: Duration::from_secs(1),
262 push_queue_cap: 1024,
263 download_chunks: download_chunks.iter(),
264 min_chunk_size: 1,
265 },
266 );
267
268 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
269 let mut push_progress: Vec<ProgressEntry> = Vec::new();
270 while let Ok(e) = result.event_chain.recv().await {
271 match e {
272 Event::PullProgress(_, p) => {
273 pull_progress.merge_progress(p);
274 }
275 Event::PushProgress(_, p) => {
276 push_progress.merge_progress(p);
277 }
278 _ => {}
279 }
280 }
281 dbg!(&pull_progress);
282 dbg!(&push_progress);
283 assert_eq!(pull_progress, download_chunks);
284 assert_eq!(push_progress, download_chunks);
285
286 result.join().await.unwrap();
287 assert_eq!(&**pusher.receive.lock(), mock_data);
288 }
289
290 #[tokio::test]
291 async fn test_sequential_download() {
292 let mock_data = build_mock_data(300 * 1024 * 1024);
293 let mut server = mockito::Server::new_async().await;
294 let client = Client::builder().no_proxy().build().unwrap();
295 let _mock = server
296 .mock("GET", "/sequential")
297 .with_status(200)
298 .with_body(mock_data.clone())
299 .create_async()
300 .await;
301 let puller = HttpPuller::new(
302 format!("{}/sequential", server.url()).parse().unwrap(),
303 client,
304 None,
305 FileId::default(),
306 );
307 let pusher = MemPusher::with_capacity(mock_data.len());
308 #[allow(clippy::single_range_in_vec_init)]
309 let download_chunks = vec![0..mock_data.len() as u64];
310 let result = download_single(
311 puller,
312 pusher.clone(),
313 single::DownloadOptions {
314 retry_gap: Duration::from_secs(1),
315 push_queue_cap: 1024,
316 },
317 );
318
319 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
320 let mut push_progress: Vec<ProgressEntry> = Vec::new();
321 while let Ok(e) = result.event_chain.recv().await {
322 match e {
323 Event::PullProgress(_, p) => {
324 pull_progress.merge_progress(p);
325 }
326 Event::PushProgress(_, p) => {
327 push_progress.merge_progress(p);
328 }
329 _ => {}
330 }
331 }
332 dbg!(&pull_progress);
333 dbg!(&push_progress);
334 assert_eq!(pull_progress, download_chunks);
335 assert_eq!(push_progress, download_chunks);
336
337 result.join().await.unwrap();
338 assert_eq!(&**pusher.receive.lock(), mock_data);
339 }
340}