openmemory 0.1.1

OpenMemory - Cognitive memory system for AI applications
Documentation
//! Keyword extraction and matching utilities
//!
//! This module provides TF-IDF based keyword extraction, BM25 scoring,
//! and various text matching algorithms.

use crate::utils::text::canonical_tokens_from_text;
use std::collections::{HashMap, HashSet};

/// Result of keyword matching
#[derive(Debug, Clone)]
pub struct KeywordMatch {
    /// Memory ID
    pub id: String,
    /// Match score
    pub score: f64,
    /// Terms that matched
    pub matched_terms: Vec<String>,
}

/// Extract keywords from text
///
/// Extracts tokens, character n-grams, and word n-grams as keywords.
///
/// # Arguments
/// * `text` - The text to extract keywords from
/// * `min_length` - Minimum keyword length (default: 3)
pub fn extract_keywords(text: &str, min_length: Option<usize>) -> HashSet<String> {
    let min_len = min_length.unwrap_or(3);
    let tokens = canonical_tokens_from_text(text);
    let mut keywords = HashSet::new();

    // Add tokens and character trigrams
    for token in &tokens {
        if token.len() >= min_len {
            keywords.insert(token.clone());

            // Character trigrams
            if token.len() >= 3 {
                let chars: Vec<char> = token.chars().collect();
                for i in 0..=chars.len().saturating_sub(3) {
                    let trigram: String = chars[i..i + 3].iter().collect();
                    keywords.insert(trigram);
                }
            }
        }
    }

    // Word bigrams
    for i in 0..tokens.len().saturating_sub(1) {
        let bigram = format!("{}_{}", tokens[i], tokens[i + 1]);
        if bigram.len() >= min_len {
            keywords.insert(bigram);
        }
    }

    // Word trigrams
    for i in 0..tokens.len().saturating_sub(2) {
        let trigram = format!("{}_{}_{}", tokens[i], tokens[i + 1], tokens[i + 2]);
        keywords.insert(trigram);
    }

    keywords
}

/// Compute keyword overlap score between query and content
///
/// Returns a weighted overlap score where n-grams (containing '_') count double.
pub fn compute_keyword_overlap(
    query_keywords: &HashSet<String>,
    content_keywords: &HashSet<String>,
) -> f64 {
    let mut matches = 0.0;
    let mut total_weight = 0.0;

    for qk in query_keywords {
        let weight = if qk.contains('_') { 2.0 } else { 1.0 };

        if content_keywords.contains(qk) {
            matches += weight;
        }
        total_weight += weight;
    }

    if total_weight == 0.0 {
        0.0
    } else {
        matches / total_weight
    }
}

/// Check if query phrase appears exactly in content
pub fn exact_phrase_match(query: &str, content: &str) -> bool {
    let q_norm = query.to_lowercase();
    let q_trimmed = q_norm.trim();
    let c_norm = content.to_lowercase();
    c_norm.contains(q_trimmed)
}

/// Compute BM25 score for a query against content
///
/// BM25 is a ranking function used by search engines to rank matching documents.
///
/// # Arguments
/// * `query_terms` - Tokenized query terms
/// * `content_terms` - Tokenized content terms
/// * `corpus_size` - Estimated corpus size (default: 10000)
/// * `avg_doc_length` - Estimated average document length (default: 100)
pub fn compute_bm25_score(
    query_terms: &[String],
    content_terms: &[String],
    corpus_size: Option<usize>,
    avg_doc_length: Option<usize>,
) -> f64 {
    let corpus = corpus_size.unwrap_or(10000) as f64;
    let avg_len = avg_doc_length.unwrap_or(100) as f64;

    const K1: f64 = 1.5;
    const B: f64 = 0.75;

    // Build term frequency map for content
    let mut term_freq: HashMap<&str, usize> = HashMap::new();
    for term in content_terms {
        *term_freq.entry(term.as_str()).or_insert(0) += 1;
    }

    let doc_length = content_terms.len() as f64;
    let mut score = 0.0;

    for q_term in query_terms {
        let tf = *term_freq.get(q_term.as_str()).unwrap_or(&0) as f64;
        if tf == 0.0 {
            continue;
        }

        // IDF calculation
        let idf = ((corpus + 1.0) / (tf + 0.5)).ln();

        // BM25 formula
        let numerator = tf * (K1 + 1.0);
        let denominator = tf + K1 * (1.0 - B + B * (doc_length / avg_len));

        score += idf * (numerator / denominator);
    }

    score
}

/// Filter memories by keyword relevance
///
/// Returns a map of memory IDs to their keyword match scores.
///
/// # Arguments
/// * `query` - The search query
/// * `memories` - List of (id, content) pairs
/// * `threshold` - Minimum score threshold (default: 0.1)
/// * `min_keyword_length` - Minimum keyword length (default: 3)
pub fn keyword_filter_memories(
    query: &str,
    memories: &[(String, String)],
    threshold: Option<f64>,
    min_keyword_length: Option<usize>,
) -> HashMap<String, f64> {
    let thresh = threshold.unwrap_or(0.1);
    let min_len = min_keyword_length.unwrap_or(3);

    let query_keywords = extract_keywords(query, Some(min_len));
    let query_terms = canonical_tokens_from_text(query);
    let mut scores = HashMap::new();

    for (id, content) in memories {
        let mut total_score = 0.0;

        // Exact phrase match bonus
        if exact_phrase_match(query, content) {
            total_score += 1.0;
        }

        // Keyword overlap
        let content_keywords = extract_keywords(content, Some(min_len));
        let keyword_score = compute_keyword_overlap(&query_keywords, &content_keywords);
        total_score += keyword_score * 0.8;

        // BM25 score
        let content_terms = canonical_tokens_from_text(content);
        let bm25_score = compute_bm25_score(&query_terms, &content_terms, None, None);
        total_score += (bm25_score / 10.0).min(1.0) * 0.5;

        if total_score > thresh {
            scores.insert(id.clone(), total_score);
        }
    }

    scores
}

