use crate::approximation::paa::{Paa, PaaConfig};
use crate::core::config::BinStrategy;
use crate::core::traits::Transformer;
use crate::preprocessing::discretizer::{KBinsDiscretizer, KBinsDiscretizerConfig};
use crate::preprocessing::scaler::{StandardScaler, StandardScalerConfig};
#[derive(Debug, Clone)]
pub struct SaxConfig {
pub n_bins: usize,
pub strategy: BinStrategy,
pub output_size: Option<usize>,
}
impl SaxConfig {
pub fn new(n_bins: usize) -> Self {
Self {
n_bins,
strategy: BinStrategy::Normal,
output_size: None,
}
}
}
pub struct Sax;
impl Transformer for Sax {
type Config = SaxConfig;
fn transform(config: &Self::Config, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
assert!(!x.is_empty(), "Input must have at least one sample");
assert!(
config.n_bins >= 2 && config.n_bins <= 26,
"n_bins must be in [2, 26]"
);
let n_timestamps = x[0].len();
let output_size = config.output_size.unwrap_or(n_timestamps);
let scaler_config = StandardScalerConfig::new();
let scaled = StandardScaler::transform(&scaler_config, x);
let paa_result = if output_size < n_timestamps {
let paa_config = PaaConfig::new(output_size);
Paa::transform(&paa_config, &scaled)
} else {
scaled
};
let disc_config = KBinsDiscretizerConfig {
n_bins: config.n_bins,
strategy: config.strategy,
};
let discretized = KBinsDiscretizer::transform(&disc_config, &paa_result);
discretized
.into_iter()
.map(|row| row.into_iter().map(|v| v as f64).collect())
.collect()
}
}
pub fn sax_transform_symbolic(config: &SaxConfig, x: &[Vec<f64>]) -> Vec<Vec<u8>> {
let float_result = Sax::transform(config, x);
float_result
.into_iter()
.map(|row| row.into_iter().map(|v| v as u8).collect())
.collect()
}
pub fn bins_to_alphabet(bins: &[u8]) -> String {
bins.iter().map(|&b| (b'a' + b) as char).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sax_basic() {
let config = SaxConfig::new(3);
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
vec![7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
];
let result = Sax::transform(&config, &x);
assert_eq!(result.len(), 2);
assert_eq!(result[0].len(), 8);
for row in &result {
for &v in row {
assert!((0.0..=2.0).contains(&v), "Bin {v} out of range");
}
}
}
#[test]
fn test_sax_with_paa() {
let config = SaxConfig {
n_bins: 3,
strategy: BinStrategy::Normal,
output_size: Some(4),
};
let x = vec![vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]];
let result = Sax::transform(&config, &x);
assert_eq!(result[0].len(), 4);
}
#[test]
fn test_sax_symbolic() {
let config = SaxConfig::new(3);
let x = vec![vec![0.0, 1.0, 2.0, 3.0, 4.0]];
let result = sax_transform_symbolic(&config, &x);
assert_eq!(result[0].len(), 5);
for &v in &result[0] {
assert!(v < 3);
}
}
#[test]
fn test_bins_to_alphabet() {
assert_eq!(bins_to_alphabet(&[0, 1, 2]), "abc");
assert_eq!(bins_to_alphabet(&[0, 0, 1, 1]), "aabb");
}
#[test]
fn test_sax_monotonic() {
let config = SaxConfig::new(4);
let x = vec![vec![-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0]];
let result = Sax::transform(&config, &x);
for i in 1..result[0].len() {
assert!(result[0][i] >= result[0][i - 1]);
}
}
}