use crate::{
hub,
id::{self, TypedId},
Epoch, LifeGuard, RefCount,
};
use bit_vec::BitVec;
use std::{borrow::Cow, marker::PhantomData, mem};
use wgt::strict_assert;
#[derive(Debug)]
pub(super) struct ResourceMetadata<A: hub::HalApi> {
owned: BitVec<usize>,
ref_counts: Vec<Option<RefCount>>,
epochs: Vec<Epoch>,
_phantom: PhantomData<A>,
}
impl<A: hub::HalApi> ResourceMetadata<A> {
pub(super) fn new() -> Self {
Self {
owned: BitVec::default(),
ref_counts: Vec::new(),
epochs: Vec::new(),
_phantom: PhantomData,
}
}
pub(super) fn size(&self) -> usize {
self.owned.len()
}
pub(super) fn set_size(&mut self, size: usize) {
self.ref_counts.resize(size, None);
self.epochs.resize(size, u32::MAX);
resize_bitvec(&mut self.owned, size);
}
#[cfg_attr(not(feature = "strict_asserts"), allow(unused_variables))]
pub(super) fn tracker_assert_in_bounds(&self, index: usize) {
strict_assert!(index < self.owned.len());
strict_assert!(index < self.ref_counts.len());
strict_assert!(index < self.epochs.len());
strict_assert!(if self.contains(index) {
self.ref_counts[index].is_some()
} else {
true
});
}
pub(super) fn is_empty(&self) -> bool {
!self.owned.any()
}
pub(super) fn contains(&self, index: usize) -> bool {
self.owned[index]
}
#[inline(always)]
pub(super) unsafe fn contains_unchecked(&self, index: usize) -> bool {
unsafe { self.owned.get(index).unwrap_unchecked() }
}
#[inline(always)]
pub(super) unsafe fn insert(&mut self, index: usize, epoch: Epoch, ref_count: RefCount) {
self.owned.set(index, true);
unsafe {
*self.epochs.get_unchecked_mut(index) = epoch;
*self.ref_counts.get_unchecked_mut(index) = Some(ref_count);
}
}
#[inline(always)]
pub(super) unsafe fn get_ref_count_unchecked(&self, index: usize) -> &RefCount {
unsafe {
self.ref_counts
.get_unchecked(index)
.as_ref()
.unwrap_unchecked()
}
}
#[inline(always)]
pub(super) unsafe fn get_epoch_unchecked(&self, index: usize) -> Epoch {
unsafe { *self.epochs.get_unchecked(index) }
}
pub(super) fn owned_ids<Id: TypedId>(&self) -> impl Iterator<Item = id::Valid<Id>> + '_ {
if !self.owned.is_empty() {
self.tracker_assert_in_bounds(self.owned.len() - 1)
};
iterate_bitvec_indices(&self.owned).map(move |index| {
let epoch = unsafe { *self.epochs.get_unchecked(index) };
id::Valid(Id::zip(index as u32, epoch, A::VARIANT))
})
}
pub(super) fn owned_indices(&self) -> impl Iterator<Item = usize> + '_ {
if !self.owned.is_empty() {
self.tracker_assert_in_bounds(self.owned.len() - 1)
};
iterate_bitvec_indices(&self.owned)
}
pub(super) unsafe fn remove(&mut self, index: usize) {
unsafe {
*self.ref_counts.get_unchecked_mut(index) = None;
*self.epochs.get_unchecked_mut(index) = u32::MAX;
}
self.owned.set(index, false);
}
}
pub(super) enum ResourceMetadataProvider<'a, A: hub::HalApi> {
Direct {
epoch: Epoch,
ref_count: Cow<'a, RefCount>,
},
Indirect { metadata: &'a ResourceMetadata<A> },
Resource { epoch: Epoch },
}
impl<A: hub::HalApi> ResourceMetadataProvider<'_, A> {
#[inline(always)]
pub(super) unsafe fn get_own(
self,
life_guard: Option<&LifeGuard>,
index: usize,
) -> (Epoch, RefCount) {
match self {
ResourceMetadataProvider::Direct { epoch, ref_count } => {
(epoch, ref_count.into_owned())
}
ResourceMetadataProvider::Indirect { metadata } => {
metadata.tracker_assert_in_bounds(index);
(unsafe { *metadata.epochs.get_unchecked(index) }, {
let ref_count = unsafe { metadata.ref_counts.get_unchecked(index) };
unsafe { ref_count.clone().unwrap_unchecked() }
})
}
ResourceMetadataProvider::Resource { epoch } => {
strict_assert!(life_guard.is_some());
(epoch, unsafe { life_guard.unwrap_unchecked() }.add_ref())
}
}
}
#[inline(always)]
pub(super) unsafe fn get_epoch(self, index: usize) -> Epoch {
match self {
ResourceMetadataProvider::Direct { epoch, .. }
| ResourceMetadataProvider::Resource { epoch, .. } => epoch,
ResourceMetadataProvider::Indirect { metadata } => {
metadata.tracker_assert_in_bounds(index);
unsafe { *metadata.epochs.get_unchecked(index) }
}
}
}
}
fn resize_bitvec<B: bit_vec::BitBlock>(vec: &mut BitVec<B>, size: usize) {
let owned_size_to_grow = size.checked_sub(vec.len());
if let Some(delta) = owned_size_to_grow {
if delta != 0 {
vec.grow(delta, false);
}
} else {
vec.truncate(size);
}
}
fn iterate_bitvec_indices(ownership: &BitVec<usize>) -> impl Iterator<Item = usize> + '_ {
const BITS_PER_BLOCK: usize = mem::size_of::<usize>() * 8;
let size = ownership.len();
ownership
.blocks()
.enumerate()
.filter(|&(_, word)| word != 0)
.flat_map(move |(word_index, mut word)| {
let bit_start = word_index * BITS_PER_BLOCK;
let bit_end = (bit_start + BITS_PER_BLOCK).min(size);
(bit_start..bit_end).filter(move |_| {
let active = word & 0b1 != 0;
word >>= 1;
active
})
})
}