Skip to main content

grafeo_core/execution/
chunk_state.rs

1//! Unified chunk state tracking for factorized execution.
2//!
3//! This module provides [`ChunkState`] for centralized state management,
4//! inspired by LadybugDB's `FStateType` pattern. Key benefits:
5//!
6//! - **Cached multiplicities**: Computed once, reused for all aggregates
7//! - **Selection integration**: Lazy filtering without data copying
8//! - **O(1) logical row count**: Cached, not recomputed
9//!
10//! # Example
11//!
12//! ```rust
13//! use grafeo_core::execution::chunk_state::ChunkState;
14//!
15//! let mut state = ChunkState::unflat(3, 1000);
16//!
17//! // First access computes, subsequent accesses use cache
18//! let mults = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
19//! let mults2 = state.get_or_compute_multiplicities(|| unreachable!());
20//! assert!(std::ptr::eq(mults.as_ptr(), mults2.as_ptr()));
21//! ```
22
23use std::sync::Arc;
24
25use super::selection::SelectionVector;
26
27/// Factorization state of a chunk (flat vs factorized).
28///
29/// Similar to LadybugDB's `FStateType`, this provides a single source
30/// of truth for the chunk's factorization status.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32#[non_exhaustive]
33pub enum FactorizationState {
34    /// All vectors are flat - one value per logical row.
35    ///
36    /// This is the state after flattening or for simple scans.
37    Flat {
38        /// Number of rows (physical = logical).
39        row_count: usize,
40    },
41    /// One or more vectors are unflat - values grouped by parent.
42    ///
43    /// The chunk has multi-level structure.
44    Unflat {
45        /// Number of factorization levels.
46        level_count: usize,
47        /// Total logical row count (cached, not recomputed).
48        logical_rows: usize,
49    },
50}
51
52impl FactorizationState {
53    /// Returns true if this is a flat state.
54    #[must_use]
55    pub fn is_flat(&self) -> bool {
56        matches!(self, Self::Flat { .. })
57    }
58
59    /// Returns true if this is an unflat (factorized) state.
60    #[must_use]
61    pub fn is_unflat(&self) -> bool {
62        matches!(self, Self::Unflat { .. })
63    }
64
65    /// Returns the logical row count.
66    #[must_use]
67    pub fn logical_row_count(&self) -> usize {
68        match self {
69            Self::Flat { row_count } => *row_count,
70            Self::Unflat { logical_rows, .. } => *logical_rows,
71        }
72    }
73
74    /// Returns the number of factorization levels.
75    #[must_use]
76    pub fn level_count(&self) -> usize {
77        match self {
78            Self::Flat { .. } => 1,
79            Self::Unflat { level_count, .. } => *level_count,
80        }
81    }
82}
83
84/// Selection state for a single factorization level.
85///
86/// Supports both sparse (for low selectivity) and dense (for high selectivity)
87/// representations to optimize memory usage.
88#[derive(Debug, Clone)]
89#[non_exhaustive]
90pub enum LevelSelection {
91    /// All values at this level are selected.
92    All {
93        /// Total count of values at this level.
94        count: usize,
95    },
96    /// Only specific indices are selected (for low selectivity).
97    ///
98    /// Uses `SelectionVector` which stores indices as `u16`.
99    Sparse(SelectionVector),
100}
101
102impl LevelSelection {
103    /// Creates a selection that selects all values.
104    #[must_use]
105    pub fn all(count: usize) -> Self {
106        Self::All { count }
107    }
108
109    /// Creates a sparse selection from a predicate.
110    #[must_use]
111    pub fn from_predicate<F>(count: usize, predicate: F) -> Self
112    where
113        F: Fn(usize) -> bool,
114    {
115        let selected = SelectionVector::from_predicate(count, predicate);
116        if selected.len() == count {
117            Self::All { count }
118        } else {
119            Self::Sparse(selected)
120        }
121    }
122
123    /// Returns the number of selected values.
124    #[must_use]
125    pub fn selected_count(&self) -> usize {
126        match self {
127            Self::All { count } => *count,
128            Self::Sparse(sel) => sel.len(),
129        }
130    }
131
132    /// Returns true if a physical index is selected.
133    #[must_use]
134    pub fn is_selected(&self, physical_idx: usize) -> bool {
135        match self {
136            Self::All { count } => physical_idx < *count,
137            Self::Sparse(sel) => sel.contains(physical_idx),
138        }
139    }
140
141    /// Filters this selection with a predicate, returning a new selection.
142    #[must_use]
143    pub fn filter<F>(&self, predicate: F) -> Self
144    where
145        F: Fn(usize) -> bool,
146    {
147        match self {
148            Self::All { count } => Self::from_predicate(*count, predicate),
149            Self::Sparse(sel) => {
150                let filtered = sel.filter(predicate);
151                Self::Sparse(filtered)
152            }
153        }
154    }
155
156    /// Returns an iterator over selected indices.
157    pub fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
158        match self {
159            Self::All { count } => Box::new(0..*count),
160            Self::Sparse(sel) => Box::new(sel.iter()),
161        }
162    }
163}
164
165impl<'a> IntoIterator for &'a LevelSelection {
166    type Item = usize;
167    type IntoIter = Box<dyn Iterator<Item = usize> + 'a>;
168
169    fn into_iter(self) -> Self::IntoIter {
170        self.iter()
171    }
172}
173
174/// Hierarchical selection for factorized data.
175///
176/// Tracks selections at each factorization level, enabling filtering
177/// without flattening or copying data.
178#[derive(Debug, Clone)]
179pub struct FactorizedSelection {
180    /// Selection state per level (level 0 = sources, higher = more nested).
181    level_selections: Vec<LevelSelection>,
182    /// Cached logical row count after selection (lazily computed).
183    cached_selected_count: Option<usize>,
184}
185
186impl FactorizedSelection {
187    /// Creates a selection that selects all values at all levels.
188    #[must_use]
189    pub fn all(level_counts: &[usize]) -> Self {
190        let level_selections = level_counts
191            .iter()
192            .map(|&count| LevelSelection::all(count))
193            .collect();
194        Self {
195            level_selections,
196            cached_selected_count: None,
197        }
198    }
199
200    /// Creates a selection from level selections.
201    #[must_use]
202    pub fn new(level_selections: Vec<LevelSelection>) -> Self {
203        Self {
204            level_selections,
205            cached_selected_count: None,
206        }
207    }
208
209    /// Returns the number of levels.
210    #[must_use]
211    pub fn level_count(&self) -> usize {
212        self.level_selections.len()
213    }
214
215    /// Gets the selection for a specific level.
216    #[must_use]
217    pub fn level(&self, level: usize) -> Option<&LevelSelection> {
218        self.level_selections.get(level)
219    }
220
221    /// Filters at a specific level using a predicate.
222    ///
223    /// Returns a new selection with the filter applied.
224    /// This is O(n_physical) where n is the physical size of that level,
225    /// not O(n_logical) where n is the logical row count.
226    #[must_use]
227    pub fn filter_level<F>(&self, level: usize, predicate: F) -> Self
228    where
229        F: Fn(usize) -> bool,
230    {
231        let mut new_selections = self.level_selections.clone();
232
233        if let Some(sel) = new_selections.get_mut(level) {
234            *sel = sel.filter(predicate);
235        }
236
237        Self {
238            level_selections: new_selections,
239            cached_selected_count: None, // Invalidate cache
240        }
241    }
242
243    /// Checks if a physical index at a level is selected.
244    #[must_use]
245    pub fn is_selected(&self, level: usize, physical_idx: usize) -> bool {
246        self.level_selections
247            .get(level)
248            .is_some_and(|sel| sel.is_selected(physical_idx))
249    }
250
251    /// Computes and caches the selected logical row count.
252    ///
253    /// The computation considers parent-child relationships:
254    /// a child is only counted if its parent is selected.
255    pub fn selected_count(&mut self, multiplicities: &[Vec<usize>]) -> usize {
256        if let Some(count) = self.cached_selected_count {
257            return count;
258        }
259
260        let count = self.compute_selected_count(multiplicities);
261        self.cached_selected_count = Some(count);
262        count
263    }
264
265    /// Computes the selected count without caching.
266    fn compute_selected_count(&self, multiplicities: &[Vec<usize>]) -> usize {
267        if self.level_selections.is_empty() {
268            return 0;
269        }
270
271        // For single level, just count selected
272        if self.level_selections.len() == 1 {
273            return self.level_selections[0].selected_count();
274        }
275
276        // For multi-level, we need to propagate selection through levels
277        // Start with level 0 selection
278        let mut parent_selected: Vec<bool> = match &self.level_selections[0] {
279            LevelSelection::All { count } => vec![true; *count],
280            LevelSelection::Sparse(sel) => {
281                let max_idx = sel.iter().max().unwrap_or(0);
282                let mut selected = vec![false; max_idx + 1];
283                for idx in sel.iter() {
284                    selected[idx] = true;
285                }
286                selected
287            }
288        };
289
290        // Propagate through subsequent levels
291        for (level_sel, level_mults) in self
292            .level_selections
293            .iter()
294            .skip(1)
295            .zip(multiplicities.iter().skip(1))
296        {
297            let mut child_selected = Vec::new();
298            let mut child_idx = 0;
299
300            for (parent_idx, &mult) in level_mults.iter().enumerate() {
301                let parent_is_selected = parent_selected.get(parent_idx).copied().unwrap_or(false);
302
303                for _ in 0..mult {
304                    let child_is_selected = parent_is_selected && level_sel.is_selected(child_idx);
305                    child_selected.push(child_is_selected);
306                    child_idx += 1;
307                }
308            }
309
310            parent_selected = child_selected;
311        }
312
313        // Count final selected
314        parent_selected.iter().filter(|&&s| s).count()
315    }
316
317    /// Invalidates the cached selected count.
318    pub fn invalidate_cache(&mut self) {
319        self.cached_selected_count = None;
320    }
321}
322
323/// Unified chunk state tracking metadata for factorized execution.
324///
325/// This replaces scattered state tracking with a centralized structure
326/// that is updated incrementally rather than recomputed.
327///
328/// # Key Features
329///
330/// - **Cached multiplicities**: Computed once per chunk, reused for all aggregates
331/// - **Selection integration**: Supports lazy filtering without data copying
332/// - **Generation tracking**: Enables cache invalidation on structure changes
333#[derive(Debug, Clone)]
334pub struct ChunkState {
335    /// Factorization state of this chunk.
336    state: FactorizationState,
337    /// Selection for filtering without data copying.
338    /// When Some, only selected indices are "active".
339    selection: Option<FactorizedSelection>,
340    /// Cached path multiplicities (invalidated on structure change).
341    /// Key optimization: computed once, reused for all aggregates.
342    cached_multiplicities: Option<Arc<[usize]>>,
343    /// Generation counter for cache invalidation.
344    generation: u64,
345}
346
347impl ChunkState {
348    /// Creates a new flat chunk state.
349    #[must_use]
350    pub fn flat(row_count: usize) -> Self {
351        Self {
352            state: FactorizationState::Flat { row_count },
353            selection: None,
354            cached_multiplicities: None,
355            generation: 0,
356        }
357    }
358
359    /// Creates an unflat (factorized) chunk state.
360    #[must_use]
361    pub fn unflat(level_count: usize, logical_rows: usize) -> Self {
362        Self {
363            state: FactorizationState::Unflat {
364                level_count,
365                logical_rows,
366            },
367            selection: None,
368            cached_multiplicities: None,
369            generation: 0,
370        }
371    }
372
373    /// Returns the factorization state.
374    #[must_use]
375    pub fn factorization_state(&self) -> FactorizationState {
376        self.state
377    }
378
379    /// Returns true if this chunk is flat.
380    #[must_use]
381    pub fn is_flat(&self) -> bool {
382        self.state.is_flat()
383    }
384
385    /// Returns true if this chunk is factorized (unflat).
386    #[must_use]
387    pub fn is_factorized(&self) -> bool {
388        self.state.is_unflat()
389    }
390
391    /// Returns the logical row count.
392    ///
393    /// If a selection is active, this returns the base logical row count
394    /// (selection count must be computed separately with multiplicities).
395    #[must_use]
396    pub fn logical_row_count(&self) -> usize {
397        self.state.logical_row_count()
398    }
399
400    /// Returns the number of factorization levels.
401    #[must_use]
402    pub fn level_count(&self) -> usize {
403        self.state.level_count()
404    }
405
406    /// Returns the current generation (for cache validation).
407    #[must_use]
408    pub fn generation(&self) -> u64 {
409        self.generation
410    }
411
412    /// Returns the selection, if any.
413    #[must_use]
414    pub fn selection(&self) -> Option<&FactorizedSelection> {
415        self.selection.as_ref()
416    }
417
418    /// Returns mutable access to the selection.
419    pub fn selection_mut(&mut self) -> &mut Option<FactorizedSelection> {
420        &mut self.selection
421    }
422
423    /// Sets the selection.
424    pub fn set_selection(&mut self, selection: FactorizedSelection) {
425        self.selection = Some(selection);
426        // Don't invalidate multiplicity cache - selection is orthogonal
427    }
428
429    /// Clears the selection.
430    pub fn clear_selection(&mut self) {
431        self.selection = None;
432    }
433
434    /// Updates the state (e.g., after adding a level).
435    pub fn set_state(&mut self, state: FactorizationState) {
436        self.state = state;
437        self.invalidate_cache();
438    }
439
440    /// Invalidates cached data (call when structure changes).
441    pub fn invalidate_cache(&mut self) {
442        self.cached_multiplicities = None;
443        self.generation += 1;
444    }
445
446    /// Gets cached multiplicities, or computes and caches them.
447    ///
448    /// This is the key optimization: multiplicities are computed once
449    /// and reused for all aggregates (COUNT, SUM, AVG, etc.).
450    ///
451    /// # Arguments
452    ///
453    /// * `compute` - Function to compute multiplicities if not cached
454    ///
455    /// # Example
456    ///
457    /// ```rust
458    /// # use grafeo_core::execution::chunk_state::ChunkState;
459    /// # let mut state = ChunkState::unflat(2, 100);
460    /// let mults = state.get_or_compute_multiplicities(|| {
461    ///     vec![1; 100] // compute multiplicities
462    /// });
463    /// ```
464    pub fn get_or_compute_multiplicities<F>(&mut self, compute: F) -> Arc<[usize]>
465    where
466        F: FnOnce() -> Vec<usize>,
467    {
468        if let Some(ref cached) = self.cached_multiplicities {
469            return Arc::clone(cached);
470        }
471
472        let mults: Arc<[usize]> = compute().into();
473        self.cached_multiplicities = Some(Arc::clone(&mults));
474        mults
475    }
476
477    /// Returns cached multiplicities without computing.
478    ///
479    /// Returns None if not yet computed.
480    #[must_use]
481    pub fn cached_multiplicities(&self) -> Option<&Arc<[usize]>> {
482        self.cached_multiplicities.as_ref()
483    }
484
485    /// Sets the cached multiplicities directly.
486    ///
487    /// Useful when multiplicities are computed externally.
488    pub fn set_cached_multiplicities(&mut self, mults: Arc<[usize]>) {
489        self.cached_multiplicities = Some(mults);
490    }
491}
492
493impl Default for ChunkState {
494    fn default() -> Self {
495        Self::flat(0)
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_factorization_state_flat() {
505        let state = FactorizationState::Flat { row_count: 100 };
506        assert!(state.is_flat());
507        assert!(!state.is_unflat());
508        assert_eq!(state.logical_row_count(), 100);
509        assert_eq!(state.level_count(), 1);
510    }
511
512    #[test]
513    fn test_factorization_state_unflat() {
514        let state = FactorizationState::Unflat {
515            level_count: 3,
516            logical_rows: 1000,
517        };
518        assert!(!state.is_flat());
519        assert!(state.is_unflat());
520        assert_eq!(state.logical_row_count(), 1000);
521        assert_eq!(state.level_count(), 3);
522    }
523
524    #[test]
525    fn test_level_selection_all() {
526        let sel = LevelSelection::all(10);
527        assert_eq!(sel.selected_count(), 10);
528        for i in 0..10 {
529            assert!(sel.is_selected(i));
530        }
531        assert!(!sel.is_selected(10));
532    }
533
534    #[test]
535    fn test_level_selection_filter() {
536        let sel = LevelSelection::all(10);
537        let filtered = sel.filter(|i| i % 2 == 0);
538        assert_eq!(filtered.selected_count(), 5);
539        assert!(filtered.is_selected(0));
540        assert!(!filtered.is_selected(1));
541        assert!(filtered.is_selected(2));
542    }
543
544    #[test]
545    fn test_level_selection_filter_sparse() {
546        // Start with sparse selection
547        let sel = LevelSelection::from_predicate(10, |i| i < 5);
548        assert_eq!(sel.selected_count(), 5);
549
550        // Filter the sparse selection further
551        let filtered = sel.filter(|i| i % 2 == 0);
552        // Only 0, 2, 4 pass both filters
553        assert_eq!(filtered.selected_count(), 3);
554        assert!(filtered.is_selected(0));
555        assert!(!filtered.is_selected(1));
556        assert!(filtered.is_selected(2));
557    }
558
559    #[test]
560    fn test_level_selection_iter_all() {
561        let sel = LevelSelection::all(5);
562        let indices: Vec<usize> = sel.iter().collect();
563        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
564    }
565
566    #[test]
567    fn test_level_selection_iter_sparse() {
568        let sel = LevelSelection::from_predicate(10, |i| i % 3 == 0);
569        let indices: Vec<usize> = sel.iter().collect();
570        assert_eq!(indices, vec![0, 3, 6, 9]);
571    }
572
573    #[test]
574    fn test_level_selection_from_predicate_all_selected() {
575        // When predicate selects everything, should return All variant
576        let sel = LevelSelection::from_predicate(5, |_| true);
577        assert_eq!(sel.selected_count(), 5);
578        match sel {
579            LevelSelection::All { count } => assert_eq!(count, 5),
580            LevelSelection::Sparse(_) => panic!("Expected All variant"),
581        }
582    }
583
584    #[test]
585    fn test_level_selection_from_predicate_partial() {
586        // When predicate selects some, should return Sparse variant
587        let sel = LevelSelection::from_predicate(10, |i| i < 3);
588        assert_eq!(sel.selected_count(), 3);
589        match sel {
590            LevelSelection::Sparse(_) => {}
591            LevelSelection::All { .. } => panic!("Expected Sparse variant"),
592        }
593    }
594
595    #[test]
596    fn test_factorized_selection_all() {
597        let sel = FactorizedSelection::all(&[10, 100, 1000]);
598        assert_eq!(sel.level_count(), 3);
599        assert!(sel.is_selected(0, 5));
600        assert!(sel.is_selected(1, 50));
601        assert!(sel.is_selected(2, 500));
602    }
603
604    #[test]
605    fn test_factorized_selection_new() {
606        let level_sels = vec![
607            LevelSelection::all(5),
608            LevelSelection::from_predicate(10, |i| i < 3),
609        ];
610        let sel = FactorizedSelection::new(level_sels);
611
612        assert_eq!(sel.level_count(), 2);
613        assert!(sel.is_selected(0, 4));
614        assert!(sel.is_selected(1, 2));
615        assert!(!sel.is_selected(1, 5));
616    }
617
618    #[test]
619    fn test_factorized_selection_filter_level() {
620        let sel = FactorizedSelection::all(&[10, 100]);
621        let filtered = sel.filter_level(1, |i| i < 50);
622
623        assert!(filtered.is_selected(0, 5)); // Level 0 unchanged
624        assert!(filtered.is_selected(1, 25)); // Level 1: 25 < 50
625        assert!(!filtered.is_selected(1, 75)); // Level 1: 75 >= 50
626    }
627
628    #[test]
629    fn test_factorized_selection_filter_level_invalid() {
630        let sel = FactorizedSelection::all(&[10, 100]);
631
632        // Filtering a non-existent level should not panic
633        let filtered = sel.filter_level(5, |_| true);
634        assert_eq!(filtered.level_count(), 2);
635    }
636
637    #[test]
638    fn test_factorized_selection_is_selected_invalid_level() {
639        let sel = FactorizedSelection::all(&[10]);
640        assert!(!sel.is_selected(5, 0)); // Non-existent level
641    }
642
643    #[test]
644    fn test_factorized_selection_level() {
645        let sel = FactorizedSelection::all(&[10, 20]);
646
647        let level0 = sel.level(0);
648        assert!(level0.is_some());
649        assert_eq!(level0.unwrap().selected_count(), 10);
650
651        let level1 = sel.level(1);
652        assert!(level1.is_some());
653        assert_eq!(level1.unwrap().selected_count(), 20);
654
655        assert!(sel.level(5).is_none());
656    }
657
658    #[test]
659    fn test_factorized_selection_selected_count_single_level() {
660        let mut sel = FactorizedSelection::all(&[10]);
661        let multiplicities: Vec<Vec<usize>> = vec![vec![1; 10]];
662
663        let count = sel.selected_count(&multiplicities);
664        assert_eq!(count, 10);
665    }
666
667    #[test]
668    fn test_factorized_selection_selected_count_multi_level() {
669        let level_sels = vec![
670            LevelSelection::all(2),                            // 2 parents
671            LevelSelection::from_predicate(4, |i| i % 2 == 0), // Select indices 0, 2
672        ];
673        let mut sel = FactorizedSelection::new(level_sels);
674
675        // Parent 0 has 2 children (indices 0, 1)
676        // Parent 1 has 2 children (indices 2, 3)
677        let multiplicities = vec![
678            vec![1, 1], // Level 0 multiplicities
679            vec![2, 2], // Level 1 multiplicities (children per parent)
680        ];
681
682        let count = sel.selected_count(&multiplicities);
683        // Parent 0 is selected, children 0 is selected (1 passes)
684        // Parent 1 is selected, child 2 is selected (1 passes)
685        assert_eq!(count, 2);
686    }
687
688    #[test]
689    fn test_factorized_selection_selected_count_cached() {
690        let mut sel = FactorizedSelection::all(&[5]);
691        let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
692
693        // First call computes
694        let count1 = sel.selected_count(&multiplicities);
695        assert_eq!(count1, 5);
696
697        // Second call uses cache
698        let count2 = sel.selected_count(&multiplicities);
699        assert_eq!(count2, 5);
700    }
701
702    #[test]
703    fn test_factorized_selection_selected_count_empty() {
704        let mut sel = FactorizedSelection::all(&[]);
705        let multiplicities: Vec<Vec<usize>> = vec![];
706
707        assert_eq!(sel.selected_count(&multiplicities), 0);
708    }
709
710    #[test]
711    fn test_factorized_selection_invalidate_cache() {
712        let mut sel = FactorizedSelection::all(&[5]);
713        let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
714
715        // Compute and cache
716        let _ = sel.selected_count(&multiplicities);
717
718        // Invalidate
719        sel.invalidate_cache();
720
721        // Should recompute (no way to verify, but shouldn't crash)
722        let _ = sel.selected_count(&multiplicities);
723    }
724
725    #[test]
726    fn test_chunk_state_flat() {
727        let state = ChunkState::flat(100);
728        assert!(state.is_flat());
729        assert!(!state.is_factorized());
730        assert_eq!(state.logical_row_count(), 100);
731        assert_eq!(state.level_count(), 1);
732    }
733
734    #[test]
735    fn test_chunk_state_unflat() {
736        let state = ChunkState::unflat(3, 1000);
737        assert!(!state.is_flat());
738        assert!(state.is_factorized());
739        assert_eq!(state.logical_row_count(), 1000);
740        assert_eq!(state.level_count(), 3);
741    }
742
743    #[test]
744    fn test_chunk_state_factorization_state() {
745        let state = ChunkState::flat(50);
746        let fs = state.factorization_state();
747        assert!(fs.is_flat());
748    }
749
750    #[test]
751    fn test_chunk_state_selection() {
752        let mut state = ChunkState::unflat(2, 100);
753
754        // Initially no selection
755        assert!(state.selection().is_none());
756
757        // Set selection
758        let sel = FactorizedSelection::all(&[10, 100]);
759        state.set_selection(sel);
760
761        assert!(state.selection().is_some());
762        assert_eq!(state.selection().unwrap().level_count(), 2);
763    }
764
765    #[test]
766    fn test_chunk_state_selection_mut() {
767        let mut state = ChunkState::unflat(2, 100);
768
769        // Set selection
770        let sel = FactorizedSelection::all(&[10, 100]);
771        state.set_selection(sel);
772
773        // Get mutable access
774        let sel_mut = state.selection_mut();
775        assert!(sel_mut.is_some());
776
777        // Clear via mutable access
778        *sel_mut = None;
779        assert!(state.selection().is_none());
780    }
781
782    #[test]
783    fn test_chunk_state_clear_selection() {
784        let mut state = ChunkState::unflat(2, 100);
785
786        let sel = FactorizedSelection::all(&[10, 100]);
787        state.set_selection(sel);
788        assert!(state.selection().is_some());
789
790        state.clear_selection();
791        assert!(state.selection().is_none());
792    }
793
794    #[test]
795    fn test_chunk_state_set_state() {
796        let mut state = ChunkState::flat(100);
797        assert!(state.is_flat());
798        assert_eq!(state.generation(), 0);
799
800        state.set_state(FactorizationState::Unflat {
801            level_count: 2,
802            logical_rows: 200,
803        });
804
805        assert!(state.is_factorized());
806        assert_eq!(state.logical_row_count(), 200);
807        assert_eq!(state.generation(), 1); // Cache invalidated
808    }
809
810    #[test]
811    fn test_chunk_state_caching() {
812        let mut state = ChunkState::unflat(2, 100);
813
814        // First call should compute
815        let mut computed = false;
816        let mults1 = state.get_or_compute_multiplicities(|| {
817            computed = true;
818            vec![1, 2, 3, 4, 5]
819        });
820        assert!(computed);
821        assert_eq!(mults1.len(), 5);
822
823        // Second call should use cache
824        computed = false;
825        let mults2 = state.get_or_compute_multiplicities(|| {
826            computed = true;
827            vec![99, 99, 99]
828        });
829        assert!(!computed);
830        assert_eq!(mults2.len(), 5); // Same as before
831
832        // After invalidation, should recompute
833        state.invalidate_cache();
834        let mults3 = state.get_or_compute_multiplicities(|| {
835            computed = true;
836            vec![10, 20, 30]
837        });
838        assert!(computed);
839        assert_eq!(mults3.len(), 3);
840    }
841
842    #[test]
843    fn test_chunk_state_cached_multiplicities() {
844        let mut state = ChunkState::unflat(2, 100);
845
846        // Initially not cached
847        assert!(state.cached_multiplicities().is_none());
848
849        // Compute multiplicities
850        let _ = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
851
852        // Now cached
853        assert!(state.cached_multiplicities().is_some());
854        assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
855    }
856
857    #[test]
858    fn test_chunk_state_set_cached_multiplicities() {
859        let mut state = ChunkState::unflat(2, 100);
860
861        let mults: Arc<[usize]> = vec![5, 10, 15].into();
862        state.set_cached_multiplicities(mults);
863
864        assert!(state.cached_multiplicities().is_some());
865        assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
866    }
867
868    #[test]
869    fn test_chunk_state_generation() {
870        let mut state = ChunkState::flat(100);
871        assert_eq!(state.generation(), 0);
872
873        state.invalidate_cache();
874        assert_eq!(state.generation(), 1);
875
876        state.set_state(FactorizationState::Unflat {
877            level_count: 2,
878            logical_rows: 200,
879        });
880        assert_eq!(state.generation(), 2);
881    }
882
883    #[test]
884    fn test_chunk_state_default() {
885        let state = ChunkState::default();
886        assert!(state.is_flat());
887        assert_eq!(state.logical_row_count(), 0);
888        assert_eq!(state.generation(), 0);
889    }
890}