Skip to main content

oxigdal_gpu_advanced/multi_gpu/
sync.rs

1//! Advanced multi-GPU synchronization primitives.
2//!
3//! This module provides sophisticated synchronization mechanisms for coordinating
4//! operations across multiple GPUs including barriers, events, and cross-GPU transfers.
5
6use crate::error::{GpuAdvancedError, Result};
7use parking_lot::{Mutex, RwLock};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::{Notify, Semaphore};
12use wgpu::{Buffer, Device, Queue};
13
14/// Multi-GPU synchronization manager
15#[derive(Clone)]
16pub struct SyncManager {
17    devices: Arc<Vec<Arc<Device>>>,
18    queues: Arc<Vec<Arc<Queue>>>,
19    barriers: Arc<RwLock<HashMap<String, Arc<Barrier>>>>,
20    events: Arc<RwLock<HashMap<String, Arc<Event>>>>,
21    fence_pool: Arc<Mutex<FencePool>>,
22}
23
24impl SyncManager {
25    /// Create a new synchronization manager
26    pub fn new(devices: Vec<Arc<Device>>, queues: Vec<Arc<Queue>>) -> Result<Self> {
27        if devices.len() != queues.len() {
28            return Err(GpuAdvancedError::InvalidConfiguration(
29                "Device and queue count mismatch".to_string(),
30            ));
31        }
32
33        Ok(Self {
34            devices: Arc::new(devices),
35            queues: Arc::new(queues),
36            barriers: Arc::new(RwLock::new(HashMap::new())),
37            events: Arc::new(RwLock::new(HashMap::new())),
38            fence_pool: Arc::new(Mutex::new(FencePool::new())),
39        })
40    }
41
42    /// Create a barrier for N GPUs
43    pub fn create_barrier(&self, name: &str, gpu_count: usize) -> Result<Arc<Barrier>> {
44        if gpu_count == 0 || gpu_count > self.devices.len() {
45            return Err(GpuAdvancedError::ConfigError(format!(
46                "Invalid GPU count {} for barrier (available: {})",
47                gpu_count,
48                self.devices.len()
49            )));
50        }
51
52        let barrier = Arc::new(Barrier::new(gpu_count));
53        self.barriers
54            .write()
55            .insert(name.to_string(), barrier.clone());
56        Ok(barrier)
57    }
58
59    /// Get an existing barrier
60    pub fn get_barrier(&self, name: &str) -> Option<Arc<Barrier>> {
61        self.barriers.read().get(name).cloned()
62    }
63
64    /// Create an event for GPU-to-GPU synchronization
65    pub fn create_event(&self, name: &str) -> Arc<Event> {
66        let event = Arc::new(Event::new());
67        self.events.write().insert(name.to_string(), event.clone());
68        event
69    }
70
71    /// Get an existing event
72    pub fn get_event(&self, name: &str) -> Option<Arc<Event>> {
73        self.events.read().get(name).cloned()
74    }
75
76    /// Transfer data between GPUs
77    pub async fn transfer_between_gpus(
78        &self,
79        src_gpu_idx: usize,
80        dst_gpu_idx: usize,
81        src_buffer: &Buffer,
82        dst_buffer: &Buffer,
83        size: u64,
84    ) -> Result<Duration> {
85        if src_gpu_idx >= self.devices.len() || dst_gpu_idx >= self.devices.len() {
86            return Err(GpuAdvancedError::InvalidConfiguration(
87                "GPU index out of bounds".to_string(),
88            ));
89        }
90
91        let start = Instant::now();
92
93        // Create staging buffer on host
94        let staging_buffer = self.devices[src_gpu_idx].create_buffer(&wgpu::BufferDescriptor {
95            label: Some("cross_gpu_staging"),
96            size,
97            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
98            mapped_at_creation: false,
99        });
100
101        // Copy from source GPU to staging
102        let mut encoder =
103            self.devices[src_gpu_idx].create_command_encoder(&wgpu::CommandEncoderDescriptor {
104                label: Some("cross_gpu_copy_src"),
105            });
106        encoder.copy_buffer_to_buffer(src_buffer, 0, &staging_buffer, 0, size);
107        self.queues[src_gpu_idx].submit(Some(encoder.finish()));
108
109        // Wait for copy to complete
110        let slice = staging_buffer.slice(..);
111        let (tx, rx) = futures::channel::oneshot::channel();
112        slice.map_async(wgpu::MapMode::Read, move |result| {
113            let _ = tx.send(result);
114        });
115        // wgpu 28 polls automatically - no explicit poll needed
116
117        rx.await
118            .map_err(|_| GpuAdvancedError::SyncError("Transfer channel closed".to_string()))?
119            .map_err(|e| GpuAdvancedError::SyncError(format!("Map async failed: {:?}", e)))?;
120
121        // Read from staging
122        let data = slice.get_mapped_range();
123        let vec_data: Vec<u8> = data.to_vec();
124        drop(data);
125        staging_buffer.unmap();
126
127        // Create destination staging buffer
128        let dst_staging = self.devices[dst_gpu_idx].create_buffer(&wgpu::BufferDescriptor {
129            label: Some("cross_gpu_staging_dst"),
130            size,
131            usage: wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::MAP_WRITE,
132            mapped_at_creation: true,
133        });
134
135        // Write to destination staging
136        {
137            let mut mapped = dst_staging.slice(..).get_mapped_range_mut();
138            mapped.copy_from_slice(&vec_data);
139        }
140        dst_staging.unmap();
141
142        // Copy from staging to destination GPU
143        let mut encoder =
144            self.devices[dst_gpu_idx].create_command_encoder(&wgpu::CommandEncoderDescriptor {
145                label: Some("cross_gpu_copy_dst"),
146            });
147        encoder.copy_buffer_to_buffer(&dst_staging, 0, dst_buffer, 0, size);
148        self.queues[dst_gpu_idx].submit(Some(encoder.finish()));
149
150        // wgpu 28 polls automatically - submission completes asynchronously
151
152        Ok(start.elapsed())
153    }
154
155    /// Acquire a fence from the pool
156    pub fn acquire_fence(&self) -> Fence {
157        self.fence_pool.lock().acquire()
158    }
159
160    /// Release a fence back to the pool
161    pub fn release_fence(&self, fence: Fence) {
162        self.fence_pool.lock().release(fence);
163    }
164
165    /// Number of available GPUs
166    pub fn gpu_count(&self) -> usize {
167        self.devices.len()
168    }
169}
170
171/// Barrier synchronization primitive for multiple GPUs
172pub struct Barrier {
173    count: usize,
174    arrived: Mutex<usize>,
175    generation: Mutex<usize>,
176    notify: Notify,
177}
178
179impl Barrier {
180    /// Create a new barrier
181    pub fn new(count: usize) -> Self {
182        Self {
183            count,
184            arrived: Mutex::new(0),
185            generation: Mutex::new(0),
186            notify: Notify::new(),
187        }
188    }
189
190    /// Wait at the barrier
191    pub async fn wait(&self) -> Result<()> {
192        let current_gen = *self.generation.lock();
193
194        let arrived = {
195            let mut arrived = self.arrived.lock();
196            *arrived += 1;
197            *arrived
198        };
199
200        if arrived == self.count {
201            // Last one to arrive, reset and notify all
202            {
203                let mut arrived = self.arrived.lock();
204                *arrived = 0;
205            }
206            {
207                let mut gen_val = self.generation.lock();
208                *gen_val += 1;
209            }
210            self.notify.notify_waiters();
211            Ok(())
212        } else {
213            // Wait for notification
214            loop {
215                self.notify.notified().await;
216                let gen_val = *self.generation.lock();
217                if gen_val > current_gen {
218                    break;
219                }
220            }
221            Ok(())
222        }
223    }
224
225    /// Wait with timeout
226    pub async fn wait_timeout(&self, timeout: Duration) -> Result<bool> {
227        let wait_future = self.wait();
228        match tokio::time::timeout(timeout, wait_future).await {
229            Ok(Ok(())) => Ok(true),
230            Ok(Err(e)) => Err(e),
231            Err(_) => Ok(false), // Timeout
232        }
233    }
234
235    /// Get the barrier count
236    pub fn count(&self) -> usize {
237        self.count
238    }
239
240    /// Get the current number of waiting threads
241    pub fn waiting(&self) -> usize {
242        *self.arrived.lock()
243    }
244}
245
246/// Event for GPU-to-GPU signaling
247pub struct Event {
248    signaled: Mutex<bool>,
249    notify: Notify,
250    timestamp: Mutex<Option<Instant>>,
251}
252
253impl Event {
254    /// Create a new event
255    pub fn new() -> Self {
256        Self {
257            signaled: Mutex::new(false),
258            notify: Notify::new(),
259            timestamp: Mutex::new(None),
260        }
261    }
262
263    /// Signal the event
264    pub fn signal(&self) {
265        *self.signaled.lock() = true;
266        *self.timestamp.lock() = Some(Instant::now());
267        self.notify.notify_waiters();
268    }
269
270    /// Reset the event
271    pub fn reset(&self) {
272        *self.signaled.lock() = false;
273        *self.timestamp.lock() = None;
274    }
275
276    /// Wait for the event to be signaled
277    pub async fn wait(&self) {
278        if *self.signaled.lock() {
279            return;
280        }
281        self.notify.notified().await;
282    }
283
284    /// Wait with timeout
285    pub async fn wait_timeout(&self, timeout: Duration) -> bool {
286        if *self.signaled.lock() {
287            return true;
288        }
289        tokio::time::timeout(timeout, self.notify.notified())
290            .await
291            .is_ok()
292    }
293
294    /// Check if the event is signaled
295    pub fn is_signaled(&self) -> bool {
296        *self.signaled.lock()
297    }
298
299    /// Get the timestamp when the event was signaled
300    pub fn timestamp(&self) -> Option<Instant> {
301        *self.timestamp.lock()
302    }
303}
304
305impl Default for Event {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311/// Fence for command buffer synchronization
312#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
313pub struct Fence {
314    id: u64,
315}
316
317impl Fence {
318    fn new(id: u64) -> Self {
319        Self { id }
320    }
321
322    /// Get the fence ID
323    pub fn id(&self) -> u64 {
324        self.id
325    }
326}
327
328/// Pool of reusable fences
329struct FencePool {
330    next_id: u64,
331    available: Vec<Fence>,
332    max_pool_size: usize,
333}
334
335impl FencePool {
336    fn new() -> Self {
337        Self {
338            next_id: 0,
339            available: Vec::new(),
340            max_pool_size: 256,
341        }
342    }
343
344    fn acquire(&mut self) -> Fence {
345        if let Some(fence) = self.available.pop() {
346            fence
347        } else {
348            let fence = Fence::new(self.next_id);
349            self.next_id += 1;
350            fence
351        }
352    }
353
354    fn release(&mut self, fence: Fence) {
355        if self.available.len() < self.max_pool_size {
356            self.available.push(fence);
357        }
358    }
359}
360
361/// Semaphore for controlling concurrent GPU access
362pub struct GpuSemaphore {
363    inner: Arc<Semaphore>,
364}
365
366impl GpuSemaphore {
367    /// Create a new semaphore with the given permit count
368    pub fn new(permits: usize) -> Self {
369        Self {
370            inner: Arc::new(Semaphore::new(permits)),
371        }
372    }
373
374    /// Acquire a permit
375    pub async fn acquire(&self) -> Result<SemaphoreGuard<'_>> {
376        let permit =
377            self.inner.acquire().await.map_err(|e| {
378                GpuAdvancedError::SyncError(format!("Semaphore acquire failed: {}", e))
379            })?;
380        Ok(SemaphoreGuard { _permit: permit })
381    }
382
383    /// Try to acquire a permit without waiting
384    pub fn try_acquire(&self) -> Option<SemaphoreGuard<'_>> {
385        self.inner
386            .try_acquire()
387            .ok()
388            .map(|permit| SemaphoreGuard { _permit: permit })
389    }
390
391    /// Get available permits
392    pub fn available_permits(&self) -> usize {
393        self.inner.available_permits()
394    }
395}
396
397impl Clone for GpuSemaphore {
398    fn clone(&self) -> Self {
399        Self {
400            inner: Arc::clone(&self.inner),
401        }
402    }
403}
404
405/// RAII guard for semaphore permit
406pub struct SemaphoreGuard<'a> {
407    _permit: tokio::sync::SemaphorePermit<'a>,
408}
409
410/// Synchronization statistics
411#[derive(Debug, Clone, Default)]
412pub struct SyncStats {
413    /// Number of barrier waits
414    pub barrier_waits: u64,
415    /// Number of event signals
416    pub event_signals: u64,
417    /// Number of cross-GPU transfers
418    pub cross_gpu_transfers: u64,
419    /// Total transfer time
420    pub total_transfer_time: Duration,
421    /// Total bytes transferred
422    pub total_bytes_transferred: u64,
423}
424
425impl SyncStats {
426    /// Calculate average transfer bandwidth in GB/s
427    pub fn average_bandwidth_gbs(&self) -> Option<f64> {
428        if self.total_transfer_time > Duration::ZERO && self.total_bytes_transferred > 0 {
429            let bytes_per_sec =
430                self.total_bytes_transferred as f64 / self.total_transfer_time.as_secs_f64();
431            Some(bytes_per_sec / 1_000_000_000.0)
432        } else {
433            None
434        }
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[tokio::test]
443    async fn test_barrier() {
444        let barrier = Arc::new(Barrier::new(3));
445        let mut handles = Vec::new();
446
447        for i in 0..3 {
448            let b = barrier.clone();
449            let handle = tokio::spawn(async move {
450                println!("Task {} waiting at barrier", i);
451                b.wait().await.ok();
452                println!("Task {} passed barrier", i);
453            });
454            handles.push(handle);
455        }
456
457        for handle in handles {
458            handle.await.ok();
459        }
460
461        assert_eq!(barrier.waiting(), 0);
462    }
463
464    #[tokio::test]
465    async fn test_event() {
466        let event = Arc::new(Event::new());
467        assert!(!event.is_signaled());
468
469        let e = event.clone();
470        let handle = tokio::spawn(async move {
471            e.wait().await;
472        });
473
474        tokio::time::sleep(Duration::from_millis(10)).await;
475        event.signal();
476        assert!(event.is_signaled());
477
478        handle.await.ok();
479    }
480
481    #[tokio::test]
482    async fn test_semaphore() {
483        let sem = GpuSemaphore::new(2);
484        assert_eq!(sem.available_permits(), 2);
485
486        let _guard1 = sem.acquire().await.ok();
487        assert_eq!(sem.available_permits(), 1);
488
489        let _guard2 = sem.acquire().await.ok();
490        assert_eq!(sem.available_permits(), 0);
491
492        drop(_guard1);
493        assert_eq!(sem.available_permits(), 1);
494    }
495
496    #[test]
497    fn test_fence_pool() {
498        let mut pool = FencePool::new();
499        let f1 = pool.acquire();
500        let f2 = pool.acquire();
501
502        assert_ne!(f1.id(), f2.id());
503
504        pool.release(f1);
505        let f3 = pool.acquire();
506        assert_eq!(f1, f3);
507    }
508}