pub mod array;
pub mod dataframe;
pub use array::Array;
pub use dataframe::core::JoinType;
pub use dataframe::{DataFrame, Series};
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_array_n_dimensional() {
let arr = Array::from_vec((0..24).map(|x| x as f64).collect(), vec![2, 3, 4]);
assert_eq!(arr.shape, vec![2, 3, 4]);
assert_eq!(arr.ndim(), 3);
assert_eq!(arr[&[0, 0, 0][..]], 0.0);
assert_eq!(arr[&[1, 2, 3][..]], 23.0);
let reshaped = arr.reshape(vec![6, 4]);
assert_eq!(reshaped.shape, vec![6, 4]);
assert_eq!(reshaped.data.len(), 24);
}
#[test]
fn test_array_broadcasting() {
let arr1 = Array::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
let arr2 = Array::from_vec(vec![10.0, 20.0], vec![2, 1]);
if let Some(result) = arr1.add_broadcast(&arr2) {
assert_eq!(result.shape, vec![2, 3]);
assert_eq!(result[&[0, 0][..]], 11.0); assert_eq!(result[&[1, 2][..]], 23.0); }
let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let scaled = &arr + 5.0;
assert_eq!(scaled.data, vec![6.0, 7.0, 8.0, 9.0]);
}
#[test]
fn test_array_reductions() {
let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
assert_eq!(arr.sum(), 21.0);
assert_eq!(arr.mean(), 3.5);
assert_eq!(arr.max(), 6.0);
assert_eq!(arr.min(), 1.0);
let sum_axis_0 = arr.sum_axis(0);
assert_eq!(sum_axis_0.shape, vec![3]);
assert_eq!(sum_axis_0.data, vec![5.0, 7.0, 9.0]);
let mean_axis_1 = arr.mean_axis(1);
assert_eq!(mean_axis_1.shape, vec![2]);
assert_eq!(mean_axis_1.data, vec![2.0, 5.0]); }
#[test]
fn test_linear_algebra() {
let matrix = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
assert_eq!(matrix.det(), -2.0);
assert_eq!(matrix.trace(), 5.0);
let other = Array::from_vec(vec![2.0, 0.0, 1.0, 3.0], vec![2, 2]);
let product = matrix.dot(&other);
assert_eq!(product.data, vec![4.0, 6.0, 10.0, 12.0]);
if let Some(inv) = matrix.inv() {
let should_be_identity = matrix.dot(&inv);
assert!((should_be_identity[(0, 0)] - 1.0).abs() < 1e-10);
assert!((should_be_identity[(1, 1)] - 1.0).abs() < 1e-10);
assert!(should_be_identity[(0, 1)].abs() < 1e-10);
assert!(should_be_identity[(1, 0)].abs() < 1e-10);
}
let (q, r) = matrix.qr();
assert_eq!(q.shape, vec![2, 2]);
assert_eq!(r.shape, vec![2, 2]);
let qt = q.transpose();
let should_be_identity = q.dot(&qt);
assert!((should_be_identity[(0, 0)] - 1.0).abs() < 1e-10);
assert!((should_be_identity[(1, 1)] - 1.0).abs() < 1e-10);
}
#[test]
fn test_dataframe_enhanced() {
let df = DataFrame::new(vec![
("id".to_string(), Series::from(vec![1, 2, 3, 4])),
(
"name".to_string(),
Series::from(vec!["Alice", "Bob", "Charlie", "Diana"]),
),
(
"score".to_string(),
Series::from(vec![85.5, 92.0, 78.5, 88.0]),
),
(
"active".to_string(),
Series::from(vec![true, true, false, true]),
),
]);
assert_eq!(df.shape(), (4, 4));
assert_eq!(df.len(), 4);
assert!(!df.is_empty());
let head = df.head(2);
assert_eq!(head.len(), 2);
let tail = df.tail(2);
assert_eq!(tail.len(), 2);
let mask = vec![true, false, true, false];
let filtered = df.filter(&mask);
assert_eq!(filtered.len(), 2);
let sorted = df.sort_by("score", true); if let Some(Series::Float64(scores)) = sorted.get_column("score") {
assert!(scores[0] < scores[1]); }
let with_bonus = df.with_column(
"bonus".to_string(),
Series::from(vec![100.0, 150.0, 75.0, 120.0]),
);
assert_eq!(with_bonus.shape().1, 5);
let dropped = df.drop(&["active"]);
assert_eq!(dropped.shape().1, 3); }
#[test]
fn test_groupby_enhanced() {
let df = DataFrame::new(vec![
(
"department".to_string(),
Series::from(vec!["IT", "HR", "IT", "Finance", "HR"]),
),
(
"salary".to_string(),
Series::from(vec![75000, 65000, 80000, 70000, 68000]),
),
("experience".to_string(), Series::from(vec![3, 5, 7, 4, 6])),
]);
let grouped = df.groupby("department");
let counts = grouped.count();
assert_eq!(counts.len(), 3);
let sums = grouped.sum();
assert_eq!(sums.columns.len(), 3);
let means = grouped.mean();
assert_eq!(means.columns.len(), 3);
let first = grouped.first();
assert_eq!(first.len(), 3);
let last = grouped.last();
assert_eq!(last.len(), 3);
}
#[test]
fn test_joins() {
let left = DataFrame::new(vec![
("id".to_string(), Series::from(vec!["1", "2", "3"])),
(
"name".to_string(),
Series::from(vec!["Alice", "Bob", "Charlie"]),
),
]);
let right = DataFrame::new(vec![
("id".to_string(), Series::from(vec!["1", "2", "4"])),
("score".to_string(), Series::from(vec!["85", "92", "78"])),
]);
let joined = left.join(&right, "id", JoinType::Inner);
assert_eq!(joined.len(), 2);
assert_eq!(joined.columns.len(), 3); }
#[test]
fn test_csv_io_with_inference() -> Result<(), Box<dyn std::error::Error>> {
let mut temp_file = NamedTempFile::new()?;
writeln!(temp_file, "name,age,salary,active")?;
writeln!(temp_file, "Alice,25,50000.5,true")?;
writeln!(temp_file, "Bob,30,60000.0,false")?;
writeln!(temp_file, "Charlie,35,70000.25,true")?;
let df = DataFrame::from_csv(temp_file.path().to_str().unwrap())?;
assert_eq!(df.shape(), (3, 4));
match df.get_column("age") {
Some(Series::Int64(_)) => {} _ => panic!("Age should be inferred as Int64"),
}
match df.get_column("salary") {
Some(Series::Float64(_)) => {} _ => panic!("Salary should be inferred as Float64"),
}
match df.get_column("active") {
Some(Series::Bool(_)) => {} _ => panic!("Active should be inferred as Bool"),
}
let output_file = NamedTempFile::new()?;
df.to_csv(output_file.path().to_str().unwrap())?;
let df2 = DataFrame::from_csv(output_file.path().to_str().unwrap())?;
assert_eq!(df2.shape(), df.shape());
Ok(())
}
#[test]
fn test_json_io() -> Result<(), Box<dyn std::error::Error>> {
let df = DataFrame::new(vec![
("name".to_string(), Series::from(vec!["Alice", "Bob"])),
("age".to_string(), Series::from(vec![25, 30])),
("active".to_string(), Series::from(vec![true, false])),
]);
let jsonl_file = NamedTempFile::new()?;
df.to_jsonl(jsonl_file.path().to_str().unwrap())?;
let df_from_jsonl = DataFrame::from_jsonl(jsonl_file.path().to_str().unwrap())?;
assert_eq!(df_from_jsonl.shape(), (2, 3));
let json_file = NamedTempFile::new()?;
df.to_json(json_file.path().to_str().unwrap())?;
Ok(())
}
#[test]
fn test_statistical_summary() {
let df = DataFrame::new(vec![
(
"values".to_string(),
Series::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]),
),
(
"integers".to_string(),
Series::from(vec![10, 20, 30, 40, 50]),
),
(
"text".to_string(),
Series::from(vec!["a", "b", "c", "d", "e"]),
),
]);
let stats = df.describe();
assert_eq!(
stats.columns,
vec!["count", "mean", "std", "min", "25%", "50%", "75%", "max"]
);
if let Some(Series::Float64(means)) = stats.data.get(1) {
assert_eq!(means[0], 3.0); assert_eq!(means[1], 30.0); }
}
#[test]
fn test_mathematical_functions() {
let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let exp_arr = arr.exp();
assert!((exp_arr[&[0, 0][..]] - 1.0_f64.exp()).abs() < 1e-10);
let ln_arr = arr.ln();
assert!((ln_arr[(0, 0)] - 1.0_f64.ln()).abs() < 1e-10);
let sin_arr = arr.sin();
assert!((sin_arr[(0, 0)] - 1.0_f64.sin()).abs() < 1e-10);
let sqrt_arr = arr.sqrt();
assert!((sqrt_arr[(0, 0)] - 1.0).abs() < 1e-10);
let pow_arr = arr.pow(2.0);
assert_eq!(pow_arr[(0, 0)], 1.0);
assert_eq!(pow_arr[(1, 1)], 16.0);
}
}