Skip to main content

fast_steal/
task_queue.rs

1extern crate alloc;
2use crate::{Executor, Handle, Task};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9    inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12    fn clone(&self) -> Self {
13        Self {
14            inner: self.inner.clone(),
15        }
16    }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20    running: VecDeque<(Task, H)>,
21    waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24    pub fn new<'a>(tasks: impl Iterator<Item = &'a Range<u64>>) -> Self {
25        let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26        Self {
27            inner: Arc::new(Mutex::new(TaskQueueInner {
28                running: VecDeque::with_capacity(waiting.len()),
29                waiting,
30            })),
31        }
32    }
33    /// 仅供 Worker 线程在任务完成后调用
34    pub fn finish_work(&self, task: &Task) -> usize {
35        let mut guard = self.inner.lock();
36        let len = guard.running.len();
37        guard.running.retain(|(t, _)| t != task);
38        len - guard.running.len()
39    }
40    /// 用于从外部取消任务
41    pub fn cancel(&self, task: &Task) -> usize {
42        let mut guard = self.inner.lock();
43        let len = guard.running.len() + guard.waiting.len();
44        guard.running.retain(|(t, _)| t != task);
45        guard.waiting.retain(|t| t != task);
46        len - guard.running.len() - guard.waiting.len()
47    }
48    pub fn add(&self, task: Task) {
49        let mut guard = self.inner.lock();
50        guard.waiting.push_back(task);
51    }
52    pub fn steal(&self, task: &Task, min_chunk_size: u64) -> bool {
53        let min_chunk_size = min_chunk_size.max(1);
54        let mut guard = self.inner.lock();
55        while let Some(new_task) = guard.waiting.pop_front() {
56            if let Some(range) = new_task.take() {
57                task.set(range);
58                return true;
59            }
60        }
61        if let Some(steal_task) = guard.running.iter().map(|w| &w.0).max()
62            && steal_task.remain() >= min_chunk_size * 2
63            && let Ok(Some(range)) = steal_task.split_two()
64        {
65            task.set(range);
66            true
67        } else {
68            false
69        }
70    }
71    /// 当线程数需要增加时,但 executor 为空时,返回 None
72    pub fn set_threads<E: Executor<Handle = H>>(
73        &self,
74        threads: usize,
75        min_chunk_size: u64,
76        executor: Option<&E>,
77    ) -> Option<()> {
78        let threads = threads.max(1);
79        let min_chunk_size = min_chunk_size.max(1);
80        let mut guard = self.inner.lock();
81        let len = guard.running.len();
82        if len < threads {
83            let executor = executor?;
84            let need = (threads - len).min(guard.waiting.len());
85            let mut temp = Vec::with_capacity(need);
86            let iter = guard.waiting.drain(..need);
87            for task in iter {
88                let handle = executor.execute(task.clone(), self.clone());
89                temp.push((task, handle));
90            }
91            guard.running.extend(temp);
92            while guard.running.len() < threads
93                && let Some(steal_task) = guard.running.iter().map(|w| &w.0).max()
94                && steal_task.remain() >= min_chunk_size * 2
95                && let Ok(Some(range)) = steal_task.split_two()
96            {
97                let task = Task::new(range);
98                let handle = executor.execute(task.clone(), self.clone());
99                guard.running.push_back((task, handle));
100            }
101        } else if len > threads {
102            let mut temp = Vec::with_capacity(len - threads);
103            let iter = guard.running.drain(threads..);
104            for (task, mut handle) in iter {
105                handle.abort();
106                temp.push(task);
107            }
108            guard.waiting.extend(temp);
109        }
110        Some(())
111    }
112    pub fn handles<F, R>(&self, f: F) -> R
113    where
114        F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
115    {
116        let mut guard = self.inner.lock();
117        let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
118        f(&mut iter)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    extern crate std;
125    use crate::{Executor, Handle, Task, TaskQueue};
126    use std::{collections::HashMap, dbg, println};
127    use tokio::{sync::mpsc, task::AbortHandle};
128
129    pub struct TokioExecutor {
130        tx: mpsc::UnboundedSender<(u64, u64)>,
131    }
132    #[derive(Clone)]
133    pub struct TokioHandle(AbortHandle);
134
135    impl Handle for TokioHandle {
136        type Output = ();
137        fn abort(&mut self) -> Self::Output {
138            self.0.abort();
139        }
140    }
141
142    impl Executor for TokioExecutor {
143        type Handle = TokioHandle;
144        fn execute(&self, task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
145            println!("execute");
146            let tx = self.tx.clone();
147            let handle = tokio::spawn(async move {
148                loop {
149                    while task.start() < task.end() {
150                        let i = task.start();
151                        task.fetch_add_start(1).unwrap();
152                        let res = fib(i);
153                        println!("task: {i} = {res}");
154                        tx.send((i, res)).unwrap();
155                    }
156                    if !task_queue.steal(&task, 1) {
157                        break;
158                    }
159                }
160                assert_eq!(task_queue.finish_work(&task), 1);
161            });
162            let abort_handle = handle.abort_handle();
163            TokioHandle(abort_handle)
164        }
165    }
166
167    fn fib(n: u64) -> u64 {
168        match n {
169            0 => 0,
170            1 => 1,
171            _ => fib(n - 1) + fib(n - 2),
172        }
173    }
174    fn fib_fast(n: u64) -> u64 {
175        let mut a = 0;
176        let mut b = 1;
177        for _ in 0..n {
178            (a, b) = (b, a + b);
179        }
180        a
181    }
182
183    #[tokio::test]
184    async fn test_task_queue() {
185        let (tx, mut rx) = mpsc::unbounded_channel();
186        let executor = TokioExecutor { tx };
187        let pre_data = [1..20, 41..48];
188        let task_queue = TaskQueue::new(pre_data.iter());
189        task_queue.set_threads(8, 1, Some(&executor));
190        drop(executor);
191        let mut data = HashMap::new();
192        while let Some((i, res)) = rx.recv().await {
193            println!("main: {i} = {res}");
194            if data.insert(i, res).is_some() {
195                panic!("数字 {i},值为 {res} 重复计算");
196            }
197        }
198        dbg!(&data);
199        for range in pre_data {
200            for i in range {
201                assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
202                data.remove(&i);
203            }
204        }
205        assert_eq!(data.len(), 0);
206    }
207}