1use crate::http::{HttpClient, HttpHeaders, HttpRequestBuilder, HttpResponse};
2use fast_pull::ProgressEntry;
3use reqwest::{
4 Client, RequestBuilder, Response,
5 header::{self, HeaderMap, HeaderName, InvalidHeaderName},
6};
7use url::Url;
8
9impl HttpClient for Client {
10 type RequestBuilder = RequestBuilder;
11 fn get(&self, url: Url, range: Option<ProgressEntry>) -> Self::RequestBuilder {
12 let mut req = self.get(url);
13 if let Some(range) = range {
14 req = req.header(
15 header::RANGE,
16 format!("bytes={}-{}", range.start, range.end - 1),
17 );
18 }
19 req
20 }
21}
22
23impl HttpRequestBuilder for RequestBuilder {
24 type Response = Response;
25 type RequestError = reqwest::Error;
26 async fn send(self) -> Result<Self::Response, Self::RequestError> {
27 let res = self.send().await?;
28 res.error_for_status()
29 }
30}
31
32impl HttpResponse for Response {
33 type Headers = HeaderMap;
34 type ChunkError = reqwest::Error;
35 fn headers(&self) -> &Self::Headers {
36 self.headers()
37 }
38 fn url(&self) -> &Url {
39 self.url()
40 }
41 fn chunk(
42 &mut self,
43 ) -> impl Future<Output = Result<Option<bytes::Bytes>, Self::ChunkError>> + Send {
44 Response::chunk(self)
45 }
46}
47
48impl HttpHeaders for HeaderMap {
49 type GetHeaderError = ReqwestGetHeaderError;
50 fn get(&self, header: &str) -> Result<&str, Self::GetHeaderError> {
51 let header_name: HeaderName = header
52 .parse()
53 .map_err(ReqwestGetHeaderError::InvalidHeaderName)?;
54 let header_value = self
55 .get(&header_name)
56 .ok_or(ReqwestGetHeaderError::NotFound)?;
57 header_value.to_str().map_err(ReqwestGetHeaderError::ToStr)
58 }
59}
60
61#[derive(thiserror::Error, Debug)]
62pub enum ReqwestGetHeaderError {
63 #[error("Invalid header name {0:?}")]
64 InvalidHeaderName(InvalidHeaderName),
65 #[error("Failed to convert header value to string {0:?}")]
66 ToStr(reqwest::header::ToStrError),
67 #[error("Header not found")]
68 NotFound,
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use crate::http::{HttpError, HttpPuller, Prefetch};
75 use fast_pull::{
76 Event, MergeProgress,
77 mem::MemPusher,
78 mock::build_mock_data,
79 multi::{self, download_multi},
80 single::{self, download_single},
81 };
82 use reqwest::{Client, StatusCode};
83 use std::{num::NonZero, time::Duration};
84
85 #[tokio::test]
86 async fn test_redirect_and_content_range() {
87 let mut server = mockito::Server::new_async().await;
88
89 let mock_redirect = server
90 .mock("GET", "/redirect")
91 .with_status(301)
92 .with_header("Location", "/%e4%bd%a0%e5%a5%bd.txt")
93 .create_async()
94 .await;
95
96 let mock_file = server
97 .mock("GET", "/%e4%bd%a0%e5%a5%bd.txt")
98 .with_status(200)
99 .with_header("Content-Length", "1024")
100 .with_header("Accept-Ranges", "bytes")
101 .with_body(vec![0; 1024])
102 .create_async()
103 .await;
104
105 let client = Client::new();
106 let url = Url::parse(&format!("{}/redirect", server.url())).unwrap();
107 let (url_info, _) = client.prefetch(url).await.expect("Request should succeed");
108
109 assert_eq!(
110 url_info.final_url.as_str(),
111 format!("{}/%e4%bd%a0%e5%a5%bd.txt", server.url())
112 );
113 assert_eq!(url_info.size, 1024);
114 assert_eq!(url_info.name, "你好.txt");
115 assert!(url_info.supports_range);
116
117 mock_redirect.assert_async().await;
118 mock_file.assert_async().await;
119 }
120
121 #[tokio::test]
122 async fn test_filename_sources() {
123 let mut server = mockito::Server::new_async().await;
124
125 let mock1 = server
127 .mock("GET", "/test1")
128 .with_header("Content-Disposition", r#"attachment; filename="test.txt""#)
129 .create_async()
130 .await;
131 let url = Url::parse(&format!("{}/test1", server.url())).unwrap();
132 let (url_info, _) = Client::new().prefetch(url).await.unwrap();
133 assert_eq!(url_info.name, "test.txt");
134 mock1.assert_async().await;
135
136 let mock2 = server
138 .mock("GET", "/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf")
139 .create_async()
140 .await;
141 let url = Url::parse(&format!(
142 "{}/test2/%E5%A5%BD%E5%A5%BD%E5%A5%BD.pdf",
143 server.url()
144 ))
145 .unwrap();
146 let (url_info, _) = Client::new().prefetch(url).await.unwrap();
147 assert_eq!(url_info.name, "好好好.pdf");
148 mock2.assert_async().await;
149 }
150
151 #[tokio::test]
152 async fn test_error_handling() {
153 let mut server = mockito::Server::new_async().await;
154 let mock1 = server
155 .mock("GET", "/404")
156 .with_status(404)
157 .create_async()
158 .await;
159
160 let client = Client::new();
161 let url = Url::parse(&format!("{}/404", server.url())).unwrap();
162 match client.prefetch(url).await {
163 Ok(info) => unreachable!("404 status code should not success: {info:?}"),
164 Err(err) => match err {
165 HttpError::Request(e) => assert_eq!(e.status(), Some(StatusCode::NOT_FOUND)),
166 HttpError::Chunk(_) => unreachable!(),
167 HttpError::GetHeader(_) => unreachable!(),
168 HttpError::Irrecoverable => unreachable!(),
169 },
170 }
171 mock1.assert_async().await;
172 }
173
174 #[tokio::test]
175 async fn test_concurrent_download() {
176 let mock_data = build_mock_data(300 * 1024 * 1024);
177 let mut server = mockito::Server::new_async().await;
178 let mock_body_clone = mock_data.clone();
179 let _mock = server
180 .mock("GET", "/concurrent")
181 .with_status(206)
182 .with_header("Accept-Ranges", "bytes")
183 .with_body_from_request(move |request| {
184 if !request.has_header("Range") {
185 return mock_body_clone.clone();
186 }
187 let range = request.header("Range")[0];
188 println!("range: {range:?}");
189 range
190 .to_str()
191 .unwrap()
192 .rsplit('=')
193 .next()
194 .unwrap()
195 .split(',')
196 .map(|p| p.trim().splitn(2, '-'))
197 .map(|mut p| {
198 let start = p.next().unwrap().parse::<usize>().unwrap();
199 let end = p.next().unwrap().parse::<usize>().unwrap();
200 start..=end
201 })
202 .flat_map(|p| mock_body_clone[p].to_vec())
203 .collect()
204 })
205 .create_async()
206 .await;
207 let puller = HttpPuller::new(
208 format!("{}/concurrent", server.url()).parse().unwrap(),
209 Client::new(),
210 );
211 let pusher = MemPusher::with_capacity(mock_data.len());
212 #[allow(clippy::single_range_in_vec_init)]
213 let download_chunks = vec![0..mock_data.len() as u64];
214 let result = download_multi(
215 puller,
216 pusher.clone(),
217 multi::DownloadOptions {
218 concurrent: NonZero::new(32).unwrap(),
219 retry_gap: Duration::from_secs(1),
220 push_queue_cap: 1024,
221 download_chunks: download_chunks.clone(),
222 min_chunk_size: NonZero::new(1).unwrap(),
223 },
224 )
225 .await;
226
227 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
228 let mut push_progress: Vec<ProgressEntry> = Vec::new();
229 while let Ok(e) = result.event_chain.recv().await {
230 match e {
231 Event::PullProgress(_, p) => {
232 pull_progress.merge_progress(p);
233 }
234 Event::PushProgress(_, p) => {
235 push_progress.merge_progress(p);
236 }
237 _ => {}
238 }
239 }
240 dbg!(&pull_progress);
241 dbg!(&push_progress);
242 assert_eq!(pull_progress, download_chunks);
243 assert_eq!(push_progress, download_chunks);
244
245 result.join().await.unwrap();
246 assert_eq!(&**pusher.receive.lock(), mock_data);
247 }
248
249 #[tokio::test]
250 async fn test_sequential_download() {
251 let mock_data = build_mock_data(300 * 1024 * 1024);
252 let mut server = mockito::Server::new_async().await;
253 let _mock = server
254 .mock("GET", "/sequential")
255 .with_status(200)
256 .with_body(mock_data.clone())
257 .create_async()
258 .await;
259 let puller = HttpPuller::new(
260 format!("{}/sequential", server.url()).parse().unwrap(),
261 Client::new(),
262 );
263 let pusher = MemPusher::with_capacity(mock_data.len());
264 #[allow(clippy::single_range_in_vec_init)]
265 let download_chunks = vec![0..mock_data.len() as u64];
266 let result = download_single(
267 puller,
268 pusher.clone(),
269 single::DownloadOptions {
270 retry_gap: Duration::from_secs(1),
271 push_queue_cap: 1024,
272 },
273 )
274 .await;
275
276 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
277 let mut push_progress: Vec<ProgressEntry> = Vec::new();
278 while let Ok(e) = result.event_chain.recv().await {
279 match e {
280 Event::PullProgress(_, p) => {
281 pull_progress.merge_progress(p);
282 }
283 Event::PushProgress(_, p) => {
284 push_progress.merge_progress(p);
285 }
286 _ => {}
287 }
288 }
289 dbg!(&pull_progress);
290 dbg!(&push_progress);
291 assert_eq!(pull_progress, download_chunks);
292 assert_eq!(push_progress, download_chunks);
293
294 result.join().await.unwrap();
295 assert_eq!(&**pusher.receive.lock(), mock_data);
296 }
297}