Skip to main content

oximedia_gpu/
barrier_manager.rs

1#![allow(dead_code)]
2//! GPU barrier and synchronization management.
3//!
4//! This module provides a barrier manager that tracks and optimizes
5//! memory and execution barriers between GPU operations, ensuring
6//! correct ordering of read/write operations across command buffers.
7
8use std::collections::HashMap;
9use std::fmt;
10
11/// Describes the type of resource access for barrier tracking.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum AccessType {
14    /// No access (initial state).
15    None,
16    /// Read-only access from a shader.
17    ShaderRead,
18    /// Write access from a shader (compute or fragment).
19    ShaderWrite,
20    /// Transfer source (copy from).
21    TransferSrc,
22    /// Transfer destination (copy to).
23    TransferDst,
24    /// Host read (CPU readback).
25    HostRead,
26    /// Host write (CPU upload).
27    HostWrite,
28}
29
30impl fmt::Display for AccessType {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        match self {
33            Self::None => write!(f, "None"),
34            Self::ShaderRead => write!(f, "ShaderRead"),
35            Self::ShaderWrite => write!(f, "ShaderWrite"),
36            Self::TransferSrc => write!(f, "TransferSrc"),
37            Self::TransferDst => write!(f, "TransferDst"),
38            Self::HostRead => write!(f, "HostRead"),
39            Self::HostWrite => write!(f, "HostWrite"),
40        }
41    }
42}
43
44/// Pipeline stage where a barrier is needed.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum PipelineStage {
47    /// Top of pipe (no operations have started).
48    TopOfPipe,
49    /// Compute shader stage.
50    Compute,
51    /// Transfer/copy stage.
52    Transfer,
53    /// Host stage (CPU access).
54    Host,
55    /// Bottom of pipe (all operations completed).
56    BottomOfPipe,
57}
58
59impl fmt::Display for PipelineStage {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self {
62            Self::TopOfPipe => write!(f, "TopOfPipe"),
63            Self::Compute => write!(f, "Compute"),
64            Self::Transfer => write!(f, "Transfer"),
65            Self::Host => write!(f, "Host"),
66            Self::BottomOfPipe => write!(f, "BottomOfPipe"),
67        }
68    }
69}
70
71/// A unique identifier for a tracked resource.
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73pub struct ResourceId(pub u64);
74
75/// Describes a single barrier between two access types.
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct BarrierDesc {
78    /// The resource this barrier applies to.
79    pub resource_id: ResourceId,
80    /// The previous access type.
81    pub src_access: AccessType,
82    /// The next access type.
83    pub dst_access: AccessType,
84    /// The pipeline stage of the previous access.
85    pub src_stage: PipelineStage,
86    /// The pipeline stage of the next access.
87    pub dst_stage: PipelineStage,
88}
89
90impl BarrierDesc {
91    /// Create a new barrier description.
92    pub fn new(
93        resource_id: ResourceId,
94        src_access: AccessType,
95        dst_access: AccessType,
96        src_stage: PipelineStage,
97        dst_stage: PipelineStage,
98    ) -> Self {
99        Self {
100            resource_id,
101            src_access,
102            dst_access,
103            src_stage,
104            dst_stage,
105        }
106    }
107
108    /// Check whether this barrier is a read-after-write hazard.
109    pub fn is_raw_hazard(&self) -> bool {
110        matches!(
111            self.src_access,
112            AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite
113        ) && matches!(
114            self.dst_access,
115            AccessType::ShaderRead | AccessType::TransferSrc | AccessType::HostRead
116        )
117    }
118
119    /// Check whether this barrier is a write-after-write hazard.
120    pub fn is_waw_hazard(&self) -> bool {
121        matches!(
122            self.src_access,
123            AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite
124        ) && matches!(
125            self.dst_access,
126            AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite
127        )
128    }
129
130    /// Check whether this barrier is a write-after-read hazard.
131    pub fn is_war_hazard(&self) -> bool {
132        matches!(
133            self.src_access,
134            AccessType::ShaderRead | AccessType::TransferSrc | AccessType::HostRead
135        ) && matches!(
136            self.dst_access,
137            AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite
138        )
139    }
140}
141
142/// Tracks the current access state of each resource.
143#[derive(Debug, Clone)]
144struct ResourceState {
145    /// Current access type.
146    access: AccessType,
147    /// Current pipeline stage.
148    stage: PipelineStage,
149}
150
151/// Manages barriers for a set of GPU resources.
152///
153/// The barrier manager tracks current access states and automatically
154/// determines which barriers are needed when a resource transitions
155/// to a new access pattern.
156pub struct BarrierManager {
157    /// Current state of each tracked resource.
158    states: HashMap<ResourceId, ResourceState>,
159    /// Accumulated pending barriers to be submitted.
160    pending: Vec<BarrierDesc>,
161    /// Total number of barriers emitted.
162    total_barriers: u64,
163    /// Number of barriers that were optimized away (redundant).
164    optimized_away: u64,
165}
166
167impl BarrierManager {
168    /// Create a new empty barrier manager.
169    pub fn new() -> Self {
170        Self {
171            states: HashMap::new(),
172            pending: Vec::new(),
173            total_barriers: 0,
174            optimized_away: 0,
175        }
176    }
177
178    /// Register a new resource with an initial access type.
179    pub fn register_resource(
180        &mut self,
181        id: ResourceId,
182        initial_access: AccessType,
183        stage: PipelineStage,
184    ) {
185        self.states.insert(
186            id,
187            ResourceState {
188                access: initial_access,
189                stage,
190            },
191        );
192    }
193
194    /// Transition a resource to a new access type, emitting a barrier if needed.
195    ///
196    /// Returns `true` if a barrier was emitted, `false` if the transition
197    /// was redundant (same access/stage).
198    pub fn transition(
199        &mut self,
200        id: ResourceId,
201        new_access: AccessType,
202        new_stage: PipelineStage,
203    ) -> bool {
204        let current = self.states.get(&id).cloned().unwrap_or(ResourceState {
205            access: AccessType::None,
206            stage: PipelineStage::TopOfPipe,
207        });
208
209        // No barrier needed if access type and stage are the same
210        if current.access == new_access && current.stage == new_stage {
211            self.optimized_away += 1;
212            return false;
213        }
214
215        // No barrier needed for read-to-read transitions at the same stage
216        if is_read_only(current.access) && is_read_only(new_access) && current.stage == new_stage {
217            self.optimized_away += 1;
218            // Still update state
219            self.states.insert(
220                id,
221                ResourceState {
222                    access: new_access,
223                    stage: new_stage,
224                },
225            );
226            return false;
227        }
228
229        let barrier = BarrierDesc::new(id, current.access, new_access, current.stage, new_stage);
230        self.pending.push(barrier);
231        self.total_barriers += 1;
232
233        self.states.insert(
234            id,
235            ResourceState {
236                access: new_access,
237                stage: new_stage,
238            },
239        );
240
241        true
242    }
243
244    /// Drain all pending barriers, returning them.
245    pub fn flush(&mut self) -> Vec<BarrierDesc> {
246        std::mem::take(&mut self.pending)
247    }
248
249    /// Get the number of pending barriers.
250    pub fn pending_count(&self) -> usize {
251        self.pending.len()
252    }
253
254    /// Get the total number of barriers emitted since creation.
255    pub fn total_barriers(&self) -> u64 {
256        self.total_barriers
257    }
258
259    /// Get the number of barriers that were optimized away.
260    pub fn optimized_away(&self) -> u64 {
261        self.optimized_away
262    }
263
264    /// Get the current access state of a resource.
265    pub fn current_access(&self, id: ResourceId) -> Option<AccessType> {
266        self.states.get(&id).map(|s| s.access)
267    }
268
269    /// Get the current pipeline stage of a resource.
270    pub fn current_stage(&self, id: ResourceId) -> Option<PipelineStage> {
271        self.states.get(&id).map(|s| s.stage)
272    }
273
274    /// Remove a resource from tracking.
275    pub fn unregister_resource(&mut self, id: ResourceId) -> bool {
276        self.states.remove(&id).is_some()
277    }
278
279    /// Get the number of tracked resources.
280    pub fn resource_count(&self) -> usize {
281        self.states.len()
282    }
283
284    /// Clear all tracked state and pending barriers.
285    pub fn reset(&mut self) {
286        self.states.clear();
287        self.pending.clear();
288    }
289
290    /// Batch-transition multiple resources at once.
291    pub fn batch_transition(
292        &mut self,
293        transitions: &[(ResourceId, AccessType, PipelineStage)],
294    ) -> usize {
295        let mut count = 0;
296        for &(id, access, stage) in transitions {
297            if self.transition(id, access, stage) {
298                count += 1;
299            }
300        }
301        count
302    }
303}
304
305impl Default for BarrierManager {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311/// Check if an access type is read-only.
312fn is_read_only(access: AccessType) -> bool {
313    matches!(
314        access,
315        AccessType::ShaderRead | AccessType::TransferSrc | AccessType::HostRead | AccessType::None
316    )
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_new_barrier_manager() {
325        let mgr = BarrierManager::new();
326        assert_eq!(mgr.resource_count(), 0);
327        assert_eq!(mgr.pending_count(), 0);
328        assert_eq!(mgr.total_barriers(), 0);
329    }
330
331    #[test]
332    fn test_register_resource() {
333        let mut mgr = BarrierManager::new();
334        mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
335        assert_eq!(mgr.resource_count(), 1);
336        assert_eq!(mgr.current_access(ResourceId(1)), Some(AccessType::None));
337    }
338
339    #[test]
340    fn test_transition_emits_barrier() {
341        let mut mgr = BarrierManager::new();
342        mgr.register_resource(
343            ResourceId(1),
344            AccessType::ShaderWrite,
345            PipelineStage::Compute,
346        );
347        let emitted = mgr.transition(
348            ResourceId(1),
349            AccessType::ShaderRead,
350            PipelineStage::Compute,
351        );
352        assert!(emitted);
353        assert_eq!(mgr.pending_count(), 1);
354    }
355
356    #[test]
357    fn test_same_state_no_barrier() {
358        let mut mgr = BarrierManager::new();
359        mgr.register_resource(
360            ResourceId(1),
361            AccessType::ShaderRead,
362            PipelineStage::Compute,
363        );
364        let emitted = mgr.transition(
365            ResourceId(1),
366            AccessType::ShaderRead,
367            PipelineStage::Compute,
368        );
369        assert!(!emitted);
370        assert_eq!(mgr.pending_count(), 0);
371        assert_eq!(mgr.optimized_away(), 1);
372    }
373
374    #[test]
375    fn test_read_to_read_same_stage_no_barrier() {
376        let mut mgr = BarrierManager::new();
377        mgr.register_resource(
378            ResourceId(1),
379            AccessType::ShaderRead,
380            PipelineStage::Compute,
381        );
382        let emitted = mgr.transition(
383            ResourceId(1),
384            AccessType::TransferSrc,
385            PipelineStage::Compute,
386        );
387        assert!(!emitted);
388    }
389
390    #[test]
391    fn test_flush_clears_pending() {
392        let mut mgr = BarrierManager::new();
393        mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
394        mgr.transition(
395            ResourceId(1),
396            AccessType::ShaderWrite,
397            PipelineStage::Compute,
398        );
399        let barriers = mgr.flush();
400        assert_eq!(barriers.len(), 1);
401        assert_eq!(mgr.pending_count(), 0);
402    }
403
404    #[test]
405    fn test_barrier_desc_raw_hazard() {
406        let desc = BarrierDesc::new(
407            ResourceId(1),
408            AccessType::ShaderWrite,
409            AccessType::ShaderRead,
410            PipelineStage::Compute,
411            PipelineStage::Compute,
412        );
413        assert!(desc.is_raw_hazard());
414        assert!(!desc.is_waw_hazard());
415        assert!(!desc.is_war_hazard());
416    }
417
418    #[test]
419    fn test_barrier_desc_waw_hazard() {
420        let desc = BarrierDesc::new(
421            ResourceId(1),
422            AccessType::ShaderWrite,
423            AccessType::TransferDst,
424            PipelineStage::Compute,
425            PipelineStage::Transfer,
426        );
427        assert!(desc.is_waw_hazard());
428    }
429
430    #[test]
431    fn test_barrier_desc_war_hazard() {
432        let desc = BarrierDesc::new(
433            ResourceId(1),
434            AccessType::ShaderRead,
435            AccessType::ShaderWrite,
436            PipelineStage::Compute,
437            PipelineStage::Compute,
438        );
439        assert!(desc.is_war_hazard());
440    }
441
442    #[test]
443    fn test_unregister_resource() {
444        let mut mgr = BarrierManager::new();
445        mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
446        assert!(mgr.unregister_resource(ResourceId(1)));
447        assert!(!mgr.unregister_resource(ResourceId(1)));
448        assert_eq!(mgr.resource_count(), 0);
449    }
450
451    #[test]
452    fn test_batch_transition() {
453        let mut mgr = BarrierManager::new();
454        mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
455        mgr.register_resource(ResourceId(2), AccessType::None, PipelineStage::TopOfPipe);
456        let count = mgr.batch_transition(&[
457            (
458                ResourceId(1),
459                AccessType::ShaderWrite,
460                PipelineStage::Compute,
461            ),
462            (
463                ResourceId(2),
464                AccessType::TransferDst,
465                PipelineStage::Transfer,
466            ),
467        ]);
468        assert_eq!(count, 2);
469        assert_eq!(mgr.pending_count(), 2);
470    }
471
472    #[test]
473    fn test_reset() {
474        let mut mgr = BarrierManager::new();
475        mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
476        mgr.transition(
477            ResourceId(1),
478            AccessType::ShaderWrite,
479            PipelineStage::Compute,
480        );
481        mgr.reset();
482        assert_eq!(mgr.resource_count(), 0);
483        assert_eq!(mgr.pending_count(), 0);
484    }
485
486    #[test]
487    fn test_transition_unregistered_resource() {
488        let mut mgr = BarrierManager::new();
489        let emitted = mgr.transition(
490            ResourceId(99),
491            AccessType::ShaderRead,
492            PipelineStage::Compute,
493        );
494        assert!(emitted);
495        assert_eq!(mgr.resource_count(), 1);
496    }
497
498    #[test]
499    fn test_display_access_type() {
500        assert_eq!(format!("{}", AccessType::ShaderWrite), "ShaderWrite");
501        assert_eq!(format!("{}", AccessType::HostRead), "HostRead");
502    }
503
504    #[test]
505    fn test_display_pipeline_stage() {
506        assert_eq!(format!("{}", PipelineStage::Compute), "Compute");
507        assert_eq!(format!("{}", PipelineStage::BottomOfPipe), "BottomOfPipe");
508    }
509}