http-downloader 0.2.3

A http download library that supports multithreading and resumable
Documentation
use std::collections::HashMap;
use std::num::{NonZeroU8, NonZeroUsize};
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};

use reqwest::Request;
use tokio::{select, sync};
use tokio::fs::File;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;

use crate::{
    chunk_item::{ChunkItem, ChunkMessageInfo, ChunkMessageKind, DownloadedChunkItem},
    ChunkIterator, ChunkRange, DownloadError,
};
use crate::{DownloadedLenChangeNotify, DownloadingEndCause};

#[allow(dead_code)]
#[cfg_attr(
feature = "async-graphql",
derive(async_graphql::SimpleObject),
graphql(complex)
)]
pub struct ChunksInfo {
    finished_chunks: Vec<ChunkRange>,
    #[cfg_attr(feature = "async-graphql", graphql(skip))]
    downloading_chunks: Vec<Arc<ChunkItem>>,
    no_chunk_remaining: bool,
}

#[cfg(feature = "async-graphql")]
pub struct DownloadChunkObject(pub Arc<ChunkItem>);

#[cfg(feature = "async-graphql")]
impl From<Arc<ChunkItem>> for DownloadChunkObject {
    fn from(value: Arc<ChunkItem>) -> Self {
        DownloadChunkObject(value)
    }
}

#[cfg(feature = "async-graphql")]
#[async_graphql::Object]
impl DownloadChunkObject {
    pub async fn index(&self) -> usize {
        self.0.chunk_info.index
    }
    pub async fn start(&self) -> u64 {
        self.0.chunk_info.range.start
    }
    pub async fn end(&self) -> u64 {
        self.0.chunk_info.range.end
    }
    pub async fn len(&self) -> u64 {
        self.0.chunk_info.range.len()
    }
    pub async fn downloaded_len(&self) -> u64 {
        self.0.downloaded_len.load(Ordering::Relaxed)
    }
}

#[cfg_attr(feature = "async-graphql", async_graphql::ComplexObject)]
impl ChunksInfo {
    #[cfg(feature = "async-graphql")]
    pub async fn downloading_chunks(&self) -> Vec<DownloadChunkObject> {
        self.downloading_chunks
            .iter()
            .cloned()
            .map(Into::into)
            .collect()
    }
}

pub struct ChunkManager {
    downloaded_len_sender: Arc<sync::watch::Sender<u64>>,
    pub chunk_iterator: ChunkIterator,
    downloading_chunks: Mutex<HashMap<usize, DownloadedChunkItem>>,
    #[cfg(feature = "breakpoint-resume")]
    pub data_archive_notify: sync::Notify,
    #[cfg(feature = "breakpoint-resume")]
    pub archive_complete_notify: sync::Notify,
    download_connection_count_sender: sync::watch::Sender<NonZeroU8>,
    pub download_connection_count_receiver: sync::watch::Receiver<NonZeroU8>,
    client: reqwest::Client,
    cancel_token: CancellationToken,
    pub superfluities_connection_count: AtomicU8,
    pub etag: Option<headers::ETag>,
    pub retry_count: u8,
}

