Skip to main content

grafeo_core/execution/
chunk_state.rs

1//! Unified chunk state tracking for factorized execution.
2//!
3//! This module provides [`ChunkState`] for centralized state management,
4//! inspired by LadybugDB's `FStateType` pattern. Key benefits:
5//!
6//! - **Cached multiplicities**: Computed once, reused for all aggregates
7//! - **Selection integration**: Lazy filtering without data copying
8//! - **O(1) logical row count**: Cached, not recomputed
9//!
10//! # Example
11//!
12//! ```rust
13//! use grafeo_core::execution::chunk_state::ChunkState;
14//!
15//! let mut state = ChunkState::unflat(3, 1000);
16//!
17//! // First access computes, subsequent accesses use cache
18//! let mults = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
19//! let mults2 = state.get_or_compute_multiplicities(|| unreachable!());
20//! assert!(std::ptr::eq(mults.as_ptr(), mults2.as_ptr()));
21//! ```
22
23use std::sync::Arc;
24
25use super::selection::SelectionVector;
26
27/// Factorization state of a chunk (flat vs factorized).
28///
29/// Similar to LadybugDB's `FStateType`, this provides a single source
30/// of truth for the chunk's factorization status.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum FactorizationState {
33    /// All vectors are flat - one value per logical row.
34    ///
35    /// This is the state after flattening or for simple scans.
36    Flat {
37        /// Number of rows (physical = logical).
38        row_count: usize,
39    },
40    /// One or more vectors are unflat - values grouped by parent.
41    ///
42    /// The chunk has multi-level structure.
43    Unflat {
44        /// Number of factorization levels.
45        level_count: usize,
46        /// Total logical row count (cached, not recomputed).
47        logical_rows: usize,
48    },
49}
50
51impl FactorizationState {
52    /// Returns true if this is a flat state.
53    #[must_use]
54    pub fn is_flat(&self) -> bool {
55        matches!(self, Self::Flat { .. })
56    }
57
58    /// Returns true if this is an unflat (factorized) state.
59    #[must_use]
60    pub fn is_unflat(&self) -> bool {
61        matches!(self, Self::Unflat { .. })
62    }
63
64    /// Returns the logical row count.
65    #[must_use]
66    pub fn logical_row_count(&self) -> usize {
67        match self {
68            Self::Flat { row_count } => *row_count,
69            Self::Unflat { logical_rows, .. } => *logical_rows,
70        }
71    }
72
73    /// Returns the number of factorization levels.
74    #[must_use]
75    pub fn level_count(&self) -> usize {
76        match self {
77            Self::Flat { .. } => 1,
78            Self::Unflat { level_count, .. } => *level_count,
79        }
80    }
81}
82
83/// Selection state for a single factorization level.
84///
85/// Supports both sparse (for low selectivity) and dense (for high selectivity)
86/// representations to optimize memory usage.
87#[derive(Debug, Clone)]
88pub enum LevelSelection {
89    /// All values at this level are selected.
90    All {
91        /// Total count of values at this level.
92        count: usize,
93    },
94    /// Only specific indices are selected (for low selectivity).
95    ///
96    /// Uses `SelectionVector` which stores indices as `u16`.
97    Sparse(SelectionVector),
98}
99
100impl LevelSelection {
101    /// Creates a selection that selects all values.
102    #[must_use]
103    pub fn all(count: usize) -> Self {
104        Self::All { count }
105    }
106
107    /// Creates a sparse selection from a predicate.
108    #[must_use]
109    pub fn from_predicate<F>(count: usize, predicate: F) -> Self
110    where
111        F: Fn(usize) -> bool,
112    {
113        let selected = SelectionVector::from_predicate(count, predicate);
114        if selected.len() == count {
115            Self::All { count }
116        } else {
117            Self::Sparse(selected)
118        }
119    }
120
121    /// Returns the number of selected values.
122    #[must_use]
123    pub fn selected_count(&self) -> usize {
124        match self {
125            Self::All { count } => *count,
126            Self::Sparse(sel) => sel.len(),
127        }
128    }
129
130    /// Returns true if a physical index is selected.
131    #[must_use]
132    pub fn is_selected(&self, physical_idx: usize) -> bool {
133        match self {
134            Self::All { count } => physical_idx < *count,
135            Self::Sparse(sel) => sel.contains(physical_idx),
136        }
137    }
138
139    /// Filters this selection with a predicate, returning a new selection.
140    #[must_use]
141    pub fn filter<F>(&self, predicate: F) -> Self
142    where
143        F: Fn(usize) -> bool,
144    {
145        match self {
146            Self::All { count } => Self::from_predicate(*count, predicate),
147            Self::Sparse(sel) => {
148                let filtered = sel.filter(predicate);
149                Self::Sparse(filtered)
150            }
151        }
152    }
153
154    /// Returns an iterator over selected indices.
155    pub fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
156        match self {
157            Self::All { count } => Box::new(0..*count),
158            Self::Sparse(sel) => Box::new(sel.iter()),
159        }
160    }
161}
162
163impl<'a> IntoIterator for &'a LevelSelection {
164    type Item = usize;
165    type IntoIter = Box<dyn Iterator<Item = usize> + 'a>;
166
167    fn into_iter(self) -> Self::IntoIter {
168        self.iter()
169    }
170}
171
172/// Hierarchical selection for factorized data.
173///
174/// Tracks selections at each factorization level, enabling filtering
175/// without flattening or copying data.
176#[derive(Debug, Clone)]
177pub struct FactorizedSelection {
178    /// Selection state per level (level 0 = sources, higher = more nested).
179    level_selections: Vec<LevelSelection>,
180    /// Cached logical row count after selection (lazily computed).
181    cached_selected_count: Option<usize>,
182}
183
184impl FactorizedSelection {
185    /// Creates a selection that selects all values at all levels.
186    #[must_use]
187    pub fn all(level_counts: &[usize]) -> Self {
188        let level_selections = level_counts
189            .iter()
190            .map(|&count| LevelSelection::all(count))
191            .collect();
192        Self {
193            level_selections,
194            cached_selected_count: None,
195        }
196    }
197
198    /// Creates a selection from level selections.
199    #[must_use]
200    pub fn new(level_selections: Vec<LevelSelection>) -> Self {
201        Self {
202            level_selections,
203            cached_selected_count: None,
204        }
205    }
206
207    /// Returns the number of levels.
208    #[must_use]
209    pub fn level_count(&self) -> usize {
210        self.level_selections.len()
211    }
212
213    /// Gets the selection for a specific level.
214    #[must_use]
215    pub fn level(&self, level: usize) -> Option<&LevelSelection> {
216        self.level_selections.get(level)
217    }
218
219    /// Filters at a specific level using a predicate.
220    ///
221    /// Returns a new selection with the filter applied.
222    /// This is O(n_physical) where n is the physical size of that level,
223    /// not O(n_logical) where n is the logical row count.
224    #[must_use]
225    pub fn filter_level<F>(&self, level: usize, predicate: F) -> Self
226    where
227        F: Fn(usize) -> bool,
228    {
229        let mut new_selections = self.level_selections.clone();
230
231        if let Some(sel) = new_selections.get_mut(level) {
232            *sel = sel.filter(predicate);
233        }
234
235        Self {
236            level_selections: new_selections,
237            cached_selected_count: None, // Invalidate cache
238        }
239    }
240
241    /// Checks if a physical index at a level is selected.
242    #[must_use]
243    pub fn is_selected(&self, level: usize, physical_idx: usize) -> bool {
244        self.level_selections
245            .get(level)
246            .is_some_and(|sel| sel.is_selected(physical_idx))
247    }
248
249    /// Computes and caches the selected logical row count.
250    ///
251    /// The computation considers parent-child relationships:
252    /// a child is only counted if its parent is selected.
253    pub fn selected_count(&mut self, multiplicities: &[Vec<usize>]) -> usize {
254        if let Some(count) = self.cached_selected_count {
255            return count;
256        }
257
258        let count = self.compute_selected_count(multiplicities);
259        self.cached_selected_count = Some(count);
260        count
261    }
262
263    /// Computes the selected count without caching.
264    fn compute_selected_count(&self, multiplicities: &[Vec<usize>]) -> usize {
265        if self.level_selections.is_empty() {
266            return 0;
267        }
268
269        // For single level, just count selected
270        if self.level_selections.len() == 1 {
271            return self.level_selections[0].selected_count();
272        }
273
274        // For multi-level, we need to propagate selection through levels
275        // Start with level 0 selection
276        let mut parent_selected: Vec<bool> = match &self.level_selections[0] {
277            LevelSelection::All { count } => vec![true; *count],
278            LevelSelection::Sparse(sel) => {
279                let max_idx = sel.iter().max().unwrap_or(0);
280                let mut selected = vec![false; max_idx + 1];
281                for idx in sel.iter() {
282                    selected[idx] = true;
283                }
284                selected
285            }
286        };
287
288        // Propagate through subsequent levels
289        for (level_sel, level_mults) in self
290            .level_selections
291            .iter()
292            .skip(1)
293            .zip(multiplicities.iter().skip(1))
294        {
295            let mut child_selected = Vec::new();
296            let mut child_idx = 0;
297
298            for (parent_idx, &mult) in level_mults.iter().enumerate() {
299                let parent_is_selected = parent_selected.get(parent_idx).copied().unwrap_or(false);
300
301                for _ in 0..mult {
302                    let child_is_selected = parent_is_selected && level_sel.is_selected(child_idx);
303                    child_selected.push(child_is_selected);
304                    child_idx += 1;
305                }
306            }
307
308            parent_selected = child_selected;
309        }
310
311        // Count final selected
312        parent_selected.iter().filter(|&&s| s).count()
313    }
314
315    /// Invalidates the cached selected count.
316    pub fn invalidate_cache(&mut self) {
317        self.cached_selected_count = None;
318    }
319}
320
321/// Unified chunk state tracking metadata for factorized execution.
322///
323/// This replaces scattered state tracking with a centralized structure
324/// that is updated incrementally rather than recomputed.
325///
326/// # Key Features
327///
328/// - **Cached multiplicities**: Computed once per chunk, reused for all aggregates
329/// - **Selection integration**: Supports lazy filtering without data copying
330/// - **Generation tracking**: Enables cache invalidation on structure changes
331#[derive(Debug, Clone)]
332pub struct ChunkState {
333    /// Factorization state of this chunk.
334    state: FactorizationState,
335    /// Selection for filtering without data copying.
336    /// When Some, only selected indices are "active".
337    selection: Option<FactorizedSelection>,
338    /// Cached path multiplicities (invalidated on structure change).
339    /// Key optimization: computed once, reused for all aggregates.
340    cached_multiplicities: Option<Arc<[usize]>>,
341    /// Generation counter for cache invalidation.
342    generation: u64,
343}
344
345impl ChunkState {
346    /// Creates a new flat chunk state.
347    #[must_use]
348    pub fn flat(row_count: usize) -> Self {
349        Self {
350            state: FactorizationState::Flat { row_count },
351            selection: None,
352            cached_multiplicities: None,
353            generation: 0,
354        }
355    }
356
357    /// Creates an unflat (factorized) chunk state.
358    #[must_use]
359    pub fn unflat(level_count: usize, logical_rows: usize) -> Self {
360        Self {
361            state: FactorizationState::Unflat {
362                level_count,
363                logical_rows,
364            },
365            selection: None,
366            cached_multiplicities: None,
367            generation: 0,
368        }
369    }
370
371    /// Returns the factorization state.
372    #[must_use]
373    pub fn factorization_state(&self) -> FactorizationState {
374        self.state
375    }
376
377    /// Returns true if this chunk is flat.
378    #[must_use]
379    pub fn is_flat(&self) -> bool {
380        self.state.is_flat()
381    }
382
383    /// Returns true if this chunk is factorized (unflat).
384    #[must_use]
385    pub fn is_factorized(&self) -> bool {
386        self.state.is_unflat()
387    }
388
389    /// Returns the logical row count.
390    ///
391    /// If a selection is active, this returns the base logical row count
392    /// (selection count must be computed separately with multiplicities).
393    #[must_use]
394    pub fn logical_row_count(&self) -> usize {
395        self.state.logical_row_count()
396    }
397
398    /// Returns the number of factorization levels.
399    #[must_use]
400    pub fn level_count(&self) -> usize {
401        self.state.level_count()
402    }
403
404    /// Returns the current generation (for cache validation).
405    #[must_use]
406    pub fn generation(&self) -> u64 {
407        self.generation
408    }
409
410    /// Returns the selection, if any.
411    #[must_use]
412    pub fn selection(&self) -> Option<&FactorizedSelection> {
413        self.selection.as_ref()
414    }
415
416    /// Returns mutable access to the selection.
417    pub fn selection_mut(&mut self) -> &mut Option<FactorizedSelection> {
418        &mut self.selection
419    }
420
421    /// Sets the selection.
422    pub fn set_selection(&mut self, selection: FactorizedSelection) {
423        self.selection = Some(selection);
424        // Don't invalidate multiplicity cache - selection is orthogonal
425    }
426
427    /// Clears the selection.
428    pub fn clear_selection(&mut self) {
429        self.selection = None;
430    }
431
432    /// Updates the state (e.g., after adding a level).
433    pub fn set_state(&mut self, state: FactorizationState) {
434        self.state = state;
435        self.invalidate_cache();
436    }
437
438    /// Invalidates cached data (call when structure changes).
439    pub fn invalidate_cache(&mut self) {
440        self.cached_multiplicities = None;
441        self.generation += 1;
442    }
443
444    /// Gets cached multiplicities, or computes and caches them.
445    ///
446    /// This is the key optimization: multiplicities are computed once
447    /// and reused for all aggregates (COUNT, SUM, AVG, etc.).
448    ///
449    /// # Arguments
450    ///
451    /// * `compute` - Function to compute multiplicities if not cached
452    ///
453    /// # Example
454    ///
455    /// ```rust
456    /// # use grafeo_core::execution::chunk_state::ChunkState;
457    /// # let mut state = ChunkState::unflat(2, 100);
458    /// let mults = state.get_or_compute_multiplicities(|| {
459    ///     vec![1; 100] // compute multiplicities
460    /// });
461    /// ```
462    pub fn get_or_compute_multiplicities<F>(&mut self, compute: F) -> Arc<[usize]>
463    where
464        F: FnOnce() -> Vec<usize>,
465    {
466        if let Some(ref cached) = self.cached_multiplicities {
467            return Arc::clone(cached);
468        }
469
470        let mults: Arc<[usize]> = compute().into();
471        self.cached_multiplicities = Some(Arc::clone(&mults));
472        mults
473    }
474
475    /// Returns cached multiplicities without computing.
476    ///
477    /// Returns None if not yet computed.
478    #[must_use]
479    pub fn cached_multiplicities(&self) -> Option<&Arc<[usize]>> {
480        self.cached_multiplicities.as_ref()
481    }
482
483    /// Sets the cached multiplicities directly.
484    ///
485    /// Useful when multiplicities are computed externally.
486    pub fn set_cached_multiplicities(&mut self, mults: Arc<[usize]>) {
487        self.cached_multiplicities = Some(mults);
488    }
489}
490
491impl Default for ChunkState {
492    fn default() -> Self {
493        Self::flat(0)
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_factorization_state_flat() {
503        let state = FactorizationState::Flat { row_count: 100 };
504        assert!(state.is_flat());
505        assert!(!state.is_unflat());
506        assert_eq!(state.logical_row_count(), 100);
507        assert_eq!(state.level_count(), 1);
508    }
509
510    #[test]
511    fn test_factorization_state_unflat() {
512        let state = FactorizationState::Unflat {
513            level_count: 3,
514            logical_rows: 1000,
515        };
516        assert!(!state.is_flat());
517        assert!(state.is_unflat());
518        assert_eq!(state.logical_row_count(), 1000);
519        assert_eq!(state.level_count(), 3);
520    }
521
522    #[test]
523    fn test_level_selection_all() {
524        let sel = LevelSelection::all(10);
525        assert_eq!(sel.selected_count(), 10);
526        for i in 0..10 {
527            assert!(sel.is_selected(i));
528        }
529        assert!(!sel.is_selected(10));
530    }
531
532    #[test]
533    fn test_level_selection_filter() {
534        let sel = LevelSelection::all(10);
535        let filtered = sel.filter(|i| i % 2 == 0);
536        assert_eq!(filtered.selected_count(), 5);
537        assert!(filtered.is_selected(0));
538        assert!(!filtered.is_selected(1));
539        assert!(filtered.is_selected(2));
540    }
541
542    #[test]
543    fn test_level_selection_filter_sparse() {
544        // Start with sparse selection
545        let sel = LevelSelection::from_predicate(10, |i| i < 5);
546        assert_eq!(sel.selected_count(), 5);
547
548        // Filter the sparse selection further
549        let filtered = sel.filter(|i| i % 2 == 0);
550        // Only 0, 2, 4 pass both filters
551        assert_eq!(filtered.selected_count(), 3);
552        assert!(filtered.is_selected(0));
553        assert!(!filtered.is_selected(1));
554        assert!(filtered.is_selected(2));
555    }
556
557    #[test]
558    fn test_level_selection_iter_all() {
559        let sel = LevelSelection::all(5);
560        let indices: Vec<usize> = sel.iter().collect();
561        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
562    }
563
564    #[test]
565    fn test_level_selection_iter_sparse() {
566        let sel = LevelSelection::from_predicate(10, |i| i % 3 == 0);
567        let indices: Vec<usize> = sel.iter().collect();
568        assert_eq!(indices, vec![0, 3, 6, 9]);
569    }
570
571    #[test]
572    fn test_level_selection_from_predicate_all_selected() {
573        // When predicate selects everything, should return All variant
574        let sel = LevelSelection::from_predicate(5, |_| true);
575        assert_eq!(sel.selected_count(), 5);
576        match sel {
577            LevelSelection::All { count } => assert_eq!(count, 5),
578            LevelSelection::Sparse(_) => panic!("Expected All variant"),
579        }
580    }
581
582    #[test]
583    fn test_level_selection_from_predicate_partial() {
584        // When predicate selects some, should return Sparse variant
585        let sel = LevelSelection::from_predicate(10, |i| i < 3);
586        assert_eq!(sel.selected_count(), 3);
587        match sel {
588            LevelSelection::Sparse(_) => {}
589            LevelSelection::All { .. } => panic!("Expected Sparse variant"),
590        }
591    }
592
593    #[test]
594    fn test_factorized_selection_all() {
595        let sel = FactorizedSelection::all(&[10, 100, 1000]);
596        assert_eq!(sel.level_count(), 3);
597        assert!(sel.is_selected(0, 5));
598        assert!(sel.is_selected(1, 50));
599        assert!(sel.is_selected(2, 500));
600    }
601
602    #[test]
603    fn test_factorized_selection_new() {
604        let level_sels = vec![
605            LevelSelection::all(5),
606            LevelSelection::from_predicate(10, |i| i < 3),
607        ];
608        let sel = FactorizedSelection::new(level_sels);
609
610        assert_eq!(sel.level_count(), 2);
611        assert!(sel.is_selected(0, 4));
612        assert!(sel.is_selected(1, 2));
613        assert!(!sel.is_selected(1, 5));
614    }
615
616    #[test]
617    fn test_factorized_selection_filter_level() {
618        let sel = FactorizedSelection::all(&[10, 100]);
619        let filtered = sel.filter_level(1, |i| i < 50);
620
621        assert!(filtered.is_selected(0, 5)); // Level 0 unchanged
622        assert!(filtered.is_selected(1, 25)); // Level 1: 25 < 50
623        assert!(!filtered.is_selected(1, 75)); // Level 1: 75 >= 50
624    }
625
626    #[test]
627    fn test_factorized_selection_filter_level_invalid() {
628        let sel = FactorizedSelection::all(&[10, 100]);
629
630        // Filtering a non-existent level should not panic
631        let filtered = sel.filter_level(5, |_| true);
632        assert_eq!(filtered.level_count(), 2);
633    }
634
635    #[test]
636    fn test_factorized_selection_is_selected_invalid_level() {
637        let sel = FactorizedSelection::all(&[10]);
638        assert!(!sel.is_selected(5, 0)); // Non-existent level
639    }
640
641    #[test]
642    fn test_factorized_selection_level() {
643        let sel = FactorizedSelection::all(&[10, 20]);
644
645        let level0 = sel.level(0);
646        assert!(level0.is_some());
647        assert_eq!(level0.unwrap().selected_count(), 10);
648
649        let level1 = sel.level(1);
650        assert!(level1.is_some());
651        assert_eq!(level1.unwrap().selected_count(), 20);
652
653        assert!(sel.level(5).is_none());
654    }
655
656    #[test]
657    fn test_factorized_selection_selected_count_single_level() {
658        let mut sel = FactorizedSelection::all(&[10]);
659        let multiplicities: Vec<Vec<usize>> = vec![vec![1; 10]];
660
661        let count = sel.selected_count(&multiplicities);
662        assert_eq!(count, 10);
663    }
664
665    #[test]
666    fn test_factorized_selection_selected_count_multi_level() {
667        let level_sels = vec![
668            LevelSelection::all(2),                            // 2 parents
669            LevelSelection::from_predicate(4, |i| i % 2 == 0), // Select indices 0, 2
670        ];
671        let mut sel = FactorizedSelection::new(level_sels);
672
673        // Parent 0 has 2 children (indices 0, 1)
674        // Parent 1 has 2 children (indices 2, 3)
675        let multiplicities = vec![
676            vec![1, 1], // Level 0 multiplicities
677            vec![2, 2], // Level 1 multiplicities (children per parent)
678        ];
679
680        let count = sel.selected_count(&multiplicities);
681        // Parent 0 is selected, children 0 is selected (1 passes)
682        // Parent 1 is selected, child 2 is selected (1 passes)
683        assert_eq!(count, 2);
684    }
685
686    #[test]
687    fn test_factorized_selection_selected_count_cached() {
688        let mut sel = FactorizedSelection::all(&[5]);
689        let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
690
691        // First call computes
692        let count1 = sel.selected_count(&multiplicities);
693        assert_eq!(count1, 5);
694
695        // Second call uses cache
696        let count2 = sel.selected_count(&multiplicities);
697        assert_eq!(count2, 5);
698    }
699
700    #[test]
701    fn test_factorized_selection_selected_count_empty() {
702        let mut sel = FactorizedSelection::all(&[]);
703        let multiplicities: Vec<Vec<usize>> = vec![];
704
705        assert_eq!(sel.selected_count(&multiplicities), 0);
706    }
707
708    #[test]
709    fn test_factorized_selection_invalidate_cache() {
710        let mut sel = FactorizedSelection::all(&[5]);
711        let multiplicities: Vec<Vec<usize>> = vec![vec![1; 5]];
712
713        // Compute and cache
714        let _ = sel.selected_count(&multiplicities);
715
716        // Invalidate
717        sel.invalidate_cache();
718
719        // Should recompute (no way to verify, but shouldn't crash)
720        let _ = sel.selected_count(&multiplicities);
721    }
722
723    #[test]
724    fn test_chunk_state_flat() {
725        let state = ChunkState::flat(100);
726        assert!(state.is_flat());
727        assert!(!state.is_factorized());
728        assert_eq!(state.logical_row_count(), 100);
729        assert_eq!(state.level_count(), 1);
730    }
731
732    #[test]
733    fn test_chunk_state_unflat() {
734        let state = ChunkState::unflat(3, 1000);
735        assert!(!state.is_flat());
736        assert!(state.is_factorized());
737        assert_eq!(state.logical_row_count(), 1000);
738        assert_eq!(state.level_count(), 3);
739    }
740
741    #[test]
742    fn test_chunk_state_factorization_state() {
743        let state = ChunkState::flat(50);
744        let fs = state.factorization_state();
745        assert!(fs.is_flat());
746    }
747
748    #[test]
749    fn test_chunk_state_selection() {
750        let mut state = ChunkState::unflat(2, 100);
751
752        // Initially no selection
753        assert!(state.selection().is_none());
754
755        // Set selection
756        let sel = FactorizedSelection::all(&[10, 100]);
757        state.set_selection(sel);
758
759        assert!(state.selection().is_some());
760        assert_eq!(state.selection().unwrap().level_count(), 2);
761    }
762
763    #[test]
764    fn test_chunk_state_selection_mut() {
765        let mut state = ChunkState::unflat(2, 100);
766
767        // Set selection
768        let sel = FactorizedSelection::all(&[10, 100]);
769        state.set_selection(sel);
770
771        // Get mutable access
772        let sel_mut = state.selection_mut();
773        assert!(sel_mut.is_some());
774
775        // Clear via mutable access
776        *sel_mut = None;
777        assert!(state.selection().is_none());
778    }
779
780    #[test]
781    fn test_chunk_state_clear_selection() {
782        let mut state = ChunkState::unflat(2, 100);
783
784        let sel = FactorizedSelection::all(&[10, 100]);
785        state.set_selection(sel);
786        assert!(state.selection().is_some());
787
788        state.clear_selection();
789        assert!(state.selection().is_none());
790    }
791
792    #[test]
793    fn test_chunk_state_set_state() {
794        let mut state = ChunkState::flat(100);
795        assert!(state.is_flat());
796        assert_eq!(state.generation(), 0);
797
798        state.set_state(FactorizationState::Unflat {
799            level_count: 2,
800            logical_rows: 200,
801        });
802
803        assert!(state.is_factorized());
804        assert_eq!(state.logical_row_count(), 200);
805        assert_eq!(state.generation(), 1); // Cache invalidated
806    }
807
808    #[test]
809    fn test_chunk_state_caching() {
810        let mut state = ChunkState::unflat(2, 100);
811
812        // First call should compute
813        let mut computed = false;
814        let mults1 = state.get_or_compute_multiplicities(|| {
815            computed = true;
816            vec![1, 2, 3, 4, 5]
817        });
818        assert!(computed);
819        assert_eq!(mults1.len(), 5);
820
821        // Second call should use cache
822        computed = false;
823        let mults2 = state.get_or_compute_multiplicities(|| {
824            computed = true;
825            vec![99, 99, 99]
826        });
827        assert!(!computed);
828        assert_eq!(mults2.len(), 5); // Same as before
829
830        // After invalidation, should recompute
831        state.invalidate_cache();
832        let mults3 = state.get_or_compute_multiplicities(|| {
833            computed = true;
834            vec![10, 20, 30]
835        });
836        assert!(computed);
837        assert_eq!(mults3.len(), 3);
838    }
839
840    #[test]
841    fn test_chunk_state_cached_multiplicities() {
842        let mut state = ChunkState::unflat(2, 100);
843
844        // Initially not cached
845        assert!(state.cached_multiplicities().is_none());
846
847        // Compute multiplicities
848        let _ = state.get_or_compute_multiplicities(|| vec![1, 2, 3]);
849
850        // Now cached
851        assert!(state.cached_multiplicities().is_some());
852        assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
853    }
854
855    #[test]
856    fn test_chunk_state_set_cached_multiplicities() {
857        let mut state = ChunkState::unflat(2, 100);
858
859        let mults: Arc<[usize]> = vec![5, 10, 15].into();
860        state.set_cached_multiplicities(mults);
861
862        assert!(state.cached_multiplicities().is_some());
863        assert_eq!(state.cached_multiplicities().unwrap().len(), 3);
864    }
865
866    #[test]
867    fn test_chunk_state_generation() {
868        let mut state = ChunkState::flat(100);
869        assert_eq!(state.generation(), 0);
870
871        state.invalidate_cache();
872        assert_eq!(state.generation(), 1);
873
874        state.set_state(FactorizationState::Unflat {
875            level_count: 2,
876            logical_rows: 200,
877        });
878        assert_eq!(state.generation(), 2);
879    }
880
881    #[test]
882    fn test_chunk_state_default() {
883        let state = ChunkState::default();
884        assert!(state.is_flat());
885        assert_eq!(state.logical_row_count(), 0);
886        assert_eq!(state.generation(), 0);
887    }
888}