Skip to main content

grafeo_core/execution/
chunk_state.rs

1//! Unified chunk state tracking for factorized execution.
2//!
3//! This module provides [`ChunkState`] for centralized state management,
4//! inspired by LadybugDB's `FStateType` pattern. Key benefits:
5//!
6//! - **Cached multiplicities**: Computed once, reused for all aggregates
7//! - **Selection integration**: Lazy filtering without data copying
8//! - **O(1) logical row count**: Cached, not recomputed
9//!
10//! # Example
11//!
12//! ```ignore
13//! let mut state = ChunkState::unflat(3, 1000);
14//!
15//! // First access computes, subsequent accesses use cache
16//! let mults = state.get_or_compute_multiplicities(|| expensive_compute());
17//! let mults2 = state.get_or_compute_multiplicities(|| unreachable!());
18//! assert!(std::ptr::eq(mults.as_ptr(), mults2.as_ptr()));
19//! ```
20
21use std::sync::Arc;
22
23use super::selection::SelectionVector;
24
25/// Factorization state of a chunk (flat vs factorized).
26///
27/// Similar to LadybugDB's `FStateType`, this provides a single source
28/// of truth for the chunk's factorization status.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum FactorizationState {
31    /// All vectors are flat - one value per logical row.
32    ///
33    /// This is the state after flattening or for simple scans.
34    Flat {
35        /// Number of rows (physical = logical).
36        row_count: usize,
37    },
38    /// One or more vectors are unflat - values grouped by parent.
39    ///
40    /// The chunk has multi-level structure.
41    Unflat {
42        /// Number of factorization levels.
43        level_count: usize,
44        /// Total logical row count (cached, not recomputed).
45        logical_rows: usize,
46    },
47}
48
49impl FactorizationState {
50    /// Returns true if this is a flat state.
51    #[must_use]
52    pub fn is_flat(&self) -> bool {
53        matches!(self, Self::Flat { .. })
54    }
55
56    /// Returns true if this is an unflat (factorized) state.
57    #[must_use]
58    pub fn is_unflat(&self) -> bool {
59        matches!(self, Self::Unflat { .. })
60    }
61
62    /// Returns the logical row count.
63    #[must_use]
64    pub fn logical_row_count(&self) -> usize {
65        match self {
66            Self::Flat { row_count } => *row_count,
67            Self::Unflat { logical_rows, .. } => *logical_rows,
68        }
69    }
70
71    /// Returns the number of factorization levels.
72    #[must_use]
73    pub fn level_count(&self) -> usize {
74        match self {
75            Self::Flat { .. } => 1,
76            Self::Unflat { level_count, .. } => *level_count,
77        }
78    }
79}
80
81/// Selection state for a single factorization level.
82///
83/// Supports both sparse (for low selectivity) and dense (for high selectivity)
84/// representations to optimize memory usage.
85#[derive(Debug, Clone)]
86pub enum LevelSelection {
87    /// All values at this level are selected.
88    All {
89        /// Total count of values at this level.
90        count: usize,
91    },
92    /// Only specific indices are selected (for low selectivity).
93    ///
94    /// Uses `SelectionVector` which stores indices as `u16`.
95    Sparse(SelectionVector),
96}
97
98impl LevelSelection {
99    /// Creates a selection that selects all values.
100    #[must_use]
101    pub fn all(count: usize) -> Self {
102        Self::All { count }
103    }
104
105    /// Creates a sparse selection from a predicate.
106    #[must_use]
107    pub fn from_predicate<F>(count: usize, predicate: F) -> Self
108    where
109        F: Fn(usize) -> bool,
110    {
111        let selected = SelectionVector::from_predicate(count, predicate);
112        if selected.len() == count {
113            Self::All { count }
114        } else {
115            Self::Sparse(selected)
116        }
117    }
118
119    /// Returns the number of selected values.
120    #[must_use]
121    pub fn selected_count(&self) -> usize {
122        match self {
123            Self::All { count } => *count,
124            Self::Sparse(sel) => sel.len(),
125        }
126    }
127
128    /// Returns true if a physical index is selected.
129    #[must_use]
130    pub fn is_selected(&self, physical_idx: usize) -> bool {
131        match self {
132            Self::All { count } => physical_idx < *count,
133            Self::Sparse(sel) => sel.contains(physical_idx),
134        }
135    }
136
137    /// Filters this selection with a predicate, returning a new selection.
138    #[must_use]
139    pub fn filter<F>(&self, predicate: F) -> Self
140    where
141        F: Fn(usize) -> bool,
142    {
143        match self {
144            Self::All { count } => Self::from_predicate(*count, predicate),
145            Self::Sparse(sel) => {
146                let filtered = sel.filter(predicate);
147                Self::Sparse(filtered)
148            }
149        }
150    }
151
152    /// Returns an iterator over selected indices.
153    #[allow(clippy::iter_without_into_iter)]
154    pub fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
155        match self {
156            Self::All { count } => Box::new(0..*count),
157            Self::Sparse(sel) => Box::new(sel.iter()),
158        }
159    }
160}
161
162/// Hierarchical selection for factorized data.
163///
164/// Tracks selections at each factorization level, enabling filtering
165/// without flattening or copying data.
166#[derive(Debug, Clone)]
167pub struct FactorizedSelection {
168    /// Selection state per level (level 0 = sources, higher = more nested).
169    level_selections: Vec<LevelSelection>,
170    /// Cached logical row count after selection (lazily computed).
171    cached_selected_count: Option<usize>,
172}
173
174impl FactorizedSelection {
175    /// Creates a selection that selects all values at all levels.
176    #[must_use]
177    pub fn all(level_counts: &[usize]) -> Self {
178        let level_selections = level_counts
179            .iter()
180            .map(|&count| LevelSelection::all(count))
181            .collect();
182        Self {
183            level_selections,
184            cached_selected_count: None,
185        }
186    }
187
188    /// Creates a selection from level selections.
189    #[must_use]
190    pub fn new(level_selections: Vec<LevelSelection>) -> Self {
191        Self {
192            level_selections,
193            cached_selected_count: None,
194        }
195    }
196
197    /// Returns the number of levels.
198    #[must_use]
199    pub fn level_count(&self) -> usize {
200        self.level_selections.len()
201    }
202
203    /// Gets the selection for a specific level.
204    #[must_use]
205    pub fn level(&self, level: usize) -> Option<&LevelSelection> {
206        self.level_selections.get(level)
207    }
208
209    /// Filters at a specific level using a predicate.
210    ///
211    /// Returns a new selection with the filter applied.
212    /// This is O(n_physical) where n is the physical size of that level,
213    /// not O(n_logical) where n is the logical row count.
214    #[must_use]
215    pub fn filter_level<F>(&self, level: usize, predicate: F) -> Self
216    where
217        F: Fn(usize) -> bool,
218    {
219        let mut new_selections = self.level_selections.clone();
220
221        if let Some(sel) = new_selections.get_mut(level) {
222            *sel = sel.filter(predicate);
223        }
224
225        Self {
226            level_selections: new_selections,
227            cached_selected_count: None, // Invalidate cache
228        }
229    }
230
231    /// Checks if a physical index at a level is selected.
232    #[must_use]
233    pub fn is_selected(&self, level: usize, physical_idx: usize) -> bool {
234        self.level_selections
235            .get(level)
236            .is_some_and(|sel| sel.is_selected(physical_idx))
237    }
238
239    /// Computes and caches the selected logical row count.
240    ///
241    /// The computation considers parent-child relationships:
242    /// a child is only counted if its parent is selected.
243    pub fn selected_count(&mut self, multiplicities: &[Vec<usize>]) -> usize {
244        if let Some(count) = self.cached_selected_count {
245            return count;
246        }
247
248        let count = self.compute_selected_count(multiplicities);
249        self.cached_selected_count = Some(count);
250        count
251    }
252
253    /// Computes the selected count without caching.
254    fn compute_selected_count(&self, multiplicities: &[Vec<usize>]) -> usize {
255        if self.level_selections.is_empty() {
256            return 0;
257        }
258
259        // For single level, just count selected
260        if self.level_selections.len() == 1 {
261            return self.level_selections[0].selected_count();
262        }
263
264        // For multi-level, we need to propagate selection through levels
265        // Start with level 0 selection
266        let mut parent_selected: Vec<bool> = match &self.level_selections[0] {
267            LevelSelection::All { count } => vec![true; *count],
268            LevelSelection::Sparse(sel) => {
269                let max_idx = sel.iter().max().unwrap_or(0);
270                let mut selected = vec![false; max_idx + 1];
271                for idx in sel.iter() {
272                    selected[idx] = true;
273                }
274                selected
275            }
276        };
277
278        // Propagate through subsequent levels
279        for (level_sel, level_mults) in self
280            .level_selections
281            .iter()
282            .skip(1)
283            .zip(multiplicities.iter().skip(1))
284        {
285            let mut child_selected = Vec::new();
286            let mut child_idx = 0;
287
288            for (parent_idx, &mult) in level_mults.iter().enumerate() {
289                let parent_is_selected = parent_selected.get(parent_idx).copied().unwrap_or(false);
290
291                for _ in 0..mult {
292                    let child_is_selected = parent_is_selected && level_sel.is_selected(child_idx);
293                    child_selected.push(child_is_selected);
294                    child_idx += 1;
295                }
296            }
297
298            parent_selected = child_selected;
299        }
300
301        // Count final selected
302        parent_selected.iter().filter(|&&s| s).count()
303    }
304
305    /// Invalidates the cached selected count.
306    pub fn invalidate_cache(&mut self) {
307        self.cached_selected_count = None;
308    }
309}
310
311/// Unified chunk state tracking metadata for factorized execution.
312///
313/// This replaces scattered state tracking with a centralized structure
314/// that is updated incrementally rather than recomputed.
315///
316/// # Key Features
317///
318/// - **Cached multiplicities**: Computed once per chunk, reused for all aggregates
319/// - **Selection integration**: Supports lazy filtering without data copying
320/// - **Generation tracking**: Enables cache invalidation on structure changes
321#[derive(Debug, Clone)]
322pub struct ChunkState {
323    /// Factorization state of this chunk.
324    state: FactorizationState,
325    /// Selection for filtering without data copying.
326    /// When Some, only selected indices are "active".
327    selection: Option<FactorizedSelection>,
328    /// Cached path multiplicities (invalidated on structure change).
329    /// Key optimization: computed once, reused for all aggregates.
330    cached_multiplicities: Option<Arc<[usize]>>,
331    /// Generation counter for cache invalidation.
332    generation: u64,
333}
334
335impl ChunkState {
336    /// Creates a new flat chunk state.
337    #[must_use]
338    pub fn flat(row_count: usize) -> Self {
339        Self {
340            state: FactorizationState::Flat { row_count },
341            selection: None,
342            cached_multiplicities: None,
343            generation: 0,
344        }
345    }
346
347    /// Creates an unflat (factorized) chunk state.
348    #[must_use]
349    pub fn unflat(level_count: usize, logical_rows: usize) -> Self {
350        Self {
351            state: FactorizationState::Unflat {
352                level_count,
353                logical_rows,
354            },
355            selection: None,
356            cached_multiplicities: None,
357            generation: 0,
358        }
359    }
360
361    /// Returns the factorization state.
362    #[must_use]
363    pub fn factorization_state(&self) -> FactorizationState {
364        self.state
365    }
366
367    /// Returns true if this chunk is flat.
368    #[must_use]
369    pub fn is_flat(&self) -> bool {
370        self.state.is_flat()
371    }
372
373    /// Returns true if this chunk is factorized (unflat).
374    #[must_use]
375    pub fn is_factorized(&self) -> bool {
376        self.state.is_unflat()
377    }
378
379    /// Returns the logical row count.
380    ///
381    /// If a selection is active, this returns the base logical row count
382    /// (selection count must be computed separately with multiplicities).
383    #[must_use]
384    pub fn logical_row_count(&self) -> usize {
385        self.state.logical_row_count()
386    }
387
388    /// Returns the number of factorization levels.
389    #[must_use]
390    pub fn level_count(&self) -> usize {
391        self.state.level_count()
392    }
393
394    /// Returns the current generation (for cache validation).
395    #[must_use]
396    pub fn generation(&self) -> u64 {
397        self.generation
398    }
399
400    /// Returns the selection, if any.
401    #[must_use]
402    pub fn selection(&self) -> Option<&FactorizedSelection> {
403        self.selection.as_ref()
404    }
405
406    /// Returns mutable access to the selection.
407    pub fn selection_mut(&mut self) -> &mut Option<FactorizedSelection> {
408        &mut self.selection
409    }
410
411    /// Sets the selection.
412    pub fn set_selection(&mut self, selection: FactorizedSelection) {
413        self.selection = Some(selection);
414        // Don't invalidate multiplicity cache - selection is orthogonal
415    }
416
417    /// Clears the selection.
418    pub fn clear_selection(&mut self) {
419        self.selection = None;
420    }
421
422    /// Updates the state (e.g., after adding a level).
423    pub fn set_state(&mut self, state: FactorizationState) {
424        self.state = state;
425        self.invalidate_cache();
426    }
427
428    /// Invalidates cached data (call when structure changes).
429    pub fn invalidate_cache(&mut self) {
430        self.cached_multiplicities = None;
431        self.generation += 1;
432    }
433
434    /// Gets cached multiplicities, or computes and caches them.
435    ///
436    /// This is the key optimization: multiplicities are computed once
437    /// and reused for all aggregates (COUNT, SUM, AVG, etc.).
438    ///
439    /// # Arguments
440    ///
441    /// * `compute` - Function to compute multiplicities if not cached
442    ///
443    /// # Example
444    ///
445    /// ```ignore
446    /// let mults = state.get_or_compute_multiplicities(|| {
447    ///     chunk.compute_path_multiplicities_impl()
448    /// });
449    /// ```
450    pub fn get_or_compute_multiplicities<F>(&mut self, compute: F) -> Arc<[usize]>
451    where
452        F: FnOnce() -> Vec<usize>,
453    {
454        if let Some(ref cached) = self.cached_multiplicities {
455            return Arc::clone(cached);
456        }
457
458        let mults: Arc<[usize]> = compute().into();
459        self.cached_multiplicities = Some(Arc::clone(&mults));
460        mults
461    }
462
463    /// Returns cached multiplicities without computing.
464    ///
465    /// Returns None if not yet computed.
466    #[must_use]
467    pub fn cached_multiplicities(&self) -> Option<&Arc<[usize]>> {
468        self.cached_multiplicities.as_ref()
469    }
470
471    /// Sets the cached multiplicities directly.
472    ///
473    /// Useful when multiplicities are computed externally.
474    pub fn set_cached_multiplicities(&mut self, mults: Arc<[usize]>) {
475        self.cached_multiplicities = Some(mults);
476    }
477}
478
479impl Default for ChunkState {
480    fn default() -> Self {
481        Self::flat(0)
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_factorization_state_flat() {
491        let state = FactorizationState::Flat { row_count: 100 };
492        assert!(state.is_flat());
493        assert!(!state.is_unflat());
494        assert_eq!(state.logical_row_count(), 100);
495        assert_eq!(state.level_count(), 1);
496    }
497
498    #[test]
499    fn test_factorization_state_unflat() {
500        let state = FactorizationState::Unflat {
501            level_count: 3,
502            logical_rows: 1000,
503        };
504        assert!(!state.is_flat());
505        assert!(state.is_unflat());
506        assert_eq!(state.logical_row_count(), 1000);
507        assert_eq!(state.level_count(), 3);
508    }
509
510    #[test]
511    fn test_level_selection_all() {
512        let sel = LevelSelection::all(10);
513        assert_eq!(sel.selected_count(), 10);
514        for i in 0..10 {
515            assert!(sel.is_selected(i));
516        }
517        assert!(!sel.is_selected(10));
518    }
519
520    #[test]
521    fn test_level_selection_filter() {
522        let sel = LevelSelection::all(10);
523        let filtered = sel.filter(|i| i % 2 == 0);
524        assert_eq!(filtered.selected_count(), 5);
525        assert!(filtered.is_selected(0));
526        assert!(!filtered.is_selected(1));
527        assert!(filtered.is_selected(2));
528    }
529
530    #[test]
531    fn test_level_selection_filter_sparse() {
532        // Start with sparse selection
533        let sel = LevelSelection::from_predicate(10, |i| i < 5);
534        assert_eq!(sel.selected_count(), 5);
535
536        // Filter the sparse selection further
537        let filtered = sel.filter(|i| i % 2 == 0);
538        // Only 0, 2, 4 pass both filters
539        assert_eq!(filtered.selected_count(), 3);
540        assert!(filtered.is_selected(0));
541        assert!(!filtered.is_selected(1));
542        assert!(filtered.is_selected(2));
543    }
544
545    #[test]
546    fn test_level_selection_iter_all() {
547        let sel = LevelSelection::all(5);
548        let indices: Vec<usize> = sel.iter().collect();
549        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
550    }
551
552    #[test]
553    fn test_level_selection_iter_sparse() {
554        let sel = LevelSelection::from_predicate(10, |i| i % 3 == 0);
555        let indices: Vec<usize> = sel.iter().collect();
556        assert_eq!(indices, vec![0, 3, 6, 9]);
557    }
558
559    #[test]
560    fn test_level_selection_from_predicate_all_selected() {
561        // When predicate selects everything, should return All variant
562        let sel = LevelSelection::from_predicate(5, |_| true);
563        assert_eq!(sel.selected_count(), 5);
564        match sel {
565            LevelSelection::All { count } => assert_eq!(count, 5),
566            LevelSelection::Sparse(_) => panic!("Expected All variant"),
567        }
568    }
569
570    #[test]
571    fn test_level_selection_from_predicate_partial() {
572        // When predicate selects some, should return Sparse variant
573        let sel = LevelSelection::from_predicate(10, |i| i < 3);
574        assert_eq!(sel.selected_count(), 3);
575        match sel {
576            LevelSelection::Sparse(_) => {}
577            LevelSelection::All { .. } => panic!("Expected Sparse variant"),
578        }
579    }
580
581    #[test]
582    fn test_factorized_selection_all() {
583        let sel = FactorizedSelection::all(&[10, 100, 1000]);
584        assert_eq!(sel.level_count(), 3);
585        assert!(sel.is_selected(0, 5));
586        assert!(sel.is_selected(1, 50));
587        assert!(sel.is_selected(2, 500));
588    }
589
590    #[test]
591    fn test_factorized_selection_new() {
592        let level_sels = vec![
593            LevelSelection::all(5),
594            LevelSelection::from_predicate(10, |i| i < 3),
595        ];
596        let sel = FactorizedSelection::new(level_sels);
597
598        assert_eq!(sel.level_count(), 2);
599        assert!(sel.is_selected(0, 4));
600        assert!(sel.is_selected(1, 2));
601        assert!(!sel.is_selected(1, 5));
602    }
603
604    #[test]
605    fn test_factorized_selection_filter_level() {
606        let sel = FactorizedSelection::all(&[10, 100]);
607        let filtered = sel.filter_level(1, |i| i < 50);
608
609        assert!(filtered.is_selected(0, 5)); // Level 0 unchanged
610        assert!(filtered.is_selected(1, 25)); // Level 1: 25 < 50
611        assert!(!filtered.is_selected(1, 75)); // Level 1: 75 >= 50
612    }
613
614    #[test]
615    fn test_factorized_selection_filter_level_invalid() {
616        let sel = FactorizedSelection::all(&[10, 100]);
617
618        // Filtering a non-existent level should not panic
619        let filtered = sel.filter_level(5, |_| true);
620        assert_eq!(filtered.level_count(), 2);
621    }
622
623    #[test]
624    fn test_factorized_selection_is_selected_invalid_level() {
625        let sel = FactorizedSelection::all(&[10]);
626        assert!(!sel.is_selected(5, 0)); // Non-existent level
627    }
628
629    #[test]
630    fn test_factorized_selection_level() {
631        let sel = FactorizedSelection::all(&[10, 20]);
632
633        let level0 = sel.level(0);
634        assert!(level0.is_some());
635        assert_eq!(level0.unwrap().selected_count(), 10);
636
637        let level1 = sel.level(1);
638        assert!(level1.is_some());
639        assert_eq!(level1.unwrap().selected_count(), 20);
640
641        assert!(sel.level(5).is_none());
642    }
643
644    #[test]
645    fn test_factorized_selection_selected_count_single_level() {
646        let mut sel = FactorizedSelection::all(&[10]);
647        let multiplicities: Vec<Vec<usize>> = vec![vec![1; 10]];
648
649        let count = sel.selected_count(&multiplicities);
650        assert_eq!(count, 10);
651    }
652
653    #[test]
654    fn test_factorized_selection_selected_count_multi_level() {
655        let level_sels = vec![
656            LevelSelection::all(2),                            // 2 parents
657            LevelSelection::from_predicate(4, |i| i % 2 == 0), // Select indices 0, 2
658        ];
659        let mut sel = FactorizedSelection::new(level_sels);
660
661        // Parent 0 has 2 children (indices 0, 1)
662        // Parent 1 has 2 children (indices 2, 3)
663        let multiplicities = vec![
664            vec![1, 1], // Level 0 multiplicities
665            vec![2, 2], // Level 1 multiplicities (children per parent)
666        ];
667
668        let count = sel.selected_count(&multiplicities);
669        // Parent 0 is selected, children 0 is selected (1 passes)
670        // Parent 1 is selected, child 2 is selected (1 passes)
671        assert_eq!(count, 2);
672    }
673
674    #[test]
675    fn test_factorized_selection_selected_count_cached() {
676        let mut sel = FactorizedSelection::all(&[5]);
677        let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
678
679        // First call computes
680        let count1 = sel.selected_count(&multiplicities);
681        assert_eq!(count1, 5);
682
683        // Second call uses cache
684        let count2 = sel.selected_count(&multiplicities);
685        assert_eq!(count2, 5);
686    }
687
688    #[test]
689    fn test_factorized_selection_selected_count_empty() {
690        let mut sel = FactorizedSelection::all(&[]);
691        let multiplicities: Vec<Vec<usize>> = vec![];
692
693        assert_eq!(sel.selected_count(&multiplicities), 0);
694    }
695
696    #[test]
697    fn test_factorized_selection_invalidate_cache() {
698        let mut sel = FactorizedSelection::all(&[5]);
699        let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
700
701        // Compute and cache
702        let _ = sel.selected_count(&multiplicities);
703
704        // Invalidate
705        sel.invalidate_cache();
706
707        // Should recompute (no way to verify, but shouldn't crash)
708        let _ = sel.selected_count(&multiplicities);
709    }
710
711    #[test]
712    fn test_chunk_state_flat() {
713        let state = ChunkState::flat(100);
714        assert!(state.is_flat());
715        assert!(!state.is_factorized());
716        assert_eq!(state.logical_row_count(), 100);
717        assert_eq!(state.level_count(), 1);
718    }
719
720    #[test]
721    fn test_chunk_state_unflat() {
722        let state = ChunkState::unflat(3, 1000);
723        assert!(!state.is_flat());
724        assert!(state.is_factorized());
725        assert_eq!(state.logical_row_count(), 1000);
726        assert_eq!(state.level_count(), 3);
727    }
728
729    #[test]
730    fn test_chunk_state_factorization_state() {
731        let state = ChunkState::flat(50);
732        let fs = state.factorization_state();
733        assert!(fs.is_flat());
734    }
735
736    #[test]
737    fn test_chunk_state_selection() {
738        let mut state = ChunkState::unflat(2, 100);
739
740        // Initially no selection
741        assert!(state.selection().is_none());
742
743        // Set selection
744        let sel = FactorizedSelection::all(&[10, 100]);
745        state.set_selection(sel);
746
747        assert!(state.selection().is_some());
748        assert_eq!(state.selection().unwrap().level_count(), 2);
749    }
750
751    #[test]
752    fn test_chunk_state_selection_mut() {
753        let mut state = ChunkState::unflat(2, 100);
754
755        // Set selection
756        let sel = FactorizedSelection::all(&[10, 100]);
757        state.set_selection(sel);
758
759        // Get mutable access
760        let sel_mut = state.selection_mut();
761        assert!(sel_mut.is_some());
762
763        // Clear via mutable access
764        *sel_mut = None;
765        assert!(state.selection().is_none());
766    }
767
768    #[test]
769    fn test_chunk_state_clear_selection() {
770        let mut state = ChunkState::unflat(2, 100);
771
772        let sel = FactorizedSelection::all(&[10, 100]);
773        state.set_selection(sel);
774        assert!(state.selection().is_some());
775
776        state.clear_selection();
777        assert!(state.selection().is_none());
778    }
779
780    #[test]
781    fn test_chunk_state_set_state() {
782        let mut state = ChunkState::flat(100);
783        assert!(state.is_flat());
784        assert_eq!(state.generation(), 0);
785
786        state.set_state(FactorizationState::Unflat {
787            level_count: 2,
788            logical_rows: 200,
789        });
790
791        assert!(state.is_factorized());
792        assert_eq!(state.logical_row_count(), 200);
793        assert_eq!(state.generation(), 1); // Cache invalidated
794    }
795
796    #[test]
797    fn test_chunk_state_caching() {
798        let mut state = ChunkState::unflat(2, 100);
799
800        // First call should compute
801        let mut computed = false;
802        let mults1 = state.get_or_compute_multiplicities(|| {
803            computed = true;
804            vec![1, 2, 3, 4, 5]
805        });
806        assert!(computed);
807        assert_eq!(mults1.len(), 5);
808
809        // Second call should use cache
810        computed = false;
811        let mults2 = state.get_or_compute_multiplicities(|| {
812            computed = true;
813            vec![99, 99, 99]
814        });
815        assert!(!computed);
816        assert_eq!(mults2.len(), 5); // Same as before
817
818        // After invalidation, should recompute
819        state.invalidate_cache();
820        let mults3 = state.get_or_compute_multiplicities(|| {
821            computed = true;
822            vec![10, 20, 30]
823        });
824        assert!(computed);
825        assert_eq!(mults3.len(), 3);
826    }
827
828    #[test]
829    fn test_chunk_state_cached_multiplicities() {
830        let mut state = ChunkState::unflat(2, 100);
831
832        // Initially not cached
833        assert!(state.cached_multiplicities().is_none());
834
835        // Compute multiplicities
836        let _ = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
837
838        // Now cached
839        assert!(state.cached_multiplicities().is_some());
840        assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
841    }
842
843    #[test]
844    fn test_chunk_state_set_cached_multiplicities() {
845        let mut state = ChunkState::unflat(2, 100);
846
847        let mults: Arc<[usize]> = vec![5, 10, 15].into();
848        state.set_cached_multiplicities(mults);
849
850        assert!(state.cached_multiplicities().is_some());
851        assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
852    }
853
854    #[test]
855    fn test_chunk_state_generation() {
856        let mut state = ChunkState::flat(100);
857        assert_eq!(state.generation(), 0);
858
859        state.invalidate_cache();
860        assert_eq!(state.generation(), 1);
861
862        state.set_state(FactorizationState::Unflat {
863            level_count: 2,
864            logical_rows: 200,
865        });
866        assert_eq!(state.generation(), 2);
867    }
868
869    #[test]
870    fn test_chunk_state_default() {
871        let state = ChunkState::default();
872        assert!(state.is_flat());
873        assert_eq!(state.logical_row_count(), 0);
874        assert_eq!(state.generation(), 0);
875    }
876}