Skip to main content

fast_steal/
task_queue.rs

1extern crate alloc;
2use crate::{Executor, Handle, Task, WeakTask};
3use alloc::{collections::vec_deque::VecDeque, sync::Arc, vec::Vec};
4use core::ops::Range;
5use parking_lot::Mutex;
6
7#[derive(Debug)]
8pub struct TaskQueue<H: Handle> {
9    inner: Arc<Mutex<TaskQueueInner<H>>>,
10}
11impl<H: Handle> Clone for TaskQueue<H> {
12    fn clone(&self) -> Self {
13        Self {
14            inner: self.inner.clone(),
15        }
16    }
17}
18#[derive(Debug)]
19struct TaskQueueInner<H: Handle> {
20    running: VecDeque<(WeakTask, H)>,
21    waiting: VecDeque<Task>,
22}
23impl<H: Handle> TaskQueue<H> {
24    pub fn new(tasks: impl Iterator<Item = Range<u64>>) -> Self {
25        let waiting: VecDeque<_> = tasks.map(Task::from).collect();
26        Self {
27            inner: Arc::new(Mutex::new(TaskQueueInner {
28                running: VecDeque::with_capacity(waiting.len()),
29                waiting,
30            })),
31        }
32    }
33    pub fn add(&self, task: Task) {
34        let mut guard = self.inner.lock();
35        guard.waiting.push_back(task);
36    }
37    pub fn steal(
38        &self,
39        id: &H::Id,
40        task: &mut Task,
41        min_chunk_size: u64,
42        max_speculative: usize,
43    ) -> bool {
44        let min_chunk_size = min_chunk_size.max(1);
45        let mut guard = self.inner.lock();
46        let mut worker_idx = None;
47        for (i, (_, handle)) in guard.running.iter_mut().enumerate() {
48            if handle.is_self(id) {
49                worker_idx = Some(i);
50                break;
51            }
52        }
53        let Some(worker_idx) = worker_idx else {
54            return false;
55        };
56        let mut found = false;
57        while let Some(new_task) = guard.waiting.pop_front() {
58            if let Some(range) = new_task.take() {
59                *task = Task::new(range);
60                found = true;
61                break;
62            }
63        }
64        if !found
65            && let Some(steal_task) = guard
66                .running
67                .iter()
68                .filter_map(|w| w.0.upgrade())
69                .filter(|w| w != task)
70                .max_by_key(Task::remain)
71        {
72            if steal_task.remain() >= min_chunk_size * 2
73                && let Ok(Some(range)) = steal_task.split_two()
74            {
75                *task = Task::new(range);
76                found = true;
77            } else if max_speculative > 1
78                && steal_task.remain() > 0
79                && steal_task.strong_count() <= max_speculative
80            {
81                task.state = steal_task.state;
82                found = true;
83            }
84        }
85        if found {
86            guard.running[worker_idx].0 = task.downgrade();
87        }
88        found
89    }
90    /// 当线程数需要增加时,但 executor 为空时,返回 None
91    #[must_use]
92    pub fn set_threads<E: Executor<Handle = H>>(
93        &self,
94        threads: usize,
95        min_chunk_size: u64,
96        executor: Option<&E>,
97    ) -> Option<()> {
98        #![allow(clippy::significant_drop_tightening)]
99        let min_chunk_size = min_chunk_size.max(1);
100        let mut guard = self.inner.lock();
101        guard.running.retain(|t| t.0.strong_count() > 0);
102        let len = guard.running.len();
103        if len < threads {
104            let executor = executor?;
105            let need = (threads - len).min(guard.waiting.len());
106            let mut temp = Vec::with_capacity(need);
107            let iter = guard.waiting.drain(..need);
108            for task in iter {
109                let weak = task.downgrade();
110                let handle = executor.execute(task, self.clone());
111                temp.push((weak, handle));
112            }
113            guard.running.extend(temp);
114            while guard.running.len() < threads
115                && let Some(steal_task) = guard
116                    .running
117                    .iter()
118                    .filter_map(|w| w.0.upgrade())
119                    .max_by_key(Task::remain)
120                && steal_task.remain() >= min_chunk_size * 2
121                && let Ok(Some(range)) = steal_task.split_two()
122            {
123                let task = Task::new(range);
124                let weak = task.downgrade();
125                let handle = executor.execute(task, self.clone());
126                guard.running.push_back((weak, handle));
127            }
128        } else if len > threads {
129            let mut temp = Vec::with_capacity(len - threads);
130            let iter = guard.running.drain(threads..);
131            for (task, mut handle) in iter {
132                if let Some(task) = task.upgrade() {
133                    temp.push(task);
134                }
135                handle.abort();
136            }
137            guard.waiting.extend(temp);
138        }
139        Some(())
140    }
141    pub fn handles<F, R>(&self, f: F) -> R
142    where
143        F: FnOnce(&mut dyn Iterator<Item = &mut H>) -> R,
144    {
145        #![allow(clippy::significant_drop_tightening)]
146        let mut guard = self.inner.lock();
147        let mut iter = guard.running.iter_mut().map(|w| &mut w.1);
148        f(&mut iter)
149    }
150
151    pub fn cancel_task(&self, task: &Task, id: &H::Id) {
152        let mut guard = self.inner.lock();
153        for (weak, handle) in &mut guard.running {
154            if let Some(t) = weak.upgrade()
155                && t == *task
156                && !handle.is_self(id)
157            {
158                handle.abort();
159            }
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    #![allow(clippy::unwrap_used)]
167    extern crate std;
168    use crate::{Executor, Handle, Task, TaskQueue};
169    use std::{collections::HashMap, dbg, println};
170    use tokio::{sync::mpsc, task::AbortHandle};
171
172    struct TokioExecutor {
173        tx: mpsc::UnboundedSender<(u64, u64)>,
174        speculative: usize,
175    }
176    #[derive(Clone)]
177    struct TokioHandle(AbortHandle);
178
179    impl Handle for TokioHandle {
180        type Output = ();
181        type Id = ();
182        fn abort(&mut self) -> Self::Output {
183            self.0.abort();
184        }
185        fn is_self(&mut self, (): &Self::Id) -> bool {
186            false
187        }
188    }
189
190    impl Executor for TokioExecutor {
191        type Handle = TokioHandle;
192        fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
193            println!("execute");
194            let tx = self.tx.clone();
195            let speculative = self.speculative;
196            let handle = tokio::spawn(async move {
197                loop {
198                    while task.start() < task.end() {
199                        let i = task.start();
200                        let res = fib(i);
201                        let Ok(_) = task.safe_add_start(i, 1) else {
202                            println!("task-failed: {i} = {res}");
203                            continue;
204                        };
205                        println!("task: {i} = {res}");
206                        tx.send((i, res)).unwrap();
207                    }
208                    if !task_queue.steal(&(), &mut task, 1, speculative) {
209                        break;
210                    }
211                }
212            });
213            let abort_handle = handle.abort_handle();
214            TokioHandle(abort_handle)
215        }
216    }
217
218    fn fib(n: u64) -> u64 {
219        match n {
220            0 => 0,
221            1 => 1,
222            _ => fib(n - 1) + fib(n - 2),
223        }
224    }
225    fn fib_fast(n: u64) -> u64 {
226        let mut a = 0;
227        let mut b = 1;
228        for _ in 0..n {
229            (a, b) = (b, a + b);
230        }
231        a
232    }
233
234    #[tokio::test(flavor = "multi_thread")]
235    async fn test_task_queue() {
236        let (tx, mut rx) = mpsc::unbounded_channel();
237        let executor = TokioExecutor { tx, speculative: 1 };
238        let pre_data = [1..20, 41..48];
239        let task_queue = TaskQueue::new(pre_data.iter().cloned());
240        task_queue.set_threads(8, 1, Some(&executor)).unwrap();
241        drop(executor);
242        let mut data = HashMap::new();
243        while let Some((i, res)) = rx.recv().await {
244            println!("main: {i} = {res}");
245            assert!(
246                data.insert(i, res).is_none(),
247                "数字 {i},值为 {res} 重复计算"
248            );
249        }
250        dbg!(&data);
251        for range in pre_data {
252            for i in range {
253                assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
254                data.remove(&i);
255            }
256        }
257        assert_eq!(data.len(), 0);
258    }
259
260    #[tokio::test(flavor = "multi_thread")]
261    async fn test_task_queue2() {
262        let (tx, mut rx) = mpsc::unbounded_channel();
263        let executor = TokioExecutor { tx, speculative: 2 };
264        let pre_data = [1..20, 41..48];
265        let task_queue = TaskQueue::new(pre_data.iter().cloned());
266        task_queue.set_threads(8, 1, Some(&executor)).unwrap();
267        drop(executor);
268        let mut data = HashMap::new();
269        while let Some((i, res)) = rx.recv().await {
270            println!("main: {i} = {res}");
271            assert!(
272                data.insert(i, res).is_none(),
273                "数字 {i},值为 {res} 重复计算"
274            );
275        }
276        dbg!(&data);
277        for range in pre_data {
278            for i in range {
279                assert_eq!((i, data.get(&i)), (i, Some(&fib_fast(i))));
280                data.remove(&i);
281            }
282        }
283        assert_eq!(data.len(), 0);
284    }
285}