#![allow(dead_code)]
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AccessType {
None,
ShaderRead,
ShaderWrite,
TransferSrc,
TransferDst,
HostRead,
HostWrite,
}
impl fmt::Display for AccessType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "None"),
Self::ShaderRead => write!(f, "ShaderRead"),
Self::ShaderWrite => write!(f, "ShaderWrite"),
Self::TransferSrc => write!(f, "TransferSrc"),
Self::TransferDst => write!(f, "TransferDst"),
Self::HostRead => write!(f, "HostRead"),
Self::HostWrite => write!(f, "HostWrite"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PipelineStage {
TopOfPipe,
Compute,
Transfer,
Host,
BottomOfPipe,
}
impl fmt::Display for PipelineStage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TopOfPipe => write!(f, "TopOfPipe"),
Self::Compute => write!(f, "Compute"),
Self::Transfer => write!(f, "Transfer"),
Self::Host => write!(f, "Host"),
Self::BottomOfPipe => write!(f, "BottomOfPipe"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ResourceId(pub u64);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BarrierDesc {
pub resource_id: ResourceId,
pub src_access: AccessType,
pub dst_access: AccessType,
pub src_stage: PipelineStage,
pub dst_stage: PipelineStage,
}
impl BarrierDesc {
pub fn new(
resource_id: ResourceId,
src_access: AccessType,
dst_access: AccessType,
src_stage: PipelineStage,
dst_stage: PipelineStage,
) -> Self {
Self {
resource_id,
src_access,
dst_access,
src_stage,
dst_stage,
}
}
pub fn is_raw_hazard(&self) -> bool {
matches!(self.src_access, AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite)
&& matches!(self.dst_access, AccessType::ShaderRead | AccessType::TransferSrc | AccessType::HostRead)
}
pub fn is_waw_hazard(&self) -> bool {
matches!(self.src_access, AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite)
&& matches!(self.dst_access, AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite)
}
pub fn is_war_hazard(&self) -> bool {
matches!(self.src_access, AccessType::ShaderRead | AccessType::TransferSrc | AccessType::HostRead)
&& matches!(self.dst_access, AccessType::ShaderWrite | AccessType::TransferDst | AccessType::HostWrite)
}
}
#[derive(Debug, Clone)]
struct ResourceState {
access: AccessType,
stage: PipelineStage,
}
pub struct BarrierManager {
states: HashMap<ResourceId, ResourceState>,
pending: Vec<BarrierDesc>,
total_barriers: u64,
optimized_away: u64,
}
impl BarrierManager {
pub fn new() -> Self {
Self {
states: HashMap::new(),
pending: Vec::new(),
total_barriers: 0,
optimized_away: 0,
}
}
pub fn register_resource(&mut self, id: ResourceId, initial_access: AccessType, stage: PipelineStage) {
self.states.insert(id, ResourceState {
access: initial_access,
stage,
});
}
pub fn transition(&mut self, id: ResourceId, new_access: AccessType, new_stage: PipelineStage) -> bool {
let current = self.states.get(&id).cloned().unwrap_or(ResourceState {
access: AccessType::None,
stage: PipelineStage::TopOfPipe,
});
if current.access == new_access && current.stage == new_stage {
self.optimized_away += 1;
return false;
}
if is_read_only(current.access) && is_read_only(new_access) && current.stage == new_stage {
self.optimized_away += 1;
self.states.insert(id, ResourceState {
access: new_access,
stage: new_stage,
});
return false;
}
let barrier = BarrierDesc::new(id, current.access, new_access, current.stage, new_stage);
self.pending.push(barrier);
self.total_barriers += 1;
self.states.insert(id, ResourceState {
access: new_access,
stage: new_stage,
});
true
}
pub fn flush(&mut self) -> Vec<BarrierDesc> {
std::mem::take(&mut self.pending)
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn total_barriers(&self) -> u64 {
self.total_barriers
}
pub fn optimized_away(&self) -> u64 {
self.optimized_away
}
pub fn current_access(&self, id: ResourceId) -> Option<AccessType> {
self.states.get(&id).map(|s| s.access)
}
pub fn current_stage(&self, id: ResourceId) -> Option<PipelineStage> {
self.states.get(&id).map(|s| s.stage)
}
pub fn unregister_resource(&mut self, id: ResourceId) -> bool {
self.states.remove(&id).is_some()
}
pub fn resource_count(&self) -> usize {
self.states.len()
}
pub fn reset(&mut self) {
self.states.clear();
self.pending.clear();
}
pub fn batch_transition(&mut self, transitions: &[(ResourceId, AccessType, PipelineStage)]) -> usize {
let mut count = 0;
for &(id, access, stage) in transitions {
if self.transition(id, access, stage) {
count += 1;
}
}
count
}
}
impl Default for BarrierManager {
fn default() -> Self {
Self::new()
}
}
fn is_read_only(access: AccessType) -> bool {
matches!(access, AccessType::ShaderRead | AccessType::TransferSrc | AccessType::HostRead | AccessType::None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_barrier_manager() {
let mgr = BarrierManager::new();
assert_eq!(mgr.resource_count(), 0);
assert_eq!(mgr.pending_count(), 0);
assert_eq!(mgr.total_barriers(), 0);
}
#[test]
fn test_register_resource() {
let mut mgr = BarrierManager::new();
mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
assert_eq!(mgr.resource_count(), 1);
assert_eq!(mgr.current_access(ResourceId(1)), Some(AccessType::None));
}
#[test]
fn test_transition_emits_barrier() {
let mut mgr = BarrierManager::new();
mgr.register_resource(ResourceId(1), AccessType::ShaderWrite, PipelineStage::Compute);
let emitted = mgr.transition(ResourceId(1), AccessType::ShaderRead, PipelineStage::Compute);
assert!(emitted);
assert_eq!(mgr.pending_count(), 1);
}
#[test]
fn test_same_state_no_barrier() {
let mut mgr = BarrierManager::new();
mgr.register_resource(ResourceId(1), AccessType::ShaderRead, PipelineStage::Compute);
let emitted = mgr.transition(ResourceId(1), AccessType::ShaderRead, PipelineStage::Compute);
assert!(!emitted);
assert_eq!(mgr.pending_count(), 0);
assert_eq!(mgr.optimized_away(), 1);
}
#[test]
fn test_read_to_read_same_stage_no_barrier() {
let mut mgr = BarrierManager::new();
mgr.register_resource(ResourceId(1), AccessType::ShaderRead, PipelineStage::Compute);
let emitted = mgr.transition(ResourceId(1), AccessType::TransferSrc, PipelineStage::Compute);
assert!(!emitted);
}
#[test]
fn test_flush_clears_pending() {
let mut mgr = BarrierManager::new();
mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
mgr.transition(ResourceId(1), AccessType::ShaderWrite, PipelineStage::Compute);
let barriers = mgr.flush();
assert_eq!(barriers.len(), 1);
assert_eq!(mgr.pending_count(), 0);
}
#[test]
fn test_barrier_desc_raw_hazard() {
let desc = BarrierDesc::new(
ResourceId(1),
AccessType::ShaderWrite,
AccessType::ShaderRead,
PipelineStage::Compute,
PipelineStage::Compute,
);
assert!(desc.is_raw_hazard());
assert!(!desc.is_waw_hazard());
assert!(!desc.is_war_hazard());
}
#[test]
fn test_barrier_desc_waw_hazard() {
let desc = BarrierDesc::new(
ResourceId(1),
AccessType::ShaderWrite,
AccessType::TransferDst,
PipelineStage::Compute,
PipelineStage::Transfer,
);
assert!(desc.is_waw_hazard());
}
#[test]
fn test_barrier_desc_war_hazard() {
let desc = BarrierDesc::new(
ResourceId(1),
AccessType::ShaderRead,
AccessType::ShaderWrite,
PipelineStage::Compute,
PipelineStage::Compute,
);
assert!(desc.is_war_hazard());
}
#[test]
fn test_unregister_resource() {
let mut mgr = BarrierManager::new();
mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
assert!(mgr.unregister_resource(ResourceId(1)));
assert!(!mgr.unregister_resource(ResourceId(1)));
assert_eq!(mgr.resource_count(), 0);
}
#[test]
fn test_batch_transition() {
let mut mgr = BarrierManager::new();
mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
mgr.register_resource(ResourceId(2), AccessType::None, PipelineStage::TopOfPipe);
let count = mgr.batch_transition(&[
(ResourceId(1), AccessType::ShaderWrite, PipelineStage::Compute),
(ResourceId(2), AccessType::TransferDst, PipelineStage::Transfer),
]);
assert_eq!(count, 2);
assert_eq!(mgr.pending_count(), 2);
}
#[test]
fn test_reset() {
let mut mgr = BarrierManager::new();
mgr.register_resource(ResourceId(1), AccessType::None, PipelineStage::TopOfPipe);
mgr.transition(ResourceId(1), AccessType::ShaderWrite, PipelineStage::Compute);
mgr.reset();
assert_eq!(mgr.resource_count(), 0);
assert_eq!(mgr.pending_count(), 0);
}
#[test]
fn test_transition_unregistered_resource() {
let mut mgr = BarrierManager::new();
let emitted = mgr.transition(ResourceId(99), AccessType::ShaderRead, PipelineStage::Compute);
assert!(emitted);
assert_eq!(mgr.resource_count(), 1);
}
#[test]
fn test_display_access_type() {
assert_eq!(format!("{}", AccessType::ShaderWrite), "ShaderWrite");
assert_eq!(format!("{}", AccessType::HostRead), "HostRead");
}
#[test]
fn test_display_pipeline_stage() {
assert_eq!(format!("{}", PipelineStage::Compute), "Compute");
assert_eq!(format!("{}", PipelineStage::BottomOfPipe), "BottomOfPipe");
}
}