1extern crate alloc;
2use crate::{Executor, Handle, Task};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9 inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12 fn clone(&self) -> Self {
13 Self {
14 inner: self.inner.clone(),
15 }
16 }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20 running: VecDeque<(Task, H)>,
21 waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24 pub fn new<'a>(tasks: impl Iterator<Item = &'a Range<u64>>) -> Self {
25 let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26 Self {
27 inner: Arc::new(Mutex::new(TaskQueueInner {
28 running: VecDeque::with_capacity(waiting.len()),
29 waiting,
30 })),
31 }
32 }
33 pub fn finish_work(&self, task: &Task) -> usize {
35 let mut guard = self.inner.lock();
36 let len = guard.running.len();
37 guard.running.retain(|(t, _)| t != task);
38 len - guard.running.len()
39 }
40 pub fn cancel(&self, task: &Task) -> usize {
42 let mut guard = self.inner.lock();
43 let len = guard.running.len() + guard.waiting.len();
44 guard.running.retain(|(t, _)| t != task);
45 guard.waiting.retain(|t| t != task);
46 len - guard.running.len() - guard.waiting.len()
47 }
48 pub fn add(&self, task: Task) {
49 let mut guard = self.inner.lock();
50 guard.waiting.push_back(task);
51 }
52 pub fn steal(&self, task: &Task, min_chunk_size: u64) -> bool {
53 let min_chunk_size = min_chunk_size.max(1);
54 let mut guard = self.inner.lock();
55 while let Some(new_task) = guard.waiting.pop_front() {
56 if let Some(range) = new_task.take() {
57 task.set(range);
58 return true;
59 }
60 }
61 if let Some(steal_task) = guard.running.iter().map(|w| &w.0).max()
62 && steal_task.remain() >= min_chunk_size * 2
63 && let Ok(Some(range)) = steal_task.split_two()
64 {
65 task.set(range);
66 true
67 } else {
68 false
69 }
70 }
71 pub fn set_threads<E: Executor<Handle = H>>(
73 &self,
74 threads: usize,
75 min_chunk_size: u64,
76 executor: Option<&E>,
77 ) -> Option<()> {
78 let threads = threads.max(1);
79 let min_chunk_size = min_chunk_size.max(1);
80 let mut guard = self.inner.lock();
81 let len = guard.running.len();
82 if len < threads {
83 let executor = executor?;
84 let need = (threads - len).min(guard.waiting.len());
85 let mut temp = Vec::with_capacity(need);
86 let iter = guard.waiting.drain(..need);
87 for task in iter {
88 let handle = executor.execute(task.clone(), self.clone());
89 temp.push((task, handle));
90 }
91 guard.running.extend(temp);
92 while guard.running.len() < threads
93 && let Some(steal_task) = guard.running.iter().map(|w| &w.0).max()
94 && steal_task.remain() >= min_chunk_size * 2
95 && let Ok(Some(range)) = steal_task.split_two()
96 {
97 let task = Task::new(range);
98 let handle = executor.execute(task.clone(), self.clone());
99 guard.running.push_back((task, handle));
100 }
101 } else if len > threads {
102 let mut temp = Vec::with_capacity(len - threads);
103 let iter = guard.running.drain(threads..);
104 for (task, mut handle) in iter {
105 handle.abort();
106 temp.push(task);
107 }
108 guard.waiting.extend(temp);
109 }
110 Some(())
111 }
112 pub fn handles<F, R>(&self, f: F) -> R
113 where
114 F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
115 {
116 let mut guard = self.inner.lock();
117 let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
118 f(&mut iter)
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 extern crate std;
125 use crate::{Executor, Handle, Task, TaskQueue};
126 use std::{collections::HashMap, dbg, println};
127 use tokio::{sync::mpsc, task::AbortHandle};
128
129 pub struct TokioExecutor {
130 tx: mpsc::UnboundedSender<(u64, u64)>,
131 }
132 #[derive(Clone)]
133 pub struct TokioHandle(AbortHandle);
134
135 impl Handle for TokioHandle {
136 type Output = ();
137 fn abort(&mut self) -> Self::Output {
138 self.0.abort();
139 }
140 }
141
142 impl Executor for TokioExecutor {
143 type Handle = TokioHandle;
144 fn execute(&self, task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
145 println!("execute");
146 let tx = self.tx.clone();
147 let handle = tokio::spawn(async move {
148 loop {
149 while task.start() < task.end() {
150 let i = task.start();
151 task.fetch_add_start(1).unwrap();
152 let res = fib(i);
153 println!("task: {i} = {res}");
154 tx.send((i, res)).unwrap();
155 }
156 if !task_queue.steal(&task, 1) {
157 break;
158 }
159 }
160 assert_eq!(task_queue.finish_work(&task), 1);
161 });
162 let abort_handle = handle.abort_handle();
163 TokioHandle(abort_handle)
164 }
165 }
166
167 fn fib(n: u64) -> u64 {
168 match n {
169 0 => 0,
170 1 => 1,
171 _ => fib(n - 1) + fib(n - 2),
172 }
173 }
174 fn fib_fast(n: u64) -> u64 {
175 let mut a = 0;
176 let mut b = 1;
177 for _ in 0..n {
178 (a, b) = (b, a + b);
179 }
180 a
181 }
182
183 #[tokio::test]
184 async fn test_task_queue() {
185 let (tx, mut rx) = mpsc::unbounded_channel();
186 let executor = TokioExecutor { tx };
187 let pre_data = [1..20, 41..48];
188 let task_queue = TaskQueue::new(pre_data.iter());
189 task_queue.set_threads(8, 1, Some(&executor));
190 drop(executor);
191 let mut data = HashMap::new();
192 while let Some((i, res)) = rx.recv().await {
193 println!("main: {i} = {res}");
194 if data.insert(i, res).is_some() {
195 panic!("数字 {i},值为 {res} 重复计算");
196 }
197 }
198 dbg!(&data);
199 for range in pre_data {
200 for i in range {
201 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
202 data.remove(&i);
203 }
204 }
205 assert_eq!(data.len(), 0);
206 }
207}