use crate::approximation::paa::{Paa, PaaConfig};
use crate::core::config::BinStrategy;
use crate::core::traits::Transformer;
use crate::preprocessing::discretizer::{KBinsDiscretizer, KBinsDiscretizerConfig};
#[derive(Debug, Clone)]
pub struct MtfConfig {
pub image_size: Option<usize>,
pub n_bins: usize,
pub strategy: BinStrategy,
}
impl MtfConfig {
pub fn new() -> Self {
Self {
image_size: None,
n_bins: 5,
strategy: BinStrategy::Quantile,
}
}
}
impl Default for MtfConfig {
fn default() -> Self {
Self::new()
}
}
pub struct Mtf;
impl Mtf {
pub fn transform(config: &MtfConfig, x: &[Vec<f64>]) -> Vec<Vec<Vec<f64>>> {
assert!(!x.is_empty(), "Input must have at least one sample");
let n_timestamps = x[0].len();
let image_size = config.image_size.unwrap_or(n_timestamps);
let n_bins = config.n_bins;
let strategy = config.strategy;
let disc_config = KBinsDiscretizerConfig { n_bins, strategy };
let build = |sample: &Vec<f64>| {
let bins = KBinsDiscretizer::transform(&disc_config, std::slice::from_ref(sample))
.into_iter()
.next()
.unwrap();
let n = bins.len();
let mut transition = vec![vec![0.0; n_bins]; n_bins];
let mut row_counts = vec![0usize; n_bins];
for i in 0..n - 1 {
let from = bins[i];
let to = bins[i + 1];
transition[from][to] += 1.0;
row_counts[from] += 1;
}
for i in 0..n_bins {
if row_counts[i] > 0 {
let count = row_counts[i] as f64;
for j in 0..n_bins {
transition[i][j] /= count;
}
}
}
(0..n)
.map(|i| {
let row = &transition[bins[i]];
(0..n).map(|j| row[bins[j]]).collect()
})
.collect()
};
#[cfg(feature = "parallel")]
let mtf_images: Vec<Vec<Vec<f64>>> = {
use rayon::prelude::*;
x.par_iter().map(build).collect()
};
#[cfg(not(feature = "parallel"))]
let mtf_images: Vec<Vec<Vec<f64>>> = x.iter().map(build).collect();
if image_size < n_timestamps {
mtf_images
.into_iter()
.map(|img| reduce_image(&img, image_size))
.collect()
} else {
mtf_images
}
}
}
fn reduce_image(image: &[Vec<f64>], size: usize) -> Vec<Vec<f64>> {
let paa_config = PaaConfig::new(size);
let row_reduced = Paa::transform(&paa_config, image);
let n = row_reduced.len();
let m = row_reduced[0].len();
let transposed: Vec<Vec<f64>> = (0..m)
.map(|j| (0..n).map(|i| row_reduced[i][j]).collect())
.collect();
let col_reduced = Paa::transform(&paa_config, &transposed);
let n2 = col_reduced.len();
let m2 = col_reduced[0].len();
(0..m2)
.map(|j| (0..n2).map(|i| col_reduced[i][j]).collect())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mtf_shape() {
let config = MtfConfig::new();
let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
let result = Mtf::transform(&config, &x);
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 5);
assert_eq!(result[0][0].len(), 5);
}
#[test]
fn test_mtf_probabilities() {
let config = MtfConfig {
n_bins: 3,
..MtfConfig::new()
};
let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
let result = Mtf::transform(&config, &x);
for row in &result[0] {
for &v in row {
assert!(
(-1e-10..=1.0 + 1e-10).contains(&v),
"MTF value {v} out of range"
);
}
}
}
#[test]
fn test_mtf_reduced_size() {
let config = MtfConfig {
image_size: Some(3),
..MtfConfig::new()
};
let x = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]];
let result = Mtf::transform(&config, &x);
assert_eq!(result[0].len(), 3);
assert_eq!(result[0][0].len(), 3);
}
}