fast_pull/reqwest/
reader.rs

1extern crate alloc;
2use crate::{RandReader, SeqReader};
3use alloc::{boxed::Box, format};
4use bytes::Bytes;
5use core::{
6    pin::{Pin, pin},
7    task::{Context, Poll},
8};
9use futures::{Stream, TryFutureExt, TryStream};
10use reqwest::{Client, Response, header};
11use url::Url;
12
13#[derive(Clone)]
14pub struct ReqwestReader {
15    pub(crate) client: Client,
16    url: Url,
17}
18
19impl ReqwestReader {
20    pub fn new(url: Url, client: Client) -> Self {
21        Self { client, url }
22    }
23}
24
25impl RandReader for ReqwestReader {
26    type Error = reqwest::Error;
27    fn read(
28        &mut self,
29        range: &crate::ProgressEntry,
30    ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
31        ReqwestStream {
32            client: self.client.clone(),
33            url: self.url.clone(),
34            start: range.start,
35            end: range.end,
36            resp: ResponseState::None,
37            max_retries: 3,
38            curr_retries: 0,
39        }
40    }
41}
42type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
43enum ResponseState {
44    Pending(ResponseFut),
45    Ready(Response),
46    None,
47}
48struct ReqwestStream {
49    client: Client,
50    url: Url,
51    start: u64,
52    end: u64,
53    resp: ResponseState,
54    max_retries: usize,
55    curr_retries: usize,
56}
57impl Stream for ReqwestStream {
58    type Item = Result<Bytes, reqwest::Error>;
59    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60        let chunk_global;
61        match &mut self.resp {
62            ResponseState::Pending(resp) => {
63                return match resp.try_poll_unpin(cx) {
64                    Poll::Ready(resp) => match resp {
65                        Ok(resp) => {
66                            self.resp = ResponseState::Ready(resp);
67                            self.poll_next(cx)
68                        }
69                        Err(e) => {
70                            self.resp = ResponseState::None;
71                            Poll::Ready(Some(Err(e)))
72                        }
73                    },
74                    Poll::Pending => Poll::Pending,
75                };
76            }
77            ResponseState::None => {
78                let resp = self
79                    .client
80                    .get(self.url.clone())
81                    .header(
82                        header::RANGE,
83                        format!("bytes={}-{}", self.start, self.end - 1),
84                    )
85                    .send();
86                self.resp = ResponseState::Pending(Box::pin(resp));
87                return self.poll_next(cx);
88            }
89            ResponseState::Ready(resp) => {
90                let mut chunk = pin!(resp.chunk());
91                match chunk.try_poll_unpin(cx) {
92                    Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
93                    Poll::Ready(Ok(None)) => return Poll::Ready(None),
94                    Poll::Ready(Err(e)) => chunk_global = Err(e),
95                    Poll::Pending => return Poll::Pending,
96                };
97            }
98        };
99        match chunk_global {
100            Ok(chunk) => {
101                self.start += chunk.len() as u64;
102                Poll::Ready(Some(Ok(chunk)))
103            }
104            Err(e) => {
105                self.curr_retries += 1;
106                if self.curr_retries >= self.max_retries {
107                    self.curr_retries = 0;
108                    self.resp = ResponseState::None;
109                }
110                Poll::Ready(Some(Err(e)))
111            }
112        }
113    }
114}
115
116impl SeqReader for ReqwestReader {
117    type Error = reqwest::Error;
118    fn read(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
119        let req = self.client.get(self.url.clone());
120        Box::pin(async move {
121            let resp = req.send().await?;
122            Ok(resp.bytes_stream())
123        })
124        .try_flatten_stream()
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    extern crate std;
131    use super::*;
132    use crate::{
133        Event, MergeProgress, ProgressEntry,
134        mock::{MockRandWriter, MockSeqWriter, build_mock_data},
135        multi::{self, download_multi},
136        reqwest::ReqwestReader,
137        single::{self, download_single},
138    };
139    use alloc::vec;
140    use core::{num::NonZeroUsize, time::Duration};
141    use reqwest::Client;
142    use std::{dbg, println};
143    use vec::Vec;
144
145    #[tokio::test]
146    async fn test_concurrent_download() {
147        let mock_data = build_mock_data(300 * 1024 * 1024);
148        let mut server = mockito::Server::new_async().await;
149        let mock_body_clone = mock_data.clone();
150        let _mock = server
151            .mock("GET", "/concurrent")
152            .with_status(206)
153            .with_body_from_request(move |request| {
154                if !request.has_header("Range") {
155                    return mock_body_clone.clone();
156                }
157                let range = request.header("Range")[0];
158                println!("range: {range:?}");
159                range
160                    .to_str()
161                    .unwrap()
162                    .rsplit('=')
163                    .next()
164                    .unwrap()
165                    .split(',')
166                    .map(|p| p.trim().splitn(2, '-'))
167                    .map(|mut p| {
168                        let start = p.next().unwrap().parse::<usize>().unwrap();
169                        let end = p.next().unwrap().parse::<usize>().unwrap();
170                        start..=end
171                    })
172                    .flat_map(|p| mock_body_clone[p].to_vec())
173                    .collect()
174            })
175            .create_async()
176            .await;
177        let reader = ReqwestReader::new(
178            format!("{}/concurrent", server.url()).parse().unwrap(),
179            Client::new(),
180        );
181        let writer = MockRandWriter::new(&mock_data);
182        #[allow(clippy::single_range_in_vec_init)]
183        let download_chunks = vec![0..mock_data.len() as u64];
184        let result = download_multi(
185            reader,
186            writer.clone(),
187            multi::DownloadOptions {
188                concurrent: NonZeroUsize::new(32).unwrap(),
189                retry_gap: Duration::from_secs(1),
190                write_queue_cap: 1024,
191                download_chunks: download_chunks.clone(),
192            },
193        )
194        .await;
195
196        let mut download_progress: Vec<ProgressEntry> = Vec::new();
197        let mut write_progress: Vec<ProgressEntry> = Vec::new();
198        while let Ok(e) = result.event_chain.recv().await {
199            match e {
200                Event::ReadProgress(_, p) => {
201                    download_progress.merge_progress(p);
202                }
203                Event::WriteProgress(_, p) => {
204                    write_progress.merge_progress(p);
205                }
206                _ => {}
207            }
208        }
209        dbg!(&download_progress);
210        dbg!(&write_progress);
211        assert_eq!(download_progress, download_chunks);
212        assert_eq!(write_progress, download_chunks);
213
214        result.join().await.unwrap();
215        writer.assert().await;
216    }
217
218    #[tokio::test]
219    async fn test_sequential_download() {
220        let mock_data = build_mock_data(300 * 1024 * 1024);
221        let mut server = mockito::Server::new_async().await;
222        let _mock = server
223            .mock("GET", "/sequential")
224            .with_status(200)
225            .with_body(mock_data.clone())
226            .create_async()
227            .await;
228        let reader = ReqwestReader::new(
229            format!("{}/sequential", server.url()).parse().unwrap(),
230            Client::new(),
231        );
232        let writer = MockSeqWriter::new(&mock_data);
233        #[allow(clippy::single_range_in_vec_init)]
234        let download_chunks = vec![0..mock_data.len() as u64];
235        let result = download_single(
236            reader,
237            writer.clone(),
238            single::DownloadOptions {
239                retry_gap: Duration::from_secs(1),
240                write_queue_cap: 1024,
241            },
242        )
243        .await;
244
245        let mut download_progress: Vec<ProgressEntry> = Vec::new();
246        let mut write_progress: Vec<ProgressEntry> = Vec::new();
247        while let Ok(e) = result.event_chain.recv().await {
248            match e {
249                Event::ReadProgress(_, p) => {
250                    download_progress.merge_progress(p);
251                }
252                Event::WriteProgress(_, p) => {
253                    write_progress.merge_progress(p);
254                }
255                _ => {}
256            }
257        }
258        dbg!(&download_progress);
259        dbg!(&write_progress);
260        assert_eq!(download_progress, download_chunks);
261        assert_eq!(write_progress, download_chunks);
262
263        result.join().await.unwrap();
264        writer.assert().await;
265    }
266}