fast_pull/reqwest/
reader.rs

1extern crate alloc;
2use crate::{RandReader, SeqReader};
3use alloc::{boxed::Box, format};
4use bytes::Bytes;
5use core::{
6    pin::{Pin, pin},
7    task::{Context, Poll},
8};
9use futures::{Stream, TryFutureExt, TryStream};
10use reqwest::{Client, Response, header};
11use url::Url;
12
13#[derive(Clone)]
14pub struct ReqwestReader {
15    pub(crate) client: Client,
16    url: Url,
17}
18
19impl ReqwestReader {
20    pub fn new(url: Url, client: Client) -> Self {
21        Self { client, url }
22    }
23}
24
25impl RandReader for ReqwestReader {
26    type Error = reqwest::Error;
27    fn read(
28        &mut self,
29        range: &crate::ProgressEntry,
30    ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
31        ReqwestStream {
32            client: self.client.clone(),
33            url: self.url.clone(),
34            start: range.start,
35            end: range.end,
36            resp: ResponseState::None,
37        }
38    }
39}
40type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
41enum ResponseState {
42    Pending(ResponseFut),
43    Ready(Response),
44    None,
45}
46struct ReqwestStream {
47    client: Client,
48    url: Url,
49    start: u64,
50    end: u64,
51    resp: ResponseState,
52}
53impl Stream for ReqwestStream {
54    type Item = Result<Bytes, reqwest::Error>;
55    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56        let chunk_global;
57        match &mut self.resp {
58            ResponseState::Pending(resp) => {
59                return match resp.try_poll_unpin(cx) {
60                    Poll::Ready(resp) => match resp {
61                        Ok(resp) => {
62                            self.resp = ResponseState::Ready(resp);
63                            self.poll_next(cx)
64                        }
65                        Err(e) => {
66                            self.resp = ResponseState::None;
67                            Poll::Ready(Some(Err(e)))
68                        }
69                    },
70                    Poll::Pending => Poll::Pending,
71                };
72            }
73            ResponseState::None => {
74                let resp = self
75                    .client
76                    .get(self.url.clone())
77                    .header(
78                        header::RANGE,
79                        format!("bytes={}-{}", self.start, self.end - 1),
80                    )
81                    .send();
82                self.resp = ResponseState::Pending(Box::pin(resp));
83                return self.poll_next(cx);
84            }
85            ResponseState::Ready(resp) => {
86                if let Err(e) = resp.error_for_status_ref() {
87                    self.resp = ResponseState::None;
88                    return Poll::Ready(Some(Err(e)));
89                }
90                let mut chunk = pin!(resp.chunk());
91                match chunk.try_poll_unpin(cx) {
92                    Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
93                    Poll::Ready(Ok(None)) => return Poll::Ready(None),
94                    Poll::Ready(Err(e)) => chunk_global = Err(e),
95                    Poll::Pending => return Poll::Pending,
96                };
97            }
98        };
99        match chunk_global {
100            Ok(chunk) => {
101                self.start += chunk.len() as u64;
102                Poll::Ready(Some(Ok(chunk)))
103            }
104            Err(e) => {
105                self.resp = ResponseState::None;
106                Poll::Ready(Some(Err(e)))
107            }
108        }
109    }
110}
111
112impl SeqReader for ReqwestReader {
113    type Error = reqwest::Error;
114    fn read(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
115        let req = self.client.get(self.url.clone());
116        Box::pin(async move {
117            let resp = req.send().await?;
118            Ok(resp.bytes_stream())
119        })
120        .try_flatten_stream()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    extern crate std;
127    use super::*;
128    use crate::{
129        Event, MergeProgress, ProgressEntry,
130        mock::{MockRandWriter, MockSeqWriter, build_mock_data},
131        multi::{self, download_multi},
132        reqwest::ReqwestReader,
133        single::{self, download_single},
134    };
135    use alloc::vec;
136    use core::{num::NonZeroUsize, time::Duration};
137    use reqwest::Client;
138    use std::{dbg, println};
139    use vec::Vec;
140
141    #[tokio::test]
142    async fn test_concurrent_download() {
143        let mock_data = build_mock_data(300 * 1024 * 1024);
144        let mut server = mockito::Server::new_async().await;
145        let mock_body_clone = mock_data.clone();
146        let _mock = server
147            .mock("GET", "/concurrent")
148            .with_status(206)
149            .with_body_from_request(move |request| {
150                if !request.has_header("Range") {
151                    return mock_body_clone.clone();
152                }
153                let range = request.header("Range")[0];
154                println!("range: {range:?}");
155                range
156                    .to_str()
157                    .unwrap()
158                    .rsplit('=')
159                    .next()
160                    .unwrap()
161                    .split(',')
162                    .map(|p| p.trim().splitn(2, '-'))
163                    .map(|mut p| {
164                        let start = p.next().unwrap().parse::<usize>().unwrap();
165                        let end = p.next().unwrap().parse::<usize>().unwrap();
166                        start..=end
167                    })
168                    .flat_map(|p| mock_body_clone[p].to_vec())
169                    .collect()
170            })
171            .create_async()
172            .await;
173        let reader = ReqwestReader::new(
174            format!("{}/concurrent", server.url()).parse().unwrap(),
175            Client::new(),
176        );
177        let writer = MockRandWriter::new(&mock_data);
178        #[allow(clippy::single_range_in_vec_init)]
179        let download_chunks = vec![0..mock_data.len() as u64];
180        let result = download_multi(
181            reader,
182            writer.clone(),
183            multi::DownloadOptions {
184                concurrent: NonZeroUsize::new(32).unwrap(),
185                retry_gap: Duration::from_secs(1),
186                write_queue_cap: 1024,
187                download_chunks: download_chunks.clone(),
188            },
189        )
190        .await;
191
192        let mut download_progress: Vec<ProgressEntry> = Vec::new();
193        let mut write_progress: Vec<ProgressEntry> = Vec::new();
194        while let Ok(e) = result.event_chain.recv().await {
195            match e {
196                Event::ReadProgress(_, p) => {
197                    download_progress.merge_progress(p);
198                }
199                Event::WriteProgress(_, p) => {
200                    write_progress.merge_progress(p);
201                }
202                _ => {}
203            }
204        }
205        dbg!(&download_progress);
206        dbg!(&write_progress);
207        assert_eq!(download_progress, download_chunks);
208        assert_eq!(write_progress, download_chunks);
209
210        result.join().await.unwrap();
211        writer.assert().await;
212    }
213
214    #[tokio::test]
215    async fn test_sequential_download() {
216        let mock_data = build_mock_data(300 * 1024 * 1024);
217        let mut server = mockito::Server::new_async().await;
218        let _mock = server
219            .mock("GET", "/sequential")
220            .with_status(200)
221            .with_body(mock_data.clone())
222            .create_async()
223            .await;
224        let reader = ReqwestReader::new(
225            format!("{}/sequential", server.url()).parse().unwrap(),
226            Client::new(),
227        );
228        let writer = MockSeqWriter::new(&mock_data);
229        #[allow(clippy::single_range_in_vec_init)]
230        let download_chunks = vec![0..mock_data.len() as u64];
231        let result = download_single(
232            reader,
233            writer.clone(),
234            single::DownloadOptions {
235                retry_gap: Duration::from_secs(1),
236                write_queue_cap: 1024,
237            },
238        )
239        .await;
240
241        let mut download_progress: Vec<ProgressEntry> = Vec::new();
242        let mut write_progress: Vec<ProgressEntry> = Vec::new();
243        while let Ok(e) = result.event_chain.recv().await {
244            match e {
245                Event::ReadProgress(_, p) => {
246                    download_progress.merge_progress(p);
247                }
248                Event::WriteProgress(_, p) => {
249                    write_progress.merge_progress(p);
250                }
251                _ => {}
252            }
253        }
254        dbg!(&download_progress);
255        dbg!(&write_progress);
256        assert_eq!(download_progress, download_chunks);
257        assert_eq!(write_progress, download_chunks);
258
259        result.join().await.unwrap();
260        writer.assert().await;
261    }
262}