use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
use ndarray::Array2;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum BinStrategy {
Uniform,
Quantile,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum EncodeStrategy {
Ordinal,
Onehot,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct KBinsDiscretizer {
pub n_bins: usize,
pub strategy: BinStrategy,
pub encode: EncodeStrategy,
}
impl KBinsDiscretizer {
pub fn new() -> Self {
Self {
n_bins: 5,
strategy: BinStrategy::Quantile,
encode: EncodeStrategy::Ordinal,
}
}
pub fn n_bins(mut self, n_bins: usize) -> Self {
self.n_bins = n_bins;
self
}
pub fn strategy(mut self, strategy: BinStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn encode(mut self, encode: EncodeStrategy) -> Self {
self.encode = encode;
self
}
}
impl Default for KBinsDiscretizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
pub struct FittedKBinsDiscretizer<F: Float> {
bin_edges: Vec<Vec<F>>,
n_bins: usize,
encode: EncodeStrategy,
}
fn percentile_sorted<F: Float>(sorted: &[F], p: f64) -> F {
let n = sorted.len();
if n == 1 {
return sorted[0];
}
let idx = p * (n - 1) as f64;
let lo = idx.floor() as usize;
let hi = idx.ceil().min((n - 1) as f64) as usize;
if lo == hi {
sorted[lo]
} else {
let frac = F::from_f64(idx - lo as f64).unwrap();
sorted[lo] * (F::one() - frac) + sorted[hi] * frac
}
}
impl<F: Float> FitUnsupervised<F> for KBinsDiscretizer {
type Fitted = FittedKBinsDiscretizer<F>;
fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
if x.is_empty() {
return Err(RustMlError::EmptyInput("input array is empty".into()));
}
if self.n_bins < 2 {
return Err(RustMlError::InvalidParameter(
"n_bins must be at least 2".into(),
));
}
let ncols = x.ncols();
let mut bin_edges = Vec::with_capacity(ncols);
for j in 0..ncols {
let mut col: Vec<F> = x.column(j).to_vec();
col.sort_by(|a, b| a.partial_cmp(b).unwrap());
let edges = match self.strategy {
BinStrategy::Uniform => {
let min_val = col[0];
let max_val = col[col.len() - 1];
let range = max_val - min_val;
let step = range / F::from_usize(self.n_bins).unwrap();
let mut e = Vec::with_capacity(self.n_bins + 1);
for i in 0..=self.n_bins {
e.push(min_val + step * F::from_usize(i).unwrap());
}
e
}
BinStrategy::Quantile => {
let mut e = Vec::with_capacity(self.n_bins + 1);
for i in 0..=self.n_bins {
let p = i as f64 / self.n_bins as f64;
e.push(percentile_sorted(&col, p));
}
e
}
};
bin_edges.push(edges);
}
Ok(FittedKBinsDiscretizer {
bin_edges,
n_bins: self.n_bins,
encode: self.encode,
})
}
}
fn find_bin<F: Float>(val: F, edges: &[F], n_bins: usize) -> usize {
let mut lo = 0;
let mut hi = edges.len() - 1;
if val <= edges[0] {
return 0;
}
if val >= edges[edges.len() - 1] {
return n_bins - 1;
}
while lo + 1 < hi {
let mid = (lo + hi) / 2;
if edges[mid] <= val {
lo = mid;
} else {
hi = mid;
}
}
lo.min(n_bins - 1)
}
impl<F: Float> Transform<F> for FittedKBinsDiscretizer<F> {
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
let expected_cols = self.bin_edges.len();
if x.ncols() != expected_cols {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
expected_cols,
x.ncols()
)));
}
match self.encode {
EncodeStrategy::Ordinal => {
let mut result = Array2::<F>::zeros(x.raw_dim());
for i in 0..x.nrows() {
for j in 0..x.ncols() {
let bin = find_bin(x[[i, j]], &self.bin_edges[j], self.n_bins);
result[[i, j]] = F::from_usize(bin).unwrap();
}
}
Ok(result)
}
EncodeStrategy::Onehot => {
let out_cols = expected_cols * self.n_bins;
let mut result = Array2::<F>::zeros((x.nrows(), out_cols));
for i in 0..x.nrows() {
for j in 0..x.ncols() {
let bin = find_bin(x[[i, j]], &self.bin_edges[j], self.n_bins);
let col_offset = j * self.n_bins + bin;
result[[i, col_offset]] = F::one();
}
}
Ok(result)
}
}
}
}
impl<F: Float> FittedKBinsDiscretizer<F> {
pub fn bin_edges(&self) -> &Vec<Vec<F>> {
&self.bin_edges
}
pub fn n_bins(&self) -> usize {
self.n_bins
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_uniform_ordinal() {
let x = array![
[0.0, 0.0],
[2.5, 5.0],
[5.0, 10.0],
[7.5, 15.0],
[10.0, 20.0],
];
let kbd = KBinsDiscretizer::new()
.n_bins(4)
.strategy(BinStrategy::Uniform)
.encode(EncodeStrategy::Ordinal);
let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[2, 0]], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(transformed[[4, 0]], 3.0, epsilon = 1e-10);
}
#[test]
fn test_quantile_ordinal() {
let x = array![
[1.0],
[2.0],
[3.0],
[4.0],
[5.0],
[6.0],
[7.0],
[8.0],
[9.0],
[10.0],
];
let kbd = KBinsDiscretizer::new()
.n_bins(5)
.strategy(BinStrategy::Quantile)
.encode(EncodeStrategy::Ordinal);
let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
for &v in transformed.iter() {
assert!(v >= 0.0 && v <= 4.0, "bin index out of range: {}", v);
}
for i in 1..x.nrows() {
assert!(
transformed[[i, 0]] >= transformed[[i - 1, 0]],
"monotonicity violated at row {}",
i
);
}
}
#[test]
fn test_onehot_encoding() {
let x = array![[1.0], [3.0], [5.0], [7.0], [9.0]];
let kbd = KBinsDiscretizer::new()
.n_bins(3)
.strategy(BinStrategy::Uniform)
.encode(EncodeStrategy::Onehot);
let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_eq!(transformed.ncols(), 3);
for i in 0..transformed.nrows() {
let row_sum: f64 = transformed.row(i).sum();
assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
}
}
#[test]
fn test_onehot_multiple_features() {
let x = array![[1.0, 10.0], [5.0, 50.0], [9.0, 90.0]];
let kbd = KBinsDiscretizer::new()
.n_bins(3)
.strategy(BinStrategy::Uniform)
.encode(EncodeStrategy::Onehot);
let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
assert_eq!(transformed.ncols(), 6);
for i in 0..transformed.nrows() {
let row_sum: f64 = transformed.row(i).sum();
assert_abs_diff_eq!(row_sum, 2.0, epsilon = 1e-10);
}
}
#[test]
fn test_empty_input() {
let x: Array2<f64> = Array2::zeros((0, 0));
let kbd = KBinsDiscretizer::default();
assert!(FitUnsupervised::<f64>::fit(&kbd, &x).is_err());
}
#[test]
fn test_invalid_n_bins() {
let x = array![[1.0], [2.0], [3.0]];
let kbd = KBinsDiscretizer::new().n_bins(1);
assert!(FitUnsupervised::<f64>::fit(&kbd, &x).is_err());
}
#[test]
fn test_shape_mismatch() {
let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let kbd = KBinsDiscretizer::default();
let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
let x_wrong = array![[1.0, 2.0, 3.0]];
assert!(fitted.transform(&x_wrong).is_err());
}
#[test]
fn test_out_of_range_values() {
let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
let kbd = KBinsDiscretizer::new()
.n_bins(3)
.strategy(BinStrategy::Uniform)
.encode(EncodeStrategy::Ordinal);
let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
let x_test = array![[-10.0], [0.0], [3.0], [6.0], [100.0]];
let transformed = fitted.transform(&x_test).unwrap();
assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[4, 0]], 2.0, epsilon = 1e-10); }
#[test]
fn test_constant_feature() {
let x = array![[5.0], [5.0], [5.0], [5.0]];
let kbd = KBinsDiscretizer::new()
.n_bins(3)
.strategy(BinStrategy::Uniform)
.encode(EncodeStrategy::Ordinal);
let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
for &v in transformed.iter() {
assert!(v.is_finite(), "constant feature produced non-finite: {}", v);
}
}
#[test]
fn test_f32() {
let x = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let kbd = KBinsDiscretizer::new()
.n_bins(3)
.strategy(BinStrategy::Quantile)
.encode(EncodeStrategy::Ordinal);
let fitted = FitUnsupervised::<f32>::fit(&kbd, &x).unwrap();
let transformed = fitted.transform(&x).unwrap();
for &v in transformed.iter() {
assert!(v.is_finite());
assert!(v >= 0.0 && v < 3.0);
}
}
}