pandrs/dataframe/
multi_index_cross_section.rs

1//! # Advanced MultiIndex with Cross-Section Selection
2//!
3//! This module provides cross-section selection capabilities for DataFrames with MultiIndex,
4//! including partial indexing, level slicing, boolean indexing, and hierarchical navigation.
5
6use crate::column::{Column, ColumnTrait};
7use crate::core::advanced_multi_index::{
8    AdvancedMultiIndex, CrossSectionResult, IndexValue, SelectionCriteria,
9};
10use crate::core::error::{Error, Result};
11use crate::dataframe::DataFrame;
12use std::collections::HashMap;
13
14/// DataFrame with advanced MultiIndex cross-section support
15#[derive(Debug, Clone)]
16pub struct MultiIndexDataFrame {
17    /// The underlying DataFrame
18    pub dataframe: DataFrame,
19    /// The row index (MultiIndex)
20    pub index: AdvancedMultiIndex,
21    /// Column names (simple for now, could be extended to MultiIndex)
22    pub column_names: Vec<String>,
23}
24
25/// Result of cross-section selection
26#[derive(Debug, Clone)]
27pub struct CrossSectionDataFrame {
28    /// Resulting DataFrame
29    pub dataframe: MultiIndexDataFrame,
30    /// Selected row indices
31    pub selected_indices: Vec<usize>,
32    /// Whether any level was dropped
33    pub level_dropped: bool,
34}
35
36impl MultiIndexDataFrame {
37    /// Create a new MultiIndexDataFrame
38    pub fn new(dataframe: DataFrame, index: AdvancedMultiIndex) -> Result<Self> {
39        if index.len() != dataframe.row_count() {
40            return Err(Error::InconsistentRowCount {
41                expected: dataframe.row_count(),
42                found: index.len(),
43            });
44        }
45
46        let column_names = dataframe.column_names().to_vec();
47
48        Ok(Self {
49            dataframe,
50            index,
51            column_names,
52        })
53    }
54
55    /// Cross-section selection: select rows with specific value at given level
56    pub fn xs(
57        &mut self,
58        key: IndexValue,
59        level: usize,
60        drop_level: bool,
61    ) -> Result<CrossSectionDataFrame> {
62        let xs_result = self.index.xs(key, level, drop_level)?;
63
64        if xs_result.indices.is_empty() {
65            return Err(Error::InvalidOperation(
66                "No matching rows found for cross-section".to_string(),
67            ));
68        }
69
70        // Create new DataFrame with selected rows
71        let selected_dataframe = self.select_rows(&xs_result.indices)?;
72
73        // Handle index transformation
74        let result_index = if drop_level {
75            xs_result.index.ok_or_else(|| {
76                Error::InvalidOperation("Expected index after dropping level".to_string())
77            })?
78        } else {
79            // Keep original index structure but only selected rows
80            self.select_index_rows(&xs_result.indices)?
81        };
82
83        let result = MultiIndexDataFrame {
84            dataframe: selected_dataframe,
85            index: result_index,
86            column_names: self.column_names.clone(),
87        };
88
89        Ok(CrossSectionDataFrame {
90            dataframe: result,
91            selected_indices: xs_result.indices,
92            level_dropped: drop_level,
93        })
94    }
95
96    /// Advanced selection using multiple criteria
97    pub fn select(&self, criteria: SelectionCriteria) -> Result<CrossSectionDataFrame> {
98        let selected_indices = self.index.select(criteria)?;
99
100        if selected_indices.is_empty() {
101            return Err(Error::InvalidOperation(
102                "No matching rows found for selection criteria".to_string(),
103            ));
104        }
105
106        let selected_dataframe = self.select_rows(&selected_indices)?;
107        let selected_index = self.select_index_rows(&selected_indices)?;
108
109        let result = MultiIndexDataFrame {
110            dataframe: selected_dataframe,
111            index: selected_index,
112            column_names: self.column_names.clone(),
113        };
114
115        Ok(CrossSectionDataFrame {
116            dataframe: result,
117            selected_indices,
118            level_dropped: false,
119        })
120    }
121
122    /// GroupBy operation using index levels
123    pub fn groupby_level(&self, levels: &[usize]) -> Result<MultiIndexGroupBy> {
124        // Validate levels
125        for &level in levels {
126            if level >= self.index.n_levels() {
127                return Err(Error::IndexOutOfBounds {
128                    index: level,
129                    size: self.index.n_levels(),
130                });
131            }
132        }
133
134        // Get unique group keys
135        let group_keys = self.index.get_group_keys(levels)?;
136
137        // Build groups: group_key -> row_indices
138        let mut groups = HashMap::new();
139        for group_key in group_keys {
140            let indices = self.index.get_group_indices(levels, &group_key)?;
141            groups.insert(group_key, indices);
142        }
143
144        Ok(MultiIndexGroupBy {
145            dataframe: self.clone(),
146            groups,
147            group_levels: levels.to_vec(),
148        })
149    }
150
151    /// Select specific levels from the index
152    pub fn select_levels(&self, levels: &[usize]) -> Result<MultiIndexDataFrame> {
153        // Validate levels
154        for &level in levels {
155            if level >= self.index.n_levels() {
156                return Err(Error::IndexOutOfBounds {
157                    index: level,
158                    size: self.index.n_levels(),
159                });
160            }
161        }
162
163        // Create new tuples with only selected levels
164        let mut new_tuples = Vec::with_capacity(self.index.len());
165        for i in 0..self.index.len() {
166            let original_tuple = self.index.get_tuple(i)?;
167            let new_tuple: Vec<IndexValue> = levels
168                .iter()
169                .map(|&level| original_tuple[level].clone())
170                .collect();
171            new_tuples.push(new_tuple);
172        }
173
174        // Create new level names
175        let original_names = self.index.level_names();
176        let new_names: Vec<Option<String>> = levels
177            .iter()
178            .map(|&level| original_names[level].clone())
179            .collect();
180
181        let new_index = AdvancedMultiIndex::new(new_tuples, Some(new_names))?;
182
183        Ok(MultiIndexDataFrame {
184            dataframe: self.dataframe.clone(),
185            index: new_index,
186            column_names: self.column_names.clone(),
187        })
188    }
189
190    /// Reindex with new MultiIndex
191    pub fn reindex(&self, new_index: AdvancedMultiIndex) -> Result<MultiIndexDataFrame> {
192        if new_index.len() != self.index.len() {
193            return Err(Error::InconsistentRowCount {
194                expected: self.index.len(),
195                found: new_index.len(),
196            });
197        }
198
199        Ok(MultiIndexDataFrame {
200            dataframe: self.dataframe.clone(),
201            index: new_index,
202            column_names: self.column_names.clone(),
203        })
204    }
205
206    /// Swap levels in the index
207    pub fn swaplevel(&self, i: usize, j: usize) -> Result<MultiIndexDataFrame> {
208        let swapped_index = self.index.reorder_levels(&{
209            let mut order: Vec<usize> = (0..self.index.n_levels()).collect();
210            order.swap(i, j);
211            order
212        })?;
213
214        Ok(MultiIndexDataFrame {
215            dataframe: self.dataframe.clone(),
216            index: swapped_index,
217            column_names: self.column_names.clone(),
218        })
219    }
220
221    /// Sort by index levels
222    pub fn sort_index(
223        &self,
224        levels: Option<&[usize]>,
225        ascending: bool,
226    ) -> Result<MultiIndexDataFrame> {
227        let default_levels: Vec<usize> = (0..self.index.n_levels()).collect();
228        let sort_levels = levels.unwrap_or(&default_levels);
229
230        // Create sortable tuples with original indices
231        let mut indexed_tuples: Vec<(usize, Vec<IndexValue>)> =
232            Vec::with_capacity(self.index.len());
233        for i in 0..self.index.len() {
234            let tuple = self.index.get_tuple(i)?;
235            let sort_tuple: Vec<IndexValue> = sort_levels
236                .iter()
237                .map(|&level| tuple[level].clone())
238                .collect();
239            indexed_tuples.push((i, sort_tuple));
240        }
241
242        // Sort tuples
243        indexed_tuples.sort_by(|a, b| {
244            let comparison = a.1.cmp(&b.1);
245            if ascending {
246                comparison
247            } else {
248                comparison.reverse()
249            }
250        });
251
252        // Extract sorted indices
253        let sorted_indices: Vec<usize> = indexed_tuples.into_iter().map(|(idx, _)| idx).collect();
254
255        // Apply sorting to DataFrame and index
256        let sorted_dataframe = self.select_rows(&sorted_indices)?;
257        let sorted_index = self.select_index_rows(&sorted_indices)?;
258
259        Ok(MultiIndexDataFrame {
260            dataframe: sorted_dataframe,
261            index: sorted_index,
262            column_names: self.column_names.clone(),
263        })
264    }
265
266    /// Get level values as a vector
267    pub fn get_level_values(&self, level: usize) -> Result<Vec<IndexValue>> {
268        if level >= self.index.n_levels() {
269            return Err(Error::IndexOutOfBounds {
270                index: level,
271                size: self.index.n_levels(),
272            });
273        }
274
275        let mut values = Vec::with_capacity(self.index.len());
276        for i in 0..self.index.len() {
277            let tuple = self.index.get_tuple(i)?;
278            values.push(tuple[level].clone());
279        }
280
281        Ok(values)
282    }
283
284    /// Check if index is monotonic at specified level
285    pub fn is_monotonic(&self, level: usize) -> Result<bool> {
286        let values = self.get_level_values(level)?;
287        if values.len() <= 1 {
288            return Ok(true);
289        }
290
291        let increasing = values.windows(2).all(|w| w[0] <= w[1]);
292        let decreasing = values.windows(2).all(|w| w[0] >= w[1]);
293
294        Ok(increasing || decreasing)
295    }
296
297    // Helper methods
298
299    fn select_rows(&self, indices: &[usize]) -> Result<DataFrame> {
300        // This is a simplified implementation - would need actual DataFrame row selection
301        // For now, return a clone as placeholder
302        Ok(self.dataframe.clone())
303    }
304
305    fn select_index_rows(&self, indices: &[usize]) -> Result<AdvancedMultiIndex> {
306        let selected_tuples: Result<Vec<Vec<IndexValue>>> = indices
307            .iter()
308            .map(|&i| self.index.get_tuple(i).map(|t| t.to_vec()))
309            .collect();
310
311        AdvancedMultiIndex::new(selected_tuples?, Some(self.index.level_names().to_vec()))
312    }
313}
314
315/// GroupBy operation result for MultiIndex DataFrames
316#[derive(Debug, Clone)]
317pub struct MultiIndexGroupBy {
318    dataframe: MultiIndexDataFrame,
319    groups: HashMap<Vec<IndexValue>, Vec<usize>>,
320    group_levels: Vec<usize>,
321}
322
323impl MultiIndexGroupBy {
324    /// Get the number of groups
325    pub fn ngroups(&self) -> usize {
326        self.groups.len()
327    }
328
329    /// Get group sizes
330    pub fn size(&self) -> HashMap<Vec<IndexValue>, usize> {
331        self.groups
332            .iter()
333            .map(|(key, indices)| (key.clone(), indices.len()))
334            .collect()
335    }
336
337    /// Get group by key
338    pub fn get_group(&self, key: &[IndexValue]) -> Result<MultiIndexDataFrame> {
339        let indices = self
340            .groups
341            .get(key)
342            .ok_or_else(|| Error::InvalidOperation(format!("Group key {:?} not found", key)))?;
343
344        let selected_dataframe = self.dataframe.select_rows(indices)?;
345        let selected_index = self.dataframe.select_index_rows(indices)?;
346
347        Ok(MultiIndexDataFrame {
348            dataframe: selected_dataframe,
349            index: selected_index,
350            column_names: self.dataframe.column_names.clone(),
351        })
352    }
353
354    /// Apply aggregation function to all groups
355    pub fn agg(&self, func: AggregationFunction) -> Result<MultiIndexDataFrame> {
356        let mut result_data: HashMap<String, Vec<f64>> = HashMap::new();
357        let mut result_index_tuples = Vec::new();
358
359        // Initialize result columns
360        for col_name in &self.dataframe.column_names {
361            result_data.insert(col_name.clone(), Vec::new());
362        }
363
364        // Process each group
365        for (group_key, indices) in &self.groups {
366            result_index_tuples.push(group_key.clone());
367
368            for col_name in &self.dataframe.column_names {
369                let values = self.extract_column_values(col_name, indices)?;
370                let aggregated = func.apply(&values);
371                result_data.get_mut(col_name).unwrap().push(aggregated);
372            }
373        }
374
375        // Create result index with only group levels
376        let group_level_names: Vec<Option<String>> = self
377            .group_levels
378            .iter()
379            .map(|&level| self.dataframe.index.level_names()[level].clone())
380            .collect();
381
382        let result_index = AdvancedMultiIndex::new(result_index_tuples, Some(group_level_names))?;
383
384        // Create result DataFrame (simplified - would need actual DataFrame construction)
385        let result_dataframe = self.dataframe.dataframe.clone(); // Placeholder
386
387        Ok(MultiIndexDataFrame {
388            dataframe: result_dataframe,
389            index: result_index,
390            column_names: self.dataframe.column_names.clone(),
391        })
392    }
393
394    /// Apply custom function to each group
395    pub fn apply<F>(&self, func: F) -> Result<MultiIndexDataFrame>
396    where
397        F: Fn(&MultiIndexDataFrame) -> Result<MultiIndexDataFrame>,
398    {
399        let mut result_parts = Vec::new();
400
401        for (group_key, indices) in &self.groups {
402            let group_dataframe = self.get_group(group_key)?;
403            let group_result = func(&group_dataframe)?;
404            result_parts.push((group_key.clone(), group_result));
405        }
406
407        // Combine results (simplified implementation)
408        if let Some((_, first_result)) = result_parts.first() {
409            Ok(first_result.clone()) // Placeholder
410        } else {
411            Err(Error::InvalidOperation(
412                "No groups to apply function to".to_string(),
413            ))
414        }
415    }
416
417    /// Get group keys
418    pub fn groups(&self) -> &HashMap<Vec<IndexValue>, Vec<usize>> {
419        &self.groups
420    }
421
422    fn extract_column_values(&self, column_name: &str, indices: &[usize]) -> Result<Vec<f64>> {
423        // This would extract numeric values from the specified column at given indices
424        // For now, return placeholder values
425        Ok(vec![0.0; indices.len()])
426    }
427}
428
429/// Aggregation functions for GroupBy operations
430#[derive(Debug, Clone, Copy)]
431pub enum AggregationFunction {
432    Sum,
433    Mean,
434    Count,
435    Min,
436    Max,
437    Std,
438    Var,
439    First,
440    Last,
441}
442
443impl AggregationFunction {
444    pub fn apply(&self, values: &[f64]) -> f64 {
445        if values.is_empty() {
446            return 0.0;
447        }
448
449        match self {
450            AggregationFunction::Sum => values.iter().sum(),
451            AggregationFunction::Mean => values.iter().sum::<f64>() / values.len() as f64,
452            AggregationFunction::Count => values.len() as f64,
453            AggregationFunction::Min => values.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
454            AggregationFunction::Max => values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
455            AggregationFunction::Std => {
456                if values.len() <= 1 {
457                    0.0
458                } else {
459                    let mean = values.iter().sum::<f64>() / values.len() as f64;
460                    let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
461                        / (values.len() - 1) as f64;
462                    variance.sqrt()
463                }
464            }
465            AggregationFunction::Var => {
466                if values.len() <= 1 {
467                    0.0
468                } else {
469                    let mean = values.iter().sum::<f64>() / values.len() as f64;
470                    values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
471                        / (values.len() - 1) as f64
472                }
473            }
474            AggregationFunction::First => values[0],
475            AggregationFunction::Last => values[values.len() - 1],
476        }
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    #[test]
485    fn test_cross_section_selection() {
486        // Create test data
487        let tuples = vec![
488            vec![IndexValue::from("A"), IndexValue::from(1)],
489            vec![IndexValue::from("A"), IndexValue::from(2)],
490            vec![IndexValue::from("B"), IndexValue::from(1)],
491            vec![IndexValue::from("B"), IndexValue::from(2)],
492        ];
493
494        let index = AdvancedMultiIndex::new(tuples, None).unwrap();
495
496        // This test would require a proper DataFrame implementation
497        // For now, we'll test the index operations
498        assert_eq!(index.len(), 4);
499        assert_eq!(index.n_levels(), 2);
500    }
501
502    #[test]
503    fn test_groupby_level() {
504        let tuples = vec![
505            vec![
506                IndexValue::from("A"),
507                IndexValue::from("X"),
508                IndexValue::from(1),
509            ],
510            vec![
511                IndexValue::from("A"),
512                IndexValue::from("Y"),
513                IndexValue::from(2),
514            ],
515            vec![
516                IndexValue::from("B"),
517                IndexValue::from("X"),
518                IndexValue::from(3),
519            ],
520            vec![
521                IndexValue::from("B"),
522                IndexValue::from("Y"),
523                IndexValue::from(4),
524            ],
525        ];
526
527        let index = AdvancedMultiIndex::new(tuples, None).unwrap();
528        let group_keys = index.get_group_keys(&[0]).unwrap();
529
530        assert_eq!(group_keys.len(), 2); // "A" and "B"
531
532        let indices_a = index
533            .get_group_indices(&[0], &[IndexValue::from("A")])
534            .unwrap();
535        assert_eq!(indices_a, vec![0, 1]);
536
537        let indices_b = index
538            .get_group_indices(&[0], &[IndexValue::from("B")])
539            .unwrap();
540        assert_eq!(indices_b, vec![2, 3]);
541    }
542
543    #[test]
544    fn test_aggregation_functions() {
545        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
546
547        assert_eq!(AggregationFunction::Sum.apply(&values), 15.0);
548        assert_eq!(AggregationFunction::Mean.apply(&values), 3.0);
549        assert_eq!(AggregationFunction::Count.apply(&values), 5.0);
550        assert_eq!(AggregationFunction::Min.apply(&values), 1.0);
551        assert_eq!(AggregationFunction::Max.apply(&values), 5.0);
552        assert_eq!(AggregationFunction::First.apply(&values), 1.0);
553        assert_eq!(AggregationFunction::Last.apply(&values), 5.0);
554
555        // Test std and var
556        let std_val = AggregationFunction::Std.apply(&values);
557        let var_val = AggregationFunction::Var.apply(&values);
558        assert!((std_val - 1.58113883).abs() < 1e-6);
559        assert!((var_val - 2.5).abs() < 1e-6);
560    }
561
562    #[test]
563    fn test_level_selection() {
564        let tuples = vec![
565            vec![
566                IndexValue::from("A"),
567                IndexValue::from("X"),
568                IndexValue::from(1),
569            ],
570            vec![
571                IndexValue::from("B"),
572                IndexValue::from("Y"),
573                IndexValue::from(2),
574            ],
575        ];
576
577        let index = AdvancedMultiIndex::new(tuples, None).unwrap();
578
579        // Test getting level values
580        let level_0_values = (0..index.len())
581            .map(|i| index.get_tuple(i).unwrap()[0].clone())
582            .collect::<Vec<_>>();
583        assert_eq!(
584            level_0_values,
585            vec![IndexValue::from("A"), IndexValue::from("B")]
586        );
587
588        let level_2_values = (0..index.len())
589            .map(|i| index.get_tuple(i).unwrap()[2].clone())
590            .collect::<Vec<_>>();
591        assert_eq!(
592            level_2_values,
593            vec![IndexValue::from(1), IndexValue::from(2)]
594        );
595    }
596}