1use crate::http::{
2 FileId, GetRequestError, GetResponse, HttpClient, HttpError, HttpHeaders, HttpRequestBuilder,
3 HttpResponse,
4};
5use bytes::Bytes;
6use fast_pull::{ProgressEntry, PullResult, PullStream, RandPuller, SeqPuller};
7use futures::Stream;
8use spin::mutex::SpinMutex;
9use std::{
10 fmt::Debug,
11 future::Future,
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15 time::Duration,
16};
17use url::Url;
18
19#[derive(Clone)]
20pub struct HttpPuller<Client: HttpClient> {
21 pub(crate) client: Client,
22 url: Url,
23 resp: Option<Arc<SpinMutex<Option<GetResponse<Client>>>>>,
24 file_id: FileId,
25}
26impl<Client: HttpClient> HttpPuller<Client> {
27 pub fn new(
28 url: Url,
29 client: Client,
30 resp: Option<Arc<SpinMutex<Option<GetResponse<Client>>>>>,
31 file_id: FileId,
32 ) -> Self {
33 Self {
34 client,
35 url,
36 resp,
37 file_id,
38 }
39 }
40}
41impl<Client: HttpClient> Debug for HttpPuller<Client> {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("HttpPuller")
44 .field("client", &"...")
45 .field("url", &self.url)
46 .field("resp", &"...")
47 .field("file_id", &self.file_id)
48 .finish()
49 }
50}
51
52type ResponseFut<Client> = Pin<
53 Box<
54 dyn Future<
55 Output = Result<GetResponse<Client>, (GetRequestError<Client>, Option<Duration>)>,
56 > + Send,
57 >,
58>;
59
60type ChunkStream<Client> = Pin<Box<dyn Stream<Item = Result<Bytes, HttpError<Client>>> + Send>>;
61
62enum ResponseState<Client: HttpClient> {
63 Pending(ResponseFut<Client>),
64 Streaming(ChunkStream<Client>),
65 None,
66}
67
68fn into_chunk_stream<Client: HttpClient + 'static>(
69 resp: GetResponse<Client>,
70) -> ChunkStream<Client> {
71 Box::pin(futures::stream::try_unfold(resp, |mut r| async move {
72 match r.chunk().await {
73 Ok(Some(chunk)) => Ok(Some((chunk, r))),
74 Ok(None) => Ok(None),
75 Err(e) => Err(HttpError::Chunk(e)),
76 }
77 }))
78}
79
80impl<Client: HttpClient + 'static> RandPuller for HttpPuller<Client> {
81 type Error = HttpError<Client>;
82 async fn pull(
83 &mut self,
84 range: &ProgressEntry,
85 ) -> PullResult<impl PullStream<Self::Error>, Self::Error> {
86 Ok(RandRequestStream {
87 client: self.client.clone(),
88 url: self.url.clone(),
89 start: range.start,
90 end: range.end,
91 state: if range.start == 0
92 && let Some(resp) = &self.resp
93 && let Some(resp) = resp.lock().take()
94 {
95 ResponseState::Streaming(into_chunk_stream(resp))
96 } else {
97 ResponseState::None
98 },
99 file_id: self.file_id.clone(),
100 })
101 }
102}
103struct RandRequestStream<Client: HttpClient + 'static> {
104 client: Client,
105 url: Url,
106 start: u64,
107 end: u64,
108 state: ResponseState<Client>,
109 file_id: FileId,
110}
111impl<Client: HttpClient> Stream for RandRequestStream<Client> {
112 type Item = Result<Bytes, (HttpError<Client>, Option<Duration>)>;
113 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
114 match &mut self.state {
115 ResponseState::Pending(resp) => match resp.as_mut().poll(cx) {
116 Poll::Ready(Ok(resp)) => {
117 let new_file_id = FileId::new(
118 resp.headers().get("etag").ok(),
119 resp.headers().get("last-modified").ok(),
120 );
121 if new_file_id != self.file_id {
122 self.state = ResponseState::None;
123 Poll::Ready(Some(Err((HttpError::MismatchedBody(new_file_id), None))))
124 } else {
125 self.state = ResponseState::Streaming(into_chunk_stream(resp));
126 self.poll_next(cx)
127 }
128 }
129 Poll::Ready(Err((e, d))) => {
130 self.state = ResponseState::None;
131 Poll::Ready(Some(Err((HttpError::Request(e), d))))
132 }
133 Poll::Pending => Poll::Pending,
134 },
135 ResponseState::None => {
136 let resp = self
137 .client
138 .get(self.url.clone(), Some(self.start..self.end))
139 .send();
140 self.state = ResponseState::Pending(Box::pin(resp));
141 self.poll_next(cx)
142 }
143 ResponseState::Streaming(stream) => match stream.as_mut().poll_next(cx) {
144 Poll::Ready(Some(Ok(chunk))) => {
145 self.start += chunk.len() as u64;
146 Poll::Ready(Some(Ok(chunk)))
147 }
148 Poll::Ready(Some(Err(e))) => {
149 self.state = ResponseState::None;
150 Poll::Ready(Some(Err((e, None))))
151 }
152 Poll::Ready(None) => Poll::Ready(None),
153 Poll::Pending => Poll::Pending,
154 },
155 }
156 }
157}
158
159impl<Client: HttpClient + 'static> SeqPuller for HttpPuller<Client> {
160 type Error = HttpError<Client>;
161 async fn pull(&mut self) -> PullResult<impl PullStream<Self::Error>, Self::Error> {
162 Ok(SeqRequestStream {
163 state: if let Some(resp) = &self.resp
164 && let Some(resp) = resp.lock().take()
165 {
166 ResponseState::Streaming(into_chunk_stream(resp))
167 } else {
168 let req = self.client.get(self.url.clone(), None).send();
169 ResponseState::Pending(Box::pin(req))
170 },
171 file_id: self.file_id.clone(),
172 })
173 }
174}
175struct SeqRequestStream<Client: HttpClient + 'static> {
176 state: ResponseState<Client>,
177 file_id: FileId,
178}
179impl<Client: HttpClient> Stream for SeqRequestStream<Client> {
180 type Item = Result<Bytes, (HttpError<Client>, Option<Duration>)>;
181 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182 match &mut self.state {
183 ResponseState::Pending(resp) => match resp.as_mut().poll(cx) {
184 Poll::Ready(Ok(resp)) => {
185 let new_file_id = FileId::new(
186 resp.headers().get("etag").ok(),
187 resp.headers().get("last-modified").ok(),
188 );
189 if new_file_id != self.file_id {
190 self.state = ResponseState::None;
191 Poll::Ready(Some(Err((HttpError::MismatchedBody(new_file_id), None))))
192 } else {
193 self.state = ResponseState::Streaming(into_chunk_stream(resp));
194 self.poll_next(cx)
195 }
196 }
197 Poll::Ready(Err((e, d))) => {
198 self.state = ResponseState::None;
199 Poll::Ready(Some(Err((HttpError::Request(e), d))))
200 }
201 Poll::Pending => Poll::Pending,
202 },
203 ResponseState::None => Poll::Ready(Some(Err((HttpError::Irrecoverable, None)))),
204 ResponseState::Streaming(stream) => match stream.as_mut().poll_next(cx) {
205 Poll::Ready(Some(Ok(chunk))) => Poll::Ready(Some(Ok(chunk))),
206 Poll::Ready(Some(Err(e))) => {
207 self.state = ResponseState::None;
208 Poll::Ready(Some(Err((e, None))))
209 }
210 Poll::Ready(None) => Poll::Ready(None),
211 Poll::Pending => Poll::Pending,
212 },
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use crate::http::{
220 FileId, HttpClient, HttpHeaders, HttpPuller, HttpRequestBuilder, HttpResponse,
221 };
222 use bytes::Bytes;
223 use fast_pull::{ProgressEntry, RandPuller};
224 use futures::{Future, TryStreamExt, task::Context};
225 use spin::mutex::SpinMutex;
226 use std::{pin::Pin, sync::Arc, task::Poll, time::Duration};
227 use url::Url;
228
229 #[derive(Clone, Debug)]
230 struct MockClient;
231 impl HttpClient for MockClient {
232 type RequestBuilder = MockRequestBuilder;
233 fn get(&self, _url: Url, _range: Option<ProgressEntry>) -> Self::RequestBuilder {
234 MockRequestBuilder
235 }
236 }
237 struct MockRequestBuilder;
238 impl HttpRequestBuilder for MockRequestBuilder {
239 type Response = MockResponse;
240 type RequestError = MockError;
241 async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
242 Ok(MockResponse::new())
243 }
244 }
245 pub struct MockResponse {
246 headers: MockHeaders,
247 url: Url,
248 }
249 impl MockResponse {
250 fn new() -> Self {
251 Self {
252 headers: MockHeaders,
253 url: Url::parse("http://mock-url").unwrap(),
254 }
255 }
256 }
257 impl HttpResponse for MockResponse {
258 type Headers = MockHeaders;
259 type ChunkError = MockError;
260 fn headers(&self) -> &Self::Headers {
261 &self.headers
262 }
263 fn url(&self) -> &Url {
264 &self.url
265 }
266 async fn chunk(&mut self) -> Result<Option<Bytes>, Self::ChunkError> {
267 DelayChunk::new().await
268 }
269 }
270 pub struct MockHeaders;
271 impl HttpHeaders for MockHeaders {
272 type GetHeaderError = MockError;
273 fn get(&self, _header: &str) -> Result<&str, Self::GetHeaderError> {
274 Err(MockError)
275 }
276 }
277 #[derive(Debug)]
278 pub struct MockError;
279
280 struct DelayChunk {
281 polled_once: bool,
282 }
283 impl DelayChunk {
284 fn new() -> Self {
285 Self { polled_once: false }
286 }
287 }
288 impl Future for DelayChunk {
289 type Output = Result<Option<Bytes>, MockError>;
290 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
291 if !self.polled_once {
292 println!("Wait... [Mock: 模拟网络延迟 Pending]");
293 self.polled_once = true;
294 cx.waker().wake_by_ref();
295 return Poll::Pending;
296 }
297 println!("Done! [Mock: 数据到达 Ready]");
298 Poll::Ready(Ok(Some(Bytes::from_static(b"success"))))
299 }
300 }
301
302 #[tokio::test]
303 async fn test_http_puller_infinite_loop_fix() {
304 let url = Url::parse("http://localhost").unwrap();
305 let client = MockClient;
306 let file_id = FileId::new(None, None);
307 let mut puller =
308 HttpPuller::new(url, client, Some(Arc::new(SpinMutex::new(None))), file_id);
309 let range = 0..7;
310 let mut stream = puller.pull(&range).await.expect("Failed to create stream");
311 println!("--- 开始测试 HttpPuller ---");
312 let result =
313 tokio::time::timeout(Duration::from_secs(1), async { stream.try_next().await }).await;
314 match result {
315 Ok(Ok(Some(bytes))) => {
316 println!("收到数据: {:?}", bytes);
317 assert_eq!(bytes, Bytes::from_static(b"success"));
318 println!("测试通过:HttpPuller 正确处理了 Pending 状态!");
319 }
320 e => {
321 panic!(
322 "测试失败:超时!这表明 HttpPuller 可能在收到 Pending 后丢失了 Future 状态并陷入了死循环。 {e:?}"
323 );
324 }
325 }
326 }
327}