oxigdal_gpu_advanced/multi_gpu/
work_queue.rs1use super::GpuDevice;
4use crate::error::{GpuAdvancedError, Result};
5use crossbeam_channel::{Receiver, Sender, bounded};
6use parking_lot::Mutex;
7use std::sync::Arc;
8use std::thread;
9use std::time::Instant;
10
11type WorkItem = Box<dyn FnOnce(&GpuDevice) -> Result<()> + Send>;
13
14type ResultSender = Sender<Result<()>>;
16
17pub struct WorkQueue {
19 device: Arc<GpuDevice>,
21 work_sender: Option<Sender<(WorkItem, ResultSender)>>,
23 worker_handle: Option<Arc<Mutex<Option<thread::JoinHandle<()>>>>>,
25 pending_tasks: Arc<Mutex<usize>>,
27}
28
29impl WorkQueue {
30 pub fn new(device: Arc<GpuDevice>) -> Self {
32 let (work_sender, work_receiver) = bounded::<(WorkItem, ResultSender)>(256);
33 let device_clone = device.clone();
34 let pending_tasks = Arc::new(Mutex::new(0));
35 let pending_clone = pending_tasks.clone();
36
37 let handle = thread::spawn(move || {
39 Self::worker_loop(device_clone, work_receiver, pending_clone);
40 });
41
42 Self {
43 device,
44 work_sender: Some(work_sender),
45 worker_handle: Some(Arc::new(Mutex::new(Some(handle)))),
46 pending_tasks,
47 }
48 }
49
50 fn worker_loop(
52 device: Arc<GpuDevice>,
53 work_receiver: Receiver<(WorkItem, ResultSender)>,
54 pending_tasks: Arc<Mutex<usize>>,
55 ) {
56 while let Ok((work, result_sender)) = work_receiver.recv() {
57 let start = Instant::now();
58
59 device.set_workload(1.0);
61
62 let result = work(&device);
64
65 device.set_workload(0.0);
67
68 let _ = result_sender.send(result);
70
71 {
73 let mut pending = pending_tasks.lock();
74 *pending = pending.saturating_sub(1);
75 }
76
77 let duration = start.elapsed();
78 tracing::debug!(
79 "Task completed on GPU {} in {:?}",
80 device.info.index,
81 duration
82 );
83 }
84 }
85
86 pub async fn submit_work<F, T>(&self, work: F) -> Result<T>
88 where
89 F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
90 T: Send + 'static,
91 {
92 let (result_sender, result_receiver) = bounded(1);
93 let result_arc = Arc::new(Mutex::new(None));
94 let result_clone = result_arc.clone();
95
96 let work_wrapper: WorkItem = Box::new(move |device| {
98 let result = work(device);
99 match result {
100 Ok(value) => {
101 *result_clone.lock() = Some(Ok(value));
102 Ok(())
103 }
104 Err(e) => {
105 *result_clone.lock() = Some(Err(e));
106 Ok(())
107 }
108 }
109 });
110
111 {
113 let mut pending = self.pending_tasks.lock();
114 *pending = pending.saturating_add(1);
115 }
116
117 self.work_sender
119 .as_ref()
120 .ok_or_else(|| GpuAdvancedError::WorkStealingError("Work queue is closed".to_string()))?
121 .send((work_wrapper, result_sender))
122 .map_err(|e| {
123 GpuAdvancedError::WorkStealingError(format!("Failed to send work: {}", e))
124 })?;
125
126 let _ = result_receiver
128 .recv()
129 .map_err(|e| GpuAdvancedError::SyncError(format!("Failed to receive result: {}", e)))?;
130
131 result_arc
133 .lock()
134 .take()
135 .ok_or_else(|| GpuAdvancedError::SyncError("Result not available".to_string()))?
136 }
137
138 pub fn pending_count(&self) -> usize {
140 *self.pending_tasks.lock()
141 }
142
143 pub fn is_empty(&self) -> bool {
145 self.pending_count() == 0
146 }
147
148 pub fn device(&self) -> Arc<GpuDevice> {
150 self.device.clone()
151 }
152}
153
154impl Drop for WorkQueue {
155 fn drop(&mut self) {
156 drop(self.work_sender.take());
160
161 if let Some(handle_arc) = self.worker_handle.take() {
163 if let Some(handle) = handle_arc.lock().take() {
164 let _ = handle.join();
165 }
166 }
167 }
168}
169
170pub struct WorkStealingQueue {
172 local_queue: Arc<Mutex<Vec<WorkItem>>>,
174 steal_threshold: usize,
176}
177
178impl WorkStealingQueue {
179 pub fn new(steal_threshold: usize) -> Self {
181 Self {
182 local_queue: Arc::new(Mutex::new(Vec::new())),
183 steal_threshold,
184 }
185 }
186
187 pub fn push(&self, work: WorkItem) {
189 let mut queue = self.local_queue.lock();
190 queue.push(work);
191 }
192
193 pub fn pop(&self) -> Option<WorkItem> {
195 let mut queue = self.local_queue.lock();
196 queue.pop()
197 }
198
199 pub fn steal(&self) -> Vec<WorkItem> {
201 let mut queue = self.local_queue.lock();
202 let len = queue.len();
203
204 if len <= self.steal_threshold {
205 return Vec::new();
206 }
207
208 let steal_count = len / 2;
209 let split_point = len - steal_count;
210 queue.split_off(split_point)
211 }
212
213 pub fn len(&self) -> usize {
215 self.local_queue.lock().len()
216 }
217
218 pub fn is_empty(&self) -> bool {
220 self.len() == 0
221 }
222
223 pub fn should_allow_stealing(&self) -> bool {
225 self.len() > self.steal_threshold
226 }
227}
228
229pub struct BatchSubmitter {
231 queues: Vec<Arc<WorkQueue>>,
233 current_index: Mutex<usize>,
235}
236
237impl BatchSubmitter {
238 pub fn new(queues: Vec<Arc<WorkQueue>>) -> Self {
240 Self {
241 queues,
242 current_index: Mutex::new(0),
243 }
244 }
245
246 pub async fn submit_batch<F, T>(&self, work_items: Vec<F>) -> Result<Vec<T>>
248 where
249 F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
250 T: Send + 'static,
251 {
252 if self.queues.is_empty() {
253 return Err(GpuAdvancedError::WorkStealingError(
254 "No work queues available".to_string(),
255 ));
256 }
257
258 let mut futures = Vec::new();
259
260 for work in work_items {
261 let queue_index = {
263 let mut index = self.current_index.lock();
264 let current = *index;
265 *index = (*index + 1) % self.queues.len();
266 current
267 };
268
269 let queue = &self.queues[queue_index];
270 let future = queue.submit_work(work);
271 futures.push(future);
272 }
273
274 let mut results = Vec::new();
276 for future in futures {
277 results.push(future.await?);
278 }
279
280 Ok(results)
281 }
282
283 pub async fn submit_batch_to_devices<F, T>(&self, work_items: Vec<(usize, F)>) -> Result<Vec<T>>
285 where
286 F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
287 T: Send + 'static,
288 {
289 let mut futures = Vec::new();
290
291 for (device_index, work) in work_items {
292 let queue = self
293 .queues
294 .get(device_index)
295 .ok_or(GpuAdvancedError::InvalidGpuIndex {
296 index: device_index,
297 total: self.queues.len(),
298 })?;
299
300 let future = queue.submit_work(work);
301 futures.push(future);
302 }
303
304 let mut results = Vec::new();
306 for future in futures {
307 results.push(future.await?);
308 }
309
310 Ok(results)
311 }
312
313 pub fn total_pending(&self) -> usize {
315 self.queues.iter().map(|q| q.pending_count()).sum()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_work_stealing_queue() {
325 let queue = WorkStealingQueue::new(10);
326 assert!(queue.is_empty());
327
328 let work: WorkItem = Box::new(|_device| Ok(()));
329 queue.push(work);
330 assert_eq!(queue.len(), 1);
331
332 let popped = queue.pop();
333 assert!(popped.is_some());
334 assert!(queue.is_empty());
335 }
336
337 #[test]
338 fn test_work_stealing_threshold() {
339 let queue = WorkStealingQueue::new(5);
340
341 for _ in 0..4 {
343 queue.push(Box::new(|_device| Ok(())));
344 }
345 assert!(!queue.should_allow_stealing());
346
347 for _ in 0..3 {
349 queue.push(Box::new(|_device| Ok(())));
350 }
351 assert!(queue.should_allow_stealing());
352 }
353}