avila_parallel/
work_stealing.rs1use std::collections::VecDeque;
7use std::sync::{Arc, Mutex};
8use std::thread;
9
10pub struct WorkStealingDeque<T> {
12 tasks: Arc<Mutex<VecDeque<T>>>,
13}
14
15impl<T> WorkStealingDeque<T> {
16 pub fn new() -> Self {
18 Self {
19 tasks: Arc::new(Mutex::new(VecDeque::new())),
20 }
21 }
22
23 pub fn push(&self, task: T) {
25 self.tasks.lock().unwrap().push_front(task);
26 }
27
28 pub fn pop(&self) -> Option<T> {
30 self.tasks.lock().unwrap().pop_front()
31 }
32
33 pub fn steal(&self) -> Option<T> {
35 self.tasks.lock().unwrap().pop_back()
36 }
37
38 pub fn len(&self) -> usize {
40 self.tasks.lock().unwrap().len()
41 }
42
43 pub fn is_empty(&self) -> bool {
45 self.tasks.lock().unwrap().is_empty()
46 }
47}
48
49impl<T> Clone for WorkStealingDeque<T> {
50 fn clone(&self) -> Self {
51 Self {
52 tasks: Arc::clone(&self.tasks),
53 }
54 }
55}
56
57pub struct WorkStealingPool {
59 workers: Vec<WorkStealingDeque<Box<dyn FnOnce() + Send + 'static>>>,
60 num_workers: usize,
61}
62
63impl WorkStealingPool {
64 pub fn new(num_threads: usize) -> Self {
66 let num_threads = if num_threads == 0 {
67 thread::available_parallelism()
68 .map(|n| n.get())
69 .unwrap_or(1)
70 } else {
71 num_threads
72 };
73
74 let workers: Vec<_> = (0..num_threads)
75 .map(|_| WorkStealingDeque::new())
76 .collect();
77
78 Self {
79 workers,
80 num_workers: num_threads,
81 }
82 }
83
84 pub fn execute<F>(&self, tasks: Vec<F>)
86 where
87 F: FnOnce() + Send + 'static,
88 {
89 for (idx, task) in tasks.into_iter().enumerate() {
91 let worker_idx = idx % self.num_workers;
92 self.workers[worker_idx].push(Box::new(task));
93 }
94
95 thread::scope(|s| {
97 for (thread_id, worker) in self.workers.iter().enumerate() {
98 let worker = worker.clone();
99 let all_workers = self.workers.clone();
100
101 s.spawn(move || {
102 loop {
103 if let Some(task) = worker.pop() {
105 task();
106 continue;
107 }
108
109 let mut found_work = false;
111 for (other_id, other_worker) in all_workers.iter().enumerate() {
112 if other_id != thread_id {
113 if let Some(task) = other_worker.steal() {
114 task();
115 found_work = true;
116 break;
117 }
118 }
119 }
120
121 if !found_work {
122 break;
124 }
125 }
126 });
127 }
128 });
129 }
130
131 pub fn num_workers(&self) -> usize {
133 self.num_workers
134 }
135}
136
137pub fn work_stealing_map<T, R, F>(items: &[T], f: F) -> Vec<R>
142where
143 T: Sync,
144 R: Send + 'static,
145 F: Fn(&T) -> R + Send + Sync,
146{
147 use crate::executor::parallel_map;
148 parallel_map(items, f)
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_work_stealing_deque() {
157 let deque = WorkStealingDeque::new();
158 deque.push(1);
159 deque.push(2);
160 deque.push(3);
161
162 assert_eq!(deque.pop(), Some(3));
163 assert_eq!(deque.steal(), Some(1));
164 assert_eq!(deque.pop(), Some(2));
165 assert!(deque.is_empty());
166 }
167
168 #[test]
169 fn test_work_stealing_pool() {
170 let pool = WorkStealingPool::new(2);
171 let counter = Arc::new(Mutex::new(0));
172
173 let tasks: Vec<_> = (0..10)
174 .map(|_| {
175 let counter = Arc::clone(&counter);
176 move || {
177 *counter.lock().unwrap() += 1;
178 }
179 })
180 .collect();
181
182 pool.execute(tasks);
183 assert_eq!(*counter.lock().unwrap(), 10);
184 }
185
186 #[test]
187 fn test_work_stealing_map() {
188 let data = vec![1, 2, 3, 4, 5];
189 let results = work_stealing_map(&data, |x| x * 2);
190 assert_eq!(results, vec![2, 4, 6, 8, 10]);
191 }
192}