use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct IncrementalIDF {
doc_freq: HashMap<String, f64>,
total_docs: f64,
decay_factor: f64,
}
impl IncrementalIDF {
#[must_use]
pub fn new(decay_factor: f64) -> Self {
assert!(
(0.0..1.0).contains(&decay_factor),
"Decay factor must be in (0, 1)"
);
Self {
doc_freq: HashMap::new(),
total_docs: 0.0,
decay_factor,
}
}
pub fn update(&mut self, terms: &[&str]) {
self.total_docs *= self.decay_factor;
for freq in self.doc_freq.values_mut() {
*freq *= self.decay_factor;
}
self.total_docs += 1.0;
for &term in terms {
*self.doc_freq.entry(term.to_string()).or_insert(0.0) += 1.0;
}
}
#[must_use]
pub fn idf(&self, term: &str) -> f64 {
let df = self.doc_freq.get(term).copied().unwrap_or(0.0);
((self.total_docs + 1.0) / (df + 1.0)).ln() + 1.0
}
#[must_use]
pub fn terms(&self) -> HashMap<String, f64> {
self.doc_freq
.keys()
.map(|term| (term.clone(), self.idf(term)))
.collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.doc_freq.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.doc_freq.is_empty()
}
#[must_use]
pub fn total_docs(&self) -> f64 {
self.total_docs
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_tracker() {
let idf = IncrementalIDF::new(0.95);
assert!(idf.is_empty());
assert_eq!(idf.len(), 0);
assert!((idf.total_docs() - 0.0).abs() < 1e-10);
}
#[test]
fn test_single_document() {
let mut idf = IncrementalIDF::new(0.95);
idf.update(&["hello", "world"]);
assert_eq!(idf.len(), 2);
assert!((idf.total_docs() - 1.0).abs() < 1e-10);
let hello_idf = idf.idf("hello");
let world_idf = idf.idf("world");
assert!((hello_idf - world_idf).abs() < 1e-10);
}
#[test]
fn test_multiple_documents() {
let mut idf = IncrementalIDF::new(0.95);
idf.update(&["machine", "learning"]);
idf.update(&["machine", "intelligence"]);
idf.update(&["deep", "learning"]);
let machine_idf = idf.idf("machine");
let deep_idf = idf.idf("deep");
assert!(
deep_idf > machine_idf,
"deep_idf={deep_idf}, machine_idf={machine_idf}"
);
}
#[test]
fn test_exponential_decay() {
let mut idf = IncrementalIDF::new(0.5);
idf.update(&["old"]);
for _ in 0..10 {
idf.update(&["new"]);
}
let old_df = idf.doc_freq.get("old").copied().unwrap_or(0.0);
assert!(old_df < 0.01, "Old term should have decayed, df={old_df}");
}
#[test]
fn test_unseen_term() {
let mut idf = IncrementalIDF::new(0.95);
idf.update(&["hello"]);
let unseen_idf = idf.idf("unseen");
assert!(unseen_idf > 0.0);
}
#[test]
fn test_terms_map() {
let mut idf = IncrementalIDF::new(0.95);
idf.update(&["alpha", "beta"]);
let terms = idf.terms();
assert_eq!(terms.len(), 2);
assert!(terms.contains_key("alpha"));
assert!(terms.contains_key("beta"));
}
#[test]
#[should_panic(expected = "Decay factor must be in (0, 1)")]
fn test_invalid_decay_factor() {
let _ = IncrementalIDF::new(1.5);
}
#[test]
fn test_idf_monotonicity() {
let mut idf = IncrementalIDF::new(0.95);
idf.update(&["common", "doc1"]);
idf.update(&["common", "doc2"]);
idf.update(&["common", "doc3"]);
let common_idf = idf.idf("common");
let doc1_idf = idf.idf("doc1");
assert!(doc1_idf > common_idf);
}
}