fast_pull/reqwest/
reader.rs

1extern crate alloc;
2use crate::{RandReader, SeqReader};
3use alloc::{boxed::Box, format};
4use bytes::Bytes;
5use core::{
6    pin::{Pin, pin},
7    task::{Context, Poll},
8};
9use futures::{Stream, TryFutureExt, TryStream};
10use reqwest::{Client, Response, header};
11use url::Url;
12
13#[derive(Clone)]
14pub struct ReqwestReader {
15    pub(crate) client: Client,
16    url: Url,
17}
18
19impl ReqwestReader {
20    pub fn new(url: Url, client: Client) -> Self {
21        Self { client, url }
22    }
23}
24
25impl RandReader for ReqwestReader {
26    type Error = reqwest::Error;
27    fn read(
28        &mut self,
29        range: &crate::ProgressEntry,
30    ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
31        ReqwestStream {
32            client: self.client.clone(),
33            url: self.url.clone(),
34            start: range.start,
35            end: range.end,
36            resp: ResponseState::None,
37        }
38    }
39}
40type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
41enum ResponseState {
42    Pending(ResponseFut),
43    Ready(Response),
44    None,
45}
46struct ReqwestStream {
47    client: Client,
48    url: Url,
49    start: u64,
50    end: u64,
51    resp: ResponseState,
52}
53impl Stream for ReqwestStream {
54    type Item = Result<Bytes, reqwest::Error>;
55    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56        let chunk_global;
57        match &mut self.resp {
58            ResponseState::Pending(resp) => {
59                return match resp.try_poll_unpin(cx) {
60                    Poll::Ready(resp) => match resp {
61                        Ok(resp) => {
62                            self.resp = ResponseState::Ready(resp);
63                            self.poll_next(cx)
64                        }
65                        Err(e) => {
66                            self.resp = ResponseState::None;
67                            Poll::Ready(Some(Err(e)))
68                        }
69                    },
70                    Poll::Pending => Poll::Pending,
71                };
72            }
73            ResponseState::None => {
74                let resp = self
75                    .client
76                    .get(self.url.clone())
77                    .header(
78                        header::RANGE,
79                        format!("bytes={}-{}", self.start, self.end - 1),
80                    )
81                    .send();
82                self.resp = ResponseState::Pending(Box::pin(resp));
83                return self.poll_next(cx);
84            }
85            ResponseState::Ready(resp) => {
86                let mut chunk = pin!(resp.chunk());
87                match chunk.try_poll_unpin(cx) {
88                    Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
89                    Poll::Ready(Ok(None)) => return Poll::Ready(None),
90                    Poll::Ready(Err(e)) => chunk_global = Err(e),
91                    Poll::Pending => return Poll::Pending,
92                };
93            }
94        };
95        match chunk_global {
96            Ok(chunk) => {
97                self.start += chunk.len() as u64;
98                Poll::Ready(Some(Ok(chunk)))
99            }
100            Err(e) => {
101                self.resp = ResponseState::None;
102                Poll::Ready(Some(Err(e)))
103            }
104        }
105    }
106}
107
108impl SeqReader for ReqwestReader {
109    type Error = reqwest::Error;
110    fn read(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
111        let req = self.client.get(self.url.clone());
112        Box::pin(async move {
113            let resp = req.send().await?;
114            Ok(resp.bytes_stream())
115        })
116        .try_flatten_stream()
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    extern crate std;
123    use super::*;
124    use crate::{
125        Event, MergeProgress, ProgressEntry,
126        mock::{MockRandWriter, MockSeqWriter, build_mock_data},
127        multi::{self, download_multi},
128        reqwest::ReqwestReader,
129        single::{self, download_single},
130    };
131    use alloc::vec;
132    use core::{num::NonZeroUsize, time::Duration};
133    use reqwest::Client;
134    use std::{dbg, println};
135    use vec::Vec;
136
137    #[tokio::test]
138    async fn test_concurrent_download() {
139        let mock_data = build_mock_data(300 * 1024 * 1024);
140        let mut server = mockito::Server::new_async().await;
141        let mock_body_clone = mock_data.clone();
142        let _mock = server
143            .mock("GET", "/concurrent")
144            .with_status(206)
145            .with_body_from_request(move |request| {
146                if !request.has_header("Range") {
147                    return mock_body_clone.clone();
148                }
149                let range = request.header("Range")[0];
150                println!("range: {range:?}");
151                range
152                    .to_str()
153                    .unwrap()
154                    .rsplit('=')
155                    .next()
156                    .unwrap()
157                    .split(',')
158                    .map(|p| p.trim().splitn(2, '-'))
159                    .map(|mut p| {
160                        let start = p.next().unwrap().parse::<usize>().unwrap();
161                        let end = p.next().unwrap().parse::<usize>().unwrap();
162                        start..=end
163                    })
164                    .flat_map(|p| mock_body_clone[p].to_vec())
165                    .collect()
166            })
167            .create_async()
168            .await;
169        let reader = ReqwestReader::new(
170            format!("{}/concurrent", server.url()).parse().unwrap(),
171            Client::new(),
172        );
173        let writer = MockRandWriter::new(&mock_data);
174        #[allow(clippy::single_range_in_vec_init)]
175        let download_chunks = vec![0..mock_data.len() as u64];
176        let result = download_multi(
177            reader,
178            writer.clone(),
179            multi::DownloadOptions {
180                concurrent: NonZeroUsize::new(32).unwrap(),
181                retry_gap: Duration::from_secs(1),
182                write_queue_cap: 1024,
183                download_chunks: download_chunks.clone(),
184            },
185        )
186        .await;
187
188        let mut download_progress: Vec<ProgressEntry> = Vec::new();
189        let mut write_progress: Vec<ProgressEntry> = Vec::new();
190        while let Ok(e) = result.event_chain.recv().await {
191            match e {
192                Event::ReadProgress(_, p) => {
193                    download_progress.merge_progress(p);
194                }
195                Event::WriteProgress(_, p) => {
196                    write_progress.merge_progress(p);
197                }
198                _ => {}
199            }
200        }
201        dbg!(&download_progress);
202        dbg!(&write_progress);
203        assert_eq!(download_progress, download_chunks);
204        assert_eq!(write_progress, download_chunks);
205
206        result.join().await.unwrap();
207        writer.assert().await;
208    }
209
210    #[tokio::test]
211    async fn test_sequential_download() {
212        let mock_data = build_mock_data(300 * 1024 * 1024);
213        let mut server = mockito::Server::new_async().await;
214        let _mock = server
215            .mock("GET", "/sequential")
216            .with_status(200)
217            .with_body(mock_data.clone())
218            .create_async()
219            .await;
220        let reader = ReqwestReader::new(
221            format!("{}/sequential", server.url()).parse().unwrap(),
222            Client::new(),
223        );
224        let writer = MockSeqWriter::new(&mock_data);
225        #[allow(clippy::single_range_in_vec_init)]
226        let download_chunks = vec![0..mock_data.len() as u64];
227        let result = download_single(
228            reader,
229            writer.clone(),
230            single::DownloadOptions {
231                retry_gap: Duration::from_secs(1),
232                write_queue_cap: 1024,
233            },
234        )
235        .await;
236
237        let mut download_progress: Vec<ProgressEntry> = Vec::new();
238        let mut write_progress: Vec<ProgressEntry> = Vec::new();
239        while let Ok(e) = result.event_chain.recv().await {
240            match e {
241                Event::ReadProgress(_, p) => {
242                    download_progress.merge_progress(p);
243                }
244                Event::WriteProgress(_, p) => {
245                    write_progress.merge_progress(p);
246                }
247                _ => {}
248            }
249        }
250        dbg!(&download_progress);
251        dbg!(&write_progress);
252        assert_eq!(download_progress, download_chunks);
253        assert_eq!(write_progress, download_chunks);
254
255        result.join().await.unwrap();
256        writer.assert().await;
257    }
258}