fast_down_gui/core/
task.rs1use parking_lot::Mutex;
2use std::{
3 collections::{HashMap, VecDeque},
4 future::Future,
5 hash::Hash,
6 sync::Arc,
7};
8use tokio::sync::watch;
9use tokio_util::sync::CancellationToken;
10
11struct QueuedTask<K> {
12 id: K,
13 task: Box<dyn FnOnce() + Send>,
14}
15
16struct State<K> {
17 max_concurrency: usize,
18 current_running: usize,
19 pending_queue: VecDeque<QueuedTask<K>>,
20 tasks: HashMap<K, (u64, CancellationToken)>,
21 next_tag: u64,
22}
23
24impl<K> Drop for State<K> {
25 fn drop(&mut self) {
26 for queued in self.pending_queue.drain(..) {
27 (queued.task)();
28 }
29 }
30}
31
32#[derive(Clone)]
33pub struct TaskSet<K> {
34 state: Arc<Mutex<State<K>>>,
35 idle_tx: Arc<watch::Sender<usize>>,
36}
37
38struct TaskGuard<K: Clone + Eq + Hash + Send + 'static> {
39 this: TaskSet<K>,
40 id: K,
41 tag: u64,
42}
43
44impl<K: Clone + Eq + Hash + Send + 'static> Drop for TaskGuard<K> {
45 fn drop(&mut self) {
46 self.this.on_task_finished(&self.id, self.tag);
47 }
48}
49
50impl<K: Clone + Eq + Hash + Send + 'static> TaskSet<K> {
51 pub fn new(max_concurrency: usize) -> Self {
52 let (tx, _rx) = watch::channel(0);
53 Self {
54 state: Arc::new(Mutex::new(State {
55 max_concurrency,
56 current_running: 0,
57 pending_queue: VecDeque::new(),
58 tasks: HashMap::new(),
59 next_tag: 0,
60 })),
61 idle_tx: Arc::new(tx),
62 }
63 }
64
65 pub fn add_task<F>(&self, id: K, cancel_token: CancellationToken, fut: F)
67 where
68 F: Future<Output = ()> + Send + 'static,
69 {
70 let mut state = self.state.lock();
71 let tag = {
72 state.next_tag += 1;
73 state.next_tag - 1
74 };
75
76 if let Some((_, old_token)) = state.tasks.insert(id.clone(), (tag, cancel_token.clone())) {
77 old_token.cancel();
78 if let Some(pos) = state.pending_queue.iter().position(|q| q.id == id) {
79 let queued = state.pending_queue.remove(pos).unwrap();
80 state.current_running += 1;
81 (queued.task)();
82 }
83 };
84
85 let wrapped_fn = {
86 let weak_state = Arc::downgrade(&self.state);
87 let weak_tx = Arc::downgrade(&self.idle_tx);
88 let id = id.clone();
89 move || match (weak_state.upgrade(), weak_tx.upgrade()) {
90 (Some(state), Some(idle_tx)) => {
91 let this = TaskSet { state, idle_tx };
92 tokio::spawn(async move {
93 let _guard = TaskGuard { this, id, tag };
94 fut.await;
95 });
96 }
97 _ => {
98 tokio::spawn(fut);
99 }
100 }
101 };
102
103 if state.current_running < state.max_concurrency {
104 state.current_running += 1;
105 wrapped_fn();
106 } else {
107 state.pending_queue.push_back(QueuedTask {
108 id,
109 task: Box::new(wrapped_fn),
110 });
111 }
112 }
113
114 pub fn cancel_task(&self, id: &K) {
116 let mut state = self.state.lock();
117 if let Some(entry) = state.tasks.remove(id) {
118 entry.1.cancel();
119 }
120 if let Some(pos) = state.pending_queue.iter().position(|q| q.id == *id) {
121 let queued = state.pending_queue.remove(pos).unwrap();
122 state.current_running += 1;
123 (queued.task)();
124 }
125 self.try_spawn_next(&mut state);
126 }
127
128 pub fn cancel_all(&self) {
130 let mut state = self.state.lock();
131 for (_, (_, token)) in state.tasks.drain() {
132 token.cancel();
133 }
134 while let Some(queued) = state.pending_queue.pop_front() {
135 state.current_running += 1;
136 (queued.task)();
137 }
138 self.try_spawn_next(&mut state);
139 }
140
141 pub fn join(&self) -> impl Future<Output = ()> {
143 let state = self.state.clone();
144 let mut rx = self.idle_tx.subscribe();
145 async move {
146 loop {
147 {
148 let s = state.lock();
149 if s.current_running == 0 && s.pending_queue.is_empty() {
150 return;
151 }
152 }
153 let _ = rx.changed().await;
154 }
155 }
156 }
157
158 pub fn wait_last(&self) -> impl Future<Output = ()> {
160 let state = self.state.clone();
161 let mut rx = self.idle_tx.subscribe();
162 async move {
163 let baseline = {
164 let s = state.lock();
165 if s.current_running == 0 && s.pending_queue.is_empty() {
166 Some(s.next_tag)
167 } else {
168 None
169 }
170 };
171 loop {
172 {
173 let s = state.lock();
174 if s.current_running == 0 && s.pending_queue.is_empty() {
175 match baseline {
176 Some(tag) => {
177 if s.next_tag > tag {
178 return;
179 }
180 }
181 None => return,
182 }
183 }
184 }
185 let _ = rx.changed().await;
186 }
187 }
188 }
189
190 pub fn set_concurrency(&self, new_max: usize) {
192 let mut state = self.state.lock();
193 state.max_concurrency = new_max;
194 self.try_spawn_next(&mut state);
195 }
196
197 pub fn stats(&self) -> (usize, usize) {
199 let state = self.state.lock();
200 (state.current_running, state.pending_queue.len())
201 }
202
203 fn on_task_finished(&self, id: &K, task_tag: u64) {
204 let mut state = self.state.lock();
205 if let Some((existing_tag, _)) = state.tasks.get(id)
206 && *existing_tag == task_tag
207 {
208 state.tasks.remove(id);
209 }
210 state.current_running = state.current_running.saturating_sub(1);
211 self.try_spawn_next(&mut state);
212 }
213
214 fn try_spawn_next(&self, state: &mut State<K>) {
215 while state.current_running < state.max_concurrency
216 && let Some(queued) = state.pending_queue.pop_front()
217 {
218 state.current_running += 1;
219 (queued.task)();
220 }
221 if state.current_running == 0 && state.pending_queue.is_empty() {
222 self.idle_tx.send_modify(|v| *v = v.wrapping_add(1));
223 }
224 }
225}