Skip to main content

fast_pull/core/
mod.rs

1use crate::{Event, handle::SharedHandle};
2use core::sync::atomic::{AtomicBool, Ordering};
3use crossfire::{MAsyncRx, mpmc};
4use fast_steal::{Executor, Handle, TaskQueue};
5use std::sync::{Arc, Weak};
6use tokio::task::{AbortHandle, JoinError, JoinHandle};
7
8pub mod handle;
9pub mod mock;
10pub mod multi;
11pub mod single;
12
13#[derive(Debug)]
14pub struct DownloadResult<E, PullError, PushError>
15where
16    E: Executor + Send + Sync,
17    PullError: Send + Unpin + 'static,
18    PushError: Send + Unpin + 'static,
19{
20    pub event_chain: MAsyncRx<mpmc::List<Event<PullError, PushError>>>,
21    handle: Arc<SharedHandle<()>>,
22    abort_handles: Option<Arc<[AbortHandle]>>,
23    task_queue: Option<(Weak<E>, TaskQueue<E::Handle>)>,
24    is_aborted: Arc<AtomicBool>,
25}
26
27impl<E, PullError, PushError> Clone for DownloadResult<E, PullError, PushError>
28where
29    E: Executor + Send + Sync,
30    PullError: Send + Unpin + 'static,
31    PushError: Send + Unpin + 'static,
32{
33    fn clone(&self) -> Self {
34        Self {
35            event_chain: self.event_chain.clone(),
36            handle: self.handle.clone(),
37            abort_handles: self.abort_handles.clone(),
38            task_queue: self.task_queue.clone(),
39            is_aborted: self.is_aborted.clone(),
40        }
41    }
42}
43
44impl<E, PullError, PushError> DownloadResult<E, PullError, PushError>
45where
46    E: Executor + Send + Sync,
47    PullError: Send + Unpin + 'static,
48    PushError: Send + Unpin + 'static,
49{
50    pub fn new(
51        event_chain: MAsyncRx<mpmc::List<Event<PullError, PushError>>>,
52        handle: JoinHandle<()>,
53        abort_handles: Option<&[AbortHandle]>,
54        task_queue: Option<(Weak<E>, TaskQueue<E::Handle>)>,
55    ) -> Self {
56        Self {
57            event_chain,
58            handle: Arc::new(SharedHandle::new(handle)),
59            abort_handles: abort_handles.map(Arc::from),
60            task_queue,
61            is_aborted: Arc::new(AtomicBool::new(false)),
62        }
63    }
64
65    /// # Errors
66    /// 当写入线程意外退出时返回 `Arc<JoinError>`
67    pub async fn join(&self) -> Result<(), Arc<JoinError>> {
68        self.handle.join().await
69    }
70
71    pub fn abort(&self) {
72        if let Some(handles) = &self.abort_handles {
73            for handle in handles.iter() {
74                handle.abort();
75            }
76        }
77        if let Some((_, task_queue)) = &self.task_queue {
78            task_queue.handles(|iter| {
79                for handle in iter {
80                    handle.abort();
81                }
82            });
83        }
84        self.is_aborted.store(true, Ordering::Release);
85    }
86
87    pub fn set_threads(&self, threads: usize, min_chunk_size: u64) {
88        if let Some((executor, task_queue)) = &self.task_queue {
89            let executor = executor.upgrade();
90            let res = task_queue.set_threads(
91                threads,
92                min_chunk_size,
93                executor.as_ref().map(AsRef::as_ref),
94            );
95            if res.is_some() && threads > 0 {
96                self.is_aborted.store(false, Ordering::Release);
97            }
98        }
99    }
100
101    #[must_use]
102    pub fn is_aborted(&self) -> bool {
103        self.is_aborted.load(Ordering::Acquire)
104    }
105}
106
107impl<E, PullError, PushError> Drop for DownloadResult<E, PullError, PushError>
108where
109    E: Executor + Send + Sync,
110    PullError: Send + Unpin + 'static,
111    PushError: Send + Unpin + 'static,
112{
113    fn drop(&mut self) {
114        self.abort();
115    }
116}