/// Calculate TF-IDF scores for terms in a document
///
/// # Arguments
/// * `document` - The document text
/// * `document_frequencies` - Map of term to number of documents containing it
/// * `total_documents` - Total number of documents in corpus
pub fn compute_tfidf(
    document: &str,
    document_frequencies: &HashMap<String, usize>,
    total_documents: usize,
) -> HashMap<String, f64> {
    let tokens = canonical_tokens_from_text(document);

    // Term frequency in document
    let mut tf: HashMap<String, usize> = HashMap::new();
    for token in &tokens {
        *tf.entry(token.clone()).or_insert(0) += 1;
    }

    let doc_length = tokens.len() as f64;
    let total_docs = total_documents as f64;

    let mut tfidf = HashMap::new();

    for (term, count) in tf {
        // Normalized TF
        let term_freq = count as f64 / doc_length;

        // IDF with smoothing
        let df = *document_frequencies.get(&term).unwrap_or(&1) as f64;
        let idf = (total_docs / df).ln() + 1.0;

        tfidf.insert(term, term_freq * idf);
    }

    tfidf
}

/// Extract top-k keywords by TF-IDF score
pub fn top_k_keywords(tfidf: &HashMap<String, f64>, k: usize) -> Vec<(String, f64)> {
    let mut scores: Vec<_> = tfidf.iter().map(|(k, v)| (k.clone(), *v)).collect();
    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    scores.truncate(k);
    scores
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_keywords() {
        let keywords = extract_keywords("hello world test", None);

        // Should contain the tokens
        assert!(keywords.contains("hello"));
        assert!(keywords.contains("world"));
        assert!(keywords.contains("test"));

        // Should contain bigrams
        assert!(keywords.contains("hello_world"));
        assert!(keywords.contains("world_test"));

        // Should contain character trigrams
        assert!(keywords.contains("hel"));
        assert!(keywords.contains("ell"));
    }

    #[test]
    fn test_compute_keyword_overlap() {
        let query: HashSet<_> = vec!["hello".to_string(), "world".to_string()]
            .into_iter()
            .collect();
        let content: HashSet<_> = vec!["hello".to_string(), "there".to_string()]
            .into_iter()
            .collect();

        let overlap = compute_keyword_overlap(&query, &content);
        assert!((overlap - 0.5).abs() < 0.01); // 1 out of 2 matches
    }

    #[test]
    fn test_exact_phrase_match() {
        assert!(exact_phrase_match("hello", "say hello world"));
        assert!(exact_phrase_match("HELLO", "say hello world"));
        assert!(!exact_phrase_match("goodbye", "say hello world"));
    }

    #[test]
    fn test_compute_bm25_score() {
        let query = vec!["hello".to_string(), "world".to_string()];
        let content = vec![
            "hello".to_string(),
            "world".to_string(),
            "test".to_string(),
        ];

        let score = compute_bm25_score(&query, &content, None, None);
        assert!(score > 0.0);
    }

    #[test]
    fn test_bm25_no_matches() {
        let query = vec!["foo".to_string()];
        let content = vec!["bar".to_string(), "baz".to_string()];

        let score = compute_bm25_score(&query, &content, None, None);
        assert!(score.abs() < 1e-6);
    }

    #[test]
    fn test_keyword_filter_memories() {
        let memories = vec![
            ("1".to_string(), "hello world programming".to_string()),
            ("2".to_string(), "goodbye moon".to_string()),
            ("3".to_string(), "hello there".to_string()),
        ];

        let scores = keyword_filter_memories("hello", &memories, None, None);

        // Memory 1 and 3 should match
        assert!(scores.contains_key("1"));
        assert!(scores.contains_key("3"));

        // Memory 2 shouldn't match well
        // (might still appear with low score due to character trigrams)
    }

    #[test]
    fn test_tfidf() {
        let mut doc_freq = HashMap::new();
        doc_freq.insert("hello".to_string(), 5);
        doc_freq.insert("world".to_string(), 2);
        doc_freq.insert("rare".to_string(), 1);

        let tfidf = compute_tfidf("hello world rare", &doc_freq, 100);

        // Rare term should have higher TF-IDF
        assert!(tfidf.get("rare").unwrap() > tfidf.get("hello").unwrap());
    }

    #[test]
    fn test_top_k_keywords() {
        let mut tfidf = HashMap::new();
        tfidf.insert("low".to_string(), 0.1);
        tfidf.insert("medium".to_string(), 0.5);
        tfidf.insert("high".to_string(), 0.9);

        let top = top_k_keywords(&tfidf, 2);
        assert_eq!(top.len(), 2);
        assert_eq!(top[0].0, "high");
        assert_eq!(top[1].0, "medium");
    }
}