impl ChunkManager {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        download_connection_count: NonZeroU8,
        client: reqwest::Client,
        cancel_token: CancellationToken,
        downloaded_len_sender: Arc<sync::watch::Sender<u64>>,
        chunk_iterator: ChunkIterator,
        etag: Option<headers::ETag>,
        retry_count: u8,
    ) -> Self {
        let (download_connection_count_sender, download_connection_count_receiver) =
            sync::watch::channel(download_connection_count);

        #[cfg(feature = "breakpoint-resume")]
            let (data_archive_notify, archive_complete_notify) =
            (sync::Notify::new(), sync::Notify::new());

        Self {
            #[cfg(feature = "breakpoint-resume")]
            data_archive_notify,
            #[cfg(feature = "breakpoint-resume")]
            archive_complete_notify,
            downloaded_len_sender,
            chunk_iterator,
            downloading_chunks: Mutex::new(HashMap::new()),
            download_connection_count_sender,
            download_connection_count_receiver,
            client,
            cancel_token,
            superfluities_connection_count: AtomicU8::new(0),
            etag,
            retry_count,
        }
    }

    pub fn change_connection_count(
        &self,
        connection_count: NonZeroU8,
    ) -> Result<(), sync::watch::error::SendError<NonZeroU8>> {
        self.download_connection_count_sender.send(connection_count)
    }

    pub fn change_chunk_size(&self, chunk_size: NonZeroUsize) {
        let mut guard = self.chunk_iterator.data.write();
        guard.remaining.chunk_size = chunk_size.get();
    }

    pub fn downloaded_len(&self) -> u64 {
        *self.downloaded_len_sender.borrow()
    }

    pub fn connection_count(&self) -> u8 {
        self.download_connection_count_sender.borrow().get()
    }

    pub fn clone_request(request: &Request) -> Box<Request> {
        let mut req = Request::new(request.method().clone(), request.url().clone());
        *req.headers_mut() = request.headers().clone();
        *req.version_mut() = request.version();
        *req.timeout_mut() = request.timeout().map(Clone::clone);
        Box::new(req)
    }

    pub async fn start_download(
        &self,
        file: Arc<Mutex<File>>,
        request: Box<Request>,
        downloaded_len_receiver: Option<Arc<dyn DownloadedLenChangeNotify>>,
        _breakpoint_resume: bool,
    ) -> Result<DownloadingEndCause, DownloadError> {
        let (chunk_message_sender, mut chunk_message_receiver) =
            sync::mpsc::channel::<ChunkMessageInfo>(64);

        let mut is_iter_all_chunk = !self
            .download_next_chunk(
                file.clone(),
                chunk_message_sender.clone(),
                downloaded_len_receiver.clone(),
                Self::clone_request(&request),
            )
            .await;
        if is_iter_all_chunk {
            #[cfg(feature = "tracing")]
            tracing::trace!("No Chunk!");
            return Ok(DownloadingEndCause::DownloadFinished);
        }

        for _ in 0..(self.connection_count() - 1) {
            is_iter_all_chunk = !self
                .download_next_chunk(
                    file.clone(),
                    chunk_message_sender.clone(),
                    downloaded_len_receiver.clone(),
                    Self::clone_request(&request),
                )
                .await;
            if is_iter_all_chunk {
                break;
            }
        }

        let connection_count_handle_future = {
            let mut download_connection_count_receiver =
                self.download_connection_count_receiver.clone();
            let chunk_message_sender = chunk_message_sender.clone();
            let file = file.clone();
            let downloaded_len_receiver = downloaded_len_receiver.clone();
            let request = Self::clone_request(&request);
            async move {
                while download_connection_count_receiver.changed().await.is_ok() {
                    let download_connection_count =
                        download_connection_count_receiver.borrow().get();
                    let current_count = self.get_chunks().await.len();
                    let diff = download_connection_count as i16 - current_count as i16;
                    if diff >= 0 {
                        self.superfluities_connection_count
                            .store(0, Ordering::SeqCst);
                        for _ in 0..diff {
                            if !self
                                .download_next_chunk(
                                    file.clone(),
                                    chunk_message_sender.clone(),
                                    downloaded_len_receiver.clone(),
                                    Self::clone_request(&request),
                                )
                                .await
                            {
                                break;
                            }
                        }
                    } else {
                        self.superfluities_connection_count
                            .store(diff.unsigned_abs() as u8, Ordering::SeqCst);
                    }
                }
                DownloadingEndCause::Cancelled
            }
        };

        let message_handle_future = async move {
            let chunk_message_sender = chunk_message_sender;
            while let Some(message_info) = chunk_message_receiver.recv().await {
                match message_info.kind {
                    ChunkMessageKind::DownloadFinished => {
                        #[cfg(feature = "tracing")]
                            let span = tracing::info_span!(
                            "Start Handle DownloadFinished",
                            chunk_inde = message_info.chunk_index
                        );
                        #[cfg(feature = "tracing")]
                            let _ = span.enter();
                        let (downloading_chunk_count, chunk_item) =
                            self.remove_chunk(message_info.chunk_index).await;
                        let _ = chunk_item
                            .ok_or(DownloadError::ChunkRemoveFailed(message_info.chunk_index))?;
                        if _breakpoint_resume {
                            self.save_spec_data().await;
                        }
                        if is_iter_all_chunk {
                            if downloading_chunk_count == 0 {
                                debug_assert_eq!(
                                    self.chunk_iterator.content_length,
                                    *self.downloaded_len_sender.borrow()
                                );
                                break;
                            }
                        } else if self.superfluities_connection_count.load(Ordering::SeqCst) == 0 {
                            is_iter_all_chunk = !self
                                .download_next_chunk(
                                    file.clone(),
                                    chunk_message_sender.clone(),
                                    downloaded_len_receiver.clone(),
                                    Self::clone_request(&request),
                                )
                                .await;
                        } else {
                            self.superfluities_connection_count
                                .fetch_sub(1, Ordering::SeqCst);
                        }
                    }
                    ChunkMessageKind::Error(err) => {
                        if _breakpoint_resume {
                            self.save_spec_data().await;
                        }
                        self.cancel_token.cancel();
                        return Err(err);
                    }
                    ChunkMessageKind::DownloadCancelled => {
                        if _breakpoint_resume {
                            self.save_spec_data().await;
                        }
                    }
                    ChunkMessageKind::DownloadLenAppend(append_len) => {
                        self.downloaded_len_sender
                            .send_modify(|n| *n += append_len as u64);
                    }
                }
            }
            Result::<DownloadingEndCause, DownloadError>::Ok(DownloadingEndCause::DownloadFinished)
        };

        let cancellation_token = self.cancel_token.clone();
        let r = select! {
            r = connection_count_handle_future => {Ok(r)}
            r = message_handle_future => {r}
            _ = cancellation_token.cancelled() => {
                Ok(DownloadingEndCause::Cancelled)
            }
        };
        r
    }
    async fn insert_chunk(&self, item: DownloadedChunkItem) {
        let mut downloading_chunks = self.downloading_chunks.lock().await;
        downloading_chunks.insert(item.chunk_info.index, item);
    }

    pub async fn get_chunks(&self) -> Vec<Arc<ChunkItem>> {
        let mut downloading_chunks: Vec<_> = self
            .downloading_chunks
            .lock()
            .await
            .values()
            .map(|n| n.chunk_item.clone())
            .collect();
        downloading_chunks.sort_by(|a, b| a.chunk_info.range.start.cmp(&b.chunk_info.range.start));
        downloading_chunks
    }

    pub async fn get_chunks_info(&self) -> ChunksInfo {
        let downloading_chunks = self.get_chunks().await;
        let mut finished_chunks = vec![];

        let no_chunk_remaining = self.chunk_iterator.data.read().no_chunk_remaining();
        if !downloading_chunks.is_empty() {
            let first_start = downloading_chunks[0].chunk_info.range.start;
            if first_start != 0 {
                finished_chunks.push(ChunkRange::new(0, first_start - 1));
            }
            for (index, _) in downloading_chunks.iter().enumerate() {
                if index == downloading_chunks.len() - 1 {
                    break;
                }

                let start = downloading_chunks[index].chunk_info.range.end;
                let end = downloading_chunks[index + 1].chunk_info.range.start;
                if (end - start) != 1 {
                    finished_chunks.push(ChunkRange::new(start + 1, end - 1));
                }
            }
            if no_chunk_remaining {
                let last = downloading_chunks.last().unwrap();
                if last.chunk_info.range.end != self.chunk_iterator.content_length - 1 {
                    finished_chunks.push(ChunkRange::new(
                        last.chunk_info.range.end + 1,
                        self.chunk_iterator.content_length - 1,
                    ))
                }
            }
        }
        ChunksInfo {
            downloading_chunks,
            finished_chunks,
            no_chunk_remaining,
        }
    }

    async fn save_spec_data(&self) {
        #[cfg(feature = "breakpoint-resume")]
        {
            #[cfg(feature = "tracing")]
                let span = tracing::info_span!("Archive Data");
            #[cfg(feature = "tracing")]
                let _ = span.enter();
            let notified = self.archive_complete_notify.notified();
            self.data_archive_notify.notify_one();
            notified.await;
        }
    }

    async fn remove_chunk(&self, index: usize) -> (usize, Option<DownloadedChunkItem>) {
        let mut downloading_chunks = self.downloading_chunks.lock().await;
        let removed = downloading_chunks.remove(&index);
        (downloading_chunks.len(), removed)
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
    async fn download_next_chunk(
        &self,
        file: Arc<Mutex<File>>,
        sender: sync::mpsc::Sender<ChunkMessageInfo>,
        downloaded_len_receiver: Option<Arc<dyn DownloadedLenChangeNotify>>,
        request: Box<Request>,
    ) -> bool {
        if let Some(chunk_info) = self.chunk_iterator.next() {
            let chunk_item = Arc::new(ChunkItem::new(
                chunk_info,
                self.cancel_token.child_token(),
                self.client.clone(),
                sender,
                file,
                downloaded_len_receiver,
                self.etag.clone(),
            ));
            let item = chunk_item.start_download(request, self.retry_count);
            self.insert_chunk(item).await;
            true
        } else {
            false
        }
    }
}