http_downloader/
chunk_manager.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::num::{NonZeroU8, NonZeroUsize};
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU8, Ordering};
7use std::task::{Context, Poll};
8
9use futures_util::{FutureExt, StreamExt};
10use futures_util::future::{BoxFuture, OptionFuture};
11use futures_util::stream::FuturesUnordered;
12use reqwest::Request;
13use tokio::fs::File;
14use tokio::sync;
15use tokio::sync::Mutex;
16use tokio_util::sync::CancellationToken;
17
18use crate::{chunk_item::ChunkItem, ChunkIterator, ChunkRange, DownloadError};
19use crate::{DownloadedLenChangeNotify, DownloadingEndCause};
20
21#[allow(dead_code)]
22#[cfg_attr(
23feature = "async-graphql",
24derive(async_graphql::SimpleObject),
25graphql(complex)
26)]
27pub struct ChunksInfo {
28    finished_chunks: Vec<ChunkRange>,
29    #[cfg_attr(feature = "async-graphql", graphql(skip))]
30    downloading_chunks: Vec<Arc<ChunkItem>>,
31    no_chunk_remaining: bool,
32}
33
34pub struct ChunkManager {
35    downloaded_len_sender: Arc<sync::watch::Sender<u64>>,
36    pub chunk_iterator: ChunkIterator,
37    downloading_chunks: Mutex<HashMap<usize, Arc<ChunkItem>>>,
38    download_connection_count_sender: sync::watch::Sender<u8>,
39    pub download_connection_count_receiver: sync::watch::Receiver<u8>,
40    client: reqwest::Client,
41    cancel_token: CancellationToken,
42    pub superfluities_connection_count: AtomicU8,
43    pub etag: Option<headers::ETag>,
44    pub retry_count: u8,
45}
46
47impl ChunkManager {
48    #[allow(clippy::too_many_arguments)]
49    pub fn new(
50        download_connection_count: NonZeroU8,
51        client: reqwest::Client,
52        cancel_token: CancellationToken,
53        downloaded_len_sender: Arc<sync::watch::Sender<u64>>,
54        chunk_iterator: ChunkIterator,
55        etag: Option<headers::ETag>,
56        retry_count: u8,
57    ) -> Self {
58        let (download_connection_count_sender, download_connection_count_receiver) =
59            sync::watch::channel(download_connection_count.get());
60
61        Self {
62            downloaded_len_sender,
63            chunk_iterator,
64            downloading_chunks: Mutex::new(HashMap::new()),
65            download_connection_count_sender,
66            download_connection_count_receiver,
67            client,
68            cancel_token,
69            superfluities_connection_count: AtomicU8::new(0),
70            etag,
71            retry_count,
72        }
73    }
74
75    pub fn change_connection_count(
76        &self,
77        connection_count: NonZeroU8,
78    ) -> Result<(), sync::watch::error::SendError<u8>> {
79        self.download_connection_count_sender.send(connection_count.get())
80    }
81
82    pub fn change_chunk_size(&self, chunk_size: NonZeroUsize) {
83        let mut guard = self.chunk_iterator.data.write();
84        guard.remaining.chunk_size = chunk_size.get();
85    }
86
87    pub fn downloaded_len(&self) -> u64 {
88        *self.downloaded_len_sender.borrow()
89    }
90
91    pub fn connection_count(&self) -> u8 {
92        *self.download_connection_count_sender.borrow()
93    }
94
95    pub fn clone_request(request: &Request) -> Box<Request> {
96        let mut req = Request::new(request.method().clone(), request.url().clone());
97        *req.headers_mut() = request.headers().clone();
98        *req.version_mut() = request.version();
99        *req.timeout_mut() = request.timeout().map(Clone::clone);
100        Box::new(req)
101    }
102
103    pub async fn start_download(
104        &self,
105        file: File,
106        request: Box<Request>,
107        downloaded_len_receiver: Option<Arc<dyn DownloadedLenChangeNotify>>,
108        #[cfg(feature = "breakpoint-resume")]
109        breakpoint_resume: Option<Arc<crate::BreakpointResume>>,
110    ) -> Result<DownloadingEndCause, DownloadError> {
111        enum RunFuture<'a> {
112            DownloadConnectionCountChanged(BoxFuture<'a, (sync::watch::Receiver<u8>, u8)>),
113            ChunkDownloadEnd {
114                chunk_index: usize,
115                future: BoxFuture<'a, Result<DownloadingEndCause, DownloadError>>,
116            },
117        }
118
119        #[derive(Debug)]
120        enum RunFutureResult {
121            DownloadConnectionCountChanged {
122                receiver: sync::watch::Receiver<u8>,
123                download_connection_count: u8,
124            },
125            ChunkDownloadEnd {
126                chunk_index: usize,
127                result: Result<DownloadingEndCause, DownloadError>,
128            },
129        }
130
131        impl Future for RunFuture<'_> {
132            type Output = RunFutureResult;
133
134            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135                match self.get_mut() {
136                    RunFuture::DownloadConnectionCountChanged(future) => {
137                        future.poll_unpin(cx).map(|r| RunFutureResult::DownloadConnectionCountChanged {
138                            receiver: r.0,
139                            download_connection_count: r.1,
140                        })
141                    }
142                    RunFuture::ChunkDownloadEnd {
143                        future,
144                        chunk_index
145                    } => {
146                        future.poll_unpin(cx).map(|result| RunFutureResult::ChunkDownloadEnd {
147                            chunk_index: chunk_index.clone(),
148                            result,
149                        })
150                    }
151                }
152            }
153        }
154
155        let mut futures_unordered = FuturesUnordered::new();
156
157
158        let file = Arc::new(Mutex::new(file));
159        let download_next_chunk = || async {
160            match self
161                .download_next_chunk(
162                    file.clone(),
163                    downloaded_len_receiver.clone(),
164                    Self::clone_request(&request),
165                )
166                .await {
167                None => {
168                    None
169                }
170                Some((chunk_index, future)) => {
171                    Some(RunFuture::ChunkDownloadEnd {
172                        chunk_index,
173                        future: future.boxed(),
174                    })
175                }
176            }
177        };
178        match download_next_chunk().await {
179            None => {
180                #[cfg(feature = "tracing")]
181                tracing::trace!("No Chunk!");
182                return Ok(DownloadingEndCause::DownloadFinished);
183            }
184            Some(future) => futures_unordered.push(future)
185        }
186
187        let mut is_iter_finished = false;
188        for _ in 0..(self.connection_count() - 1) {
189            match download_next_chunk().await {
190                None => {
191                    is_iter_finished = true;
192                    break;
193                }
194                Some(future) => futures_unordered.push(future)
195            }
196        }
197        futures_unordered.push(RunFuture::DownloadConnectionCountChanged({
198            let mut receiver = self.download_connection_count_receiver.clone();
199            async move {
200                let _ = receiver.changed().await;
201                let i = *receiver.borrow();
202                (receiver, i)
203            }.boxed()
204        }));
205
206        #[cfg(feature = "breakpoint-resume")]
207            let save_data = || async {
208            if let Some(notifies) = breakpoint_resume.as_ref() {
209                #[cfg(feature = "tracing")]
210                    let span = tracing::info_span!("Archive Data");
211                #[cfg(feature = "tracing")]
212                    let _ = span.enter();
213                let notified = notifies.archive_complete_notify.notified();
214                notifies.data_archive_notify.notify_one();
215                notified.await;
216            }
217        };
218
219        let mut result = Result::<DownloadingEndCause, DownloadError>::Ok(DownloadingEndCause::DownloadFinished);
220        while let Some(future_result) = futures_unordered.next().await {
221            match future_result {
222                RunFutureResult::DownloadConnectionCountChanged {
223                    download_connection_count,
224                    mut receiver
225                } => {
226                    if download_connection_count == 0 {
227                        continue;
228                    }
229
230                    let current_count = self.get_chunks().await.len();
231                    let diff = download_connection_count as i16 - current_count as i16;
232                    if diff >= 0 {
233                        self.superfluities_connection_count
234                            .store(0, Ordering::SeqCst);
235                        for _ in 0..diff {
236                            match download_next_chunk().await {
237                                None => {
238                                    is_iter_finished = true;
239                                    break;
240                                }
241                                Some(future) => futures_unordered.push(future)
242                            }
243                        }
244                    } else {
245                        self.superfluities_connection_count
246                            .store(diff.unsigned_abs() as u8, Ordering::SeqCst);
247                    }
248
249                    futures_unordered.push(RunFuture::DownloadConnectionCountChanged(async move {
250                        let _ = receiver.changed().await;
251                        let i = *receiver.borrow();
252                        (receiver, i)
253                    }.boxed()))
254                }
255                RunFutureResult::ChunkDownloadEnd {
256                    chunk_index,
257                    result: Ok(DownloadingEndCause::DownloadFinished)
258                } => {
259                    let (downloading_chunk_count, _) = self.remove_chunk(chunk_index).await;
260
261                    #[cfg(feature = "breakpoint-resume")]
262                    save_data().await;
263                    if is_iter_finished {
264                        if downloading_chunk_count == 0 {
265                            debug_assert_eq!(
266                                self.chunk_iterator.content_length,
267                                *self.downloaded_len_sender.borrow()
268                            );
269                            break;
270                        }
271                    } else if self.superfluities_connection_count.load(Ordering::SeqCst) == 0 {
272                        match download_next_chunk().await {
273                            None => {
274                                is_iter_finished = true;
275                                if downloading_chunk_count == 0 {
276                                    debug_assert_eq!(
277                                        self.chunk_iterator.content_length,
278                                        *self.downloaded_len_sender.borrow()
279                                    );
280                                    break;
281                                }
282                            }
283                            Some(future) => futures_unordered.push(future)
284                        }
285                    } else {
286                        self.superfluities_connection_count
287                            .fetch_sub(1, Ordering::SeqCst);
288                    }
289                }
290                RunFutureResult::ChunkDownloadEnd {
291                    result: Err(err),
292                    ..
293                } => {
294                    // 只记录第一个错误
295                    if matches!(result,Ok(DownloadingEndCause::DownloadFinished)) {
296                        result = Err(err);
297                        // 取消监听 连接数 的更改
298                        let _ =
299                            self.download_connection_count_sender.send(0);
300                        // 取消其他的 Chunk 下载
301                        self.cancel_token.cancel();
302                    }
303                }
304                RunFutureResult::ChunkDownloadEnd {
305                    result: Ok(DownloadingEndCause::Cancelled),
306                    ..
307                } => {
308                    if matches!(result,Ok(DownloadingEndCause::DownloadFinished)) {
309                        result = Ok(DownloadingEndCause::Cancelled);
310                        // 取消监听 连接数 的更改
311                        let _ =
312                            self.download_connection_count_sender.send(0);
313                    }
314                }
315            }
316        }
317        // 如果没有完成,怎保存进度
318        if !matches!(result,Ok(DownloadingEndCause::DownloadFinished)) {
319            #[cfg(feature = "breakpoint-resume")]
320            save_data().await;
321        }
322        result
323    }
324    async fn insert_chunk(&self, item: Arc<ChunkItem>) {
325        let mut downloading_chunks = self.downloading_chunks.lock().await;
326        downloading_chunks.insert(item.chunk_info.index, item);
327    }
328
329    pub async fn get_chunks(&self) -> Vec<Arc<ChunkItem>> {
330        let mut downloading_chunks: Vec<_> = self
331            .downloading_chunks
332            .lock()
333            .await
334            .values()
335            .cloned()
336            .collect();
337        downloading_chunks.sort_by(|a, b| a.chunk_info.range.start.cmp(&b.chunk_info.range.start));
338        downloading_chunks
339    }
340
341    pub async fn get_chunks_info(&self) -> ChunksInfo {
342        let downloading_chunks = self.get_chunks().await;
343        let mut finished_chunks = vec![];
344
345        let no_chunk_remaining = self.chunk_iterator.data.read().no_chunk_remaining();
346        if !downloading_chunks.is_empty() {
347            let first_start = downloading_chunks[0].chunk_info.range.start;
348            if first_start != 0 {
349                finished_chunks.push(ChunkRange::new(0, first_start - 1));
350            }
351            for (index, _) in downloading_chunks.iter().enumerate() {
352                if index == downloading_chunks.len() - 1 {
353                    break;
354                }
355
356                let start = downloading_chunks[index].chunk_info.range.end;
357                let end = downloading_chunks[index + 1].chunk_info.range.start;
358                if (end - start) != 1 {
359                    finished_chunks.push(ChunkRange::new(start + 1, end - 1));
360                }
361            }
362            if no_chunk_remaining {
363                let last = downloading_chunks.last().unwrap();
364                if last.chunk_info.range.end != self.chunk_iterator.content_length - 1 {
365                    finished_chunks.push(ChunkRange::new(
366                        last.chunk_info.range.end + 1,
367                        self.chunk_iterator.content_length - 1,
368                    ))
369                }
370            }
371        }
372        ChunksInfo {
373            downloading_chunks,
374            finished_chunks,
375            no_chunk_remaining,
376        }
377    }
378
379    async fn remove_chunk(&self, index: usize) -> (usize, Option<Arc<ChunkItem>>) {
380        let mut downloading_chunks = self.downloading_chunks.lock().await;
381        let removed = downloading_chunks.remove(&index);
382        (downloading_chunks.len(), removed)
383    }
384
385    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
386    async fn download_next_chunk(
387        &self,
388        file: Arc<Mutex<File>>,
389        downloaded_len_receiver: Option<Arc<dyn DownloadedLenChangeNotify>>,
390        request: Box<Request>,
391    ) -> Option<(usize, impl Future<Output=Result<DownloadingEndCause, DownloadError>>)> {
392        if let Some(chunk_info) = self.chunk_iterator.next() {
393            let chunk_item = Arc::new(ChunkItem::new(
394                chunk_info,
395                self.cancel_token.child_token(),
396                self.client.clone(),
397                file,
398                self.etag.clone(),
399            ));
400            self.insert_chunk(chunk_item.clone()).await;
401            Some((chunk_item.chunk_info.index, chunk_item.download_chunk(request, self.retry_count, Some(LenChangedNotify {
402                notify: downloaded_len_receiver,
403                downloaded_len_sender: self.downloaded_len_sender.clone(),
404            }))))
405        } else {
406            None
407        }
408    }
409}
410
411pub struct LenChangedNotify {
412    downloaded_len_sender: Arc<sync::watch::Sender<u64>>,
413    notify: Option<Arc<dyn DownloadedLenChangeNotify>>,
414}
415
416impl DownloadedLenChangeNotify for LenChangedNotify {
417    fn receive_len(&self, len: usize) -> OptionFuture<BoxFuture<()>> {
418        self.downloaded_len_sender
419            .send_modify(|n| *n += len as u64);
420        if let Some(notify) = self.notify.as_ref() {
421            notify.receive_len(len)
422        } else {
423            None.into()
424        }
425    }
426}
427
428
429#[cfg(feature = "async-graphql")]
430pub struct DownloadChunkObject(pub Arc<ChunkItem>);
431
432#[cfg(feature = "async-graphql")]
433impl From<Arc<ChunkItem>> for DownloadChunkObject {
434    fn from(value: Arc<ChunkItem>) -> Self {
435        DownloadChunkObject(value)
436    }
437}
438
439#[cfg(feature = "async-graphql")]
440#[async_graphql::Object]
441impl DownloadChunkObject {
442    pub async fn index(&self) -> usize {
443        self.0.chunk_info.index
444    }
445    pub async fn start(&self) -> u64 {
446        self.0.chunk_info.range.start
447    }
448    pub async fn end(&self) -> u64 {
449        self.0.chunk_info.range.end
450    }
451    pub async fn len(&self) -> u64 {
452        self.0.chunk_info.range.len()
453    }
454    pub async fn downloaded_len(&self) -> u64 {
455        self.0.downloaded_len.load(Ordering::Relaxed)
456    }
457}
458
459#[cfg_attr(feature = "async-graphql", async_graphql::ComplexObject)]
460impl ChunksInfo {
461    #[cfg(feature = "async-graphql")]
462    pub async fn downloading_chunks(&self) -> Vec<DownloadChunkObject> {
463        self.downloading_chunks
464            .iter()
465            .cloned()
466            .map(Into::into)
467            .collect()
468    }
469}