Skip to main content

fast_down_gui/core/
task.rs

1use parking_lot::Mutex;
2use std::{
3    collections::{HashMap, VecDeque},
4    future::Future,
5    hash::Hash,
6    sync::Arc,
7};
8use tokio::sync::watch;
9use tokio_util::sync::CancellationToken;
10
11struct QueuedTask<K> {
12    id: K,
13    task: Box<dyn FnOnce() + Send>,
14}
15
16struct State<K> {
17    max_concurrency: usize,
18    current_running: usize,
19    pending_queue: VecDeque<QueuedTask<K>>,
20    tasks: HashMap<K, (u64, CancellationToken)>,
21    next_tag: u64,
22}
23
24impl<K> Drop for State<K> {
25    fn drop(&mut self) {
26        for queued in self.pending_queue.drain(..) {
27            (queued.task)();
28        }
29    }
30}
31
32#[derive(Clone)]
33pub struct TaskSet<K> {
34    state: Arc<Mutex<State<K>>>,
35    idle_tx: Arc<watch::Sender<usize>>,
36}
37
38struct TaskGuard<K: Clone + Eq + Hash + Send + 'static> {
39    this: TaskSet<K>,
40    id: K,
41    tag: u64,
42}
43
44impl<K: Clone + Eq + Hash + Send + 'static> Drop for TaskGuard<K> {
45    fn drop(&mut self) {
46        self.this.on_task_finished(&self.id, self.tag);
47    }
48}
49
50impl<K: Clone + Eq + Hash + Send + 'static> TaskSet<K> {
51    pub fn new(max_concurrency: usize) -> Self {
52        let (tx, _rx) = watch::channel(0);
53        Self {
54            state: Arc::new(Mutex::new(State {
55                max_concurrency,
56                current_running: 0,
57                pending_queue: VecDeque::new(),
58                tasks: HashMap::new(),
59                next_tag: 0,
60            })),
61            idle_tx: Arc::new(tx),
62        }
63    }
64
65    /// 添加任务
66    pub fn add_task<F>(&self, id: K, cancel_token: CancellationToken, fut: F)
67    where
68        F: Future<Output = ()> + Send + 'static,
69    {
70        let mut state = self.state.lock();
71        let tag = {
72            state.next_tag += 1;
73            state.next_tag - 1
74        };
75
76        if let Some((_, old_token)) = state.tasks.insert(id.clone(), (tag, cancel_token.clone())) {
77            old_token.cancel();
78            if let Some(pos) = state.pending_queue.iter().position(|q| q.id == id) {
79                let queued = state.pending_queue.remove(pos).unwrap();
80                state.current_running += 1;
81                (queued.task)();
82            }
83        };
84
85        let wrapped_fn = {
86            let weak_state = Arc::downgrade(&self.state);
87            let weak_tx = Arc::downgrade(&self.idle_tx);
88            let id = id.clone();
89            move || match (weak_state.upgrade(), weak_tx.upgrade()) {
90                (Some(state), Some(idle_tx)) => {
91                    let this = TaskSet { state, idle_tx };
92                    tokio::spawn(async move {
93                        let _guard = TaskGuard { this, id, tag };
94                        fut.await;
95                    });
96                }
97                _ => {
98                    tokio::spawn(fut);
99                }
100            }
101        };
102
103        if state.current_running < state.max_concurrency {
104            state.current_running += 1;
105            wrapped_fn();
106        } else {
107            state.pending_queue.push_back(QueuedTask {
108                id,
109                task: Box::new(wrapped_fn),
110            });
111        }
112    }
113
114    /// 取消指定任务
115    pub fn cancel_task(&self, id: &K) {
116        let mut state = self.state.lock();
117        if let Some(entry) = state.tasks.remove(id) {
118            entry.1.cancel();
119        }
120        if let Some(pos) = state.pending_queue.iter().position(|q| q.id == *id) {
121            let queued = state.pending_queue.remove(pos).unwrap();
122            state.current_running += 1;
123            (queued.task)();
124        }
125        self.try_spawn_next(&mut state);
126    }
127
128    /// 取消全部任务
129    pub fn cancel_all(&self) {
130        let mut state = self.state.lock();
131        for (_, (_, token)) in state.tasks.drain() {
132            token.cancel();
133        }
134        while let Some(queued) = state.pending_queue.pop_front() {
135            state.current_running += 1;
136            (queued.task)();
137        }
138        self.try_spawn_next(&mut state);
139    }
140
141    /// 等待所有任务完成,无任务时立刻返回
142    pub fn join(&self) -> impl Future<Output = ()> {
143        let state = self.state.clone();
144        let mut rx = self.idle_tx.subscribe();
145        async move {
146            loop {
147                {
148                    let s = state.lock();
149                    if s.current_running == 0 && s.pending_queue.is_empty() {
150                        return;
151                    }
152                }
153                let _ = rx.changed().await;
154            }
155        }
156    }
157
158    /// 等待所有任务完成,无任务时等待
159    pub fn wait_last(&self) -> impl Future<Output = ()> {
160        let state = self.state.clone();
161        let mut rx = self.idle_tx.subscribe();
162        async move {
163            let baseline = {
164                let s = state.lock();
165                if s.current_running == 0 && s.pending_queue.is_empty() {
166                    Some(s.next_tag)
167                } else {
168                    None
169                }
170            };
171            loop {
172                {
173                    let s = state.lock();
174                    if s.current_running == 0 && s.pending_queue.is_empty() {
175                        match baseline {
176                            Some(tag) => {
177                                if s.next_tag > tag {
178                                    return;
179                                }
180                            }
181                            None => return,
182                        }
183                    }
184                }
185                let _ = rx.changed().await;
186            }
187        }
188    }
189
190    /// 调整并发数
191    pub fn set_concurrency(&self, new_max: usize) {
192        let mut state = self.state.lock();
193        state.max_concurrency = new_max;
194        self.try_spawn_next(&mut state);
195    }
196
197    /// 状态统计
198    pub fn stats(&self) -> (usize, usize) {
199        let state = self.state.lock();
200        (state.current_running, state.pending_queue.len())
201    }
202
203    fn on_task_finished(&self, id: &K, task_tag: u64) {
204        let mut state = self.state.lock();
205        if let Some((existing_tag, _)) = state.tasks.get(id)
206            && *existing_tag == task_tag
207        {
208            state.tasks.remove(id);
209        }
210        state.current_running = state.current_running.saturating_sub(1);
211        self.try_spawn_next(&mut state);
212    }
213
214    fn try_spawn_next(&self, state: &mut State<K>) {
215        while state.current_running < state.max_concurrency
216            && let Some(queued) = state.pending_queue.pop_front()
217        {
218            state.current_running += 1;
219            (queued.task)();
220        }
221        if state.current_running == 0 && state.pending_queue.is_empty() {
222            self.idle_tx.send_modify(|v| *v = v.wrapping_add(1));
223        }
224    }
225}