1extern crate alloc;
2use crate::{RandReader, SeqReader};
3use alloc::{boxed::Box, format};
4use bytes::Bytes;
5use core::{
6 pin::{Pin, pin},
7 task::{Context, Poll},
8};
9use futures::{Stream, TryFutureExt, TryStream};
10use reqwest::{Client, Response, header};
11use url::Url;
12
13#[derive(Clone)]
14pub struct ReqwestReader {
15 pub(crate) client: Client,
16 url: Url,
17}
18
19impl ReqwestReader {
20 pub fn new(url: Url, client: Client) -> Self {
21 Self { client, url }
22 }
23}
24
25impl RandReader for ReqwestReader {
26 type Error = reqwest::Error;
27 fn read(
28 &mut self,
29 range: &crate::ProgressEntry,
30 ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
31 ReqwestStream {
32 client: self.client.clone(),
33 url: self.url.clone(),
34 start: range.start,
35 end: range.end,
36 resp: ResponseState::None,
37 max_retries: 3,
38 curr_retries: 0,
39 }
40 }
41}
42type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
43enum ResponseState {
44 Pending(ResponseFut),
45 Ready(Response),
46 None,
47}
48struct ReqwestStream {
49 client: Client,
50 url: Url,
51 start: u64,
52 end: u64,
53 resp: ResponseState,
54 max_retries: usize,
55 curr_retries: usize,
56}
57impl Stream for ReqwestStream {
58 type Item = Result<Bytes, reqwest::Error>;
59 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60 let chunk_global;
61 match &mut self.resp {
62 ResponseState::Pending(resp) => {
63 return match resp.try_poll_unpin(cx) {
64 Poll::Ready(resp) => match resp {
65 Ok(resp) => {
66 self.resp = ResponseState::Ready(resp);
67 self.poll_next(cx)
68 }
69 Err(e) => {
70 self.resp = ResponseState::None;
71 Poll::Ready(Some(Err(e)))
72 }
73 },
74 Poll::Pending => Poll::Pending,
75 };
76 }
77 ResponseState::None => {
78 let resp = self
79 .client
80 .get(self.url.clone())
81 .header(
82 header::RANGE,
83 format!("bytes={}-{}", self.start, self.end - 1),
84 )
85 .send();
86 self.resp = ResponseState::Pending(Box::pin(resp));
87 return self.poll_next(cx);
88 }
89 ResponseState::Ready(resp) => {
90 let mut chunk = pin!(resp.chunk());
91 match chunk.try_poll_unpin(cx) {
92 Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
93 Poll::Ready(Ok(None)) => return Poll::Ready(None),
94 Poll::Ready(Err(e)) => chunk_global = Err(e),
95 Poll::Pending => return Poll::Pending,
96 };
97 }
98 };
99 match chunk_global {
100 Ok(chunk) => {
101 self.start += chunk.len() as u64;
102 Poll::Ready(Some(Ok(chunk)))
103 }
104 Err(e) => {
105 self.curr_retries += 1;
106 if self.curr_retries >= self.max_retries {
107 self.curr_retries = 0;
108 self.resp = ResponseState::None;
109 }
110 Poll::Ready(Some(Err(e)))
111 }
112 }
113 }
114}
115
116impl SeqReader for ReqwestReader {
117 type Error = reqwest::Error;
118 fn read(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
119 let req = self.client.get(self.url.clone());
120 Box::pin(async move {
121 let resp = req.send().await?;
122 Ok(resp.bytes_stream())
123 })
124 .try_flatten_stream()
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 extern crate std;
131 use super::*;
132 use crate::{
133 Event, MergeProgress, ProgressEntry,
134 mock::{MockRandWriter, MockSeqWriter, build_mock_data},
135 multi::{self, download_multi},
136 reqwest::ReqwestReader,
137 single::{self, download_single},
138 };
139 use alloc::vec;
140 use core::{num::NonZeroUsize, time::Duration};
141 use reqwest::Client;
142 use std::{dbg, println};
143 use vec::Vec;
144
145 #[tokio::test]
146 async fn test_concurrent_download() {
147 let mock_data = build_mock_data(300 * 1024 * 1024);
148 let mut server = mockito::Server::new_async().await;
149 let mock_body_clone = mock_data.clone();
150 let _mock = server
151 .mock("GET", "/concurrent")
152 .with_status(206)
153 .with_body_from_request(move |request| {
154 if !request.has_header("Range") {
155 return mock_body_clone.clone();
156 }
157 let range = request.header("Range")[0];
158 println!("range: {range:?}");
159 range
160 .to_str()
161 .unwrap()
162 .rsplit('=')
163 .next()
164 .unwrap()
165 .split(',')
166 .map(|p| p.trim().splitn(2, '-'))
167 .map(|mut p| {
168 let start = p.next().unwrap().parse::<usize>().unwrap();
169 let end = p.next().unwrap().parse::<usize>().unwrap();
170 start..=end
171 })
172 .flat_map(|p| mock_body_clone[p].to_vec())
173 .collect()
174 })
175 .create_async()
176 .await;
177 let reader = ReqwestReader::new(
178 format!("{}/concurrent", server.url()).parse().unwrap(),
179 Client::new(),
180 );
181 let writer = MockRandWriter::new(&mock_data);
182 #[allow(clippy::single_range_in_vec_init)]
183 let download_chunks = vec![0..mock_data.len() as u64];
184 let result = download_multi(
185 reader,
186 writer.clone(),
187 multi::DownloadOptions {
188 concurrent: NonZeroUsize::new(32).unwrap(),
189 retry_gap: Duration::from_secs(1),
190 write_queue_cap: 1024,
191 download_chunks: download_chunks.clone(),
192 },
193 )
194 .await;
195
196 let mut download_progress: Vec<ProgressEntry> = Vec::new();
197 let mut write_progress: Vec<ProgressEntry> = Vec::new();
198 while let Ok(e) = result.event_chain.recv().await {
199 match e {
200 Event::ReadProgress(_, p) => {
201 download_progress.merge_progress(p);
202 }
203 Event::WriteProgress(_, p) => {
204 write_progress.merge_progress(p);
205 }
206 _ => {}
207 }
208 }
209 dbg!(&download_progress);
210 dbg!(&write_progress);
211 assert_eq!(download_progress, download_chunks);
212 assert_eq!(write_progress, download_chunks);
213
214 result.join().await.unwrap();
215 writer.assert().await;
216 }
217
218 #[tokio::test]
219 async fn test_sequential_download() {
220 let mock_data = build_mock_data(300 * 1024 * 1024);
221 let mut server = mockito::Server::new_async().await;
222 let _mock = server
223 .mock("GET", "/sequential")
224 .with_status(200)
225 .with_body(mock_data.clone())
226 .create_async()
227 .await;
228 let reader = ReqwestReader::new(
229 format!("{}/sequential", server.url()).parse().unwrap(),
230 Client::new(),
231 );
232 let writer = MockSeqWriter::new(&mock_data);
233 #[allow(clippy::single_range_in_vec_init)]
234 let download_chunks = vec![0..mock_data.len() as u64];
235 let result = download_single(
236 reader,
237 writer.clone(),
238 single::DownloadOptions {
239 retry_gap: Duration::from_secs(1),
240 write_queue_cap: 1024,
241 },
242 )
243 .await;
244
245 let mut download_progress: Vec<ProgressEntry> = Vec::new();
246 let mut write_progress: Vec<ProgressEntry> = Vec::new();
247 while let Ok(e) = result.event_chain.recv().await {
248 match e {
249 Event::ReadProgress(_, p) => {
250 download_progress.merge_progress(p);
251 }
252 Event::WriteProgress(_, p) => {
253 write_progress.merge_progress(p);
254 }
255 _ => {}
256 }
257 }
258 dbg!(&download_progress);
259 dbg!(&write_progress);
260 assert_eq!(download_progress, download_chunks);
261 assert_eq!(write_progress, download_chunks);
262
263 result.join().await.unwrap();
264 writer.assert().await;
265 }
266}