use std::collections::HashMap;
use crate::core::config::TfIdfVariant;
use crate::transformation::boss::{Boss, BossConfig, BossFitted};
use crate::utils::tfidf::{build_vocabulary, cosine_similarity};
#[derive(Debug, Clone)]
pub struct BossvsConfig {
pub window_size: usize,
pub word_size: usize,
pub n_bins: usize,
pub variant: TfIdfVariant,
}
impl BossvsConfig {
pub fn new(window_size: usize) -> Self {
Self {
window_size,
word_size: 4,
n_bins: 4,
variant: TfIdfVariant::SublinearTfIdf,
}
}
}
#[derive(Debug, Clone)]
pub struct BossvsFitted {
pub boss_fitted: BossFitted,
pub vocabulary: Vec<String>,
pub class_vectors: Vec<(String, Vec<f64>)>,
}
pub struct Bossvs;
impl Bossvs {
pub fn fit(config: &BossvsConfig, x: &[Vec<f64>], y: &[String]) -> BossvsFitted {
assert!(!x.is_empty(), "Input must have at least one sample");
assert_eq!(x.len(), y.len(), "X and y must have same length");
let boss_config = BossConfig {
n_bins: config.n_bins,
..BossConfig::new(config.window_size, config.word_size)
};
let (boss_fitted, histograms) = Boss::fit_with_histograms(&boss_config, x, Some(y));
let all_words: Vec<Vec<String>> = histograms
.iter()
.map(|h| h.keys().cloned().collect())
.collect();
let vocabulary = build_vocabulary(&all_words);
let word_idx: HashMap<&str, usize> = vocabulary
.iter()
.enumerate()
.map(|(i, w)| (w.as_str(), i))
.collect();
let mut classes: Vec<String> = y.to_vec();
classes.sort();
classes.dedup();
let n_docs = histograms.len();
let n_features = vocabulary.len();
let mut doc_freq = vec![0usize; n_features];
for hist in &histograms {
for word in hist.keys() {
if let Some(&idx) = word_idx.get(word.as_str()) {
doc_freq[idx] += 1;
}
}
}
let idf: Vec<f64> = doc_freq
.iter()
.map(|&df| ((1.0 + n_docs as f64) / (1.0 + df as f64)).ln() + 1.0)
.collect();
let mut class_vectors = Vec::new();
for class in &classes {
let mut class_vec = vec![0.0; n_features];
for (i, label) in y.iter().enumerate() {
if label == class {
for (word, &count) in &histograms[i] {
if let Some(&idx) = word_idx.get(word.as_str()) {
let tf = match config.variant {
TfIdfVariant::SublinearTfIdf => {
if count > 0 {
1.0 + (count as f64).ln()
} else {
0.0
}
}
_ => count as f64,
};
class_vec[idx] += tf * idf[idx];
}
}
}
}
let norm: f64 = class_vec.iter().map(|v| v * v).sum::<f64>().sqrt();
if norm > 0.0 {
for v in &mut class_vec {
*v /= norm;
}
}
class_vectors.push((class.clone(), class_vec));
}
BossvsFitted {
boss_fitted,
vocabulary,
class_vectors,
}
}
pub fn predict(fitted: &BossvsFitted, x: &[Vec<f64>]) -> Vec<String> {
let histograms = Boss::transform(&fitted.boss_fitted, x);
let word_idx: HashMap<&str, usize> = fitted
.vocabulary
.iter()
.enumerate()
.map(|(i, w)| (w.as_str(), i))
.collect();
let n_features = fitted.vocabulary.len();
histograms
.iter()
.map(|hist| {
let mut sample_vec = vec![0.0; n_features];
for (word, &count) in hist {
if let Some(&idx) = word_idx.get(word.as_str()) {
sample_vec[idx] = count as f64;
}
}
fitted
.class_vectors
.iter()
.map(|(class, class_vec)| {
(class.as_str(), cosine_similarity(&sample_vec, class_vec))
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(class, _)| class.to_string())
.unwrap()
})
.collect()
}
pub fn score(fitted: &BossvsFitted, x: &[Vec<f64>], y: &[String]) -> f64 {
let predictions = Self::predict(fitted, x);
let correct = predictions
.iter()
.zip(y.iter())
.filter(|(p, t)| p == t)
.count();
correct as f64 / y.len() as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bossvs_basic() {
let config = BossvsConfig::new(4);
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0],
vec![7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
vec![8.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
];
let y = vec![
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
];
let fitted = Bossvs::fit(&config, &x, &y);
let predictions = Bossvs::predict(&fitted, &x);
assert_eq!(predictions.len(), 4);
}
#[test]
fn test_bossvs_score() {
let config = BossvsConfig::new(3);
let x = vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
vec![0.0, 1.0, 2.0, 3.0, 4.0, 6.0],
vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0],
vec![6.0, 4.0, 3.0, 2.0, 1.0, 0.0],
];
let y = vec![
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
];
let fitted = Bossvs::fit(&config, &x, &y);
let score = Bossvs::score(&fitted, &x, &y);
assert!((0.0..=1.0).contains(&score));
}
}