1extern crate alloc;
2use crate::{Executor, Handle, Task, WeakTask};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9 inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12 fn clone(&self) -> Self {
13 Self {
14 inner: self.inner.clone(),
15 }
16 }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20 running: VecDeque<(WeakTask, H)>,
21 waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24 pub fn new(tasks: impl Iterator<Item = Range<u64>>) -> Self {
25 let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26 Self {
27 inner: Arc::new(Mutex::new(TaskQueueInner {
28 running: VecDeque::with_capacity(waiting.len()),
29 waiting,
30 })),
31 }
32 }
33 pub fn add(&self, task: Task) {
34 let mut guard = self.inner.lock();
35 guard.waiting.push_back(task);
36 }
37 pub fn steal(&self, task: &mut Task, min_chunk_size: u64, max_speculative: usize) -> bool {
38 let min_chunk_size = min_chunk_size.max(1);
39 let mut guard = self.inner.lock();
40 while let Some(new_task) = guard.waiting.pop_front() {
41 if let Some(range) = new_task.take() {
42 task.set(range);
43 return true;
44 }
45 }
46 if let Some(steal_task) = guard
47 .running
48 .iter()
49 .filter_map(|w| w.0.upgrade())
50 .filter(|w| w != task)
51 .max_by_key(|w| w.remain())
52 {
53 if steal_task.remain() >= min_chunk_size * 2
54 && let Ok(Some(range)) = steal_task.split_two()
55 {
56 task.set(range);
57 true
58 } else if max_speculative > 1
59 && steal_task.remain() > 0
60 && steal_task.strong_count() <= max_speculative
61 {
62 task.state = steal_task.state.clone();
63 true
64 } else {
65 false
66 }
67 } else {
68 false
69 }
70 }
71 pub fn set_threads<E: Executor<Handle = H>>(
73 &self,
74 threads: usize,
75 min_chunk_size: u64,
76 executor: Option<&E>,
77 ) -> Option<()> {
78 let min_chunk_size = min_chunk_size.max(1);
79 let mut guard = self.inner.lock();
80 guard.running.retain(|t| t.0.strong_count() > 0);
81 let len = guard.running.len();
82 if len < threads {
83 let executor = executor?;
84 let need = (threads - len).min(guard.waiting.len());
85 let mut temp = Vec::with_capacity(need);
86 let iter = guard.waiting.drain(..need);
87 for task in iter {
88 let weak = task.downgrade();
89 let handle = executor.execute(task, self.clone());
90 temp.push((weak, handle));
91 }
92 guard.running.extend(temp);
93 while guard.running.len() < threads
94 && let Some(steal_task) = guard
95 .running
96 .iter()
97 .filter_map(|w| w.0.upgrade())
98 .max_by_key(|w| w.remain())
99 && steal_task.remain() >= min_chunk_size * 2
100 && let Ok(Some(range)) = steal_task.split_two()
101 {
102 let task = Task::new(range);
103 let weak = task.downgrade();
104 let handle = executor.execute(task, self.clone());
105 guard.running.push_back((weak, handle));
106 }
107 } else if len > threads {
108 let mut temp = Vec::with_capacity(len - threads);
109 let iter = guard.running.drain(threads..);
110 for (task, mut handle) in iter {
111 if let Some(task) = task.upgrade() {
112 temp.push(task);
113 }
114 handle.abort();
115 }
116 guard.waiting.extend(temp);
117 }
118 Some(())
119 }
120 pub fn handles<F, R>(&self, f: F) -> R
121 where
122 F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
123 {
124 let mut guard = self.inner.lock();
125 let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
126 f(&mut iter)
127 }
128
129 pub fn cancel_task(&self, task: &Task, id: &H::Id) {
130 let mut guard = self.inner.lock();
131 for (weak, handle) in &mut guard.running {
132 if let Some(t) = weak.upgrade()
133 && t == *task
134 && !handle.is_self(id)
135 {
136 handle.abort();
137 }
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 extern crate std;
145 use crate::{Executor, Handle, Task, TaskQueue};
146 use std::{collections::HashMap, dbg, println};
147 use tokio::{sync::mpsc, task::AbortHandle};
148
149 pub struct TokioExecutor {
150 tx: mpsc::UnboundedSender<(u64, u64)>,
151 speculative: usize,
152 }
153 #[derive(Clone)]
154 pub struct TokioHandle(AbortHandle);
155
156 impl Handle for TokioHandle {
157 type Output = ();
158 type Id = ();
159 fn abort(&mut self) -> Self::Output {
160 self.0.abort();
161 }
162 fn is_self(&mut self, _: &Self::Id) -> bool {
163 false
164 }
165 }
166
167 impl Executor for TokioExecutor {
168 type Handle = TokioHandle;
169 fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
170 println!("execute");
171 let tx = self.tx.clone();
172 let speculative = self.speculative;
173 let handle = tokio::spawn(async move {
174 loop {
175 while task.start() < task.end() {
176 let i = task.start();
177 let res = fib(i);
178 let Ok(_) = task.safe_add_start(i, 1) else {
179 println!("task-failed: {i} = {res}");
180 continue;
181 };
182 println!("task: {i} = {res}");
183 tx.send((i, res)).unwrap();
184 }
185 if !task_queue.steal(&mut task, 1, speculative) {
186 break;
187 }
188 }
189 });
190 let abort_handle = handle.abort_handle();
191 TokioHandle(abort_handle)
192 }
193 }
194
195 fn fib(n: u64) -> u64 {
196 match n {
197 0 => 0,
198 1 => 1,
199 _ => fib(n - 1) + fib(n - 2),
200 }
201 }
202 fn fib_fast(n: u64) -> u64 {
203 let mut a = 0;
204 let mut b = 1;
205 for _ in 0..n {
206 (a, b) = (b, a + b);
207 }
208 a
209 }
210
211 #[tokio::test(flavor = "multi_thread")]
212 async fn test_task_queue() {
213 let (tx, mut rx) = mpsc::unbounded_channel();
214 let executor = TokioExecutor { tx, speculative: 1 };
215 let pre_data = [1..20, 41..48];
216 let task_queue = TaskQueue::new(pre_data.iter().cloned());
217 task_queue.set_threads(8, 1, Some(&executor));
218 drop(executor);
219 let mut data = HashMap::new();
220 while let Some((i, res)) = rx.recv().await {
221 println!("main: {i} = {res}");
222 if data.insert(i, res).is_some() {
223 panic!("数字 {i},值为 {res} 重复计算");
224 }
225 }
226 dbg!(&data);
227 for range in pre_data {
228 for i in range {
229 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
230 data.remove(&i);
231 }
232 }
233 assert_eq!(data.len(), 0);
234 }
235
236 #[tokio::test(flavor = "multi_thread")]
237 async fn test_task_queue2() {
238 let (tx, mut rx) = mpsc::unbounded_channel();
239 let executor = TokioExecutor { tx, speculative: 2 };
240 let pre_data = [1..20, 41..48];
241 let task_queue = TaskQueue::new(pre_data.iter().cloned());
242 task_queue.set_threads(8, 1, Some(&executor));
243 drop(executor);
244 let mut data = HashMap::new();
245 while let Some((i, res)) = rx.recv().await {
246 println!("main: {i} = {res}");
247 if data.insert(i, res).is_some() {
248 panic!("数字 {i},值为 {res} 重复计算");
249 }
250 }
251 dbg!(&data);
252 for range in pre_data {
253 for i in range {
254 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
255 data.remove(&i);
256 }
257 }
258 assert_eq!(data.len(), 0);
259 }
260}