fast-down 4.0.6

Download everything fast
Documentation
use crate::http::{
    FileId, GetRequestError, GetResponse, HttpClient, HttpError, HttpHeaders, HttpRequestBuilder,
    HttpResponse,
};
use bytes::Bytes;
use fast_pull::{ProgressEntry, PullResult, PullStream, Puller};
use futures::Stream;
use parking_lot::Mutex;
use std::{
    fmt::Debug,
    future::Future,
    ops::Range,
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
    time::Duration,
};
use url::Url;

pub struct HttpPuller<Client: HttpClient> {
    client: Client,
    url: Url,
    resp: Option<Arc<Mutex<Option<GetResponse<Client>>>>>,
    file_id: FileId,
}
impl<C: HttpClient> Clone for HttpPuller<C> {
    fn clone(&self) -> Self {
        Self {
            client: self.client.clone(),
            url: self.url.clone(),
            resp: self.resp.clone(),
            file_id: self.file_id.clone(),
        }
    }
}
impl<Client: HttpClient> HttpPuller<Client> {
    pub const fn new(
        url: Url,
        client: Client,
        resp: Option<Arc<Mutex<Option<GetResponse<Client>>>>>,
        file_id: FileId,
    ) -> Self {
        Self {
            client,
            url,
            resp,
            file_id,
        }
    }
}
impl<Client: HttpClient> Debug for HttpPuller<Client> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("HttpPuller")
            .field("url", &self.url)
            .field("file_id", &self.file_id)
            .field("client", &"...")
            .field("resp", &"...")
            .finish()
    }
}

type ResponseFut<Client> = Pin<
    Box<
        dyn Future<
                Output = Result<GetResponse<Client>, (GetRequestError<Client>, Option<Duration>)>,
            > + Send,
    >,
>;

type ChunkStream<Client> = Pin<Box<dyn Stream<Item = Result<Bytes, HttpError<Client>>> + Send>>;

enum ResponseState<Client: HttpClient> {
    Pending(ResponseFut<Client>),
    Streaming(ChunkStream<Client>),
    None,
}

fn into_chunk_stream<Client: HttpClient>(resp: GetResponse<Client>) -> ChunkStream<Client> {
    Box::pin(futures::stream::try_unfold(resp, |mut r| async move {
        match r.chunk().await {
            Ok(Some(chunk)) => Ok(Some((chunk, r))),
            Ok(None) => Ok(None),
            Err(e) => Err(HttpError::Chunk(e)),
        }
    }))
}

impl<Client: HttpClient> Puller for HttpPuller<Client> {
    type Error = HttpError<Client>;
    async fn pull(
        &mut self,
        range: Option<&ProgressEntry>,
    ) -> PullResult<impl PullStream<Self::Error>, Self::Error> {
        let range = range.cloned().unwrap_or(0..u64::MAX);
        Ok(RandRequestStream {
            client: self.client.clone(),
            url: self.url.clone(),
            state: if range.start == 0
                && let Some(resp) = &self.resp
                && let Some(resp) = resp.lock().take()
            {
                ResponseState::Streaming(into_chunk_stream(resp))
            } else if range.end == u64::MAX {
                let req = self.client.get(self.url.clone(), None).send();
                ResponseState::Pending(Box::pin(req))
            } else {
                ResponseState::None
            },
            range,
            file_id: self.file_id.clone(),
        })
    }
}
struct RandRequestStream<Client: HttpClient> {
    client: Client,
    url: Url,
    range: Range<u64>,
    state: ResponseState<Client>,
    file_id: FileId,
}
impl<Client: HttpClient> Stream for RandRequestStream<Client> {
    type Item = Result<Bytes, (HttpError<Client>, Option<Duration>)>;
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        loop {
            break match &mut self.state {
                ResponseState::Pending(resp) => match resp.as_mut().poll(cx) {
                    Poll::Ready(Ok(resp)) => {
                        let new_file_id = FileId::new(
                            resp.headers().get("etag").ok().as_deref(),
                            resp.headers().get("last-modified").ok().as_deref(),
                        );
                        if new_file_id == self.file_id {
                            self.state = ResponseState::Streaming(into_chunk_stream(resp));
                            continue;
                        }
                        self.state = ResponseState::None;
                        Poll::Ready(Some(Err((HttpError::MismatchedBody(new_file_id), None))))
                    }
                    Poll::Ready(Err((e, d))) => {
                        self.state = ResponseState::None;
                        Poll::Ready(Some(Err((HttpError::Request(e), d))))
                    }
                    Poll::Pending => Poll::Pending,
                },
                ResponseState::None => {
                    if self.range.end == u64::MAX {
                        break Poll::Ready(Some(Err((HttpError::Irrecoverable, None))));
                    }
                    let resp = self
                        .client
                        .get(self.url.clone(), Some(self.range.clone()))
                        .send();
                    self.state = ResponseState::Pending(Box::pin(resp));
                    continue;
                }
                ResponseState::Streaming(stream) => match stream.as_mut().poll_next(cx) {
                    Poll::Ready(Some(Ok(chunk))) => {
                        self.range.start += chunk.len() as u64;
                        Poll::Ready(Some(Ok(chunk)))
                    }
                    Poll::Ready(Some(Err(e))) => {
                        self.state = ResponseState::None;
                        Poll::Ready(Some(Err((e, None))))
                    }
                    Poll::Ready(None) => Poll::Ready(None),
                    Poll::Pending => Poll::Pending,
                },
            };
        }
    }
}

