1extern crate alloc;
2use crate::{RandReader, SeqReader};
3use alloc::{boxed::Box, format};
4use bytes::Bytes;
5use core::{
6 pin::{Pin, pin},
7 task::{Context, Poll},
8};
9use futures::{Stream, TryFutureExt, TryStream};
10use reqwest::{Client, Response, header};
11use url::Url;
12
13#[derive(Clone)]
14pub struct ReqwestReader {
15 pub(crate) client: Client,
16 url: Url,
17}
18
19impl ReqwestReader {
20 pub fn new(url: Url, client: Client) -> Self {
21 Self { client, url }
22 }
23}
24
25impl RandReader for ReqwestReader {
26 type Error = reqwest::Error;
27 fn read(
28 &mut self,
29 range: &crate::ProgressEntry,
30 ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
31 ReqwestStream {
32 client: self.client.clone(),
33 url: self.url.clone(),
34 start: range.start,
35 end: range.end,
36 resp: ResponseState::None,
37 }
38 }
39}
40type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
41enum ResponseState {
42 Pending(ResponseFut),
43 Ready(Response),
44 None,
45}
46struct ReqwestStream {
47 client: Client,
48 url: Url,
49 start: u64,
50 end: u64,
51 resp: ResponseState,
52}
53impl Stream for ReqwestStream {
54 type Item = Result<Bytes, reqwest::Error>;
55 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56 let chunk_global;
57 match &mut self.resp {
58 ResponseState::Pending(resp) => {
59 return match resp.try_poll_unpin(cx) {
60 Poll::Ready(resp) => match resp {
61 Ok(resp) => {
62 self.resp = ResponseState::Ready(resp);
63 self.poll_next(cx)
64 }
65 Err(e) => {
66 self.resp = ResponseState::None;
67 Poll::Ready(Some(Err(e)))
68 }
69 },
70 Poll::Pending => Poll::Pending,
71 };
72 }
73 ResponseState::None => {
74 let resp = self
75 .client
76 .get(self.url.clone())
77 .header(
78 header::RANGE,
79 format!("bytes={}-{}", self.start, self.end - 1),
80 )
81 .send();
82 self.resp = ResponseState::Pending(Box::pin(resp));
83 return self.poll_next(cx);
84 }
85 ResponseState::Ready(resp) => {
86 let mut chunk = pin!(resp.chunk());
87 match chunk.try_poll_unpin(cx) {
88 Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
89 Poll::Ready(Ok(None)) => return Poll::Ready(None),
90 Poll::Ready(Err(e)) => chunk_global = Err(e),
91 Poll::Pending => return Poll::Pending,
92 };
93 }
94 };
95 match chunk_global {
96 Ok(chunk) => {
97 self.start += chunk.len() as u64;
98 Poll::Ready(Some(Ok(chunk)))
99 }
100 Err(e) => {
101 self.resp = ResponseState::None;
102 Poll::Ready(Some(Err(e)))
103 }
104 }
105 }
106}
107
108impl SeqReader for ReqwestReader {
109 type Error = reqwest::Error;
110 fn read(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
111 let req = self.client.get(self.url.clone());
112 Box::pin(async move {
113 let resp = req.send().await?;
114 Ok(resp.bytes_stream())
115 })
116 .try_flatten_stream()
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 extern crate std;
123 use super::*;
124 use crate::{
125 Event, MergeProgress, ProgressEntry,
126 mock::{MockRandWriter, MockSeqWriter, build_mock_data},
127 multi::{self, download_multi},
128 reqwest::ReqwestReader,
129 single::{self, download_single},
130 };
131 use alloc::vec;
132 use core::{num::NonZeroUsize, time::Duration};
133 use reqwest::Client;
134 use std::{dbg, println};
135 use vec::Vec;
136
137 #[tokio::test]
138 async fn test_concurrent_download() {
139 let mock_data = build_mock_data(300 * 1024 * 1024);
140 let mut server = mockito::Server::new_async().await;
141 let mock_body_clone = mock_data.clone();
142 let _mock = server
143 .mock("GET", "/concurrent")
144 .with_status(206)
145 .with_body_from_request(move |request| {
146 if !request.has_header("Range") {
147 return mock_body_clone.clone();
148 }
149 let range = request.header("Range")[0];
150 println!("range: {range:?}");
151 range
152 .to_str()
153 .unwrap()
154 .rsplit('=')
155 .next()
156 .unwrap()
157 .split(',')
158 .map(|p| p.trim().splitn(2, '-'))
159 .map(|mut p| {
160 let start = p.next().unwrap().parse::<usize>().unwrap();
161 let end = p.next().unwrap().parse::<usize>().unwrap();
162 start..=end
163 })
164 .flat_map(|p| mock_body_clone[p].to_vec())
165 .collect()
166 })
167 .create_async()
168 .await;
169 let reader = ReqwestReader::new(
170 format!("{}/concurrent", server.url()).parse().unwrap(),
171 Client::new(),
172 );
173 let writer = MockRandWriter::new(&mock_data);
174 #[allow(clippy::single_range_in_vec_init)]
175 let download_chunks = vec![0..mock_data.len() as u64];
176 let result = download_multi(
177 reader,
178 writer.clone(),
179 multi::DownloadOptions {
180 concurrent: NonZeroUsize::new(32).unwrap(),
181 retry_gap: Duration::from_secs(1),
182 write_queue_cap: 1024,
183 download_chunks: download_chunks.clone(),
184 },
185 )
186 .await;
187
188 let mut download_progress: Vec<ProgressEntry> = Vec::new();
189 let mut write_progress: Vec<ProgressEntry> = Vec::new();
190 while let Ok(e) = result.event_chain.recv().await {
191 match e {
192 Event::ReadProgress(_, p) => {
193 download_progress.merge_progress(p);
194 }
195 Event::WriteProgress(_, p) => {
196 write_progress.merge_progress(p);
197 }
198 _ => {}
199 }
200 }
201 dbg!(&download_progress);
202 dbg!(&write_progress);
203 assert_eq!(download_progress, download_chunks);
204 assert_eq!(write_progress, download_chunks);
205
206 result.join().await.unwrap();
207 writer.assert().await;
208 }
209
210 #[tokio::test]
211 async fn test_sequential_download() {
212 let mock_data = build_mock_data(300 * 1024 * 1024);
213 let mut server = mockito::Server::new_async().await;
214 let _mock = server
215 .mock("GET", "/sequential")
216 .with_status(200)
217 .with_body(mock_data.clone())
218 .create_async()
219 .await;
220 let reader = ReqwestReader::new(
221 format!("{}/sequential", server.url()).parse().unwrap(),
222 Client::new(),
223 );
224 let writer = MockSeqWriter::new(&mock_data);
225 #[allow(clippy::single_range_in_vec_init)]
226 let download_chunks = vec![0..mock_data.len() as u64];
227 let result = download_single(
228 reader,
229 writer.clone(),
230 single::DownloadOptions {
231 retry_gap: Duration::from_secs(1),
232 write_queue_cap: 1024,
233 },
234 )
235 .await;
236
237 let mut download_progress: Vec<ProgressEntry> = Vec::new();
238 let mut write_progress: Vec<ProgressEntry> = Vec::new();
239 while let Ok(e) = result.event_chain.recv().await {
240 match e {
241 Event::ReadProgress(_, p) => {
242 download_progress.merge_progress(p);
243 }
244 Event::WriteProgress(_, p) => {
245 write_progress.merge_progress(p);
246 }
247 _ => {}
248 }
249 }
250 dbg!(&download_progress);
251 dbg!(&write_progress);
252 assert_eq!(download_progress, download_chunks);
253 assert_eq!(write_progress, download_chunks);
254
255 result.join().await.unwrap();
256 writer.assert().await;
257 }
258}