Skip to main content

fast_steal/
task_queue.rs

1extern crate alloc;
2use crate::{Executor, Handle, Task, WeakTask};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9    inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12    fn clone(&self) -> Self {
13        Self {
14            inner: self.inner.clone(),
15        }
16    }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20    running: VecDeque<(WeakTask, H)>,
21    waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24    pub fn new(tasks: impl Iterator<Item = Range<u64>>) -> Self {
25        let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26        Self {
27            inner: Arc::new(Mutex::new(TaskQueueInner {
28                running: VecDeque::with_capacity(waiting.len()),
29                waiting,
30            })),
31        }
32    }
33    pub fn add(&self, task: Task) {
34        let mut guard = self.inner.lock();
35        guard.waiting.push_back(task);
36    }
37    pub fn steal(&self, task: &mut Task, min_chunk_size: u64, max_speculative: usize) -> bool {
38        let min_chunk_size = min_chunk_size.max(1);
39        let mut guard = self.inner.lock();
40        while let Some(new_task) = guard.waiting.pop_front() {
41            if let Some(range) = new_task.take() {
42                task.set(range);
43                return true;
44            }
45        }
46        if let Some(steal_task) = guard
47            .running
48            .iter()
49            .filter_map(|w| w.0.upgrade())
50            .filter(|w| w != task)
51            .max_by_key(Task::remain)
52        {
53            if steal_task.remain() >= min_chunk_size * 2
54                && let Ok(Some(range)) = steal_task.split_two()
55            {
56                task.set(range);
57                true
58            } else if max_speculative > 1
59                && steal_task.remain() > 0
60                && steal_task.strong_count() <= max_speculative
61            {
62                task.state = steal_task.state;
63                true
64            } else {
65                false
66            }
67        } else {
68            false
69        }
70    }
71    /// 当线程数需要增加时,但 executor 为空时,返回 None
72    #[must_use]
73    pub fn set_threads<E: Executor<Handle = H>>(
74        &self,
75        threads: usize,
76        min_chunk_size: u64,
77        executor: Option<&E>,
78    ) -> Option<()> {
79        #![allow(clippy::significant_drop_tightening)]
80        let min_chunk_size = min_chunk_size.max(1);
81        let mut guard = self.inner.lock();
82        guard.running.retain(|t| t.0.strong_count() > 0);
83        let len = guard.running.len();
84        if len < threads {
85            let executor = executor?;
86            let need = (threads - len).min(guard.waiting.len());
87            let mut temp = Vec::with_capacity(need);
88            let iter = guard.waiting.drain(..need);
89            for task in iter {
90                let weak = task.downgrade();
91                let handle = executor.execute(task, self.clone());
92                temp.push((weak, handle));
93            }
94            guard.running.extend(temp);
95            while guard.running.len() < threads
96                && let Some(steal_task) = guard
97                    .running
98                    .iter()
99                    .filter_map(|w| w.0.upgrade())
100                    .max_by_key(Task::remain)
101                && steal_task.remain() >= min_chunk_size * 2
102                && let Ok(Some(range)) = steal_task.split_two()
103            {
104                let task = Task::new(range);
105                let weak = task.downgrade();
106                let handle = executor.execute(task, self.clone());
107                guard.running.push_back((weak, handle));
108            }
109        } else if len > threads {
110            let mut temp = Vec::with_capacity(len - threads);
111            let iter = guard.running.drain(threads..);
112            for (task, mut handle) in iter {
113                if let Some(task) = task.upgrade() {
114                    temp.push(task);
115                }
116                handle.abort();
117            }
118            guard.waiting.extend(temp);
119        }
120        Some(())
121    }
122    pub fn handles<F, R>(&self, f: F) -> R
123    where
124        F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
125    {
126        #![allow(clippy::significant_drop_tightening)]
127        let mut guard = self.inner.lock();
128        let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
129        f(&mut iter)
130    }
131
132    pub fn cancel_task(&self, task: &Task, id: &H::Id) {
133        let mut guard = self.inner.lock();
134        for (weak, handle) in &mut guard.running {
135            if let Some(t) = weak.upgrade()
136                && t == *task
137                && !handle.is_self(id)
138            {
139                handle.abort();
140            }
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    #![allow(clippy::unwrap_used)]
148    extern crate std;
149    use crate::{Executor, Handle, Task, TaskQueue};
150    use std::{collections::HashMap, dbg, println};
151    use tokio::{sync::mpsc, task::AbortHandle};
152
153    struct TokioExecutor {
154        tx: mpsc::UnboundedSender<(u64, u64)>,
155        speculative: usize,
156    }
157    #[derive(Clone)]
158    struct TokioHandle(AbortHandle);
159
160    impl Handle for TokioHandle {
161        type Output = ();
162        type Id = ();
163        fn abort(&mut self) -> Self::Output {
164            self.0.abort();
165        }
166        fn is_self(&mut self, (): &Self::Id) -> bool {
167            false
168        }
169    }
170
171    impl Executor for TokioExecutor {
172        type Handle = TokioHandle;
173        fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
174            println!("execute");
175            let tx = self.tx.clone();
176            let speculative = self.speculative;
177            let handle = tokio::spawn(async move {
178                loop {
179                    while task.start() < task.end() {
180                        let i = task.start();
181                        let res = fib(i);
182                        let Ok(_) = task.safe_add_start(i, 1) else {
183                            println!("task-failed: {i} = {res}");
184                            continue;
185                        };
186                        println!("task: {i} = {res}");
187                        tx.send((i, res)).unwrap();
188                    }
189                    if !task_queue.steal(&mut task, 1, speculative) {
190                        break;
191                    }
192                }
193            });
194            let abort_handle = handle.abort_handle();
195            TokioHandle(abort_handle)
196        }
197    }
198
199    fn fib(n: u64) -> u64 {
200        match n {
201            0 => 0,
202            1 => 1,
203            _ => fib(n - 1) + fib(n - 2),
204        }
205    }
206    fn fib_fast(n: u64) -> u64 {
207        let mut a = 0;
208        let mut b = 1;
209        for _ in 0..n {
210            (a, b) = (b, a + b);
211        }
212        a
213    }
214
215    #[tokio::test(flavor = "multi_thread")]
216    async fn test_task_queue() {
217        let (tx, mut rx) = mpsc::unbounded_channel();
218        let executor = TokioExecutor { tx, speculative: 1 };
219        let pre_data = [1..20, 41..48];
220        let task_queue = TaskQueue::new(pre_data.iter().cloned());
221        task_queue.set_threads(8, 1, Some(&executor)).unwrap();
222        drop(executor);
223        let mut data = HashMap::new();
224        while let Some((i, res)) = rx.recv().await {
225            println!("main: {i} = {res}");
226            assert!(
227                data.insert(i, res).is_none(),
228                "数字 {i},值为 {res} 重复计算"
229            );
230        }
231        dbg!(&data);
232        for range in pre_data {
233            for i in range {
234                assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
235                data.remove(&i);
236            }
237        }
238        assert_eq!(data.len(), 0);
239    }
240
241    #[tokio::test(flavor = "multi_thread")]
242    async fn test_task_queue2() {
243        let (tx, mut rx) = mpsc::unbounded_channel();
244        let executor = TokioExecutor { tx, speculative: 2 };
245        let pre_data = [1..20, 41..48];
246        let task_queue = TaskQueue::new(pre_data.iter().cloned());
247        task_queue.set_threads(8, 1, Some(&executor)).unwrap();
248        drop(executor);
249        let mut data = HashMap::new();
250        while let Some((i, res)) = rx.recv().await {
251            println!("main: {i} = {res}");
252            assert!(
253                data.insert(i, res).is_none(),
254                "数字 {i},值为 {res} 重复计算"
255            );
256        }
257        dbg!(&data);
258        for range in pre_data {
259            for i in range {
260                assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
261                data.remove(&i);
262            }
263        }
264        assert_eq!(data.len(), 0);
265    }
266}