Skip to main content

fast_pull/core/
mod.rs

1use crate::{Event, handle::SharedHandle};
2use core::sync::atomic::{AtomicBool, Ordering};
3use crossfire::{MAsyncRx, mpmc};
4use fast_steal::{Executor, Handle, TaskQueue};
5use std::sync::{Arc, Weak};
6use tokio::task::{AbortHandle, JoinError, JoinHandle};
7
8pub mod handle;
9pub mod mock;
10pub mod multi;
11pub mod single;
12
13pub struct DownloadResult<E, PullError, PushError>
14where
15    E: Executor,
16    PullError: Send + Unpin + 'static,
17    PushError: Send + Unpin + 'static,
18{
19    pub event_chain: MAsyncRx<mpmc::List<Event<PullError, PushError>>>,
20    handle: Arc<SharedHandle<()>>,
21    abort_handles: Option<Arc<[AbortHandle]>>,
22    task_queue: Option<(Weak<E>, TaskQueue<E::Handle>)>,
23    is_aborted: Arc<AtomicBool>,
24}
25
26impl<E, PullError, PushError> Clone for DownloadResult<E, PullError, PushError>
27where
28    E: Executor,
29    PullError: Send + Unpin + 'static,
30    PushError: Send + Unpin + 'static,
31{
32    fn clone(&self) -> Self {
33        Self {
34            event_chain: self.event_chain.clone(),
35            handle: self.handle.clone(),
36            abort_handles: self.abort_handles.clone(),
37            task_queue: self.task_queue.clone(),
38            is_aborted: self.is_aborted.clone(),
39        }
40    }
41}
42
43impl<E, PullError, PushError> DownloadResult<E, PullError, PushError>
44where
45    E: Executor,
46    PullError: Send + Unpin + 'static,
47    PushError: Send + Unpin + 'static,
48{
49    pub fn new(
50        event_chain: MAsyncRx<mpmc::List<Event<PullError, PushError>>>,
51        handle: JoinHandle<()>,
52        abort_handles: Option<&[AbortHandle]>,
53        task_queue: Option<(Weak<E>, TaskQueue<E::Handle>)>,
54    ) -> Self {
55        Self {
56            event_chain,
57            handle: Arc::new(SharedHandle::new(handle)),
58            abort_handles: abort_handles.map(Arc::from),
59            task_queue,
60            is_aborted: Arc::new(AtomicBool::new(false)),
61        }
62    }
63
64    pub async fn join(&self) -> Result<(), Arc<JoinError>> {
65        self.handle.join().await
66    }
67
68    pub fn abort(&self) {
69        if let Some(handles) = &self.abort_handles {
70            for handle in handles.iter() {
71                handle.abort();
72            }
73        }
74        if let Some((_, task_queue)) = &self.task_queue {
75            task_queue.handles(|iter| {
76                for handle in iter {
77                    handle.abort();
78                }
79            });
80        }
81        self.is_aborted.store(true, Ordering::Release);
82    }
83
84    pub fn set_threads(&self, threads: usize, min_chunk_size: u64) {
85        if let Some((executor, task_queue)) = &self.task_queue {
86            let executor = executor.upgrade();
87            let res = task_queue.set_threads(
88                threads,
89                min_chunk_size,
90                executor.as_ref().map(|e| e.as_ref()),
91            );
92            if res.is_some() && threads > 0 {
93                self.is_aborted.store(false, Ordering::Release);
94            }
95        }
96    }
97
98    pub fn is_aborted(&self) -> bool {
99        self.is_aborted.load(Ordering::Acquire)
100    }
101}
102
103impl<E, PullError, PushError> Drop for DownloadResult<E, PullError, PushError>
104where
105    E: Executor,
106    PullError: Send + Unpin + 'static,
107    PushError: Send + Unpin + 'static,
108{
109    fn drop(&mut self) {
110        self.abort();
111    }
112}