1use std::sync::Arc;
22
23use super::selection::SelectionVector;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum FactorizationState {
31 Flat {
35 row_count: usize,
37 },
38 Unflat {
42 level_count: usize,
44 logical_rows: usize,
46 },
47}
48
49impl FactorizationState {
50 #[must_use]
52 pub fn is_flat(&self) -> bool {
53 matches!(self, Self::Flat { .. })
54 }
55
56 #[must_use]
58 pub fn is_unflat(&self) -> bool {
59 matches!(self, Self::Unflat { .. })
60 }
61
62 #[must_use]
64 pub fn logical_row_count(&self) -> usize {
65 match self {
66 Self::Flat { row_count } => *row_count,
67 Self::Unflat { logical_rows, .. } => *logical_rows,
68 }
69 }
70
71 #[must_use]
73 pub fn level_count(&self) -> usize {
74 match self {
75 Self::Flat { .. } => 1,
76 Self::Unflat { level_count, .. } => *level_count,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
86pub enum LevelSelection {
87 All {
89 count: usize,
91 },
92 Sparse(SelectionVector),
96}
97
98impl LevelSelection {
99 #[must_use]
101 pub fn all(count: usize) -> Self {
102 Self::All { count }
103 }
104
105 #[must_use]
107 pub fn from_predicate<F>(count: usize, predicate: F) -> Self
108 where
109 F: Fn(usize) -> bool,
110 {
111 let selected = SelectionVector::from_predicate(count, predicate);
112 if selected.len() == count {
113 Self::All { count }
114 } else {
115 Self::Sparse(selected)
116 }
117 }
118
119 #[must_use]
121 pub fn selected_count(&self) -> usize {
122 match self {
123 Self::All { count } => *count,
124 Self::Sparse(sel) => sel.len(),
125 }
126 }
127
128 #[must_use]
130 pub fn is_selected(&self, physical_idx: usize) -> bool {
131 match self {
132 Self::All { count } => physical_idx < *count,
133 Self::Sparse(sel) => sel.contains(physical_idx),
134 }
135 }
136
137 #[must_use]
139 pub fn filter<F>(&self, predicate: F) -> Self
140 where
141 F: Fn(usize) -> bool,
142 {
143 match self {
144 Self::All { count } => Self::from_predicate(*count, predicate),
145 Self::Sparse(sel) => {
146 let filtered = sel.filter(predicate);
147 Self::Sparse(filtered)
148 }
149 }
150 }
151
152 pub fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
154 match self {
155 Self::All { count } => Box::new(0..*count),
156 Self::Sparse(sel) => Box::new(sel.iter()),
157 }
158 }
159}
160
161impl<'a> IntoIterator for &'a LevelSelection {
162 type Item = usize;
163 type IntoIter = Box<dyn Iterator<Item = usize> + 'a>;
164
165 fn into_iter(self) -> Self::IntoIter {
166 self.iter()
167 }
168}
169
170#[derive(Debug, Clone)]
175pub struct FactorizedSelection {
176 level_selections: Vec<LevelSelection>,
178 cached_selected_count: Option<usize>,
180}
181
182impl FactorizedSelection {
183 #[must_use]
185 pub fn all(level_counts: &[usize]) -> Self {
186 let level_selections = level_counts
187 .iter()
188 .map(|&count| LevelSelection::all(count))
189 .collect();
190 Self {
191 level_selections,
192 cached_selected_count: None,
193 }
194 }
195
196 #[must_use]
198 pub fn new(level_selections: Vec<LevelSelection>) -> Self {
199 Self {
200 level_selections,
201 cached_selected_count: None,
202 }
203 }
204
205 #[must_use]
207 pub fn level_count(&self) -> usize {
208 self.level_selections.len()
209 }
210
211 #[must_use]
213 pub fn level(&self, level: usize) -> Option<&LevelSelection> {
214 self.level_selections.get(level)
215 }
216
217 #[must_use]
223 pub fn filter_level<F>(&self, level: usize, predicate: F) -> Self
224 where
225 F: Fn(usize) -> bool,
226 {
227 let mut new_selections = self.level_selections.clone();
228
229 if let Some(sel) = new_selections.get_mut(level) {
230 *sel = sel.filter(predicate);
231 }
232
233 Self {
234 level_selections: new_selections,
235 cached_selected_count: None, }
237 }
238
239 #[must_use]
241 pub fn is_selected(&self, level: usize, physical_idx: usize) -> bool {
242 self.level_selections
243 .get(level)
244 .is_some_and(|sel| sel.is_selected(physical_idx))
245 }
246
247 pub fn selected_count(&mut self, multiplicities: &[Vec<usize>]) -> usize {
252 if let Some(count) = self.cached_selected_count {
253 return count;
254 }
255
256 let count = self.compute_selected_count(multiplicities);
257 self.cached_selected_count = Some(count);
258 count
259 }
260
261 fn compute_selected_count(&self, multiplicities: &[Vec<usize>]) -> usize {
263 if self.level_selections.is_empty() {
264 return 0;
265 }
266
267 if self.level_selections.len() == 1 {
269 return self.level_selections[0].selected_count();
270 }
271
272 let mut parent_selected: Vec<bool> = match &self.level_selections[0] {
275 LevelSelection::All { count } => vec![true; *count],
276 LevelSelection::Sparse(sel) => {
277 let max_idx = sel.iter().max().unwrap_or(0);
278 let mut selected = vec![false; max_idx + 1];
279 for idx in sel.iter() {
280 selected[idx] = true;
281 }
282 selected
283 }
284 };
285
286 for (level_sel, level_mults) in self
288 .level_selections
289 .iter()
290 .skip(1)
291 .zip(multiplicities.iter().skip(1))
292 {
293 let mut child_selected = Vec::new();
294 let mut child_idx = 0;
295
296 for (parent_idx, &mult) in level_mults.iter().enumerate() {
297 let parent_is_selected = parent_selected.get(parent_idx).copied().unwrap_or(false);
298
299 for _ in 0..mult {
300 let child_is_selected = parent_is_selected && level_sel.is_selected(child_idx);
301 child_selected.push(child_is_selected);
302 child_idx += 1;
303 }
304 }
305
306 parent_selected = child_selected;
307 }
308
309 parent_selected.iter().filter(|&&s| s).count()
311 }
312
313 pub fn invalidate_cache(&mut self) {
315 self.cached_selected_count = None;
316 }
317}
318
319#[derive(Debug, Clone)]
330pub struct ChunkState {
331 state: FactorizationState,
333 selection: Option<FactorizedSelection>,
336 cached_multiplicities: Option<Arc<[usize]>>,
339 generation: u64,
341}
342
343impl ChunkState {
344 #[must_use]
346 pub fn flat(row_count: usize) -> Self {
347 Self {
348 state: FactorizationState::Flat { row_count },
349 selection: None,
350 cached_multiplicities: None,
351 generation: 0,
352 }
353 }
354
355 #[must_use]
357 pub fn unflat(level_count: usize, logical_rows: usize) -> Self {
358 Self {
359 state: FactorizationState::Unflat {
360 level_count,
361 logical_rows,
362 },
363 selection: None,
364 cached_multiplicities: None,
365 generation: 0,
366 }
367 }
368
369 #[must_use]
371 pub fn factorization_state(&self) -> FactorizationState {
372 self.state
373 }
374
375 #[must_use]
377 pub fn is_flat(&self) -> bool {
378 self.state.is_flat()
379 }
380
381 #[must_use]
383 pub fn is_factorized(&self) -> bool {
384 self.state.is_unflat()
385 }
386
387 #[must_use]
392 pub fn logical_row_count(&self) -> usize {
393 self.state.logical_row_count()
394 }
395
396 #[must_use]
398 pub fn level_count(&self) -> usize {
399 self.state.level_count()
400 }
401
402 #[must_use]
404 pub fn generation(&self) -> u64 {
405 self.generation
406 }
407
408 #[must_use]
410 pub fn selection(&self) -> Option<&FactorizedSelection> {
411 self.selection.as_ref()
412 }
413
414 pub fn selection_mut(&mut self) -> &mut Option<FactorizedSelection> {
416 &mut self.selection
417 }
418
419 pub fn set_selection(&mut self, selection: FactorizedSelection) {
421 self.selection = Some(selection);
422 }
424
425 pub fn clear_selection(&mut self) {
427 self.selection = None;
428 }
429
430 pub fn set_state(&mut self, state: FactorizationState) {
432 self.state = state;
433 self.invalidate_cache();
434 }
435
436 pub fn invalidate_cache(&mut self) {
438 self.cached_multiplicities = None;
439 self.generation += 1;
440 }
441
442 pub fn get_or_compute_multiplicities<F>(&mut self, compute: F) -> Arc<[usize]>
459 where
460 F: FnOnce() -> Vec<usize>,
461 {
462 if let Some(ref cached) = self.cached_multiplicities {
463 return Arc::clone(cached);
464 }
465
466 let mults: Arc<[usize]> = compute().into();
467 self.cached_multiplicities = Some(Arc::clone(&mults));
468 mults
469 }
470
471 #[must_use]
475 pub fn cached_multiplicities(&self) -> Option<&Arc<[usize]>> {
476 self.cached_multiplicities.as_ref()
477 }
478
479 pub fn set_cached_multiplicities(&mut self, mults: Arc<[usize]>) {
483 self.cached_multiplicities = Some(mults);
484 }
485}
486
487impl Default for ChunkState {
488 fn default() -> Self {
489 Self::flat(0)
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_factorization_state_flat() {
499 let state = FactorizationState::Flat { row_count: 100 };
500 assert!(state.is_flat());
501 assert!(!state.is_unflat());
502 assert_eq!(state.logical_row_count(), 100);
503 assert_eq!(state.level_count(), 1);
504 }
505
506 #[test]
507 fn test_factorization_state_unflat() {
508 let state = FactorizationState::Unflat {
509 level_count: 3,
510 logical_rows: 1000,
511 };
512 assert!(!state.is_flat());
513 assert!(state.is_unflat());
514 assert_eq!(state.logical_row_count(), 1000);
515 assert_eq!(state.level_count(), 3);
516 }
517
518 #[test]
519 fn test_level_selection_all() {
520 let sel = LevelSelection::all(10);
521 assert_eq!(sel.selected_count(), 10);
522 for i in 0..10 {
523 assert!(sel.is_selected(i));
524 }
525 assert!(!sel.is_selected(10));
526 }
527
528 #[test]
529 fn test_level_selection_filter() {
530 let sel = LevelSelection::all(10);
531 let filtered = sel.filter(|i| i % 2 == 0);
532 assert_eq!(filtered.selected_count(), 5);
533 assert!(filtered.is_selected(0));
534 assert!(!filtered.is_selected(1));
535 assert!(filtered.is_selected(2));
536 }
537
538 #[test]
539 fn test_level_selection_filter_sparse() {
540 let sel = LevelSelection::from_predicate(10, |i| i < 5);
542 assert_eq!(sel.selected_count(), 5);
543
544 let filtered = sel.filter(|i| i % 2 == 0);
546 assert_eq!(filtered.selected_count(), 3);
548 assert!(filtered.is_selected(0));
549 assert!(!filtered.is_selected(1));
550 assert!(filtered.is_selected(2));
551 }
552
553 #[test]
554 fn test_level_selection_iter_all() {
555 let sel = LevelSelection::all(5);
556 let indices: Vec<usize> = sel.iter().collect();
557 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
558 }
559
560 #[test]
561 fn test_level_selection_iter_sparse() {
562 let sel = LevelSelection::from_predicate(10, |i| i % 3 == 0);
563 let indices: Vec<usize> = sel.iter().collect();
564 assert_eq!(indices, vec![0, 3, 6, 9]);
565 }
566
567 #[test]
568 fn test_level_selection_from_predicate_all_selected() {
569 let sel = LevelSelection::from_predicate(5, |_| true);
571 assert_eq!(sel.selected_count(), 5);
572 match sel {
573 LevelSelection::All { count } => assert_eq!(count, 5),
574 LevelSelection::Sparse(_) => panic!("Expected All variant"),
575 }
576 }
577
578 #[test]
579 fn test_level_selection_from_predicate_partial() {
580 let sel = LevelSelection::from_predicate(10, |i| i < 3);
582 assert_eq!(sel.selected_count(), 3);
583 match sel {
584 LevelSelection::Sparse(_) => {}
585 LevelSelection::All { .. } => panic!("Expected Sparse variant"),
586 }
587 }
588
589 #[test]
590 fn test_factorized_selection_all() {
591 let sel = FactorizedSelection::all(&[10, 100, 1000]);
592 assert_eq!(sel.level_count(), 3);
593 assert!(sel.is_selected(0, 5));
594 assert!(sel.is_selected(1, 50));
595 assert!(sel.is_selected(2, 500));
596 }
597
598 #[test]
599 fn test_factorized_selection_new() {
600 let level_sels = vec![
601 LevelSelection::all(5),
602 LevelSelection::from_predicate(10, |i| i < 3),
603 ];
604 let sel = FactorizedSelection::new(level_sels);
605
606 assert_eq!(sel.level_count(), 2);
607 assert!(sel.is_selected(0, 4));
608 assert!(sel.is_selected(1, 2));
609 assert!(!sel.is_selected(1, 5));
610 }
611
612 #[test]
613 fn test_factorized_selection_filter_level() {
614 let sel = FactorizedSelection::all(&[10, 100]);
615 let filtered = sel.filter_level(1, |i| i < 50);
616
617 assert!(filtered.is_selected(0, 5)); assert!(filtered.is_selected(1, 25)); assert!(!filtered.is_selected(1, 75)); }
621
622 #[test]
623 fn test_factorized_selection_filter_level_invalid() {
624 let sel = FactorizedSelection::all(&[10, 100]);
625
626 let filtered = sel.filter_level(5, |_| true);
628 assert_eq!(filtered.level_count(), 2);
629 }
630
631 #[test]
632 fn test_factorized_selection_is_selected_invalid_level() {
633 let sel = FactorizedSelection::all(&[10]);
634 assert!(!sel.is_selected(5, 0)); }
636
637 #[test]
638 fn test_factorized_selection_level() {
639 let sel = FactorizedSelection::all(&[10, 20]);
640
641 let level0 = sel.level(0);
642 assert!(level0.is_some());
643 assert_eq!(level0.unwrap().selected_count(), 10);
644
645 let level1 = sel.level(1);
646 assert!(level1.is_some());
647 assert_eq!(level1.unwrap().selected_count(), 20);
648
649 assert!(sel.level(5).is_none());
650 }
651
652 #[test]
653 fn test_factorized_selection_selected_count_single_level() {
654 let mut sel = FactorizedSelection::all(&[10]);
655 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 10]];
656
657 let count = sel.selected_count(&multiplicities);
658 assert_eq!(count, 10);
659 }
660
661 #[test]
662 fn test_factorized_selection_selected_count_multi_level() {
663 let level_sels = vec![
664 LevelSelection::all(2), LevelSelection::from_predicate(4, |i| i % 2 == 0), ];
667 let mut sel = FactorizedSelection::new(level_sels);
668
669 let multiplicities = vec![
672 vec![1, 1], vec![2, 2], ];
675
676 let count = sel.selected_count(&multiplicities);
677 assert_eq!(count, 2);
680 }
681
682 #[test]
683 fn test_factorized_selection_selected_count_cached() {
684 let mut sel = FactorizedSelection::all(&[5]);
685 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
686
687 let count1 = sel.selected_count(&multiplicities);
689 assert_eq!(count1, 5);
690
691 let count2 = sel.selected_count(&multiplicities);
693 assert_eq!(count2, 5);
694 }
695
696 #[test]
697 fn test_factorized_selection_selected_count_empty() {
698 let mut sel = FactorizedSelection::all(&[]);
699 let multiplicities: Vec<Vec<usize>> = vec![];
700
701 assert_eq!(sel.selected_count(&multiplicities), 0);
702 }
703
704 #[test]
705 fn test_factorized_selection_invalidate_cache() {
706 let mut sel = FactorizedSelection::all(&[5]);
707 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
708
709 let _ = sel.selected_count(&multiplicities);
711
712 sel.invalidate_cache();
714
715 let _ = sel.selected_count(&multiplicities);
717 }
718
719 #[test]
720 fn test_chunk_state_flat() {
721 let state = ChunkState::flat(100);
722 assert!(state.is_flat());
723 assert!(!state.is_factorized());
724 assert_eq!(state.logical_row_count(), 100);
725 assert_eq!(state.level_count(), 1);
726 }
727
728 #[test]
729 fn test_chunk_state_unflat() {
730 let state = ChunkState::unflat(3, 1000);
731 assert!(!state.is_flat());
732 assert!(state.is_factorized());
733 assert_eq!(state.logical_row_count(), 1000);
734 assert_eq!(state.level_count(), 3);
735 }
736
737 #[test]
738 fn test_chunk_state_factorization_state() {
739 let state = ChunkState::flat(50);
740 let fs = state.factorization_state();
741 assert!(fs.is_flat());
742 }
743
744 #[test]
745 fn test_chunk_state_selection() {
746 let mut state = ChunkState::unflat(2, 100);
747
748 assert!(state.selection().is_none());
750
751 let sel = FactorizedSelection::all(&[10, 100]);
753 state.set_selection(sel);
754
755 assert!(state.selection().is_some());
756 assert_eq!(state.selection().unwrap().level_count(), 2);
757 }
758
759 #[test]
760 fn test_chunk_state_selection_mut() {
761 let mut state = ChunkState::unflat(2, 100);
762
763 let sel = FactorizedSelection::all(&[10, 100]);
765 state.set_selection(sel);
766
767 let sel_mut = state.selection_mut();
769 assert!(sel_mut.is_some());
770
771 *sel_mut = None;
773 assert!(state.selection().is_none());
774 }
775
776 #[test]
777 fn test_chunk_state_clear_selection() {
778 let mut state = ChunkState::unflat(2, 100);
779
780 let sel = FactorizedSelection::all(&[10, 100]);
781 state.set_selection(sel);
782 assert!(state.selection().is_some());
783
784 state.clear_selection();
785 assert!(state.selection().is_none());
786 }
787
788 #[test]
789 fn test_chunk_state_set_state() {
790 let mut state = ChunkState::flat(100);
791 assert!(state.is_flat());
792 assert_eq!(state.generation(), 0);
793
794 state.set_state(FactorizationState::Unflat {
795 level_count: 2,
796 logical_rows: 200,
797 });
798
799 assert!(state.is_factorized());
800 assert_eq!(state.logical_row_count(), 200);
801 assert_eq!(state.generation(), 1); }
803
804 #[test]
805 fn test_chunk_state_caching() {
806 let mut state = ChunkState::unflat(2, 100);
807
808 let mut computed = false;
810 let mults1 = state.get_or_compute_multiplicities(|| {
811 computed = true;
812 vec![1, 2, 3, 4, 5]
813 });
814 assert!(computed);
815 assert_eq!(mults1.len(), 5);
816
817 computed = false;
819 let mults2 = state.get_or_compute_multiplicities(|| {
820 computed = true;
821 vec![99, 99, 99]
822 });
823 assert!(!computed);
824 assert_eq!(mults2.len(), 5); state.invalidate_cache();
828 let mults3 = state.get_or_compute_multiplicities(|| {
829 computed = true;
830 vec![10, 20, 30]
831 });
832 assert!(computed);
833 assert_eq!(mults3.len(), 3);
834 }
835
836 #[test]
837 fn test_chunk_state_cached_multiplicities() {
838 let mut state = ChunkState::unflat(2, 100);
839
840 assert!(state.cached_multiplicities().is_none());
842
843 let _ = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
845
846 assert!(state.cached_multiplicities().is_some());
848 assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
849 }
850
851 #[test]
852 fn test_chunk_state_set_cached_multiplicities() {
853 let mut state = ChunkState::unflat(2, 100);
854
855 let mults: Arc<[usize]> = vec![5, 10, 15].into();
856 state.set_cached_multiplicities(mults);
857
858 assert!(state.cached_multiplicities().is_some());
859 assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
860 }
861
862 #[test]
863 fn test_chunk_state_generation() {
864 let mut state = ChunkState::flat(100);
865 assert_eq!(state.generation(), 0);
866
867 state.invalidate_cache();
868 assert_eq!(state.generation(), 1);
869
870 state.set_state(FactorizationState::Unflat {
871 level_count: 2,
872 logical_rows: 200,
873 });
874 assert_eq!(state.generation(), 2);
875 }
876
877 #[test]
878 fn test_chunk_state_default() {
879 let state = ChunkState::default();
880 assert!(state.is_flat());
881 assert_eq!(state.logical_row_count(), 0);
882 assert_eq!(state.generation(), 0);
883 }
884}