1extern crate alloc;
2use crate::{Executor, Handle, Task, WeakTask};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9 inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12 fn clone(&self) -> Self {
13 Self {
14 inner: self.inner.clone(),
15 }
16 }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20 running: VecDeque<(WeakTask, H)>,
21 waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24 pub fn new(tasks: impl Iterator<Item = Range<u64>>) -> Self {
25 let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26 Self {
27 inner: Arc::new(Mutex::new(TaskQueueInner {
28 running: VecDeque::with_capacity(waiting.len()),
29 waiting,
30 })),
31 }
32 }
33 pub fn add(&self, task: Task) {
34 let mut guard = self.inner.lock();
35 guard.waiting.push_back(task);
36 }
37 pub fn steal(
38 &self,
39 id: &H::Id,
40 task: &mut Task,
41 min_chunk_size: u64,
42 max_speculative: usize,
43 ) -> bool {
44 let min_chunk_size = min_chunk_size.max(1);
45 let mut guard = self.inner.lock();
46 let mut worker_idx = None;
47 for (i, (_, handle)) in guard.running.iter_mut().enumerate() {
48 if handle.is_self(id) {
49 worker_idx = Some(i);
50 break;
51 }
52 }
53 let Some(worker_idx) = worker_idx else {
54 return false;
55 };
56 let mut found = false;
57 while let Some(new_task) = guard.waiting.pop_front() {
58 if let Some(range) = new_task.take() {
59 *task = Task::new(range);
60 found = true;
61 break;
62 }
63 }
64 if !found
65 && let Some(steal_task) = guard
66 .running
67 .iter()
68 .filter_map(|w| w.0.upgrade())
69 .filter(|w| w != task)
70 .max_by_key(Task::remain)
71 {
72 if steal_task.remain() >= min_chunk_size * 2
73 && let Ok(Some(range)) = steal_task.split_two()
74 {
75 *task = Task::new(range);
76 found = true;
77 } else if max_speculative > 1
78 && steal_task.remain() > 0
79 && steal_task.strong_count() <= max_speculative
80 {
81 task.state = steal_task.state;
82 found = true;
83 }
84 }
85 if found {
86 guard.running[worker_idx].0 = task.downgrade();
87 }
88 found
89 }
90 #[must_use]
92 pub fn set_threads<E: Executor<Handle = H>>(
93 &self,
94 threads: usize,
95 min_chunk_size: u64,
96 executor: Option<&E>,
97 ) -> Option<()> {
98 #![allow(clippy::significant_drop_tightening)]
99 let min_chunk_size = min_chunk_size.max(1);
100 let mut guard = self.inner.lock();
101 guard.running.retain(|t| t.0.strong_count() > 0);
102 let len = guard.running.len();
103 if len < threads {
104 let executor = executor?;
105 let need = (threads - len).min(guard.waiting.len());
106 let mut temp = Vec::with_capacity(need);
107 let iter = guard.waiting.drain(..need);
108 for task in iter {
109 let weak = task.downgrade();
110 let handle = executor.execute(task, self.clone());
111 temp.push((weak, handle));
112 }
113 guard.running.extend(temp);
114 while guard.running.len() < threads
115 && let Some(steal_task) = guard
116 .running
117 .iter()
118 .filter_map(|w| w.0.upgrade())
119 .max_by_key(Task::remain)
120 && steal_task.remain() >= min_chunk_size * 2
121 && let Ok(Some(range)) = steal_task.split_two()
122 {
123 let task = Task::new(range);
124 let weak = task.downgrade();
125 let handle = executor.execute(task, self.clone());
126 guard.running.push_back((weak, handle));
127 }
128 } else if len > threads {
129 let mut temp = Vec::with_capacity(len - threads);
130 let iter = guard.running.drain(threads..);
131 for (task, mut handle) in iter {
132 if let Some(task) = task.upgrade() {
133 temp.push(task);
134 }
135 handle.abort();
136 }
137 guard.waiting.extend(temp);
138 }
139 Some(())
140 }
141 pub fn handles<F, R>(&self, f: F) -> R
142 where
143 F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
144 {
145 #![allow(clippy::significant_drop_tightening)]
146 let mut guard = self.inner.lock();
147 let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
148 f(&mut iter)
149 }
150
151 pub fn cancel_task(&self, task: &Task, id: &H::Id) {
152 let mut guard = self.inner.lock();
153 for (weak, handle) in &mut guard.running {
154 if let Some(t) = weak.upgrade()
155 && t == *task
156 && !handle.is_self(id)
157 {
158 handle.abort();
159 }
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 #![allow(clippy::unwrap_used)]
167 extern crate std;
168 use crate::{Executor, Handle, Task, TaskQueue};
169 use std::{collections::HashMap, dbg, println};
170 use tokio::{sync::mpsc, task::AbortHandle};
171
172 struct TokioExecutor {
173 tx: mpsc::UnboundedSender<(u64, u64)>,
174 speculative: usize,
175 }
176 #[derive(Clone)]
177 struct TokioHandle(AbortHandle);
178
179 impl Handle for TokioHandle {
180 type Output = ();
181 type Id = ();
182 fn abort(&mut self) -> Self::Output {
183 self.0.abort();
184 }
185 fn is_self(&mut self, (): &Self::Id) -> bool {
186 false
187 }
188 }
189
190 impl Executor for TokioExecutor {
191 type Handle = TokioHandle;
192 fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
193 println!("execute");
194 let tx = self.tx.clone();
195 let speculative = self.speculative;
196 let handle = tokio::spawn(async move {
197 loop {
198 while task.start() < task.end() {
199 let i = task.start();
200 let res = fib(i);
201 let Ok(_) = task.safe_add_start(i, 1) else {
202 println!("task-failed: {i} = {res}");
203 continue;
204 };
205 println!("task: {i} = {res}");
206 tx.send((i, res)).unwrap();
207 }
208 if !task_queue.steal(&(), &mut task, 1, speculative) {
209 break;
210 }
211 }
212 });
213 let abort_handle = handle.abort_handle();
214 TokioHandle(abort_handle)
215 }
216 }
217
218 fn fib(n: u64) -> u64 {
219 match n {
220 0 => 0,
221 1 => 1,
222 _ => fib(n - 1) + fib(n - 2),
223 }
224 }
225 fn fib_fast(n: u64) -> u64 {
226 let mut a = 0;
227 let mut b = 1;
228 for _ in 0..n {
229 (a, b) = (b, a + b);
230 }
231 a
232 }
233
234 #[tokio::test(flavor = "multi_thread")]
235 async fn test_task_queue() {
236 let (tx, mut rx) = mpsc::unbounded_channel();
237 let executor = TokioExecutor { tx, speculative: 1 };
238 let pre_data = [1..20, 41..48];
239 let task_queue = TaskQueue::new(pre_data.iter().cloned());
240 task_queue.set_threads(8, 1, Some(&executor)).unwrap();
241 drop(executor);
242 let mut data = HashMap::new();
243 while let Some((i, res)) = rx.recv().await {
244 println!("main: {i} = {res}");
245 assert!(
246 data.insert(i, res).is_none(),
247 "数字 {i},值为 {res} 重复计算"
248 );
249 }
250 dbg!(&data);
251 for range in pre_data {
252 for i in range {
253 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
254 data.remove(&i);
255 }
256 }
257 assert_eq!(data.len(), 0);
258 }
259
260 #[tokio::test(flavor = "multi_thread")]
261 async fn test_task_queue2() {
262 let (tx, mut rx) = mpsc::unbounded_channel();
263 let executor = TokioExecutor { tx, speculative: 2 };
264 let pre_data = [1..20, 41..48];
265 let task_queue = TaskQueue::new(pre_data.iter().cloned());
266 task_queue.set_threads(8, 1, Some(&executor)).unwrap();
267 drop(executor);
268 let mut data = HashMap::new();
269 while let Some((i, res)) = rx.recv().await {
270 println!("main: {i} = {res}");
271 assert!(
272 data.insert(i, res).is_none(),
273 "数字 {i},值为 {res} 重复计算"
274 );
275 }
276 dbg!(&data);
277 for range in pre_data {
278 for i in range {
279 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
280 data.remove(&i);
281 }
282 }
283 assert_eq!(data.len(), 0);
284 }
285}