fast_pull/reqwest/
puller.rs

1extern crate alloc;
2use crate::{RandPuller, SeqPuller};
3use alloc::{boxed::Box, format};
4use bytes::Bytes;
5use core::{
6    pin::{Pin, pin},
7    task::{Context, Poll},
8};
9use futures::{Stream, TryFutureExt, TryStream};
10use reqwest::{Client, Response, header};
11use url::Url;
12
13#[derive(Clone)]
14pub struct ReqwestPuller {
15    pub(crate) client: Client,
16    url: Url,
17}
18
19impl ReqwestPuller {
20    pub fn new(url: Url, client: Client) -> Self {
21        Self { client, url }
22    }
23}
24
25impl RandPuller for ReqwestPuller {
26    type Error = reqwest::Error;
27    fn pull(
28        &mut self,
29        range: &crate::ProgressEntry,
30    ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
31        ReqwestStream {
32            client: self.client.clone(),
33            url: self.url.clone(),
34            start: range.start,
35            end: range.end,
36            resp: ResponseState::None,
37        }
38    }
39}
40type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
41enum ResponseState {
42    Pending(ResponseFut),
43    Ready(Response),
44    None,
45}
46struct ReqwestStream {
47    client: Client,
48    url: Url,
49    start: u64,
50    end: u64,
51    resp: ResponseState,
52}
53impl Stream for ReqwestStream {
54    type Item = Result<Bytes, reqwest::Error>;
55    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56        let chunk_global;
57        match &mut self.resp {
58            ResponseState::Pending(resp) => {
59                return match resp.try_poll_unpin(cx) {
60                    Poll::Ready(resp) => match resp {
61                        Ok(resp) => {
62                            self.resp = ResponseState::Ready(resp);
63                            self.poll_next(cx)
64                        }
65                        Err(e) => {
66                            self.resp = ResponseState::None;
67                            Poll::Ready(Some(Err(e)))
68                        }
69                    },
70                    Poll::Pending => Poll::Pending,
71                };
72            }
73            ResponseState::None => {
74                let resp = self
75                    .client
76                    .get(self.url.clone())
77                    .header(
78                        header::RANGE,
79                        format!("bytes={}-{}", self.start, self.end - 1),
80                    )
81                    .send();
82                self.resp = ResponseState::Pending(Box::pin(resp));
83                return self.poll_next(cx);
84            }
85            ResponseState::Ready(resp) => {
86                if let Err(e) = resp.error_for_status_ref() {
87                    self.resp = ResponseState::None;
88                    return Poll::Ready(Some(Err(e)));
89                }
90                let mut chunk = pin!(resp.chunk());
91                match chunk.try_poll_unpin(cx) {
92                    Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
93                    Poll::Ready(Ok(None)) => return Poll::Ready(None),
94                    Poll::Ready(Err(e)) => chunk_global = Err(e),
95                    Poll::Pending => return Poll::Pending,
96                };
97            }
98        };
99        match chunk_global {
100            Ok(chunk) => {
101                self.start += chunk.len() as u64;
102                Poll::Ready(Some(Ok(chunk)))
103            }
104            Err(e) => {
105                self.resp = ResponseState::None;
106                Poll::Ready(Some(Err(e)))
107            }
108        }
109    }
110}
111
112impl SeqPuller for ReqwestPuller {
113    type Error = reqwest::Error;
114    fn pull(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
115        let req = self.client.get(self.url.clone());
116        Box::pin(async move {
117            let resp = req.send().await?;
118            Ok(resp.bytes_stream())
119        })
120        .try_flatten_stream()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    extern crate std;
127    use super::*;
128    use crate::{
129        Event, MergeProgress, ProgressEntry,
130        mock::{MockRandPusher, MockSeqPusher, build_mock_data},
131        multi::{self, download_multi},
132        reqwest::ReqwestPuller,
133        single::{self, download_single},
134    };
135    use alloc::vec;
136    use core::{num::NonZeroUsize, time::Duration};
137    use reqwest::Client;
138    use std::{dbg, println};
139    use vec::Vec;
140
141    #[tokio::test]
142    async fn test_concurrent_download() {
143        let mock_data = build_mock_data(300 * 1024 * 1024);
144        let mut server = mockito::Server::new_async().await;
145        let mock_body_clone = mock_data.clone();
146        let _mock = server
147            .mock("GET", "/concurrent")
148            .with_status(206)
149            .with_body_from_request(move |request| {
150                if !request.has_header("Range") {
151                    return mock_body_clone.clone();
152                }
153                let range = request.header("Range")[0];
154                println!("range: {range:?}");
155                range
156                    .to_str()
157                    .unwrap()
158                    .rsplit('=')
159                    .next()
160                    .unwrap()
161                    .split(',')
162                    .map(|p| p.trim().splitn(2, '-'))
163                    .map(|mut p| {
164                        let start = p.next().unwrap().parse::<usize>().unwrap();
165                        let end = p.next().unwrap().parse::<usize>().unwrap();
166                        start..=end
167                    })
168                    .flat_map(|p| mock_body_clone[p].to_vec())
169                    .collect()
170            })
171            .create_async()
172            .await;
173        let puller = ReqwestPuller::new(
174            format!("{}/concurrent", server.url()).parse().unwrap(),
175            Client::new(),
176        );
177        let pusher = MockRandPusher::new(&mock_data);
178        #[allow(clippy::single_range_in_vec_init)]
179        let download_chunks = vec![0..mock_data.len() as u64];
180        let result = download_multi(
181            puller,
182            pusher.clone(),
183            multi::DownloadOptions {
184                concurrent: NonZeroUsize::new(32).unwrap(),
185                retry_gap: Duration::from_secs(1),
186                push_queue_cap: 1024,
187                download_chunks: download_chunks.clone(),
188                min_chunk_size: 1,
189            },
190        )
191        .await;
192
193        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
194        let mut push_progress: Vec<ProgressEntry> = Vec::new();
195        while let Ok(e) = result.event_chain.recv().await {
196            match e {
197                Event::PullProgress(_, p) => {
198                    pull_progress.merge_progress(p);
199                }
200                Event::PushProgress(_, p) => {
201                    push_progress.merge_progress(p);
202                }
203                _ => {}
204            }
205        }
206        dbg!(&pull_progress);
207        dbg!(&push_progress);
208        assert_eq!(pull_progress, download_chunks);
209        assert_eq!(push_progress, download_chunks);
210
211        result.join().await.unwrap();
212        pusher.assert().await;
213    }
214
215    #[tokio::test]
216    async fn test_sequential_download() {
217        let mock_data = build_mock_data(300 * 1024 * 1024);
218        let mut server = mockito::Server::new_async().await;
219        let _mock = server
220            .mock("GET", "/sequential")
221            .with_status(200)
222            .with_body(mock_data.clone())
223            .create_async()
224            .await;
225        let puller = ReqwestPuller::new(
226            format!("{}/sequential", server.url()).parse().unwrap(),
227            Client::new(),
228        );
229        let pusher = MockSeqPusher::new(&mock_data);
230        #[allow(clippy::single_range_in_vec_init)]
231        let download_chunks = vec![0..mock_data.len() as u64];
232        let result = download_single(
233            puller,
234            pusher.clone(),
235            single::DownloadOptions {
236                retry_gap: Duration::from_secs(1),
237                push_queue_cap: 1024,
238            },
239        )
240        .await;
241
242        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
243        let mut push_progress: Vec<ProgressEntry> = Vec::new();
244        while let Ok(e) = result.event_chain.recv().await {
245            match e {
246                Event::PullProgress(_, p) => {
247                    pull_progress.merge_progress(p);
248                }
249                Event::PushProgress(_, p) => {
250                    push_progress.merge_progress(p);
251                }
252                _ => {}
253            }
254        }
255        dbg!(&pull_progress);
256        dbg!(&push_progress);
257        assert_eq!(pull_progress, download_chunks);
258        assert_eq!(push_progress, download_chunks);
259
260        result.join().await.unwrap();
261        pusher.assert().await;
262    }
263}