oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
use crate::core::traits::Transformer;
use realfft::{RealFftPlanner, RealToComplex};
use std::sync::Arc;

#[derive(Debug, Clone)]
pub struct DftConfig {
    pub n_coefs: Option<usize>,
    pub norm_mean: bool,
    pub norm_std: bool,
    pub drop_sum: bool,
}

impl DftConfig {
    pub fn new() -> Self {
        Self {
            n_coefs: None,
            norm_mean: false,
            norm_std: false,
            drop_sum: false,
        }
    }
}

impl Default for DftConfig {
    fn default() -> Self {
        Self::new()
    }
}

pub struct Dft;

impl Transformer for Dft {
    type Config = DftConfig;

    fn transform(config: &Self::Config, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
        assert!(!x.is_empty(), "Input must have at least one sample");
        let n_timestamps = x[0].len();
        assert!(
            x.iter().all(|s| s.len() == n_timestamps),
            "All samples must have same length"
        );

        // Create FFT plan once and share across all samples
        let mut planner = RealFftPlanner::<f64>::new();
        let fft = planner.plan_fft_forward(n_timestamps);

        #[cfg(feature = "parallel")]
        {
            use rayon::prelude::*;
            return x
                .par_iter()
                .map(|sample| dft_single(sample, config, &fft))
                .collect();
        }

        #[cfg(not(feature = "parallel"))]
        x.iter()
            .map(|sample| dft_single(sample, config, &fft))
            .collect()
    }
}

/// Normalize data in-place: subtract mean and/or divide by std.
fn normalize_data(data: &mut [f64], norm_mean: bool, norm_std: bool) {
    let n = data.len() as f64;
    let mean = data.iter().sum::<f64>() / n;
    if norm_mean {
        for v in data.iter_mut() {
            *v -= mean;
        }
    }
    if norm_std {
        let var = data
            .iter()
            .map(|&v| {
                let centered = if norm_mean { v } else { v - mean };
                centered * centered
            })
            .sum::<f64>()
            / n;
        let std = var.sqrt();
        if std > 0.0 {
            for v in data.iter_mut() {
                *v /= std;
            }
        }
    }
}

/// Check if a frequency bin is real-only (DC or Nyquist).
#[inline]
fn is_real_only_bin(i: usize, last_idx: usize, nyquist_real_only: bool) -> bool {
    i == 0 || (i == last_idx && nyquist_real_only)
}

/// Extract interleaved real coefficients from complex FFT spectrum.
///
/// Layout: `[Re(c0), Re(c1), Im(c1), Re(c2), Im(c2), ...]`
/// DC (index 0) and Nyquist (last, if n is even) are real-only.
fn extract_coefs(
    spectrum: &[realfft::num_complex::Complex<f64>],
    n: usize,
    n_coefs_limit: usize,
    start_idx: usize,
) -> Vec<f64> {
    let mut coefs = Vec::with_capacity(n_coefs_limit);
    let last_idx = spectrum.len() - 1;
    let nyquist_real_only = n.is_multiple_of(2);

    for (i, c) in spectrum.iter().enumerate().skip(start_idx) {
        if coefs.len() >= n_coefs_limit {
            break;
        }
        coefs.push(c.re);
        if !is_real_only_bin(i, last_idx, nyquist_real_only) && coefs.len() < n_coefs_limit {
            coefs.push(c.im);
        }
    }
    coefs
}

fn dft_single(sample: &[f64], config: &DftConfig, fft: &Arc<dyn RealToComplex<f64>>) -> Vec<f64> {
    let n = sample.len();
    let mut data = sample.to_vec();

    if config.norm_mean || config.norm_std {
        normalize_data(&mut data, config.norm_mean, config.norm_std);
    }

    let mut spectrum = fft.make_output_vec();
    fft.process(&mut data, &mut spectrum).unwrap();

    let n_coefs_limit = config.n_coefs.unwrap_or(n);
    let start_idx = if config.drop_sum { 1 } else { 0 };
    extract_coefs(&spectrum, n, n_coefs_limit, start_idx)
}

