use core::hash::Hash;
use hashbrown::HashSet;
use crate::channel::Channel;
const MAX_CHANNELS: usize = 64;
#[derive(Debug)]
pub struct InvalidationSet<K>
where
K: Copy + Eq + Hash,
{
channels: [HashSet<K>; MAX_CHANNELS],
generation: u64,
}
impl<K> Default for InvalidationSet<K>
where
K: Copy + Eq + Hash,
{
fn default() -> Self {
Self::new()
}
}
impl<K> InvalidationSet<K>
where
K: Copy + Eq + Hash,
{
#[must_use]
pub fn new() -> Self {
Self {
channels: core::array::from_fn(|_| HashSet::new()),
generation: 0,
}
}
#[inline]
#[must_use]
pub fn generation(&self) -> u64 {
self.generation
}
#[inline]
pub fn mark(&mut self, key: K, channel: Channel) -> bool {
self.generation = self.generation.wrapping_add(1);
self.channels[channel.index() as usize].insert(key)
}
#[inline]
#[must_use]
pub fn is_invalidated(&self, key: K, channel: Channel) -> bool {
self.channels[channel.index() as usize].contains(&key)
}
#[inline]
#[must_use]
pub fn has_invalidated(&self, channel: Channel) -> bool {
!self.channels[channel.index() as usize].is_empty()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.channels.iter().all(HashSet::is_empty)
}
#[must_use]
pub fn len(&self, channel: Channel) -> usize {
self.channels[channel.index() as usize].len()
}
pub fn iter(&self, channel: Channel) -> impl Iterator<Item = K> + '_ {
self.channels[channel.index() as usize].iter().copied()
}
#[inline]
pub fn drain(&mut self, channel: Channel) -> impl Iterator<Item = K> + '_ {
self.generation = self.generation.wrapping_add(1);
self.channels[channel.index() as usize].drain()
}
#[inline]
pub fn take(&mut self, key: K, channel: Channel) -> bool {
let removed = self.channels[channel.index() as usize].remove(&key);
if removed {
self.generation = self.generation.wrapping_add(1);
}
removed
}
pub fn clear(&mut self, channel: Channel) {
self.generation = self.generation.wrapping_add(1);
self.channels[channel.index() as usize].clear();
}
pub fn clear_all(&mut self) {
self.generation = self.generation.wrapping_add(1);
for set in &mut self.channels {
set.clear();
}
}
pub fn remove_key(&mut self, key: K) {
let mut removed = false;
for set in &mut self.channels {
removed |= set.remove(&key);
}
if removed {
self.generation = self.generation.wrapping_add(1);
}
}
}
impl<K> Clone for InvalidationSet<K>
where
K: Copy + Eq + Hash,
{
fn clone(&self) -> Self {
Self {
channels: core::array::from_fn(|i| self.channels[i].clone()),
generation: self.generation,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
const LAYOUT: Channel = Channel::new(0);
const PAINT: Channel = Channel::new(1);
#[test]
fn mark_and_query() {
let mut invalidated = InvalidationSet::<u32>::new();
assert!(!invalidated.is_invalidated(1, LAYOUT));
assert!(invalidated.is_empty());
let inserted = invalidated.mark(1, LAYOUT);
assert!(inserted);
assert!(invalidated.is_invalidated(1, LAYOUT));
assert!(!invalidated.is_empty());
assert!(invalidated.has_invalidated(LAYOUT));
let inserted_again = invalidated.mark(1, LAYOUT);
assert!(!inserted_again);
}
#[test]
fn channel_independence() {
let mut invalidated = InvalidationSet::<u32>::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(2, PAINT);
assert!(invalidated.is_invalidated(1, LAYOUT));
assert!(!invalidated.is_invalidated(1, PAINT));
assert!(!invalidated.is_invalidated(2, LAYOUT));
assert!(invalidated.is_invalidated(2, PAINT));
}
#[test]
fn drain_clears_channel() {
let mut invalidated = InvalidationSet::<u32>::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(2, LAYOUT);
invalidated.mark(1, PAINT);
let drained: Vec<_> = invalidated.drain(LAYOUT).collect();
assert_eq!(drained.len(), 2);
assert!(!invalidated.has_invalidated(LAYOUT));
assert!(invalidated.has_invalidated(PAINT));
}
#[test]
fn take_removes_single_key() {
let mut invalidated = InvalidationSet::<u32>::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(2, LAYOUT);
assert!(invalidated.take(1, LAYOUT));
assert!(!invalidated.is_invalidated(1, LAYOUT));
assert!(invalidated.is_invalidated(2, LAYOUT));
assert!(!invalidated.take(1, LAYOUT));
}
#[test]
fn clear_specific_channel() {
let mut invalidated = InvalidationSet::<u32>::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(1, PAINT);
invalidated.clear(LAYOUT);
assert!(!invalidated.has_invalidated(LAYOUT));
assert!(invalidated.has_invalidated(PAINT));
}
#[test]
fn clear_all() {
let mut invalidated = InvalidationSet::<u32>::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(2, PAINT);
invalidated.clear_all();
assert!(invalidated.is_empty());
}
#[test]
fn remove_key_from_all_channels() {
let mut invalidated = InvalidationSet::<u32>::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(1, PAINT);
invalidated.mark(2, LAYOUT);
invalidated.remove_key(1);
assert!(!invalidated.is_invalidated(1, LAYOUT));
assert!(!invalidated.is_invalidated(1, PAINT));
assert!(invalidated.is_invalidated(2, LAYOUT));
}
#[test]
fn generation_increments() {
let mut invalidated = InvalidationSet::<u32>::new();
let initial = invalidated.generation();
invalidated.mark(1, LAYOUT);
assert_eq!(invalidated.generation(), initial + 1);
invalidated.mark(2, LAYOUT);
assert_eq!(invalidated.generation(), initial + 2);
let _ = invalidated.drain(LAYOUT).count();
assert_eq!(invalidated.generation(), initial + 3);
invalidated.clear(PAINT);
assert_eq!(invalidated.generation(), initial + 4);
}
#[test]
fn len_and_iter() {
let mut invalidated = InvalidationSet::<u32>::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(2, LAYOUT);
invalidated.mark(3, LAYOUT);
assert_eq!(invalidated.len(LAYOUT), 3);
assert_eq!(invalidated.len(PAINT), 0);
let keys: Vec<_> = invalidated.iter(LAYOUT).collect();
assert_eq!(keys.len(), 3);
}
}