1use std::collections::HashMap;
2use std::future::Future;
3use std::num::{NonZeroU8, NonZeroUsize};
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU8, Ordering};
7use std::task::{Context, Poll};
8
9use futures_util::{FutureExt, StreamExt};
10use futures_util::future::{BoxFuture, OptionFuture};
11use futures_util::stream::FuturesUnordered;
12use reqwest::Request;
13use tokio::fs::File;
14use tokio::sync;
15use tokio::sync::Mutex;
16use tokio_util::sync::CancellationToken;
17
18use crate::{chunk_item::ChunkItem, ChunkIterator, ChunkRange, DownloadError};
19use crate::{DownloadedLenChangeNotify, DownloadingEndCause};
20
21#[allow(dead_code)]
22#[cfg_attr(
23feature = "async-graphql",
24derive(async_graphql::SimpleObject),
25graphql(complex)
26)]
27pub struct ChunksInfo {
28 finished_chunks: Vec<ChunkRange>,
29 #[cfg_attr(feature = "async-graphql", graphql(skip))]
30 downloading_chunks: Vec<Arc<ChunkItem>>,
31 no_chunk_remaining: bool,
32}
33
34pub struct ChunkManager {
35 downloaded_len_sender: Arc<sync::watch::Sender<u64>>,
36 pub chunk_iterator: ChunkIterator,
37 downloading_chunks: Mutex<HashMap<usize, Arc<ChunkItem>>>,
38 download_connection_count_sender: sync::watch::Sender<u8>,
39 pub download_connection_count_receiver: sync::watch::Receiver<u8>,
40 client: reqwest::Client,
41 cancel_token: CancellationToken,
42 pub superfluities_connection_count: AtomicU8,
43 pub etag: Option<headers::ETag>,
44 pub retry_count: u8,
45}
46
47impl ChunkManager {
48 #[allow(clippy::too_many_arguments)]
49 pub fn new(
50 download_connection_count: NonZeroU8,
51 client: reqwest::Client,
52 cancel_token: CancellationToken,
53 downloaded_len_sender: Arc<sync::watch::Sender<u64>>,
54 chunk_iterator: ChunkIterator,
55 etag: Option<headers::ETag>,
56 retry_count: u8,
57 ) -> Self {
58 let (download_connection_count_sender, download_connection_count_receiver) =
59 sync::watch::channel(download_connection_count.get());
60
61 Self {
62 downloaded_len_sender,
63 chunk_iterator,
64 downloading_chunks: Mutex::new(HashMap::new()),
65 download_connection_count_sender,
66 download_connection_count_receiver,
67 client,
68 cancel_token,
69 superfluities_connection_count: AtomicU8::new(0),
70 etag,
71 retry_count,
72 }
73 }
74
75 pub fn change_connection_count(
76 &self,
77 connection_count: NonZeroU8,
78 ) -> Result<(), sync::watch::error::SendError<u8>> {
79 self.download_connection_count_sender.send(connection_count.get())
80 }
81
82 pub fn change_chunk_size(&self, chunk_size: NonZeroUsize) {
83 let mut guard = self.chunk_iterator.data.write();
84 guard.remaining.chunk_size = chunk_size.get();
85 }
86
87 pub fn downloaded_len(&self) -> u64 {
88 *self.downloaded_len_sender.borrow()
89 }
90
91 pub fn connection_count(&self) -> u8 {
92 *self.download_connection_count_sender.borrow()
93 }
94
95 pub fn clone_request(request: &Request) -> Box<Request> {
96 let mut req = Request::new(request.method().clone(), request.url().clone());
97 *req.headers_mut() = request.headers().clone();
98 *req.version_mut() = request.version();
99 *req.timeout_mut() = request.timeout().map(Clone::clone);
100 Box::new(req)
101 }
102
103 pub async fn start_download(
104 &self,
105 file: File,
106 request: Box<Request>,
107 downloaded_len_receiver: Option<Arc<dyn DownloadedLenChangeNotify>>,
108 #[cfg(feature = "breakpoint-resume")]
109 breakpoint_resume: Option<Arc<crate::BreakpointResume>>,
110 ) -> Result<DownloadingEndCause, DownloadError> {
111 enum RunFuture<'a> {
112 DownloadConnectionCountChanged(BoxFuture<'a, (sync::watch::Receiver<u8>, u8)>),
113 ChunkDownloadEnd {
114 chunk_index: usize,
115 future: BoxFuture<'a, Result<DownloadingEndCause, DownloadError>>,
116 },
117 }
118
119 #[derive(Debug)]
120 enum RunFutureResult {
121 DownloadConnectionCountChanged {
122 receiver: sync::watch::Receiver<u8>,
123 download_connection_count: u8,
124 },
125 ChunkDownloadEnd {
126 chunk_index: usize,
127 result: Result<DownloadingEndCause, DownloadError>,
128 },
129 }
130
131 impl Future for RunFuture<'_> {
132 type Output = RunFutureResult;
133
134 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 match self.get_mut() {
136 RunFuture::DownloadConnectionCountChanged(future) => {
137 future.poll_unpin(cx).map(|r| RunFutureResult::DownloadConnectionCountChanged {
138 receiver: r.0,
139 download_connection_count: r.1,
140 })
141 }
142 RunFuture::ChunkDownloadEnd {
143 future,
144 chunk_index
145 } => {
146 future.poll_unpin(cx).map(|result| RunFutureResult::ChunkDownloadEnd {
147 chunk_index: chunk_index.clone(),
148 result,
149 })
150 }
151 }
152 }
153 }
154
155 let mut futures_unordered = FuturesUnordered::new();
156
157
158 let file = Arc::new(Mutex::new(file));
159 let download_next_chunk = || async {
160 match self
161 .download_next_chunk(
162 file.clone(),
163 downloaded_len_receiver.clone(),
164 Self::clone_request(&request),
165 )
166 .await {
167 None => {
168 None
169 }
170 Some((chunk_index, future)) => {
171 Some(RunFuture::ChunkDownloadEnd {
172 chunk_index,
173 future: future.boxed(),
174 })
175 }
176 }
177 };
178 match download_next_chunk().await {
179 None => {
180 #[cfg(feature = "tracing")]
181 tracing::trace!("No Chunk!");
182 return Ok(DownloadingEndCause::DownloadFinished);
183 }
184 Some(future) => futures_unordered.push(future)
185 }
186
187 let mut is_iter_finished = false;
188 for _ in 0..(self.connection_count() - 1) {
189 match download_next_chunk().await {
190 None => {
191 is_iter_finished = true;
192 break;
193 }
194 Some(future) => futures_unordered.push(future)
195 }
196 }
197 futures_unordered.push(RunFuture::DownloadConnectionCountChanged({
198 let mut receiver = self.download_connection_count_receiver.clone();
199 async move {
200 let _ = receiver.changed().await;
201 let i = *receiver.borrow();
202 (receiver, i)
203 }.boxed()
204 }));
205
206 #[cfg(feature = "breakpoint-resume")]
207 let save_data = || async {
208 if let Some(notifies) = breakpoint_resume.as_ref() {
209 #[cfg(feature = "tracing")]
210 let span = tracing::info_span!("Archive Data");
211 #[cfg(feature = "tracing")]
212 let _ = span.enter();
213 let notified = notifies.archive_complete_notify.notified();
214 notifies.data_archive_notify.notify_one();
215 notified.await;
216 }
217 };
218
219 let mut result = Result::<DownloadingEndCause, DownloadError>::Ok(DownloadingEndCause::DownloadFinished);
220 while let Some(future_result) = futures_unordered.next().await {
221 match future_result {
222 RunFutureResult::DownloadConnectionCountChanged {
223 download_connection_count,
224 mut receiver
225 } => {
226 if download_connection_count == 0 {
227 continue;
228 }
229
230 let current_count = self.get_chunks().await.len();
231 let diff = download_connection_count as i16 - current_count as i16;
232 if diff >= 0 {
233 self.superfluities_connection_count
234 .store(0, Ordering::SeqCst);
235 for _ in 0..diff {
236 match download_next_chunk().await {
237 None => {
238 is_iter_finished = true;
239 break;
240 }
241 Some(future) => futures_unordered.push(future)
242 }
243 }
244 } else {
245 self.superfluities_connection_count
246 .store(diff.unsigned_abs() as u8, Ordering::SeqCst);
247 }
248
249 futures_unordered.push(RunFuture::DownloadConnectionCountChanged(async move {
250 let _ = receiver.changed().await;
251 let i = *receiver.borrow();
252 (receiver, i)
253 }.boxed()))
254 }
255 RunFutureResult::ChunkDownloadEnd {
256 chunk_index,
257 result: Ok(DownloadingEndCause::DownloadFinished)
258 } => {
259 let (downloading_chunk_count, _) = self.remove_chunk(chunk_index).await;
260
261 #[cfg(feature = "breakpoint-resume")]
262 save_data().await;
263 if is_iter_finished {
264 if downloading_chunk_count == 0 {
265 debug_assert_eq!(
266 self.chunk_iterator.content_length,
267 *self.downloaded_len_sender.borrow()
268 );
269 break;
270 }
271 } else if self.superfluities_connection_count.load(Ordering::SeqCst) == 0 {
272 match download_next_chunk().await {
273 None => {
274 is_iter_finished = true;
275 if downloading_chunk_count == 0 {
276 debug_assert_eq!(
277 self.chunk_iterator.content_length,
278 *self.downloaded_len_sender.borrow()
279 );
280 break;
281 }
282 }
283 Some(future) => futures_unordered.push(future)
284 }
285 } else {
286 self.superfluities_connection_count
287 .fetch_sub(1, Ordering::SeqCst);
288 }
289 }
290 RunFutureResult::ChunkDownloadEnd {
291 result: Err(err),
292 ..
293 } => {
294 if matches!(result,Ok(DownloadingEndCause::DownloadFinished)) {
296 result = Err(err);
297 let _ =
299 self.download_connection_count_sender.send(0);
300 self.cancel_token.cancel();
302 }
303 }
304 RunFutureResult::ChunkDownloadEnd {
305 result: Ok(DownloadingEndCause::Cancelled),
306 ..
307 } => {
308 if matches!(result,Ok(DownloadingEndCause::DownloadFinished)) {
309 result = Ok(DownloadingEndCause::Cancelled);
310 let _ =
312 self.download_connection_count_sender.send(0);
313 }
314 }
315 }
316 }
317 if !matches!(result,Ok(DownloadingEndCause::DownloadFinished)) {
319 #[cfg(feature = "breakpoint-resume")]
320 save_data().await;
321 }
322 result
323 }
324 async fn insert_chunk(&self, item: Arc<ChunkItem>) {
325 let mut downloading_chunks = self.downloading_chunks.lock().await;
326 downloading_chunks.insert(item.chunk_info.index, item);
327 }
328
329 pub async fn get_chunks(&self) -> Vec<Arc<ChunkItem>> {
330 let mut downloading_chunks: Vec<_> = self
331 .downloading_chunks
332 .lock()
333 .await
334 .values()
335 .cloned()
336 .collect();
337 downloading_chunks.sort_by(|a, b| a.chunk_info.range.start.cmp(&b.chunk_info.range.start));
338 downloading_chunks
339 }
340
341 pub async fn get_chunks_info(&self) -> ChunksInfo {
342 let downloading_chunks = self.get_chunks().await;
343 let mut finished_chunks = vec![];
344
345 let no_chunk_remaining = self.chunk_iterator.data.read().no_chunk_remaining();
346 if !downloading_chunks.is_empty() {
347 let first_start = downloading_chunks[0].chunk_info.range.start;
348 if first_start != 0 {
349 finished_chunks.push(ChunkRange::new(0, first_start - 1));
350 }
351 for (index, _) in downloading_chunks.iter().enumerate() {
352 if index == downloading_chunks.len() - 1 {
353 break;
354 }
355
356 let start = downloading_chunks[index].chunk_info.range.end;
357 let end = downloading_chunks[index + 1].chunk_info.range.start;
358 if (end - start) != 1 {
359 finished_chunks.push(ChunkRange::new(start + 1, end - 1));
360 }
361 }
362 if no_chunk_remaining {
363 let last = downloading_chunks.last().unwrap();
364 if last.chunk_info.range.end != self.chunk_iterator.content_length - 1 {
365 finished_chunks.push(ChunkRange::new(
366 last.chunk_info.range.end + 1,
367 self.chunk_iterator.content_length - 1,
368 ))
369 }
370 }
371 }
372 ChunksInfo {
373 downloading_chunks,
374 finished_chunks,
375 no_chunk_remaining,
376 }
377 }
378
379 async fn remove_chunk(&self, index: usize) -> (usize, Option<Arc<ChunkItem>>) {
380 let mut downloading_chunks = self.downloading_chunks.lock().await;
381 let removed = downloading_chunks.remove(&index);
382 (downloading_chunks.len(), removed)
383 }
384
385 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
386 async fn download_next_chunk(
387 &self,
388 file: Arc<Mutex<File>>,
389 downloaded_len_receiver: Option<Arc<dyn DownloadedLenChangeNotify>>,
390 request: Box<Request>,
391 ) -> Option<(usize, impl Future<Output=Result<DownloadingEndCause, DownloadError>>)> {
392 if let Some(chunk_info) = self.chunk_iterator.next() {
393 let chunk_item = Arc::new(ChunkItem::new(
394 chunk_info,
395 self.cancel_token.child_token(),
396 self.client.clone(),
397 file,
398 self.etag.clone(),
399 ));
400 self.insert_chunk(chunk_item.clone()).await;
401 Some((chunk_item.chunk_info.index, chunk_item.download_chunk(request, self.retry_count, Some(LenChangedNotify {
402 notify: downloaded_len_receiver,
403 downloaded_len_sender: self.downloaded_len_sender.clone(),
404 }))))
405 } else {
406 None
407 }
408 }
409}
410
411pub struct LenChangedNotify {
412 downloaded_len_sender: Arc<sync::watch::Sender<u64>>,
413 notify: Option<Arc<dyn DownloadedLenChangeNotify>>,
414}
415
416impl DownloadedLenChangeNotify for LenChangedNotify {
417 fn receive_len(&self, len: usize) -> OptionFuture<BoxFuture<()>> {
418 self.downloaded_len_sender
419 .send_modify(|n| *n += len as u64);
420 if let Some(notify) = self.notify.as_ref() {
421 notify.receive_len(len)
422 } else {
423 None.into()
424 }
425 }
426}
427
428
429#[cfg(feature = "async-graphql")]
430pub struct DownloadChunkObject(pub Arc<ChunkItem>);
431
432#[cfg(feature = "async-graphql")]
433impl From<Arc<ChunkItem>> for DownloadChunkObject {
434 fn from(value: Arc<ChunkItem>) -> Self {
435 DownloadChunkObject(value)
436 }
437}
438
439#[cfg(feature = "async-graphql")]
440#[async_graphql::Object]
441impl DownloadChunkObject {
442 pub async fn index(&self) -> usize {
443 self.0.chunk_info.index
444 }
445 pub async fn start(&self) -> u64 {
446 self.0.chunk_info.range.start
447 }
448 pub async fn end(&self) -> u64 {
449 self.0.chunk_info.range.end
450 }
451 pub async fn len(&self) -> u64 {
452 self.0.chunk_info.range.len()
453 }
454 pub async fn downloaded_len(&self) -> u64 {
455 self.0.downloaded_len.load(Ordering::Relaxed)
456 }
457}
458
459#[cfg_attr(feature = "async-graphql", async_graphql::ComplexObject)]
460impl ChunksInfo {
461 #[cfg(feature = "async-graphql")]
462 pub async fn downloading_chunks(&self) -> Vec<DownloadChunkObject> {
463 self.downloading_chunks
464 .iter()
465 .cloned()
466 .map(Into::into)
467 .collect()
468 }
469}