Skip to main content

fast_down_gui/core/
task.rs

1use parking_lot::Mutex;
2use std::{
3    collections::{HashMap, VecDeque},
4    future::Future,
5    hash::Hash,
6    sync::Arc,
7};
8use tokio::sync::Notify;
9use tokio_util::sync::CancellationToken;
10
11struct QueuedTask<K> {
12    id: K,
13    task: Box<dyn FnOnce() + Send>,
14}
15
16struct State<K> {
17    max_concurrency: usize,
18    current_running: usize,
19    pending_queue: VecDeque<QueuedTask<K>>,
20    tasks: HashMap<K, (u64, CancellationToken)>,
21    next_tag: u64,
22}
23
24impl<K> Drop for State<K> {
25    fn drop(&mut self) {
26        for queued in self.pending_queue.drain(..) {
27            (queued.task)();
28        }
29    }
30}
31
32#[derive(Clone)]
33pub struct TaskSet<K> {
34    state: Arc<Mutex<State<K>>>,
35    idle_notify: Arc<Notify>,
36}
37
38struct TaskGuard<K: Clone + Eq + Hash + Send + 'static> {
39    this: TaskSet<K>,
40    id: K,
41    tag: u64,
42}
43
44impl<K: Clone + Eq + Hash + Send + 'static> Drop for TaskGuard<K> {
45    fn drop(&mut self) {
46        self.this.on_task_finished(&self.id, self.tag);
47    }
48}
49
50impl<K: Clone + Eq + Hash + Send + 'static> TaskSet<K> {
51    pub fn new(max_concurrency: usize) -> Self {
52        Self {
53            state: Arc::new(Mutex::new(State {
54                max_concurrency,
55                current_running: 0,
56                pending_queue: VecDeque::new(),
57                tasks: HashMap::new(),
58                next_tag: 0,
59            })),
60            idle_notify: Arc::new(Notify::new()),
61        }
62    }
63
64    /// 添加任务
65    pub fn add_task<F>(&self, id: K, cancel_token: CancellationToken, fut: F)
66    where
67        F: Future<Output = ()> + Send + 'static,
68    {
69        let mut state = self.state.lock();
70        let tag = {
71            state.next_tag += 1;
72            state.next_tag - 1
73        };
74
75        if let Some((_, old_token)) = state.tasks.insert(id.clone(), (tag, cancel_token.clone())) {
76            old_token.cancel();
77            let mut i = 0;
78            while i < state.pending_queue.len() {
79                if state.pending_queue[i].id == id {
80                    let queued = state.pending_queue.remove(i).unwrap();
81                    state.current_running += 1;
82                    (queued.task)();
83                    break;
84                } else {
85                    i += 1;
86                }
87            }
88        };
89
90        let wrapped_fn = {
91            let weak_state = Arc::downgrade(&self.state);
92            let weak_notify = Arc::downgrade(&self.idle_notify);
93            let id = id.clone();
94            move || match (weak_state.upgrade(), weak_notify.upgrade()) {
95                (Some(state), Some(idle_notify)) => {
96                    let this = TaskSet { state, idle_notify };
97                    tokio::spawn(async move {
98                        let _guard = TaskGuard { this, id, tag };
99                        fut.await;
100                    });
101                }
102                _ => {
103                    tokio::spawn(fut);
104                }
105            }
106        };
107
108        if state.current_running < state.max_concurrency {
109            state.current_running += 1;
110            wrapped_fn();
111        } else {
112            state.pending_queue.push_back(QueuedTask {
113                id,
114                task: Box::new(wrapped_fn),
115            });
116        }
117    }
118
119    /// 取消指定任务
120    pub fn cancel_task(&self, id: &K) {
121        let mut state = self.state.lock();
122        if let Some(entry) = state.tasks.remove(id) {
123            entry.1.cancel();
124        }
125
126        let mut i = 0;
127        while i < state.pending_queue.len() {
128            if state.pending_queue[i].id == *id {
129                let queued = state.pending_queue.remove(i).unwrap();
130                state.current_running += 1;
131                (queued.task)();
132                break;
133            } else {
134                i += 1;
135            }
136        }
137
138        self.try_spawn_next(&mut state);
139    }
140
141    /// 取消全部任务
142    pub fn cancel_all(&self) {
143        let mut state = self.state.lock();
144        for (_, (_, token)) in state.tasks.drain() {
145            token.cancel();
146        }
147
148        while let Some(queued) = state.pending_queue.pop_front() {
149            state.current_running += 1;
150            (queued.task)();
151        }
152
153        self.try_spawn_next(&mut state);
154    }
155
156    /// 等待所有任务完成
157    pub fn join(&self) -> impl Future<Output = ()> {
158        let state = self.state.clone();
159        let notify = self.idle_notify.clone();
160        async move {
161            loop {
162                {
163                    let state = state.lock();
164                    if state.current_running == 0 && state.pending_queue.is_empty() {
165                        return;
166                    }
167                    notify.notified()
168                }
169                .await;
170            }
171        }
172    }
173
174    /// 调整并发数
175    pub fn set_concurrency(&self, new_max: usize) {
176        let mut state = self.state.lock();
177        state.max_concurrency = new_max;
178        self.try_spawn_next(&mut state);
179    }
180
181    /// 状态统计
182    pub fn stats(&self) -> (usize, usize) {
183        let state = self.state.lock();
184        (state.current_running, state.pending_queue.len())
185    }
186
187    fn on_task_finished(&self, id: &K, task_tag: u64) {
188        let mut state = self.state.lock();
189        if let Some((existing_tag, _)) = state.tasks.get(id)
190            && *existing_tag == task_tag
191        {
192            state.tasks.remove(id);
193        }
194        state.current_running = state.current_running.saturating_sub(1);
195        self.try_spawn_next(&mut state);
196    }
197
198    fn try_spawn_next(&self, state: &mut State<K>) {
199        while state.current_running < state.max_concurrency
200            && let Some(queued) = state.pending_queue.pop_front()
201        {
202            state.current_running += 1;
203            (queued.task)();
204        }
205        if state.current_running == 0 && state.pending_queue.is_empty() {
206            self.idle_notify.notify_waiters();
207        }
208    }
209}