use std::ops::Range;
#[derive(Debug, Clone, Copy)]
pub(crate) enum MemoryInitKind {
ImplicitlyInitialized,
NeedsInitializedMemory,
}
#[derive(Debug, Clone)]
pub(crate) struct MemoryInitTrackerAction<ResourceId> {
pub(crate) id: ResourceId,
pub(crate) range: Range<wgt::BufferAddress>,
pub(crate) kind: MemoryInitKind,
}
#[derive(Debug)]
pub(crate) struct MemoryInitTracker {
uninitialized_ranges: Vec<Range<wgt::BufferAddress>>,
}
pub(crate) struct MemoryInitTrackerDrain<'a> {
uninitialized_ranges: &'a mut Vec<Range<wgt::BufferAddress>>,
drain_range: Range<wgt::BufferAddress>,
first_index: usize,
next_index: usize,
}
impl<'a> Iterator for MemoryInitTrackerDrain<'a> {
type Item = Range<wgt::BufferAddress>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(r) = self
.uninitialized_ranges
.get(self.next_index)
.and_then(|range| {
if range.start < self.drain_range.end {
Some(range.clone())
} else {
None
}
})
{
self.next_index += 1;
Some(r.start.max(self.drain_range.start)..r.end.min(self.drain_range.end))
} else {
let num_affected = self.next_index - self.first_index;
if num_affected == 0 {
return None;
}
let first_range = &mut self.uninitialized_ranges[self.first_index];
if num_affected == 1
&& first_range.start < self.drain_range.start
&& first_range.end > self.drain_range.end
{
let old_start = first_range.start;
first_range.start = self.drain_range.end;
self.uninitialized_ranges
.insert(self.first_index, old_start..self.drain_range.start);
}
else {
let remove_start = if first_range.start >= self.drain_range.start {
self.first_index
} else {
first_range.end = self.drain_range.start;
self.first_index + 1
};
let last_range = &mut self.uninitialized_ranges[self.next_index - 1];
let remove_end = if last_range.end <= self.drain_range.end {
self.next_index
} else {
last_range.start = self.drain_range.end;
self.next_index - 1
};
self.uninitialized_ranges.drain(remove_start..remove_end);
}
None
}
}
}
impl MemoryInitTracker {
pub(crate) fn new(size: wgt::BufferAddress) -> Self {
Self {
uninitialized_ranges: vec![0..size],
}
}
fn lower_bound(&self, bound: wgt::BufferAddress) -> usize {
let mut left = 0;
let mut right = self.uninitialized_ranges.len();
while left != right {
let mid = left + (right - left) / 2;
let value = unsafe { self.uninitialized_ranges.get_unchecked(mid) };
if value.end <= bound {
left = mid + 1;
} else {
right = mid;
}
}
left
}
pub(crate) fn check(
&self,
query_range: Range<wgt::BufferAddress>,
) -> Option<Range<wgt::BufferAddress>> {
let index = self.lower_bound(query_range.start);
self.uninitialized_ranges
.get(index)
.map(|start_range| {
if start_range.start < query_range.end {
let start = start_range.start.max(query_range.start);
match self.uninitialized_ranges.get(index + 1) {
Some(next_range) => {
if next_range.start < query_range.end {
Some(start..query_range.end)
} else {
Some(start..start_range.end.min(query_range.end))
}
}
None => Some(start..start_range.end.min(query_range.end)),
}
} else {
None
}
})
.flatten()
}
#[must_use]
pub(crate) fn drain(
&mut self,
drain_range: Range<wgt::BufferAddress>,
) -> MemoryInitTrackerDrain {
let index = self.lower_bound(drain_range.start);
MemoryInitTrackerDrain {
drain_range,
uninitialized_ranges: &mut self.uninitialized_ranges,
first_index: index,
next_index: index,
}
}
pub(crate) fn clear(&mut self, range: Range<wgt::BufferAddress>) {
self.drain(range).for_each(drop);
}
}
#[cfg(test)]
mod test {
use super::MemoryInitTracker;
use std::ops::Range;
#[test]
fn check_for_newly_created_tracker() {
let tracker = MemoryInitTracker::new(10);
assert_eq!(tracker.check(0..10), Some(0..10));
assert_eq!(tracker.check(0..3), Some(0..3));
assert_eq!(tracker.check(3..4), Some(3..4));
assert_eq!(tracker.check(4..10), Some(4..10));
}
#[test]
fn check_for_cleared_tracker() {
let mut tracker = MemoryInitTracker::new(10);
tracker.clear(0..10);
assert_eq!(tracker.check(0..10), None);
assert_eq!(tracker.check(0..3), None);
assert_eq!(tracker.check(3..4), None);
assert_eq!(tracker.check(4..10), None);
}
#[test]
fn check_for_partially_filled_tracker() {
let mut tracker = MemoryInitTracker::new(25);
tracker.clear(0..5);
tracker.clear(10..15);
tracker.clear(20..25);
assert_eq!(tracker.check(0..25), Some(5..25));
assert_eq!(tracker.check(0..5), None); assert_eq!(tracker.check(3..8), Some(5..8)); assert_eq!(tracker.check(3..17), Some(5..17));
assert_eq!(tracker.check(8..22), Some(8..22)); assert_eq!(tracker.check(17..22), Some(17..20)); assert_eq!(tracker.check(20..25), None); }
#[test]
fn clear_already_cleared() {
let mut tracker = MemoryInitTracker::new(30);
tracker.clear(10..20);
tracker.clear(5..15); tracker.clear(15..25); tracker.clear(0..30);
tracker.clear(0..30);
assert_eq!(tracker.check(0..30), None);
}
#[test]
fn drain_never_returns_ranges_twice_for_same_range() {
let mut tracker = MemoryInitTracker::new(19);
assert_eq!(tracker.drain(0..19).count(), 1);
assert_eq!(tracker.drain(0..19).count(), 0);
let mut tracker = MemoryInitTracker::new(17);
assert_eq!(tracker.drain(5..8).count(), 1);
assert_eq!(tracker.drain(5..8).count(), 0);
assert_eq!(tracker.drain(1..3).count(), 1);
assert_eq!(tracker.drain(1..3).count(), 0);
assert_eq!(tracker.drain(7..13).count(), 1);
assert_eq!(tracker.drain(7..13).count(), 0);
}
#[test]
fn drain_splits_ranges_correctly() {
let mut tracker = MemoryInitTracker::new(1337);
assert_eq!(
tracker
.drain(21..42)
.collect::<Vec<Range<wgt::BufferAddress>>>(),
vec![21..42]
);
assert_eq!(
tracker
.drain(900..1000)
.collect::<Vec<Range<wgt::BufferAddress>>>(),
vec![900..1000]
);
assert_eq!(
tracker
.drain(5..1003)
.collect::<Vec<Range<wgt::BufferAddress>>>(),
vec![5..21, 42..900, 1000..1003]
);
assert_eq!(
tracker
.drain(0..1337)
.collect::<Vec<Range<wgt::BufferAddress>>>(),
vec![0..5, 1003..1337]
);
}
}