1use std::sync::Arc;
22
23use super::selection::SelectionVector;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum FactorizationState {
31 Flat {
35 row_count: usize,
37 },
38 Unflat {
42 level_count: usize,
44 logical_rows: usize,
46 },
47}
48
49impl FactorizationState {
50 #[must_use]
52 pub fn is_flat(&self) -> bool {
53 matches!(self, Self::Flat { .. })
54 }
55
56 #[must_use]
58 pub fn is_unflat(&self) -> bool {
59 matches!(self, Self::Unflat { .. })
60 }
61
62 #[must_use]
64 pub fn logical_row_count(&self) -> usize {
65 match self {
66 Self::Flat { row_count } => *row_count,
67 Self::Unflat { logical_rows, .. } => *logical_rows,
68 }
69 }
70
71 #[must_use]
73 pub fn level_count(&self) -> usize {
74 match self {
75 Self::Flat { .. } => 1,
76 Self::Unflat { level_count, .. } => *level_count,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
86pub enum LevelSelection {
87 All {
89 count: usize,
91 },
92 Sparse(SelectionVector),
96}
97
98impl LevelSelection {
99 #[must_use]
101 pub fn all(count: usize) -> Self {
102 Self::All { count }
103 }
104
105 #[must_use]
107 pub fn from_predicate<F>(count: usize, predicate: F) -> Self
108 where
109 F: Fn(usize) -> bool,
110 {
111 let selected = SelectionVector::from_predicate(count, predicate);
112 if selected.len() == count {
113 Self::All { count }
114 } else {
115 Self::Sparse(selected)
116 }
117 }
118
119 #[must_use]
121 pub fn selected_count(&self) -> usize {
122 match self {
123 Self::All { count } => *count,
124 Self::Sparse(sel) => sel.len(),
125 }
126 }
127
128 #[must_use]
130 pub fn is_selected(&self, physical_idx: usize) -> bool {
131 match self {
132 Self::All { count } => physical_idx < *count,
133 Self::Sparse(sel) => sel.contains(physical_idx),
134 }
135 }
136
137 #[must_use]
139 pub fn filter<F>(&self, predicate: F) -> Self
140 where
141 F: Fn(usize) -> bool,
142 {
143 match self {
144 Self::All { count } => Self::from_predicate(*count, predicate),
145 Self::Sparse(sel) => {
146 let filtered = sel.filter(predicate);
147 Self::Sparse(filtered)
148 }
149 }
150 }
151
152 #[allow(clippy::iter_without_into_iter)]
154 pub fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
155 match self {
156 Self::All { count } => Box::new(0..*count),
157 Self::Sparse(sel) => Box::new(sel.iter()),
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
167pub struct FactorizedSelection {
168 level_selections: Vec<LevelSelection>,
170 cached_selected_count: Option<usize>,
172}
173
174impl FactorizedSelection {
175 #[must_use]
177 pub fn all(level_counts: &[usize]) -> Self {
178 let level_selections = level_counts
179 .iter()
180 .map(|&count| LevelSelection::all(count))
181 .collect();
182 Self {
183 level_selections,
184 cached_selected_count: None,
185 }
186 }
187
188 #[must_use]
190 pub fn new(level_selections: Vec<LevelSelection>) -> Self {
191 Self {
192 level_selections,
193 cached_selected_count: None,
194 }
195 }
196
197 #[must_use]
199 pub fn level_count(&self) -> usize {
200 self.level_selections.len()
201 }
202
203 #[must_use]
205 pub fn level(&self, level: usize) -> Option<&LevelSelection> {
206 self.level_selections.get(level)
207 }
208
209 #[must_use]
215 pub fn filter_level<F>(&self, level: usize, predicate: F) -> Self
216 where
217 F: Fn(usize) -> bool,
218 {
219 let mut new_selections = self.level_selections.clone();
220
221 if let Some(sel) = new_selections.get_mut(level) {
222 *sel = sel.filter(predicate);
223 }
224
225 Self {
226 level_selections: new_selections,
227 cached_selected_count: None, }
229 }
230
231 #[must_use]
233 pub fn is_selected(&self, level: usize, physical_idx: usize) -> bool {
234 self.level_selections
235 .get(level)
236 .is_some_and(|sel| sel.is_selected(physical_idx))
237 }
238
239 pub fn selected_count(&mut self, multiplicities: &[Vec<usize>]) -> usize {
244 if let Some(count) = self.cached_selected_count {
245 return count;
246 }
247
248 let count = self.compute_selected_count(multiplicities);
249 self.cached_selected_count = Some(count);
250 count
251 }
252
253 fn compute_selected_count(&self, multiplicities: &[Vec<usize>]) -> usize {
255 if self.level_selections.is_empty() {
256 return 0;
257 }
258
259 if self.level_selections.len() == 1 {
261 return self.level_selections[0].selected_count();
262 }
263
264 let mut parent_selected: Vec<bool> = match &self.level_selections[0] {
267 LevelSelection::All { count } => vec![true; *count],
268 LevelSelection::Sparse(sel) => {
269 let max_idx = sel.iter().max().unwrap_or(0);
270 let mut selected = vec![false; max_idx + 1];
271 for idx in sel.iter() {
272 selected[idx] = true;
273 }
274 selected
275 }
276 };
277
278 for (level_sel, level_mults) in self
280 .level_selections
281 .iter()
282 .skip(1)
283 .zip(multiplicities.iter().skip(1))
284 {
285 let mut child_selected = Vec::new();
286 let mut child_idx = 0;
287
288 for (parent_idx, &mult) in level_mults.iter().enumerate() {
289 let parent_is_selected = parent_selected.get(parent_idx).copied().unwrap_or(false);
290
291 for _ in 0..mult {
292 let child_is_selected = parent_is_selected && level_sel.is_selected(child_idx);
293 child_selected.push(child_is_selected);
294 child_idx += 1;
295 }
296 }
297
298 parent_selected = child_selected;
299 }
300
301 parent_selected.iter().filter(|&&s| s).count()
303 }
304
305 pub fn invalidate_cache(&mut self) {
307 self.cached_selected_count = None;
308 }
309}
310
311#[derive(Debug, Clone)]
322pub struct ChunkState {
323 state: FactorizationState,
325 selection: Option<FactorizedSelection>,
328 cached_multiplicities: Option<Arc<[usize]>>,
331 generation: u64,
333}
334
335impl ChunkState {
336 #[must_use]
338 pub fn flat(row_count: usize) -> Self {
339 Self {
340 state: FactorizationState::Flat { row_count },
341 selection: None,
342 cached_multiplicities: None,
343 generation: 0,
344 }
345 }
346
347 #[must_use]
349 pub fn unflat(level_count: usize, logical_rows: usize) -> Self {
350 Self {
351 state: FactorizationState::Unflat {
352 level_count,
353 logical_rows,
354 },
355 selection: None,
356 cached_multiplicities: None,
357 generation: 0,
358 }
359 }
360
361 #[must_use]
363 pub fn factorization_state(&self) -> FactorizationState {
364 self.state
365 }
366
367 #[must_use]
369 pub fn is_flat(&self) -> bool {
370 self.state.is_flat()
371 }
372
373 #[must_use]
375 pub fn is_factorized(&self) -> bool {
376 self.state.is_unflat()
377 }
378
379 #[must_use]
384 pub fn logical_row_count(&self) -> usize {
385 self.state.logical_row_count()
386 }
387
388 #[must_use]
390 pub fn level_count(&self) -> usize {
391 self.state.level_count()
392 }
393
394 #[must_use]
396 pub fn generation(&self) -> u64 {
397 self.generation
398 }
399
400 #[must_use]
402 pub fn selection(&self) -> Option<&FactorizedSelection> {
403 self.selection.as_ref()
404 }
405
406 pub fn selection_mut(&mut self) -> &mut Option<FactorizedSelection> {
408 &mut self.selection
409 }
410
411 pub fn set_selection(&mut self, selection: FactorizedSelection) {
413 self.selection = Some(selection);
414 }
416
417 pub fn clear_selection(&mut self) {
419 self.selection = None;
420 }
421
422 pub fn set_state(&mut self, state: FactorizationState) {
424 self.state = state;
425 self.invalidate_cache();
426 }
427
428 pub fn invalidate_cache(&mut self) {
430 self.cached_multiplicities = None;
431 self.generation += 1;
432 }
433
434 pub fn get_or_compute_multiplicities<F>(&mut self, compute: F) -> Arc<[usize]>
451 where
452 F: FnOnce() -> Vec<usize>,
453 {
454 if let Some(ref cached) = self.cached_multiplicities {
455 return Arc::clone(cached);
456 }
457
458 let mults: Arc<[usize]> = compute().into();
459 self.cached_multiplicities = Some(Arc::clone(&mults));
460 mults
461 }
462
463 #[must_use]
467 pub fn cached_multiplicities(&self) -> Option<&Arc<[usize]>> {
468 self.cached_multiplicities.as_ref()
469 }
470
471 pub fn set_cached_multiplicities(&mut self, mults: Arc<[usize]>) {
475 self.cached_multiplicities = Some(mults);
476 }
477}
478
479impl Default for ChunkState {
480 fn default() -> Self {
481 Self::flat(0)
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_factorization_state_flat() {
491 let state = FactorizationState::Flat { row_count: 100 };
492 assert!(state.is_flat());
493 assert!(!state.is_unflat());
494 assert_eq!(state.logical_row_count(), 100);
495 assert_eq!(state.level_count(), 1);
496 }
497
498 #[test]
499 fn test_factorization_state_unflat() {
500 let state = FactorizationState::Unflat {
501 level_count: 3,
502 logical_rows: 1000,
503 };
504 assert!(!state.is_flat());
505 assert!(state.is_unflat());
506 assert_eq!(state.logical_row_count(), 1000);
507 assert_eq!(state.level_count(), 3);
508 }
509
510 #[test]
511 fn test_level_selection_all() {
512 let sel = LevelSelection::all(10);
513 assert_eq!(sel.selected_count(), 10);
514 for i in 0..10 {
515 assert!(sel.is_selected(i));
516 }
517 assert!(!sel.is_selected(10));
518 }
519
520 #[test]
521 fn test_level_selection_filter() {
522 let sel = LevelSelection::all(10);
523 let filtered = sel.filter(|i| i % 2 == 0);
524 assert_eq!(filtered.selected_count(), 5);
525 assert!(filtered.is_selected(0));
526 assert!(!filtered.is_selected(1));
527 assert!(filtered.is_selected(2));
528 }
529
530 #[test]
531 fn test_level_selection_filter_sparse() {
532 let sel = LevelSelection::from_predicate(10, |i| i < 5);
534 assert_eq!(sel.selected_count(), 5);
535
536 let filtered = sel.filter(|i| i % 2 == 0);
538 assert_eq!(filtered.selected_count(), 3);
540 assert!(filtered.is_selected(0));
541 assert!(!filtered.is_selected(1));
542 assert!(filtered.is_selected(2));
543 }
544
545 #[test]
546 fn test_level_selection_iter_all() {
547 let sel = LevelSelection::all(5);
548 let indices: Vec<usize> = sel.iter().collect();
549 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
550 }
551
552 #[test]
553 fn test_level_selection_iter_sparse() {
554 let sel = LevelSelection::from_predicate(10, |i| i % 3 == 0);
555 let indices: Vec<usize> = sel.iter().collect();
556 assert_eq!(indices, vec![0, 3, 6, 9]);
557 }
558
559 #[test]
560 fn test_level_selection_from_predicate_all_selected() {
561 let sel = LevelSelection::from_predicate(5, |_| true);
563 assert_eq!(sel.selected_count(), 5);
564 match sel {
565 LevelSelection::All { count } => assert_eq!(count, 5),
566 LevelSelection::Sparse(_) => panic!("Expected All variant"),
567 }
568 }
569
570 #[test]
571 fn test_level_selection_from_predicate_partial() {
572 let sel = LevelSelection::from_predicate(10, |i| i < 3);
574 assert_eq!(sel.selected_count(), 3);
575 match sel {
576 LevelSelection::Sparse(_) => {}
577 LevelSelection::All { .. } => panic!("Expected Sparse variant"),
578 }
579 }
580
581 #[test]
582 fn test_factorized_selection_all() {
583 let sel = FactorizedSelection::all(&[10, 100, 1000]);
584 assert_eq!(sel.level_count(), 3);
585 assert!(sel.is_selected(0, 5));
586 assert!(sel.is_selected(1, 50));
587 assert!(sel.is_selected(2, 500));
588 }
589
590 #[test]
591 fn test_factorized_selection_new() {
592 let level_sels = vec![
593 LevelSelection::all(5),
594 LevelSelection::from_predicate(10, |i| i < 3),
595 ];
596 let sel = FactorizedSelection::new(level_sels);
597
598 assert_eq!(sel.level_count(), 2);
599 assert!(sel.is_selected(0, 4));
600 assert!(sel.is_selected(1, 2));
601 assert!(!sel.is_selected(1, 5));
602 }
603
604 #[test]
605 fn test_factorized_selection_filter_level() {
606 let sel = FactorizedSelection::all(&[10, 100]);
607 let filtered = sel.filter_level(1, |i| i < 50);
608
609 assert!(filtered.is_selected(0, 5)); assert!(filtered.is_selected(1, 25)); assert!(!filtered.is_selected(1, 75)); }
613
614 #[test]
615 fn test_factorized_selection_filter_level_invalid() {
616 let sel = FactorizedSelection::all(&[10, 100]);
617
618 let filtered = sel.filter_level(5, |_| true);
620 assert_eq!(filtered.level_count(), 2);
621 }
622
623 #[test]
624 fn test_factorized_selection_is_selected_invalid_level() {
625 let sel = FactorizedSelection::all(&[10]);
626 assert!(!sel.is_selected(5, 0)); }
628
629 #[test]
630 fn test_factorized_selection_level() {
631 let sel = FactorizedSelection::all(&[10, 20]);
632
633 let level0 = sel.level(0);
634 assert!(level0.is_some());
635 assert_eq!(level0.unwrap().selected_count(), 10);
636
637 let level1 = sel.level(1);
638 assert!(level1.is_some());
639 assert_eq!(level1.unwrap().selected_count(), 20);
640
641 assert!(sel.level(5).is_none());
642 }
643
644 #[test]
645 fn test_factorized_selection_selected_count_single_level() {
646 let mut sel = FactorizedSelection::all(&[10]);
647 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 10]];
648
649 let count = sel.selected_count(&multiplicities);
650 assert_eq!(count, 10);
651 }
652
653 #[test]
654 fn test_factorized_selection_selected_count_multi_level() {
655 let level_sels = vec![
656 LevelSelection::all(2), LevelSelection::from_predicate(4, |i| i % 2 == 0), ];
659 let mut sel = FactorizedSelection::new(level_sels);
660
661 let multiplicities = vec![
664 vec![1, 1], vec![2, 2], ];
667
668 let count = sel.selected_count(&multiplicities);
669 assert_eq!(count, 2);
672 }
673
674 #[test]
675 fn test_factorized_selection_selected_count_cached() {
676 let mut sel = FactorizedSelection::all(&[5]);
677 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
678
679 let count1 = sel.selected_count(&multiplicities);
681 assert_eq!(count1, 5);
682
683 let count2 = sel.selected_count(&multiplicities);
685 assert_eq!(count2, 5);
686 }
687
688 #[test]
689 fn test_factorized_selection_selected_count_empty() {
690 let mut sel = FactorizedSelection::all(&[]);
691 let multiplicities: Vec<Vec<usize>> = vec![];
692
693 assert_eq!(sel.selected_count(&multiplicities), 0);
694 }
695
696 #[test]
697 fn test_factorized_selection_invalidate_cache() {
698 let mut sel = FactorizedSelection::all(&[5]);
699 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
700
701 let _ = sel.selected_count(&multiplicities);
703
704 sel.invalidate_cache();
706
707 let _ = sel.selected_count(&multiplicities);
709 }
710
711 #[test]
712 fn test_chunk_state_flat() {
713 let state = ChunkState::flat(100);
714 assert!(state.is_flat());
715 assert!(!state.is_factorized());
716 assert_eq!(state.logical_row_count(), 100);
717 assert_eq!(state.level_count(), 1);
718 }
719
720 #[test]
721 fn test_chunk_state_unflat() {
722 let state = ChunkState::unflat(3, 1000);
723 assert!(!state.is_flat());
724 assert!(state.is_factorized());
725 assert_eq!(state.logical_row_count(), 1000);
726 assert_eq!(state.level_count(), 3);
727 }
728
729 #[test]
730 fn test_chunk_state_factorization_state() {
731 let state = ChunkState::flat(50);
732 let fs = state.factorization_state();
733 assert!(fs.is_flat());
734 }
735
736 #[test]
737 fn test_chunk_state_selection() {
738 let mut state = ChunkState::unflat(2, 100);
739
740 assert!(state.selection().is_none());
742
743 let sel = FactorizedSelection::all(&[10, 100]);
745 state.set_selection(sel);
746
747 assert!(state.selection().is_some());
748 assert_eq!(state.selection().unwrap().level_count(), 2);
749 }
750
751 #[test]
752 fn test_chunk_state_selection_mut() {
753 let mut state = ChunkState::unflat(2, 100);
754
755 let sel = FactorizedSelection::all(&[10, 100]);
757 state.set_selection(sel);
758
759 let sel_mut = state.selection_mut();
761 assert!(sel_mut.is_some());
762
763 *sel_mut = None;
765 assert!(state.selection().is_none());
766 }
767
768 #[test]
769 fn test_chunk_state_clear_selection() {
770 let mut state = ChunkState::unflat(2, 100);
771
772 let sel = FactorizedSelection::all(&[10, 100]);
773 state.set_selection(sel);
774 assert!(state.selection().is_some());
775
776 state.clear_selection();
777 assert!(state.selection().is_none());
778 }
779
780 #[test]
781 fn test_chunk_state_set_state() {
782 let mut state = ChunkState::flat(100);
783 assert!(state.is_flat());
784 assert_eq!(state.generation(), 0);
785
786 state.set_state(FactorizationState::Unflat {
787 level_count: 2,
788 logical_rows: 200,
789 });
790
791 assert!(state.is_factorized());
792 assert_eq!(state.logical_row_count(), 200);
793 assert_eq!(state.generation(), 1); }
795
796 #[test]
797 fn test_chunk_state_caching() {
798 let mut state = ChunkState::unflat(2, 100);
799
800 let mut computed = false;
802 let mults1 = state.get_or_compute_multiplicities(|| {
803 computed = true;
804 vec![1, 2, 3, 4, 5]
805 });
806 assert!(computed);
807 assert_eq!(mults1.len(), 5);
808
809 computed = false;
811 let mults2 = state.get_or_compute_multiplicities(|| {
812 computed = true;
813 vec![99, 99, 99]
814 });
815 assert!(!computed);
816 assert_eq!(mults2.len(), 5); state.invalidate_cache();
820 let mults3 = state.get_or_compute_multiplicities(|| {
821 computed = true;
822 vec![10, 20, 30]
823 });
824 assert!(computed);
825 assert_eq!(mults3.len(), 3);
826 }
827
828 #[test]
829 fn test_chunk_state_cached_multiplicities() {
830 let mut state = ChunkState::unflat(2, 100);
831
832 assert!(state.cached_multiplicities().is_none());
834
835 let _ = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
837
838 assert!(state.cached_multiplicities().is_some());
840 assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
841 }
842
843 #[test]
844 fn test_chunk_state_set_cached_multiplicities() {
845 let mut state = ChunkState::unflat(2, 100);
846
847 let mults: Arc<[usize]> = vec![5, 10, 15].into();
848 state.set_cached_multiplicities(mults);
849
850 assert!(state.cached_multiplicities().is_some());
851 assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
852 }
853
854 #[test]
855 fn test_chunk_state_generation() {
856 let mut state = ChunkState::flat(100);
857 assert_eq!(state.generation(), 0);
858
859 state.invalidate_cache();
860 assert_eq!(state.generation(), 1);
861
862 state.set_state(FactorizationState::Unflat {
863 level_count: 2,
864 logical_rows: 200,
865 });
866 assert_eq!(state.generation(), 2);
867 }
868
869 #[test]
870 fn test_chunk_state_default() {
871 let state = ChunkState::default();
872 assert!(state.is_flat());
873 assert_eq!(state.logical_row_count(), 0);
874 assert_eq!(state.generation(), 0);
875 }
876}