#[cfg(test)]
mod tests {
    #![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
    use std::borrow::Cow;

    use super::*;
    use futures::TryStreamExt;

    #[derive(Clone, Debug)]
    struct MockClient;
    impl HttpClient for MockClient {
        type RequestBuilder = MockRequestBuilder;
        fn get(&self, _url: Url, _range: Option<ProgressEntry>) -> Self::RequestBuilder {
            MockRequestBuilder
        }
    }
    struct MockRequestBuilder;
    impl HttpRequestBuilder for MockRequestBuilder {
        type Response = MockResponse;
        type RequestError = MockError;
        async fn send(self) -> Result<Self::Response, (Self::RequestError, Option<Duration>)> {
            Ok(MockResponse::new())
        }
    }
    struct MockResponse {
        headers: MockHeaders,
        url: Url,
    }
    impl MockResponse {
        fn new() -> Self {
            Self {
                headers: MockHeaders,
                url: Url::parse("http://mock-url").unwrap(),
            }
        }
    }
    impl HttpResponse for MockResponse {
        type Headers = MockHeaders;
        type ChunkError = MockError;
        fn headers(&self) -> &Self::Headers {
            &self.headers
        }
        fn url(&self) -> &Url {
            &self.url
        }
        async fn chunk(&mut self) -> Result<Option<Bytes>, Self::ChunkError> {
            DelayChunk::new().await
        }
    }
    struct MockHeaders;
    impl HttpHeaders for MockHeaders {
        type GetHeaderError = MockError;
        fn get(&self, _header: &str) -> Result<Cow<'_, str>, Self::GetHeaderError> {
            Err(MockError)
        }
    }
    #[derive(Debug)]
    struct MockError;

    struct DelayChunk {
        polled_once: bool,
    }
    impl DelayChunk {
        fn new() -> Self {
            Self { polled_once: false }
        }
    }
    impl Future for DelayChunk {
        type Output = Result<Option<Bytes>, MockError>;
        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
            if !self.polled_once {
                println!("Wait... [Mock: 模拟网络延迟 Pending]");
                self.polled_once = true;
                cx.waker().wake_by_ref();
                return Poll::Pending;
            }
            println!("Done! [Mock: 数据到达 Ready]");
            Poll::Ready(Ok(Some(Bytes::from_static(b"success"))))
        }
    }

    #[tokio::test]
    async fn test_http_puller_infinite_loop_fix() {
        let url = Url::parse("http://localhost").unwrap();
        let client = MockClient;
        let file_id = FileId::new(None, None);
        let mut puller = HttpPuller::new(url, client, None, file_id);
        let range = 0..7;
        let mut stream = Puller::pull(&mut puller, Some(&range))
            .await
            .expect("Failed to create stream");
        println!("--- 开始测试 HttpPuller ---");
        let result =
            tokio::time::timeout(Duration::from_secs(1), async { stream.try_next().await }).await;
        match result {
            Ok(Ok(Some(bytes))) => {
                println!("收到数据: {bytes:?}");
                assert_eq!(bytes, Bytes::from_static(b"success"));
                println!("测试通过:HttpPuller 正确处理了 Pending 状态!");
            }
            e => {
                panic!(
                    "测试失败:超时!这表明 HttpPuller 可能在收到 Pending 后丢失了 Future 状态并陷入了死循环。 {e:?}"
                );
            }
        }
    }
}