1extern crate alloc;
2use crate::{RandReader, SeqReader};
3use alloc::{boxed::Box, format};
4use bytes::Bytes;
5use core::{
6 pin::{Pin, pin},
7 task::{Context, Poll},
8};
9use futures::{Stream, TryFutureExt, TryStream};
10use reqwest::{Client, Response, header};
11use url::Url;
12
13#[derive(Clone)]
14pub struct ReqwestReader {
15 pub(crate) client: Client,
16 url: Url,
17}
18
19impl ReqwestReader {
20 pub fn new(url: Url, client: Client) -> Self {
21 Self { client, url }
22 }
23}
24
25impl RandReader for ReqwestReader {
26 type Error = reqwest::Error;
27 fn read(
28 &mut self,
29 range: &crate::ProgressEntry,
30 ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
31 ReqwestStream {
32 client: self.client.clone(),
33 url: self.url.clone(),
34 start: range.start,
35 end: range.end,
36 resp: ResponseState::None,
37 }
38 }
39}
40type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
41enum ResponseState {
42 Pending(ResponseFut),
43 Ready(Response),
44 None,
45}
46struct ReqwestStream {
47 client: Client,
48 url: Url,
49 start: u64,
50 end: u64,
51 resp: ResponseState,
52}
53impl Stream for ReqwestStream {
54 type Item = Result<Bytes, reqwest::Error>;
55 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56 let chunk_global;
57 match &mut self.resp {
58 ResponseState::Pending(resp) => {
59 return match resp.try_poll_unpin(cx) {
60 Poll::Ready(resp) => match resp {
61 Ok(resp) => {
62 self.resp = ResponseState::Ready(resp);
63 self.poll_next(cx)
64 }
65 Err(e) => {
66 self.resp = ResponseState::None;
67 Poll::Ready(Some(Err(e)))
68 }
69 },
70 Poll::Pending => Poll::Pending,
71 };
72 }
73 ResponseState::None => {
74 let resp = self
75 .client
76 .get(self.url.clone())
77 .header(
78 header::RANGE,
79 format!("bytes={}-{}", self.start, self.end - 1),
80 )
81 .send();
82 self.resp = ResponseState::Pending(Box::pin(resp));
83 return self.poll_next(cx);
84 }
85 ResponseState::Ready(resp) => {
86 if let Err(e) = resp.error_for_status_ref() {
87 self.resp = ResponseState::None;
88 return Poll::Ready(Some(Err(e)));
89 }
90 let mut chunk = pin!(resp.chunk());
91 match chunk.try_poll_unpin(cx) {
92 Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
93 Poll::Ready(Ok(None)) => return Poll::Ready(None),
94 Poll::Ready(Err(e)) => chunk_global = Err(e),
95 Poll::Pending => return Poll::Pending,
96 };
97 }
98 };
99 match chunk_global {
100 Ok(chunk) => {
101 self.start += chunk.len() as u64;
102 Poll::Ready(Some(Ok(chunk)))
103 }
104 Err(e) => {
105 self.resp = ResponseState::None;
106 Poll::Ready(Some(Err(e)))
107 }
108 }
109 }
110}
111
112impl SeqReader for ReqwestReader {
113 type Error = reqwest::Error;
114 fn read(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
115 let req = self.client.get(self.url.clone());
116 Box::pin(async move {
117 let resp = req.send().await?;
118 Ok(resp.bytes_stream())
119 })
120 .try_flatten_stream()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 extern crate std;
127 use super::*;
128 use crate::{
129 Event, MergeProgress, ProgressEntry,
130 mock::{MockRandWriter, MockSeqWriter, build_mock_data},
131 multi::{self, download_multi},
132 reqwest::ReqwestReader,
133 single::{self, download_single},
134 };
135 use alloc::vec;
136 use core::{num::NonZeroUsize, time::Duration};
137 use reqwest::Client;
138 use std::{dbg, println};
139 use vec::Vec;
140
141 #[tokio::test]
142 async fn test_concurrent_download() {
143 let mock_data = build_mock_data(300 * 1024 * 1024);
144 let mut server = mockito::Server::new_async().await;
145 let mock_body_clone = mock_data.clone();
146 let _mock = server
147 .mock("GET", "/concurrent")
148 .with_status(206)
149 .with_body_from_request(move |request| {
150 if !request.has_header("Range") {
151 return mock_body_clone.clone();
152 }
153 let range = request.header("Range")[0];
154 println!("range: {range:?}");
155 range
156 .to_str()
157 .unwrap()
158 .rsplit('=')
159 .next()
160 .unwrap()
161 .split(',')
162 .map(|p| p.trim().splitn(2, '-'))
163 .map(|mut p| {
164 let start = p.next().unwrap().parse::<usize>().unwrap();
165 let end = p.next().unwrap().parse::<usize>().unwrap();
166 start..=end
167 })
168 .flat_map(|p| mock_body_clone[p].to_vec())
169 .collect()
170 })
171 .create_async()
172 .await;
173 let reader = ReqwestReader::new(
174 format!("{}/concurrent", server.url()).parse().unwrap(),
175 Client::new(),
176 );
177 let writer = MockRandWriter::new(&mock_data);
178 #[allow(clippy::single_range_in_vec_init)]
179 let download_chunks = vec![0..mock_data.len() as u64];
180 let result = download_multi(
181 reader,
182 writer.clone(),
183 multi::DownloadOptions {
184 concurrent: NonZeroUsize::new(32).unwrap(),
185 retry_gap: Duration::from_secs(1),
186 write_queue_cap: 1024,
187 download_chunks: download_chunks.clone(),
188 },
189 )
190 .await;
191
192 let mut download_progress: Vec<ProgressEntry> = Vec::new();
193 let mut write_progress: Vec<ProgressEntry> = Vec::new();
194 while let Ok(e) = result.event_chain.recv().await {
195 match e {
196 Event::ReadProgress(_, p) => {
197 download_progress.merge_progress(p);
198 }
199 Event::WriteProgress(_, p) => {
200 write_progress.merge_progress(p);
201 }
202 _ => {}
203 }
204 }
205 dbg!(&download_progress);
206 dbg!(&write_progress);
207 assert_eq!(download_progress, download_chunks);
208 assert_eq!(write_progress, download_chunks);
209
210 result.join().await.unwrap();
211 writer.assert().await;
212 }
213
214 #[tokio::test]
215 async fn test_sequential_download() {
216 let mock_data = build_mock_data(300 * 1024 * 1024);
217 let mut server = mockito::Server::new_async().await;
218 let _mock = server
219 .mock("GET", "/sequential")
220 .with_status(200)
221 .with_body(mock_data.clone())
222 .create_async()
223 .await;
224 let reader = ReqwestReader::new(
225 format!("{}/sequential", server.url()).parse().unwrap(),
226 Client::new(),
227 );
228 let writer = MockSeqWriter::new(&mock_data);
229 #[allow(clippy::single_range_in_vec_init)]
230 let download_chunks = vec![0..mock_data.len() as u64];
231 let result = download_single(
232 reader,
233 writer.clone(),
234 single::DownloadOptions {
235 retry_gap: Duration::from_secs(1),
236 write_queue_cap: 1024,
237 },
238 )
239 .await;
240
241 let mut download_progress: Vec<ProgressEntry> = Vec::new();
242 let mut write_progress: Vec<ProgressEntry> = Vec::new();
243 while let Ok(e) = result.event_chain.recv().await {
244 match e {
245 Event::ReadProgress(_, p) => {
246 download_progress.merge_progress(p);
247 }
248 Event::WriteProgress(_, p) => {
249 write_progress.merge_progress(p);
250 }
251 _ => {}
252 }
253 }
254 dbg!(&download_progress);
255 dbg!(&write_progress);
256 assert_eq!(download_progress, download_chunks);
257 assert_eq!(write_progress, download_chunks);
258
259 result.join().await.unwrap();
260 writer.assert().await;
261 }
262}