1use std::sync::Arc;
24
25use super::selection::SelectionVector;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum FactorizationState {
33 Flat {
37 row_count: usize,
39 },
40 Unflat {
44 level_count: usize,
46 logical_rows: usize,
48 },
49}
50
51impl FactorizationState {
52 #[must_use]
54 pub fn is_flat(&self) -> bool {
55 matches!(self, Self::Flat { .. })
56 }
57
58 #[must_use]
60 pub fn is_unflat(&self) -> bool {
61 matches!(self, Self::Unflat { .. })
62 }
63
64 #[must_use]
66 pub fn logical_row_count(&self) -> usize {
67 match self {
68 Self::Flat { row_count } => *row_count,
69 Self::Unflat { logical_rows, .. } => *logical_rows,
70 }
71 }
72
73 #[must_use]
75 pub fn level_count(&self) -> usize {
76 match self {
77 Self::Flat { .. } => 1,
78 Self::Unflat { level_count, .. } => *level_count,
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
88pub enum LevelSelection {
89 All {
91 count: usize,
93 },
94 Sparse(SelectionVector),
98}
99
100impl LevelSelection {
101 #[must_use]
103 pub fn all(count: usize) -> Self {
104 Self::All { count }
105 }
106
107 #[must_use]
109 pub fn from_predicate<F>(count: usize, predicate: F) -> Self
110 where
111 F: Fn(usize) -> bool,
112 {
113 let selected = SelectionVector::from_predicate(count, predicate);
114 if selected.len() == count {
115 Self::All { count }
116 } else {
117 Self::Sparse(selected)
118 }
119 }
120
121 #[must_use]
123 pub fn selected_count(&self) -> usize {
124 match self {
125 Self::All { count } => *count,
126 Self::Sparse(sel) => sel.len(),
127 }
128 }
129
130 #[must_use]
132 pub fn is_selected(&self, physical_idx: usize) -> bool {
133 match self {
134 Self::All { count } => physical_idx < *count,
135 Self::Sparse(sel) => sel.contains(physical_idx),
136 }
137 }
138
139 #[must_use]
141 pub fn filter<F>(&self, predicate: F) -> Self
142 where
143 F: Fn(usize) -> bool,
144 {
145 match self {
146 Self::All { count } => Self::from_predicate(*count, predicate),
147 Self::Sparse(sel) => {
148 let filtered = sel.filter(predicate);
149 Self::Sparse(filtered)
150 }
151 }
152 }
153
154 pub fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
156 match self {
157 Self::All { count } => Box::new(0..*count),
158 Self::Sparse(sel) => Box::new(sel.iter()),
159 }
160 }
161}
162
163impl<'a> IntoIterator for &'a LevelSelection {
164 type Item = usize;
165 type IntoIter = Box<dyn Iterator<Item = usize> + 'a>;
166
167 fn into_iter(self) -> Self::IntoIter {
168 self.iter()
169 }
170}
171
172#[derive(Debug, Clone)]
177pub struct FactorizedSelection {
178 level_selections: Vec<LevelSelection>,
180 cached_selected_count: Option<usize>,
182}
183
184impl FactorizedSelection {
185 #[must_use]
187 pub fn all(level_counts: &[usize]) -> Self {
188 let level_selections = level_counts
189 .iter()
190 .map(|&count| LevelSelection::all(count))
191 .collect();
192 Self {
193 level_selections,
194 cached_selected_count: None,
195 }
196 }
197
198 #[must_use]
200 pub fn new(level_selections: Vec<LevelSelection>) -> Self {
201 Self {
202 level_selections,
203 cached_selected_count: None,
204 }
205 }
206
207 #[must_use]
209 pub fn level_count(&self) -> usize {
210 self.level_selections.len()
211 }
212
213 #[must_use]
215 pub fn level(&self, level: usize) -> Option<&LevelSelection> {
216 self.level_selections.get(level)
217 }
218
219 #[must_use]
225 pub fn filter_level<F>(&self, level: usize, predicate: F) -> Self
226 where
227 F: Fn(usize) -> bool,
228 {
229 let mut new_selections = self.level_selections.clone();
230
231 if let Some(sel) = new_selections.get_mut(level) {
232 *sel = sel.filter(predicate);
233 }
234
235 Self {
236 level_selections: new_selections,
237 cached_selected_count: None, }
239 }
240
241 #[must_use]
243 pub fn is_selected(&self, level: usize, physical_idx: usize) -> bool {
244 self.level_selections
245 .get(level)
246 .is_some_and(|sel| sel.is_selected(physical_idx))
247 }
248
249 pub fn selected_count(&mut self, multiplicities: &[Vec<usize>]) -> usize {
254 if let Some(count) = self.cached_selected_count {
255 return count;
256 }
257
258 let count = self.compute_selected_count(multiplicities);
259 self.cached_selected_count = Some(count);
260 count
261 }
262
263 fn compute_selected_count(&self, multiplicities: &[Vec<usize>]) -> usize {
265 if self.level_selections.is_empty() {
266 return 0;
267 }
268
269 if self.level_selections.len() == 1 {
271 return self.level_selections[0].selected_count();
272 }
273
274 let mut parent_selected: Vec<bool> = match &self.level_selections[0] {
277 LevelSelection::All { count } => vec![true; *count],
278 LevelSelection::Sparse(sel) => {
279 let max_idx = sel.iter().max().unwrap_or(0);
280 let mut selected = vec![false; max_idx + 1];
281 for idx in sel.iter() {
282 selected[idx] = true;
283 }
284 selected
285 }
286 };
287
288 for (level_sel, level_mults) in self
290 .level_selections
291 .iter()
292 .skip(1)
293 .zip(multiplicities.iter().skip(1))
294 {
295 let mut child_selected = Vec::new();
296 let mut child_idx = 0;
297
298 for (parent_idx, &mult) in level_mults.iter().enumerate() {
299 let parent_is_selected = parent_selected.get(parent_idx).copied().unwrap_or(false);
300
301 for _ in 0..mult {
302 let child_is_selected = parent_is_selected && level_sel.is_selected(child_idx);
303 child_selected.push(child_is_selected);
304 child_idx += 1;
305 }
306 }
307
308 parent_selected = child_selected;
309 }
310
311 parent_selected.iter().filter(|&&s| s).count()
313 }
314
315 pub fn invalidate_cache(&mut self) {
317 self.cached_selected_count = None;
318 }
319}
320
321#[derive(Debug, Clone)]
332pub struct ChunkState {
333 state: FactorizationState,
335 selection: Option<FactorizedSelection>,
338 cached_multiplicities: Option<Arc<[usize]>>,
341 generation: u64,
343}
344
345impl ChunkState {
346 #[must_use]
348 pub fn flat(row_count: usize) -> Self {
349 Self {
350 state: FactorizationState::Flat { row_count },
351 selection: None,
352 cached_multiplicities: None,
353 generation: 0,
354 }
355 }
356
357 #[must_use]
359 pub fn unflat(level_count: usize, logical_rows: usize) -> Self {
360 Self {
361 state: FactorizationState::Unflat {
362 level_count,
363 logical_rows,
364 },
365 selection: None,
366 cached_multiplicities: None,
367 generation: 0,
368 }
369 }
370
371 #[must_use]
373 pub fn factorization_state(&self) -> FactorizationState {
374 self.state
375 }
376
377 #[must_use]
379 pub fn is_flat(&self) -> bool {
380 self.state.is_flat()
381 }
382
383 #[must_use]
385 pub fn is_factorized(&self) -> bool {
386 self.state.is_unflat()
387 }
388
389 #[must_use]
394 pub fn logical_row_count(&self) -> usize {
395 self.state.logical_row_count()
396 }
397
398 #[must_use]
400 pub fn level_count(&self) -> usize {
401 self.state.level_count()
402 }
403
404 #[must_use]
406 pub fn generation(&self) -> u64 {
407 self.generation
408 }
409
410 #[must_use]
412 pub fn selection(&self) -> Option<&FactorizedSelection> {
413 self.selection.as_ref()
414 }
415
416 pub fn selection_mut(&mut self) -> &mut Option<FactorizedSelection> {
418 &mut self.selection
419 }
420
421 pub fn set_selection(&mut self, selection: FactorizedSelection) {
423 self.selection = Some(selection);
424 }
426
427 pub fn clear_selection(&mut self) {
429 self.selection = None;
430 }
431
432 pub fn set_state(&mut self, state: FactorizationState) {
434 self.state = state;
435 self.invalidate_cache();
436 }
437
438 pub fn invalidate_cache(&mut self) {
440 self.cached_multiplicities = None;
441 self.generation += 1;
442 }
443
444 pub fn get_or_compute_multiplicities<F>(&mut self, compute: F) -> Arc<[usize]>
463 where
464 F: FnOnce() -> Vec<usize>,
465 {
466 if let Some(ref cached) = self.cached_multiplicities {
467 return Arc::clone(cached);
468 }
469
470 let mults: Arc<[usize]> = compute().into();
471 self.cached_multiplicities = Some(Arc::clone(&mults));
472 mults
473 }
474
475 #[must_use]
479 pub fn cached_multiplicities(&self) -> Option<&Arc<[usize]>> {
480 self.cached_multiplicities.as_ref()
481 }
482
483 pub fn set_cached_multiplicities(&mut self, mults: Arc<[usize]>) {
487 self.cached_multiplicities = Some(mults);
488 }
489}
490
491impl Default for ChunkState {
492 fn default() -> Self {
493 Self::flat(0)
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_factorization_state_flat() {
503 let state = FactorizationState::Flat { row_count: 100 };
504 assert!(state.is_flat());
505 assert!(!state.is_unflat());
506 assert_eq!(state.logical_row_count(), 100);
507 assert_eq!(state.level_count(), 1);
508 }
509
510 #[test]
511 fn test_factorization_state_unflat() {
512 let state = FactorizationState::Unflat {
513 level_count: 3,
514 logical_rows: 1000,
515 };
516 assert!(!state.is_flat());
517 assert!(state.is_unflat());
518 assert_eq!(state.logical_row_count(), 1000);
519 assert_eq!(state.level_count(), 3);
520 }
521
522 #[test]
523 fn test_level_selection_all() {
524 let sel = LevelSelection::all(10);
525 assert_eq!(sel.selected_count(), 10);
526 for i in 0..10 {
527 assert!(sel.is_selected(i));
528 }
529 assert!(!sel.is_selected(10));
530 }
531
532 #[test]
533 fn test_level_selection_filter() {
534 let sel = LevelSelection::all(10);
535 let filtered = sel.filter(|i| i % 2 == 0);
536 assert_eq!(filtered.selected_count(), 5);
537 assert!(filtered.is_selected(0));
538 assert!(!filtered.is_selected(1));
539 assert!(filtered.is_selected(2));
540 }
541
542 #[test]
543 fn test_level_selection_filter_sparse() {
544 let sel = LevelSelection::from_predicate(10, |i| i < 5);
546 assert_eq!(sel.selected_count(), 5);
547
548 let filtered = sel.filter(|i| i % 2 == 0);
550 assert_eq!(filtered.selected_count(), 3);
552 assert!(filtered.is_selected(0));
553 assert!(!filtered.is_selected(1));
554 assert!(filtered.is_selected(2));
555 }
556
557 #[test]
558 fn test_level_selection_iter_all() {
559 let sel = LevelSelection::all(5);
560 let indices: Vec<usize> = sel.iter().collect();
561 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
562 }
563
564 #[test]
565 fn test_level_selection_iter_sparse() {
566 let sel = LevelSelection::from_predicate(10, |i| i % 3 == 0);
567 let indices: Vec<usize> = sel.iter().collect();
568 assert_eq!(indices, vec![0, 3, 6, 9]);
569 }
570
571 #[test]
572 fn test_level_selection_from_predicate_all_selected() {
573 let sel = LevelSelection::from_predicate(5, |_| true);
575 assert_eq!(sel.selected_count(), 5);
576 match sel {
577 LevelSelection::All { count } => assert_eq!(count, 5),
578 LevelSelection::Sparse(_) => panic!("Expected All variant"),
579 }
580 }
581
582 #[test]
583 fn test_level_selection_from_predicate_partial() {
584 let sel = LevelSelection::from_predicate(10, |i| i < 3);
586 assert_eq!(sel.selected_count(), 3);
587 match sel {
588 LevelSelection::Sparse(_) => {}
589 LevelSelection::All { .. } => panic!("Expected Sparse variant"),
590 }
591 }
592
593 #[test]
594 fn test_factorized_selection_all() {
595 let sel = FactorizedSelection::all(&[10, 100, 1000]);
596 assert_eq!(sel.level_count(), 3);
597 assert!(sel.is_selected(0, 5));
598 assert!(sel.is_selected(1, 50));
599 assert!(sel.is_selected(2, 500));
600 }
601
602 #[test]
603 fn test_factorized_selection_new() {
604 let level_sels = vec![
605 LevelSelection::all(5),
606 LevelSelection::from_predicate(10, |i| i < 3),
607 ];
608 let sel = FactorizedSelection::new(level_sels);
609
610 assert_eq!(sel.level_count(), 2);
611 assert!(sel.is_selected(0, 4));
612 assert!(sel.is_selected(1, 2));
613 assert!(!sel.is_selected(1, 5));
614 }
615
616 #[test]
617 fn test_factorized_selection_filter_level() {
618 let sel = FactorizedSelection::all(&[10, 100]);
619 let filtered = sel.filter_level(1, |i| i < 50);
620
621 assert!(filtered.is_selected(0, 5)); assert!(filtered.is_selected(1, 25)); assert!(!filtered.is_selected(1, 75)); }
625
626 #[test]
627 fn test_factorized_selection_filter_level_invalid() {
628 let sel = FactorizedSelection::all(&[10, 100]);
629
630 let filtered = sel.filter_level(5, |_| true);
632 assert_eq!(filtered.level_count(), 2);
633 }
634
635 #[test]
636 fn test_factorized_selection_is_selected_invalid_level() {
637 let sel = FactorizedSelection::all(&[10]);
638 assert!(!sel.is_selected(5, 0)); }
640
641 #[test]
642 fn test_factorized_selection_level() {
643 let sel = FactorizedSelection::all(&[10, 20]);
644
645 let level0 = sel.level(0);
646 assert!(level0.is_some());
647 assert_eq!(level0.unwrap().selected_count(), 10);
648
649 let level1 = sel.level(1);
650 assert!(level1.is_some());
651 assert_eq!(level1.unwrap().selected_count(), 20);
652
653 assert!(sel.level(5).is_none());
654 }
655
656 #[test]
657 fn test_factorized_selection_selected_count_single_level() {
658 let mut sel = FactorizedSelection::all(&[10]);
659 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 10]];
660
661 let count = sel.selected_count(&multiplicities);
662 assert_eq!(count, 10);
663 }
664
665 #[test]
666 fn test_factorized_selection_selected_count_multi_level() {
667 let level_sels = vec![
668 LevelSelection::all(2), LevelSelection::from_predicate(4, |i| i % 2 == 0), ];
671 let mut sel = FactorizedSelection::new(level_sels);
672
673 let multiplicities = vec![
676 vec![1, 1], vec![2, 2], ];
679
680 let count = sel.selected_count(&multiplicities);
681 assert_eq!(count, 2);
684 }
685
686 #[test]
687 fn test_factorized_selection_selected_count_cached() {
688 let mut sel = FactorizedSelection::all(&[5]);
689 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
690
691 let count1 = sel.selected_count(&multiplicities);
693 assert_eq!(count1, 5);
694
695 let count2 = sel.selected_count(&multiplicities);
697 assert_eq!(count2, 5);
698 }
699
700 #[test]
701 fn test_factorized_selection_selected_count_empty() {
702 let mut sel = FactorizedSelection::all(&[]);
703 let multiplicities: Vec<Vec<usize>> = vec![];
704
705 assert_eq!(sel.selected_count(&multiplicities), 0);
706 }
707
708 #[test]
709 fn test_factorized_selection_invalidate_cache() {
710 let mut sel = FactorizedSelection::all(&[5]);
711 let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
712
713 let _ = sel.selected_count(&multiplicities);
715
716 sel.invalidate_cache();
718
719 let _ = sel.selected_count(&multiplicities);
721 }
722
723 #[test]
724 fn test_chunk_state_flat() {
725 let state = ChunkState::flat(100);
726 assert!(state.is_flat());
727 assert!(!state.is_factorized());
728 assert_eq!(state.logical_row_count(), 100);
729 assert_eq!(state.level_count(), 1);
730 }
731
732 #[test]
733 fn test_chunk_state_unflat() {
734 let state = ChunkState::unflat(3, 1000);
735 assert!(!state.is_flat());
736 assert!(state.is_factorized());
737 assert_eq!(state.logical_row_count(), 1000);
738 assert_eq!(state.level_count(), 3);
739 }
740
741 #[test]
742 fn test_chunk_state_factorization_state() {
743 let state = ChunkState::flat(50);
744 let fs = state.factorization_state();
745 assert!(fs.is_flat());
746 }
747
748 #[test]
749 fn test_chunk_state_selection() {
750 let mut state = ChunkState::unflat(2, 100);
751
752 assert!(state.selection().is_none());
754
755 let sel = FactorizedSelection::all(&[10, 100]);
757 state.set_selection(sel);
758
759 assert!(state.selection().is_some());
760 assert_eq!(state.selection().unwrap().level_count(), 2);
761 }
762
763 #[test]
764 fn test_chunk_state_selection_mut() {
765 let mut state = ChunkState::unflat(2, 100);
766
767 let sel = FactorizedSelection::all(&[10, 100]);
769 state.set_selection(sel);
770
771 let sel_mut = state.selection_mut();
773 assert!(sel_mut.is_some());
774
775 *sel_mut = None;
777 assert!(state.selection().is_none());
778 }
779
780 #[test]
781 fn test_chunk_state_clear_selection() {
782 let mut state = ChunkState::unflat(2, 100);
783
784 let sel = FactorizedSelection::all(&[10, 100]);
785 state.set_selection(sel);
786 assert!(state.selection().is_some());
787
788 state.clear_selection();
789 assert!(state.selection().is_none());
790 }
791
792 #[test]
793 fn test_chunk_state_set_state() {
794 let mut state = ChunkState::flat(100);
795 assert!(state.is_flat());
796 assert_eq!(state.generation(), 0);
797
798 state.set_state(FactorizationState::Unflat {
799 level_count: 2,
800 logical_rows: 200,
801 });
802
803 assert!(state.is_factorized());
804 assert_eq!(state.logical_row_count(), 200);
805 assert_eq!(state.generation(), 1); }
807
808 #[test]
809 fn test_chunk_state_caching() {
810 let mut state = ChunkState::unflat(2, 100);
811
812 let mut computed = false;
814 let mults1 = state.get_or_compute_multiplicities(|| {
815 computed = true;
816 vec![1, 2, 3, 4, 5]
817 });
818 assert!(computed);
819 assert_eq!(mults1.len(), 5);
820
821 computed = false;
823 let mults2 = state.get_or_compute_multiplicities(|| {
824 computed = true;
825 vec![99, 99, 99]
826 });
827 assert!(!computed);
828 assert_eq!(mults2.len(), 5); state.invalidate_cache();
832 let mults3 = state.get_or_compute_multiplicities(|| {
833 computed = true;
834 vec![10, 20, 30]
835 });
836 assert!(computed);
837 assert_eq!(mults3.len(), 3);
838 }
839
840 #[test]
841 fn test_chunk_state_cached_multiplicities() {
842 let mut state = ChunkState::unflat(2, 100);
843
844 assert!(state.cached_multiplicities().is_none());
846
847 let _ = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
849
850 assert!(state.cached_multiplicities().is_some());
852 assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
853 }
854
855 #[test]
856 fn test_chunk_state_set_cached_multiplicities() {
857 let mut state = ChunkState::unflat(2, 100);
858
859 let mults: Arc<[usize]> = vec![5, 10, 15].into();
860 state.set_cached_multiplicities(mults);
861
862 assert!(state.cached_multiplicities().is_some());
863 assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
864 }
865
866 #[test]
867 fn test_chunk_state_generation() {
868 let mut state = ChunkState::flat(100);
869 assert_eq!(state.generation(), 0);
870
871 state.invalidate_cache();
872 assert_eq!(state.generation(), 1);
873
874 state.set_state(FactorizationState::Unflat {
875 level_count: 2,
876 logical_rows: 200,
877 });
878 assert_eq!(state.generation(), 2);
879 }
880
881 #[test]
882 fn test_chunk_state_default() {
883 let state = ChunkState::default();
884 assert!(state.is_flat());
885 assert_eq!(state.logical_row_count(), 0);
886 assert_eq!(state.generation(), 0);
887 }
888}