use std::{any, fmt, marker::PhantomData, mem, ops::Range};
use enumset::{EnumSet, EnumSetType};
use sorted_iter::{assume::AssumeSortedByItemExt, sorted_iterator::SortedByItem, SortedIterator};
use crate::{div_ceil, Bitset};
#[derive(Clone, PartialEq, Eq)]
pub struct EnumBitMatrix<R: EnumSetType>(Bitset<Box<[u32]>>, PhantomData<R>);
impl<R: EnumSetType> fmt::Debug for EnumBitMatrix<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("EnumBitMatrix")
.field(&any::type_name::<R>())
.field(&self.0)
.finish()
}
}
impl<R: EnumSetType> EnumBitMatrix<R> {
#[must_use]
pub fn new(width: u32) -> Self {
let len = width.checked_mul(R::BIT_WIDTH).unwrap() as usize;
let data = vec![0; div_ceil(len, mem::size_of::<u32>())];
Self(Bitset(data.into_boxed_slice()), PhantomData)
}
pub fn set_row(&mut self, row: R, iter: impl Iterator<Item = u32>) {
let row = row.enum_into_u32();
let width = self.bit_width();
let start = row * width;
for to_set in iter.filter(|i| *i < width).map(|i| i + start) {
unsafe { self.0.enable_bit(to_set as usize).unwrap_unchecked() };
}
}
#[must_use]
pub const fn bit_width(&self) -> u32 {
self.0 .0.len() as u32 / R::BIT_WIDTH
}
pub fn row(&self, row: R, mut range: Range<u32>) -> impl SortedIterator<Item = u32> + '_ {
let row = row.enum_into_u32();
let width = self.bit_width();
range.end = range.end.min(range.start + width);
range.start = range.start.min(range.end);
let start = row * width;
let subrange_start = (start + range.start) as usize;
let subrange_end = (start + range.end) as usize;
self.0
.ones_in_range(subrange_start..subrange_end)
.map(move |i| i - start)
.assume_sorted_by_item()
}
#[must_use]
pub const fn rows(&self, rows: EnumSet<R>, range: Range<u32>) -> Rows<R> {
Rows { range, rows, bitset: self }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Rows<'a, R: EnumSetType> {
range: Range<u32>,
rows: EnumSet<R>,
bitset: &'a EnumBitMatrix<R>,
}
impl<'a, R: EnumSetType> Iterator for Rows<'a, R> {
type Item = u32;
fn next(&mut self) -> Option<Self::Item> {
if self.range.is_empty() {
return None;
}
let range = self.range.clone();
self.range.start += 1;
self.rows
.iter()
.find_map(|row| self.bitset.row(row, range.clone()).next())
}
}
impl<R: EnumSetType> SortedByItem for Rows<'_, R> {}