airust/
tfidf_agent.rs

1// src/tfidf_agent.rs - Optimized TF-IDF/BM25 Agent
2use crate::agent::{text_utils, Agent, ResponseFormat, TrainableAgent, TrainingExample};
3use indexmap::IndexMap;
4use std::collections::HashSet;
5
6/// TF-IDF Agent using BM25 scoring for intelligent text matching
7pub struct TfidfAgent {
8    /// Stored training documents
9    docs: Vec<TrainingExample>,
10
11    /// Document frequency for each term (in how many documents a term appears)
12    term_df: IndexMap<String, f32>,
13
14    /// Term frequencies for each document
15    doc_term_freq: Vec<IndexMap<String, f32>>,
16
17    /// Total number of documents
18    doc_count: f32,
19
20    /// BM25 parameter k1 (controls term frequency scaling)
21    bm25_k1: f32,
22
23    /// BM25 parameter b (controls document length normalization)
24    bm25_b: f32,
25}
26
27impl TfidfAgent {
28    /// Creates a new TF-IDF agent with default BM25 parameters
29    pub fn new() -> Self {
30        Self {
31            docs: Vec::new(),
32            term_df: IndexMap::new(),
33            doc_term_freq: Vec::new(),
34            doc_count: 0.0,
35            bm25_k1: 1.2, // Default term frequency scaling
36            bm25_b: 0.75, // Default length normalization
37        }
38    }
39
40    /// Configures custom BM25 parameters for fine-tuned matching
41    pub fn with_bm25_params(mut self, k1: f32, b: f32) -> Self {
42        self.bm25_k1 = k1;
43        self.bm25_b = b;
44        self
45    }
46
47    /// Calculates BM25 score between query terms and a specific document
48    fn bm25_score(&self, query_terms: &[String], doc_idx: usize) -> f32 {
49        // Calculate average document length
50        let avg_doc_len: f32 = self
51            .doc_term_freq
52            .iter()
53            .map(|doc| doc.values().sum::<f32>())
54            .sum::<f32>()
55            / self.doc_count;
56
57        // Length of the current document
58        let doc_len: f32 = self.doc_term_freq[doc_idx].values().sum();
59
60        query_terms
61            .iter()
62            .map(|term| {
63                // Check if term exists in the document frequency index
64                if let Some(&df) = self.term_df.get(term) {
65                    // Inverse Document Frequency (IDF) component
66                    let idf = (self.doc_count - df + 0.5) / (df + 0.5);
67                    let idf = (1.0 + idf).ln();
68
69                    // Term Frequency (TF) with BM25 normalization
70                    let tf = self.doc_term_freq[doc_idx]
71                        .get(term)
72                        .cloned()
73                        .unwrap_or(0.0);
74
75                    // BM25 scoring formula
76                    let numerator = tf * (self.bm25_k1 + 1.0);
77                    let denominator = tf
78                        + self.bm25_k1 * (1.0 - self.bm25_b + self.bm25_b * doc_len / avg_doc_len);
79
80                    idf * numerator / denominator
81                } else {
82                    0.0
83                }
84            })
85            .sum()
86    }
87}
88
89impl Agent for TfidfAgent {
90    /// Predicts the most relevant response using BM25 scoring
91    fn predict(&self, input: &str) -> ResponseFormat {
92        // Handle empty training data
93        if self.docs.is_empty() {
94            return ResponseFormat::Text("No training data available.".to_string());
95        }
96
97        // Tokenize input into terms
98        let query_terms = text_utils::tokenize(input);
99
100        // Calculate BM25 scores for each document
101        let mut scores: Vec<(usize, f32)> = self
102            .docs
103            .iter()
104            .enumerate()
105            .map(|(i, doc)| {
106                // Calculate score with document weight
107                let score = self.bm25_score(&query_terms, i) * doc.weight;
108                (i, score)
109            })
110            .collect();
111
112        // Sort scores in descending order
113        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
114
115        // Select best matching document
116        if let Some(&(best_idx, score)) = scores.first() {
117            if score > 0.0 {
118                return self.docs[best_idx].output.clone();
119            }
120        }
121
122        ResponseFormat::Text("No matching answer found.".to_string())
123    }
124}
125
126impl TrainableAgent for TfidfAgent {
127    /// Trains the agent by processing training documents
128    fn train(&mut self, data: &[TrainingExample]) {
129        // Reset existing data
130        self.docs = data.to_vec();
131        self.doc_count = data.len() as f32;
132        self.term_df.clear();
133        self.doc_term_freq.clear();
134
135        // Process each document
136        for doc in &self.docs {
137            // Tokenize document input
138            let mut doc_terms: IndexMap<String, f32> = IndexMap::new();
139            let terms = text_utils::tokenize(&doc.input);
140
141            // Count term frequencies
142            for term in &terms {
143                *doc_terms.entry(term.clone()).or_insert(0.0) += 1.0;
144            }
145
146            // Track unique terms for document frequency
147            let unique_terms: HashSet<String> = terms.into_iter().collect();
148            for term in unique_terms {
149                *self.term_df.entry(term).or_insert(0.0) += 1.0;
150            }
151
152            self.doc_term_freq.push(doc_terms);
153        }
154    }
155}
156
157// Default implementation for creating a new TF-IDF agent
158impl Default for TfidfAgent {
159    fn default() -> Self {
160        Self::new()
161    }
162}