1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use httpdate::parse_http_date;
4use reqwest::{
5 Client, RequestBuilder, Response,
6 header::{self, HeaderMap, HeaderName, InvalidHeaderName},
7};
8use std::time::{Duration, SystemTime};
9use url::Url;
10
11impl HttpClient for Client {
12 type RequestBuilder = RequestBuilder;
13 fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
14 let mut req = self.get(url);
15 if let Some(range) = range {
16 req = req.header(
17 header::RANGE,
18 format!("bytes={}-{}", range.start, range.end - 1),
19 );
20 }
21 req
22 }
23}
24
25impl HttpRequestBuilder for RequestBuilder {
26 type Response = Response;
27 type RequestError = ReqwestResponseError;
28 async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
29 let res = self
30 .send()
31 .await
32 .map_err(|e| (ReqwestResponseError::Reqwest(e), None))?;
33 let status = res.status();
34 if status.is_success() {
35 Ok(res)
36 } else {
37 let retry_after = res
38 .headers()
39 .get(header::RETRY_AFTER)
40 .and_then(|r| r.to_str().ok())
41 .and_then(|r| match r.parse() {
42 Ok(r) => Some(Duration::from_secs(r)),
43 Err(_) => match parse_http_date(r) {
44 Ok(target_time) => target_time.duration_since(SystemTime::now()).ok(),
45 Err(_) => None,
46 },
47 });
48 Err((ReqwestResponseError::StatusCode(status), retry_after))
49 }
50 }
51}
52
53impl HttpResponse for Response {
54 type Headers = HeaderMap;
55 type ChunkError = reqwest::Error;
56 fn headers(&self) -> &Self::Headers {
57 self.headers()
58 }
59 fn url(&self) -> &Url {
60 self.url()
61 }
62 fn chunk(
63 &mut self,
64 ) -> impl Future<Output = Result<Option<bytes::Bytes>, Self::ChunkError>> + Send {
65 self.chunk()
66 }
67}
68
69impl HttpHeaders for HeaderMap {
70 type GetHeaderError = ReqwestGetHeaderError;
71 fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
72 let header_name: HeaderName = header
73 .parse()
74 .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
75 let header_value = self
76 .get(&header_name)
77 .ok_or(ReqwestGetHeaderError::NotFound)?;
78 header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
79 }
80}
81
82#[derive(thiserror::Error, Debug)]
83pub enum ReqwestGetHeaderError {
84 #[error("Invalid header name {0:?}")]
85 InvalidHeaderName(InvalidHeaderName),
86 #[error("Failed to convert header value to string {0:?}")]
87 ToStr(reqwest::header::ToStrError),
88 #[error("Header not found")]
89 NotFound,
90}
91
92#[derive(thiserror::Error, Debug)]
93pub enum ReqwestResponseError {
94 #[error("Reqwest error {0:?}")]
95 Reqwest(reqwest::Error),
96 #[error("Status code {0:?}")]
97 StatusCode(reqwest::StatusCode),
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::{
104 http::{HttpError, HttpPuller, Prefetch},
105 url_info::FileId,
106 };
107 use fast_pull::{
108 Event, MergeProgress,
109 mem::MemPusher,
110 mock::build_mock_data,
111 multi::{self, download_multi},
112 single::{self, download_single},
113 };
114 use reqwest::{Client, StatusCode};
115 use std::{num::NonZero, time::Duration};
116
117 #[tokio::test]
118 async fn test_redirect_and_content_range() {
119 let mut server = mockito::Server::new_async().await;
120
121 let mock_redirect = server
122 .mock("GET", "/redirect")
123 .with_status(301)
124 .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
125 .create_async()
126 .await;
127
128 let mock_file = server
129 .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
130 .with_status(200)
131 .with_header("Content-Length", "1024")
132 .with_header("Accept-Ranges", "bytes")
133 .with_body(vec![0; 1024])
134 .create_async()
135 .await;
136
137 let client = Client::new();
138 let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
139 let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
140
141 assert_eq!(
142 url_info.final_url.as_str(),
143 format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
144 );
145 assert_eq!(url_info.size, 1024);
146 assert_eq!(url_info.raw_name, "你好.txt");
147 assert!(url_info.supports_range);
148
149 mock_redirect.assert_async().await;
150 mock_file.assert_async().await;
151 }
152
153 #[tokio::test]
154 async fn test_filename_sources() {
155 let mut server = mockito::Server::new_async().await;
156
157 let mock1 = server
159 .mock("GET", "/test1")
160 .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
161 .create_async()
162 .await;
163 let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
164 let (url_info, _) = Client::new().prefetch(url).await.unwrap();
165 assert_eq!(url_info.raw_name, "test.txt");
166 mock1.assert_async().await;
167
168 let mock2 = server
170 .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
171 .create_async()
172 .await;
173 let url = Url::parse(&format!(
174 "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
175 server.url()
176 ))
177 .unwrap();
178 let (url_info, _) = Client::new().prefetch(url).await.unwrap();
179 assert_eq!(url_info.raw_name, "好好好.pdf");
180 mock2.assert_async().await;
181 }
182
183 #[tokio::test]
184 async fn test_error_handling() {
185 let mut server = mockito::Server::new_async().await;
186 let mock1 = server
187 .mock("GET", "/404")
188 .with_status(404)
189 .create_async()
190 .await;
191
192 let client = Client::new();
193 let url = Url::parse(&format!("{}/404", server.url())).unwrap();
194 match client.prefetch(url).await {
195 Ok(info) => unreachable!("404 status code should not success: {info:?}"),
196 Err((err, _)) => match err {
197 HttpError::Request(e) => match e {
198 ReqwestResponseError::Reqwest(error) => unreachable!("{error:?}"),
199 ReqwestResponseError::StatusCode(status_code) => {
200 assert_eq!(status_code, StatusCode::NOT_FOUND)
201 }
202 },
203 HttpError::Chunk(_) => unreachable!(),
204 HttpError::GetHeader(_) => unreachable!(),
205 HttpError::Irrecoverable => unreachable!(),
206 HttpError::MismatchedBody(file_id) => {
207 unreachable!("404 status code should not return mismatched body: {file_id:?}")
208 }
209 },
210 }
211 mock1.assert_async().await;
212 }
213
214 #[tokio::test]
215 async fn test_concurrent_download() {
216 let mock_data = build_mock_data(300 * 1024 * 1024);
217 let mut server = mockito::Server::new_async().await;
218 let mock_body_clone = mock_data.clone();
219 let _mock = server
220 .mock("GET", "/concurrent")
221 .with_status(206)
222 .with_header("Accept-Ranges", "bytes")
223 .with_body_from_request(move |request| {
224 if !request.has_header("Range") {
225 return mock_body_clone.clone();
226 }
227 let range = request.header("Range")[0];
228 println!("range: {range:?}");
229 range
230 .to_str()
231 .unwrap()
232 .rsplit('=')
233 .next()
234 .unwrap()
235 .split(',')
236 .map(|p| p.trim().splitn(2, '-'))
237 .map(|mut p| {
238 let start = p.next().unwrap().parse::<usize>().unwrap();
239 let end = p.next().unwrap().parse::<usize>().unwrap();
240 start..=end
241 })
242 .flat_map(|p| mock_body_clone[p].to_vec())
243 .collect()
244 })
245 .create_async()
246 .await;
247 let puller = HttpPuller::new(
248 format!("{}/concurrent", server.url()).parse().unwrap(),
249 Client::new(),
250 None,
251 FileId::empty(),
252 );
253 let pusher = MemPusher::with_capacity(mock_data.len());
254 #[allow(clippy::single_range_in_vec_init)]
255 let download_chunks = vec![0..mock_data.len() as u64];
256 let result = download_multi(
257 puller,
258 pusher.clone(),
259 multi::DownloadOptions {
260 concurrent: NonZero::new(32).unwrap(),
261 retry_gap: Duration::from_secs(1),
262 push_queue_cap: 1024,
263 download_chunks: download_chunks.clone(),
264 min_chunk_size: NonZero::new(1).unwrap(),
265 },
266 )
267 .await;
268
269 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
270 let mut push_progress: Vec<ProgressEntry> = Vec::new();
271 while let Ok(e) = result.event_chain.recv().await {
272 match e {
273 Event::PullProgress(_, p) => {
274 pull_progress.merge_progress(p);
275 }
276 Event::PushProgress(_, p) => {
277 push_progress.merge_progress(p);
278 }
279 _ => {}
280 }
281 }
282 dbg!(&pull_progress);
283 dbg!(&push_progress);
284 assert_eq!(pull_progress, download_chunks);
285 assert_eq!(push_progress, download_chunks);
286
287 result.join().await.unwrap();
288 assert_eq!(&**pusher.receive.lock(), mock_data);
289 }
290
291 #[tokio::test]
292 async fn test_sequential_download() {
293 let mock_data = build_mock_data(300 * 1024 * 1024);
294 let mut server = mockito::Server::new_async().await;
295 let _mock = server
296 .mock("GET", "/sequential")
297 .with_status(200)
298 .with_body(mock_data.clone())
299 .create_async()
300 .await;
301 let puller = HttpPuller::new(
302 format!("{}/sequential", server.url()).parse().unwrap(),
303 Client::new(),
304 None,
305 FileId::empty(),
306 );
307 let pusher = MemPusher::with_capacity(mock_data.len());
308 #[allow(clippy::single_range_in_vec_init)]
309 let download_chunks = vec![0..mock_data.len() as u64];
310 let result = download_single(
311 puller,
312 pusher.clone(),
313 single::DownloadOptions {
314 retry_gap: Duration::from_secs(1),
315 push_queue_cap: 1024,
316 },
317 )
318 .await;
319
320 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
321 let mut push_progress: Vec<ProgressEntry> = Vec::new();
322 while let Ok(e) = result.event_chain.recv().await {
323 match e {
324 Event::PullProgress(_, p) => {
325 pull_progress.merge_progress(p);
326 }
327 Event::PushProgress(_, p) => {
328 push_progress.merge_progress(p);
329 }
330 _ => {}
331 }
332 }
333 dbg!(&pull_progress);
334 dbg!(&push_progress);
335 assert_eq!(pull_progress, download_chunks);
336 assert_eq!(push_progress, download_chunks);
337
338 result.join().await.unwrap();
339 assert_eq!(&**pusher.receive.lock(), mock_data);
340 }
341}