1extern crate alloc;
2use crate::{Executor, Handle, Task};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9 inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12 fn clone(&self) -> Self {
13 Self {
14 inner: self.inner.clone(),
15 }
16 }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20 running: VecDeque<(Task, H)>,
21 waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24 pub fn new<'a>(tasks: impl Iterator<Item = &'a Range<u64>>) -> Self {
25 let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26 Self {
27 inner: Arc::new(Mutex::new(TaskQueueInner {
28 running: VecDeque::with_capacity(waiting.len()),
29 waiting,
30 })),
31 }
32 }
33 pub fn finish_work(&self, task: &Task) -> usize {
35 let mut guard = self.inner.lock();
36 let len = guard.running.len();
37 guard.running.retain(|(t, _)| t != task);
38 len - guard.running.len()
39 }
40 pub fn cancel(&self, task: &Task) -> usize {
42 let mut guard = self.inner.lock();
43 let len = guard.running.len() + guard.waiting.len();
44 guard.running.retain(|(t, _)| t != task);
45 guard.waiting.retain(|t| t != task);
46 len - guard.running.len() - guard.waiting.len()
47 }
48 pub fn add(&self, task: Task) {
49 let mut guard = self.inner.lock();
50 guard.waiting.push_back(task);
51 }
52 pub fn steal(&self, task: &Task, min_chunk_size: u64) -> bool {
53 let min_chunk_size = min_chunk_size.max(1);
54 let mut guard = self.inner.lock();
55 while let Some(new_task) = guard.waiting.pop_front() {
56 if let Some(range) = new_task.take() {
57 task.set(range);
58 return true;
59 }
60 }
61 if let Some(steal_task) = guard.running.iter().map(|w| &w.0).max()
62 && steal_task.remain() >= min_chunk_size * 2
63 && let Ok(Some(range)) = steal_task.split_two()
64 {
65 task.set(range);
66 true
67 } else {
68 false
69 }
70 }
71 pub fn set_threads<E: Executor<Handle = H>>(
73 &self,
74 threads: usize,
75 min_chunk_size: u64,
76 executor: Option<&E>,
77 ) -> Option<()> {
78 let min_chunk_size = min_chunk_size.max(1);
79 let mut guard = self.inner.lock();
80 let len = guard.running.len();
81 if len < threads {
82 let executor = executor?;
83 let need = (threads - len).min(guard.waiting.len());
84 let mut temp = Vec::with_capacity(need);
85 let iter = guard.waiting.drain(..need);
86 for task in iter {
87 let handle = executor.execute(task.clone(), self.clone());
88 temp.push((task, handle));
89 }
90 guard.running.extend(temp);
91 while guard.running.len() < threads
92 && let Some(steal_task) = guard.running.iter().map(|w| &w.0).max()
93 && steal_task.remain() >= min_chunk_size * 2
94 && let Ok(Some(range)) = steal_task.split_two()
95 {
96 let task = Task::new(range);
97 let handle = executor.execute(task.clone(), self.clone());
98 guard.running.push_back((task, handle));
99 }
100 } else if len > threads {
101 let mut temp = Vec::with_capacity(len - threads);
102 let iter = guard.running.drain(threads..);
103 for (task, mut handle) in iter {
104 handle.abort();
105 temp.push(task);
106 }
107 guard.waiting.extend(temp);
108 }
109 Some(())
110 }
111 pub fn handles<F, R>(&self, f: F) -> R
112 where
113 F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
114 {
115 let mut guard = self.inner.lock();
116 let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
117 f(&mut iter)
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 extern crate std;
124 use crate::{Executor, Handle, Task, TaskQueue};
125 use std::{collections::HashMap, dbg, println};
126 use tokio::{sync::mpsc, task::AbortHandle};
127
128 pub struct TokioExecutor {
129 tx: mpsc::UnboundedSender<(u64, u64)>,
130 }
131 #[derive(Clone)]
132 pub struct TokioHandle(AbortHandle);
133
134 impl Handle for TokioHandle {
135 type Output = ();
136 fn abort(&mut self) -> Self::Output {
137 self.0.abort();
138 }
139 }
140
141 impl Executor for TokioExecutor {
142 type Handle = TokioHandle;
143 fn execute(&self, task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
144 println!("execute");
145 let tx = self.tx.clone();
146 let handle = tokio::spawn(async move {
147 loop {
148 while task.start() < task.end() {
149 let i = task.start();
150 task.fetch_add_start(1).unwrap();
151 let res = fib(i);
152 println!("task: {i} = {res}");
153 tx.send((i, res)).unwrap();
154 }
155 if !task_queue.steal(&task, 1) {
156 break;
157 }
158 }
159 assert_eq!(task_queue.finish_work(&task), 1);
160 });
161 let abort_handle = handle.abort_handle();
162 TokioHandle(abort_handle)
163 }
164 }
165
166 fn fib(n: u64) -> u64 {
167 match n {
168 0 => 0,
169 1 => 1,
170 _ => fib(n - 1) + fib(n - 2),
171 }
172 }
173 fn fib_fast(n: u64) -> u64 {
174 let mut a = 0;
175 let mut b = 1;
176 for _ in 0..n {
177 (a, b) = (b, a + b);
178 }
179 a
180 }
181
182 #[tokio::test]
183 async fn test_task_queue() {
184 let (tx, mut rx) = mpsc::unbounded_channel();
185 let executor = TokioExecutor { tx };
186 let pre_data = [1..20, 41..48];
187 let task_queue = TaskQueue::new(pre_data.iter());
188 task_queue.set_threads(8, 1, Some(&executor));
189 drop(executor);
190 let mut data = HashMap::new();
191 while let Some((i, res)) = rx.recv().await {
192 println!("main: {i} = {res}");
193 if data.insert(i, res).is_some() {
194 panic!("数字 {i},值为 {res} 重复计算");
195 }
196 }
197 dbg!(&data);
198 for range in pre_data {
199 for i in range {
200 assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
201 data.remove(&i);
202 }
203 }
204 assert_eq!(data.len(), 0);
205 }
206}