1use crate::column::{Column, ColumnTrait};
7use crate::core::advanced_multi_index::{
8 AdvancedMultiIndex, CrossSectionResult, IndexValue, SelectionCriteria,
9};
10use crate::core::error::{Error, Result};
11use crate::dataframe::DataFrame;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct MultiIndexDataFrame {
17 pub dataframe: DataFrame,
19 pub index: AdvancedMultiIndex,
21 pub column_names: Vec<String>,
23}
24
25#[derive(Debug, Clone)]
27pub struct CrossSectionDataFrame {
28 pub dataframe: MultiIndexDataFrame,
30 pub selected_indices: Vec<usize>,
32 pub level_dropped: bool,
34}
35
36impl MultiIndexDataFrame {
37 pub fn new(dataframe: DataFrame, index: AdvancedMultiIndex) -> Result<Self> {
39 if index.len() != dataframe.row_count() {
40 return Err(Error::InconsistentRowCount {
41 expected: dataframe.row_count(),
42 found: index.len(),
43 });
44 }
45
46 let column_names = dataframe.column_names().to_vec();
47
48 Ok(Self {
49 dataframe,
50 index,
51 column_names,
52 })
53 }
54
55 pub fn xs(
57 &mut self,
58 key: IndexValue,
59 level: usize,
60 drop_level: bool,
61 ) -> Result<CrossSectionDataFrame> {
62 let xs_result = self.index.xs(key, level, drop_level)?;
63
64 if xs_result.indices.is_empty() {
65 return Err(Error::InvalidOperation(
66 "No matching rows found for cross-section".to_string(),
67 ));
68 }
69
70 let selected_dataframe = self.select_rows(&xs_result.indices)?;
72
73 let result_index = if drop_level {
75 xs_result.index.ok_or_else(|| {
76 Error::InvalidOperation("Expected index after dropping level".to_string())
77 })?
78 } else {
79 self.select_index_rows(&xs_result.indices)?
81 };
82
83 let result = MultiIndexDataFrame {
84 dataframe: selected_dataframe,
85 index: result_index,
86 column_names: self.column_names.clone(),
87 };
88
89 Ok(CrossSectionDataFrame {
90 dataframe: result,
91 selected_indices: xs_result.indices,
92 level_dropped: drop_level,
93 })
94 }
95
96 pub fn select(&self, criteria: SelectionCriteria) -> Result<CrossSectionDataFrame> {
98 let selected_indices = self.index.select(criteria)?;
99
100 if selected_indices.is_empty() {
101 return Err(Error::InvalidOperation(
102 "No matching rows found for selection criteria".to_string(),
103 ));
104 }
105
106 let selected_dataframe = self.select_rows(&selected_indices)?;
107 let selected_index = self.select_index_rows(&selected_indices)?;
108
109 let result = MultiIndexDataFrame {
110 dataframe: selected_dataframe,
111 index: selected_index,
112 column_names: self.column_names.clone(),
113 };
114
115 Ok(CrossSectionDataFrame {
116 dataframe: result,
117 selected_indices,
118 level_dropped: false,
119 })
120 }
121
122 pub fn groupby_level(&self, levels: &[usize]) -> Result<MultiIndexGroupBy> {
124 for &level in levels {
126 if level >= self.index.n_levels() {
127 return Err(Error::IndexOutOfBounds {
128 index: level,
129 size: self.index.n_levels(),
130 });
131 }
132 }
133
134 let group_keys = self.index.get_group_keys(levels)?;
136
137 let mut groups = HashMap::new();
139 for group_key in group_keys {
140 let indices = self.index.get_group_indices(levels, &group_key)?;
141 groups.insert(group_key, indices);
142 }
143
144 Ok(MultiIndexGroupBy {
145 dataframe: self.clone(),
146 groups,
147 group_levels: levels.to_vec(),
148 })
149 }
150
151 pub fn select_levels(&self, levels: &[usize]) -> Result<MultiIndexDataFrame> {
153 for &level in levels {
155 if level >= self.index.n_levels() {
156 return Err(Error::IndexOutOfBounds {
157 index: level,
158 size: self.index.n_levels(),
159 });
160 }
161 }
162
163 let mut new_tuples = Vec::with_capacity(self.index.len());
165 for i in 0..self.index.len() {
166 let original_tuple = self.index.get_tuple(i)?;
167 let new_tuple: Vec<IndexValue> = levels
168 .iter()
169 .map(|&level| original_tuple[level].clone())
170 .collect();
171 new_tuples.push(new_tuple);
172 }
173
174 let original_names = self.index.level_names();
176 let new_names: Vec<Option<String>> = levels
177 .iter()
178 .map(|&level| original_names[level].clone())
179 .collect();
180
181 let new_index = AdvancedMultiIndex::new(new_tuples, Some(new_names))?;
182
183 Ok(MultiIndexDataFrame {
184 dataframe: self.dataframe.clone(),
185 index: new_index,
186 column_names: self.column_names.clone(),
187 })
188 }
189
190 pub fn reindex(&self, new_index: AdvancedMultiIndex) -> Result<MultiIndexDataFrame> {
192 if new_index.len() != self.index.len() {
193 return Err(Error::InconsistentRowCount {
194 expected: self.index.len(),
195 found: new_index.len(),
196 });
197 }
198
199 Ok(MultiIndexDataFrame {
200 dataframe: self.dataframe.clone(),
201 index: new_index,
202 column_names: self.column_names.clone(),
203 })
204 }
205
206 pub fn swaplevel(&self, i: usize, j: usize) -> Result<MultiIndexDataFrame> {
208 let swapped_index = self.index.reorder_levels(&{
209 let mut order: Vec<usize> = (0..self.index.n_levels()).collect();
210 order.swap(i, j);
211 order
212 })?;
213
214 Ok(MultiIndexDataFrame {
215 dataframe: self.dataframe.clone(),
216 index: swapped_index,
217 column_names: self.column_names.clone(),
218 })
219 }
220
221 pub fn sort_index(
223 &self,
224 levels: Option<&[usize]>,
225 ascending: bool,
226 ) -> Result<MultiIndexDataFrame> {
227 let default_levels: Vec<usize> = (0..self.index.n_levels()).collect();
228 let sort_levels = levels.unwrap_or(&default_levels);
229
230 let mut indexed_tuples: Vec<(usize, Vec<IndexValue>)> =
232 Vec::with_capacity(self.index.len());
233 for i in 0..self.index.len() {
234 let tuple = self.index.get_tuple(i)?;
235 let sort_tuple: Vec<IndexValue> = sort_levels
236 .iter()
237 .map(|&level| tuple[level].clone())
238 .collect();
239 indexed_tuples.push((i, sort_tuple));
240 }
241
242 indexed_tuples.sort_by(|a, b| {
244 let comparison = a.1.cmp(&b.1);
245 if ascending {
246 comparison
247 } else {
248 comparison.reverse()
249 }
250 });
251
252 let sorted_indices: Vec<usize> = indexed_tuples.into_iter().map(|(idx, _)| idx).collect();
254
255 let sorted_dataframe = self.select_rows(&sorted_indices)?;
257 let sorted_index = self.select_index_rows(&sorted_indices)?;
258
259 Ok(MultiIndexDataFrame {
260 dataframe: sorted_dataframe,
261 index: sorted_index,
262 column_names: self.column_names.clone(),
263 })
264 }
265
266 pub fn get_level_values(&self, level: usize) -> Result<Vec<IndexValue>> {
268 if level >= self.index.n_levels() {
269 return Err(Error::IndexOutOfBounds {
270 index: level,
271 size: self.index.n_levels(),
272 });
273 }
274
275 let mut values = Vec::with_capacity(self.index.len());
276 for i in 0..self.index.len() {
277 let tuple = self.index.get_tuple(i)?;
278 values.push(tuple[level].clone());
279 }
280
281 Ok(values)
282 }
283
284 pub fn is_monotonic(&self, level: usize) -> Result<bool> {
286 let values = self.get_level_values(level)?;
287 if values.len() <= 1 {
288 return Ok(true);
289 }
290
291 let increasing = values.windows(2).all(|w| w[0] <= w[1]);
292 let decreasing = values.windows(2).all(|w| w[0] >= w[1]);
293
294 Ok(increasing || decreasing)
295 }
296
297 fn select_rows(&self, indices: &[usize]) -> Result<DataFrame> {
300 Ok(self.dataframe.clone())
303 }
304
305 fn select_index_rows(&self, indices: &[usize]) -> Result<AdvancedMultiIndex> {
306 let selected_tuples: Result<Vec<Vec<IndexValue>>> = indices
307 .iter()
308 .map(|&i| self.index.get_tuple(i).map(|t| t.to_vec()))
309 .collect();
310
311 AdvancedMultiIndex::new(selected_tuples?, Some(self.index.level_names().to_vec()))
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct MultiIndexGroupBy {
318 dataframe: MultiIndexDataFrame,
319 groups: HashMap<Vec<IndexValue>, Vec<usize>>,
320 group_levels: Vec<usize>,
321}
322
323impl MultiIndexGroupBy {
324 pub fn ngroups(&self) -> usize {
326 self.groups.len()
327 }
328
329 pub fn size(&self) -> HashMap<Vec<IndexValue>, usize> {
331 self.groups
332 .iter()
333 .map(|(key, indices)| (key.clone(), indices.len()))
334 .collect()
335 }
336
337 pub fn get_group(&self, key: &[IndexValue]) -> Result<MultiIndexDataFrame> {
339 let indices = self
340 .groups
341 .get(key)
342 .ok_or_else(|| Error::InvalidOperation(format!("Group key {:?} not found", key)))?;
343
344 let selected_dataframe = self.dataframe.select_rows(indices)?;
345 let selected_index = self.dataframe.select_index_rows(indices)?;
346
347 Ok(MultiIndexDataFrame {
348 dataframe: selected_dataframe,
349 index: selected_index,
350 column_names: self.dataframe.column_names.clone(),
351 })
352 }
353
354 pub fn agg(&self, func: AggregationFunction) -> Result<MultiIndexDataFrame> {
356 let mut result_data: HashMap<String, Vec<f64>> = HashMap::new();
357 let mut result_index_tuples = Vec::new();
358
359 for col_name in &self.dataframe.column_names {
361 result_data.insert(col_name.clone(), Vec::new());
362 }
363
364 for (group_key, indices) in &self.groups {
366 result_index_tuples.push(group_key.clone());
367
368 for col_name in &self.dataframe.column_names {
369 let values = self.extract_column_values(col_name, indices)?;
370 let aggregated = func.apply(&values);
371 result_data.get_mut(col_name).unwrap().push(aggregated);
372 }
373 }
374
375 let group_level_names: Vec<Option<String>> = self
377 .group_levels
378 .iter()
379 .map(|&level| self.dataframe.index.level_names()[level].clone())
380 .collect();
381
382 let result_index = AdvancedMultiIndex::new(result_index_tuples, Some(group_level_names))?;
383
384 let result_dataframe = self.dataframe.dataframe.clone(); Ok(MultiIndexDataFrame {
388 dataframe: result_dataframe,
389 index: result_index,
390 column_names: self.dataframe.column_names.clone(),
391 })
392 }
393
394 pub fn apply<F>(&self, func: F) -> Result<MultiIndexDataFrame>
396 where
397 F: Fn(&MultiIndexDataFrame) -> Result<MultiIndexDataFrame>,
398 {
399 let mut result_parts = Vec::new();
400
401 for (group_key, indices) in &self.groups {
402 let group_dataframe = self.get_group(group_key)?;
403 let group_result = func(&group_dataframe)?;
404 result_parts.push((group_key.clone(), group_result));
405 }
406
407 if let Some((_, first_result)) = result_parts.first() {
409 Ok(first_result.clone()) } else {
411 Err(Error::InvalidOperation(
412 "No groups to apply function to".to_string(),
413 ))
414 }
415 }
416
417 pub fn groups(&self) -> &HashMap<Vec<IndexValue>, Vec<usize>> {
419 &self.groups
420 }
421
422 fn extract_column_values(&self, column_name: &str, indices: &[usize]) -> Result<Vec<f64>> {
423 Ok(vec![0.0; indices.len()])
426 }
427}
428
429#[derive(Debug, Clone, Copy)]
431pub enum AggregationFunction {
432 Sum,
433 Mean,
434 Count,
435 Min,
436 Max,
437 Std,
438 Var,
439 First,
440 Last,
441}
442
443impl AggregationFunction {
444 pub fn apply(&self, values: &[f64]) -> f64 {
445 if values.is_empty() {
446 return 0.0;
447 }
448
449 match self {
450 AggregationFunction::Sum => values.iter().sum(),
451 AggregationFunction::Mean => values.iter().sum::<f64>() / values.len() as f64,
452 AggregationFunction::Count => values.len() as f64,
453 AggregationFunction::Min => values.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
454 AggregationFunction::Max => values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
455 AggregationFunction::Std => {
456 if values.len() <= 1 {
457 0.0
458 } else {
459 let mean = values.iter().sum::<f64>() / values.len() as f64;
460 let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
461 / (values.len() - 1) as f64;
462 variance.sqrt()
463 }
464 }
465 AggregationFunction::Var => {
466 if values.len() <= 1 {
467 0.0
468 } else {
469 let mean = values.iter().sum::<f64>() / values.len() as f64;
470 values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
471 / (values.len() - 1) as f64
472 }
473 }
474 AggregationFunction::First => values[0],
475 AggregationFunction::Last => values[values.len() - 1],
476 }
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
485 fn test_cross_section_selection() {
486 let tuples = vec![
488 vec![IndexValue::from("A"), IndexValue::from(1)],
489 vec![IndexValue::from("A"), IndexValue::from(2)],
490 vec![IndexValue::from("B"), IndexValue::from(1)],
491 vec![IndexValue::from("B"), IndexValue::from(2)],
492 ];
493
494 let index = AdvancedMultiIndex::new(tuples, None).unwrap();
495
496 assert_eq!(index.len(), 4);
499 assert_eq!(index.n_levels(), 2);
500 }
501
502 #[test]
503 fn test_groupby_level() {
504 let tuples = vec![
505 vec![
506 IndexValue::from("A"),
507 IndexValue::from("X"),
508 IndexValue::from(1),
509 ],
510 vec![
511 IndexValue::from("A"),
512 IndexValue::from("Y"),
513 IndexValue::from(2),
514 ],
515 vec![
516 IndexValue::from("B"),
517 IndexValue::from("X"),
518 IndexValue::from(3),
519 ],
520 vec![
521 IndexValue::from("B"),
522 IndexValue::from("Y"),
523 IndexValue::from(4),
524 ],
525 ];
526
527 let index = AdvancedMultiIndex::new(tuples, None).unwrap();
528 let group_keys = index.get_group_keys(&[0]).unwrap();
529
530 assert_eq!(group_keys.len(), 2); let indices_a = index
533 .get_group_indices(&[0], &[IndexValue::from("A")])
534 .unwrap();
535 assert_eq!(indices_a, vec![0, 1]);
536
537 let indices_b = index
538 .get_group_indices(&[0], &[IndexValue::from("B")])
539 .unwrap();
540 assert_eq!(indices_b, vec![2, 3]);
541 }
542
543 #[test]
544 fn test_aggregation_functions() {
545 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
546
547 assert_eq!(AggregationFunction::Sum.apply(&values), 15.0);
548 assert_eq!(AggregationFunction::Mean.apply(&values), 3.0);
549 assert_eq!(AggregationFunction::Count.apply(&values), 5.0);
550 assert_eq!(AggregationFunction::Min.apply(&values), 1.0);
551 assert_eq!(AggregationFunction::Max.apply(&values), 5.0);
552 assert_eq!(AggregationFunction::First.apply(&values), 1.0);
553 assert_eq!(AggregationFunction::Last.apply(&values), 5.0);
554
555 let std_val = AggregationFunction::Std.apply(&values);
557 let var_val = AggregationFunction::Var.apply(&values);
558 assert!((std_val - 1.58113883).abs() < 1e-6);
559 assert!((var_val - 2.5).abs() < 1e-6);
560 }
561
562 #[test]
563 fn test_level_selection() {
564 let tuples = vec![
565 vec![
566 IndexValue::from("A"),
567 IndexValue::from("X"),
568 IndexValue::from(1),
569 ],
570 vec![
571 IndexValue::from("B"),
572 IndexValue::from("Y"),
573 IndexValue::from(2),
574 ],
575 ];
576
577 let index = AdvancedMultiIndex::new(tuples, None).unwrap();
578
579 let level_0_values = (0..index.len())
581 .map(|i| index.get_tuple(i).unwrap()[0].clone())
582 .collect::<Vec<_>>();
583 assert_eq!(
584 level_0_values,
585 vec![IndexValue::from("A"), IndexValue::from("B")]
586 );
587
588 let level_2_values = (0..index.len())
589 .map(|i| index.get_tuple(i).unwrap()[2].clone())
590 .collect::<Vec<_>>();
591 assert_eq!(
592 level_2_values,
593 vec![IndexValue::from(1), IndexValue::from(2)]
594 );
595 }
596}