1use crate::{RandReader, SeqReader};
2use bytes::Bytes;
3use core::{
4 pin::{Pin, pin},
5 task::{Context, Poll},
6};
7use futures::{Stream, TryFutureExt, TryStream};
8use reqwest::{Client, Response, header};
9use url::Url;
10
11#[derive(Clone)]
12pub struct ReqwestReader {
13 pub(crate) client: Client,
14 url: Url,
15}
16
17impl ReqwestReader {
18 pub fn new(url: Url, client: Client) -> Self {
19 Self { client, url }
20 }
21}
22
23impl RandReader for ReqwestReader {
24 type Error = reqwest::Error;
25 fn read(
26 &mut self,
27 range: &crate::ProgressEntry,
28 ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
29 ReqwestStream {
30 client: self.client.clone(),
31 url: self.url.clone(),
32 start: range.start,
33 end: range.end,
34 resp: ResponseState::None,
35 max_retries: 3,
36 curr_retries: 0,
37 }
38 }
39}
40type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
41enum ResponseState {
42 Pending(ResponseFut),
43 Ready(Response),
44 None,
45}
46struct ReqwestStream {
47 client: Client,
48 url: Url,
49 start: u64,
50 end: u64,
51 resp: ResponseState,
52 max_retries: usize,
53 curr_retries: usize,
54}
55impl Stream for ReqwestStream {
56 type Item = Result<Bytes, reqwest::Error>;
57 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
58 let chunk_global;
59 match &mut self.resp {
60 ResponseState::Pending(resp) => match resp.try_poll_unpin(cx) {
61 Poll::Ready(resp) => {
62 return match resp {
63 Ok(resp) => {
64 self.resp = ResponseState::Ready(resp);
65 self.poll_next(cx)
66 }
67 Err(e) => {
68 self.resp = ResponseState::None;
69 Poll::Ready(Some(Err(e)))
70 }
71 };
72 }
73 Poll::Pending => return Poll::Pending,
74 },
75 ResponseState::None => {
76 let resp = self
77 .client
78 .get(self.url.clone())
79 .header(
80 header::RANGE,
81 format!("bytes={}-{}", self.start, self.end - 1),
82 )
83 .send();
84 self.resp = ResponseState::Pending(Box::pin(resp));
85 return self.poll_next(cx);
86 }
87 ResponseState::Ready(resp) => {
88 let mut chunk = pin!(resp.chunk());
89 match chunk.try_poll_unpin(cx) {
90 Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
91 Poll::Ready(Ok(None)) => return Poll::Ready(None),
92 Poll::Ready(Err(e)) => chunk_global = Err(e),
93 Poll::Pending => return Poll::Pending,
94 };
95 }
96 };
97 match chunk_global {
98 Ok(chunk) => {
99 self.start += chunk.len() as u64;
100 Poll::Ready(Some(Ok(chunk)))
101 }
102 Err(e) => {
103 self.curr_retries += 1;
104 if self.curr_retries >= self.max_retries {
105 self.curr_retries = 0;
106 self.resp = ResponseState::None;
107 }
108 Poll::Ready(Some(Err(e)))
109 }
110 }
111 }
112}
113
114impl SeqReader for ReqwestReader {
115 type Error = reqwest::Error;
116 fn read(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
117 let req = self.client.get(self.url.clone());
118 Box::pin(async move {
119 let resp = req.send().await?;
120 Ok(resp.bytes_stream())
121 })
122 .try_flatten_stream()
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use crate::{
129 Event, MergeProgress, ProgressEntry,
130 mock::{MockRandWriter, MockSeqWriter, build_mock_data},
131 multi::{self, download_multi},
132 reqwest::ReqwestReader,
133 single::{self, download_single},
134 };
135 use core::{num::NonZeroUsize, time::Duration};
136 use reqwest::Client;
137
138 #[tokio::test]
139 async fn test_concurrent_download() {
140 let mock_data = build_mock_data(300 * 1024 * 1024);
141 let mut server = mockito::Server::new_async().await;
142 let mock_body_clone = mock_data.clone();
143 let _mock = server
144 .mock("GET", "/concurrent")
145 .with_status(206)
146 .with_body_from_request(move |request| {
147 if !request.has_header("Range") {
148 return mock_body_clone.clone();
149 }
150 let range = request.header("Range")[0];
151 println!("range: {range:?}");
152 range
153 .to_str()
154 .unwrap()
155 .rsplit('=')
156 .next()
157 .unwrap()
158 .split(',')
159 .map(|p| p.trim().splitn(2, '-'))
160 .map(|mut p| {
161 let start = p.next().unwrap().parse::<usize>().unwrap();
162 let end = p.next().unwrap().parse::<usize>().unwrap();
163 start..=end
164 })
165 .flat_map(|p| mock_body_clone[p].to_vec())
166 .collect()
167 })
168 .create_async()
169 .await;
170 let reader = ReqwestReader::new(
171 format!("{}/concurrent", server.url()).parse().unwrap(),
172 Client::new(),
173 );
174 let writer = MockRandWriter::new(&mock_data);
175 #[allow(clippy::single_range_in_vec_init)]
176 let download_chunks = vec![0..mock_data.len() as u64];
177 let result = download_multi(
178 reader,
179 writer.clone(),
180 multi::DownloadOptions {
181 concurrent: NonZeroUsize::new(32).unwrap(),
182 retry_gap: Duration::from_secs(1),
183 write_queue_cap: 1024,
184 download_chunks: download_chunks.clone(),
185 },
186 )
187 .await;
188
189 let mut download_progress: Vec<ProgressEntry> = Vec::new();
190 let mut write_progress: Vec<ProgressEntry> = Vec::new();
191 while let Ok(e) = result.event_chain.recv().await {
192 match e {
193 Event::ReadProgress(_, p) => {
194 download_progress.merge_progress(p);
195 }
196 Event::WriteProgress(_, p) => {
197 write_progress.merge_progress(p);
198 }
199 _ => {}
200 }
201 }
202 dbg!(&download_progress);
203 dbg!(&write_progress);
204 assert_eq!(download_progress, download_chunks);
205 assert_eq!(write_progress, download_chunks);
206
207 result.join().await.unwrap();
208 writer.assert().await;
209 }
210
211 #[tokio::test]
212 async fn test_sequential_download() {
213 let mock_data = build_mock_data(300 * 1024 * 1024);
214 let mut server = mockito::Server::new_async().await;
215 let _mock = server
216 .mock("GET", "/sequential")
217 .with_status(200)
218 .with_body(mock_data.clone())
219 .create_async()
220 .await;
221 let reader = ReqwestReader::new(
222 format!("{}/sequential", server.url()).parse().unwrap(),
223 Client::new(),
224 );
225 let writer = MockSeqWriter::new(&mock_data);
226 #[allow(clippy::single_range_in_vec_init)]
227 let download_chunks = vec![0..mock_data.len() as u64];
228 let result = download_single(
229 reader,
230 writer.clone(),
231 single::DownloadOptions {
232 retry_gap: Duration::from_secs(1),
233 write_queue_cap: 1024,
234 },
235 )
236 .await;
237
238 let mut download_progress: Vec<ProgressEntry> = Vec::new();
239 let mut write_progress: Vec<ProgressEntry> = Vec::new();
240 while let Ok(e) = result.event_chain.recv().await {
241 match e {
242 Event::ReadProgress(_, p) => {
243 download_progress.merge_progress(p);
244 }
245 Event::WriteProgress(_, p) => {
246 write_progress.merge_progress(p);
247 }
248 _ => {}
249 }
250 }
251 dbg!(&download_progress);
252 dbg!(&write_progress);
253 assert_eq!(download_progress, download_chunks);
254 assert_eq!(write_progress, download_chunks);
255
256 result.join().await.unwrap();
257 writer.assert().await;
258 }
259}