fast_down_gui/core/
task.rs1use parking_lot::Mutex;
2use std::{
3 collections::{HashMap, VecDeque},
4 future::Future,
5 hash::Hash,
6 sync::Arc,
7};
8use tokio::sync::Notify;
9use tokio_util::sync::CancellationToken;
10
11struct QueuedTask<K> {
12 id: K,
13 task: Box<dyn FnOnce() + Send>,
14}
15
16struct State<K> {
17 max_concurrency: usize,
18 current_running: usize,
19 pending_queue: VecDeque<QueuedTask<K>>,
20 tasks: HashMap<K, (u64, CancellationToken)>,
21 next_tag: u64,
22}
23
24impl<K> Drop for State<K> {
25 fn drop(&mut self) {
26 for queued in self.pending_queue.drain(..) {
27 (queued.task)();
28 }
29 }
30}
31
32#[derive(Clone)]
33pub struct TaskSet<K> {
34 state: Arc<Mutex<State<K>>>,
35 idle_notify: Arc<Notify>,
36}
37
38struct TaskGuard<K: Clone + Eq + Hash + Send + 'static> {
39 this: TaskSet<K>,
40 id: K,
41 tag: u64,
42}
43
44impl<K: Clone + Eq + Hash + Send + 'static> Drop for TaskGuard<K> {
45 fn drop(&mut self) {
46 self.this.on_task_finished(&self.id, self.tag);
47 }
48}
49
50impl<K: Clone + Eq + Hash + Send + 'static> TaskSet<K> {
51 pub fn new(max_concurrency: usize) -> Self {
52 Self {
53 state: Arc::new(Mutex::new(State {
54 max_concurrency,
55 current_running: 0,
56 pending_queue: VecDeque::new(),
57 tasks: HashMap::new(),
58 next_tag: 0,
59 })),
60 idle_notify: Arc::new(Notify::new()),
61 }
62 }
63
64 pub fn add_task<F>(&self, id: K, cancel_token: CancellationToken, fut: F)
66 where
67 F: Future<Output = ()> + Send + 'static,
68 {
69 let mut state = self.state.lock();
70 let tag = {
71 state.next_tag += 1;
72 state.next_tag - 1
73 };
74
75 if let Some((_, old_token)) = state.tasks.insert(id.clone(), (tag, cancel_token.clone())) {
76 old_token.cancel();
77 let mut i = 0;
78 while i < state.pending_queue.len() {
79 if state.pending_queue[i].id == id {
80 let queued = state.pending_queue.remove(i).unwrap();
81 state.current_running += 1;
82 (queued.task)();
83 break;
84 } else {
85 i += 1;
86 }
87 }
88 };
89
90 let wrapped_fn = {
91 let weak_state = Arc::downgrade(&self.state);
92 let weak_notify = Arc::downgrade(&self.idle_notify);
93 let id = id.clone();
94 move || match (weak_state.upgrade(), weak_notify.upgrade()) {
95 (Some(state), Some(idle_notify)) => {
96 let this = TaskSet { state, idle_notify };
97 tokio::spawn(async move {
98 let _guard = TaskGuard { this, id, tag };
99 fut.await;
100 });
101 }
102 _ => {
103 tokio::spawn(fut);
104 }
105 }
106 };
107
108 if state.current_running < state.max_concurrency {
109 state.current_running += 1;
110 wrapped_fn();
111 } else {
112 state.pending_queue.push_back(QueuedTask {
113 id,
114 task: Box::new(wrapped_fn),
115 });
116 }
117 }
118
119 pub fn cancel_task(&self, id: &K) {
121 let mut state = self.state.lock();
122 if let Some(entry) = state.tasks.remove(id) {
123 entry.1.cancel();
124 }
125
126 let mut i = 0;
127 while i < state.pending_queue.len() {
128 if state.pending_queue[i].id == *id {
129 let queued = state.pending_queue.remove(i).unwrap();
130 state.current_running += 1;
131 (queued.task)();
132 break;
133 } else {
134 i += 1;
135 }
136 }
137
138 self.try_spawn_next(&mut state);
139 }
140
141 pub fn cancel_all(&self) {
143 let mut state = self.state.lock();
144 for (_, (_, token)) in state.tasks.drain() {
145 token.cancel();
146 }
147
148 while let Some(queued) = state.pending_queue.pop_front() {
149 state.current_running += 1;
150 (queued.task)();
151 }
152
153 self.try_spawn_next(&mut state);
154 }
155
156 pub fn join(&self) -> impl Future<Output = ()> {
158 let state = self.state.clone();
159 let notify = self.idle_notify.clone();
160 async move {
161 loop {
162 {
163 let state = state.lock();
164 if state.current_running == 0 && state.pending_queue.is_empty() {
165 return;
166 }
167 notify.notified()
168 }
169 .await;
170 }
171 }
172 }
173
174 pub fn set_concurrency(&self, new_max: usize) {
176 let mut state = self.state.lock();
177 state.max_concurrency = new_max;
178 self.try_spawn_next(&mut state);
179 }
180
181 pub fn stats(&self) -> (usize, usize) {
183 let state = self.state.lock();
184 (state.current_running, state.pending_queue.len())
185 }
186
187 fn on_task_finished(&self, id: &K, task_tag: u64) {
188 let mut state = self.state.lock();
189 if let Some((existing_tag, _)) = state.tasks.get(id)
190 && *existing_tag == task_tag
191 {
192 state.tasks.remove(id);
193 }
194 state.current_running = state.current_running.saturating_sub(1);
195 self.try_spawn_next(&mut state);
196 }
197
198 fn try_spawn_next(&self, state: &mut State<K>) {
199 while state.current_running < state.max_concurrency
200 && let Some(queued) = state.pending_queue.pop_front()
201 {
202 state.current_running += 1;
203 (queued.task)();
204 }
205 if state.current_running == 0 && state.pending_queue.is_empty() {
206 self.idle_notify.notify_waiters();
207 }
208 }
209}