1extern crate alloc;
2use crate::{RandPuller, SeqPuller};
3use alloc::{boxed::Box, format};
4use bytes::Bytes;
5use core::{
6 pin::{Pin, pin},
7 task::{Context, Poll},
8};
9use futures::{Stream, TryFutureExt, TryStream};
10use reqwest::{Client, Response, header};
11use url::Url;
12
13#[derive(Clone)]
14pub struct ReqwestPuller {
15 pub(crate) client: Client,
16 url: Url,
17}
18
19impl ReqwestPuller {
20 pub fn new(url: Url, client: Client) -> Self {
21 Self { client, url }
22 }
23}
24
25impl RandPuller for ReqwestPuller {
26 type Error = reqwest::Error;
27 fn pull(
28 &mut self,
29 range: &crate::ProgressEntry,
30 ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
31 ReqwestStream {
32 client: self.client.clone(),
33 url: self.url.clone(),
34 start: range.start,
35 end: range.end,
36 resp: ResponseState::None,
37 }
38 }
39}
40type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
41enum ResponseState {
42 Pending(ResponseFut),
43 Ready(Response),
44 None,
45}
46struct ReqwestStream {
47 client: Client,
48 url: Url,
49 start: u64,
50 end: u64,
51 resp: ResponseState,
52}
53impl Stream for ReqwestStream {
54 type Item = Result<Bytes, reqwest::Error>;
55 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56 let chunk_global;
57 match &mut self.resp {
58 ResponseState::Pending(resp) => {
59 return match resp.try_poll_unpin(cx) {
60 Poll::Ready(resp) => match resp {
61 Ok(resp) => {
62 self.resp = ResponseState::Ready(resp);
63 self.poll_next(cx)
64 }
65 Err(e) => {
66 self.resp = ResponseState::None;
67 Poll::Ready(Some(Err(e)))
68 }
69 },
70 Poll::Pending => Poll::Pending,
71 };
72 }
73 ResponseState::None => {
74 let resp = self
75 .client
76 .get(self.url.clone())
77 .header(
78 header::RANGE,
79 format!("bytes={}-{}", self.start, self.end - 1),
80 )
81 .send();
82 self.resp = ResponseState::Pending(Box::pin(resp));
83 return self.poll_next(cx);
84 }
85 ResponseState::Ready(resp) => {
86 if let Err(e) = resp.error_for_status_ref() {
87 self.resp = ResponseState::None;
88 return Poll::Ready(Some(Err(e)));
89 }
90 let mut chunk = pin!(resp.chunk());
91 match chunk.try_poll_unpin(cx) {
92 Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
93 Poll::Ready(Ok(None)) => return Poll::Ready(None),
94 Poll::Ready(Err(e)) => chunk_global = Err(e),
95 Poll::Pending => return Poll::Pending,
96 };
97 }
98 };
99 match chunk_global {
100 Ok(chunk) => {
101 self.start += chunk.len() as u64;
102 Poll::Ready(Some(Ok(chunk)))
103 }
104 Err(e) => {
105 self.resp = ResponseState::None;
106 Poll::Ready(Some(Err(e)))
107 }
108 }
109 }
110}
111
112impl SeqPuller for ReqwestPuller {
113 type Error = reqwest::Error;
114 fn pull(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
115 let req = self.client.get(self.url.clone());
116 Box::pin(async move {
117 let resp = req.send().await?;
118 Ok(resp.bytes_stream())
119 })
120 .try_flatten_stream()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 extern crate std;
127 use super::*;
128 use crate::{
129 Event, MergeProgress, ProgressEntry,
130 mock::{MockRandPusher, MockSeqPusher, build_mock_data},
131 multi::{self, download_multi},
132 reqwest::ReqwestPuller,
133 single::{self, download_single},
134 };
135 use alloc::vec;
136 use core::{num::NonZeroUsize, time::Duration};
137 use reqwest::Client;
138 use std::{dbg, println};
139 use vec::Vec;
140
141 #[tokio::test]
142 async fn test_concurrent_download() {
143 let mock_data = build_mock_data(300 * 1024 * 1024);
144 let mut server = mockito::Server::new_async().await;
145 let mock_body_clone = mock_data.clone();
146 let _mock = server
147 .mock("GET", "/concurrent")
148 .with_status(206)
149 .with_body_from_request(move |request| {
150 if !request.has_header("Range") {
151 return mock_body_clone.clone();
152 }
153 let range = request.header("Range")[0];
154 println!("range: {range:?}");
155 range
156 .to_str()
157 .unwrap()
158 .rsplit('=')
159 .next()
160 .unwrap()
161 .split(',')
162 .map(|p| p.trim().splitn(2, '-'))
163 .map(|mut p| {
164 let start = p.next().unwrap().parse::<usize>().unwrap();
165 let end = p.next().unwrap().parse::<usize>().unwrap();
166 start..=end
167 })
168 .flat_map(|p| mock_body_clone[p].to_vec())
169 .collect()
170 })
171 .create_async()
172 .await;
173 let puller = ReqwestPuller::new(
174 format!("{}/concurrent", server.url()).parse().unwrap(),
175 Client::new(),
176 );
177 let pusher = MockRandPusher::new(&mock_data);
178 #[allow(clippy::single_range_in_vec_init)]
179 let download_chunks = vec![0..mock_data.len() as u64];
180 let result = download_multi(
181 puller,
182 pusher.clone(),
183 multi::DownloadOptions {
184 concurrent: NonZeroUsize::new(32).unwrap(),
185 retry_gap: Duration::from_secs(1),
186 push_queue_cap: 1024,
187 download_chunks: download_chunks.clone(),
188 min_chunk_size: 1,
189 },
190 )
191 .await;
192
193 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
194 let mut push_progress: Vec<ProgressEntry> = Vec::new();
195 while let Ok(e) = result.event_chain.recv().await {
196 match e {
197 Event::PullProgress(_, p) => {
198 pull_progress.merge_progress(p);
199 }
200 Event::PushProgress(_, p) => {
201 push_progress.merge_progress(p);
202 }
203 _ => {}
204 }
205 }
206 dbg!(&pull_progress);
207 dbg!(&push_progress);
208 assert_eq!(pull_progress, download_chunks);
209 assert_eq!(push_progress, download_chunks);
210
211 result.join().await.unwrap();
212 pusher.assert().await;
213 }
214
215 #[tokio::test]
216 async fn test_sequential_download() {
217 let mock_data = build_mock_data(300 * 1024 * 1024);
218 let mut server = mockito::Server::new_async().await;
219 let _mock = server
220 .mock("GET", "/sequential")
221 .with_status(200)
222 .with_body(mock_data.clone())
223 .create_async()
224 .await;
225 let puller = ReqwestPuller::new(
226 format!("{}/sequential", server.url()).parse().unwrap(),
227 Client::new(),
228 );
229 let pusher = MockSeqPusher::new(&mock_data);
230 #[allow(clippy::single_range_in_vec_init)]
231 let download_chunks = vec![0..mock_data.len() as u64];
232 let result = download_single(
233 puller,
234 pusher.clone(),
235 single::DownloadOptions {
236 retry_gap: Duration::from_secs(1),
237 push_queue_cap: 1024,
238 },
239 )
240 .await;
241
242 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
243 let mut push_progress: Vec<ProgressEntry> = Vec::new();
244 while let Ok(e) = result.event_chain.recv().await {
245 match e {
246 Event::PullProgress(_, p) => {
247 pull_progress.merge_progress(p);
248 }
249 Event::PushProgress(_, p) => {
250 push_progress.merge_progress(p);
251 }
252 _ => {}
253 }
254 }
255 dbg!(&pull_progress);
256 dbg!(&push_progress);
257 assert_eq!(pull_progress, download_chunks);
258 assert_eq!(push_progress, download_chunks);
259
260 result.join().await.unwrap();
261 pusher.assert().await;
262 }
263}