use crate::error::AprenderError;
use crate::index::hnsw::HNSWIndex;
use crate::primitives::Vector;
use crate::text::incremental_idf::IncrementalIDF;
use crate::text::tokenize::WhitespaceTokenizer;
use crate::text::Tokenizer;
use std::collections::HashMap;
#[derive(Debug)]
pub struct ContentRecommender {
hnsw: HNSWIndex,
idf: IncrementalIDF,
item_content: HashMap<String, String>,
tokenizer: WhitespaceTokenizer,
}
impl ContentRecommender {
#[must_use]
pub fn new(m: usize, ef_construction: usize, decay_factor: f64) -> Self {
Self {
hnsw: HNSWIndex::new(m, ef_construction, 0.0),
idf: IncrementalIDF::new(decay_factor),
item_content: HashMap::new(),
tokenizer: WhitespaceTokenizer::new(),
}
}
pub fn add_item(&mut self, item_id: impl Into<String>, content: impl Into<String>) {
let item_id = item_id.into();
let content = content.into();
let vocab_size_before = self.idf.len();
let tokens: Vec<String> = self.tokenizer.tokenize(&content).unwrap_or_default();
let unique_terms: Vec<String> = tokens
.iter()
.map(|s| s.to_lowercase())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let term_refs: Vec<&str> = unique_terms.iter().map(String::as_str).collect();
self.idf.update(&term_refs);
self.item_content.insert(item_id.clone(), content);
let vocab_size_after = self.idf.len();
if vocab_size_after > vocab_size_before && !self.item_content.is_empty() {
self.rebuild_index();
} else {
let tfidf_vec = self.compute_tfidf(&tokens);
self.hnsw.add(item_id, tfidf_vec);
}
}
pub fn recommend(&self, item_id: &str, k: usize) -> Result<Vec<(String, f64)>, AprenderError> {
let content = self
.item_content
.get(item_id)
.ok_or_else(|| AprenderError::Other(format!("Item not found: {item_id}")))?;
let tokens = self.tokenizer.tokenize(content)?;
let query_vec = self.compute_tfidf(&tokens);
let results = self.hnsw.search(&query_vec, k + 1);
let recommendations: Vec<(String, f64)> = results
.into_iter()
.filter(|(id, _)| id != item_id)
.take(k)
.map(|(id, dist)| {
let similarity = 1.0 - dist;
(id, similarity)
})
.collect();
Ok(recommendations)
}
#[must_use]
pub fn len(&self) -> usize {
self.item_content.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.item_content.is_empty()
}
fn rebuild_index(&mut self) {
self.hnsw = HNSWIndex::new(self.hnsw.m(), self.hnsw.ef_construction(), 0.0);
for (item_id, content) in &self.item_content {
let tokens: Vec<String> = self.tokenizer.tokenize(content).unwrap_or_default();
let tfidf_vec = self.compute_tfidf(&tokens);
self.hnsw.add(item_id.clone(), tfidf_vec);
}
}
fn compute_tfidf(&self, tokens: &[String]) -> Vector<f64> {
let mut tf: HashMap<String, f64> = HashMap::new();
for token in tokens {
let term = token.to_lowercase();
*tf.entry(term).or_insert(0.0) += 1.0;
}
let max_tf = tf.values().copied().fold(0.0, f64::max);
if max_tf > 0.0 {
for value in tf.values_mut() {
*value /= max_tf;
}
}
let mut vocab: Vec<String> = self.idf.terms().keys().cloned().collect();
vocab.sort();
let tfidf: Vec<f64> = vocab
.iter()
.map(|term| {
let tf_val = tf.get(term).copied().unwrap_or(0.0);
let idf_val = self.idf.idf(term);
tf_val * idf_val
})
.collect();
Vector::from_vec(tfidf)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_recommender() {
let rec = ContentRecommender::new(16, 200, 0.95);
assert!(rec.is_empty());
assert_eq!(rec.len(), 0);
}
#[test]
fn test_add_single_item() {
let mut rec = ContentRecommender::new(16, 200, 0.95);
rec.add_item("item1", "machine learning");
assert_eq!(rec.len(), 1);
assert!(!rec.is_empty());
}
#[test]
fn test_recommend_similar_items() {
let mut rec = ContentRecommender::new(16, 200, 0.95);
rec.add_item("ml_intro", "machine learning introduction");
rec.add_item("dl_guide", "deep learning neural networks");
rec.add_item("ml_practice", "machine learning applications");
let similar = rec.recommend("ml_intro", 2).expect("should succeed");
assert_eq!(similar.len(), 2);
assert_eq!(similar[0].0, "ml_practice");
}
#[test]
fn test_recommend_nonexistent_item() {
let mut rec = ContentRecommender::new(16, 200, 0.95);
rec.add_item("item1", "content");
let result = rec.recommend("nonexistent", 1);
assert!(result.is_err());
if let Err(err) = result {
assert!(matches!(err, AprenderError::Other(_)));
}
}
#[test]
fn test_similarity_scores() {
let mut rec = ContentRecommender::new(16, 200, 0.95);
rec.add_item("a", "machine learning");
rec.add_item("b", "deep learning");
rec.add_item("c", "data science");
let similar = rec.recommend("a", 2).expect("should succeed");
assert_eq!(similar.len(), 2);
for (id, sim) in &similar {
assert!(
sim.is_finite(),
"Similarity for {id} should be finite, got {sim}"
);
}
assert_eq!(similar[0].0, "b");
assert!(
similar[0].1 > 0.0,
"Similarity should be positive: {}",
similar[0].1
);
}
#[test]
fn test_empty_content() {
let mut rec = ContentRecommender::new(16, 200, 0.95);
rec.add_item("empty", "");
rec.add_item("normal", "machine learning");
let similar = rec.recommend("normal", 1);
assert!(similar.is_ok());
}
#[test]
fn test_case_insensitive() {
let mut rec = ContentRecommender::new(16, 200, 0.95);
rec.add_item("a", "Machine Learning");
rec.add_item("b", "machine learning");
rec.add_item("c", "MACHINE LEARNING");
let similar = rec.recommend("a", 2).expect("should succeed");
assert_eq!(similar.len(), 2);
for (_, sim) in similar {
assert!(
sim > 0.9,
"Similar terms should have high similarity despite case"
);
}
}
}