1use crate::error::Result;
7use crate::primitives::{Matrix, Vector};
8
9#[derive(Debug, Clone)]
28pub struct DataFrame {
29 columns: Vec<(String, Vector<f32>)>,
30 n_rows: usize,
31}
32
33impl DataFrame {
34 pub fn new(columns: Vec<(String, Vector<f32>)>) -> Result<Self> {
40 if columns.is_empty() {
41 return Err("DataFrame must have at least one column".into());
42 }
43
44 let n_rows = columns[0].1.len();
45
46 for (name, col) in &columns {
48 if col.len() != n_rows {
49 return Err("All columns must have the same length".into());
50 }
51 if name.is_empty() {
52 return Err("Column names cannot be empty".into());
53 }
54 }
55
56 let mut names: Vec<&str> = columns.iter().map(|(n, _)| n.as_str()).collect();
58 names.sort_unstable();
59 for i in 1..names.len() {
60 if names[i] == names[i - 1] {
61 return Err("Duplicate column names not allowed".into());
62 }
63 }
64
65 Ok(Self { columns, n_rows })
66 }
67
68 #[must_use]
70 pub fn shape(&self) -> (usize, usize) {
71 (self.n_rows, self.columns.len())
72 }
73
74 #[must_use]
76 pub fn n_rows(&self) -> usize {
77 self.n_rows
78 }
79
80 #[must_use]
82 pub fn n_cols(&self) -> usize {
83 self.columns.len()
84 }
85
86 #[must_use]
88 pub fn column_names(&self) -> Vec<&str> {
89 self.columns.iter().map(|(n, _)| n.as_str()).collect()
90 }
91
92 pub fn column(&self, name: &str) -> Result<&Vector<f32>> {
98 self.columns
99 .iter()
100 .find(|(n, _)| n == name)
101 .map(|(_, v)| v)
102 .ok_or_else(|| "Column not found".into())
103 }
104
105 pub fn select(&self, names: &[&str]) -> Result<Self> {
111 if names.is_empty() {
112 return Err("Must select at least one column".into());
113 }
114
115 let mut selected = Vec::with_capacity(names.len());
116
117 for &name in names {
118 let col = self.column(name)?;
119 selected.push((name.to_string(), col.clone()));
120 }
121
122 Self::new(selected)
123 }
124
125 pub fn row(&self, idx: usize) -> Result<Vector<f32>> {
131 if idx >= self.n_rows {
132 return Err("Row index out of bounds".into());
133 }
134
135 let data: Vec<f32> = self.columns.iter().map(|(_, col)| col[idx]).collect();
136 Ok(Vector::from_vec(data))
137 }
138
139 #[must_use]
143 pub fn to_matrix(&self) -> Matrix<f32> {
144 let mut data = Vec::with_capacity(self.n_rows * self.columns.len());
145
146 for row_idx in 0..self.n_rows {
147 for (_, col) in &self.columns {
148 data.push(col[row_idx]);
149 }
150 }
151
152 Matrix::from_vec(self.n_rows, self.columns.len(), data)
153 .expect("Internal error: data size mismatch")
154 }
155
156 pub fn iter_columns(&self) -> impl Iterator<Item = (&str, &Vector<f32>)> {
158 self.columns.iter().map(|(n, v)| (n.as_str(), v))
159 }
160
161 pub fn add_column(&mut self, name: String, data: Vector<f32>) -> Result<()> {
167 if data.len() != self.n_rows {
168 return Err("Column length must match existing rows".into());
169 }
170
171 if self.columns.iter().any(|(n, _)| n == &name) {
172 return Err("Column name already exists".into());
173 }
174
175 if name.is_empty() {
176 return Err("Column name cannot be empty".into());
177 }
178
179 self.columns.push((name, data));
180 Ok(())
181 }
182
183 pub fn drop_column(&mut self, name: &str) -> Result<()> {
189 if self.columns.len() == 1 {
190 return Err("Cannot drop the last column".into());
191 }
192
193 let idx = self
194 .columns
195 .iter()
196 .position(|(n, _)| n == name)
197 .ok_or("Column not found")?;
198
199 self.columns.remove(idx);
200 Ok(())
201 }
202
203 #[must_use]
205 pub fn describe(&self) -> Vec<ColumnStats> {
206 self.columns
207 .iter()
208 .map(|(name, col)| {
209 let mean = col.mean();
210 let variance = col.variance();
211 let std = variance.sqrt();
212
213 let mut sorted: Vec<f32> = col.as_slice().to_vec();
214 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
215
216 let min = sorted.first().copied().unwrap_or(0.0);
217 let max = sorted.last().copied().unwrap_or(0.0);
218 let median = if sorted.is_empty() {
219 0.0
220 } else if sorted.len() % 2 == 0 {
221 (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
222 } else {
223 sorted[sorted.len() / 2]
224 };
225
226 ColumnStats {
227 name: name.clone(),
228 count: col.len(),
229 mean,
230 std,
231 min,
232 median,
233 max,
234 }
235 })
236 .collect()
237 }
238}
239
240#[derive(Debug, Clone)]
242pub struct ColumnStats {
243 pub name: String,
245 pub count: usize,
247 pub mean: f32,
249 pub std: f32,
251 pub min: f32,
253 pub median: f32,
255 pub max: f32,
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 fn sample_df() -> DataFrame {
264 let columns = vec![
265 ("a".to_string(), Vector::from_slice(&[1.0, 2.0, 3.0])),
266 ("b".to_string(), Vector::from_slice(&[4.0, 5.0, 6.0])),
267 ("c".to_string(), Vector::from_slice(&[7.0, 8.0, 9.0])),
268 ];
269 DataFrame::new(columns)
270 .expect("sample_df should create valid DataFrame with equal-length columns")
271 }
272
273 #[test]
274 fn test_new() {
275 let df = sample_df();
276 assert_eq!(df.shape(), (3, 3));
277 assert_eq!(df.n_rows(), 3);
278 assert_eq!(df.n_cols(), 3);
279 }
280
281 #[test]
282 fn test_new_empty_error() {
283 let result = DataFrame::new(vec![]);
284 assert!(result.is_err());
285 }
286
287 #[test]
288 fn test_new_mismatched_lengths_error() {
289 let columns = vec![
290 ("a".to_string(), Vector::from_slice(&[1.0, 2.0, 3.0])),
291 ("b".to_string(), Vector::from_slice(&[4.0, 5.0])),
292 ];
293 let result = DataFrame::new(columns);
294 assert!(result.is_err());
295 }
296
297 #[test]
298 fn test_new_duplicate_names_error() {
299 let columns = vec![
300 ("a".to_string(), Vector::from_slice(&[1.0, 2.0])),
301 ("a".to_string(), Vector::from_slice(&[3.0, 4.0])),
302 ];
303 let result = DataFrame::new(columns);
304 assert!(result.is_err());
305 }
306
307 #[test]
308 fn test_new_empty_name_error() {
309 let columns = vec![(String::new(), Vector::from_slice(&[1.0, 2.0]))];
310 let result = DataFrame::new(columns);
311 assert!(result.is_err());
312 }
313
314 #[test]
315 fn test_column_names() {
316 let df = sample_df();
317 let names = df.column_names();
318 assert_eq!(names, vec!["a", "b", "c"]);
319 }
320
321 #[test]
322 fn test_column() {
323 let df = sample_df();
324 let col = df
325 .column("b")
326 .expect("column 'b' should exist in sample_df");
327 assert_eq!(col.len(), 3);
328 assert!((col[0] - 4.0).abs() < 1e-6);
329 assert!((col[1] - 5.0).abs() < 1e-6);
330 assert!((col[2] - 6.0).abs() < 1e-6);
331 }
332
333 #[test]
334 fn test_column_not_found() {
335 let df = sample_df();
336 let result = df.column("z");
337 assert!(result.is_err());
338 }
339
340 #[test]
341 fn test_select() {
342 let df = sample_df();
343 let selected = df
344 .select(&["a", "c"])
345 .expect("select should succeed with existing column names");
346 assert_eq!(selected.shape(), (3, 2));
347 assert_eq!(selected.column_names(), vec!["a", "c"]);
348 }
349
350 #[test]
351 fn test_select_empty_error() {
352 let df = sample_df();
353 let result = df.select(&[]);
354 assert!(result.is_err());
355 }
356
357 #[test]
358 fn test_select_not_found_error() {
359 let df = sample_df();
360 let result = df.select(&["a", "z"]);
361 assert!(result.is_err());
362 }
363
364 #[test]
365 fn test_row() {
366 let df = sample_df();
367 let row = df
368 .row(1)
369 .expect("row index 1 should be valid for 3-row DataFrame");
370 assert_eq!(row.len(), 3);
371 assert!((row[0] - 2.0).abs() < 1e-6);
372 assert!((row[1] - 5.0).abs() < 1e-6);
373 assert!((row[2] - 8.0).abs() < 1e-6);
374 }
375
376 #[test]
377 fn test_row_out_of_bounds() {
378 let df = sample_df();
379 let result = df.row(10);
380 assert!(result.is_err());
381 }
382
383 #[test]
384 fn test_to_matrix() {
385 let df = sample_df();
386 let matrix = df.to_matrix();
387 assert_eq!(matrix.shape(), (3, 3));
388
389 assert!((matrix.get(0, 0) - 1.0).abs() < 1e-6);
391 assert!((matrix.get(0, 1) - 4.0).abs() < 1e-6);
392 assert!((matrix.get(0, 2) - 7.0).abs() < 1e-6);
393
394 assert!((matrix.get(1, 0) - 2.0).abs() < 1e-6);
396 assert!((matrix.get(1, 1) - 5.0).abs() < 1e-6);
397 assert!((matrix.get(1, 2) - 8.0).abs() < 1e-6);
398 }
399
400 #[test]
401 fn test_add_column() {
402 let mut df = sample_df();
403 let new_col = Vector::from_slice(&[10.0, 11.0, 12.0]);
404 df.add_column("d".to_string(), new_col)
405 .expect("add_column should succeed with matching length");
406
407 assert_eq!(df.n_cols(), 4);
408 let col = df
409 .column("d")
410 .expect("column 'd' should exist after add_column");
411 assert!((col[0] - 10.0).abs() < 1e-6);
412 }
413
414 #[test]
415 fn test_add_column_wrong_length() {
416 let mut df = sample_df();
417 let new_col = Vector::from_slice(&[10.0, 11.0]);
418 let result = df.add_column("d".to_string(), new_col);
419 assert!(result.is_err());
420 }
421
422 #[test]
423 fn test_add_column_duplicate_name() {
424 let mut df = sample_df();
425 let new_col = Vector::from_slice(&[10.0, 11.0, 12.0]);
426 let result = df.add_column("a".to_string(), new_col);
427 assert!(result.is_err());
428 }
429
430 #[test]
431 fn test_add_column_empty_name() {
432 let mut df = sample_df();
433 let new_col = Vector::from_slice(&[10.0, 11.0, 12.0]);
434 let result = df.add_column(String::new(), new_col);
435 assert!(result.is_err());
436 }
437
438 #[test]
439 fn test_drop_column() {
440 let mut df = sample_df();
441 df.drop_column("b")
442 .expect("drop_column should succeed for existing column 'b'");
443
444 assert_eq!(df.n_cols(), 2);
445 assert!(df.column("b").is_err());
446 }
447
448 #[test]
449 fn test_drop_column_not_found() {
450 let mut df = sample_df();
451 let result = df.drop_column("z");
452 assert!(result.is_err());
453 }
454
455 #[test]
456 fn test_drop_last_column_error() {
457 let columns = vec![("a".to_string(), Vector::from_slice(&[1.0, 2.0]))];
458 let mut df = DataFrame::new(columns)
459 .expect("DataFrame creation should succeed with single valid column");
460 let result = df.drop_column("a");
461 assert!(result.is_err());
462 }
463
464 #[test]
465 fn test_describe() {
466 let columns = vec![(
467 "x".to_string(),
468 Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]),
469 )];
470 let df = DataFrame::new(columns)
471 .expect("DataFrame creation should succeed with valid 5-element column");
472 let stats = df.describe();
473
474 assert_eq!(stats.len(), 1);
475 assert_eq!(stats[0].name, "x");
476 assert_eq!(stats[0].count, 5);
477 assert!((stats[0].mean - 3.0).abs() < 1e-6);
478 assert!((stats[0].min - 1.0).abs() < 1e-6);
479 assert!((stats[0].max - 5.0).abs() < 1e-6);
480 assert!((stats[0].median - 3.0).abs() < 1e-6);
481 }
482
483 #[test]
484 fn test_iter_columns() {
485 let df = sample_df();
486 let cols: Vec<_> = df.iter_columns().collect();
487 assert_eq!(cols.len(), 3);
488 assert_eq!(cols[0].0, "a");
489 assert_eq!(cols[1].0, "b");
490 assert_eq!(cols[2].0, "c");
491 }
492
493 #[test]
494 fn test_select_preserves_property() {
495 let df = sample_df();
497 let selected = df
498 .select(&["a", "c"])
499 .expect("select should succeed with existing columns");
500
501 let orig_a = df
502 .column("a")
503 .expect("column 'a' should exist in original DataFrame");
504 let sel_a = selected
505 .column("a")
506 .expect("column 'a' should exist in selected DataFrame");
507
508 assert_eq!(orig_a.len(), sel_a.len());
509 for i in 0..orig_a.len() {
510 assert!((orig_a[i] - sel_a[i]).abs() < 1e-6);
511 }
512 }
513
514 #[test]
515 fn test_to_matrix_column_count() {
516 let df = sample_df();
518 let selected = df
519 .select(&["a", "b"])
520 .expect("select should succeed with existing columns 'a' and 'b'");
521 let matrix = selected.to_matrix();
522 assert_eq!(matrix.n_cols(), 2);
523 }
524
525 #[test]
526 fn test_describe_median_even_length() {
527 let columns = vec![("x".to_string(), Vector::from_slice(&[1.0, 2.0, 3.0, 4.0]))];
530 let df = DataFrame::new(columns)
531 .expect("DataFrame creation should succeed with valid 4-element column");
532 let stats = df.describe();
533
534 assert!(
540 (stats[0].median - 2.5).abs() < 1e-6,
541 "Expected median 2.5, got {}",
542 stats[0].median
543 );
544 }
545
546 #[test]
547 fn test_describe_median_odd_length() {
548 let columns = vec![("x".to_string(), Vector::from_slice(&[1.0, 2.0, 3.0]))];
551 let df = DataFrame::new(columns)
552 .expect("DataFrame creation should succeed with valid 3-element column");
553 let stats = df.describe();
554
555 assert!(
557 (stats[0].median - 2.0).abs() < 1e-6,
558 "Expected median 2.0, got {}",
559 stats[0].median
560 );
561 }
562
563 #[test]
564 fn test_describe_median_two_elements() {
565 let columns = vec![("x".to_string(), Vector::from_slice(&[10.0, 20.0]))];
568 let df = DataFrame::new(columns)
569 .expect("DataFrame creation should succeed with valid 2-element column");
570 let stats = df.describe();
571
572 assert!(
574 (stats[0].median - 15.0).abs() < 1e-6,
575 "Expected median 15.0, got {}",
576 stats[0].median
577 );
578 }
579
580 #[test]
581 fn test_describe_median_arithmetic_mutations() {
582 let columns = vec![("x".to_string(), Vector::from_slice(&[2.0, 4.0, 6.0, 8.0]))];
586 let df = DataFrame::new(columns)
587 .expect("DataFrame creation should succeed with valid 4-element column");
588 let stats = df.describe();
589
590 assert!(
594 (stats[0].median - 5.0).abs() < 1e-6,
595 "Expected median 5.0, got {}",
596 stats[0].median
597 );
598 assert!(
599 stats[0].median > 0.0,
600 "Median should be positive, got {}",
601 stats[0].median
602 );
603 assert!(
604 stats[0].median < 10.0,
605 "Median should be < 10, got {}",
606 stats[0].median
607 );
608 }
609
610 #[test]
611 fn test_describe_median_unsorted_input() {
612 let columns = vec![(
615 "x".to_string(),
616 Vector::from_slice(&[5.0, 1.0, 3.0, 2.0, 4.0]),
617 )];
618 let df = DataFrame::new(columns)
619 .expect("DataFrame creation should succeed with valid 5-element unsorted column");
620 let stats = df.describe();
621
622 assert!(
623 (stats[0].median - 3.0).abs() < 1e-6,
624 "Expected median 3.0, got {}",
625 stats[0].median
626 );
627 }
628
629 #[test]
630 fn test_describe_six_elements() {
631 let columns = vec![(
636 "x".to_string(),
637 Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
638 )];
639 let df = DataFrame::new(columns)
640 .expect("DataFrame creation should succeed with valid 6-element column");
641 let stats = df.describe();
642
643 assert!(
644 (stats[0].median - 3.5).abs() < 1e-6,
645 "Expected median 3.5, got {}",
646 stats[0].median
647 );
648 }
649}