fast_pull/reqwest/
reader.rs

1use crate::{RandReader, SeqReader};
2use bytes::Bytes;
3use core::{
4    pin::{Pin, pin},
5    task::{Context, Poll},
6};
7use futures::{Stream, TryFutureExt, TryStream};
8use reqwest::{Client, Response, header};
9use url::Url;
10
11#[derive(Clone)]
12pub struct ReqwestReader {
13    pub(crate) client: Client,
14    url: Url,
15}
16
17impl ReqwestReader {
18    pub fn new(url: Url, client: Client) -> Self {
19        Self { client, url }
20    }
21}
22
23impl RandReader for ReqwestReader {
24    type Error = reqwest::Error;
25    fn read(
26        &mut self,
27        range: &crate::ProgressEntry,
28    ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
29        ReqwestStream {
30            client: self.client.clone(),
31            url: self.url.clone(),
32            start: range.start,
33            end: range.end,
34            resp: ResponseState::None,
35            max_retries: 3,
36            curr_retries: 0,
37        }
38    }
39}
40type ResponseFut = Pin<Box<dyn Future<Output = Result<Response, reqwest::Error>> + Send>>;
41enum ResponseState {
42    Pending(ResponseFut),
43    Ready(Response),
44    None,
45}
46struct ReqwestStream {
47    client: Client,
48    url: Url,
49    start: u64,
50    end: u64,
51    resp: ResponseState,
52    max_retries: usize,
53    curr_retries: usize,
54}
55impl Stream for ReqwestStream {
56    type Item = Result<Bytes, reqwest::Error>;
57    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
58        let chunk_global;
59        match &mut self.resp {
60            ResponseState::Pending(resp) => match resp.try_poll_unpin(cx) {
61                Poll::Ready(resp) => {
62                    return match resp {
63                        Ok(resp) => {
64                            self.resp = ResponseState::Ready(resp);
65                            self.poll_next(cx)
66                        }
67                        Err(e) => {
68                            self.resp = ResponseState::None;
69                            Poll::Ready(Some(Err(e)))
70                        }
71                    };
72                }
73                Poll::Pending => return Poll::Pending,
74            },
75            ResponseState::None => {
76                let resp = self
77                    .client
78                    .get(self.url.clone())
79                    .header(
80                        header::RANGE,
81                        format!("bytes={}-{}", self.start, self.end - 1),
82                    )
83                    .send();
84                self.resp = ResponseState::Pending(Box::pin(resp));
85                return self.poll_next(cx);
86            }
87            ResponseState::Ready(resp) => {
88                let mut chunk = pin!(resp.chunk());
89                match chunk.try_poll_unpin(cx) {
90                    Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
91                    Poll::Ready(Ok(None)) => return Poll::Ready(None),
92                    Poll::Ready(Err(e)) => chunk_global = Err(e),
93                    Poll::Pending => return Poll::Pending,
94                };
95            }
96        };
97        match chunk_global {
98            Ok(chunk) => {
99                self.start += chunk.len() as u64;
100                Poll::Ready(Some(Ok(chunk)))
101            }
102            Err(e) => {
103                self.curr_retries += 1;
104                if self.curr_retries >= self.max_retries {
105                    self.curr_retries = 0;
106                    self.resp = ResponseState::None;
107                }
108                Poll::Ready(Some(Err(e)))
109            }
110        }
111    }
112}
113
114impl SeqReader for ReqwestReader {
115    type Error = reqwest::Error;
116    fn read(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
117        let req = self.client.get(self.url.clone());
118        Box::pin(async move {
119            let resp = req.send().await?;
120            Ok(resp.bytes_stream())
121        })
122        .try_flatten_stream()
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use crate::{
129        Event, MergeProgress, ProgressEntry,
130        mock::{MockRandWriter, MockSeqWriter, build_mock_data},
131        multi::{self, download_multi},
132        reqwest::ReqwestReader,
133        single::{self, download_single},
134    };
135    use core::{num::NonZeroUsize, time::Duration};
136    use reqwest::Client;
137
138    #[tokio::test]
139    async fn test_concurrent_download() {
140        let mock_data = build_mock_data(300 * 1024 * 1024);
141        let mut server = mockito::Server::new_async().await;
142        let mock_body_clone = mock_data.clone();
143        let _mock = server
144            .mock("GET", "/concurrent")
145            .with_status(206)
146            .with_body_from_request(move |request| {
147                if !request.has_header("Range") {
148                    return mock_body_clone.clone();
149                }
150                let range = request.header("Range")[0];
151                println!("range: {range:?}");
152                range
153                    .to_str()
154                    .unwrap()
155                    .rsplit('=')
156                    .next()
157                    .unwrap()
158                    .split(',')
159                    .map(|p| p.trim().splitn(2, '-'))
160                    .map(|mut p| {
161                        let start = p.next().unwrap().parse::<usize>().unwrap();
162                        let end = p.next().unwrap().parse::<usize>().unwrap();
163                        start..=end
164                    })
165                    .flat_map(|p| mock_body_clone[p].to_vec())
166                    .collect()
167            })
168            .create_async()
169            .await;
170        let reader = ReqwestReader::new(
171            format!("{}/concurrent", server.url()).parse().unwrap(),
172            Client::new(),
173        );
174        let writer = MockRandWriter::new(&mock_data);
175        #[allow(clippy::single_range_in_vec_init)]
176        let download_chunks = vec![0..mock_data.len() as u64];
177        let result = download_multi(
178            reader,
179            writer.clone(),
180            multi::DownloadOptions {
181                concurrent: NonZeroUsize::new(32).unwrap(),
182                retry_gap: Duration::from_secs(1),
183                write_queue_cap: 1024,
184                download_chunks: download_chunks.clone(),
185            },
186        )
187        .await;
188
189        let mut download_progress: Vec<ProgressEntry> = Vec::new();
190        let mut write_progress: Vec<ProgressEntry> = Vec::new();
191        while let Ok(e) = result.event_chain.recv().await {
192            match e {
193                Event::ReadProgress(_, p) => {
194                    download_progress.merge_progress(p);
195                }
196                Event::WriteProgress(_, p) => {
197                    write_progress.merge_progress(p);
198                }
199                _ => {}
200            }
201        }
202        dbg!(&download_progress);
203        dbg!(&write_progress);
204        assert_eq!(download_progress, download_chunks);
205        assert_eq!(write_progress, download_chunks);
206
207        result.join().await.unwrap();
208        writer.assert().await;
209    }
210
211    #[tokio::test]
212    async fn test_sequential_download() {
213        let mock_data = build_mock_data(300 * 1024 * 1024);
214        let mut server = mockito::Server::new_async().await;
215        let _mock = server
216            .mock("GET", "/sequential")
217            .with_status(200)
218            .with_body(mock_data.clone())
219            .create_async()
220            .await;
221        let reader = ReqwestReader::new(
222            format!("{}/sequential", server.url()).parse().unwrap(),
223            Client::new(),
224        );
225        let writer = MockSeqWriter::new(&mock_data);
226        #[allow(clippy::single_range_in_vec_init)]
227        let download_chunks = vec![0..mock_data.len() as u64];
228        let result = download_single(
229            reader,
230            writer.clone(),
231            single::DownloadOptions {
232                retry_gap: Duration::from_secs(1),
233                write_queue_cap: 1024,
234            },
235        )
236        .await;
237
238        let mut download_progress: Vec<ProgressEntry> = Vec::new();
239        let mut write_progress: Vec<ProgressEntry> = Vec::new();
240        while let Ok(e) = result.event_chain.recv().await {
241            match e {
242                Event::ReadProgress(_, p) => {
243                    download_progress.merge_progress(p);
244                }
245                Event::WriteProgress(_, p) => {
246                    write_progress.merge_progress(p);
247                }
248                _ => {}
249            }
250        }
251        dbg!(&download_progress);
252        dbg!(&write_progress);
253        assert_eq!(download_progress, download_chunks);
254        assert_eq!(write_progress, download_chunks);
255
256        result.join().await.unwrap();
257        writer.assert().await;
258    }
259}