1use crate::agent::{text_utils, Agent, ResponseFormat, TrainableAgent, TrainingExample};
3use indexmap::IndexMap;
4use std::collections::HashSet;
5
6pub struct TfidfAgent {
8 docs: Vec<TrainingExample>,
10
11 term_df: IndexMap<String, f32>,
13
14 doc_term_freq: Vec<IndexMap<String, f32>>,
16
17 doc_count: f32,
19
20 bm25_k1: f32,
22
23 bm25_b: f32,
25}
26
27impl TfidfAgent {
28 pub fn new() -> Self {
30 Self {
31 docs: Vec::new(),
32 term_df: IndexMap::new(),
33 doc_term_freq: Vec::new(),
34 doc_count: 0.0,
35 bm25_k1: 1.2, bm25_b: 0.75, }
38 }
39
40 pub fn with_bm25_params(mut self, k1: f32, b: f32) -> Self {
42 self.bm25_k1 = k1;
43 self.bm25_b = b;
44 self
45 }
46
47 fn bm25_score(&self, query_terms: &[String], doc_idx: usize) -> f32 {
49 let avg_doc_len: f32 = self
51 .doc_term_freq
52 .iter()
53 .map(|doc| doc.values().sum::<f32>())
54 .sum::<f32>()
55 / self.doc_count;
56
57 let doc_len: f32 = self.doc_term_freq[doc_idx].values().sum();
59
60 query_terms
61 .iter()
62 .map(|term| {
63 if let Some(&df) = self.term_df.get(term) {
65 let idf = (self.doc_count - df + 0.5) / (df + 0.5);
67 let idf = (1.0 + idf).ln();
68
69 let tf = self.doc_term_freq[doc_idx]
71 .get(term)
72 .cloned()
73 .unwrap_or(0.0);
74
75 let numerator = tf * (self.bm25_k1 + 1.0);
77 let denominator = tf
78 + self.bm25_k1 * (1.0 - self.bm25_b + self.bm25_b * doc_len / avg_doc_len);
79
80 idf * numerator / denominator
81 } else {
82 0.0
83 }
84 })
85 .sum()
86 }
87}
88
89impl Agent for TfidfAgent {
90 fn predict(&self, input: &str) -> ResponseFormat {
92 if self.docs.is_empty() {
94 return ResponseFormat::Text("No training data available.".to_string());
95 }
96
97 let query_terms = text_utils::tokenize(input);
99
100 let mut scores: Vec<(usize, f32)> = self
102 .docs
103 .iter()
104 .enumerate()
105 .map(|(i, doc)| {
106 let score = self.bm25_score(&query_terms, i) * doc.weight;
108 (i, score)
109 })
110 .collect();
111
112 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
114
115 if let Some(&(best_idx, score)) = scores.first() {
117 if score > 0.0 {
118 return self.docs[best_idx].output.clone();
119 }
120 }
121
122 ResponseFormat::Text("No matching answer found.".to_string())
123 }
124}
125
126impl TrainableAgent for TfidfAgent {
127 fn train(&mut self, data: &[TrainingExample]) {
129 self.docs = data.to_vec();
131 self.doc_count = data.len() as f32;
132 self.term_df.clear();
133 self.doc_term_freq.clear();
134
135 for doc in &self.docs {
137 let mut doc_terms: IndexMap<String, f32> = IndexMap::new();
139 let terms = text_utils::tokenize(&doc.input);
140
141 for term in &terms {
143 *doc_terms.entry(term.clone()).or_insert(0.0) += 1.0;
144 }
145
146 let unique_terms: HashSet<String> = terms.into_iter().collect();
148 for term in unique_terms {
149 *self.term_df.entry(term).or_insert(0.0) += 1.0;
150 }
151
152 self.doc_term_freq.push(doc_terms);
153 }
154 }
155}
156
157impl Default for TfidfAgent {
159 fn default() -> Self {
160 Self::new()
161 }
162}