1use std::sync::Arc;
24
25use super::selection::SelectionVector;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32#[non_exhaustive]
33pub enum FactorizationState {
34 Flat {
38 row_count: usize,
40 },
41 Unflat {
45 level_count: usize,
47 logical_rows: usize,
49 },
50}
51
52impl FactorizationState {
53 #[must_use]
55 pub fn is_flat(&self) -> bool {
56 matches!(self, Self::Flat { .. })
57 }
58
59 #[must_use]
61 pub fn is_unflat(&self) -> bool {
62 matches!(self, Self::Unflat { .. })
63 }
64
65 #[must_use]
67 pub fn logical_row_count(&self) -> usize {
68 match self {
69 Self::Flat { row_count } => *row_count,
70 Self::Unflat { logical_rows, .. } => *logical_rows,
71 }
72 }
73
74 #[must_use]
76 pub fn level_count(&self) -> usize {
77 match self {
78 Self::Flat { .. } => 1,
79 Self::Unflat { level_count, .. } => *level_count,
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
89#[non_exhaustive]
90pub enum LevelSelection {
91 All {
93 count: usize,
95 },
96 Sparse(SelectionVector),
100}
101
102impl LevelSelection {
103 #[must_use]
105 pub fn all(count: usize) -> Self {
106 Self::All { count }
107 }
108
109 #[must_use]
111 pub fn from_predicate<F>(count: usize, predicate: F) -> Self
112 where
113 F: Fn(usize) -> bool,
114 {
115 let selected = SelectionVector::from_predicate(count, predicate);
116 if selected.len() == count {
117 Self::All { count }
118 } else {
119 Self::Sparse(selected)
120 }
121 }
122
123 #[must_use]
125 pub fn selected_count(&self) -> usize {
126 match self {
127 Self::All { count } => *count,
128 Self::Sparse(sel) => sel.len(),
129 }
130 }
131
132 #[must_use]
134 pub fn is_selected(&self, physical_idx: usize) -> bool {
135 match self {
136 Self::All { count } => physical_idx < *count,
137 Self::Sparse(sel) => sel.contains(physical_idx),
138 }
139 }
140
141 #[must_use]
143 pub fn filter<F>(&self, predicate: F) -> Self
144 where
145 F: Fn(usize) -> bool,
146 {
147 match self {
148 Self::All { count } => Self::from_predicate(*count, predicate),
149 Self::Sparse(sel) => {
150 let filtered = sel.filter(predicate);
151 Self::Sparse(filtered)
152 }
153 }
154 }
155
156 pub fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
158 match self {
159 Self::All { count } => Box::new(0..*count),
160 Self::Sparse(sel) => Box::new(sel.iter()),
161 }
162 }
163}
164
165impl<'a> IntoIterator for &'a LevelSelection {
166 type Item = usize;
167 type IntoIter = Box<dyn Iterator<Item = usize> + 'a>;
168
169 fn into_iter(self) -> Self::IntoIter {
170 self.iter()
171 }
172}
173
174#[derive(Debug, Clone)]
179pub struct FactorizedSelection {
180 level_selections: Vec<LevelSelection>,
182 cached_selected_count: Option<usize>,
184}
185
186impl FactorizedSelection {
187 #[must_use]
189 pub fn all(level_counts: &[usize]) -> Self {
190 let level_selections = level_counts
191 .iter()
192 .map(|&count| LevelSelection::all(count))
193 .collect();
194 Self {
195 level_selections,
196 cached_selected_count: None,
197 }
198 }
199
200 #[must_use]
202 pub fn new(level_selections: Vec<LevelSelection>) -> Self {
203 Self {
204 level_selections,
205 cached_selected_count: None,
206 }
207 }
208
209 #[must_use]
211 pub fn level_count(&self) -> usize {
212 self.level_selections.len()
213 }
214
215 #[must_use]
217 pub fn level(&self, level: usize) -> Option<&LevelSelection> {
218 self.level_selections.get(level)
219 }
220
221 #[must_use]
227 pub fn filter_level<F>(&self, level: usize, predicate: F) -> Self
228 where
229 F: Fn(usize) -> bool,
230 {
231 let mut new_selections = self.level_selections.clone();
232
233 if let Some(sel) = new_selections.get_mut(level) {
234 *sel = sel.filter(predicate);
235 }
236
237 Self {
238 level_selections: new_selections,
239 cached_selected_count: None, }
241 }
242
243 #[must_use]
245 pub fn is_selected(&self, level: usize, physical_idx: usize) -> bool {
246 self.level_selections
247 .get(level)
248 .is_some_and(|sel| sel.is_selected(physical_idx))
249 }
250
251 pub fn selected_count(&mut self, multiplicities: &[Vec<usize>]) -> usize {
256 if let Some(count) = self.cached_selected_count {
257 return count;
258 }
259
260 let count = self.compute_selected_count(multiplicities);
261 self.cached_selected_count = Some(count);
262 count
263 }
264
265 fn compute_selected_count(&self, multiplicities: &[Vec<usize>]) -> usize {
267 if self.level_selections.is_empty() {
268 return 0;
269 }
270
271 if self.level_selections.len() == 1 {
273 return self.level_selections[0].selected_count();
274 }
275
276 let mut parent_selected: Vec<bool> = match &self.level_selections[0] {
279 LevelSelection::All { count } => vec![true; *count],
280 LevelSelection::Sparse(sel) => {
281 let max_idx = sel.iter().max().unwrap_or(0);
282 let mut selected = vec![false; max_idx + 1];
283 for idx in sel.iter() {
284 selected[idx] = true;
285 }
286 selected
287 }
288 };
289
290 for (level_sel, level_mults) in self
292 .level_selections
293 .iter()
294 .skip(1)
295 .zip(multiplicities.iter().skip(1))
296 {
297 let mut child_selected = Vec::new();
298 let mut child_idx = 0;
299
300 for (parent_idx, &mult) in level_mults.iter().enumerate() {
301 let parent_is_selected = parent_selected.get(parent_idx).copied().unwrap_or(false);
302
303 for _ in 0..mult {
304 let child_is_selected = parent_is_selected && level_sel.is_selected(child_idx);
305 child_selected.push(child_is_selected);
306 child_idx += 1;
307 }
308 }
309
310 parent_selected = child_selected;
311 }
312
313 parent_selected.iter().filter(|&&s| s).count()
315 }
316
317 pub fn invalidate_cache(&mut self) {
319 self.cached_selected_count = None;
320 }
321}
322
323#[derive(Debug, Clone)]
334pub struct ChunkState {
335 state: FactorizationState,
337 selection: Option<FactorizedSelection>,
340 cached_multiplicities: Option<Arc<[usize]>>,
343 generation: u64,
345}
346
347impl ChunkState {
348 #[must_use]
350 pub fn flat(row_count: usize) -> Self {
351 Self {
352 state: FactorizationState::Flat { row_count },
353 selection: None,
354 cached_multiplicities: None,
355 generation: 0,
356 }
357 }
358
359 #[must_use]
361 pub fn unflat(level_count: usize, logical_rows: usize) -> Self {
362 Self {
363 state: FactorizationState::Unflat {
364 level_count,
365 logical_rows,
366 },
367 selection: None,
368 cached_multiplicities: None,
369 generation: 0,
370 }
371 }
372
373 #[must_use]
375 pub fn factorization_state(&self) -> FactorizationState {
376 self.state
377 }
378
379 #[must_use]
381 pub fn is_flat(&self) -> bool {
382 self.state.is_flat()
383 }
384
385 #[must_use]
387 pub fn is_factorized(&self) -> bool {
388 self.state.is_unflat()
389 }
390
391 #[must_use]
396 pub fn logical_row_count(&self) -> usize {
397 self.state.logical_row_count()
398 }
399
400 #[must_use]
402 pub fn level_count(&self) -> usize {
403 self.state.level_count()
404 }
405
406 #[must_use]
408 pub fn generation(&self) -> u64 {
409 self.generation
410 }
411
412 #[must_use]
414 pub fn selection(&self) -> Option<&FactorizedSelection> {
415 self.selection.as_ref()
416 }
417
418 pub fn selection_mut(&mut self) -> &mut Option<FactorizedSelection> {
420 &mut self.selection
421 }
422
423 pub fn set_selection(&mut self, selection: FactorizedSelection) {
425 self.selection = Some(selection);
426 }
428
429 pub fn clear_selection(&mut self) {
431 self.selection = None;
432 }
433
434 pub fn set_state(&mut self, state: FactorizationState) {
436 self.state = state;
437 self.invalidate_cache();
438 }
439
440 pub fn invalidate_cache(&mut self) {
442 self.cached_multiplicities = None;
443 self.generation += 1;
444 }
445
446 pub fn get_or_compute_multiplicities<F>(&mut self, compute: F) -> Arc<[usize]>
465 where
466 F: FnOnce() -> Vec<usize>,
467 {
468 if let Some(ref cached) = self.cached_multiplicities {
469 return Arc::clone(cached);
470 }
471
472 let mults: Arc<[usize]> = compute().into();
473 self.cached_multiplicities = Some(Arc::clone(&mults));
474 mults
475 }
476
477 #[must_use]
481 pub fn cached_multiplicities(&self) -> Option<&Arc<[usize]>> {
482 self.cached_multiplicities.as_ref()
483 }
484
485 pub fn set_cached_multiplicities(&mut self, mults: Arc<[usize]>) {
489 self.cached_multiplicities = Some(mults);
490 }
491}
492
493impl Default for ChunkState {
494 fn default() -> Self {
495 Self::flat(0)
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 #[test]
504 fn test_factorization_state_flat() {
505 let state = FactorizationState::Flat { row_count: 100 };
506 assert!(state.is_flat());
507 assert!(!state.is_unflat());
508 assert_eq!(state.logical_row_count(), 100);
509 assert_eq!(state.level_count(), 1);
510 }
511
512 #[test]
513 fn test_factorization_state_unflat() {
514 let state = FactorizationState::Unflat {
515 level_count: 3,
516 logical_rows: 1000,
517 };
518 assert!(!state.is_flat());
519 assert!(state.is_unflat());
520 assert_eq!(state.logical_row_count(), 1000);
521 assert_eq!(state.level_count(), 3);
522 }
523
524 #[test]
525 fn test_level_selection_all() {
526 let sel = LevelSelection::all(10);
527 assert_eq!(sel.selected_count(), 10);
528 for i in 0..10 {
529 assert!(sel.is_selected(i));
530 }
531 assert!(!sel.is_selected(10));
532 }
533
534 #[test]
535 fn test_level_selection_filter() {
536 let sel = LevelSelection::all(10);
537 let filtered = sel.filter(|i| i % 2 == 0);
538 assert_eq!(filtered.selected_count(), 5);
539 assert!(filtered.is_selected(0));
540 assert!(!filtered.is_selected(1));
541 assert!(filtered.is_selected(2));
542 }
543
544 #[test]
545 fn test_level_selection_filter_sparse() {
546 let sel = LevelSelection::from_predicate(10, |i| i < 5);
548 assert_eq!(sel.selected_count(), 5);
549
550 let filtered = sel.filter(|i| i % 2 == 0);
552 assert_eq!(filtered.selected_count(), 3);
554 assert!(filtered.is_selected(0));
555 assert!(!filtered.is_selected(1));
556 assert!(filtered.is_selected(2));
557 }
558
559 #[test]
560 fn test_level_selection_iter_all() {
561 let sel = LevelSelection::all(5);
562 let indices: Vec<usize> = sel.iter().collect();
563 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
564 }
565
566 #[test]
567 fn test_level_selection_iter_sparse() {
568 let sel = LevelSelection::from_predicate(10, |i| i % 3 == 0);
569 let indices: Vec<usize> = sel.iter().collect();
570 assert_eq!(indices, vec![0, 3, 6, 9]);
571 }
572
573 #[test]
574 fn test_level_selection_from_predicate_all_selected() {
575 let sel = LevelSelection::from_predicate(5, |_| true);
577 assert_eq!(sel.selected_count(), 5);
578 match sel {
579 LevelSelection::All { count } => assert_eq!(count, 5),
580 LevelSelection::Sparse(_) => panic!("Expected All variant"),
581 }
582 }
583
584 #[test]
585 fn test_level_selection_from_predicate_partial() {
586 let sel = LevelSelection::from_predicate(10, |i| i < 3);
588 assert_eq!(sel.selected_count(), 3);
589 match sel {
590 LevelSelection::Sparse(_) => {}
591 LevelSelection::All { .. } => panic!("Expected Sparse variant"),
592 }
593 }
594
595 #[test]
596 fn test_factorized_selection_all() {
597 let sel = FactorizedSelection::all(&[10, 100, 1000]);
598 assert_eq!(sel.level_count(), 3);
599 assert!(sel.is_selected(0, 5));
600 assert!(sel.is_selected(1, 50));
601 assert!(sel.is_selected(2, 500));
602 }
603
604 #[test]
605 fn test_factorized_selection_new() {
606 let level_sels = vec![
607 LevelSelection::all(5),
608 LevelSelection::from_predicate(10, |i| i < 3),
609 ];
610 let sel = FactorizedSelection::new(level_sels);
611
612 assert_eq!(sel.level_count(), 2);
613 assert!(sel.is_selected(0, 4));
614 assert!(sel.is_selected(1, 2));
615 assert!(!sel.is_selected(1, 5));
616 }
617
618 #[test]
619 fn test_factorized_selection_filter_level() {
620 let sel = FactorizedSelection::all(&[10, 100]);
621 let filtered = sel.filter_level(1, |i| i < 50);
622
623 assert!(filtered.is_selected(0, 5)); assert!(filtered.is_selected(1, 25)); assert!(!filtered.is_selected(1, 75)); }
627
628 #[test]
629 fn test_factorized_selection_filter_level_invalid() {
630 let sel = FactorizedSelection::all(&[10, 100]);
631
632 let filtered = sel.filter_level(5, |_| true);
634 assert_eq!(filtered.level_count(), 2);
635 }
636
637 #[test]
638 fn test_factorized_selection_is_selected_invalid_level() {
639 let sel = FactorizedSelection::all(&[10]);
640 assert!(!sel.is_selected(5, 0)); }
642
643 #[test]
644 fn test_factorized_selection_level() {
645 let sel = FactorizedSelection::all(&[10, 20]);
646
647 let level0 = sel.level(0);
648 assert!(level0.is_some());
649 assert_eq!(level0.unwrap().selected_count(), 10);
650
651 let level1 = sel.level(1);
652 assert!(level1.is_some());
653 assert_eq!(level1.unwrap().selected_count(), 20);
654
655 assert!(sel.level(5).is_none());
656 }
657
658 #[test]
659 fn test_factorized_selection_selected_count_single_level() {
660 let mut sel = FactorizedSelection::all(&[10]);
661 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 10]];
662
663 let count = sel.selected_count(&multiplicities);
664 assert_eq!(count, 10);
665 }
666
667 #[test]
668 fn test_factorized_selection_selected_count_multi_level() {
669 let level_sels = vec![
670 LevelSelection::all(2), LevelSelection::from_predicate(4, |i| i % 2 == 0), ];
673 let mut sel = FactorizedSelection::new(level_sels);
674
675 let multiplicities = vec![
678 vec![1, 1], vec![2, 2], ];
681
682 let count = sel.selected_count(&multiplicities);
683 assert_eq!(count, 2);
686 }
687
688 #[test]
689 fn test_factorized_selection_selected_count_cached() {
690 let mut sel = FactorizedSelection::all(&[5]);
691 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
692
693 let count1 = sel.selected_count(&multiplicities);
695 assert_eq!(count1, 5);
696
697 let count2 = sel.selected_count(&multiplicities);
699 assert_eq!(count2, 5);
700 }
701
702 #[test]
703 fn test_factorized_selection_selected_count_empty() {
704 let mut sel = FactorizedSelection::all(&[]);
705 let multiplicities: Vec<Vec<usize>> = vec![];
706
707 assert_eq!(sel.selected_count(&multiplicities), 0);
708 }
709
710 #[test]
711 fn test_factorized_selection_invalidate_cache() {
712 let mut sel = FactorizedSelection::all(&[5]);
713 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
714
715 let _ = sel.selected_count(&multiplicities);
717
718 sel.invalidate_cache();
720
721 let _ = sel.selected_count(&multiplicities);
723 }
724
725 #[test]
726 fn test_chunk_state_flat() {
727 let state = ChunkState::flat(100);
728 assert!(state.is_flat());
729 assert!(!state.is_factorized());
730 assert_eq!(state.logical_row_count(), 100);
731 assert_eq!(state.level_count(), 1);
732 }
733
734 #[test]
735 fn test_chunk_state_unflat() {
736 let state = ChunkState::unflat(3, 1000);
737 assert!(!state.is_flat());
738 assert!(state.is_factorized());
739 assert_eq!(state.logical_row_count(), 1000);
740 assert_eq!(state.level_count(), 3);
741 }
742
743 #[test]
744 fn test_chunk_state_factorization_state() {
745 let state = ChunkState::flat(50);
746 let fs = state.factorization_state();
747 assert!(fs.is_flat());
748 }
749
750 #[test]
751 fn test_chunk_state_selection() {
752 let mut state = ChunkState::unflat(2, 100);
753
754 assert!(state.selection().is_none());
756
757 let sel = FactorizedSelection::all(&[10, 100]);
759 state.set_selection(sel);
760
761 assert!(state.selection().is_some());
762 assert_eq!(state.selection().unwrap().level_count(), 2);
763 }
764
765 #[test]
766 fn test_chunk_state_selection_mut() {
767 let mut state = ChunkState::unflat(2, 100);
768
769 let sel = FactorizedSelection::all(&[10, 100]);
771 state.set_selection(sel);
772
773 let sel_mut = state.selection_mut();
775 assert!(sel_mut.is_some());
776
777 *sel_mut = None;
779 assert!(state.selection().is_none());
780 }
781
782 #[test]
783 fn test_chunk_state_clear_selection() {
784 let mut state = ChunkState::unflat(2, 100);
785
786 let sel = FactorizedSelection::all(&[10, 100]);
787 state.set_selection(sel);
788 assert!(state.selection().is_some());
789
790 state.clear_selection();
791 assert!(state.selection().is_none());
792 }
793
794 #[test]
795 fn test_chunk_state_set_state() {
796 let mut state = ChunkState::flat(100);
797 assert!(state.is_flat());
798 assert_eq!(state.generation(), 0);
799
800 state.set_state(FactorizationState::Unflat {
801 level_count: 2,
802 logical_rows: 200,
803 });
804
805 assert!(state.is_factorized());
806 assert_eq!(state.logical_row_count(), 200);
807 assert_eq!(state.generation(), 1); }
809
810 #[test]
811 fn test_chunk_state_caching() {
812 let mut state = ChunkState::unflat(2, 100);
813
814 let mut computed = false;
816 let mults1 = state.get_or_compute_multiplicities(|| {
817 computed = true;
818 vec![1, 2, 3, 4, 5]
819 });
820 assert!(computed);
821 assert_eq!(mults1.len(), 5);
822
823 computed = false;
825 let mults2 = state.get_or_compute_multiplicities(|| {
826 computed = true;
827 vec![99, 99, 99]
828 });
829 assert!(!computed);
830 assert_eq!(mults2.len(), 5); state.invalidate_cache();
834 let mults3 = state.get_or_compute_multiplicities(|| {
835 computed = true;
836 vec![10, 20, 30]
837 });
838 assert!(computed);
839 assert_eq!(mults3.len(), 3);
840 }
841
842 #[test]
843 fn test_chunk_state_cached_multiplicities() {
844 let mut state = ChunkState::unflat(2, 100);
845
846 assert!(state.cached_multiplicities().is_none());
848
849 let _ = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
851
852 assert!(state.cached_multiplicities().is_some());
854 assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
855 }
856
857 #[test]
858 fn test_chunk_state_set_cached_multiplicities() {
859 let mut state = ChunkState::unflat(2, 100);
860
861 let mults: Arc<[usize]> = vec![5, 10, 15].into();
862 state.set_cached_multiplicities(mults);
863
864 assert!(state.cached_multiplicities().is_some());
865 assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
866 }
867
868 #[test]
869 fn test_chunk_state_generation() {
870 let mut state = ChunkState::flat(100);
871 assert_eq!(state.generation(), 0);
872
873 state.invalidate_cache();
874 assert_eq!(state.generation(), 1);
875
876 state.set_state(FactorizationState::Unflat {
877 level_count: 2,
878 logical_rows: 200,
879 });
880 assert_eq!(state.generation(), 2);
881 }
882
883 #[test]
884 fn test_chunk_state_default() {
885 let state = ChunkState::default();
886 assert!(state.is_flat());
887 assert_eq!(state.logical_row_count(), 0);
888 assert_eq!(state.generation(), 0);
889 }
890}