http-downloader 0.2.3

A http download library that supports multithreading and resumable
Documentation
use std::io::SeekFrom;
use std::ops::Deref;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;

use anyhow::Result;
use bytes::Bytes;
use futures_util::future::BoxFuture;
use futures_util::StreamExt;
use headers::HeaderMapExt;
use reqwest::Request;
use tokio::fs::File;
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::select;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
#[cfg(feature = "tracing")]
use tracing::Instrument;

use crate::{ChunkInfo, ChunkManager, ChunkRange, DownloadError};

#[derive(Debug)]
pub enum ChunkMessageKind {
    DownloadFinished,
    DownloadCancelled,
    DownloadLenAppend(usize),
    Error(DownloadError),
}

#[derive(Debug)]
pub struct ChunkMessageInfo {
    pub chunk_index: usize,
    pub kind: ChunkMessageKind,
}

pub trait DownloadedLenChangeNotify: Send + Sync {
    fn receive_len(&self, len: usize) -> Option<BoxFuture<()>>;
}

pub struct ChunkItem {
    pub chunk_info: ChunkInfo,
    pub downloaded_len: AtomicU64,
    cancel_token: CancellationToken,
    client: reqwest::Client,
    sender: tokio::sync::mpsc::Sender<ChunkMessageInfo>,
    file: Arc<Mutex<File>>,
    etag: Option<headers::ETag>,
    downloaded_len_receiver: Option<Arc<dyn DownloadedLenChangeNotify>>,
}

impl ChunkItem {
    pub fn new(
        chunk_info: ChunkInfo,
        cancel_token: CancellationToken,
        client: reqwest::Client,
        sender: tokio::sync::mpsc::Sender<ChunkMessageInfo>,
        file: Arc<Mutex<File>>,
        downloaded_len_receiver: Option<Arc<dyn DownloadedLenChangeNotify>>,
        etag: Option<headers::ETag>,
    ) -> Self {
        Self {
            downloaded_len: AtomicU64::new(0),
            cancel_token,
            client,
            chunk_info,
            file,
            downloaded_len_receiver,
            sender,
            etag,
        }
    }
    async fn send_message(
        &self,
        message_kind: ChunkMessageKind,
    ) -> Result<(), SendError<ChunkMessageInfo>> {
        self.sender
            .send(ChunkMessageInfo {
                chunk_index: self.chunk_info.index,
                kind: message_kind,
            })
            .await
    }

    #[inline]
    async fn add_downloaded_len(&self, len: usize) {
        self.downloaded_len.fetch_add(len as u64, Ordering::Relaxed);
        debug_assert!(
            self.downloaded_len.load(Ordering::SeqCst) <= self.chunk_info.range.len(),
            "downloaded_len:{},chunk_info.range.len():{}",
            self.downloaded_len.load(Ordering::SeqCst),
            self.chunk_info.range.len()
        );
        self.send_message(ChunkMessageKind::DownloadLenAppend(len))
            .await
            .unwrap_or_else(|_err| {
                #[cfg(feature = "tracing")]
                tracing::trace!("ChunkMessageInfoSendFailed! {:?}", _err);
            });
    }

