use crate::dataset::binner::QuantileBinner;
use crate::dataset::{BinnedDataset, DatasetBinner, FeatureInfo, FeatureType};
use crate::{Result, TreeBoostError};
use polars::prelude::*;
use std::path::Path;
pub struct DatasetLoader {
binner: DatasetBinner,
}
impl DatasetLoader {
pub fn new(num_bins: usize) -> Self {
Self {
binner: DatasetBinner::new(num_bins),
}
}
pub fn load_parquet(
&self,
path: impl AsRef<Path>,
target_column: &str,
feature_columns: Option<&[&str]>,
) -> Result<BinnedDataset> {
let pl_path = PlPath::new(&path.as_ref().to_string_lossy());
let df = LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?;
self.from_dataframe(df, target_column, feature_columns)
}
pub fn load_csv(
&self,
path: impl AsRef<Path>,
target_column: &str,
feature_columns: Option<&[&str]>,
) -> Result<BinnedDataset> {
let df = CsvReadOptions::default()
.try_into_reader_with_file_path(Some(path.as_ref().to_path_buf()))?
.finish()?;
self.from_dataframe(df, target_column, feature_columns)
}
pub fn from_dataframe(
&self,
df: DataFrame,
target_column: &str,
feature_columns: Option<&[&str]>,
) -> Result<BinnedDataset> {
let num_rows = df.height();
let target_col = df.column(target_column).map_err(|e| {
TreeBoostError::Data(format!(
"Target column '{}' not found: {}",
target_column, e
))
})?;
let target_series = target_col.as_materialized_series();
let targets = self.series_to_f32(target_series)?;
let feature_names: Vec<String> = match feature_columns {
Some(cols) => cols.iter().map(|s| s.to_string()).collect(),
None => df
.get_column_names()
.into_iter()
.filter(|name| *name != target_column)
.map(|s| s.to_string())
.collect(),
};
let mut all_binned: Vec<Vec<u8>> = Vec::with_capacity(feature_names.len());
let mut all_info: Vec<FeatureInfo> = Vec::with_capacity(feature_names.len());
for name in &feature_names {
let col = df.column(name.as_str()).map_err(|e| {
TreeBoostError::Data(format!("Feature column '{}' not found: {}", name, e))
})?;
let series = col.as_materialized_series();
let (binned, info) = self.process_series(name.clone(), series)?;
all_binned.push(binned);
all_info.push(info);
}
let mut features = Vec::with_capacity(num_rows * all_binned.len());
for binned_col in all_binned {
features.extend(binned_col);
}
Ok(BinnedDataset::new(num_rows, features, targets, all_info))
}
fn process_series(&self, name: String, series: &Series) -> Result<(Vec<u8>, FeatureInfo)> {
match series.dtype() {
DataType::Float64
| DataType::Float32
| DataType::Int64
| DataType::Int32
| DataType::Int16
| DataType::Int8
| DataType::UInt64
| DataType::UInt32
| DataType::UInt16
| DataType::UInt8 => {
let values = self.series_to_f64(series)?;
self.binner.process_numeric_column(name, &values)
}
DataType::String | DataType::Categorical(_, _) => {
let values = self.categorical_to_numeric(series)?;
self.binner.process_numeric_column(name, &values)
}
dtype => Err(TreeBoostError::Data(format!(
"Unsupported dtype for column '{}': {:?}",
name, dtype
))),
}
}
fn series_to_f64(&self, series: &Series) -> Result<Vec<f64>> {
series
.cast(&DataType::Float64)
.map_err(|e| TreeBoostError::Data(format!("Failed to cast to f64: {}", e)))?
.f64()
.map_err(|e| TreeBoostError::Data(format!("Failed to get f64 chunked: {}", e)))?
.into_iter()
.map(|opt| Ok(opt.unwrap_or(f64::NAN)))
.collect()
}
fn series_to_f32(&self, series: &Series) -> Result<Vec<f32>> {
series
.cast(&DataType::Float32)
.map_err(|e| TreeBoostError::Data(format!("Failed to cast to f32: {}", e)))?
.f32()
.map_err(|e| TreeBoostError::Data(format!("Failed to get f32 chunked: {}", e)))?
.into_iter()
.map(|opt| Ok(opt.unwrap_or(f32::NAN)))
.collect()
}
fn categorical_to_numeric(&self, series: &Series) -> Result<Vec<f64>> {
use std::collections::HashMap;
let str_series = series
.cast(&DataType::String)
.map_err(|e| TreeBoostError::Data(format!("Failed to cast to string: {}", e)))?;
let str_ca = str_series
.str()
.map_err(|e| TreeBoostError::Data(format!("Failed to get string chunked: {}", e)))?;
let mut mapping: HashMap<String, u32> = HashMap::new();
let mut next_idx = 0u32;
let values: Vec<f64> = str_ca
.into_iter()
.map(|opt| match opt {
Some(s) => {
let idx = *mapping.entry(s.to_string()).or_insert_with(|| {
let idx = next_idx;
next_idx += 1;
idx
});
idx as f64
}
None => f64::NAN,
})
.collect();
Ok(values)
}
pub fn load_parquet_for_prediction(
&self,
path: impl AsRef<Path>,
feature_info: &[FeatureInfo],
) -> Result<BinnedDataset> {
let pl_path = PlPath::new(&path.as_ref().to_string_lossy());
let df = LazyFrame::scan_parquet(pl_path, Default::default())?.collect()?;
self.from_dataframe_for_prediction(df, feature_info)
}
pub fn load_csv_for_prediction(
&self,
path: impl AsRef<Path>,
feature_info: &[FeatureInfo],
) -> Result<BinnedDataset> {
let df = CsvReadOptions::default()
.try_into_reader_with_file_path(Some(path.as_ref().to_path_buf()))?
.finish()?;
self.from_dataframe_for_prediction(df, feature_info)
}
pub fn from_dataframe_for_prediction(
&self,
df: DataFrame,
feature_info: &[FeatureInfo],
) -> Result<BinnedDataset> {
let num_rows = df.height();
let mut all_binned: Vec<Vec<u8>> = Vec::with_capacity(feature_info.len());
let mut all_info: Vec<FeatureInfo> = Vec::with_capacity(feature_info.len());
for info in feature_info {
let col = df.column(&info.name).map_err(|e| {
TreeBoostError::Data(format!("Feature column '{}' not found: {}", info.name, e))
})?;
let series = col.as_materialized_series();
let binned = self.bin_with_boundaries(series, info)?;
all_binned.push(binned);
all_info.push(info.clone());
}
let mut features = Vec::with_capacity(num_rows * all_binned.len());
for binned_col in all_binned {
features.extend(binned_col);
}
let targets = vec![0.0f32; num_rows];
Ok(BinnedDataset::new(num_rows, features, targets, all_info))
}
fn bin_with_boundaries(&self, series: &Series, info: &FeatureInfo) -> Result<Vec<u8>> {
match info.feature_type {
FeatureType::Numeric => {
let values = self.series_to_f64(series)?;
Ok(values
.iter()
.map(|&v| QuantileBinner::bin_value(v, &info.bin_boundaries))
.collect())
}
FeatureType::Categorical => {
let values = self.categorical_to_numeric(series)?;
Ok(values
.iter()
.map(|&v| QuantileBinner::bin_value(v, &info.bin_boundaries))
.collect())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_loader_from_dataframe() {
let df = df! {
"feature1" => &[1.0, 2.0, 3.0, 4.0, 5.0],
"feature2" => &[10.0, 20.0, 30.0, 40.0, 50.0],
"target" => &[100.0, 200.0, 300.0, 400.0, 500.0]
}
.unwrap();
let loader = DatasetLoader::new(4);
let dataset = loader.from_dataframe(df, "target", None).unwrap();
assert_eq!(dataset.num_rows(), 5);
assert_eq!(dataset.num_features(), 2);
assert_eq!(dataset.targets().len(), 5);
}
}