Skip to main content

fast_down/http/
puller.rs

1use crate::http::{
2    FileId, GetRequestError, GetResponse, HttpClient, HttpError, HttpHeaders, HttpRequestBuilder,
3    HttpResponse,
4};
5use bytes::Bytes;
6use fast_pull::{ProgressEntry, PullResult, PullStream, Puller};
7use futures::Stream;
8use parking_lot::Mutex;
9use std::{
10    fmt::Debug,
11    future::Future,
12    ops::Range,
13    pin::Pin,
14    sync::Arc,
15    task::{Context, Poll},
16    time::Duration,
17};
18use url::Url;
19
20#[derive(Clone)]
21pub struct HttpPuller<Client: HttpClient> {
22    client: Client,
23    url: Url,
24    resp: Option<Arc<Mutex<Option<GetResponse<Client>>>>>,
25    file_id: FileId,
26}
27impl<Client: HttpClient> HttpPuller<Client> {
28    pub fn new(
29        url: Url,
30        client: Client,
31        resp: Option<Arc<Mutex<Option<GetResponse<Client>>>>>,
32        file_id: FileId,
33    ) -> Self {
34        Self {
35            client,
36            url,
37            resp,
38            file_id,
39        }
40    }
41}
42impl<Client: HttpClient> Debug for HttpPuller<Client> {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("HttpPuller")
45            .field("url", &self.url)
46            .field("file_id", &self.file_id)
47            .field("client", &"...")
48            .field("resp", &"...")
49            .finish()
50    }
51}
52
53type ResponseFut<Client> = Pin<
54    Box<
55        dyn Future<
56                Output = Result<GetResponse<Client>, (GetRequestError<Client>, Option<Duration>)>,
57            > + Send,
58    >,
59>;
60
61type ChunkStream<Client> = Pin<Box<dyn Stream<Item = Result<Bytes, HttpError<Client>>> + Send>>;
62
63enum ResponseState<Client: HttpClient> {
64    Pending(ResponseFut<Client>),
65    Streaming(ChunkStream<Client>),
66    None,
67}
68
69fn into_chunk_stream<Client: HttpClient>(resp: GetResponse<Client>) -> ChunkStream<Client> {
70    Box::pin(futures::stream::try_unfold(resp, |mut r| async move {
71        match r.chunk().await {
72            Ok(Some(chunk)) => Ok(Some((chunk, r))),
73            Ok(None) => Ok(None),
74            Err(e) => Err(HttpError::Chunk(e)),
75        }
76    }))
77}
78
79impl<Client: HttpClient> Puller for HttpPuller<Client> {
80    type Error = HttpError<Client>;
81    async fn pull(
82        &mut self,
83        range: Option<&ProgressEntry>,
84    ) -> PullResult<impl PullStream<Self::Error>, Self::Error> {
85        let range = range.cloned().unwrap_or(0..u64::MAX);
86        Ok(RandRequestStream {
87            client: self.client.clone(),
88            url: self.url.clone(),
89            state: if range.start == 0
90                && let Some(resp) = &self.resp
91                && let Some(resp) = resp.lock().take()
92            {
93                ResponseState::Streaming(into_chunk_stream(resp))
94            } else if range.end == u64::MAX {
95                let req = self.client.get(self.url.clone(), None).send();
96                ResponseState::Pending(Box::pin(req))
97            } else {
98                ResponseState::None
99            },
100            range,
101            file_id: self.file_id.clone(),
102        })
103    }
104}
105struct RandRequestStream<Client: HttpClient> {
106    client: Client,
107    url: Url,
108    range: Range<u64>,
109    state: ResponseState<Client>,
110    file_id: FileId,
111}
112impl<Client: HttpClient> Stream for RandRequestStream<Client> {
113    type Item = Result<Bytes, (HttpError<Client>, Option<Duration>)>;
114    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
115        loop {
116            break match &mut self.state {
117                ResponseState::Pending(resp) => match resp.as_mut().poll(cx) {
118                    Poll::Ready(Ok(resp)) => {
119                        let new_file_id = FileId::new(
120                            resp.headers().get("etag").ok(),
121                            resp.headers().get("last-modified").ok(),
122                        );
123                        if new_file_id != self.file_id {
124                            self.state = ResponseState::None;
125                            Poll::Ready(Some(Err((HttpError::MismatchedBody(new_file_id), None))))
126                        } else {
127                            self.state = ResponseState::Streaming(into_chunk_stream(resp));
128                            continue;
129                        }
130                    }
131                    Poll::Ready(Err((e, d))) => {
132                        self.state = ResponseState::None;
133                        Poll::Ready(Some(Err((HttpError::Request(e), d))))
134                    }
135                    Poll::Pending => Poll::Pending,
136                },
137                ResponseState::None => {
138                    if self.range.end == u64::MAX {
139                        break Poll::Ready(Some(Err((HttpError::Irrecoverable, None))));
140                    } else {
141                        let resp = self
142                            .client
143                            .get(self.url.clone(), Some(self.range.clone()))
144                            .send();
145                        self.state = ResponseState::Pending(Box::pin(resp));
146                        continue;
147                    }
148                }
149                ResponseState::Streaming(stream) => match stream.as_mut().poll_next(cx) {
150                    Poll::Ready(Some(Ok(chunk))) => {
151                        self.range.start += chunk.len() as u64;
152                        Poll::Ready(Some(Ok(chunk)))
153                    }
154                    Poll::Ready(Some(Err(e))) => {
155                        self.state = ResponseState::None;
156                        Poll::Ready(Some(Err((e, None))))
157                    }
158                    Poll::Ready(None) => Poll::Ready(None),
159                    Poll::Pending => Poll::Pending,
160                },
161            };
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use futures::TryStreamExt;
170
171    #[derive(Clone, Debug)]
172    struct MockClient;
173    impl HttpClient for MockClient {
174        type RequestBuilder = MockRequestBuilder;
175        fn get(&self, _url: Url, _range: Option<ProgressEntry>) -> Self::RequestBuilder {
176            MockRequestBuilder
177        }
178    }
179    struct MockRequestBuilder;
180    impl HttpRequestBuilder for MockRequestBuilder {
181        type Response = MockResponse;
182        type RequestError = MockError;
183        async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
184            Ok(MockResponse::new())
185        }
186    }
187    pub struct MockResponse {
188        headers: MockHeaders,
189        url: Url,
190    }
191    impl MockResponse {
192        fn new() -> Self {
193            Self {
194                headers: MockHeaders,
195                url: Url::parse("http://mock-url").unwrap(),
196            }
197        }
198    }
199    impl HttpResponse for MockResponse {
200        type Headers = MockHeaders;
201        type ChunkError = MockError;
202        fn headers(&self) -> &Self::Headers {
203            &self.headers
204        }
205        fn url(&self) -> &Url {
206            &self.url
207        }
208        async fn chunk(&mut self) -> Result<Option<Bytes>, Self::ChunkError> {
209            DelayChunk::new().await
210        }
211    }
212    pub struct MockHeaders;
213    impl HttpHeaders for MockHeaders {
214        type GetHeaderError = MockError;
215        fn get(&self, _header: &str) -> Result<&str, Self::GetHeaderError> {
216            Err(MockError)
217        }
218    }
219    #[derive(Debug)]
220    pub struct MockError;
221
222    struct DelayChunk {
223        polled_once: bool,
224    }
225    impl DelayChunk {
226        fn new() -> Self {
227            Self { polled_once: false }
228        }
229    }
230    impl Future for DelayChunk {
231        type Output = Result<Option<Bytes>, MockError>;
232        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233            if !self.polled_once {
234                println!("Wait... [Mock: 模拟网络延迟 Pending]");
235                self.polled_once = true;
236                cx.waker().wake_by_ref();
237                return Poll::Pending;
238            }
239            println!("Done! [Mock: 数据到达 Ready]");
240            Poll::Ready(Ok(Some(Bytes::from_static(b"success"))))
241        }
242    }
243
244    #[tokio::test]
245    async fn test_http_puller_infinite_loop_fix() {
246        let url = Url::parse("http://localhost").unwrap();
247        let client = MockClient;
248        let file_id = FileId::new(None, None);
249        let mut puller = HttpPuller::new(url, client, None, file_id);
250        let range = 0..7;
251        let mut stream = Puller::pull(&mut puller, Some(&range))
252            .await
253            .expect("Failed to create stream");
254        println!("--- 开始测试 HttpPuller ---");
255        let result =
256            tokio::time::timeout(Duration::from_secs(1), async { stream.try_next().await }).await;
257        match result {
258            Ok(Ok(Some(bytes))) => {
259                println!("收到数据: {:?}", bytes);
260                assert_eq!(bytes, Bytes::from_static(b"success"));
261                println!("测试通过:HttpPuller 正确处理了 Pending 状态!");
262            }
263            e => {
264                panic!(
265                    "测试失败:超时!这表明 HttpPuller 可能在收到 Pending 后丢失了 Future 状态并陷入了死循环。 {e:?}"
266                );
267            }
268        }
269    }
270}