1extern crate alloc;
2use crate::{Executor, Handle, Task, WeakTask};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9 inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12 fn clone(&self) -> Self {
13 Self {
14 inner: self.inner.clone(),
15 }
16 }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20 running: VecDeque<(WeakTask, H)>,
21 waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24 pub fn new<'a>(tasks: impl Iterator<Item = &'a Range<u64>>) -> Self {
25 let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26 Self {
27 inner: Arc::new(Mutex::new(TaskQueueInner {
28 running: VecDeque::with_capacity(waiting.len()),
29 waiting,
30 })),
31 }
32 }
33 pub fn add(&self, task: Task) {
34 let mut guard = self.inner.lock();
35 guard.waiting.push_back(task);
36 }
37 pub fn steal(&self, task: &mut Task, min_chunk_size: u64, speculative: bool) -> bool {
38 let min_chunk_size = min_chunk_size.max(1);
39 let mut guard = self.inner.lock();
40 while let Some(new_task) = guard.waiting.pop_front() {
41 if let Some(range) = new_task.take() {
42 task.set(range);
43 return true;
44 }
45 }
46 if let Some(steal_task) = guard
47 .running
48 .iter()
49 .filter_map(|w| w.0.upgrade())
50 .filter(|w| w != task)
51 .max_by_key(|w| w.remain())
52 {
53 if steal_task.remain() >= min_chunk_size * 2
54 && let Ok(Some(range)) = steal_task.split_two()
55 {
56 task.set(range);
57 true
58 } else if speculative && steal_task.remain() > 0 {
59 task.state = steal_task.state.clone();
60 true
61 } else {
62 false
63 }
64 } else {
65 false
66 }
67 }
68 pub fn set_threads<E: Executor<Handle = H>>(
70 &self,
71 threads: usize,
72 min_chunk_size: u64,
73 executor: Option<&E>,
74 ) -> Option<()> {
75 let min_chunk_size = min_chunk_size.max(1);
76 let mut guard = self.inner.lock();
77 guard.running.retain(|t| t.0.strong_count() > 0);
78 let len = guard.running.len();
79 if len < threads {
80 let executor = executor?;
81 let need = (threads - len).min(guard.waiting.len());
82 let mut temp = Vec::with_capacity(need);
83 let iter = guard.waiting.drain(..need);
84 for task in iter {
85 let weak = task.downgrade();
86 let handle = executor.execute(task, self.clone());
87 temp.push((weak, handle));
88 }
89 guard.running.extend(temp);
90 while guard.running.len() < threads
91 && let Some(steal_task) = guard
92 .running
93 .iter()
94 .filter_map(|w| w.0.upgrade())
95 .max_by_key(|w| w.remain())
96 && steal_task.remain() >= min_chunk_size * 2
97 && let Ok(Some(range)) = steal_task.split_two()
98 {
99 let task = Task::new(range);
100 let weak = task.downgrade();
101 let handle = executor.execute(task, self.clone());
102 guard.running.push_back((weak, handle));
103 }
104 } else if len > threads {
105 let mut temp = Vec::with_capacity(len - threads);
106 let iter = guard.running.drain(threads..);
107 for (task, mut handle) in iter {
108 if let Some(task) = task.upgrade() {
109 temp.push(task);
110 }
111 handle.abort();
112 }
113 guard.waiting.extend(temp);
114 }
115 Some(())
116 }
117 pub fn handles<F, R>(&self, f: F) -> R
118 where
119 F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
120 {
121 let mut guard = self.inner.lock();
122 let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
123 f(&mut iter)
124 }
125 pub fn cancel_tasks(&self, task: &Task) {
126 let mut guard = self.inner.lock();
127 for (weak, handle) in &mut guard.running {
128 if let Some(t) = weak.upgrade()
129 && t == *task
130 {
131 handle.abort();
132 }
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 extern crate std;
140 use crate::{Executor, Handle, Task, TaskQueue};
141 use std::{collections::HashMap, dbg, println};
142 use tokio::{sync::mpsc, task::AbortHandle};
143
144 pub struct TokioExecutor {
145 tx: mpsc::UnboundedSender<(u64, u64)>,
146 speculative: bool,
147 }
148 #[derive(Clone)]
149 pub struct TokioHandle(AbortHandle);
150
151 impl Handle for TokioHandle {
152 type Output = ();
153 fn abort(&mut self) -> Self::Output {
154 self.0.abort();
155 }
156 }
157
158 impl Executor for TokioExecutor {
159 type Handle = TokioHandle;
160 fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
161 println!("execute");
162 let tx = self.tx.clone();
163 let speculative = self.speculative;
164 let handle = tokio::spawn(async move {
165 loop {
166 while task.start() < task.end() {
167 let i = task.start();
168 let res = fib(i);
169 let Ok(_) = task.safe_add_start(i, 1) else {
170 println!("task-failed: {i} = {res}");
171 continue;
172 };
173 println!("task: {i} = {res}");
174 tx.send((i, res)).unwrap();
175 }
176 if !task_queue.steal(&mut task, 1, speculative) {
177 break;
178 }
179 }
180 });
181 let abort_handle = handle.abort_handle();
182 TokioHandle(abort_handle)
183 }
184 }
185
186 fn fib(n: u64) -> u64 {
187 match n {
188 0 => 0,
189 1 => 1,
190 _ => fib(n - 1) + fib(n - 2),
191 }
192 }
193 fn fib_fast(n: u64) -> u64 {
194 let mut a = 0;
195 let mut b = 1;
196 for _ in 0..n {
197 (a, b) = (b, a + b);
198 }
199 a
200 }
201
202 #[tokio::test(flavor = "multi_thread")]
203 async fn test_task_queue() {
204 let (tx, mut rx) = mpsc::unbounded_channel();
205 let executor = TokioExecutor {
206 tx,
207 speculative: false,
208 };
209 let pre_data = [1..20, 41..48];
210 let task_queue = TaskQueue::new(pre_data.iter());
211 task_queue.set_threads(8, 1, Some(&executor));
212 drop(executor);
213 let mut data = HashMap::new();
214 while let Some((i, res)) = rx.recv().await {
215 println!("main: {i} = {res}");
216 if data.insert(i, res).is_some() {
217 panic!("数字 {i},值为 {res} 重复计算");
218 }
219 }
220 dbg!(&data);
221 for range in pre_data {
222 for i in range {
223 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
224 data.remove(&i);
225 }
226 }
227 assert_eq!(data.len(), 0);
228 }
229
230 #[tokio::test(flavor = "multi_thread")]
231 async fn test_task_queue2() {
232 let (tx, mut rx) = mpsc::unbounded_channel();
233 let executor = TokioExecutor {
234 tx,
235 speculative: true,
236 };
237 let pre_data = [1..20, 41..48];
238 let task_queue = TaskQueue::new(pre_data.iter());
239 task_queue.set_threads(8, 1, Some(&executor));
240 drop(executor);
241 let mut data = HashMap::new();
242 while let Some((i, res)) = rx.recv().await {
243 println!("main: {i} = {res}");
244 if data.insert(i, res).is_some() {
245 panic!("数字 {i},值为 {res} 重复计算");
246 }
247 }
248 dbg!(&data);
249 for range in pre_data {
250 for i in range {
251 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
252 data.remove(&i);
253 }
254 }
255 assert_eq!(data.len(), 0);
256 }
257}