use std::collections::HashMap;
use crate::bag_of_words::bag_of_words::{BagOfWords, BagOfWordsConfig};
use crate::core::config::{BinStrategy, TfIdfVariant};
use crate::utils::tfidf::{cosine_similarity, tfidf_vectorize};
#[derive(Debug, Clone)]
pub struct SaxvsmConfig {
pub window_size: usize,
pub word_size: usize,
pub n_bins: usize,
pub strategy: BinStrategy,
pub window_step: usize,
pub variant: TfIdfVariant,
}
impl SaxvsmConfig {
pub fn new(window_size: usize, word_size: usize) -> Self {
Self {
window_size,
word_size,
n_bins: 4,
strategy: BinStrategy::Normal,
window_step: 1,
variant: TfIdfVariant::SublinearTfIdf,
}
}
}
#[derive(Debug, Clone)]
pub struct SaxvsmFitted {
pub config: SaxvsmConfig,
pub vocabulary: Vec<String>,
pub class_vectors: Vec<(String, Vec<f64>)>,
}
pub struct Saxvsm;
impl Saxvsm {
pub fn fit(config: &SaxvsmConfig, x: &[Vec<f64>], y: &[String]) -> SaxvsmFitted {
assert!(!x.is_empty(), "Input must have at least one sample");
assert_eq!(x.len(), y.len(), "X and y must have same length");
let bow_config = BagOfWordsConfig {
window_size: config.window_size,
word_size: config.word_size,
n_bins: config.n_bins,
strategy: config.strategy,
numerosity_reduction: crate::core::config::NumerosityReduction::IdenticalConsecutive,
window_step: config.window_step,
};
let bow_strings = BagOfWords::transform(&bow_config, x);
let corpus: Vec<Vec<String>> = bow_strings
.iter()
.map(|s| s.split_whitespace().map(|w| w.to_string()).collect())
.collect();
let mut classes: Vec<String> = y.to_vec();
classes.sort();
classes.dedup();
let mut class_docs: Vec<Vec<String>> = Vec::new();
for class in &classes {
let mut doc = Vec::new();
for (i, label) in y.iter().enumerate() {
if label == class {
doc.extend(corpus[i].clone());
}
}
class_docs.push(doc);
}
let (vocabulary, matrix) = tfidf_vectorize(&class_docs, config.variant);
let class_vectors: Vec<(String, Vec<f64>)> = classes.into_iter().zip(matrix).collect();
SaxvsmFitted {
config: config.clone(),
vocabulary,
class_vectors,
}
}
pub fn predict(fitted: &SaxvsmFitted, x: &[Vec<f64>]) -> Vec<String> {
let bow_config = BagOfWordsConfig {
window_size: fitted.config.window_size,
word_size: fitted.config.word_size,
n_bins: fitted.config.n_bins,
strategy: fitted.config.strategy,
numerosity_reduction: crate::core::config::NumerosityReduction::IdenticalConsecutive,
window_step: fitted.config.window_step,
};
let bow_strings = BagOfWords::transform(&bow_config, x);
let word_idx: HashMap<&str, usize> = fitted
.vocabulary
.iter()
.enumerate()
.map(|(i, w)| (w.as_str(), i))
.collect();
let n_features = fitted.vocabulary.len();
bow_strings
.iter()
.map(|bow| {
let words: Vec<&str> = bow.split_whitespace().collect();
let mut vec = vec![0.0; n_features];
for word in &words {
if let Some(&idx) = word_idx.get(word) {
vec[idx] += 1.0;
}
}
fitted
.class_vectors
.iter()
.map(|(class, class_vec)| (class.as_str(), cosine_similarity(&vec, class_vec)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(class, _)| class.to_string())
.unwrap()
})
.collect()
}
pub fn score(fitted: &SaxvsmFitted, x: &[Vec<f64>], y: &[String]) -> f64 {
let predictions = Self::predict(fitted, x);
let correct = predictions
.iter()
.zip(y.iter())
.filter(|(p, t)| p == t)
.count();
correct as f64 / y.len() as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_saxvsm_basic() {
let config = SaxvsmConfig::new(4, 2);
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0],
vec![7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
vec![8.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
];
let y = vec![
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
];
let fitted = Saxvsm::fit(&config, &x, &y);
let predictions = Saxvsm::predict(&fitted, &x);
assert_eq!(predictions.len(), 4);
}
#[test]
fn test_saxvsm_score() {
let config = SaxvsmConfig::new(3, 2);
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
vec![0.0, 1.0, 2.0, 3.0, 4.0, 6.0],
vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
vec![6.0, 4.0, 3.0, 2.0, 1.0, 0.0],
];
let y = vec![
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
];
let fitted = Saxvsm::fit(&config, &x, &y);
let score = Saxvsm::score(&fitted, &x, &y);
assert!((0.0..=1.0).contains(&score));
}
}