use crate::core::config::BinStrategy;
use crate::preprocessing::quantile_transform::norm_ppf;
#[derive(Debug, Clone)]
pub struct KBinsDiscretizerConfig {
pub n_bins: usize,
pub strategy: BinStrategy,
}
impl KBinsDiscretizerConfig {
pub fn new(n_bins: usize) -> Self {
Self {
n_bins,
strategy: BinStrategy::Quantile,
}
}
}
pub struct KBinsDiscretizer;
impl KBinsDiscretizer {
pub fn transform(config: &KBinsDiscretizerConfig, x: &[Vec<f64>]) -> Vec<Vec<usize>> {
assert!(!x.is_empty(), "Input must have at least one sample");
assert!(
config.n_bins >= 2,
"n_bins must be at least 2, got {}",
config.n_bins
);
match config.strategy {
BinStrategy::Normal => {
let edges = normal_bin_edges(config.n_bins);
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
return x
.par_iter()
.map(|sample| digitize_1d(sample, &edges))
.collect();
}
#[cfg(not(feature = "parallel"))]
x.iter().map(|sample| digitize_1d(sample, &edges)).collect()
}
BinStrategy::Uniform => {
let n_bins = config.n_bins;
let discretize = |sample: &Vec<f64>| {
let edges = uniform_bin_edges(sample, n_bins);
digitize_1d(sample, &edges)
};
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
return x.par_iter().map(discretize).collect();
}
#[cfg(not(feature = "parallel"))]
x.iter().map(discretize).collect()
}
BinStrategy::Quantile => {
let n_bins = config.n_bins;
let discretize = |sample: &Vec<f64>| {
let edges = quantile_bin_edges(sample, n_bins);
digitize_1d(sample, &edges)
};
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
return x.par_iter().map(discretize).collect();
}
#[cfg(not(feature = "parallel"))]
x.iter().map(discretize).collect()
}
}
}
}
fn normal_bin_edges(n_bins: usize) -> Vec<f64> {
(1..n_bins)
.map(|i| {
let p = i as f64 / n_bins as f64;
norm_ppf(p)
})
.collect()
}
fn uniform_bin_edges(x: &[f64], n_bins: usize) -> Vec<f64> {
let x_min = x.iter().copied().fold(f64::INFINITY, f64::min);
let x_max = x.iter().copied().fold(f64::NEG_INFINITY, f64::max);
(1..n_bins)
.map(|i| x_min + (x_max - x_min) * i as f64 / n_bins as f64)
.collect()
}
fn quantile_bin_edges(x: &[f64], n_bins: usize) -> Vec<f64> {
let mut sorted: Vec<f64> = x.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let n = sorted.len();
let mut edges: Vec<f64> = (1..n_bins)
.map(|i| {
let p = i as f64 / n_bins as f64;
let idx = p * (n - 1) as f64;
let lo = idx.floor() as usize;
let hi = lo + 1;
let frac = idx - lo as f64;
if hi >= n {
sorted[n - 1]
} else {
sorted[lo] + frac * (sorted[hi] - sorted[lo])
}
})
.collect();
edges.dedup_by(|a, b| (*a - *b).abs() < 1e-8);
edges
}
fn digitize_1d(x: &[f64], edges: &[f64]) -> Vec<usize> {
x.iter()
.map(|&v| edges.partition_point(|&e| e <= v))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uniform_basic() {
let config = KBinsDiscretizerConfig {
n_bins: 4,
strategy: BinStrategy::Uniform,
};
let x = vec![vec![0.0, 2.5, 5.0, 7.5, 10.0]];
let result = KBinsDiscretizer::transform(&config, &x);
assert_eq!(result[0].len(), 5);
assert_eq!(result[0][0], 0); assert_eq!(result[0][4], 3);
}
#[test]
fn test_normal_strategy() {
let config = KBinsDiscretizerConfig {
n_bins: 3,
strategy: BinStrategy::Normal,
};
let x = vec![vec![-2.0, -0.5, 0.0, 0.5, 2.0]];
let result = KBinsDiscretizer::transform(&config, &x);
assert_eq!(result[0].len(), 5);
assert_eq!(result[0][0], 0); assert_eq!(result[0][4], 2); }
#[test]
fn test_quantile_strategy() {
let config = KBinsDiscretizerConfig {
n_bins: 2,
strategy: BinStrategy::Quantile,
};
let x = vec![vec![1.0, 2.0, 3.0, 4.0]];
let result = KBinsDiscretizer::transform(&config, &x);
assert_eq!(result[0].len(), 4);
}
#[test]
fn test_multiple_samples() {
let config = KBinsDiscretizerConfig::new(3);
let x = vec![vec![1.0, 2.0, 3.0, 4.0], vec![10.0, 20.0, 30.0, 40.0]];
let result = KBinsDiscretizer::transform(&config, &x);
assert_eq!(result.len(), 2);
}
#[test]
#[should_panic(expected = "n_bins must be at least 2")]
fn test_invalid_n_bins() {
let config = KBinsDiscretizerConfig::new(1);
KBinsDiscretizer::transform(&config, &[vec![1.0, 2.0]]);
}
#[test]
fn test_normal_bin_edges() {
let edges = normal_bin_edges(3);
assert_eq!(edges.len(), 2);
assert!((edges[0] + edges[1]).abs() < 1e-10);
}
}