/// Get the indices of Fourier coefficients selected by ANOVA F-test.
/// Returns the indices of the top `n_coefs` coefficients by F-statistic.
pub fn anova_selection(
    x: &[Vec<f64>],
    y: &[String],
    n_coefs: usize,
    config: &DftConfig,
) -> Vec<usize> {
    assert_eq!(x.len(), y.len(), "X and y must have same number of samples");

    // Compute DFT for all samples (uses shared FFT plan via Dft::transform)
    let dft_config = DftConfig {
        n_coefs: None, // need all coefficients for ANOVA selection
        norm_mean: config.norm_mean,
        norm_std: config.norm_std,
        drop_sum: config.drop_sum,
    };
    let all_coefs = Dft::transform(&dft_config, x);
    let n_features = all_coefs[0].len();
    let n_coefs = n_coefs.min(n_features);

    // Compute F-statistic for each coefficient
    let classes: Vec<&str> = {
        let mut c: Vec<&str> = y.iter().map(|s| s.as_str()).collect();
        c.sort();
        c.dedup();
        c
    };

    let mut f_scores = vec![0.0; n_features];
    let grand_mean: Vec<f64> = (0..n_features)
        .map(|j| all_coefs.iter().map(|c| c[j]).sum::<f64>() / all_coefs.len() as f64)
        .collect();

    for j in 0..n_features {
        let n_total = all_coefs.len() as f64;
        let k = classes.len() as f64;

        let mut ss_between = 0.0;
        let mut ss_within = 0.0;

        for class in &classes {
            let class_vals: Vec<f64> = all_coefs
                .iter()
                .zip(y.iter())
                .filter(|(_, label)| label.as_str() == *class)
                .map(|(c, _)| c[j])
                .collect();
            let n_c = class_vals.len() as f64;
            let class_mean = class_vals.iter().sum::<f64>() / n_c;

            let d = class_mean - grand_mean[j];
            ss_between += n_c * d * d;
            ss_within += class_vals
                .iter()
                .map(|&v| {
                    let d = v - class_mean;
                    d * d
                })
                .sum::<f64>();
        }

        let ms_between = ss_between / (k - 1.0);
        let ms_within = if n_total - k > 0.0 {
            ss_within / (n_total - k)
        } else {
            1.0
        };

        f_scores[j] = if ms_within > 0.0 {
            ms_between / ms_within
        } else {
            0.0
        };
    }

    // Get top n_coefs indices by F-score
    let mut indices: Vec<usize> = (0..n_features).collect();
    indices.sort_by(|&a, &b| f_scores[b].partial_cmp(&f_scores[a]).unwrap());
    indices.truncate(n_coefs);
    indices.sort(); // Return in original order
    indices
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_dft_basic() {
        let config = DftConfig::new();
        let x = vec![vec![1.0, 2.0, 3.0, 4.0]];
        let result = Dft::transform(&config, &x);
        assert!(!result[0].is_empty());
        // DC component should be the sum
        assert!((result[0][0] - 10.0).abs() < 1e-10);
    }

    #[test]
    fn test_dft_with_n_coefs() {
        let config = DftConfig {
            n_coefs: Some(3),
            ..DftConfig::new()
        };
        let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]];
        let result = Dft::transform(&config, &x);
        assert_eq!(result[0].len(), 3);
    }

    #[test]
    fn test_dft_drop_sum() {
        let config_keep = DftConfig::new();
        let config_drop = DftConfig {
            drop_sum: true,
            ..DftConfig::new()
        };
        let x = vec![vec![1.0, 2.0, 3.0, 4.0]];
        let with_sum = Dft::transform(&config_keep, &x);
        let without_sum = Dft::transform(&config_drop, &x);
        // Without sum should have fewer coefficients
        assert!(without_sum[0].len() < with_sum[0].len());
    }

    #[test]
    fn test_dft_norm_mean() {
        let config = DftConfig {
            norm_mean: true,
            ..DftConfig::new()
        };
        let x = vec![vec![1.0, 2.0, 3.0, 4.0]];
        let result = Dft::transform(&config, &x);
        // DC component should be ~0 when mean is removed
        assert!(result[0][0].abs() < 1e-10);
    }

    #[test]
    fn test_dft_constant_series() {
        let config = DftConfig::new();
        let x = vec![vec![5.0, 5.0, 5.0, 5.0]];
        let result = Dft::transform(&config, &x);
        // DC = sum = 20.0, all other coefficients should be 0
        assert!((result[0][0] - 20.0).abs() < 1e-10);
        for &v in &result[0][1..] {
            assert!(v.abs() < 1e-10);
        }
    }
}