Skip to main content

oximedia_gpu/
sync.rs

1//! Synchronization primitives for GPU operations
2//!
3//! This module provides abstractions for synchronizing GPU operations,
4//! including fences, barriers, and event synchronization.
5
6use crate::GpuDevice;
7use parking_lot::{Condvar, Mutex};
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12/// Fence for CPU-GPU synchronization
13///
14/// A fence allows the CPU to wait for GPU operations to complete.
15pub struct Fence {
16    device: Arc<wgpu::Device>,
17    signaled: Arc<AtomicBool>,
18    timestamp: Arc<AtomicU64>,
19}
20
21impl Fence {
22    /// Create a new fence
23    #[must_use]
24    pub fn new(device: &GpuDevice) -> Self {
25        Self {
26            device: Arc::clone(device.device()),
27            signaled: Arc::new(AtomicBool::new(false)),
28            timestamp: Arc::new(AtomicU64::new(0)),
29        }
30    }
31
32    /// Signal the fence
33    ///
34    /// This should be called after submitting GPU commands that you want to wait for.
35    pub fn signal(&self) {
36        let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
37        self.signaled.store(true, Ordering::Release);
38        self.timestamp.store(
39            Instant::now().elapsed().as_nanos() as u64,
40            Ordering::Relaxed,
41        );
42    }
43
44    /// Wait for the fence to be signaled
45    ///
46    /// This will block the current thread until the GPU has completed the operations
47    /// that were submitted before the fence was signaled.
48    pub fn wait(&self) {
49        while !self.signaled.load(Ordering::Acquire) {
50            std::thread::yield_now();
51        }
52    }
53
54    /// Wait for the fence with a timeout
55    ///
56    /// # Arguments
57    ///
58    /// * `timeout` - Maximum time to wait
59    ///
60    /// # Returns
61    ///
62    /// True if the fence was signaled within the timeout, false otherwise.
63    #[must_use]
64    pub fn wait_timeout(&self, timeout: Duration) -> bool {
65        let start = Instant::now();
66        while !self.signaled.load(Ordering::Acquire) {
67            if start.elapsed() > timeout {
68                return false;
69            }
70            std::thread::yield_now();
71        }
72        true
73    }
74
75    /// Check if the fence is signaled without blocking
76    #[must_use]
77    pub fn is_signaled(&self) -> bool {
78        self.signaled.load(Ordering::Acquire)
79    }
80
81    /// Reset the fence to unsignaled state
82    pub fn reset(&self) {
83        self.signaled.store(false, Ordering::Release);
84    }
85
86    /// Get the timestamp when the fence was signaled (in nanoseconds)
87    #[must_use]
88    pub fn timestamp(&self) -> Option<u64> {
89        if self.is_signaled() {
90            Some(self.timestamp.load(Ordering::Relaxed))
91        } else {
92            None
93        }
94    }
95}
96
97impl Clone for Fence {
98    fn clone(&self) -> Self {
99        Self {
100            device: Arc::clone(&self.device),
101            signaled: Arc::clone(&self.signaled),
102            timestamp: Arc::clone(&self.timestamp),
103        }
104    }
105}
106
107/// Semaphore for GPU-GPU synchronization
108///
109/// Semaphores are used to synchronize operations between different command queues.
110pub struct Semaphore {
111    value: Arc<AtomicU64>,
112    condvar: Arc<Condvar>,
113    mutex: Arc<Mutex<()>>,
114}
115
116impl Semaphore {
117    /// Create a new semaphore with an initial value
118    #[must_use]
119    pub fn new(initial_value: u64) -> Self {
120        Self {
121            value: Arc::new(AtomicU64::new(initial_value)),
122            condvar: Arc::new(Condvar::new()),
123            mutex: Arc::new(Mutex::new(())),
124        }
125    }
126
127    /// Signal the semaphore (increment value)
128    pub fn signal(&self) {
129        self.value.fetch_add(1, Ordering::Release);
130        self.condvar.notify_all();
131    }
132
133    /// Wait for the semaphore (decrement value, block if zero)
134    pub fn wait(&self) {
135        loop {
136            let current = self.value.load(Ordering::Acquire);
137            if current > 0 {
138                if self
139                    .value
140                    .compare_exchange_weak(
141                        current,
142                        current - 1,
143                        Ordering::AcqRel,
144                        Ordering::Acquire,
145                    )
146                    .is_ok()
147                {
148                    return;
149                }
150            } else {
151                let mut guard = self.mutex.lock();
152                self.condvar.wait(&mut guard);
153            }
154        }
155    }
156
157    /// Try to wait for the semaphore without blocking
158    ///
159    /// # Returns
160    ///
161    /// True if the semaphore was successfully acquired, false otherwise.
162    #[must_use]
163    pub fn try_wait(&self) -> bool {
164        loop {
165            let current = self.value.load(Ordering::Acquire);
166            if current > 0 {
167                match self.value.compare_exchange_weak(
168                    current,
169                    current - 1,
170                    Ordering::AcqRel,
171                    Ordering::Acquire,
172                ) {
173                    Ok(_) => return true,
174                    Err(_) => continue,
175                }
176            }
177            return false;
178        }
179    }
180
181    /// Get the current semaphore value
182    #[must_use]
183    pub fn value(&self) -> u64 {
184        self.value.load(Ordering::Acquire)
185    }
186
187    /// Reset the semaphore to a specific value
188    pub fn reset(&self, value: u64) {
189        self.value.store(value, Ordering::Release);
190        self.condvar.notify_all();
191    }
192}
193
194impl Clone for Semaphore {
195    fn clone(&self) -> Self {
196        Self {
197            value: Arc::clone(&self.value),
198            condvar: Arc::clone(&self.condvar),
199            mutex: Arc::clone(&self.mutex),
200        }
201    }
202}
203
204/// Event for signaling completion of operations
205pub struct Event {
206    signaled: Arc<AtomicBool>,
207    condvar: Arc<Condvar>,
208    mutex: Arc<Mutex<()>>,
209}
210
211impl Event {
212    /// Create a new event
213    #[must_use]
214    pub fn new() -> Self {
215        Self {
216            signaled: Arc::new(AtomicBool::new(false)),
217            condvar: Arc::new(Condvar::new()),
218            mutex: Arc::new(Mutex::new(())),
219        }
220    }
221
222    /// Signal the event
223    pub fn signal(&self) {
224        self.signaled.store(true, Ordering::Release);
225        self.condvar.notify_all();
226    }
227
228    /// Wait for the event to be signaled
229    pub fn wait(&self) {
230        while !self.signaled.load(Ordering::Acquire) {
231            let mut guard = self.mutex.lock();
232            self.condvar.wait(&mut guard);
233        }
234    }
235
236    /// Wait for the event with a timeout
237    ///
238    /// # Arguments
239    ///
240    /// * `timeout` - Maximum time to wait
241    ///
242    /// # Returns
243    ///
244    /// True if the event was signaled within the timeout, false otherwise.
245    #[must_use]
246    pub fn wait_timeout(&self, timeout: Duration) -> bool {
247        let start = Instant::now();
248        while !self.signaled.load(Ordering::Acquire) {
249            if start.elapsed() > timeout {
250                return false;
251            }
252            let mut guard = self.mutex.lock();
253            let remaining = timeout.saturating_sub(start.elapsed());
254            if remaining.is_zero() {
255                return false;
256            }
257            self.condvar.wait_for(&mut guard, remaining);
258        }
259        true
260    }
261
262    /// Check if the event is signaled
263    #[must_use]
264    pub fn is_signaled(&self) -> bool {
265        self.signaled.load(Ordering::Acquire)
266    }
267
268    /// Reset the event to unsignaled state
269    pub fn reset(&self) {
270        self.signaled.store(false, Ordering::Release);
271    }
272}
273
274impl Default for Event {
275    fn default() -> Self {
276        Self::new()
277    }
278}
279
280impl Clone for Event {
281    fn clone(&self) -> Self {
282        Self {
283            signaled: Arc::clone(&self.signaled),
284            condvar: Arc::clone(&self.condvar),
285            mutex: Arc::clone(&self.mutex),
286        }
287    }
288}
289
290/// Barrier for synchronizing multiple operations
291pub struct Barrier {
292    total_count: usize,
293    current_count: Arc<AtomicU64>,
294    generation: Arc<AtomicU64>,
295    condvar: Arc<Condvar>,
296    mutex: Arc<Mutex<()>>,
297}
298
299impl Barrier {
300    /// Create a new barrier
301    ///
302    /// # Arguments
303    ///
304    /// * `count` - Number of threads/operations that must reach the barrier
305    #[must_use]
306    pub fn new(count: usize) -> Self {
307        Self {
308            total_count: count,
309            current_count: Arc::new(AtomicU64::new(0)),
310            generation: Arc::new(AtomicU64::new(0)),
311            condvar: Arc::new(Condvar::new()),
312            mutex: Arc::new(Mutex::new(())),
313        }
314    }
315
316    /// Wait at the barrier
317    ///
318    /// This will block until all threads/operations have reached the barrier.
319    pub fn wait(&self) {
320        let gen = self.generation.load(Ordering::Acquire);
321        let count = self.current_count.fetch_add(1, Ordering::AcqRel) + 1;
322
323        if count >= self.total_count as u64 {
324            // Last thread resets the barrier
325            self.current_count.store(0, Ordering::Release);
326            self.generation.fetch_add(1, Ordering::Release);
327            self.condvar.notify_all();
328        } else {
329            // Wait for all threads
330            let mut guard = self.mutex.lock();
331            while gen == self.generation.load(Ordering::Acquire) {
332                self.condvar.wait(&mut guard);
333            }
334        }
335    }
336
337    /// Get the total count required for the barrier
338    #[must_use]
339    pub fn count(&self) -> usize {
340        self.total_count
341    }
342
343    /// Get the current number of waiting threads
344    #[must_use]
345    pub fn waiting(&self) -> u64 {
346        self.current_count.load(Ordering::Acquire)
347    }
348}
349
350impl Clone for Barrier {
351    fn clone(&self) -> Self {
352        Self {
353            total_count: self.total_count,
354            current_count: Arc::clone(&self.current_count),
355            generation: Arc::clone(&self.generation),
356            condvar: Arc::clone(&self.condvar),
357            mutex: Arc::clone(&self.mutex),
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_semaphore() {
368        let sem = Semaphore::new(1);
369        assert_eq!(sem.value(), 1);
370
371        assert!(sem.try_wait());
372        assert_eq!(sem.value(), 0);
373
374        assert!(!sem.try_wait());
375        assert_eq!(sem.value(), 0);
376
377        sem.signal();
378        assert_eq!(sem.value(), 1);
379
380        assert!(sem.try_wait());
381        assert_eq!(sem.value(), 0);
382    }
383
384    #[test]
385    fn test_event() {
386        let event = Event::new();
387        assert!(!event.is_signaled());
388
389        event.signal();
390        assert!(event.is_signaled());
391
392        event.reset();
393        assert!(!event.is_signaled());
394    }
395
396    #[test]
397    fn test_barrier() {
398        let barrier = Barrier::new(3);
399        assert_eq!(barrier.count(), 3);
400        assert_eq!(barrier.waiting(), 0);
401    }
402}