Skip to main content

fast_steal/
task_queue.rs

1extern crate alloc;
2use crate::{Executor, Handle, Task, WeakTask};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9    inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12    fn clone(&self) -> Self {
13        Self {
14            inner: self.inner.clone(),
15        }
16    }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20    running: VecDeque<(WeakTask, H)>,
21    waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24    pub fn new<'a>(tasks: impl Iterator<Item = &'a Range<u64>>) -> Self {
25        let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26        Self {
27            inner: Arc::new(Mutex::new(TaskQueueInner {
28                running: VecDeque::with_capacity(waiting.len()),
29                waiting,
30            })),
31        }
32    }
33    pub fn add(&self, task: Task) {
34        let mut guard = self.inner.lock();
35        guard.waiting.push_back(task);
36    }
37    pub fn steal(&self, task: &mut Task, min_chunk_size: u64, speculative: bool) -> bool {
38        let min_chunk_size = min_chunk_size.max(1);
39        let mut guard = self.inner.lock();
40        while let Some(new_task) = guard.waiting.pop_front() {
41            if let Some(range) = new_task.take() {
42                task.set(range);
43                return true;
44            }
45        }
46        if let Some(steal_task) = guard
47            .running
48            .iter()
49            .filter_map(|w| w.0.upgrade())
50            .filter(|w| w != task)
51            .max_by_key(|w| w.remain())
52        {
53            if steal_task.remain() >= min_chunk_size * 2
54                && let Ok(Some(range)) = steal_task.split_two()
55            {
56                task.set(range);
57                true
58            } else if speculative && steal_task.remain() > 0 {
59                task.state = steal_task.state.clone();
60                true
61            } else {
62                false
63            }
64        } else {
65            false
66        }
67    }
68    /// 当线程数需要增加时,但 executor 为空时,返回 None
69    pub fn set_threads<E: Executor<Handle = H>>(
70        &self,
71        threads: usize,
72        min_chunk_size: u64,
73        executor: Option<&E>,
74    ) -> Option<()> {
75        let min_chunk_size = min_chunk_size.max(1);
76        let mut guard = self.inner.lock();
77        guard.running.retain(|t| t.0.strong_count() > 0);
78        let len = guard.running.len();
79        if len < threads {
80            let executor = executor?;
81            let need = (threads - len).min(guard.waiting.len());
82            let mut temp = Vec::with_capacity(need);
83            let iter = guard.waiting.drain(..need);
84            for task in iter {
85                let weak = task.downgrade();
86                let handle = executor.execute(task, self.clone());
87                temp.push((weak, handle));
88            }
89            guard.running.extend(temp);
90            while guard.running.len() < threads
91                && let Some(steal_task) = guard
92                    .running
93                    .iter()
94                    .filter_map(|w| w.0.upgrade())
95                    .max_by_key(|w| w.remain())
96                && steal_task.remain() >= min_chunk_size * 2
97                && let Ok(Some(range)) = steal_task.split_two()
98            {
99                let task = Task::new(range);
100                let weak = task.downgrade();
101                let handle = executor.execute(task, self.clone());
102                guard.running.push_back((weak, handle));
103            }
104        } else if len > threads {
105            let mut temp = Vec::with_capacity(len - threads);
106            let iter = guard.running.drain(threads..);
107            for (task, mut handle) in iter {
108                if let Some(task) = task.upgrade() {
109                    temp.push(task);
110                }
111                handle.abort();
112            }
113            guard.waiting.extend(temp);
114        }
115        Some(())
116    }
117    pub fn handles<F, R>(&self, f: F) -> R
118    where
119        F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
120    {
121        let mut guard = self.inner.lock();
122        let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
123        f(&mut iter)
124    }
125    pub fn cancel_tasks(&self, task: &Task) {
126        let mut guard = self.inner.lock();
127        for (weak, handle) in &mut guard.running {
128            if let Some(t) = weak.upgrade()
129                && t == *task
130            {
131                handle.abort();
132            }
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    extern crate std;
140    use crate::{Executor, Handle, Task, TaskQueue};
141    use std::{collections::HashMap, dbg, println};
142    use tokio::{sync::mpsc, task::AbortHandle};
143
144    pub struct TokioExecutor {
145        tx: mpsc::UnboundedSender<(u64, u64)>,
146        speculative: bool,
147    }
148    #[derive(Clone)]
149    pub struct TokioHandle(AbortHandle);
150
151    impl Handle for TokioHandle {
152        type Output = ();
153        fn abort(&mut self) -> Self::Output {
154            self.0.abort();
155        }
156    }
157
158    impl Executor for TokioExecutor {
159        type Handle = TokioHandle;
160        fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
161            println!("execute");
162            let tx = self.tx.clone();
163            let speculative = self.speculative;
164            let handle = tokio::spawn(async move {
165                loop {
166                    while task.start() < task.end() {
167                        let i = task.start();
168                        let res = fib(i);
169                        let Ok(_) = task.safe_add_start(i, 1) else {
170                            println!("task-failed: {i} = {res}");
171                            continue;
172                        };
173                        println!("task: {i} = {res}");
174                        tx.send((i, res)).unwrap();
175                    }
176                    if !task_queue.steal(&mut task, 1, speculative) {
177                        break;
178                    }
179                }
180            });
181            let abort_handle = handle.abort_handle();
182            TokioHandle(abort_handle)
183        }
184    }
185
186    fn fib(n: u64) -> u64 {
187        match n {
188            0 => 0,
189            1 => 1,
190            _ => fib(n - 1) + fib(n - 2),
191        }
192    }
193    fn fib_fast(n: u64) -> u64 {
194        let mut a = 0;
195        let mut b = 1;
196        for _ in 0..n {
197            (a, b) = (b, a + b);
198        }
199        a
200    }
201
202    #[tokio::test(flavor = "multi_thread")]
203    async fn test_task_queue() {
204        let (tx, mut rx) = mpsc::unbounded_channel();
205        let executor = TokioExecutor {
206            tx,
207            speculative: false,
208        };
209        let pre_data = [1..20, 41..48];
210        let task_queue = TaskQueue::new(pre_data.iter());
211        task_queue.set_threads(8, 1, Some(&executor));
212        drop(executor);
213        let mut data = HashMap::new();
214        while let Some((i, res)) = rx.recv().await {
215            println!("main: {i} = {res}");
216            if data.insert(i, res).is_some() {
217                panic!("数字 {i},值为 {res} 重复计算");
218            }
219        }
220        dbg!(&data);
221        for range in pre_data {
222            for i in range {
223                assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
224                data.remove(&i);
225            }
226        }
227        assert_eq!(data.len(), 0);
228    }
229
230    #[tokio::test(flavor = "multi_thread")]
231    async fn test_task_queue2() {
232        let (tx, mut rx) = mpsc::unbounded_channel();
233        let executor = TokioExecutor {
234            tx,
235            speculative: true,
236        };
237        let pre_data = [1..20, 41..48];
238        let task_queue = TaskQueue::new(pre_data.iter());
239        task_queue.set_threads(8, 1, Some(&executor));
240        drop(executor);
241        let mut data = HashMap::new();
242        while let Some((i, res)) = rx.recv().await {
243            println!("main: {i} = {res}");
244            if data.insert(i, res).is_some() {
245                panic!("数字 {i},值为 {res} 重复计算");
246            }
247        }
248        dbg!(&data);
249        for range in pre_data {
250            for i in range {
251                assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
252                data.remove(&i);
253            }
254        }
255        assert_eq!(data.len(), 0);
256    }
257}