use crate::column::Column;
use std::fmt;
#[derive(Debug, Clone)]
pub enum DataError {
ColumnLengthMismatch {
expected: usize,
got: usize,
column: String,
},
DuplicateColumn(String),
Empty,
}
impl fmt::Display for DataError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DataError::ColumnLengthMismatch {
expected,
got,
column,
} => write!(
f,
"column `{}` has {} rows, expected {}",
column, got, expected
),
DataError::DuplicateColumn(name) => write!(f, "duplicate column `{}`", name),
DataError::Empty => write!(f, "no columns provided"),
}
}
}
impl std::error::Error for DataError {}
#[derive(Debug, Clone)]
pub struct DataFrame {
pub columns: Vec<(String, Column)>,
}
impl DataFrame {
pub fn new() -> Self {
Self {
columns: Vec::new(),
}
}
pub fn from_columns(columns: Vec<(String, Column)>) -> Result<Self, DataError> {
if columns.is_empty() {
return Ok(Self {
columns: Vec::new(),
});
}
let mut names = std::collections::BTreeSet::new();
for (name, _) in &columns {
if !names.insert(name.as_str()) {
return Err(DataError::DuplicateColumn(name.clone()));
}
}
let len = columns[0].1.len();
for (name, col) in &columns {
if col.len() != len {
return Err(DataError::ColumnLengthMismatch {
expected: len,
got: col.len(),
column: name.clone(),
});
}
}
Ok(Self { columns })
}
pub fn nrows(&self) -> usize {
self.columns.first().map(|(_, c)| c.len()).unwrap_or(0)
}
pub fn ncols(&self) -> usize {
self.columns.len()
}
pub fn get_column(&self, name: &str) -> Option<&Column> {
self.columns
.iter()
.find(|(n, _)| n == name)
.map(|(_, c)| c)
}
pub fn column_index(&self, name: &str) -> Option<usize> {
self.columns.iter().position(|(n, _)| n == name)
}
pub fn column_names(&self) -> Vec<&str> {
self.columns.iter().map(|(n, _)| n.as_str()).collect()
}
}
impl Default for DataFrame {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_columns() {
let df = DataFrame::from_columns(vec![
("id".into(), Column::Int(vec![1, 2, 3])),
("name".into(), Column::Str(vec!["a".into(), "b".into(), "c".into()])),
])
.unwrap();
assert_eq!(df.nrows(), 3);
assert_eq!(df.ncols(), 2);
}
#[test]
fn test_length_mismatch() {
let result = DataFrame::from_columns(vec![
("a".into(), Column::Int(vec![1, 2])),
("b".into(), Column::Int(vec![1, 2, 3])),
]);
assert!(result.is_err());
}
#[test]
fn test_duplicate_column() {
let result = DataFrame::from_columns(vec![
("a".into(), Column::Int(vec![1])),
("a".into(), Column::Int(vec![2])),
]);
assert!(result.is_err());
}
#[test]
fn test_get_column() {
let df = DataFrame::from_columns(vec![
("x".into(), Column::Float(vec![1.0, 2.0])),
])
.unwrap();
assert!(df.get_column("x").is_some());
assert!(df.get_column("y").is_none());
}
}