1#![allow(dead_code)]
2use std::collections::HashMap;
9use std::fmt;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum AccessType {
14 None,
16 ShaderRead,
18 ShaderWrite,
20 TransferSrc,
22 TransferDst,
24 HostRead,
26 HostWrite,
28}
29
30impl fmt::Display for AccessType {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 match self {
33 Self::None => write!(f, "None"),
34 Self::ShaderRead => write!(f, "ShaderRead"),
35 Self::ShaderWrite => write!(f, "ShaderWrite"),
36 Self::TransferSrc => write!(f, "TransferSrc"),
37 Self::TransferDst => write!(f, "TransferDst"),
38 Self::HostRead => write!(f, "HostRead"),
39 Self::HostWrite => write!(f, "HostWrite"),
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum PipelineStage {
47 TopOfPipe,
49 Compute,
51 Transfer,
53 Host,
55 BottomOfPipe,
57}
58
59impl fmt::Display for PipelineStage {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 match self {
62 Self::TopOfPipe => write!(f, "TopOfPipe"),
63 Self::Compute => write!(f, "Compute"),
64 Self::Transfer => write!(f, "Transfer"),
65 Self::Host => write!(f, "Host"),
66 Self::BottomOfPipe => write!(f, "BottomOfPipe"),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73pub struct ResourceId(pub u64);
74
75#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct BarrierDesc {
78 pub resource_id: ResourceId,
80 pub src_access: AccessType,
82 pub dst_access: AccessType,
84 pub src_stage: PipelineStage,
86 pub dst_stage: PipelineStage,
88}
89
90impl BarrierDesc {
91 pub fn new(
93 resource_id: ResourceId,
94 src_access: AccessType,
95 dst_access: AccessType,
96 src_stage: PipelineStage,
97 dst_stage: PipelineStage,
98 ) -> Self {
99 Self {
100 resource_id,
101 src_access,
102 dst_access,
103 src_stage,
104 dst_stage,
105 }
106 }
107
108 pub fn is_raw_hazard(&self) -> bool {
110 matches!(
111 self.src_access,
112 AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite
113 ) && matches!(
114 self.dst_access,
115 AccessType::ShaderRead | AccessType::TransferSrc | AccessType::HostRead
116 )
117 }
118
119 pub fn is_waw_hazard(&self) -> bool {
121 matches!(
122 self.src_access,
123 AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite
124 ) && matches!(
125 self.dst_access,
126 AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite
127 )
128 }
129
130 pub fn is_war_hazard(&self) -> bool {
132 matches!(
133 self.src_access,
134 AccessType::ShaderRead | AccessType::TransferSrc | AccessType::HostRead
135 ) && matches!(
136 self.dst_access,
137 AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite
138 )
139 }
140}
141
142#[derive(Debug, Clone)]
144struct ResourceState {
145 access: AccessType,
147 stage: PipelineStage,
149}
150
151pub struct BarrierManager {
157 states: HashMap<ResourceId, ResourceState>,
159 pending: Vec<BarrierDesc>,
161 total_barriers: u64,
163 optimized_away: u64,
165}
166
167impl BarrierManager {
168 pub fn new() -> Self {
170 Self {
171 states: HashMap::new(),
172 pending: Vec::new(),
173 total_barriers: 0,
174 optimized_away: 0,
175 }
176 }
177
178 pub fn register_resource(
180 &mut self,
181 id: ResourceId,
182 initial_access: AccessType,
183 stage: PipelineStage,
184 ) {
185 self.states.insert(
186 id,
187 ResourceState {
188 access: initial_access,
189 stage,
190 },
191 );
192 }
193
194 pub fn transition(
199 &mut self,
200 id: ResourceId,
201 new_access: AccessType,
202 new_stage: PipelineStage,
203 ) -> bool {
204 let current = self.states.get(&id).cloned().unwrap_or(ResourceState {
205 access: AccessType::None,
206 stage: PipelineStage::TopOfPipe,
207 });
208
209 if current.access == new_access && current.stage == new_stage {
211 self.optimized_away += 1;
212 return false;
213 }
214
215 if is_read_only(current.access) && is_read_only(new_access) && current.stage == new_stage {
217 self.optimized_away += 1;
218 self.states.insert(
220 id,
221 ResourceState {
222 access: new_access,
223 stage: new_stage,
224 },
225 );
226 return false;
227 }
228
229 let barrier = BarrierDesc::new(id, current.access, new_access, current.stage, new_stage);
230 self.pending.push(barrier);
231 self.total_barriers += 1;
232
233 self.states.insert(
234 id,
235 ResourceState {
236 access: new_access,
237 stage: new_stage,
238 },
239 );
240
241 true
242 }
243
244 pub fn flush(&mut self) -> Vec<BarrierDesc> {
246 std::mem::take(&mut self.pending)
247 }
248
249 pub fn pending_count(&self) -> usize {
251 self.pending.len()
252 }
253
254 pub fn total_barriers(&self) -> u64 {
256 self.total_barriers
257 }
258
259 pub fn optimized_away(&self) -> u64 {
261 self.optimized_away
262 }
263
264 pub fn current_access(&self, id: ResourceId) -> Option<AccessType> {
266 self.states.get(&id).map(|s| s.access)
267 }
268
269 pub fn current_stage(&self, id: ResourceId) -> Option<PipelineStage> {
271 self.states.get(&id).map(|s| s.stage)
272 }
273
274 pub fn unregister_resource(&mut self, id: ResourceId) -> bool {
276 self.states.remove(&id).is_some()
277 }
278
279 pub fn resource_count(&self) -> usize {
281 self.states.len()
282 }
283
284 pub fn reset(&mut self) {
286 self.states.clear();
287 self.pending.clear();
288 }
289
290 pub fn batch_transition(
292 &mut self,
293 transitions: &[(ResourceId, AccessType, PipelineStage)],
294 ) -> usize {
295 let mut count = 0;
296 for &(id, access, stage) in transitions {
297 if self.transition(id, access, stage) {
298 count += 1;
299 }
300 }
301 count
302 }
303}
304
305impl Default for BarrierManager {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311fn is_read_only(access: AccessType) -> bool {
313 matches!(
314 access,
315 AccessType::ShaderRead | AccessType::TransferSrc | AccessType::HostRead | AccessType::None
316 )
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_new_barrier_manager() {
325 let mgr = BarrierManager::new();
326 assert_eq!(mgr.resource_count(), 0);
327 assert_eq!(mgr.pending_count(), 0);
328 assert_eq!(mgr.total_barriers(), 0);
329 }
330
331 #[test]
332 fn test_register_resource() {
333 let mut mgr = BarrierManager::new();
334 mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
335 assert_eq!(mgr.resource_count(), 1);
336 assert_eq!(mgr.current_access(ResourceId(1)), Some(AccessType::None));
337 }
338
339 #[test]
340 fn test_transition_emits_barrier() {
341 let mut mgr = BarrierManager::new();
342 mgr.register_resource(
343 ResourceId(1),
344 AccessType::ShaderWrite,
345 PipelineStage::Compute,
346 );
347 let emitted = mgr.transition(
348 ResourceId(1),
349 AccessType::ShaderRead,
350 PipelineStage::Compute,
351 );
352 assert!(emitted);
353 assert_eq!(mgr.pending_count(), 1);
354 }
355
356 #[test]
357 fn test_same_state_no_barrier() {
358 let mut mgr = BarrierManager::new();
359 mgr.register_resource(
360 ResourceId(1),
361 AccessType::ShaderRead,
362 PipelineStage::Compute,
363 );
364 let emitted = mgr.transition(
365 ResourceId(1),
366 AccessType::ShaderRead,
367 PipelineStage::Compute,
368 );
369 assert!(!emitted);
370 assert_eq!(mgr.pending_count(), 0);
371 assert_eq!(mgr.optimized_away(), 1);
372 }
373
374 #[test]
375 fn test_read_to_read_same_stage_no_barrier() {
376 let mut mgr = BarrierManager::new();
377 mgr.register_resource(
378 ResourceId(1),
379 AccessType::ShaderRead,
380 PipelineStage::Compute,
381 );
382 let emitted = mgr.transition(
383 ResourceId(1),
384 AccessType::TransferSrc,
385 PipelineStage::Compute,
386 );
387 assert!(!emitted);
388 }
389
390 #[test]
391 fn test_flush_clears_pending() {
392 let mut mgr = BarrierManager::new();
393 mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
394 mgr.transition(
395 ResourceId(1),
396 AccessType::ShaderWrite,
397 PipelineStage::Compute,
398 );
399 let barriers = mgr.flush();
400 assert_eq!(barriers.len(), 1);
401 assert_eq!(mgr.pending_count(), 0);
402 }
403
404 #[test]
405 fn test_barrier_desc_raw_hazard() {
406 let desc = BarrierDesc::new(
407 ResourceId(1),
408 AccessType::ShaderWrite,
409 AccessType::ShaderRead,
410 PipelineStage::Compute,
411 PipelineStage::Compute,
412 );
413 assert!(desc.is_raw_hazard());
414 assert!(!desc.is_waw_hazard());
415 assert!(!desc.is_war_hazard());
416 }
417
418 #[test]
419 fn test_barrier_desc_waw_hazard() {
420 let desc = BarrierDesc::new(
421 ResourceId(1),
422 AccessType::ShaderWrite,
423 AccessType::TransferDst,
424 PipelineStage::Compute,
425 PipelineStage::Transfer,
426 );
427 assert!(desc.is_waw_hazard());
428 }
429
430 #[test]
431 fn test_barrier_desc_war_hazard() {
432 let desc = BarrierDesc::new(
433 ResourceId(1),
434 AccessType::ShaderRead,
435 AccessType::ShaderWrite,
436 PipelineStage::Compute,
437 PipelineStage::Compute,
438 );
439 assert!(desc.is_war_hazard());
440 }
441
442 #[test]
443 fn test_unregister_resource() {
444 let mut mgr = BarrierManager::new();
445 mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
446 assert!(mgr.unregister_resource(ResourceId(1)));
447 assert!(!mgr.unregister_resource(ResourceId(1)));
448 assert_eq!(mgr.resource_count(), 0);
449 }
450
451 #[test]
452 fn test_batch_transition() {
453 let mut mgr = BarrierManager::new();
454 mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
455 mgr.register_resource(ResourceId(2), AccessType::None, PipelineStage::TopOfPipe);
456 let count = mgr.batch_transition(&[
457 (
458 ResourceId(1),
459 AccessType::ShaderWrite,
460 PipelineStage::Compute,
461 ),
462 (
463 ResourceId(2),
464 AccessType::TransferDst,
465 PipelineStage::Transfer,
466 ),
467 ]);
468 assert_eq!(count, 2);
469 assert_eq!(mgr.pending_count(), 2);
470 }
471
472 #[test]
473 fn test_reset() {
474 let mut mgr = BarrierManager::new();
475 mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
476 mgr.transition(
477 ResourceId(1),
478 AccessType::ShaderWrite,
479 PipelineStage::Compute,
480 );
481 mgr.reset();
482 assert_eq!(mgr.resource_count(), 0);
483 assert_eq!(mgr.pending_count(), 0);
484 }
485
486 #[test]
487 fn test_transition_unregistered_resource() {
488 let mut mgr = BarrierManager::new();
489 let emitted = mgr.transition(
490 ResourceId(99),
491 AccessType::ShaderRead,
492 PipelineStage::Compute,
493 );
494 assert!(emitted);
495 assert_eq!(mgr.resource_count(), 1);
496 }
497
498 #[test]
499 fn test_display_access_type() {
500 assert_eq!(format!("{}", AccessType::ShaderWrite), "ShaderWrite");
501 assert_eq!(format!("{}", AccessType::HostRead), "HostRead");
502 }
503
504 #[test]
505 fn test_display_pipeline_stage() {
506 assert_eq!(format!("{}", PipelineStage::Compute), "Compute");
507 assert_eq!(format!("{}", PipelineStage::BottomOfPipe), "BottomOfPipe");
508 }
509}