    #[cfg_attr(feature = "tracing",tracing::instrument(name="download chunk",skip_all,fields(chunk_index = self.chunk_info.index)))]
    async fn download_chunk(
        self: Arc<Self>,
        mut request: Box<Request>,
        retry_count: u8,
    ) -> Result<bool, DownloadError> {
        let cancel_token = self.cancel_token.clone();
        let mut chunk_bytes = Vec::with_capacity(self.chunk_info.range.len() as usize);

        let mut cur_retry_count = 0;
        let future = async {
            'r: loop {
                request.headers_mut().typed_insert(
                    ChunkRange::new(
                        self.chunk_info.range.start + chunk_bytes.len() as u64,
                        self.chunk_info.range.end,
                    )
                    .to_range_header(),
                );
                // 避免 clone request ?
                let response = self.client.execute(*ChunkManager::clone_request(&request));
                #[cfg(feature = "tracing")]
                let response = response.instrument(tracing::info_span!("chunk's http request"));
                let response = match response.await {
                    Ok(response) => {
                        cur_retry_count = 0;
                        response
                    }
                    Err(err) => {
                        cur_retry_count += 1;
                        #[cfg(feature = "tracing")]
                        tracing::trace!(
                            "Request error! {:?},retry_info: {}/{}",
                            err,
                            cur_retry_count,
                            retry_count
                        );
                        if cur_retry_count > retry_count {
                            return Err(DownloadError::HttpRequestFailed(err));
                        }
                        continue 'r;
                    }
                };
                if self.etag.is_some() {
                    let etag = response.headers().typed_get::<headers::ETag>();
                    if etag != self.etag {
                        #[cfg(feature = "tracing")]
                        tracing::trace!(
                            "etag mismatching,your etag: {:?} , current etag:{:?}",
                            self.etag,
                            etag
                        );
                        return Err(DownloadError::ServerFileAlreadyChanged);
                    }
                }
                let mut stream = response.bytes_stream();
                while let Some(bytes) = stream.next().await {
                    #[cfg(feature = "tracing")]
                    let span = tracing::info_span!("process received bytes", is_ok = bytes.is_ok());
                    #[cfg(feature = "tracing")]
                    let _ = span.enter();
                    let bytes: Bytes = {
                        match bytes {
                            Ok(bytes) => {
                                cur_retry_count = 0;
                                bytes
                            }
                            Err(err) => {
                                cur_retry_count += 1;
                                #[cfg(feature = "tracing")]
                                tracing::trace!(
                                    "Request error! {:?},retry_info: {}/{}",
                                    err,
                                    cur_retry_count,
                                    retry_count
                                );
                                if cur_retry_count > retry_count {
                                    // 出错后与取消一样处理:将缓冲中的数据写入磁盘并持久化数据
                                    let mut file = self.file.lock().await;
                                    file.seek(SeekFrom::Start(self.chunk_info.range.start))
                                        .await?;
                                    debug_assert!(
                                        chunk_bytes.len() as u64 <= self.chunk_info.range.len(),
                                        "chunk_bytes.len() = {}, self.chunk_info.range.len() = {}",
                                        chunk_bytes.len(),
                                        self.chunk_info.range.len()
                                    );
                                    file.write_all(chunk_bytes.as_ref()).await?;
                                    file.flush().await?;
                                    file.sync_all().await?;
                                    return Err(DownloadError::HttpRequestFailed(err));
                                }
                                continue 'r;
                            }
                        }
                    };
                    let len = bytes.len();
                    chunk_bytes.extend(bytes);
                    self.add_downloaded_len(len).await;
                    if let Some(downloaded_len_receiver) = self.downloaded_len_receiver.as_ref() {
                        match downloaded_len_receiver.receive_len(len) {
                            None => {}
                            Some(r) => r.await,
                        };
                    }
                }
                break;
            }
            Result::<(), DownloadError>::Ok(())
        };

        select! {
            r = future => {
                r?;
                let mut file = self.file.lock().await;
                file.seek(SeekFrom::Start(self.chunk_info.range.start)).await?;
                debug_assert_eq!(chunk_bytes.len() as u64,self.chunk_info.range.len());
                file.write_all(chunk_bytes.as_ref()).await?;
                file.flush().await?;
                file.sync_all().await?;
                Ok(true)
            }
            _ = cancel_token.cancelled() => {
                let mut file = self.file.lock().await;
                file.seek(SeekFrom::Start(self.chunk_info.range.start)).await?;
                debug_assert!(chunk_bytes.len() as u64 <= self.chunk_info.range.len(),"chunk_bytes.len() = {}, self.chunk_info.range.len() = {}", chunk_bytes.len(), self.chunk_info.range.len());
                file.write_all(chunk_bytes.as_ref()).await?;
                file.flush().await?;
                file.sync_all().await?;
                Ok(false)
            }
        }
    }
    pub fn start_download(
        self: Arc<Self>,
        request: Box<Request>,
        retry_count: u8,
    ) -> DownloadedChunkItem {
        use futures_util::FutureExt;

        let chunk_item = self.clone();
        let join_handle = tokio::spawn(self.clone().download_chunk(request, retry_count).then(
            |result| async move {
                match result {
                    Ok(is_finished) => {
                        if is_finished {
                            self.send_message(ChunkMessageKind::DownloadFinished)
                                .await
                                .unwrap_or_else(|_err| {
                                    #[cfg(feature = "tracing")]
                                    tracing::trace!("ChunkMessageInfoSendFailed! {:?}", _err);
                                })
                        }
                    }
                    Err(err) => self
                        .send_message(ChunkMessageKind::Error(err))
                        .await
                        .unwrap_or_else(|_err| {
                            #[cfg(feature = "tracing")]
                            tracing::trace!("ChunkMessageInfoSendFailed! {:?}", _err);
                        }),
                };
            },
        ));
        DownloadedChunkItem::new(chunk_item, join_handle)
    }
}

pub struct DownloadedChunkItem {
    pub chunk_item: Arc<ChunkItem>,
    pub join_handle: JoinHandle<()>,
}

impl DownloadedChunkItem {
    pub fn new(chunk_item: Arc<ChunkItem>, join_handle: JoinHandle<()>) -> Self {
        Self {
            chunk_item,
            join_handle,
        }
    }
}

impl Deref for DownloadedChunkItem {
    type Target = Arc<ChunkItem>;

    fn deref(&self) -> &Self::Target {
        &self.chunk_item
    }
}