avila_parallel/
work_stealing.rs

1//! Work Stealing Scheduler
2//!
3//! This module implements a work-stealing scheduler for better load balancing
4//! across threads. Each thread has its own deque and can steal work from others.
5
6use std::collections::VecDeque;
7use std::sync::{Arc, Mutex};
8use std::thread;
9
10/// A work-stealing deque for task distribution
11pub struct WorkStealingDeque<T> {
12    tasks: Arc<Mutex<VecDeque<T>>>,
13}
14
15impl<T> WorkStealingDeque<T> {
16    /// Create a new work-stealing deque
17    pub fn new() -> Self {
18        Self {
19            tasks: Arc::new(Mutex::new(VecDeque::new())),
20        }
21    }
22
23    /// Push a task to the front (local end)
24    pub fn push(&self, task: T) {
25        self.tasks.lock().unwrap().push_front(task);
26    }
27
28    /// Pop a task from the front (local end)
29    pub fn pop(&self) -> Option<T> {
30        self.tasks.lock().unwrap().pop_front()
31    }
32
33    /// Steal a task from the back (remote end)
34    pub fn steal(&self) -> Option<T> {
35        self.tasks.lock().unwrap().pop_back()
36    }
37
38    /// Get the number of tasks
39    pub fn len(&self) -> usize {
40        self.tasks.lock().unwrap().len()
41    }
42
43    /// Check if the deque is empty
44    pub fn is_empty(&self) -> bool {
45        self.tasks.lock().unwrap().is_empty()
46    }
47}
48
49impl<T> Clone for WorkStealingDeque<T> {
50    fn clone(&self) -> Self {
51        Self {
52            tasks: Arc::clone(&self.tasks),
53        }
54    }
55}
56
57/// Work-stealing thread pool
58pub struct WorkStealingPool {
59    workers: Vec<WorkStealingDeque<Box<dyn FnOnce() + Send + 'static>>>,
60    num_workers: usize,
61}
62
63impl WorkStealingPool {
64    /// Create a new work-stealing pool
65    pub fn new(num_threads: usize) -> Self {
66        let num_threads = if num_threads == 0 {
67            thread::available_parallelism()
68                .map(|n| n.get())
69                .unwrap_or(1)
70        } else {
71            num_threads
72        };
73
74        let workers: Vec<_> = (0..num_threads)
75            .map(|_| WorkStealingDeque::new())
76            .collect();
77
78        Self {
79            workers,
80            num_workers: num_threads,
81        }
82    }
83
84    /// Execute a task using work stealing
85    pub fn execute<F>(&self, tasks: Vec<F>)
86    where
87        F: FnOnce() + Send + 'static,
88    {
89        // Distribute tasks to workers
90        for (idx, task) in tasks.into_iter().enumerate() {
91            let worker_idx = idx % self.num_workers;
92            self.workers[worker_idx].push(Box::new(task));
93        }
94
95        // Spawn threads
96        thread::scope(|s| {
97            for (thread_id, worker) in self.workers.iter().enumerate() {
98                let worker = worker.clone();
99                let all_workers = self.workers.clone();
100
101                s.spawn(move || {
102                    loop {
103                        // Try to get local work
104                        if let Some(task) = worker.pop() {
105                            task();
106                            continue;
107                        }
108
109                        // Try to steal from others
110                        let mut found_work = false;
111                        for (other_id, other_worker) in all_workers.iter().enumerate() {
112                            if other_id != thread_id {
113                                if let Some(task) = other_worker.steal() {
114                                    task();
115                                    found_work = true;
116                                    break;
117                                }
118                            }
119                        }
120
121                        if !found_work {
122                            // No work available, exit
123                            break;
124                        }
125                    }
126                });
127            }
128        });
129    }
130
131    /// Get the number of workers
132    pub fn num_workers(&self) -> usize {
133        self.num_workers
134    }
135}
136
137/// Parallel map with work stealing
138///
139/// Note: This is a simplified implementation for demonstration.
140/// For production use, consider using a dedicated work-stealing library.
141pub fn work_stealing_map<T, R, F>(items: &[T], f: F) -> Vec<R>
142where
143    T: Sync,
144    R: Send + 'static,
145    F: Fn(&T) -> R + Send + Sync,
146{
147    use crate::executor::parallel_map;
148    parallel_map(items, f)
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_work_stealing_deque() {
157        let deque = WorkStealingDeque::new();
158        deque.push(1);
159        deque.push(2);
160        deque.push(3);
161
162        assert_eq!(deque.pop(), Some(3));
163        assert_eq!(deque.steal(), Some(1));
164        assert_eq!(deque.pop(), Some(2));
165        assert!(deque.is_empty());
166    }
167
168    #[test]
169    fn test_work_stealing_pool() {
170        let pool = WorkStealingPool::new(2);
171        let counter = Arc::new(Mutex::new(0));
172
173        let tasks: Vec<_> = (0..10)
174            .map(|_| {
175                let counter = Arc::clone(&counter);
176                move || {
177                    *counter.lock().unwrap() += 1;
178                }
179            })
180            .collect();
181
182        pool.execute(tasks);
183        assert_eq!(*counter.lock().unwrap(), 10);
184    }
185
186    #[test]
187    fn test_work_stealing_map() {
188        let data = vec![1, 2, 3, 4, 5];
189        let results = work_stealing_map(&data, |x| x * 2);
190        assert_eq!(results, vec![2, 4, 6, 8, 10]);
191    }
192}