1use std::collections::{HashMap, HashSet};
12use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
13
14use crate::core::error::{Error, Result};
15use crate::dataframe::base::DataFrame;
16use crate::series::base::Series;
17
18#[derive(Debug, Clone)]
20pub enum RowSelector {
21 Single(String),
23 Position(usize),
25 Multiple(Vec<String>),
27 Positions(Vec<usize>),
29 Boolean(Vec<bool>),
31 Range(IndexRange),
33 All,
35}
36
37#[derive(Debug, Clone)]
39pub enum ColumnSelector {
40 Single(String),
42 Multiple(Vec<String>),
44 All,
46}
47
48#[derive(Debug, Clone)]
50pub enum IndexRange {
51 Standard { start: usize, end: usize },
53 From { start: usize },
55 To { end: usize },
57 Full,
59 Inclusive { start: usize, end: usize },
61 ToInclusive { end: usize },
63}
64
65#[derive(Debug, Clone)]
67pub enum AlignmentStrategy {
68 Outer,
70 Inner,
72 Left,
74 Right,
76}
77
78#[derive(Debug, Clone)]
80pub struct MultiLevelIndex {
81 pub names: Vec<String>,
83 pub levels: Vec<Vec<String>>,
85 pub tuples: Vec<Vec<String>>,
87}
88
89impl MultiLevelIndex {
90 pub fn new(names: Vec<String>, levels: Vec<Vec<String>>) -> Result<Self> {
92 if names.len() != levels.len() {
93 return Err(Error::InvalidValue(
94 "Number of names must match number of levels".to_string(),
95 ));
96 }
97
98 if levels.is_empty() {
99 return Err(Error::InvalidValue(
100 "At least one level required".to_string(),
101 ));
102 }
103
104 let row_count = levels[0].len();
105 for level in &levels {
106 if level.len() != row_count {
107 return Err(Error::InvalidValue(
108 "All levels must have the same length".to_string(),
109 ));
110 }
111 }
112
113 let mut tuples = Vec::with_capacity(row_count);
114 for i in 0..row_count {
115 let mut tuple = Vec::with_capacity(levels.len());
116 for level in &levels {
117 tuple.push(level[i].clone());
118 }
119 tuples.push(tuple);
120 }
121
122 Ok(Self {
123 names,
124 levels,
125 tuples,
126 })
127 }
128
129 pub fn len(&self) -> usize {
131 if self.levels.is_empty() {
132 0
133 } else {
134 self.levels[0].len()
135 }
136 }
137
138 pub fn is_empty(&self) -> bool {
140 self.len() == 0
141 }
142
143 pub fn level_values(&self, level: usize) -> Result<Vec<String>> {
145 if level >= self.levels.len() {
146 return Err(Error::IndexOutOfBounds {
147 index: level,
148 size: self.levels.len(),
149 });
150 }
151
152 let mut unique_values: Vec<String> = self.levels[level].iter().cloned().collect();
153 unique_values.sort();
154 unique_values.dedup();
155 Ok(unique_values)
156 }
157
158 pub fn find_tuple(&self, tuple: &[String]) -> Vec<usize> {
160 self.tuples
161 .iter()
162 .enumerate()
163 .filter_map(|(i, t)| {
164 if t.len() >= tuple.len() && &t[..tuple.len()] == tuple {
165 Some(i)
166 } else {
167 None
168 }
169 })
170 .collect()
171 }
172}
173
174pub struct ILocIndexer<'a> {
176 dataframe: &'a DataFrame,
177}
178
179impl<'a> ILocIndexer<'a> {
180 pub fn new(dataframe: &'a DataFrame) -> Self {
181 Self { dataframe }
182 }
183
184 pub fn get(&self, row: usize) -> Result<HashMap<String, String>> {
186 if row >= self.dataframe.row_count() {
187 return Err(Error::IndexOutOfBounds {
188 index: row,
189 size: self.dataframe.row_count(),
190 });
191 }
192
193 let mut result = HashMap::new();
194 for col_name in self.dataframe.column_names() {
195 let values = self.dataframe.get_column_string_values(&col_name)?;
196 if row < values.len() {
197 result.insert(col_name, values[row].clone());
198 }
199 }
200 Ok(result)
201 }
202
203 pub fn get_at(&self, row: usize, col: usize) -> Result<String> {
205 let col_names = self.dataframe.column_names();
206 if col >= col_names.len() {
207 return Err(Error::IndexOutOfBounds {
208 index: col,
209 size: col_names.len(),
210 });
211 }
212
213 let col_name = &col_names[col];
214 let values = self.dataframe.get_column_string_values(col_name)?;
215
216 if row >= values.len() {
217 return Err(Error::IndexOutOfBounds {
218 index: row,
219 size: values.len(),
220 });
221 }
222
223 Ok(values[row].clone())
224 }
225
226 pub fn get_range(&self, rows: Range<usize>) -> Result<DataFrame> {
228 self.select_rows(RowSelector::Range(IndexRange::Standard {
229 start: rows.start,
230 end: rows.end,
231 }))
232 }
233
234 pub fn get_slice(&self, rows: Range<usize>, cols: Range<usize>) -> Result<DataFrame> {
236 let result = self.get_range(rows)?;
237 let col_names = self.dataframe.column_names();
238
239 let selected_cols: Vec<String> = col_names
241 .into_iter()
242 .skip(cols.start)
243 .take(cols.end - cols.start)
244 .collect();
245
246 let col_refs: Vec<&str> = selected_cols.iter().map(|s| s.as_str()).collect();
247 result.select_columns(&col_refs)
248 }
249
250 pub fn get_positions(&self, positions: &[usize]) -> Result<DataFrame> {
252 self.select_rows(RowSelector::Positions(positions.to_vec()))
253 }
254
255 pub fn get_boolean(&self, mask: &[bool]) -> Result<DataFrame> {
257 self.select_rows(RowSelector::Boolean(mask.to_vec()))
258 }
259
260 fn select_rows(&self, selector: RowSelector) -> Result<DataFrame> {
262 let mut result = DataFrame::new();
263
264 let indices = match selector {
265 RowSelector::Range(range) => {
266 let (start, end) = match range {
267 IndexRange::Standard { start, end } => (start, end),
268 IndexRange::From { start } => (start, self.dataframe.row_count()),
269 IndexRange::To { end } => (0, end),
270 IndexRange::Full => (0, self.dataframe.row_count()),
271 IndexRange::Inclusive { start, end } => (start, end + 1),
272 IndexRange::ToInclusive { end } => (0, end + 1),
273 };
274 (start..end.min(self.dataframe.row_count())).collect()
275 }
276 RowSelector::Positions(positions) => positions,
277 RowSelector::Boolean(mask) => mask
278 .iter()
279 .enumerate()
280 .filter_map(|(i, &include)| if include { Some(i) } else { None })
281 .collect(),
282 _ => {
283 return Err(Error::InvalidValue(
284 "Unsupported selector for iloc".to_string(),
285 ))
286 }
287 };
288
289 for col_name in self.dataframe.column_names() {
291 let column_values = self.dataframe.get_column_string_values(&col_name)?;
292 let filtered_values: Vec<String> = indices
293 .iter()
294 .filter_map(|&idx| {
295 if idx < column_values.len() {
296 Some(column_values[idx].clone())
297 } else {
298 None
299 }
300 })
301 .collect();
302
303 let filtered_series = Series::new(filtered_values, Some(col_name.clone()))?;
304 result.add_column(col_name.clone(), filtered_series)?
305 }
306
307 Ok(result)
308 }
309}
310
311pub struct LocIndexer<'a> {
313 dataframe: &'a DataFrame,
314 index: Option<&'a MultiLevelIndex>,
315}
316
317impl<'a> LocIndexer<'a> {
318 pub fn new(dataframe: &'a DataFrame) -> Self {
319 Self {
320 dataframe,
321 index: None,
322 }
323 }
324
325 pub fn with_index(dataframe: &'a DataFrame, index: &'a MultiLevelIndex) -> Self {
326 Self {
327 dataframe,
328 index: Some(index),
329 }
330 }
331
332 pub fn get(&self, label: &str) -> Result<HashMap<String, String>> {
334 let position = self.find_label_position(label)?;
335 let iloc = ILocIndexer::new(self.dataframe);
336 iloc.get(position)
337 }
338
339 pub fn get_at(&self, label: &str, column: &str) -> Result<String> {
341 let position = self.find_label_position(label)?;
342 let values = self.dataframe.get_column_string_values(column)?;
343
344 if position >= values.len() {
345 return Err(Error::IndexOutOfBounds {
346 index: position,
347 size: values.len(),
348 });
349 }
350
351 Ok(values[position].clone())
352 }
353
354 pub fn get_labels(&self, labels: &[String]) -> Result<DataFrame> {
356 let positions: Result<Vec<usize>> = labels
357 .iter()
358 .map(|label| self.find_label_position(label))
359 .collect();
360
361 let iloc = ILocIndexer::new(self.dataframe);
362 iloc.get_positions(&positions?)
363 }
364
365 pub fn get_tuple(&self, tuple: &[String]) -> Result<DataFrame> {
367 if let Some(index) = self.index {
368 let positions = index.find_tuple(tuple);
369 if positions.is_empty() {
370 return Err(Error::InvalidValue(format!("Tuple {:?} not found", tuple)));
371 }
372 let iloc = ILocIndexer::new(self.dataframe);
373 iloc.get_positions(&positions)
374 } else {
375 Err(Error::InvalidValue(
376 "Multi-level index required for tuple selection".to_string(),
377 ))
378 }
379 }
380
381 fn find_label_position(&self, label: &str) -> Result<usize> {
383 label
386 .parse::<usize>()
387 .map_err(|_| Error::InvalidValue(format!("Label '{}' not found in index", label)))
388 }
389}
390
391pub struct AtIndexer<'a> {
393 dataframe: &'a DataFrame,
394}
395
396impl<'a> AtIndexer<'a> {
397 pub fn new(dataframe: &'a DataFrame) -> Self {
398 Self { dataframe }
399 }
400
401 pub fn get(&self, label: &str, column: &str) -> Result<String> {
403 let loc = LocIndexer::new(self.dataframe);
404 loc.get_at(label, column)
405 }
406
407 pub fn set(&self, label: &str, column: &str, value: String) -> Result<DataFrame> {
409 let _result = self.dataframe.clone();
411
412 Err(Error::NotImplemented(
415 "Mutable .at operations not yet implemented".to_string(),
416 ))
417 }
418}
419
420pub struct IAtIndexer<'a> {
422 dataframe: &'a DataFrame,
423}
424
425impl<'a> IAtIndexer<'a> {
426 pub fn new(dataframe: &'a DataFrame) -> Self {
427 Self { dataframe }
428 }
429
430 pub fn get(&self, row: usize, col: usize) -> Result<String> {
432 let iloc = ILocIndexer::new(self.dataframe);
433 iloc.get_at(row, col)
434 }
435
436 pub fn set(&self, row: usize, col: usize, value: String) -> Result<DataFrame> {
438 let _result = self.dataframe.clone();
440
441 Err(Error::NotImplemented(
444 "Mutable .iat operations not yet implemented".to_string(),
445 ))
446 }
447}
448
449pub struct SelectionBuilder<'a> {
451 dataframe: &'a DataFrame,
452 row_selector: Option<RowSelector>,
453 column_selector: Option<ColumnSelector>,
454}
455
456impl<'a> SelectionBuilder<'a> {
457 pub fn new(dataframe: &'a DataFrame) -> Self {
458 Self {
459 dataframe,
460 row_selector: None,
461 column_selector: None,
462 }
463 }
464
465 pub fn rows(mut self, selector: RowSelector) -> Self {
467 self.row_selector = Some(selector);
468 self
469 }
470
471 pub fn columns(mut self, selector: ColumnSelector) -> Self {
473 self.column_selector = Some(selector);
474 self
475 }
476
477 pub fn select(self) -> Result<DataFrame> {
479 let mut result = self.dataframe.clone();
480
481 if let Some(row_selector) = self.row_selector {
483 let iloc = ILocIndexer::new(&result);
484 result = iloc.select_rows(row_selector)?;
485 }
486
487 if let Some(column_selector) = self.column_selector {
489 match column_selector {
490 ColumnSelector::Single(col) => {
491 result = result.select_columns(&[&col])?;
492 }
493 ColumnSelector::Multiple(cols) => {
494 let col_refs: Vec<&str> = cols.iter().map(|s| s.as_str()).collect();
495 result = result.select_columns(&col_refs)?;
496 }
497 ColumnSelector::All => {
498 }
500 }
501 }
502
503 Ok(result)
504 }
505}
506
507pub struct IndexAligner;
509
510impl IndexAligner {
511 pub fn align(
513 left: &DataFrame,
514 right: &DataFrame,
515 strategy: AlignmentStrategy,
516 ) -> Result<(DataFrame, DataFrame)> {
517 let left_len = left.row_count();
519 let right_len = right.row_count();
520
521 match strategy {
522 AlignmentStrategy::Outer => {
523 let max_len = left_len.max(right_len);
524 let aligned_left = Self::extend_dataframe(left, max_len)?;
525 let aligned_right = Self::extend_dataframe(right, max_len)?;
526 Ok((aligned_left, aligned_right))
527 }
528 AlignmentStrategy::Inner => {
529 let min_len = left_len.min(right_len);
530 let aligned_left = Self::truncate_dataframe(left, min_len)?;
531 let aligned_right = Self::truncate_dataframe(right, min_len)?;
532 Ok((aligned_left, aligned_right))
533 }
534 AlignmentStrategy::Left => {
535 let aligned_right = if right_len < left_len {
536 Self::extend_dataframe(right, left_len)?
537 } else {
538 Self::truncate_dataframe(right, left_len)?
539 };
540 Ok((left.clone(), aligned_right))
541 }
542 AlignmentStrategy::Right => {
543 let aligned_left = if left_len < right_len {
544 Self::extend_dataframe(left, right_len)?
545 } else {
546 Self::truncate_dataframe(left, right_len)?
547 };
548 Ok((aligned_left, right.clone()))
549 }
550 }
551 }
552
553 fn extend_dataframe(df: &DataFrame, target_len: usize) -> Result<DataFrame> {
555 if df.row_count() >= target_len {
556 return Ok(df.clone());
557 }
558
559 let mut result = DataFrame::new();
560 let current_len = df.row_count();
561
562 for col_name in df.column_names() {
563 let values = df.get_column_string_values(&col_name)?;
564 let mut extended_values = values.clone();
565
566 for i in current_len..target_len {
568 if values.is_empty() {
569 extended_values.push("NaN".to_string());
570 } else {
571 extended_values.push(values[i % values.len()].clone());
572 }
573 }
574
575 let extended_series = Series::new(extended_values, Some(col_name.clone()))?;
576 result.add_column(col_name, extended_series)?;
577 }
578
579 Ok(result)
580 }
581
582 fn truncate_dataframe(df: &DataFrame, target_len: usize) -> Result<DataFrame> {
584 if df.row_count() <= target_len {
585 return Ok(df.clone());
586 }
587
588 let iloc = ILocIndexer::new(df);
589 iloc.get_range(0..target_len)
590 }
591
592 pub fn reindex(df: &DataFrame, new_index: &[String]) -> Result<DataFrame> {
594 let mut result = DataFrame::new();
595
596 for col_name in df.column_names() {
597 let current_values = df.get_column_string_values(&col_name)?;
598 let mut reindexed_values = Vec::with_capacity(new_index.len());
599
600 for index_val in new_index {
601 if let Ok(pos) = index_val.parse::<usize>() {
603 if pos < current_values.len() {
604 reindexed_values.push(current_values[pos].clone());
605 } else {
606 reindexed_values.push("NaN".to_string());
607 }
608 } else {
609 reindexed_values.push("NaN".to_string());
610 }
611 }
612
613 let reindexed_series = Series::new(reindexed_values, Some(col_name.clone()))?;
614 result.add_column(col_name, reindexed_series)?;
615 }
616
617 Ok(result)
618 }
619}
620
621pub trait AdvancedIndexingExt {
623 fn iloc(&self) -> ILocIndexer;
625
626 fn loc(&self) -> LocIndexer;
628
629 fn at(&self) -> AtIndexer;
631
632 fn iat(&self) -> IAtIndexer;
634
635 fn select(&self) -> SelectionBuilder;
637
638 fn reset_index(&self) -> Result<DataFrame>;
640
641 fn set_index(&self, column: &str) -> Result<DataFrame>;
643
644 fn set_multi_index(&self, columns: &[String]) -> Result<(DataFrame, MultiLevelIndex)>;
646
647 fn select_columns(&self, columns: &[String]) -> Result<DataFrame>;
649
650 fn drop_columns(&self, columns: &[String]) -> Result<DataFrame>;
652
653 fn sample(&self, n: usize) -> Result<DataFrame>;
655
656 fn head(&self, n: usize) -> Result<DataFrame>;
658
659 fn tail(&self, n: usize) -> Result<DataFrame>;
661}
662
663impl AdvancedIndexingExt for DataFrame {
664 fn iloc(&self) -> ILocIndexer {
665 ILocIndexer::new(self)
666 }
667
668 fn loc(&self) -> LocIndexer {
669 LocIndexer::new(self)
670 }
671
672 fn at(&self) -> AtIndexer {
673 AtIndexer::new(self)
674 }
675
676 fn iat(&self) -> IAtIndexer {
677 IAtIndexer::new(self)
678 }
679
680 fn select(&self) -> SelectionBuilder {
681 SelectionBuilder::new(self)
682 }
683
684 fn reset_index(&self) -> Result<DataFrame> {
685 Ok(self.clone())
687 }
688
689 fn set_index(&self, column: &str) -> Result<DataFrame> {
690 self.drop_columns(&[column.to_string()])
692 }
693
694 fn set_multi_index(&self, columns: &[String]) -> Result<(DataFrame, MultiLevelIndex)> {
695 let mut level_values = Vec::new();
696 let names = columns.to_vec();
697
698 for col_name in columns {
699 let values = self.get_column_string_values(col_name)?;
700 level_values.push(values);
701 }
702
703 let multi_index = MultiLevelIndex::new(names.clone(), level_values)?;
704 let result_df = self.drop_columns(columns)?;
705
706 Ok((result_df, multi_index))
707 }
708
709 fn select_columns(&self, columns: &[String]) -> Result<DataFrame> {
710 let column_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
711 let mut result = DataFrame::new();
712
713 for col_name in &column_refs {
714 if !self.contains_column(col_name) {
715 return Err(Error::ColumnNotFound(col_name.to_string()));
716 }
717
718 let values = self.get_column_string_values(col_name)?;
719 let series = Series::new(values, Some(col_name.to_string()))?;
720 result.add_column(col_name.to_string(), series)?;
721 }
722
723 Ok(result)
724 }
725
726 fn drop_columns(&self, columns: &[String]) -> Result<DataFrame> {
727 let all_columns: HashSet<String> = self.column_names().into_iter().collect();
728 let to_drop: HashSet<String> = columns.iter().cloned().collect();
729 let to_keep: Vec<String> = all_columns.difference(&to_drop).cloned().collect();
730 let to_keep_refs: Vec<&str> = to_keep.iter().map(|s| s.as_str()).collect();
731
732 self.select_columns(&to_keep_refs)
733 }
734
735 fn sample(&self, n: usize) -> Result<DataFrame> {
736 use rand::rng;
737 use rand::seq::SliceRandom;
738
739 let row_count = self.row_count();
740 if n >= row_count {
741 return Ok(self.clone());
742 }
743
744 let mut indices: Vec<usize> = (0..row_count).collect();
745 indices.shuffle(&mut rng());
746 indices.truncate(n);
747
748 let iloc = self.iloc();
749 iloc.get_positions(&indices)
750 }
751
752 fn head(&self, n: usize) -> Result<DataFrame> {
753 let iloc = self.iloc();
754 iloc.get_range(0..n.min(self.row_count()))
755 }
756
757 fn tail(&self, n: usize) -> Result<DataFrame> {
758 let row_count = self.row_count();
759 let start = if n >= row_count { 0 } else { row_count - n };
760 let iloc = self.iloc();
761 iloc.get_range(start..row_count)
762 }
763}
764
765pub mod selectors {
767 use super::*;
768
769 pub fn row(index: String) -> RowSelector {
771 RowSelector::Single(index)
772 }
773
774 pub fn rows(indices: Vec<String>) -> RowSelector {
776 RowSelector::Multiple(indices)
777 }
778
779 pub fn pos(position: usize) -> RowSelector {
781 RowSelector::Position(position)
782 }
783
784 pub fn positions(positions: Vec<usize>) -> RowSelector {
786 RowSelector::Positions(positions)
787 }
788
789 pub fn mask(mask: Vec<bool>) -> RowSelector {
791 RowSelector::Boolean(mask)
792 }
793
794 pub fn col(name: String) -> ColumnSelector {
796 ColumnSelector::Single(name)
797 }
798
799 pub fn cols(names: Vec<String>) -> ColumnSelector {
801 ColumnSelector::Multiple(names)
802 }
803
804 pub fn range(start: usize, end: usize) -> RowSelector {
806 RowSelector::Range(IndexRange::Standard { start, end })
807 }
808
809 pub fn range_inclusive(start: usize, end: usize) -> RowSelector {
811 RowSelector::Range(IndexRange::Inclusive { start, end })
812 }
813}
814
815#[macro_export]
817macro_rules! iloc {
818 ($df:expr, $row:expr) => {
819 $df.iloc().get($row)
820 };
821 ($df:expr, $row:expr, $col:expr) => {
822 $df.iloc().get_at($row, $col)
823 };
824 ($df:expr, $rows:expr, $cols:expr) => {
825 $df.iloc().get_slice($rows, $cols)
826 };
827}
828
829#[macro_export]
830macro_rules! loc {
831 ($df:expr, $label:expr) => {
832 $df.loc().get($label)
833 };
834 ($df:expr, $label:expr, $col:expr) => {
835 $df.loc().get_at($label, $col)
836 };
837}
838
839#[macro_export]
840macro_rules! select {
841 ($df:expr, rows: $rows:expr) => {
842 $df.select().rows($rows).select()
843 };
844 ($df:expr, cols: $cols:expr) => {
845 $df.select().columns($cols).select()
846 };
847 ($df:expr, rows: $rows:expr, cols: $cols:expr) => {
848 $df.select().rows($rows).columns($cols).select()
849 };
850}