use indexmap::IndexMap;
use super::Value;
#[derive(Clone, Debug)]
pub struct DataFrame {
columns: IndexMap<String, Vec<Value>>,
nrows: usize,
}
impl DataFrame {
pub fn new() -> Self {
DataFrame {
columns: IndexMap::new(),
nrows: 0,
}
}
pub fn column(&self, name: &str) -> Option<&[Value]> {
self.columns.get(name).map(|v| v.as_slice())
}
pub fn nrows(&self) -> usize {
self.nrows
}
pub fn ncols(&self) -> usize {
self.columns.len()
}
pub fn column_names(&self) -> Vec<&str> {
self.columns.keys().map(|s| s.as_str()).collect()
}
pub fn has_column(&self, name: &str) -> bool {
self.columns.contains_key(name)
}
pub fn add_column(&mut self, name: String, values: Vec<Value>) {
if self.columns.is_empty() {
self.nrows = values.len();
} else {
assert_eq!(
values.len(),
self.nrows,
"Column '{}' has {} values but DataFrame has {} rows",
name,
values.len(),
self.nrows
);
}
self.columns.insert(name, values);
}
pub fn column_mut(&mut self, name: &str) -> Option<&mut Vec<Value>> {
self.columns.get_mut(name)
}
pub fn group_by(&self, keys: &[&str]) -> Vec<DataFrame> {
if self.nrows == 0 {
return vec![];
}
let mut group_map: IndexMap<Vec<String>, Vec<usize>> = IndexMap::new();
for i in 0..self.nrows {
let key: Vec<String> = keys
.iter()
.map(|k| {
self.columns
.get(*k)
.map(|col| col[i].to_group_key())
.unwrap_or_else(|| "NA".to_string())
})
.collect();
group_map.entry(key).or_default().push(i);
}
group_map
.into_values()
.map(|indices| {
let mut df = DataFrame::new();
for (name, col) in &self.columns {
let values: Vec<Value> = indices.iter().map(|&i| col[i].clone()).collect();
df.add_column(name.clone(), values);
}
df
})
.collect()
}
pub fn vstack(&mut self, other: &DataFrame) {
if other.nrows == 0 {
return;
}
if self.columns.is_empty() {
*self = other.clone();
return;
}
for (name, col) in &self.columns {
if let Some(other_col) = other.columns.get(name) {
let _ = (col, other_col);
}
}
for name in other.columns.keys() {
if !self.columns.contains_key(name) {
self.columns
.insert(name.clone(), vec![Value::Na; self.nrows]);
}
}
let old_nrows = self.nrows;
self.nrows += other.nrows;
for (name, col) in &mut self.columns {
if let Some(other_col) = other.columns.get(name) {
col.extend(other_col.iter().cloned());
} else {
col.extend(std::iter::repeat_with(|| Value::Na).take(other.nrows));
}
debug_assert_eq!(col.len(), old_nrows + other.nrows);
}
}
pub fn select(&self, columns: &[&str]) -> DataFrame {
let mut df = DataFrame::new();
for &col_name in columns {
if let Some(col) = self.columns.get(col_name) {
df.add_column(col_name.to_string(), col.clone());
}
}
df
}
pub fn row(&self, idx: usize) -> IndexMap<String, Value> {
assert!(
idx < self.nrows,
"Row index {idx} out of bounds ({} rows)",
self.nrows
);
let mut map = IndexMap::new();
for (name, col) in &self.columns {
map.insert(name.clone(), col[idx].clone());
}
map
}
pub fn sort_by(&self, column: &str) -> DataFrame {
let col = match self.columns.get(column) {
Some(c) => c,
None => return self.clone(),
};
let mut indices: Vec<usize> = (0..self.nrows).collect();
indices.sort_by(|&a, &b| {
let va = col[a].as_f64().unwrap_or(f64::NAN);
let vb = col[b].as_f64().unwrap_or(f64::NAN);
va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
});
let mut df = DataFrame::new();
for (name, c) in &self.columns {
let values: Vec<Value> = indices.iter().map(|&i| c[i].clone()).collect();
df.add_column(name.clone(), values);
}
df
}
pub fn from_rows(rows: Vec<IndexMap<String, Value>>) -> Self {
if rows.is_empty() {
return DataFrame::new();
}
let mut col_names: IndexMap<String, ()> = IndexMap::new();
for row in &rows {
for key in row.keys() {
col_names.entry(key.clone()).or_default();
}
}
let mut df = DataFrame::new();
for name in col_names.keys() {
let values: Vec<Value> = rows
.iter()
.map(|row| row.get(name).cloned().unwrap_or(Value::Na))
.collect();
df.add_column(name.clone(), values);
}
df
}
pub fn unique_values(&self, column: &str) -> Vec<Value> {
let col = match self.columns.get(column) {
Some(c) => c,
None => return vec![],
};
let mut seen: Vec<String> = Vec::new();
let mut result = Vec::new();
for v in col {
let key = v.to_group_key();
if !seen.contains(&key) {
seen.push(key);
result.push(v.clone());
}
}
result
}
}
impl DataFrame {
pub fn from_csv(path: &str) -> Result<Self, std::io::Error> {
let content = std::fs::read_to_string(path)?;
let mut lines = content.lines();
let header = match lines.next() {
Some(h) => h,
None => return Ok(DataFrame::new()),
};
let col_names: Vec<&str> = header.split(',').map(|s| s.trim()).collect();
let mut columns: Vec<Vec<Value>> = vec![Vec::new(); col_names.len()];
for line in lines {
let line = line.trim();
if line.is_empty() {
continue;
}
let fields: Vec<&str> = line.split(',').collect();
for (i, field) in fields.iter().enumerate() {
if i >= col_names.len() {
continue;
}
let field = field.trim();
let val = if field == "NA" || field == "na" {
Value::Na
} else if let Ok(f) = field.parse::<f64>() {
Value::Float(f)
} else {
Value::Str(field.to_string())
};
columns[i].push(val);
}
for col in columns.iter_mut().skip(fields.len()) {
col.push(Value::Na);
}
}
let mut df = DataFrame::new();
for (i, name) in col_names.iter().enumerate() {
if !columns[i].is_empty() {
df.add_column(name.to_string(), std::mem::take(&mut columns[i]));
}
}
Ok(df)
}
}
impl Default for DataFrame {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_column_and_access() {
let mut df = DataFrame::new();
df.add_column("x".into(), vec![Value::Float(1.0), Value::Float(2.0)]);
df.add_column("y".into(), vec![Value::Float(3.0), Value::Float(4.0)]);
assert_eq!(df.nrows(), 2);
assert_eq!(df.ncols(), 2);
assert!(df.has_column("x"));
assert!(!df.has_column("z"));
}
#[test]
fn test_group_by() {
let mut df = DataFrame::new();
df.add_column(
"cat".into(),
vec![
Value::Str("a".into()),
Value::Str("b".into()),
Value::Str("a".into()),
],
);
df.add_column(
"val".into(),
vec![Value::Float(1.0), Value::Float(2.0), Value::Float(3.0)],
);
let groups = df.group_by(&["cat"]);
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].nrows(), 2); assert_eq!(groups[1].nrows(), 1); }
#[test]
fn test_vstack() {
let mut df1 = DataFrame::new();
df1.add_column("x".into(), vec![Value::Float(1.0)]);
let mut df2 = DataFrame::new();
df2.add_column("x".into(), vec![Value::Float(2.0)]);
df1.vstack(&df2);
assert_eq!(df1.nrows(), 2);
}
#[test]
fn test_sort_by() {
let mut df = DataFrame::new();
df.add_column(
"x".into(),
vec![Value::Float(3.0), Value::Float(1.0), Value::Float(2.0)],
);
let sorted = df.sort_by("x");
let col = sorted.column("x").unwrap();
assert_eq!(col[0].as_f64(), Some(1.0));
assert_eq!(col[1].as_f64(), Some(2.0));
assert_eq!(col[2].as_f64(), Some(3.0));
}
}