oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
use crate::core::config::BinStrategy;
use crate::preprocessing::quantile_transform::norm_ppf;

#[derive(Debug, Clone)]
pub struct KBinsDiscretizerConfig {
    pub n_bins: usize,
    pub strategy: BinStrategy,
}

impl KBinsDiscretizerConfig {
    pub fn new(n_bins: usize) -> Self {
        Self {
            n_bins,
            strategy: BinStrategy::Quantile,
        }
    }
}

/// Discretize time series into bin indices (0-based).
pub struct KBinsDiscretizer;

impl KBinsDiscretizer {
    pub fn transform(config: &KBinsDiscretizerConfig, x: &[Vec<f64>]) -> Vec<Vec<usize>> {
        assert!(!x.is_empty(), "Input must have at least one sample");
        assert!(
            config.n_bins >= 2,
            "n_bins must be at least 2, got {}",
            config.n_bins
        );

        match config.strategy {
            BinStrategy::Normal => {
                let edges = normal_bin_edges(config.n_bins);

                #[cfg(feature = "parallel")]
                {
                    use rayon::prelude::*;
                    return x
                        .par_iter()
                        .map(|sample| digitize_1d(sample, &edges))
                        .collect();
                }

                #[cfg(not(feature = "parallel"))]
                x.iter().map(|sample| digitize_1d(sample, &edges)).collect()
            }
            BinStrategy::Uniform => {
                let n_bins = config.n_bins;
                let discretize = |sample: &Vec<f64>| {
                    let edges = uniform_bin_edges(sample, n_bins);
                    digitize_1d(sample, &edges)
                };

                #[cfg(feature = "parallel")]
                {
                    use rayon::prelude::*;
                    return x.par_iter().map(discretize).collect();
                }

                #[cfg(not(feature = "parallel"))]
                x.iter().map(discretize).collect()
            }
            BinStrategy::Quantile => {
                let n_bins = config.n_bins;
                let discretize = |sample: &Vec<f64>| {
                    let edges = quantile_bin_edges(sample, n_bins);
                    digitize_1d(sample, &edges)
                };

                #[cfg(feature = "parallel")]
                {
                    use rayon::prelude::*;
                    return x.par_iter().map(discretize).collect();
                }

                #[cfg(not(feature = "parallel"))]
                x.iter().map(discretize).collect()
            }
        }
    }
}

/// Compute bin edges for normal strategy: standard normal quantiles.
fn normal_bin_edges(n_bins: usize) -> Vec<f64> {
    (1..n_bins)
        .map(|i| {
            let p = i as f64 / n_bins as f64;
            norm_ppf(p)
        })
        .collect()
}

/// Compute bin edges for uniform strategy (per-sample).
fn uniform_bin_edges(x: &[f64], n_bins: usize) -> Vec<f64> {
    let x_min = x.iter().copied().fold(f64::INFINITY, f64::min);
    let x_max = x.iter().copied().fold(f64::NEG_INFINITY, f64::max);
    (1..n_bins)
        .map(|i| x_min + (x_max - x_min) * i as f64 / n_bins as f64)
        .collect()
}

/// Compute bin edges for quantile strategy (per-sample).
/// Sorts once and computes all percentiles from the sorted array.
fn quantile_bin_edges(x: &[f64], n_bins: usize) -> Vec<f64> {
    let mut sorted: Vec<f64> = x.to_vec();
    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
    let n = sorted.len();

    let mut edges: Vec<f64> = (1..n_bins)
        .map(|i| {
            let p = i as f64 / n_bins as f64;
            let idx = p * (n - 1) as f64;
            let lo = idx.floor() as usize;
            let hi = lo + 1;
            let frac = idx - lo as f64;
            if hi >= n {
                sorted[n - 1]
            } else {
                sorted[lo] + frac * (sorted[hi] - sorted[lo])
            }
        })
        .collect();

    // Deduplicate adjacent edges (within tolerance)
    edges.dedup_by(|a, b| (*a - *b).abs() < 1e-8);
    edges
}

/// Digitize: for each value, find the bin index via searchsorted (left).
/// Returns 0-based bin indices in [0, n_edges].
fn digitize_1d(x: &[f64], edges: &[f64]) -> Vec<usize> {
    x.iter()
        .map(|&v| edges.partition_point(|&e| e <= v))
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_uniform_basic() {
        let config = KBinsDiscretizerConfig {
            n_bins: 4,
            strategy: BinStrategy::Uniform,
        };
        let x = vec![vec![0.0, 2.5, 5.0, 7.5, 10.0]];
        let result = KBinsDiscretizer::transform(&config, &x);
        assert_eq!(result[0].len(), 5);
        assert_eq!(result[0][0], 0); // 0.0 -> bin 0
                                     // Edges at 2.5, 5.0, 7.5 — 10.0 is past all edges -> bin 3
        assert_eq!(result[0][4], 3);
    }

    #[test]
    fn test_normal_strategy() {
        let config = KBinsDiscretizerConfig {
            n_bins: 3,
            strategy: BinStrategy::Normal,
        };
        let x = vec![vec![-2.0, -0.5, 0.0, 0.5, 2.0]];
        let result = KBinsDiscretizer::transform(&config, &x);
        // Normal bins: edges at norm_ppf(1/3) ≈ -0.43 and norm_ppf(2/3) ≈ 0.43
        assert_eq!(result[0].len(), 5);
        assert_eq!(result[0][0], 0); // -2.0 -> bin 0
        assert_eq!(result[0][4], 2); // 2.0 -> bin 2
    }

    #[test]
    fn test_quantile_strategy() {
        let config = KBinsDiscretizerConfig {
            n_bins: 2,
            strategy: BinStrategy::Quantile,
        };
        let x = vec![vec![1.0, 2.0, 3.0, 4.0]];
        let result = KBinsDiscretizer::transform(&config, &x);
        assert_eq!(result[0].len(), 4);
        // 1 edge at 50th percentile = 2.5
        // Values < 2.5 -> bin 0, values >= 2.5 -> bin 1
    }

    #[test]
    fn test_multiple_samples() {
        let config = KBinsDiscretizerConfig::new(3);
        let x = vec![vec![1.0, 2.0, 3.0, 4.0], vec![10.0, 20.0, 30.0, 40.0]];
        let result = KBinsDiscretizer::transform(&config, &x);
        assert_eq!(result.len(), 2);
    }

    #[test]
    #[should_panic(expected = "n_bins must be at least 2")]
    fn test_invalid_n_bins() {
        let config = KBinsDiscretizerConfig::new(1);
        KBinsDiscretizer::transform(&config, &[vec![1.0, 2.0]]);
    }

    #[test]
    fn test_normal_bin_edges() {
        let edges = normal_bin_edges(3);
        assert_eq!(edges.len(), 2);
        // Should be symmetric around 0
        assert!((edges[0] + edges[1]).abs() < 1e-10);
    }
}