use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
pub const MAX_BINS: usize = 256;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct FeatureBinner {
bin_edges: Vec<Vec<f64>>,
n_bins_per_feature: Vec<usize>,
max_bins: usize,
fitted: bool,
#[cfg_attr(feature = "serde", serde(default))]
_schema_version: u32,
}
impl FeatureBinner {
pub fn new() -> Self {
Self {
bin_edges: Vec::new(),
n_bins_per_feature: Vec::new(),
max_bins: MAX_BINS,
fitted: false,
_schema_version: crate::version::SCHEMA_VERSION,
}
}
pub fn max_bins(mut self, bins: usize) -> Self {
self.max_bins = bins.clamp(2, MAX_BINS);
self
}
pub fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_no_inf()?;
if data.n_samples() == 0 {
return Err(ScryLearnError::EmptyDataset);
}
let n_features = data.n_features();
let valid_bins = self.max_bins - 1;
self.bin_edges = Vec::with_capacity(n_features);
self.n_bins_per_feature = Vec::with_capacity(n_features);
for f in 0..n_features {
let col = &data.features[f];
let mut valid: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
valid.sort_unstable_by(|a, b| a.total_cmp(b));
if valid.is_empty() {
self.bin_edges.push(Vec::new());
self.n_bins_per_feature.push(1);
continue;
}
valid.dedup();
let n_unique = valid.len();
let actual_bins = n_unique.min(valid_bins);
if actual_bins <= 1 {
self.bin_edges.push(Vec::new());
self.n_bins_per_feature.push(1);
continue;
}
let mut edges = Vec::with_capacity(actual_bins - 1);
for i in 1..actual_bins {
let q = i as f64 / actual_bins as f64;
let pos = q * (valid.len() - 1) as f64;
let lo = pos.floor() as usize;
let hi = (lo + 1).min(valid.len() - 1);
let frac = pos - lo as f64;
let edge = valid[lo] * (1.0 - frac) + valid[hi] * frac;
edges.push(edge);
}
edges.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON);
let n_valid_bins = edges.len() + 1;
self.n_bins_per_feature.push(n_valid_bins);
self.bin_edges.push(edges);
}
self.fitted = true;
Ok(())
}
pub fn transform(&self, data: &Dataset) -> Result<Vec<Vec<u8>>> {
if !self.fitted {
return Err(ScryLearnError::NotFitted);
}
let n_features = data.n_features();
if n_features != self.bin_edges.len() {
return Err(ScryLearnError::ShapeMismatch {
expected: self.bin_edges.len(),
got: n_features,
});
}
let n_samples = data.n_samples();
let mut result = Vec::with_capacity(n_features);
for f in 0..n_features {
let col = &data.features[f];
let edges = &self.bin_edges[f];
let mut binned = vec![0u8; n_samples];
for (i, &val) in col.iter().enumerate() {
if val.is_nan() {
binned[i] = 0; } else {
let bin = match edges.binary_search_by(|edge| {
edge.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Equal)
}) {
Ok(pos) => pos + 1, Err(pos) => pos,
};
binned[i] = (bin + 1).min(255) as u8;
}
}
result.push(binned);
}
Ok(result)
}
pub fn fit_transform(&mut self, data: &Dataset) -> Result<Vec<Vec<u8>>> {
self.fit(data)?;
self.transform(data)
}
pub fn n_bins_per_feature(&self) -> &[usize] {
&self.n_bins_per_feature
}
pub fn bin_edges(&self) -> &[Vec<f64>] {
&self.bin_edges
}
}
impl Default for FeatureBinner {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_dataset() -> Dataset {
Dataset::new(
vec![
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
vec![
100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700.0, 800.0, 900.0, 1000.0,
],
],
vec![0.0; 10],
vec!["a".into(), "b".into()],
"y",
)
}
#[test]
fn test_fit_transform_basic() {
let ds = simple_dataset();
let mut binner = FeatureBinner::new();
let binned = binner.fit_transform(&ds).unwrap();
assert_eq!(binned.len(), 2);
assert_eq!(binned[0].len(), 10);
for &b in &binned[0] {
assert!(b >= 1, "valid values should map to bins >= 1");
}
for i in 1..10 {
assert!(binned[0][i] >= binned[0][i - 1]);
}
}
#[test]
fn test_nan_handling() {
let ds = Dataset::new(
vec![vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0]],
vec![0.0; 5],
vec!["x".into()],
"y",
);
let mut binner = FeatureBinner::new();
let binned = binner.fit_transform(&ds).unwrap();
assert_eq!(binned[0][1], 0, "NaN should map to bin 0");
assert_eq!(binned[0][3], 0, "NaN should map to bin 0");
assert!(binned[0][0] >= 1, "valid value should be >= 1");
}
#[test]
fn test_constant_feature() {
let ds = Dataset::new(
vec![vec![5.0, 5.0, 5.0, 5.0]],
vec![0.0; 4],
vec!["x".into()],
"y",
);
let mut binner = FeatureBinner::new();
let binned = binner.fit_transform(&ds).unwrap();
let first = binned[0][0];
for &b in &binned[0] {
assert_eq!(b, first);
}
}
#[test]
fn test_max_bins_param() {
let ds = simple_dataset();
let mut binner = FeatureBinner::new().max_bins(4);
let binned = binner.fit_transform(&ds).unwrap();
for &b in &binned[0] {
assert!(b <= 3, "with max_bins=4, bin index should be <= 3, got {b}");
}
}
#[test]
fn test_not_fitted_error() {
let ds = simple_dataset();
let binner = FeatureBinner::new();
let result = binner.transform(&ds);
assert!(result.is_err());
}
#[test]
fn test_all_nan_feature() {
let ds = Dataset::new(
vec![vec![f64::NAN, f64::NAN, f64::NAN]],
vec![0.0; 3],
vec!["x".into()],
"y",
);
let mut binner = FeatureBinner::new();
let binned = binner.fit_transform(&ds).unwrap();
for &b in &binned[0] {
assert_eq!(b, 0, "all-NaN feature should map entirely to bin 0");
}
}
}