aprender/data/
mod.rs

1//! DataFrame module for named column containers.
2//!
3//! Provides a minimal DataFrame implementation (~300 LOC) for ML workflows.
4//! Heavy data wrangling should be delegated to ruchy/polars.
5
6use crate::error::Result;
7use crate::primitives::{Matrix, Vector};
8
9/// A minimal DataFrame with named columns.
10///
11/// This is a thin wrapper around `Vec<(String, Vector<f32>)>` with
12/// convenience methods for ML workflows.
13///
14/// # Examples
15///
16/// ```
17/// use aprender::data::DataFrame;
18/// use aprender::primitives::Vector;
19///
20/// let columns = vec![
21///     ("x".to_string(), Vector::from_slice(&[1.0, 2.0, 3.0])),
22///     ("y".to_string(), Vector::from_slice(&[4.0, 5.0, 6.0])),
23/// ];
24/// let df = DataFrame::new(columns).expect("DataFrame creation should succeed with valid columns");
25/// assert_eq!(df.shape(), (3, 2));
26/// ```
27#[derive(Debug, Clone)]
28pub struct DataFrame {
29    columns: Vec<(String, Vector<f32>)>,
30    n_rows: usize,
31}
32
33impl DataFrame {
34    /// Creates a new DataFrame from named columns.
35    ///
36    /// # Errors
37    ///
38    /// Returns an error if columns have different lengths or if empty.
39    pub fn new(columns: Vec<(String, Vector<f32>)>) -> Result<Self> {
40        if columns.is_empty() {
41            return Err("DataFrame must have at least one column".into());
42        }
43
44        let n_rows = columns[0].1.len();
45
46        // Verify all columns have same length
47        for (name, col) in &columns {
48            if col.len() != n_rows {
49                return Err("All columns must have the same length".into());
50            }
51            if name.is_empty() {
52                return Err("Column names cannot be empty".into());
53            }
54        }
55
56        // Check for duplicate column names
57        let mut names: Vec<&str> = columns.iter().map(|(n, _)| n.as_str()).collect();
58        names.sort_unstable();
59        for i in 1..names.len() {
60            if names[i] == names[i - 1] {
61                return Err("Duplicate column names not allowed".into());
62            }
63        }
64
65        Ok(Self { columns, n_rows })
66    }
67
68    /// Returns the shape as (n_rows, n_cols).
69    #[must_use]
70    pub fn shape(&self) -> (usize, usize) {
71        (self.n_rows, self.columns.len())
72    }
73
74    /// Returns the number of rows.
75    #[must_use]
76    pub fn n_rows(&self) -> usize {
77        self.n_rows
78    }
79
80    /// Returns the number of columns.
81    #[must_use]
82    pub fn n_cols(&self) -> usize {
83        self.columns.len()
84    }
85
86    /// Returns the column names.
87    #[must_use]
88    pub fn column_names(&self) -> Vec<&str> {
89        self.columns.iter().map(|(n, _)| n.as_str()).collect()
90    }
91
92    /// Returns a reference to a column by name.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if the column doesn't exist.
97    pub fn column(&self, name: &str) -> Result<&Vector<f32>> {
98        self.columns
99            .iter()
100            .find(|(n, _)| n == name)
101            .map(|(_, v)| v)
102            .ok_or_else(|| "Column not found".into())
103    }
104
105    /// Selects multiple columns by name, returning a new DataFrame.
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if any column doesn't exist.
110    pub fn select(&self, names: &[&str]) -> Result<Self> {
111        if names.is_empty() {
112            return Err("Must select at least one column".into());
113        }
114
115        let mut selected = Vec::with_capacity(names.len());
116
117        for &name in names {
118            let col = self.column(name)?;
119            selected.push((name.to_string(), col.clone()));
120        }
121
122        Self::new(selected)
123    }
124
125    /// Returns a row as a Vector.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if the index is out of bounds.
130    pub fn row(&self, idx: usize) -> Result<Vector<f32>> {
131        if idx >= self.n_rows {
132            return Err("Row index out of bounds".into());
133        }
134
135        let data: Vec<f32> = self.columns.iter().map(|(_, col)| col[idx]).collect();
136        Ok(Vector::from_vec(data))
137    }
138
139    /// Converts the DataFrame to a Matrix (column-major stacking).
140    ///
141    /// Returns a Matrix with shape (n_rows, n_cols).
142    #[must_use]
143    pub fn to_matrix(&self) -> Matrix<f32> {
144        let mut data = Vec::with_capacity(self.n_rows * self.columns.len());
145
146        for row_idx in 0..self.n_rows {
147            for (_, col) in &self.columns {
148                data.push(col[row_idx]);
149            }
150        }
151
152        Matrix::from_vec(self.n_rows, self.columns.len(), data)
153            .expect("Internal error: data size mismatch")
154    }
155
156    /// Returns an iterator over columns as (name, vector) pairs.
157    pub fn iter_columns(&self) -> impl Iterator<Item = (&str, &Vector<f32>)> {
158        self.columns.iter().map(|(n, v)| (n.as_str(), v))
159    }
160
161    /// Adds a new column to the DataFrame.
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if column length doesn't match or name already exists.
166    pub fn add_column(&mut self, name: String, data: Vector<f32>) -> Result<()> {
167        if data.len() != self.n_rows {
168            return Err("Column length must match existing rows".into());
169        }
170
171        if self.columns.iter().any(|(n, _)| n == &name) {
172            return Err("Column name already exists".into());
173        }
174
175        if name.is_empty() {
176            return Err("Column name cannot be empty".into());
177        }
178
179        self.columns.push((name, data));
180        Ok(())
181    }
182
183    /// Drops a column by name.
184    ///
185    /// # Errors
186    ///
187    /// Returns an error if the column doesn't exist or is the last column.
188    pub fn drop_column(&mut self, name: &str) -> Result<()> {
189        if self.columns.len() == 1 {
190            return Err("Cannot drop the last column".into());
191        }
192
193        let idx = self
194            .columns
195            .iter()
196            .position(|(n, _)| n == name)
197            .ok_or("Column not found")?;
198
199        self.columns.remove(idx);
200        Ok(())
201    }
202
203    /// Returns descriptive statistics for all columns.
204    #[must_use]
205    pub fn describe(&self) -> Vec<ColumnStats> {
206        self.columns
207            .iter()
208            .map(|(name, col)| {
209                let mean = col.mean();
210                let variance = col.variance();
211                let std = variance.sqrt();
212
213                let mut sorted: Vec<f32> = col.as_slice().to_vec();
214                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
215
216                let min = sorted.first().copied().unwrap_or(0.0);
217                let max = sorted.last().copied().unwrap_or(0.0);
218                let median = if sorted.is_empty() {
219                    0.0
220                } else if sorted.len() % 2 == 0 {
221                    (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
222                } else {
223                    sorted[sorted.len() / 2]
224                };
225
226                ColumnStats {
227                    name: name.clone(),
228                    count: col.len(),
229                    mean,
230                    std,
231                    min,
232                    median,
233                    max,
234                }
235            })
236            .collect()
237    }
238}
239
240/// Descriptive statistics for a column.
241#[derive(Debug, Clone)]
242pub struct ColumnStats {
243    /// Column name.
244    pub name: String,
245    /// Number of elements.
246    pub count: usize,
247    /// Mean value.
248    pub mean: f32,
249    /// Standard deviation.
250    pub std: f32,
251    /// Minimum value.
252    pub min: f32,
253    /// Median value.
254    pub median: f32,
255    /// Maximum value.
256    pub max: f32,
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    fn sample_df() -> DataFrame {
264        let columns = vec![
265            ("a".to_string(), Vector::from_slice(&[1.0, 2.0, 3.0])),
266            ("b".to_string(), Vector::from_slice(&[4.0, 5.0, 6.0])),
267            ("c".to_string(), Vector::from_slice(&[7.0, 8.0, 9.0])),
268        ];
269        DataFrame::new(columns)
270            .expect("sample_df should create valid DataFrame with equal-length columns")
271    }
272
273    #[test]
274    fn test_new() {
275        let df = sample_df();
276        assert_eq!(df.shape(), (3, 3));
277        assert_eq!(df.n_rows(), 3);
278        assert_eq!(df.n_cols(), 3);
279    }
280
281    #[test]
282    fn test_new_empty_error() {
283        let result = DataFrame::new(vec![]);
284        assert!(result.is_err());
285    }
286
287    #[test]
288    fn test_new_mismatched_lengths_error() {
289        let columns = vec![
290            ("a".to_string(), Vector::from_slice(&[1.0, 2.0, 3.0])),
291            ("b".to_string(), Vector::from_slice(&[4.0, 5.0])),
292        ];
293        let result = DataFrame::new(columns);
294        assert!(result.is_err());
295    }
296
297    #[test]
298    fn test_new_duplicate_names_error() {
299        let columns = vec![
300            ("a".to_string(), Vector::from_slice(&[1.0, 2.0])),
301            ("a".to_string(), Vector::from_slice(&[3.0, 4.0])),
302        ];
303        let result = DataFrame::new(columns);
304        assert!(result.is_err());
305    }
306
307    #[test]
308    fn test_new_empty_name_error() {
309        let columns = vec![(String::new(), Vector::from_slice(&[1.0, 2.0]))];
310        let result = DataFrame::new(columns);
311        assert!(result.is_err());
312    }
313
314    #[test]
315    fn test_column_names() {
316        let df = sample_df();
317        let names = df.column_names();
318        assert_eq!(names, vec!["a", "b", "c"]);
319    }
320
321    #[test]
322    fn test_column() {
323        let df = sample_df();
324        let col = df
325            .column("b")
326            .expect("column 'b' should exist in sample_df");
327        assert_eq!(col.len(), 3);
328        assert!((col[0] - 4.0).abs() < 1e-6);
329        assert!((col[1] - 5.0).abs() < 1e-6);
330        assert!((col[2] - 6.0).abs() < 1e-6);
331    }
332
333    #[test]
334    fn test_column_not_found() {
335        let df = sample_df();
336        let result = df.column("z");
337        assert!(result.is_err());
338    }
339
340    #[test]
341    fn test_select() {
342        let df = sample_df();
343        let selected = df
344            .select(&["a", "c"])
345            .expect("select should succeed with existing column names");
346        assert_eq!(selected.shape(), (3, 2));
347        assert_eq!(selected.column_names(), vec!["a", "c"]);
348    }
349
350    #[test]
351    fn test_select_empty_error() {
352        let df = sample_df();
353        let result = df.select(&[]);
354        assert!(result.is_err());
355    }
356
357    #[test]
358    fn test_select_not_found_error() {
359        let df = sample_df();
360        let result = df.select(&["a", "z"]);
361        assert!(result.is_err());
362    }
363
364    #[test]
365    fn test_row() {
366        let df = sample_df();
367        let row = df
368            .row(1)
369            .expect("row index 1 should be valid for 3-row DataFrame");
370        assert_eq!(row.len(), 3);
371        assert!((row[0] - 2.0).abs() < 1e-6);
372        assert!((row[1] - 5.0).abs() < 1e-6);
373        assert!((row[2] - 8.0).abs() < 1e-6);
374    }
375
376    #[test]
377    fn test_row_out_of_bounds() {
378        let df = sample_df();
379        let result = df.row(10);
380        assert!(result.is_err());
381    }
382
383    #[test]
384    fn test_to_matrix() {
385        let df = sample_df();
386        let matrix = df.to_matrix();
387        assert_eq!(matrix.shape(), (3, 3));
388
389        // Row 0: [1, 4, 7]
390        assert!((matrix.get(0, 0) - 1.0).abs() < 1e-6);
391        assert!((matrix.get(0, 1) - 4.0).abs() < 1e-6);
392        assert!((matrix.get(0, 2) - 7.0).abs() < 1e-6);
393
394        // Row 1: [2, 5, 8]
395        assert!((matrix.get(1, 0) - 2.0).abs() < 1e-6);
396        assert!((matrix.get(1, 1) - 5.0).abs() < 1e-6);
397        assert!((matrix.get(1, 2) - 8.0).abs() < 1e-6);
398    }
399
400    #[test]
401    fn test_add_column() {
402        let mut df = sample_df();
403        let new_col = Vector::from_slice(&[10.0, 11.0, 12.0]);
404        df.add_column("d".to_string(), new_col)
405            .expect("add_column should succeed with matching length");
406
407        assert_eq!(df.n_cols(), 4);
408        let col = df
409            .column("d")
410            .expect("column 'd' should exist after add_column");
411        assert!((col[0] - 10.0).abs() < 1e-6);
412    }
413
414    #[test]
415    fn test_add_column_wrong_length() {
416        let mut df = sample_df();
417        let new_col = Vector::from_slice(&[10.0, 11.0]);
418        let result = df.add_column("d".to_string(), new_col);
419        assert!(result.is_err());
420    }
421
422    #[test]
423    fn test_add_column_duplicate_name() {
424        let mut df = sample_df();
425        let new_col = Vector::from_slice(&[10.0, 11.0, 12.0]);
426        let result = df.add_column("a".to_string(), new_col);
427        assert!(result.is_err());
428    }
429
430    #[test]
431    fn test_add_column_empty_name() {
432        let mut df = sample_df();
433        let new_col = Vector::from_slice(&[10.0, 11.0, 12.0]);
434        let result = df.add_column(String::new(), new_col);
435        assert!(result.is_err());
436    }
437
438    #[test]
439    fn test_drop_column() {
440        let mut df = sample_df();
441        df.drop_column("b")
442            .expect("drop_column should succeed for existing column 'b'");
443
444        assert_eq!(df.n_cols(), 2);
445        assert!(df.column("b").is_err());
446    }
447
448    #[test]
449    fn test_drop_column_not_found() {
450        let mut df = sample_df();
451        let result = df.drop_column("z");
452        assert!(result.is_err());
453    }
454
455    #[test]
456    fn test_drop_last_column_error() {
457        let columns = vec![("a".to_string(), Vector::from_slice(&[1.0, 2.0]))];
458        let mut df = DataFrame::new(columns)
459            .expect("DataFrame creation should succeed with single valid column");
460        let result = df.drop_column("a");
461        assert!(result.is_err());
462    }
463
464    #[test]
465    fn test_describe() {
466        let columns = vec![(
467            "x".to_string(),
468            Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]),
469        )];
470        let df = DataFrame::new(columns)
471            .expect("DataFrame creation should succeed with valid 5-element column");
472        let stats = df.describe();
473
474        assert_eq!(stats.len(), 1);
475        assert_eq!(stats[0].name, "x");
476        assert_eq!(stats[0].count, 5);
477        assert!((stats[0].mean - 3.0).abs() < 1e-6);
478        assert!((stats[0].min - 1.0).abs() < 1e-6);
479        assert!((stats[0].max - 5.0).abs() < 1e-6);
480        assert!((stats[0].median - 3.0).abs() < 1e-6);
481    }
482
483    #[test]
484    fn test_iter_columns() {
485        let df = sample_df();
486        let cols: Vec<_> = df.iter_columns().collect();
487        assert_eq!(cols.len(), 3);
488        assert_eq!(cols[0].0, "a");
489        assert_eq!(cols[1].0, "b");
490        assert_eq!(cols[2].0, "c");
491    }
492
493    #[test]
494    fn test_select_preserves_property() {
495        // Property: select(names).column(name) == original.column(name)
496        let df = sample_df();
497        let selected = df
498            .select(&["a", "c"])
499            .expect("select should succeed with existing columns");
500
501        let orig_a = df
502            .column("a")
503            .expect("column 'a' should exist in original DataFrame");
504        let sel_a = selected
505            .column("a")
506            .expect("column 'a' should exist in selected DataFrame");
507
508        assert_eq!(orig_a.len(), sel_a.len());
509        for i in 0..orig_a.len() {
510            assert!((orig_a[i] - sel_a[i]).abs() < 1e-6);
511        }
512    }
513
514    #[test]
515    fn test_to_matrix_column_count() {
516        // Property: to_matrix().n_cols() == n_selected_columns
517        let df = sample_df();
518        let selected = df
519            .select(&["a", "b"])
520            .expect("select should succeed with existing columns 'a' and 'b'");
521        let matrix = selected.to_matrix();
522        assert_eq!(matrix.n_cols(), 2);
523    }
524
525    #[test]
526    fn test_describe_median_even_length() {
527        // Test median calculation for even-length arrays
528        // Median of [1, 2, 3, 4] = (2 + 3) / 2 = 2.5
529        let columns = vec![("x".to_string(), Vector::from_slice(&[1.0, 2.0, 3.0, 4.0]))];
530        let df = DataFrame::new(columns)
531            .expect("DataFrame creation should succeed with valid 4-element column");
532        let stats = df.describe();
533
534        // This catches mutations in:
535        // - sorted.len() % 2 == 0 (% vs + or /)
536        // - sorted[sorted.len() / 2 - 1] (index calculation)
537        // - + sorted[sorted.len() / 2] (sum of middle values)
538        // - / 2.0 (averaging)
539        assert!(
540            (stats[0].median - 2.5).abs() < 1e-6,
541            "Expected median 2.5, got {}",
542            stats[0].median
543        );
544    }
545
546    #[test]
547    fn test_describe_median_odd_length() {
548        // Test median calculation for odd-length arrays
549        // Median of [1, 2, 3] = 2.0 (middle element)
550        let columns = vec![("x".to_string(), Vector::from_slice(&[1.0, 2.0, 3.0]))];
551        let df = DataFrame::new(columns)
552            .expect("DataFrame creation should succeed with valid 3-element column");
553        let stats = df.describe();
554
555        // For odd length, median = sorted[len / 2] = sorted[1] = 2.0
556        assert!(
557            (stats[0].median - 2.0).abs() < 1e-6,
558            "Expected median 2.0, got {}",
559            stats[0].median
560        );
561    }
562
563    #[test]
564    fn test_describe_median_two_elements() {
565        // Test median with exactly 2 elements
566        // Median of [10, 20] = (10 + 20) / 2 = 15
567        let columns = vec![("x".to_string(), Vector::from_slice(&[10.0, 20.0]))];
568        let df = DataFrame::new(columns)
569            .expect("DataFrame creation should succeed with valid 2-element column");
570        let stats = df.describe();
571
572        // This catches mutations in median averaging
573        assert!(
574            (stats[0].median - 15.0).abs() < 1e-6,
575            "Expected median 15.0, got {}",
576            stats[0].median
577        );
578    }
579
580    #[test]
581    fn test_describe_median_arithmetic_mutations() {
582        // Test to catch specific arithmetic mutations
583        // Using values where wrong operations give different results
584        // [2, 4, 6, 8]: median = (4 + 6) / 2 = 5.0
585        let columns = vec![("x".to_string(), Vector::from_slice(&[2.0, 4.0, 6.0, 8.0]))];
586        let df = DataFrame::new(columns)
587            .expect("DataFrame creation should succeed with valid 4-element column");
588        let stats = df.describe();
589
590        // If + becomes - in median sum: (4 - 6) / 2 = -1
591        // If / 2.0 becomes * 2.0: (4 + 6) * 2 = 20
592        // If / 2 - 1 becomes / 2 + 1: would access wrong index
593        assert!(
594            (stats[0].median - 5.0).abs() < 1e-6,
595            "Expected median 5.0, got {}",
596            stats[0].median
597        );
598        assert!(
599            stats[0].median > 0.0,
600            "Median should be positive, got {}",
601            stats[0].median
602        );
603        assert!(
604            stats[0].median < 10.0,
605            "Median should be < 10, got {}",
606            stats[0].median
607        );
608    }
609
610    #[test]
611    fn test_describe_median_unsorted_input() {
612        // Verify median calculation sorts data correctly
613        // Input [5, 1, 3, 2, 4] -> sorted [1, 2, 3, 4, 5] -> median = 3
614        let columns = vec![(
615            "x".to_string(),
616            Vector::from_slice(&[5.0, 1.0, 3.0, 2.0, 4.0]),
617        )];
618        let df = DataFrame::new(columns)
619            .expect("DataFrame creation should succeed with valid 5-element unsorted column");
620        let stats = df.describe();
621
622        assert!(
623            (stats[0].median - 3.0).abs() < 1e-6,
624            "Expected median 3.0, got {}",
625            stats[0].median
626        );
627    }
628
629    #[test]
630    fn test_describe_six_elements() {
631        // Test with 6 elements to ensure index math is correct
632        // [1, 2, 3, 4, 5, 6]: median = (3 + 4) / 2 = 3.5
633        // len = 6, len/2 = 3, len/2 - 1 = 2
634        // sorted[2] = 3, sorted[3] = 4
635        let columns = vec![(
636            "x".to_string(),
637            Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
638        )];
639        let df = DataFrame::new(columns)
640            .expect("DataFrame creation should succeed with valid 6-element column");
641        let stats = df.describe();
642
643        assert!(
644            (stats[0].median - 3.5).abs() < 1e-6,
645            "Expected median 3.5, got {}",
646            stats[0].median
647        );
648    }
649}