use std::sync::Arc;
use super::selection::SelectionVector;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum FactorizationState {
Flat {
row_count: usize,
},
Unflat {
level_count: usize,
logical_rows: usize,
},
}
impl FactorizationState {
#[must_use]
pub fn is_flat(&self) -> bool {
matches!(self, Self::Flat { .. })
}
#[must_use]
pub fn is_unflat(&self) -> bool {
matches!(self, Self::Unflat { .. })
}
#[must_use]
pub fn logical_row_count(&self) -> usize {
match self {
Self::Flat { row_count } => *row_count,
Self::Unflat { logical_rows, .. } => *logical_rows,
}
}
#[must_use]
pub fn level_count(&self) -> usize {
match self {
Self::Flat { .. } => 1,
Self::Unflat { level_count, .. } => *level_count,
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum LevelSelection {
All {
count: usize,
},
Sparse(SelectionVector),
}
impl LevelSelection {
#[must_use]
pub fn all(count: usize) -> Self {
Self::All { count }
}
#[must_use]
pub fn from_predicate<F>(count: usize, predicate: F) -> Self
where
F: Fn(usize) -> bool,
{
let selected = SelectionVector::from_predicate(count, predicate);
if selected.len() == count {
Self::All { count }
} else {
Self::Sparse(selected)
}
}
#[must_use]
pub fn selected_count(&self) -> usize {
match self {
Self::All { count } => *count,
Self::Sparse(sel) => sel.len(),
}
}
#[must_use]
pub fn is_selected(&self, physical_idx: usize) -> bool {
match self {
Self::All { count } => physical_idx < *count,
Self::Sparse(sel) => sel.contains(physical_idx),
}
}
#[must_use]
pub fn filter<F>(&self, predicate: F) -> Self
where
F: Fn(usize) -> bool,
{
match self {
Self::All { count } => Self::from_predicate(*count, predicate),
Self::Sparse(sel) => {
let filtered = sel.filter(predicate);
Self::Sparse(filtered)
}
}
}
pub fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
match self {
Self::All { count } => Box::new(0..*count),
Self::Sparse(sel) => Box::new(sel.iter()),
}
}
}
impl<'a> IntoIterator for &'a LevelSelection {
type Item = usize;
type IntoIter = Box<dyn Iterator<Item = usize> + 'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Debug, Clone)]
pub struct FactorizedSelection {
level_selections: Vec<LevelSelection>,
cached_selected_count: Option<usize>,
}
impl FactorizedSelection {
#[must_use]
pub fn all(level_counts: &[usize]) -> Self {
let level_selections = level_counts
.iter()
.map(|&count| LevelSelection::all(count))
.collect();
Self {
level_selections,
cached_selected_count: None,
}
}
#[must_use]
pub fn new(level_selections: Vec<LevelSelection>) -> Self {
Self {
level_selections,
cached_selected_count: None,
}
}
#[must_use]
pub fn level_count(&self) -> usize {
self.level_selections.len()
}
#[must_use]
pub fn level(&self, level: usize) -> Option<&LevelSelection> {
self.level_selections.get(level)
}
#[must_use]
pub fn filter_level<F>(&self, level: usize, predicate: F) -> Self
where
F: Fn(usize) -> bool,
{
let mut new_selections = self.level_selections.clone();
if let Some(sel) = new_selections.get_mut(level) {
*sel = sel.filter(predicate);
}
Self {
level_selections: new_selections,
cached_selected_count: None, }
}
#[must_use]
pub fn is_selected(&self, level: usize, physical_idx: usize) -> bool {
self.level_selections
.get(level)
.is_some_and(|sel| sel.is_selected(physical_idx))
}
pub fn selected_count(&mut self, multiplicities: &[Vec<usize>]) -> usize {
if let Some(count) = self.cached_selected_count {
return count;
}
let count = self.compute_selected_count(multiplicities);
self.cached_selected_count = Some(count);
count
}
fn compute_selected_count(&self, multiplicities: &[Vec<usize>]) -> usize {
if self.level_selections.is_empty() {
return 0;
}
if self.level_selections.len() == 1 {
return self.level_selections[0].selected_count();
}
let mut parent_selected: Vec<bool> = match &self.level_selections[0] {
LevelSelection::All { count } => vec![true; *count],
LevelSelection::Sparse(sel) => {
let max_idx = sel.iter().max().unwrap_or(0);
let mut selected = vec![false; max_idx + 1];
for idx in sel.iter() {
selected[idx] = true;
}
selected
}
};
for (level_sel, level_mults) in self
.level_selections
.iter()
.skip(1)
.zip(multiplicities.iter().skip(1))
{
let mut child_selected = Vec::new();
let mut child_idx = 0;
for (parent_idx, &mult) in level_mults.iter().enumerate() {
let parent_is_selected = parent_selected.get(parent_idx).copied().unwrap_or(false);
for _ in 0..mult {
let child_is_selected = parent_is_selected && level_sel.is_selected(child_idx);
child_selected.push(child_is_selected);
child_idx += 1;
}
}
parent_selected = child_selected;
}
parent_selected.iter().filter(|&&s| s).count()
}
pub fn invalidate_cache(&mut self) {
self.cached_selected_count = None;
}
}
#[derive(Debug, Clone)]
pub struct ChunkState {
state: FactorizationState,
selection: Option<FactorizedSelection>,
cached_multiplicities: Option<Arc<[usize]>>,
generation: u64,
}
impl ChunkState {
#[must_use]
pub fn flat(row_count: usize) -> Self {
Self {
state: FactorizationState::Flat { row_count },
selection: None,
cached_multiplicities: None,
generation: 0,
}
}
#[must_use]
pub fn unflat(level_count: usize, logical_rows: usize) -> Self {
Self {
state: FactorizationState::Unflat {
level_count,
logical_rows,
},
selection: None,
cached_multiplicities: None,
generation: 0,
}
}
#[must_use]
pub fn factorization_state(&self) -> FactorizationState {
self.state
}
#[must_use]
pub fn is_flat(&self) -> bool {
self.state.is_flat()
}
#[must_use]
pub fn is_factorized(&self) -> bool {
self.state.is_unflat()
}
#[must_use]
pub fn logical_row_count(&self) -> usize {
self.state.logical_row_count()
}
#[must_use]
pub fn level_count(&self) -> usize {
self.state.level_count()
}
#[must_use]
pub fn generation(&self) -> u64 {
self.generation
}
#[must_use]
pub fn selection(&self) -> Option<&FactorizedSelection> {
self.selection.as_ref()
}
pub fn selection_mut(&mut self) -> &mut Option<FactorizedSelection> {
&mut self.selection
}
pub fn set_selection(&mut self, selection: FactorizedSelection) {
self.selection = Some(selection);
}
pub fn clear_selection(&mut self) {
self.selection = None;
}
pub fn set_state(&mut self, state: FactorizationState) {
self.state = state;
self.invalidate_cache();
}
pub fn invalidate_cache(&mut self) {
self.cached_multiplicities = None;
self.generation += 1;
}
pub fn get_or_compute_multiplicities<F>(&mut self, compute: F) -> Arc<[usize]>
where
F: FnOnce() -> Vec<usize>,
{
if let Some(ref cached) = self.cached_multiplicities {
return Arc::clone(cached);
}
let mults: Arc<[usize]> = compute().into();
self.cached_multiplicities = Some(Arc::clone(&mults));
mults
}
#[must_use]
pub fn cached_multiplicities(&self) -> Option<&Arc<[usize]>> {
self.cached_multiplicities.as_ref()
}
pub fn set_cached_multiplicities(&mut self, mults: Arc<[usize]>) {
self.cached_multiplicities = Some(mults);
}
}
impl Default for ChunkState {
fn default() -> Self {
Self::flat(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_factorization_state_flat() {
let state = FactorizationState::Flat { row_count: 100 };
assert!(state.is_flat());
assert!(!state.is_unflat());
assert_eq!(state.logical_row_count(), 100);
assert_eq!(state.level_count(), 1);
}
#[test]
fn test_factorization_state_unflat() {
let state = FactorizationState::Unflat {
level_count: 3,
logical_rows: 1000,
};
assert!(!state.is_flat());
assert!(state.is_unflat());
assert_eq!(state.logical_row_count(), 1000);
assert_eq!(state.level_count(), 3);
}
#[test]
fn test_level_selection_all() {
let sel = LevelSelection::all(10);
assert_eq!(sel.selected_count(), 10);
for i in 0..10 {
assert!(sel.is_selected(i));
}
assert!(!sel.is_selected(10));
}
#[test]
fn test_level_selection_filter() {
let sel = LevelSelection::all(10);
let filtered = sel.filter(|i| i % 2 == 0);
assert_eq!(filtered.selected_count(), 5);
assert!(filtered.is_selected(0));
assert!(!filtered.is_selected(1));
assert!(filtered.is_selected(2));
}
#[test]
fn test_level_selection_filter_sparse() {
let sel = LevelSelection::from_predicate(10, |i| i < 5);
assert_eq!(sel.selected_count(), 5);
let filtered = sel.filter(|i| i % 2 == 0);
assert_eq!(filtered.selected_count(), 3);
assert!(filtered.is_selected(0));
assert!(!filtered.is_selected(1));
assert!(filtered.is_selected(2));
}
#[test]
fn test_level_selection_iter_all() {
let sel = LevelSelection::all(5);
let indices: Vec<usize> = sel.iter().collect();
assert_eq!(indices, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_level_selection_iter_sparse() {
let sel = LevelSelection::from_predicate(10, |i| i % 3 == 0);
let indices: Vec<usize> = sel.iter().collect();
assert_eq!(indices, vec![0, 3, 6, 9]);
}
#[test]
fn test_level_selection_from_predicate_all_selected() {
let sel = LevelSelection::from_predicate(5, |_| true);
assert_eq!(sel.selected_count(), 5);
match sel {
LevelSelection::All { count } => assert_eq!(count, 5),
LevelSelection::Sparse(_) => panic!("Expected All variant"),
}
}
#[test]
fn test_level_selection_from_predicate_partial() {
let sel = LevelSelection::from_predicate(10, |i| i < 3);
assert_eq!(sel.selected_count(), 3);
match sel {
LevelSelection::Sparse(_) => {}
LevelSelection::All { .. } => panic!("Expected Sparse variant"),
}
}
#[test]
fn test_factorized_selection_all() {
let sel = FactorizedSelection::all(&[10, 100, 1000]);
assert_eq!(sel.level_count(), 3);
assert!(sel.is_selected(0, 5));
assert!(sel.is_selected(1, 50));
assert!(sel.is_selected(2, 500));
}
#[test]
fn test_factorized_selection_new() {
let level_sels = vec![
LevelSelection::all(5),
LevelSelection::from_predicate(10, |i| i < 3),
];
let sel = FactorizedSelection::new(level_sels);
assert_eq!(sel.level_count(), 2);
assert!(sel.is_selected(0, 4));
assert!(sel.is_selected(1, 2));
assert!(!sel.is_selected(1, 5));
}
#[test]
fn test_factorized_selection_filter_level() {
let sel = FactorizedSelection::all(&[10, 100]);
let filtered = sel.filter_level(1, |i| i < 50);
assert!(filtered.is_selected(0, 5)); assert!(filtered.is_selected(1, 25)); assert!(!filtered.is_selected(1, 75)); }
#[test]
fn test_factorized_selection_filter_level_invalid() {
let sel = FactorizedSelection::all(&[10, 100]);
let filtered = sel.filter_level(5, |_| true);
assert_eq!(filtered.level_count(), 2);
}
#[test]
fn test_factorized_selection_is_selected_invalid_level() {
let sel = FactorizedSelection::all(&[10]);
assert!(!sel.is_selected(5, 0)); }
#[test]
fn test_factorized_selection_level() {
let sel = FactorizedSelection::all(&[10, 20]);
let level0 = sel.level(0);
assert!(level0.is_some());
assert_eq!(level0.unwrap().selected_count(), 10);
let level1 = sel.level(1);
assert!(level1.is_some());
assert_eq!(level1.unwrap().selected_count(), 20);
assert!(sel.level(5).is_none());
}
#[test]
fn test_factorized_selection_selected_count_single_level() {
let mut sel = FactorizedSelection::all(&[10]);
let multiplicities: Vec<Vec<usize>> = vec![vec![1; 10]];
let count = sel.selected_count(&multiplicities);
assert_eq!(count, 10);
}
#[test]
fn test_factorized_selection_selected_count_multi_level() {
let level_sels = vec![
LevelSelection::all(2), LevelSelection::from_predicate(4, |i| i % 2 == 0), ];
let mut sel = FactorizedSelection::new(level_sels);
let multiplicities = vec![
vec![1, 1], vec![2, 2], ];
let count = sel.selected_count(&multiplicities);
assert_eq!(count, 2);
}
#[test]
fn test_factorized_selection_selected_count_cached() {
let mut sel = FactorizedSelection::all(&[5]);
let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
let count1 = sel.selected_count(&multiplicities);
assert_eq!(count1, 5);
let count2 = sel.selected_count(&multiplicities);
assert_eq!(count2, 5);
}
#[test]
fn test_factorized_selection_selected_count_empty() {
let mut sel = FactorizedSelection::all(&[]);
let multiplicities: Vec<Vec<usize>> = vec![];
assert_eq!(sel.selected_count(&multiplicities), 0);
}
#[test]
fn test_factorized_selection_invalidate_cache() {
let mut sel = FactorizedSelection::all(&[5]);
let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
let _ = sel.selected_count(&multiplicities);
sel.invalidate_cache();
let _ = sel.selected_count(&multiplicities);
}
#[test]
fn test_chunk_state_flat() {
let state = ChunkState::flat(100);
assert!(state.is_flat());
assert!(!state.is_factorized());
assert_eq!(state.logical_row_count(), 100);
assert_eq!(state.level_count(), 1);
}
#[test]
fn test_chunk_state_unflat() {
let state = ChunkState::unflat(3, 1000);
assert!(!state.is_flat());
assert!(state.is_factorized());
assert_eq!(state.logical_row_count(), 1000);
assert_eq!(state.level_count(), 3);
}
#[test]
fn test_chunk_state_factorization_state() {
let state = ChunkState::flat(50);
let fs = state.factorization_state();
assert!(fs.is_flat());
}
#[test]
fn test_chunk_state_selection() {
let mut state = ChunkState::unflat(2, 100);
assert!(state.selection().is_none());
let sel = FactorizedSelection::all(&[10, 100]);
state.set_selection(sel);
assert!(state.selection().is_some());
assert_eq!(state.selection().unwrap().level_count(), 2);
}
#[test]
fn test_chunk_state_selection_mut() {
let mut state = ChunkState::unflat(2, 100);
let sel = FactorizedSelection::all(&[10, 100]);
state.set_selection(sel);
let sel_mut = state.selection_mut();
assert!(sel_mut.is_some());
*sel_mut = None;
assert!(state.selection().is_none());
}
#[test]
fn test_chunk_state_clear_selection() {
let mut state = ChunkState::unflat(2, 100);
let sel = FactorizedSelection::all(&[10, 100]);
state.set_selection(sel);
assert!(state.selection().is_some());
state.clear_selection();
assert!(state.selection().is_none());
}
#[test]
fn test_chunk_state_set_state() {
let mut state = ChunkState::flat(100);
assert!(state.is_flat());
assert_eq!(state.generation(), 0);
state.set_state(FactorizationState::Unflat {
level_count: 2,
logical_rows: 200,
});
assert!(state.is_factorized());
assert_eq!(state.logical_row_count(), 200);
assert_eq!(state.generation(), 1); }
#[test]
fn test_chunk_state_caching() {
let mut state = ChunkState::unflat(2, 100);
let mut computed = false;
let mults1 = state.get_or_compute_multiplicities(|| {
computed = true;
vec![1, 2, 3, 4, 5]
});
assert!(computed);
assert_eq!(mults1.len(), 5);
computed = false;
let mults2 = state.get_or_compute_multiplicities(|| {
computed = true;
vec![99, 99, 99]
});
assert!(!computed);
assert_eq!(mults2.len(), 5);
state.invalidate_cache();
let mults3 = state.get_or_compute_multiplicities(|| {
computed = true;
vec![10, 20, 30]
});
assert!(computed);
assert_eq!(mults3.len(), 3);
}
#[test]
fn test_chunk_state_cached_multiplicities() {
let mut state = ChunkState::unflat(2, 100);
assert!(state.cached_multiplicities().is_none());
let _ = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
assert!(state.cached_multiplicities().is_some());
assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
}
#[test]
fn test_chunk_state_set_cached_multiplicities() {
let mut state = ChunkState::unflat(2, 100);
let mults: Arc<[usize]> = vec![5, 10, 15].into();
state.set_cached_multiplicities(mults);
assert!(state.cached_multiplicities().is_some());
assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
}
#[test]
fn test_chunk_state_generation() {
let mut state = ChunkState::flat(100);
assert_eq!(state.generation(), 0);
state.invalidate_cache();
assert_eq!(state.generation(), 1);
state.set_state(FactorizationState::Unflat {
level_count: 2,
logical_rows: 200,
});
assert_eq!(state.generation(), 2);
}
#[test]
fn test_chunk_state_default() {
let state = ChunkState::default();
assert!(state.is_flat());
assert_eq!(state.logical_row_count(), 0);
assert_eq!(state.generation(), 0);
}
}