fast_down/http/
puller.rs

1use crate::http::{
2    FileId, GetRequestError, GetResponse, HttpClient, HttpError, HttpHeaders, HttpRequestBuilder,
3    HttpResponse,
4};
5use bytes::Bytes;
6use fast_pull::{ProgressEntry, RandPuller, SeqPuller};
7use futures::{Stream, TryFutureExt, TryStream};
8use spin::mutex::SpinMutex;
9use std::{
10    pin::{Pin, pin},
11    sync::Arc,
12    task::{Context, Poll},
13};
14use url::Url;
15
16#[derive(Clone)]
17pub struct HttpPuller<Client: HttpClient> {
18    pub(crate) client: Client,
19    url: Url,
20    resp: Arc<SpinMutex<Option<GetResponse<Client>>>>,
21    file_id: FileId,
22}
23impl<Client: HttpClient> HttpPuller<Client> {
24    pub fn new(
25        url: Url,
26        client: Client,
27        resp: Option<GetResponse<Client>>,
28        file_id: FileId,
29    ) -> Self {
30        Self {
31            client,
32            url,
33            resp: Arc::new(SpinMutex::new(resp)),
34            file_id,
35        }
36    }
37}
38
39type ResponseFut<Client> =
40    Pin<Box<dyn Future<Output = Result<GetResponse<Client>, GetRequestError<Client>>> + Send>>;
41enum ResponseState<Client: HttpClient> {
42    Pending(ResponseFut<Client>),
43    Ready(GetResponse<Client>),
44    None,
45}
46
47impl<Client: HttpClient + 'static> RandPuller for HttpPuller<Client> {
48    type Error = HttpError<Client>;
49    fn pull(
50        &mut self,
51        range: &ProgressEntry,
52    ) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
53        RandRequestStream {
54            client: self.client.clone(),
55            url: self.url.clone(),
56            start: range.start,
57            end: range.end,
58            state: if range.start == 0
59                && let Some(resp) = self.resp.lock().take()
60            {
61                ResponseState::Ready(resp)
62            } else {
63                ResponseState::None
64            },
65            file_id: self.file_id.clone(),
66        }
67    }
68}
69struct RandRequestStream<Client: HttpClient + 'static> {
70    client: Client,
71    url: Url,
72    start: u64,
73    end: u64,
74    state: ResponseState<Client>,
75    file_id: FileId,
76}
77impl<Client: HttpClient> Stream for RandRequestStream<Client> {
78    type Item = Result<Bytes, HttpError<Client>>;
79    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80        let chunk_global;
81        match &mut self.state {
82            ResponseState::Pending(resp) => {
83                return match resp.try_poll_unpin(cx) {
84                    Poll::Ready(resp) => match resp {
85                        Ok(resp) => {
86                            let etag = resp.headers().get("etag").ok();
87                            let last_modified = resp.headers().get("last-modified").ok();
88                            let new_file_id = FileId::new(etag, last_modified);
89                            if new_file_id != self.file_id {
90                                self.state = ResponseState::None;
91                                Poll::Ready(Some(Err(HttpError::MismatchedBody(new_file_id))))
92                            } else {
93                                self.state = ResponseState::Ready(resp);
94                                self.poll_next(cx)
95                            }
96                        }
97                        Err(e) => {
98                            self.state = ResponseState::None;
99                            Poll::Ready(Some(Err(HttpError::Request(e))))
100                        }
101                    },
102                    Poll::Pending => Poll::Pending,
103                };
104            }
105            ResponseState::None => {
106                let resp = self
107                    .client
108                    .get(self.url.clone(), Some(self.start..self.end))
109                    .send();
110                self.state = ResponseState::Pending(Box::pin(resp));
111                return self.poll_next(cx);
112            }
113            ResponseState::Ready(resp) => {
114                let mut chunk = pin!(resp.chunk());
115                match chunk.try_poll_unpin(cx) {
116                    Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
117                    Poll::Ready(Ok(None)) => return Poll::Ready(None),
118                    Poll::Ready(Err(e)) => chunk_global = Err(e),
119                    Poll::Pending => return Poll::Pending,
120                };
121            }
122        };
123        match chunk_global {
124            Ok(chunk) => {
125                self.start += chunk.len() as u64;
126                Poll::Ready(Some(Ok(chunk)))
127            }
128            Err(e) => {
129                self.state = ResponseState::None;
130                Poll::Ready(Some(Err(HttpError::Chunk(e))))
131            }
132        }
133    }
134}
135
136impl<Client: HttpClient + 'static> SeqPuller for HttpPuller<Client> {
137    type Error = HttpError<Client>;
138    fn pull(&mut self) -> impl TryStream<Ok = Bytes, Error = Self::Error> + Send + Unpin {
139        SeqRequestStream {
140            state: match self.resp.lock().take() {
141                Some(resp) => ResponseState::Ready(resp),
142                None => {
143                    let req = self.client.get(self.url.clone(), None).send();
144                    ResponseState::Pending(Box::pin(req))
145                }
146            },
147            file_id: self.file_id.clone(),
148        }
149    }
150}
151struct SeqRequestStream<Client: HttpClient + 'static> {
152    state: ResponseState<Client>,
153    file_id: FileId,
154}
155impl<Client: HttpClient> Stream for SeqRequestStream<Client> {
156    type Item = Result<Bytes, HttpError<Client>>;
157    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
158        let chunk_global;
159        match &mut self.state {
160            ResponseState::Pending(resp) => {
161                return match resp.try_poll_unpin(cx) {
162                    Poll::Ready(resp) => match resp {
163                        Ok(resp) => {
164                            let etag = resp.headers().get("etag").ok();
165                            let last_modified = resp.headers().get("last-modified").ok();
166                            let new_file_id = FileId::new(etag, last_modified);
167                            if new_file_id != self.file_id {
168                                self.state = ResponseState::None;
169                                Poll::Ready(Some(Err(HttpError::MismatchedBody(new_file_id))))
170                            } else {
171                                self.state = ResponseState::Ready(resp);
172                                self.poll_next(cx)
173                            }
174                        }
175                        Err(e) => {
176                            self.state = ResponseState::None;
177                            Poll::Ready(Some(Err(HttpError::Request(e))))
178                        }
179                    },
180                    Poll::Pending => Poll::Pending,
181                };
182            }
183            ResponseState::None => return Poll::Ready(Some(Err(HttpError::Irrecoverable))),
184            ResponseState::Ready(resp) => {
185                let mut chunk = pin!(resp.chunk());
186                match chunk.try_poll_unpin(cx) {
187                    Poll::Ready(Ok(Some(chunk))) => chunk_global = Ok(chunk),
188                    Poll::Ready(Ok(None)) => return Poll::Ready(None),
189                    Poll::Ready(Err(e)) => chunk_global = Err(e),
190                    Poll::Pending => return Poll::Pending,
191                };
192            }
193        };
194        match chunk_global {
195            Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
196            Err(e) => {
197                self.state = ResponseState::None;
198                Poll::Ready(Some(Err(HttpError::Chunk(e))))
199            }
200        }
201    }
202}