use crate::utils::text::canonical_tokens_from_text;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct KeywordMatch {
pub id: String,
pub score: f64,
pub matched_terms: Vec<String>,
}
pub fn extract_keywords(text: &str, min_length: Option<usize>) -> HashSet<String> {
let min_len = min_length.unwrap_or(3);
let tokens = canonical_tokens_from_text(text);
let mut keywords = HashSet::new();
for token in &tokens {
if token.len() >= min_len {
keywords.insert(token.clone());
if token.len() >= 3 {
let chars: Vec<char> = token.chars().collect();
for i in 0..=chars.len().saturating_sub(3) {
let trigram: String = chars[i..i + 3].iter().collect();
keywords.insert(trigram);
}
}
}
}
for i in 0..tokens.len().saturating_sub(1) {
let bigram = format!("{}_{}", tokens[i], tokens[i + 1]);
if bigram.len() >= min_len {
keywords.insert(bigram);
}
}
for i in 0..tokens.len().saturating_sub(2) {
let trigram = format!("{}_{}_{}", tokens[i], tokens[i + 1], tokens[i + 2]);
keywords.insert(trigram);
}
keywords
}
pub fn compute_keyword_overlap(
query_keywords: &HashSet<String>,
content_keywords: &HashSet<String>,
) -> f64 {
let mut matches = 0.0;
let mut total_weight = 0.0;
for qk in query_keywords {
let weight = if qk.contains('_') { 2.0 } else { 1.0 };
if content_keywords.contains(qk) {
matches += weight;
}
total_weight += weight;
}
if total_weight == 0.0 {
0.0
} else {
matches / total_weight
}
}
pub fn exact_phrase_match(query: &str, content: &str) -> bool {
let q_norm = query.to_lowercase();
let q_trimmed = q_norm.trim();
let c_norm = content.to_lowercase();
c_norm.contains(q_trimmed)
}
pub fn compute_bm25_score(
query_terms: &[String],
content_terms: &[String],
corpus_size: Option<usize>,
avg_doc_length: Option<usize>,
) -> f64 {
let corpus = corpus_size.unwrap_or(10000) as f64;
let avg_len = avg_doc_length.unwrap_or(100) as f64;
const K1: f64 = 1.5;
const B: f64 = 0.75;
let mut term_freq: HashMap<&str, usize> = HashMap::new();
for term in content_terms {
*term_freq.entry(term.as_str()).or_insert(0) += 1;
}
let doc_length = content_terms.len() as f64;
let mut score = 0.0;
for q_term in query_terms {
let tf = *term_freq.get(q_term.as_str()).unwrap_or(&0) as f64;
if tf == 0.0 {
continue;
}
let idf = ((corpus + 1.0) / (tf + 0.5)).ln();
let numerator = tf * (K1 + 1.0);
let denominator = tf + K1 * (1.0 - B + B * (doc_length / avg_len));
score += idf * (numerator / denominator);
}
score
}
pub fn keyword_filter_memories(
query: &str,
memories: &[(String, String)],
threshold: Option<f64>,
min_keyword_length: Option<usize>,
) -> HashMap<String, f64> {
let thresh = threshold.unwrap_or(0.1);
let min_len = min_keyword_length.unwrap_or(3);
let query_keywords = extract_keywords(query, Some(min_len));
let query_terms = canonical_tokens_from_text(query);
let mut scores = HashMap::new();
for (id, content) in memories {
let mut total_score = 0.0;
if exact_phrase_match(query, content) {
total_score += 1.0;
}
let content_keywords = extract_keywords(content, Some(min_len));
let keyword_score = compute_keyword_overlap(&query_keywords, &content_keywords);
total_score += keyword_score * 0.8;
let content_terms = canonical_tokens_from_text(content);
let bm25_score = compute_bm25_score(&query_terms, &content_terms, None, None);
total_score += (bm25_score / 10.0).min(1.0) * 0.5;
if total_score > thresh {
scores.insert(id.clone(), total_score);
}
}
scores
}
pub fn compute_tfidf(
document: &str,
document_frequencies: &HashMap<String, usize>,
total_documents: usize,
) -> HashMap<String, f64> {
let tokens = canonical_tokens_from_text(document);
let mut tf: HashMap<String, usize> = HashMap::new();
for token in &tokens {
*tf.entry(token.clone()).or_insert(0) += 1;
}
let doc_length = tokens.len() as f64;
let total_docs = total_documents as f64;
let mut tfidf = HashMap::new();
for (term, count) in tf {
let term_freq = count as f64 / doc_length;
let df = *document_frequencies.get(&term).unwrap_or(&1) as f64;
let idf = (total_docs / df).ln() + 1.0;
tfidf.insert(term, term_freq * idf);
}
tfidf
}
pub fn top_k_keywords(tfidf: &HashMap<String, f64>, k: usize) -> Vec<(String, f64)> {
let mut scores: Vec<_> = tfidf.iter().map(|(k, v)| (k.clone(), *v)).collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_keywords() {
let keywords = extract_keywords("hello world test", None);
assert!(keywords.contains("hello"));
assert!(keywords.contains("world"));
assert!(keywords.contains("test"));
assert!(keywords.contains("hello_world"));
assert!(keywords.contains("world_test"));
assert!(keywords.contains("hel"));
assert!(keywords.contains("ell"));
}
#[test]
fn test_compute_keyword_overlap() {
let query: HashSet<_> = vec!["hello".to_string(), "world".to_string()]
.into_iter()
.collect();
let content: HashSet<_> = vec!["hello".to_string(), "there".to_string()]
.into_iter()
.collect();
let overlap = compute_keyword_overlap(&query, &content);
assert!((overlap - 0.5).abs() < 0.01); }
#[test]
fn test_exact_phrase_match() {
assert!(exact_phrase_match("hello", "say hello world"));
assert!(exact_phrase_match("HELLO", "say hello world"));
assert!(!exact_phrase_match("goodbye", "say hello world"));
}
#[test]
fn test_compute_bm25_score() {
let query = vec!["hello".to_string(), "world".to_string()];
let content = vec![
"hello".to_string(),
"world".to_string(),
"test".to_string(),
];
let score = compute_bm25_score(&query, &content, None, None);
assert!(score > 0.0);
}
#[test]
fn test_bm25_no_matches() {
let query = vec!["foo".to_string()];
let content = vec!["bar".to_string(), "baz".to_string()];
let score = compute_bm25_score(&query, &content, None, None);
assert!(score.abs() < 1e-6);
}
#[test]
fn test_keyword_filter_memories() {
let memories = vec![
("1".to_string(), "hello world programming".to_string()),
("2".to_string(), "goodbye moon".to_string()),
("3".to_string(), "hello there".to_string()),
];
let scores = keyword_filter_memories("hello", &memories, None, None);
assert!(scores.contains_key("1"));
assert!(scores.contains_key("3"));
}
#[test]
fn test_tfidf() {
let mut doc_freq = HashMap::new();
doc_freq.insert("hello".to_string(), 5);
doc_freq.insert("world".to_string(), 2);
doc_freq.insert("rare".to_string(), 1);
let tfidf = compute_tfidf("hello world rare", &doc_freq, 100);
assert!(tfidf.get("rare").unwrap() > tfidf.get("hello").unwrap());
}
#[test]
fn test_top_k_keywords() {
let mut tfidf = HashMap::new();
tfidf.insert("low".to_string(), 0.1);
tfidf.insert("medium".to_string(), 0.5);
tfidf.insert("high".to_string(), 0.9);
let top = top_k_keywords(&tfidf, 2);
assert_eq!(top.len(), 2);
assert_eq!(top[0].0, "high");
assert_eq!(top[1].0, "medium");
}
}