Skip to main content

oxigdal_gpu_advanced/multi_gpu/
work_queue.rs

1//! Work queue for GPU task management.
2
3use super::GpuDevice;
4use crate::error::{GpuAdvancedError, Result};
5use crossbeam_channel::{Receiver, Sender, bounded};
6use parking_lot::Mutex;
7use std::sync::Arc;
8use std::thread;
9use std::time::Instant;
10
11/// Work item for GPU execution
12type WorkItem = Box<dyn FnOnce(&GpuDevice) -> Result<()> + Send>;
13
14/// Result sender for work completion
15type ResultSender = Sender<Result<()>>;
16
17/// Work queue for a single GPU device
18pub struct WorkQueue {
19    /// Associated GPU device
20    device: Arc<GpuDevice>,
21    /// Work sender (wrapped in Option for dropping in Drop impl)
22    work_sender: Option<Sender<(WorkItem, ResultSender)>>,
23    /// Worker thread handle
24    worker_handle: Option<Arc<Mutex<Option<thread::JoinHandle<()>>>>>,
25    /// Number of pending tasks
26    pending_tasks: Arc<Mutex<usize>>,
27}
28
29impl WorkQueue {
30    /// Create a new work queue
31    pub fn new(device: Arc<GpuDevice>) -> Self {
32        let (work_sender, work_receiver) = bounded::<(WorkItem, ResultSender)>(256);
33        let device_clone = device.clone();
34        let pending_tasks = Arc::new(Mutex::new(0));
35        let pending_clone = pending_tasks.clone();
36
37        // Spawn worker thread
38        let handle = thread::spawn(move || {
39            Self::worker_loop(device_clone, work_receiver, pending_clone);
40        });
41
42        Self {
43            device,
44            work_sender: Some(work_sender),
45            worker_handle: Some(Arc::new(Mutex::new(Some(handle)))),
46            pending_tasks,
47        }
48    }
49
50    /// Worker loop for processing tasks
51    fn worker_loop(
52        device: Arc<GpuDevice>,
53        work_receiver: Receiver<(WorkItem, ResultSender)>,
54        pending_tasks: Arc<Mutex<usize>>,
55    ) {
56        while let Ok((work, result_sender)) = work_receiver.recv() {
57            let start = Instant::now();
58
59            // Update workload
60            device.set_workload(1.0);
61
62            // Execute work
63            let result = work(&device);
64
65            // Update workload
66            device.set_workload(0.0);
67
68            // Send result
69            let _ = result_sender.send(result);
70
71            // Update pending tasks
72            {
73                let mut pending = pending_tasks.lock();
74                *pending = pending.saturating_sub(1);
75            }
76
77            let duration = start.elapsed();
78            tracing::debug!(
79                "Task completed on GPU {} in {:?}",
80                device.info.index,
81                duration
82            );
83        }
84    }
85
86    /// Submit work to the queue
87    pub async fn submit_work<F, T>(&self, work: F) -> Result<T>
88    where
89        F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
90        T: Send + 'static,
91    {
92        let (result_sender, result_receiver) = bounded(1);
93        let result_arc = Arc::new(Mutex::new(None));
94        let result_clone = result_arc.clone();
95
96        // Wrap work to capture result
97        let work_wrapper: WorkItem = Box::new(move |device| {
98            let result = work(device);
99            match result {
100                Ok(value) => {
101                    *result_clone.lock() = Some(Ok(value));
102                    Ok(())
103                }
104                Err(e) => {
105                    *result_clone.lock() = Some(Err(e));
106                    Ok(())
107                }
108            }
109        });
110
111        // Update pending tasks
112        {
113            let mut pending = self.pending_tasks.lock();
114            *pending = pending.saturating_add(1);
115        }
116
117        // Send work
118        self.work_sender
119            .as_ref()
120            .ok_or_else(|| GpuAdvancedError::WorkStealingError("Work queue is closed".to_string()))?
121            .send((work_wrapper, result_sender))
122            .map_err(|e| {
123                GpuAdvancedError::WorkStealingError(format!("Failed to send work: {}", e))
124            })?;
125
126        // Wait for completion
127        let _ = result_receiver
128            .recv()
129            .map_err(|e| GpuAdvancedError::SyncError(format!("Failed to receive result: {}", e)))?;
130
131        // Extract result
132        result_arc
133            .lock()
134            .take()
135            .ok_or_else(|| GpuAdvancedError::SyncError("Result not available".to_string()))?
136    }
137
138    /// Get number of pending tasks
139    pub fn pending_count(&self) -> usize {
140        *self.pending_tasks.lock()
141    }
142
143    /// Check if queue is empty
144    pub fn is_empty(&self) -> bool {
145        self.pending_count() == 0
146    }
147
148    /// Get associated device
149    pub fn device(&self) -> Arc<GpuDevice> {
150        self.device.clone()
151    }
152}
153
154impl Drop for WorkQueue {
155    fn drop(&mut self) {
156        // Close the channel by dropping the sender
157        // This will cause the worker thread's recv() to return Err,
158        // allowing it to exit cleanly
159        drop(self.work_sender.take());
160
161        // Wait for worker thread to finish
162        if let Some(handle_arc) = self.worker_handle.take() {
163            if let Some(handle) = handle_arc.lock().take() {
164                let _ = handle.join();
165            }
166        }
167    }
168}
169
170/// Work stealing queue for load balancing
171pub struct WorkStealingQueue {
172    /// Local work queue
173    local_queue: Arc<Mutex<Vec<WorkItem>>>,
174    /// Steal threshold (steal if more than this many items)
175    steal_threshold: usize,
176}
177
178impl WorkStealingQueue {
179    /// Create a new work stealing queue
180    pub fn new(steal_threshold: usize) -> Self {
181        Self {
182            local_queue: Arc::new(Mutex::new(Vec::new())),
183            steal_threshold,
184        }
185    }
186
187    /// Push work to local queue
188    pub fn push(&self, work: WorkItem) {
189        let mut queue = self.local_queue.lock();
190        queue.push(work);
191    }
192
193    /// Pop work from local queue
194    pub fn pop(&self) -> Option<WorkItem> {
195        let mut queue = self.local_queue.lock();
196        queue.pop()
197    }
198
199    /// Steal work from this queue (take half if above threshold)
200    pub fn steal(&self) -> Vec<WorkItem> {
201        let mut queue = self.local_queue.lock();
202        let len = queue.len();
203
204        if len <= self.steal_threshold {
205            return Vec::new();
206        }
207
208        let steal_count = len / 2;
209        let split_point = len - steal_count;
210        queue.split_off(split_point)
211    }
212
213    /// Get queue length
214    pub fn len(&self) -> usize {
215        self.local_queue.lock().len()
216    }
217
218    /// Check if queue is empty
219    pub fn is_empty(&self) -> bool {
220        self.len() == 0
221    }
222
223    /// Check if stealing should be allowed
224    pub fn should_allow_stealing(&self) -> bool {
225        self.len() > self.steal_threshold
226    }
227}
228
229/// Batch work submitter for efficient multi-GPU processing
230pub struct BatchSubmitter {
231    /// Work queues for each device
232    queues: Vec<Arc<WorkQueue>>,
233    /// Current queue index for round-robin
234    current_index: Mutex<usize>,
235}
236
237impl BatchSubmitter {
238    /// Create a new batch submitter
239    pub fn new(queues: Vec<Arc<WorkQueue>>) -> Self {
240        Self {
241            queues,
242            current_index: Mutex::new(0),
243        }
244    }
245
246    /// Submit batch of work items
247    pub async fn submit_batch<F, T>(&self, work_items: Vec<F>) -> Result<Vec<T>>
248    where
249        F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
250        T: Send + 'static,
251    {
252        if self.queues.is_empty() {
253            return Err(GpuAdvancedError::WorkStealingError(
254                "No work queues available".to_string(),
255            ));
256        }
257
258        let mut futures = Vec::new();
259
260        for work in work_items {
261            // Select queue in round-robin fashion
262            let queue_index = {
263                let mut index = self.current_index.lock();
264                let current = *index;
265                *index = (*index + 1) % self.queues.len();
266                current
267            };
268
269            let queue = &self.queues[queue_index];
270            let future = queue.submit_work(work);
271            futures.push(future);
272        }
273
274        // Wait for all to complete
275        let mut results = Vec::new();
276        for future in futures {
277            results.push(future.await?);
278        }
279
280        Ok(results)
281    }
282
283    /// Submit batch with explicit device assignment
284    pub async fn submit_batch_to_devices<F, T>(&self, work_items: Vec<(usize, F)>) -> Result<Vec<T>>
285    where
286        F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
287        T: Send + 'static,
288    {
289        let mut futures = Vec::new();
290
291        for (device_index, work) in work_items {
292            let queue = self
293                .queues
294                .get(device_index)
295                .ok_or(GpuAdvancedError::InvalidGpuIndex {
296                    index: device_index,
297                    total: self.queues.len(),
298                })?;
299
300            let future = queue.submit_work(work);
301            futures.push(future);
302        }
303
304        // Wait for all to complete
305        let mut results = Vec::new();
306        for future in futures {
307            results.push(future.await?);
308        }
309
310        Ok(results)
311    }
312
313    /// Get total pending tasks across all queues
314    pub fn total_pending(&self) -> usize {
315        self.queues.iter().map(|q| q.pending_count()).sum()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_work_stealing_queue() {
325        let queue = WorkStealingQueue::new(10);
326        assert!(queue.is_empty());
327
328        let work: WorkItem = Box::new(|_device| Ok(()));
329        queue.push(work);
330        assert_eq!(queue.len(), 1);
331
332        let popped = queue.pop();
333        assert!(popped.is_some());
334        assert!(queue.is_empty());
335    }
336
337    #[test]
338    fn test_work_stealing_threshold() {
339        let queue = WorkStealingQueue::new(5);
340
341        // Add items below threshold
342        for _ in 0..4 {
343            queue.push(Box::new(|_device| Ok(())));
344        }
345        assert!(!queue.should_allow_stealing());
346
347        // Add items above threshold
348        for _ in 0..3 {
349            queue.push(Box::new(|_device| Ok(())));
350        }
351        assert!(queue.should_allow_stealing());
352    }
353}