use polars::prelude::*;
use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
impl Dataset {
pub fn from_dataframe(df: &DataFrame, target_col: &str) -> Result<Self> {
let target_series = df
.column(target_col)
.map_err(|_| ScryLearnError::InvalidColumn(target_col.to_string()))?;
let target_f64 = cast_column_to_f64(target_series).ok_or_else(|| {
ScryLearnError::InvalidColumn(format!(
"target column '{target_col}' is not numeric (type: {})",
target_series.dtype()
))
})?;
let target = extract_f64_vec(&target_f64, target_col)?;
let mut features: Vec<Vec<f64>> = Vec::new();
let mut feature_names: Vec<String> = Vec::new();
for col in df.get_columns() {
let name = col.name().as_str();
if name == target_col {
continue;
}
if let Some(cast) = cast_column_to_f64(col) {
let vals = extract_f64_vec(&cast, name)?;
features.push(vals);
feature_names.push(name.to_string());
}
}
if features.is_empty() {
return Err(ScryLearnError::InvalidParameter(
"no numeric feature columns found in DataFrame".to_string(),
));
}
Ok(Dataset::new(features, target, feature_names, target_col))
}
pub fn to_dataframe(&self) -> Result<DataFrame> {
let mut columns: Vec<Column> = Vec::with_capacity(self.feature_names.len() + 1);
for (i, name) in self.feature_names.iter().enumerate() {
let s = Column::new(name.as_str().into(), &self.features[i]);
columns.push(s);
}
let target_col = Column::new(self.target_name.as_str().into(), &self.target);
columns.push(target_col);
DataFrame::new(columns).map_err(|e| {
ScryLearnError::InvalidParameter(format!("failed to create DataFrame: {e}"))
})
}
}
fn cast_column_to_f64(col: &Column) -> Option<Column> {
match col.dtype() {
DataType::Float64 => Some(col.clone()),
DataType::Float32
| DataType::Int64
| DataType::Int32
| DataType::Int16
| DataType::Int8
| DataType::UInt64
| DataType::UInt32
| DataType::UInt16
| DataType::UInt8
| DataType::Boolean => col.cast(&DataType::Float64).ok(),
_ => None,
}
}
fn extract_f64_vec(col: &Column, col_name: &str) -> Result<Vec<f64>> {
let ca = col
.f64()
.map_err(|e| ScryLearnError::InvalidParameter(format!("column '{col_name}': {e}")))?;
let mut out = Vec::with_capacity(ca.len());
for opt_val in ca {
match opt_val {
Some(v) => out.push(v),
None => {
return Err(ScryLearnError::InvalidParameter(format!(
"column '{col_name}' contains null values — \
use SimpleImputer or DataFrame.fill_null() first"
)));
}
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_dataframe_basic() {
let df = df!(
"f1" => [1.0, 2.0, 3.0],
"f2" => [4.0, 5.0, 6.0],
"target" => [0.0, 1.0, 0.0]
)
.unwrap();
let ds = Dataset::from_dataframe(&df, "target").unwrap();
assert_eq!(ds.n_samples(), 3);
assert_eq!(ds.n_features(), 2);
assert_eq!(ds.target, vec![0.0, 1.0, 0.0]);
assert_eq!(ds.feature(0), &[1.0, 2.0, 3.0]);
assert_eq!(ds.feature(1), &[4.0, 5.0, 6.0]);
}
#[test]
fn test_to_dataframe_basic() {
let ds = Dataset::new(
vec![vec![1.0, 2.0], vec![3.0, 4.0]],
vec![0.0, 1.0],
vec!["a".into(), "b".into()],
"t",
);
let df = ds.to_dataframe().unwrap();
assert_eq!(df.shape(), (2, 3)); assert_eq!(
df.column("a").unwrap().f64().unwrap().to_vec(),
vec![Some(1.0), Some(2.0)]
);
assert_eq!(
df.column("t").unwrap().f64().unwrap().to_vec(),
vec![Some(0.0), Some(1.0)]
);
}
#[test]
fn test_round_trip() {
let ds = Dataset::new(
vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]],
vec![10.0, 20.0, 30.0],
vec!["x".into(), "y".into()],
"target",
);
let df = ds.to_dataframe().unwrap();
let ds2 = Dataset::from_dataframe(&df, "target").unwrap();
assert_eq!(ds2.n_samples(), ds.n_samples());
assert_eq!(ds2.n_features(), ds.n_features());
assert_eq!(ds2.target, ds.target);
assert_eq!(ds2.feature_names, ds.feature_names);
for i in 0..ds.n_features() {
assert_eq!(ds2.feature(i), ds.feature(i));
}
}
#[test]
fn test_mixed_types() {
let df = DataFrame::new(vec![
Column::new("float_col".into(), &[1.0_f64, 2.0, 3.0]),
Column::new("int_col".into(), &[10_i64, 20, 30]),
Column::new("bool_col".into(), &[true, false, true]),
Column::new("str_col".into(), &["a", "b", "c"]),
Column::new("target".into(), &[0.0_f64, 1.0, 0.0]),
])
.unwrap();
let ds = Dataset::from_dataframe(&df, "target").unwrap();
assert_eq!(ds.n_features(), 3); assert_eq!(ds.feature_names, vec!["float_col", "int_col", "bool_col"]);
assert_eq!(ds.feature(1), &[10.0, 20.0, 30.0]); assert_eq!(ds.feature(2), &[1.0, 0.0, 1.0]); }
#[test]
fn test_null_handling_errors() {
let s1 = Column::new("f1".into(), &[1.0_f64, 2.0, 3.0]);
let s2: Column = {
let ca = Float64Chunked::new("f2".into(), &[Some(1.0), None, Some(3.0)]);
ca.into_column()
};
let s3 = Column::new("target".into(), &[0.0_f64, 1.0, 0.0]);
let df = DataFrame::new(vec![s1, s2, s3]).unwrap();
let err = Dataset::from_dataframe(&df, "target");
assert!(err.is_err());
let msg = err.unwrap_err().to_string();
assert!(msg.contains("null"), "expected null error, got: {msg}");
}
#[test]
fn test_missing_target_errors() {
let df = df!(
"f1" => [1.0, 2.0],
"f2" => [3.0, 4.0]
)
.unwrap();
let err = Dataset::from_dataframe(&df, "nonexistent");
assert!(err.is_err());
}
#[test]
fn test_single_feature() {
let df = df!(
"feat" => [1.0, 2.0, 3.0],
"target" => [0.0, 1.0, 0.0]
)
.unwrap();
let ds = Dataset::from_dataframe(&df, "target").unwrap();
assert_eq!(ds.n_features(), 1);
assert_eq!(ds.feature(0), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_empty_dataframe() {
let df = df!(
"f1" => Vec::<f64>::new(),
"target" => Vec::<f64>::new()
)
.unwrap();
let ds = Dataset::from_dataframe(&df, "target").unwrap();
assert_eq!(ds.n_samples(), 0);
assert_eq!(ds.n_features(), 1);
}
#[test]
fn test_large_dataframe() {
let n = 10_000;
let mut rng = fastrand::Rng::with_seed(42);
let f1: Vec<f64> = (0..n).map(|_| rng.f64()).collect();
let f2: Vec<f64> = (0..n).map(|_| rng.f64()).collect();
let target: Vec<f64> = (0..n).map(|_| (rng.u32(0..2)) as f64).collect();
let df = df!(
"f1" => f1.clone(),
"f2" => f2.clone(),
"target" => target.clone()
)
.unwrap();
let ds = Dataset::from_dataframe(&df, "target").unwrap();
assert_eq!(ds.n_samples(), n);
assert_eq!(ds.n_features(), 2);
assert_eq!(ds.target, target);
assert_eq!(ds.feature(0), f1.as_slice());
assert_eq!(ds.feature(1), f2.as_slice());
}
}