oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
use crate::core::config::BinStrategy;
use crate::core::traits::FittableTransformer;
use crate::preprocessing::discretizer::{KBinsDiscretizer, KBinsDiscretizerConfig};

#[derive(Debug, Clone)]
pub struct McbConfig {
    pub n_bins: usize,
    pub strategy: BinStrategy,
}

impl McbConfig {
    pub fn new(n_bins: usize) -> Self {
        Self {
            n_bins,
            strategy: BinStrategy::Quantile,
        }
    }
}

/// Fitted state: per-coefficient bin edges.
#[derive(Debug, Clone)]
pub struct McbFitted {
    /// For normal strategy: single set of edges (shared across all coefficients)
    /// For uniform/quantile: per-coefficient edges
    pub bin_edges: Vec<Vec<f64>>,
    pub strategy: BinStrategy,
    pub n_bins: usize,
    pub alphabet: Vec<u8>,
}

pub struct Mcb;

impl FittableTransformer for Mcb {
    type Config = McbConfig;
    type Fitted = McbFitted;

    fn fit(config: &Self::Config, x: &[Vec<f64>], _y: Option<&[String]>) -> Self::Fitted {
        assert!(!x.is_empty(), "Input must have at least one sample");
        assert!(
            config.n_bins >= 2 && config.n_bins <= 26,
            "n_bins must be in [2, 26]"
        );

        let n_coefs = x[0].len();
        let alphabet: Vec<u8> = (0..config.n_bins as u8).collect();

        // Transpose: collect each coefficient across all samples
        let columns: Vec<Vec<f64>> = (0..n_coefs)
            .map(|j| x.iter().map(|sample| sample[j]).collect())
            .collect();

        let bin_edges = match config.strategy {
            BinStrategy::Normal => {
                // Normal strategy: single shared edges
                let disc_config = KBinsDiscretizerConfig {
                    n_bins: config.n_bins,
                    strategy: BinStrategy::Normal,
                };
                let dummy = vec![vec![0.0; 2]]; // not used for normal
                let _ = KBinsDiscretizer::transform(&disc_config, &dummy);

                // Compute normal bin edges
                use crate::preprocessing::quantile_transform::norm_ppf;
                let edges: Vec<f64> = (1..config.n_bins)
                    .map(|i| norm_ppf(i as f64 / config.n_bins as f64))
                    .collect();
                vec![edges]
            }
            _ => {
                // Per-coefficient bin edges
                columns
                    .iter()
                    .map(|col| compute_bin_edges(col, config.n_bins, config.strategy))
                    .collect()
            }
        };

        McbFitted {
            bin_edges,
            strategy: config.strategy,
            n_bins: config.n_bins,
            alphabet,
        }
    }

    fn transform(fitted: &Self::Fitted, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
        x.iter()
            .map(|sample| {
                sample
                    .iter()
                    .enumerate()
                    .map(|(j, &v)| {
                        let edges = if fitted.strategy == BinStrategy::Normal {
                            &fitted.bin_edges[0]
                        } else {
                            &fitted.bin_edges[j]
                        };
                        let bin = edges.partition_point(|&e| e <= v);
                        bin as f64
                    })
                    .collect()
            })
            .collect()
    }
}

/// Compute per-coefficient bin edges for the given strategy.
fn compute_bin_edges(values: &[f64], n_bins: usize, strategy: BinStrategy) -> Vec<f64> {
    let mut sorted = values.to_vec();
    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());

    match strategy {
        BinStrategy::Uniform => {
            let min_val = sorted[0];
            let max_val = sorted[sorted.len() - 1];
            (1..n_bins)
                .map(|i| min_val + (max_val - min_val) * i as f64 / n_bins as f64)
                .collect()
        }
        BinStrategy::Quantile => {
            use crate::preprocessing::scaler::percentile;
            let mut edges: Vec<f64> = (1..n_bins)
                .map(|i| percentile(values, i as f64 * 100.0 / n_bins as f64))
                .collect();
            edges.dedup_by(|a, b| (*a - *b).abs() < 1e-8);
            edges
        }
        BinStrategy::Normal => unreachable!("Normal strategy handled separately"),
    }
}

/// Transform to symbolic output (u8 bin indices).
pub fn mcb_transform_symbolic(fitted: &McbFitted, x: &[Vec<f64>]) -> Vec<Vec<u8>> {
    x.iter()
        .map(|sample| {
            sample
                .iter()
                .enumerate()
                .map(|(j, &v)| {
                    let edges = if fitted.strategy == BinStrategy::Normal {
                        &fitted.bin_edges[0]
                    } else {
                        &fitted.bin_edges[j]
                    };
                    edges.partition_point(|&e| e <= v) as u8
                })
                .collect()
        })
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_mcb_fit_transform() {
        let config = McbConfig::new(4);
        let x = vec![
            vec![1.0, 10.0],
            vec![2.0, 20.0],
            vec![3.0, 30.0],
            vec![4.0, 40.0],
        ];
        let fitted = Mcb::fit(&config, &x, None);
        let result = Mcb::transform(&fitted, &x);
        assert_eq!(result.len(), 4);
        assert_eq!(result[0].len(), 2);
    }

    #[test]
    fn test_mcb_normal() {
        let config = McbConfig {
            n_bins: 3,
            strategy: BinStrategy::Normal,
        };
        let x = vec![vec![-2.0, 0.5], vec![0.0, -0.5], vec![2.0, 1.0]];
        let fitted = Mcb::fit(&config, &x, None);
        assert_eq!(fitted.bin_edges.len(), 1); // single shared edges
    }

    #[test]
    fn test_mcb_symbolic() {
        let config = McbConfig::new(3);
        let x = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
        let fitted = Mcb::fit(&config, &x, None);
        let symbolic = mcb_transform_symbolic(&fitted, &x);
        assert_eq!(symbolic.len(), 2);
        for row in &symbolic {
            for &v in row {
                assert!(v < 3);
            }
        }
    }
}