1use crate::http::{
2 FileId, GetRequestError, GetResponse, HttpClient, HttpError, HttpHeaders, HttpRequestBuilder,
3 HttpResponse,
4};
5use bytes::Bytes;
6use fast_pull::{ProgressEntry, PullResult, PullStream, Puller};
7use futures::Stream;
8use parking_lot::Mutex;
9use std::{
10 fmt::Debug,
11 future::Future,
12 ops::Range,
13 pin::Pin,
14 sync::Arc,
15 task::{Context, Poll},
16 time::Duration,
17};
18use url::Url;
19
20#[derive(Clone)]
21pub struct HttpPuller<Client: HttpClient> {
22 client: Client,
23 url: Url,
24 resp: Option<Arc<Mutex<Option<GetResponse<Client>>>>>,
25 file_id: FileId,
26}
27impl<Client: HttpClient> HttpPuller<Client> {
28 pub fn new(
29 url: Url,
30 client: Client,
31 resp: Option<Arc<Mutex<Option<GetResponse<Client>>>>>,
32 file_id: FileId,
33 ) -> Self {
34 Self {
35 client,
36 url,
37 resp,
38 file_id,
39 }
40 }
41}
42impl<Client: HttpClient> Debug for HttpPuller<Client> {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("HttpPuller")
45 .field("url", &self.url)
46 .field("file_id", &self.file_id)
47 .field("client", &"...")
48 .field("resp", &"...")
49 .finish()
50 }
51}
52
53type ResponseFut<Client> = Pin<
54 Box<
55 dyn Future<
56 Output = Result<GetResponse<Client>, (GetRequestError<Client>, Option<Duration>)>,
57 > + Send,
58 >,
59>;
60
61type ChunkStream<Client> = Pin<Box<dyn Stream<Item = Result<Bytes, HttpError<Client>>> + Send>>;
62
63enum ResponseState<Client: HttpClient> {
64 Pending(ResponseFut<Client>),
65 Streaming(ChunkStream<Client>),
66 None,
67}
68
69fn into_chunk_stream<Client: HttpClient>(resp: GetResponse<Client>) -> ChunkStream<Client> {
70 Box::pin(futures::stream::try_unfold(resp, |mut r| async move {
71 match r.chunk().await {
72 Ok(Some(chunk)) => Ok(Some((chunk, r))),
73 Ok(None) => Ok(None),
74 Err(e) => Err(HttpError::Chunk(e)),
75 }
76 }))
77}
78
79impl<Client: HttpClient> Puller for HttpPuller<Client> {
80 type Error = HttpError<Client>;
81 async fn pull(
82 &mut self,
83 range: Option<&ProgressEntry>,
84 ) -> PullResult<impl PullStream<Self::Error>, Self::Error> {
85 let range = range.cloned().unwrap_or(0..u64::MAX);
86 Ok(RandRequestStream {
87 client: self.client.clone(),
88 url: self.url.clone(),
89 state: if range.start == 0
90 && let Some(resp) = &self.resp
91 && let Some(resp) = resp.lock().take()
92 {
93 ResponseState::Streaming(into_chunk_stream(resp))
94 } else if range.end == u64::MAX {
95 let req = self.client.get(self.url.clone(), None).send();
96 ResponseState::Pending(Box::pin(req))
97 } else {
98 ResponseState::None
99 },
100 range,
101 file_id: self.file_id.clone(),
102 })
103 }
104}
105struct RandRequestStream<Client: HttpClient> {
106 client: Client,
107 url: Url,
108 range: Range<u64>,
109 state: ResponseState<Client>,
110 file_id: FileId,
111}
112impl<Client: HttpClient> Stream for RandRequestStream<Client> {
113 type Item = Result<Bytes, (HttpError<Client>, Option<Duration>)>;
114 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
115 loop {
116 break match &mut self.state {
117 ResponseState::Pending(resp) => match resp.as_mut().poll(cx) {
118 Poll::Ready(Ok(resp)) => {
119 let new_file_id = FileId::new(
120 resp.headers().get("etag").ok(),
121 resp.headers().get("last-modified").ok(),
122 );
123 if new_file_id != self.file_id {
124 self.state = ResponseState::None;
125 Poll::Ready(Some(Err((HttpError::MismatchedBody(new_file_id), None))))
126 } else {
127 self.state = ResponseState::Streaming(into_chunk_stream(resp));
128 continue;
129 }
130 }
131 Poll::Ready(Err((e, d))) => {
132 self.state = ResponseState::None;
133 Poll::Ready(Some(Err((HttpError::Request(e), d))))
134 }
135 Poll::Pending => Poll::Pending,
136 },
137 ResponseState::None => {
138 if self.range.end == u64::MAX {
139 break Poll::Ready(Some(Err((HttpError::Irrecoverable, None))));
140 } else {
141 let resp = self
142 .client
143 .get(self.url.clone(), Some(self.range.clone()))
144 .send();
145 self.state = ResponseState::Pending(Box::pin(resp));
146 continue;
147 }
148 }
149 ResponseState::Streaming(stream) => match stream.as_mut().poll_next(cx) {
150 Poll::Ready(Some(Ok(chunk))) => {
151 self.range.start += chunk.len() as u64;
152 Poll::Ready(Some(Ok(chunk)))
153 }
154 Poll::Ready(Some(Err(e))) => {
155 self.state = ResponseState::None;
156 Poll::Ready(Some(Err((e, None))))
157 }
158 Poll::Ready(None) => Poll::Ready(None),
159 Poll::Pending => Poll::Pending,
160 },
161 };
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use futures::TryStreamExt;
170
171 #[derive(Clone, Debug)]
172 struct MockClient;
173 impl HttpClient for MockClient {
174 type RequestBuilder = MockRequestBuilder;
175 fn get(&self, _url: Url, _range: Option<ProgressEntry>) -> Self::RequestBuilder {
176 MockRequestBuilder
177 }
178 }
179 struct MockRequestBuilder;
180 impl HttpRequestBuilder for MockRequestBuilder {
181 type Response = MockResponse;
182 type RequestError = MockError;
183 async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
184 Ok(MockResponse::new())
185 }
186 }
187 pub struct MockResponse {
188 headers: MockHeaders,
189 url: Url,
190 }
191 impl MockResponse {
192 fn new() -> Self {
193 Self {
194 headers: MockHeaders,
195 url: Url::parse("http://mock-url").unwrap(),
196 }
197 }
198 }
199 impl HttpResponse for MockResponse {
200 type Headers = MockHeaders;
201 type ChunkError = MockError;
202 fn headers(&self) -> &Self::Headers {
203 &self.headers
204 }
205 fn url(&self) -> &Url {
206 &self.url
207 }
208 async fn chunk(&mut self) -> Result<Option<Bytes>, Self::ChunkError> {
209 DelayChunk::new().await
210 }
211 }
212 pub struct MockHeaders;
213 impl HttpHeaders for MockHeaders {
214 type GetHeaderError = MockError;
215 fn get(&self, _header: &str) -> Result<&str, Self::GetHeaderError> {
216 Err(MockError)
217 }
218 }
219 #[derive(Debug)]
220 pub struct MockError;
221
222 struct DelayChunk {
223 polled_once: bool,
224 }
225 impl DelayChunk {
226 fn new() -> Self {
227 Self { polled_once: false }
228 }
229 }
230 impl Future for DelayChunk {
231 type Output = Result<Option<Bytes>, MockError>;
232 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233 if !self.polled_once {
234 println!("Wait... [Mock: 模拟网络延迟 Pending]");
235 self.polled_once = true;
236 cx.waker().wake_by_ref();
237 return Poll::Pending;
238 }
239 println!("Done! [Mock: 数据到达 Ready]");
240 Poll::Ready(Ok(Some(Bytes::from_static(b"success"))))
241 }
242 }
243
244 #[tokio::test]
245 async fn test_http_puller_infinite_loop_fix() {
246 let url = Url::parse("http://localhost").unwrap();
247 let client = MockClient;
248 let file_id = FileId::new(None, None);
249 let mut puller = HttpPuller::new(url, client, None, file_id);
250 let range = 0..7;
251 let mut stream = Puller::pull(&mut puller, Some(&range))
252 .await
253 .expect("Failed to create stream");
254 println!("--- 开始测试 HttpPuller ---");
255 let result =
256 tokio::time::timeout(Duration::from_secs(1), async { stream.try_next().await }).await;
257 match result {
258 Ok(Ok(Some(bytes))) => {
259 println!("收到数据: {:?}", bytes);
260 assert_eq!(bytes, Bytes::from_static(b"success"));
261 println!("测试通过:HttpPuller 正确处理了 Pending 状态!");
262 }
263 e => {
264 panic!(
265 "测试失败:超时!这表明 HttpPuller 可能在收到 Pending 后丢失了 Future 状态并陷入了死循环。 {e:?}"
266 );
267 }
268 }
269 }
270}