fast_down/http/
puller.rs

1use crate::http::{
2    FileId, GetRequestError, GetResponse, HttpClient, HttpError, HttpHeaders, HttpRequestBuilder,
3    HttpResponse,
4};
5use bytes::Bytes;
6use fast_pull::{ProgressEntry, PullResult, PullStream, RandPuller, SeqPuller};
7use futures::Stream;
8use spin::mutex::SpinMutex;
9use std::{
10    fmt::Debug,
11    future::Future,
12    pin::Pin,
13    sync::Arc,
14    task::{Context, Poll},
15    time::Duration,
16};
17use url::Url;
18
19#[derive(Clone)]
20pub struct HttpPuller<Client: HttpClient> {
21    pub(crate) client: Client,
22    url: Url,
23    resp: Option<Arc<SpinMutex<Option<GetResponse<Client>>>>>,
24    file_id: FileId,
25}
26impl<Client: HttpClient> HttpPuller<Client> {
27    pub fn new(
28        url: Url,
29        client: Client,
30        resp: Option<Arc<SpinMutex<Option<GetResponse<Client>>>>>,
31        file_id: FileId,
32    ) -> Self {
33        Self {
34            client,
35            url,
36            resp,
37            file_id,
38        }
39    }
40}
41impl<Client: HttpClient> Debug for HttpPuller<Client> {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("HttpPuller")
44            .field("client", &"...")
45            .field("url", &self.url)
46            .field("resp", &"...")
47            .field("file_id", &self.file_id)
48            .finish()
49    }
50}
51
52type ResponseFut<Client> = Pin<
53    Box<
54        dyn Future<
55                Output = Result<GetResponse<Client>, (GetRequestError<Client>, Option<Duration>)>,
56            > + Send,
57    >,
58>;
59
60type ChunkStream<Client> = Pin<Box<dyn Stream<Item = Result<Bytes, HttpError<Client>>> + Send>>;
61
62enum ResponseState<Client: HttpClient> {
63    Pending(ResponseFut<Client>),
64    Streaming(ChunkStream<Client>),
65    None,
66}
67
68fn into_chunk_stream<Client: HttpClient + 'static>(
69    resp: GetResponse<Client>,
70) -> ChunkStream<Client> {
71    Box::pin(futures::stream::try_unfold(resp, |mut r| async move {
72        match r.chunk().await {
73            Ok(Some(chunk)) => Ok(Some((chunk, r))),
74            Ok(None) => Ok(None),
75            Err(e) => Err(HttpError::Chunk(e)),
76        }
77    }))
78}
79
80impl<Client: HttpClient + 'static> RandPuller for HttpPuller<Client> {
81    type Error = HttpError<Client>;
82    async fn pull(
83        &mut self,
84        range: &ProgressEntry,
85    ) -> PullResult<impl PullStream<Self::Error>, Self::Error> {
86        Ok(RandRequestStream {
87            client: self.client.clone(),
88            url: self.url.clone(),
89            start: range.start,
90            end: range.end,
91            state: if range.start == 0
92                && let Some(resp) = &self.resp
93                && let Some(resp) = resp.lock().take()
94            {
95                ResponseState::Streaming(into_chunk_stream(resp))
96            } else {
97                ResponseState::None
98            },
99            file_id: self.file_id.clone(),
100        })
101    }
102}
103struct RandRequestStream<Client: HttpClient + 'static> {
104    client: Client,
105    url: Url,
106    start: u64,
107    end: u64,
108    state: ResponseState<Client>,
109    file_id: FileId,
110}
111impl<Client: HttpClient> Stream for RandRequestStream<Client> {
112    type Item = Result<Bytes, (HttpError<Client>, Option<Duration>)>;
113    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
114        match &mut self.state {
115            ResponseState::Pending(resp) => match resp.as_mut().poll(cx) {
116                Poll::Ready(Ok(resp)) => {
117                    let new_file_id = FileId::new(
118                        resp.headers().get("etag").ok(),
119                        resp.headers().get("last-modified").ok(),
120                    );
121                    if new_file_id != self.file_id {
122                        self.state = ResponseState::None;
123                        Poll::Ready(Some(Err((HttpError::MismatchedBody(new_file_id), None))))
124                    } else {
125                        self.state = ResponseState::Streaming(into_chunk_stream(resp));
126                        self.poll_next(cx)
127                    }
128                }
129                Poll::Ready(Err((e, d))) => {
130                    self.state = ResponseState::None;
131                    Poll::Ready(Some(Err((HttpError::Request(e), d))))
132                }
133                Poll::Pending => Poll::Pending,
134            },
135            ResponseState::None => {
136                let resp = self
137                    .client
138                    .get(self.url.clone(), Some(self.start..self.end))
139                    .send();
140                self.state = ResponseState::Pending(Box::pin(resp));
141                self.poll_next(cx)
142            }
143            ResponseState::Streaming(stream) => match stream.as_mut().poll_next(cx) {
144                Poll::Ready(Some(Ok(chunk))) => {
145                    self.start += chunk.len() as u64;
146                    Poll::Ready(Some(Ok(chunk)))
147                }
148                Poll::Ready(Some(Err(e))) => {
149                    self.state = ResponseState::None;
150                    Poll::Ready(Some(Err((e, None))))
151                }
152                Poll::Ready(None) => Poll::Ready(None),
153                Poll::Pending => Poll::Pending,
154            },
155        }
156    }
157}
158
159impl<Client: HttpClient + 'static> SeqPuller for HttpPuller<Client> {
160    type Error = HttpError<Client>;
161    async fn pull(&mut self) -> PullResult<impl PullStream<Self::Error>, Self::Error> {
162        Ok(SeqRequestStream {
163            state: if let Some(resp) = &self.resp
164                && let Some(resp) = resp.lock().take()
165            {
166                ResponseState::Streaming(into_chunk_stream(resp))
167            } else {
168                let req = self.client.get(self.url.clone(), None).send();
169                ResponseState::Pending(Box::pin(req))
170            },
171            file_id: self.file_id.clone(),
172        })
173    }
174}
175struct SeqRequestStream<Client: HttpClient + 'static> {
176    state: ResponseState<Client>,
177    file_id: FileId,
178}
179impl<Client: HttpClient> Stream for SeqRequestStream<Client> {
180    type Item = Result<Bytes, (HttpError<Client>, Option<Duration>)>;
181    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182        match &mut self.state {
183            ResponseState::Pending(resp) => match resp.as_mut().poll(cx) {
184                Poll::Ready(Ok(resp)) => {
185                    let new_file_id = FileId::new(
186                        resp.headers().get("etag").ok(),
187                        resp.headers().get("last-modified").ok(),
188                    );
189                    if new_file_id != self.file_id {
190                        self.state = ResponseState::None;
191                        Poll::Ready(Some(Err((HttpError::MismatchedBody(new_file_id), None))))
192                    } else {
193                        self.state = ResponseState::Streaming(into_chunk_stream(resp));
194                        self.poll_next(cx)
195                    }
196                }
197                Poll::Ready(Err((e, d))) => {
198                    self.state = ResponseState::None;
199                    Poll::Ready(Some(Err((HttpError::Request(e), d))))
200                }
201                Poll::Pending => Poll::Pending,
202            },
203            ResponseState::None => Poll::Ready(Some(Err((HttpError::Irrecoverable, None)))),
204            ResponseState::Streaming(stream) => match stream.as_mut().poll_next(cx) {
205                Poll::Ready(Some(Ok(chunk))) => Poll::Ready(Some(Ok(chunk))),
206                Poll::Ready(Some(Err(e))) => {
207                    self.state = ResponseState::None;
208                    Poll::Ready(Some(Err((e, None))))
209                }
210                Poll::Ready(None) => Poll::Ready(None),
211                Poll::Pending => Poll::Pending,
212            },
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use crate::http::{
220        FileId, HttpClient, HttpHeaders, HttpPuller, HttpRequestBuilder, HttpResponse,
221    };
222    use bytes::Bytes;
223    use fast_pull::{ProgressEntry, RandPuller};
224    use futures::{Future, TryStreamExt, task::Context};
225    use spin::mutex::SpinMutex;
226    use std::{pin::Pin, sync::Arc, task::Poll, time::Duration};
227    use url::Url;
228
229    #[derive(Clone, Debug)]
230    struct MockClient;
231    impl HttpClient for MockClient {
232        type RequestBuilder = MockRequestBuilder;
233        fn get(&self, _url: Url, _range: Option<ProgressEntry>) -> Self::RequestBuilder {
234            MockRequestBuilder
235        }
236    }
237    struct MockRequestBuilder;
238    impl HttpRequestBuilder for MockRequestBuilder {
239        type Response = MockResponse;
240        type RequestError = MockError;
241        async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
242            Ok(MockResponse::new())
243        }
244    }
245    pub struct MockResponse {
246        headers: MockHeaders,
247        url: Url,
248    }
249    impl MockResponse {
250        fn new() -> Self {
251            Self {
252                headers: MockHeaders,
253                url: Url::parse("http://mock-url").unwrap(),
254            }
255        }
256    }
257    impl HttpResponse for MockResponse {
258        type Headers = MockHeaders;
259        type ChunkError = MockError;
260        fn headers(&self) -> &Self::Headers {
261            &self.headers
262        }
263        fn url(&self) -> &Url {
264            &self.url
265        }
266        async fn chunk(&mut self) -> Result<Option<Bytes>, Self::ChunkError> {
267            DelayChunk::new().await
268        }
269    }
270    pub struct MockHeaders;
271    impl HttpHeaders for MockHeaders {
272        type GetHeaderError = MockError;
273        fn get(&self, _header: &str) -> Result<&str, Self::GetHeaderError> {
274            Err(MockError)
275        }
276    }
277    #[derive(Debug)]
278    pub struct MockError;
279
280    struct DelayChunk {
281        polled_once: bool,
282    }
283    impl DelayChunk {
284        fn new() -> Self {
285            Self { polled_once: false }
286        }
287    }
288    impl Future for DelayChunk {
289        type Output = Result<Option<Bytes>, MockError>;
290        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
291            if !self.polled_once {
292                println!("Wait... [Mock: 模拟网络延迟 Pending]");
293                self.polled_once = true;
294                cx.waker().wake_by_ref();
295                return Poll::Pending;
296            }
297            println!("Done! [Mock: 数据到达 Ready]");
298            Poll::Ready(Ok(Some(Bytes::from_static(b"success"))))
299        }
300    }
301
302    #[tokio::test]
303    async fn test_http_puller_infinite_loop_fix() {
304        let url = Url::parse("http://localhost").unwrap();
305        let client = MockClient;
306        let file_id = FileId::new(None, None);
307        let mut puller =
308            HttpPuller::new(url, client, Some(Arc::new(SpinMutex::new(None))), file_id);
309        let range = 0..7;
310        let mut stream = puller.pull(&range).await.expect("Failed to create stream");
311        println!("--- 开始测试 HttpPuller ---");
312        let result =
313            tokio::time::timeout(Duration::from_secs(1), async { stream.try_next().await }).await;
314        match result {
315            Ok(Ok(Some(bytes))) => {
316                println!("收到数据: {:?}", bytes);
317                assert_eq!(bytes, Bytes::from_static(b"success"));
318                println!("测试通过:HttpPuller 正确处理了 Pending 状态!");
319            }
320            e => {
321                panic!(
322                    "测试失败:超时!这表明 HttpPuller 可能在收到 Pending 后丢失了 Future 状态并陷入了死循环。 {e:?}"
323                );
324            }
325        }
326    }
327}