1extern crate alloc;
2use crate::{Executor, Handle, Task, WeakTask};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9 inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12 fn clone(&self) -> Self {
13 Self {
14 inner: self.inner.clone(),
15 }
16 }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20 running: VecDeque<(WeakTask, H)>,
21 waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24 pub fn new(tasks: impl Iterator<Item = Range<u64>>) -> Self {
25 let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26 Self {
27 inner: Arc::new(Mutex::new(TaskQueueInner {
28 running: VecDeque::with_capacity(waiting.len()),
29 waiting,
30 })),
31 }
32 }
33 pub fn add(&self, task: Task) {
34 let mut guard = self.inner.lock();
35 guard.waiting.push_back(task);
36 }
37 pub fn steal(&self, task: &mut Task, min_chunk_size: u64, max_speculative: usize) -> bool {
38 let min_chunk_size = min_chunk_size.max(1);
39 let mut guard = self.inner.lock();
40 while let Some(new_task) = guard.waiting.pop_front() {
41 if let Some(range) = new_task.take() {
42 task.set(range);
43 return true;
44 }
45 }
46 if let Some(steal_task) = guard
47 .running
48 .iter()
49 .filter_map(|w| w.0.upgrade())
50 .filter(|w| w != task)
51 .max_by_key(Task::remain)
52 {
53 if steal_task.remain() >= min_chunk_size * 2
54 && let Ok(Some(range)) = steal_task.split_two()
55 {
56 task.set(range);
57 true
58 } else if max_speculative > 1
59 && steal_task.remain() > 0
60 && steal_task.strong_count() <= max_speculative
61 {
62 task.state = steal_task.state;
63 true
64 } else {
65 false
66 }
67 } else {
68 false
69 }
70 }
71 #[must_use]
73 pub fn set_threads<E: Executor<Handle = H>>(
74 &self,
75 threads: usize,
76 min_chunk_size: u64,
77 executor: Option<&E>,
78 ) -> Option<()> {
79 #![allow(clippy::significant_drop_tightening)]
80 let min_chunk_size = min_chunk_size.max(1);
81 let mut guard = self.inner.lock();
82 guard.running.retain(|t| t.0.strong_count() > 0);
83 let len = guard.running.len();
84 if len < threads {
85 let executor = executor?;
86 let need = (threads - len).min(guard.waiting.len());
87 let mut temp = Vec::with_capacity(need);
88 let iter = guard.waiting.drain(..need);
89 for task in iter {
90 let weak = task.downgrade();
91 let handle = executor.execute(task, self.clone());
92 temp.push((weak, handle));
93 }
94 guard.running.extend(temp);
95 while guard.running.len() < threads
96 && let Some(steal_task) = guard
97 .running
98 .iter()
99 .filter_map(|w| w.0.upgrade())
100 .max_by_key(Task::remain)
101 && steal_task.remain() >= min_chunk_size * 2
102 && let Ok(Some(range)) = steal_task.split_two()
103 {
104 let task = Task::new(range);
105 let weak = task.downgrade();
106 let handle = executor.execute(task, self.clone());
107 guard.running.push_back((weak, handle));
108 }
109 } else if len > threads {
110 let mut temp = Vec::with_capacity(len - threads);
111 let iter = guard.running.drain(threads..);
112 for (task, mut handle) in iter {
113 if let Some(task) = task.upgrade() {
114 temp.push(task);
115 }
116 handle.abort();
117 }
118 guard.waiting.extend(temp);
119 }
120 Some(())
121 }
122 pub fn handles<F, R>(&self, f: F) -> R
123 where
124 F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
125 {
126 #![allow(clippy::significant_drop_tightening)]
127 let mut guard = self.inner.lock();
128 let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
129 f(&mut iter)
130 }
131
132 pub fn cancel_task(&self, task: &Task, id: &H::Id) {
133 let mut guard = self.inner.lock();
134 for (weak, handle) in &mut guard.running {
135 if let Some(t) = weak.upgrade()
136 && t == *task
137 && !handle.is_self(id)
138 {
139 handle.abort();
140 }
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 #![allow(clippy::unwrap_used)]
148 extern crate std;
149 use crate::{Executor, Handle, Task, TaskQueue};
150 use std::{collections::HashMap, dbg, println};
151 use tokio::{sync::mpsc, task::AbortHandle};
152
153 struct TokioExecutor {
154 tx: mpsc::UnboundedSender<(u64, u64)>,
155 speculative: usize,
156 }
157 #[derive(Clone)]
158 struct TokioHandle(AbortHandle);
159
160 impl Handle for TokioHandle {
161 type Output = ();
162 type Id = ();
163 fn abort(&mut self) -> Self::Output {
164 self.0.abort();
165 }
166 fn is_self(&mut self, (): &Self::Id) -> bool {
167 false
168 }
169 }
170
171 impl Executor for TokioExecutor {
172 type Handle = TokioHandle;
173 fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
174 println!("execute");
175 let tx = self.tx.clone();
176 let speculative = self.speculative;
177 let handle = tokio::spawn(async move {
178 loop {
179 while task.start() < task.end() {
180 let i = task.start();
181 let res = fib(i);
182 let Ok(_) = task.safe_add_start(i, 1) else {
183 println!("task-failed: {i} = {res}");
184 continue;
185 };
186 println!("task: {i} = {res}");
187 tx.send((i, res)).unwrap();
188 }
189 if !task_queue.steal(&mut task, 1, speculative) {
190 break;
191 }
192 }
193 });
194 let abort_handle = handle.abort_handle();
195 TokioHandle(abort_handle)
196 }
197 }
198
199 fn fib(n: u64) -> u64 {
200 match n {
201 0 => 0,
202 1 => 1,
203 _ => fib(n - 1) + fib(n - 2),
204 }
205 }
206 fn fib_fast(n: u64) -> u64 {
207 let mut a = 0;
208 let mut b = 1;
209 for _ in 0..n {
210 (a, b) = (b, a + b);
211 }
212 a
213 }
214
215 #[tokio::test(flavor = "multi_thread")]
216 async fn test_task_queue() {
217 let (tx, mut rx) = mpsc::unbounded_channel();
218 let executor = TokioExecutor { tx, speculative: 1 };
219 let pre_data = [1..20, 41..48];
220 let task_queue = TaskQueue::new(pre_data.iter().cloned());
221 task_queue.set_threads(8, 1, Some(&executor)).unwrap();
222 drop(executor);
223 let mut data = HashMap::new();
224 while let Some((i, res)) = rx.recv().await {
225 println!("main: {i} = {res}");
226 assert!(
227 data.insert(i, res).is_none(),
228 "数字 {i},值为 {res} 重复计算"
229 );
230 }
231 dbg!(&data);
232 for range in pre_data {
233 for i in range {
234 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
235 data.remove(&i);
236 }
237 }
238 assert_eq!(data.len(), 0);
239 }
240
241 #[tokio::test(flavor = "multi_thread")]
242 async fn test_task_queue2() {
243 let (tx, mut rx) = mpsc::unbounded_channel();
244 let executor = TokioExecutor { tx, speculative: 2 };
245 let pre_data = [1..20, 41..48];
246 let task_queue = TaskQueue::new(pre_data.iter().cloned());
247 task_queue.set_threads(8, 1, Some(&executor)).unwrap();
248 drop(executor);
249 let mut data = HashMap::new();
250 while let Some((i, res)) = rx.recv().await {
251 println!("main: {i} = {res}");
252 assert!(
253 data.insert(i, res).is_none(),
254 "数字 {i},值为 {res} 重复计算"
255 );
256 }
257 dbg!(&data);
258 for range in pre_data {
259 for i in range {
260 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
261 data.remove(&i);
262 }
263 }
264 assert_eq!(data.len(), 0);
265 }
266}