use std::collections::{HashMap, HashSet};
use crate::core::config::TfIdfVariant;
pub fn build_vocabulary(corpus: &[Vec<String>]) -> Vec<String> {
let mut vocab: HashSet<String> = HashSet::new();
for doc in corpus {
for word in doc {
vocab.insert(word.clone());
}
}
let mut sorted: Vec<String> = vocab.into_iter().collect();
sorted.sort();
sorted
}
fn compute_doc_freq(
corpus: &[Vec<String>],
word_idx: &HashMap<&str, usize>,
n_vocab: usize,
) -> Vec<usize> {
let mut doc_freq = vec![0usize; n_vocab];
for doc in corpus {
let unique_words: HashSet<&str> = doc.iter().map(|w| w.as_str()).collect();
for word in unique_words {
if let Some(&idx) = word_idx.get(word) {
doc_freq[idx] += 1;
}
}
}
doc_freq
}
fn tf_value(count: usize, variant: TfIdfVariant) -> f64 {
match variant {
TfIdfVariant::Tf | TfIdfVariant::TfIdf => count as f64,
TfIdfVariant::SublinearTfIdf => {
if count > 0 {
1.0 + (count as f64).ln()
} else {
0.0
}
}
}
}
fn fill_tfidf_row(
doc: &[String],
row: &mut [f64],
word_idx: &HashMap<&str, usize>,
idf: &[f64],
variant: TfIdfVariant,
) {
let mut tf: HashMap<&str, usize> = HashMap::new();
for word in doc {
*tf.entry(word.as_str()).or_insert(0) += 1;
}
let apply_idf = matches!(variant, TfIdfVariant::TfIdf | TfIdfVariant::SublinearTfIdf);
for (word, &count) in &tf {
if let Some(&idx) = word_idx.get(word) {
let tf_val = tf_value(count, variant);
row[idx] = if apply_idf { tf_val * idf[idx] } else { tf_val };
}
}
}
fn l2_normalize(row: &mut [f64]) {
let norm: f64 = row.iter().map(|v| v * v).sum::<f64>().sqrt();
if norm > 0.0 {
for v in row.iter_mut() {
*v /= norm;
}
}
}
pub fn tfidf_vectorize(
corpus: &[Vec<String>],
variant: TfIdfVariant,
) -> (Vec<String>, Vec<Vec<f64>>) {
assert!(!corpus.is_empty(), "Corpus must not be empty");
let vocab = build_vocabulary(corpus);
let n_docs = corpus.len();
let word_idx: HashMap<&str, usize> = vocab
.iter()
.enumerate()
.map(|(i, w)| (w.as_str(), i))
.collect();
let doc_freq = compute_doc_freq(corpus, &word_idx, vocab.len());
let idf: Vec<f64> = doc_freq
.iter()
.map(|&df| ((1.0 + n_docs as f64) / (1.0 + df as f64)).ln() + 1.0)
.collect();
let normalize = matches!(variant, TfIdfVariant::TfIdf | TfIdfVariant::SublinearTfIdf);
let mut matrix = vec![vec![0.0; vocab.len()]; n_docs];
for (i, doc) in corpus.iter().enumerate() {
fill_tfidf_row(doc, &mut matrix[i], &word_idx, &idf, variant);
if normalize {
l2_normalize(&mut matrix[i]);
}
}
(vocab, matrix)
}
pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "Vectors must have same length");
let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_vocabulary() {
let corpus = vec![
vec!["a".into(), "b".into(), "a".into()],
vec!["b".into(), "c".into()],
];
let vocab = build_vocabulary(&corpus);
assert_eq!(vocab, vec!["a", "b", "c"]);
}
#[test]
fn test_tfidf_basic() {
let corpus = vec![
vec!["hello".into(), "world".into()],
vec!["hello".into(), "rust".into()],
];
let (vocab, matrix) = tfidf_vectorize(&corpus, TfIdfVariant::TfIdf);
assert_eq!(vocab.len(), 3); assert_eq!(matrix.len(), 2);
assert_eq!(matrix[0].len(), 3);
}
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&a, &a);
assert!((sim - 1.0).abs() < 1e-10);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-10);
}
}