1use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct RagChunk {
15 pub file_id: String,
16 pub file_name: String,
17 pub chunk_index: usize,
18 pub text: String,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct RagResult {
24 pub file: String,
25 pub chunk: usize,
26 pub score: f64,
27 pub text: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct RagStatus {
33 pub doc_count: usize,
34 pub chunk_count: usize,
35 pub indexed: bool,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub engine: Option<String>,
38}
39
40#[cfg(feature = "rag")]
45pub struct RagIndex {
46 bm25: RwLock<trueno_rag::BM25Index>,
48 chunks: RwLock<Vec<RagChunk>>,
50 id_map: RwLock<HashMap<String, usize>>,
52 indexed_files: RwLock<std::collections::HashSet<String>>,
54}
55
56#[cfg(feature = "rag")]
57impl Default for RagIndex {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63#[cfg(feature = "rag")]
64impl RagIndex {
65 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 bm25: RwLock::new(trueno_rag::BM25Index::new()),
69 chunks: RwLock::new(Vec::new()),
70 id_map: RwLock::new(HashMap::new()),
71 indexed_files: RwLock::new(std::collections::HashSet::new()),
72 }
73 }
74
75 pub fn index_document(&self, file_id: &str, file_name: &str, text: &str) {
77 let chunk_texts = chunk_text(text, 512, 64);
78 let doc_id = trueno_rag::DocumentId::new();
79
80 let mut bm25 = self.bm25.write().unwrap_or_else(|e| e.into_inner());
81 let mut chunks = self.chunks.write().unwrap_or_else(|e| e.into_inner());
82 let mut id_map = self.id_map.write().unwrap_or_else(|e| e.into_inner());
83
84 let mut offset = 0;
85 for (i, chunk_text) in chunk_texts.iter().enumerate() {
86 let end_offset = offset + chunk_text.len();
87 let chunk = trueno_rag::Chunk::new(doc_id, chunk_text.clone(), offset, end_offset);
88
89 let chunk_id_str = chunk.id.0.to_string();
90 let our_idx = chunks.len();
91
92 use trueno_rag::SparseIndex;
94 bm25.add(&chunk);
95
96 chunks.push(RagChunk {
98 file_id: file_id.to_string(),
99 file_name: file_name.to_string(),
100 chunk_index: i,
101 text: chunk_text.clone(),
102 });
103 id_map.insert(chunk_id_str, our_idx);
104 offset = end_offset;
105 }
106
107 if let Ok(mut files) = self.indexed_files.write() {
108 files.insert(file_id.to_string());
109 }
110 }
111
112 pub fn search(&self, query: &str, top_k: usize, min_score: f64) -> Vec<RagResult> {
114 let bm25 = self.bm25.read().unwrap_or_else(|e| e.into_inner());
115 let chunks = self.chunks.read().unwrap_or_else(|e| e.into_inner());
116 let id_map = self.id_map.read().unwrap_or_else(|e| e.into_inner());
117
118 use trueno_rag::SparseIndex;
119 let results = bm25.search(query, top_k);
120
121 results
122 .into_iter()
123 .filter(|(_, score)| (*score as f64) >= min_score)
124 .filter_map(|(chunk_id, score)| {
125 let key = chunk_id.0.to_string();
126 let idx = id_map.get(&key)?;
127 let c = chunks.get(*idx)?;
128 Some(RagResult {
129 file: c.file_name.clone(),
130 chunk: c.chunk_index,
131 score: score as f64,
132 text: c.text.clone(),
133 })
134 })
135 .collect()
136 }
137
138 #[must_use]
140 pub fn status(&self) -> RagStatus {
141 let chunk_count = self.chunks.read().map(|c| c.len()).unwrap_or(0);
142 let doc_count = self.indexed_files.read().map(|f| f.len()).unwrap_or(0);
143 RagStatus {
144 doc_count,
145 chunk_count,
146 indexed: chunk_count > 0,
147 engine: Some("trueno-rag".to_string()),
148 }
149 }
150
151 pub fn clear(&self) {
153 *self.bm25.write().unwrap_or_else(|e| e.into_inner()) = trueno_rag::BM25Index::new();
154 if let Ok(mut c) = self.chunks.write() {
155 c.clear();
156 }
157 if let Ok(mut m) = self.id_map.write() {
158 m.clear();
159 }
160 if let Ok(mut f) = self.indexed_files.write() {
161 f.clear();
162 }
163 }
164
165 #[must_use]
167 pub fn is_indexed(&self, file_id: &str) -> bool {
168 self.indexed_files.read().map(|f| f.contains(file_id)).unwrap_or(false)
169 }
170}
171
172#[cfg(not(feature = "rag"))]
177pub struct RagIndex {
178 chunks: RwLock<Vec<RagChunk>>,
179 postings: RwLock<HashMap<String, Vec<(usize, u32)>>>,
180 doc_lengths: RwLock<Vec<usize>>,
181 indexed_files: RwLock<std::collections::HashSet<String>>,
182}
183
184#[cfg(not(feature = "rag"))]
185impl Default for RagIndex {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191#[cfg(not(feature = "rag"))]
192impl RagIndex {
193 #[must_use]
194 pub fn new() -> Self {
195 Self {
196 chunks: RwLock::new(Vec::new()),
197 postings: RwLock::new(HashMap::new()),
198 doc_lengths: RwLock::new(Vec::new()),
199 indexed_files: RwLock::new(std::collections::HashSet::new()),
200 }
201 }
202
203 pub fn index_document(&self, file_id: &str, file_name: &str, text: &str) {
205 let chunk_texts = chunk_text(text, 512, 64);
206
207 let mut chunks = self.chunks.write().unwrap_or_else(|e| e.into_inner());
208 let mut postings = self.postings.write().unwrap_or_else(|e| e.into_inner());
209 let mut doc_lens = self.doc_lengths.write().unwrap_or_else(|e| e.into_inner());
210
211 for (i, ct) in chunk_texts.iter().enumerate() {
212 let chunk_idx = chunks.len();
213 chunks.push(RagChunk {
214 file_id: file_id.to_string(),
215 file_name: file_name.to_string(),
216 chunk_index: i,
217 text: ct.clone(),
218 });
219
220 let terms = tokenize(ct);
221 let mut term_freqs: HashMap<&str, u32> = HashMap::new();
222 for term in &terms {
223 *term_freqs.entry(term.as_str()).or_insert(0) += 1;
224 }
225
226 for (term, freq) in term_freqs {
227 postings.entry(term.to_string()).or_default().push((chunk_idx, freq));
228 }
229 doc_lens.push(terms.len());
230 }
231
232 if let Ok(mut files) = self.indexed_files.write() {
233 files.insert(file_id.to_string());
234 }
235 }
236
237 pub fn search(&self, query: &str, top_k: usize, min_score: f64) -> Vec<RagResult> {
239 let chunks = self.chunks.read().unwrap_or_else(|e| e.into_inner());
240 let postings = self.postings.read().unwrap_or_else(|e| e.into_inner());
241 let doc_lens = self.doc_lengths.read().unwrap_or_else(|e| e.into_inner());
242
243 if chunks.is_empty() {
244 return Vec::new();
245 }
246
247 let n = chunks.len() as f64;
248 let avg_dl: f64 = if doc_lens.is_empty() {
249 1.0
250 } else {
251 doc_lens.iter().sum::<usize>() as f64 / doc_lens.len() as f64
252 };
253
254 let query_terms = tokenize(query);
255 let mut scores: HashMap<usize, f64> = HashMap::new();
256 let (k1, b) = (1.2, 0.75);
257
258 for term in &query_terms {
259 if let Some(posting_list) = postings.get(term.as_str()) {
260 let df = posting_list.len() as f64;
261 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
262 for &(chunk_idx, tf) in posting_list {
263 let dl = doc_lens.get(chunk_idx).copied().unwrap_or(1) as f64;
264 let tf_norm =
265 (tf as f64 * (k1 + 1.0)) / (tf as f64 + k1 * (1.0 - b + b * dl / avg_dl));
266 *scores.entry(chunk_idx).or_insert(0.0) += idf * tf_norm;
267 }
268 }
269 }
270
271 let mut results: Vec<(usize, f64)> =
272 scores.into_iter().filter(|&(_, s)| s >= min_score).collect();
273 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
274 results.truncate(top_k);
275
276 results
277 .into_iter()
278 .filter_map(|(idx, score)| {
279 chunks.get(idx).map(|c| RagResult {
280 file: c.file_name.clone(),
281 chunk: c.chunk_index,
282 score,
283 text: c.text.clone(),
284 })
285 })
286 .collect()
287 }
288
289 #[must_use]
291 pub fn status(&self) -> RagStatus {
292 let chunk_count = self.chunks.read().map(|c| c.len()).unwrap_or(0);
293 let doc_count = self.indexed_files.read().map(|f| f.len()).unwrap_or(0);
294 RagStatus { doc_count, chunk_count, indexed: chunk_count > 0, engine: None }
295 }
296
297 pub fn clear(&self) {
299 if let Ok(mut c) = self.chunks.write() {
300 c.clear();
301 }
302 if let Ok(mut p) = self.postings.write() {
303 p.clear();
304 }
305 if let Ok(mut d) = self.doc_lengths.write() {
306 d.clear();
307 }
308 if let Ok(mut f) = self.indexed_files.write() {
309 f.clear();
310 }
311 }
312
313 #[must_use]
315 pub fn is_indexed(&self, file_id: &str) -> bool {
316 self.indexed_files.read().map(|f| f.contains(file_id)).unwrap_or(false)
317 }
318}
319
320fn chunk_text(text: &str, max_tokens: usize, overlap_tokens: usize) -> Vec<String> {
326 let max_chars = max_tokens * 4;
327 let overlap_chars = overlap_tokens.min(max_tokens / 2) * 4;
328
329 if text.len() <= max_chars {
330 return vec![text.to_string()];
331 }
332
333 let mut chunks = Vec::new();
334 let mut start = 0;
335 while start < text.len() {
336 let end = (start + max_chars).min(text.len());
337 chunks.push(text[start..end].to_string());
338 if end == text.len() {
339 break;
340 }
341 start = end.saturating_sub(overlap_chars);
342 }
343 chunks
344}
345
346#[cfg(not(feature = "rag"))]
348fn tokenize(text: &str) -> Vec<String> {
349 text.split_whitespace()
350 .map(|w| w.to_lowercase().trim_matches(|c: char| !c.is_alphanumeric()).to_string())
351 .filter(|w| w.len() > 1)
352 .collect()
353}