Skip to main content

batuta/serve/banco/
rag.rs

1//! Banco RAG (Retrieval-Augmented Generation) pipeline.
2//!
3//! With `rag` feature: uses trueno-rag's BM25Index for production-grade retrieval.
4//! Without `rag` feature: uses built-in BM25 for zero-dependency operation.
5//!
6//! Chat requests with `rag: true` retrieve relevant chunks before generation.
7
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::RwLock;
11
12/// A chunk in the RAG index.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct RagChunk {
15    pub file_id: String,
16    pub file_name: String,
17    pub chunk_index: usize,
18    pub text: String,
19}
20
21/// A search result with relevance score.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct RagResult {
24    pub file: String,
25    pub chunk: usize,
26    pub score: f64,
27    pub text: String,
28}
29
30/// RAG index status.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct RagStatus {
33    pub doc_count: usize,
34    pub chunk_count: usize,
35    pub indexed: bool,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub engine: Option<String>,
38}
39
40// ============================================================================
41// trueno-rag powered RAG index (rag feature)
42// ============================================================================
43
44#[cfg(feature = "rag")]
45pub struct RagIndex {
46    /// trueno-rag BM25 index.
47    bm25: RwLock<trueno_rag::BM25Index>,
48    /// Chunk metadata store (id → RagChunk).
49    chunks: RwLock<Vec<RagChunk>>,
50    /// Map from trueno-rag ChunkId → our chunk index.
51    id_map: RwLock<HashMap<String, usize>>,
52    /// Set of indexed file IDs.
53    indexed_files: RwLock<std::collections::HashSet<String>>,
54}
55
56#[cfg(feature = "rag")]
57impl Default for RagIndex {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63#[cfg(feature = "rag")]
64impl RagIndex {
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            bm25: RwLock::new(trueno_rag::BM25Index::new()),
69            chunks: RwLock::new(Vec::new()),
70            id_map: RwLock::new(HashMap::new()),
71            indexed_files: RwLock::new(std::collections::HashSet::new()),
72        }
73    }
74
75    /// Index a document's text via trueno-rag BM25.
76    pub fn index_document(&self, file_id: &str, file_name: &str, text: &str) {
77        let chunk_texts = chunk_text(text, 512, 64);
78        let doc_id = trueno_rag::DocumentId::new();
79
80        let mut bm25 = self.bm25.write().unwrap_or_else(|e| e.into_inner());
81        let mut chunks = self.chunks.write().unwrap_or_else(|e| e.into_inner());
82        let mut id_map = self.id_map.write().unwrap_or_else(|e| e.into_inner());
83
84        let mut offset = 0;
85        for (i, chunk_text) in chunk_texts.iter().enumerate() {
86            let end_offset = offset + chunk_text.len();
87            let chunk = trueno_rag::Chunk::new(doc_id, chunk_text.clone(), offset, end_offset);
88
89            let chunk_id_str = chunk.id.0.to_string();
90            let our_idx = chunks.len();
91
92            // Add to trueno-rag BM25 index
93            use trueno_rag::SparseIndex;
94            bm25.add(&chunk);
95
96            // Store our metadata
97            chunks.push(RagChunk {
98                file_id: file_id.to_string(),
99                file_name: file_name.to_string(),
100                chunk_index: i,
101                text: chunk_text.clone(),
102            });
103            id_map.insert(chunk_id_str, our_idx);
104            offset = end_offset;
105        }
106
107        if let Ok(mut files) = self.indexed_files.write() {
108            files.insert(file_id.to_string());
109        }
110    }
111
112    /// Search the index using trueno-rag BM25.
113    pub fn search(&self, query: &str, top_k: usize, min_score: f64) -> Vec<RagResult> {
114        let bm25 = self.bm25.read().unwrap_or_else(|e| e.into_inner());
115        let chunks = self.chunks.read().unwrap_or_else(|e| e.into_inner());
116        let id_map = self.id_map.read().unwrap_or_else(|e| e.into_inner());
117
118        use trueno_rag::SparseIndex;
119        let results = bm25.search(query, top_k);
120
121        results
122            .into_iter()
123            .filter(|(_, score)| (*score as f64) >= min_score)
124            .filter_map(|(chunk_id, score)| {
125                let key = chunk_id.0.to_string();
126                let idx = id_map.get(&key)?;
127                let c = chunks.get(*idx)?;
128                Some(RagResult {
129                    file: c.file_name.clone(),
130                    chunk: c.chunk_index,
131                    score: score as f64,
132                    text: c.text.clone(),
133                })
134            })
135            .collect()
136    }
137
138    /// Get index status.
139    #[must_use]
140    pub fn status(&self) -> RagStatus {
141        let chunk_count = self.chunks.read().map(|c| c.len()).unwrap_or(0);
142        let doc_count = self.indexed_files.read().map(|f| f.len()).unwrap_or(0);
143        RagStatus {
144            doc_count,
145            chunk_count,
146            indexed: chunk_count > 0,
147            engine: Some("trueno-rag".to_string()),
148        }
149    }
150
151    /// Clear the entire index.
152    pub fn clear(&self) {
153        *self.bm25.write().unwrap_or_else(|e| e.into_inner()) = trueno_rag::BM25Index::new();
154        if let Ok(mut c) = self.chunks.write() {
155            c.clear();
156        }
157        if let Ok(mut m) = self.id_map.write() {
158            m.clear();
159        }
160        if let Ok(mut f) = self.indexed_files.write() {
161            f.clear();
162        }
163    }
164
165    /// Check if a file has been indexed.
166    #[must_use]
167    pub fn is_indexed(&self, file_id: &str) -> bool {
168        self.indexed_files.read().map(|f| f.contains(file_id)).unwrap_or(false)
169    }
170}
171
172// ============================================================================
173// Built-in BM25 RAG index (no rag feature)
174// ============================================================================
175
176#[cfg(not(feature = "rag"))]
177pub struct RagIndex {
178    chunks: RwLock<Vec<RagChunk>>,
179    postings: RwLock<HashMap<String, Vec<(usize, u32)>>>,
180    doc_lengths: RwLock<Vec<usize>>,
181    indexed_files: RwLock<std::collections::HashSet<String>>,
182}
183
184#[cfg(not(feature = "rag"))]
185impl Default for RagIndex {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191#[cfg(not(feature = "rag"))]
192impl RagIndex {
193    #[must_use]
194    pub fn new() -> Self {
195        Self {
196            chunks: RwLock::new(Vec::new()),
197            postings: RwLock::new(HashMap::new()),
198            doc_lengths: RwLock::new(Vec::new()),
199            indexed_files: RwLock::new(std::collections::HashSet::new()),
200        }
201    }
202
203    /// Index a document's text, splitting into chunks.
204    pub fn index_document(&self, file_id: &str, file_name: &str, text: &str) {
205        let chunk_texts = chunk_text(text, 512, 64);
206
207        let mut chunks = self.chunks.write().unwrap_or_else(|e| e.into_inner());
208        let mut postings = self.postings.write().unwrap_or_else(|e| e.into_inner());
209        let mut doc_lens = self.doc_lengths.write().unwrap_or_else(|e| e.into_inner());
210
211        for (i, ct) in chunk_texts.iter().enumerate() {
212            let chunk_idx = chunks.len();
213            chunks.push(RagChunk {
214                file_id: file_id.to_string(),
215                file_name: file_name.to_string(),
216                chunk_index: i,
217                text: ct.clone(),
218            });
219
220            let terms = tokenize(ct);
221            let mut term_freqs: HashMap<&str, u32> = HashMap::new();
222            for term in &terms {
223                *term_freqs.entry(term.as_str()).or_insert(0) += 1;
224            }
225
226            for (term, freq) in term_freqs {
227                postings.entry(term.to_string()).or_default().push((chunk_idx, freq));
228            }
229            doc_lens.push(terms.len());
230        }
231
232        if let Ok(mut files) = self.indexed_files.write() {
233            files.insert(file_id.to_string());
234        }
235    }
236
237    /// Search the index using BM25 scoring.
238    pub fn search(&self, query: &str, top_k: usize, min_score: f64) -> Vec<RagResult> {
239        let chunks = self.chunks.read().unwrap_or_else(|e| e.into_inner());
240        let postings = self.postings.read().unwrap_or_else(|e| e.into_inner());
241        let doc_lens = self.doc_lengths.read().unwrap_or_else(|e| e.into_inner());
242
243        if chunks.is_empty() {
244            return Vec::new();
245        }
246
247        let n = chunks.len() as f64;
248        let avg_dl: f64 = if doc_lens.is_empty() {
249            1.0
250        } else {
251            doc_lens.iter().sum::<usize>() as f64 / doc_lens.len() as f64
252        };
253
254        let query_terms = tokenize(query);
255        let mut scores: HashMap<usize, f64> = HashMap::new();
256        let (k1, b) = (1.2, 0.75);
257
258        for term in &query_terms {
259            if let Some(posting_list) = postings.get(term.as_str()) {
260                let df = posting_list.len() as f64;
261                let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
262                for &(chunk_idx, tf) in posting_list {
263                    let dl = doc_lens.get(chunk_idx).copied().unwrap_or(1) as f64;
264                    let tf_norm =
265                        (tf as f64 * (k1 + 1.0)) / (tf as f64 + k1 * (1.0 - b + b * dl / avg_dl));
266                    *scores.entry(chunk_idx).or_insert(0.0) += idf * tf_norm;
267                }
268            }
269        }
270
271        let mut results: Vec<(usize, f64)> =
272            scores.into_iter().filter(|&(_, s)| s >= min_score).collect();
273        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
274        results.truncate(top_k);
275
276        results
277            .into_iter()
278            .filter_map(|(idx, score)| {
279                chunks.get(idx).map(|c| RagResult {
280                    file: c.file_name.clone(),
281                    chunk: c.chunk_index,
282                    score,
283                    text: c.text.clone(),
284                })
285            })
286            .collect()
287    }
288
289    /// Get index status.
290    #[must_use]
291    pub fn status(&self) -> RagStatus {
292        let chunk_count = self.chunks.read().map(|c| c.len()).unwrap_or(0);
293        let doc_count = self.indexed_files.read().map(|f| f.len()).unwrap_or(0);
294        RagStatus { doc_count, chunk_count, indexed: chunk_count > 0, engine: None }
295    }
296
297    /// Clear the entire index.
298    pub fn clear(&self) {
299        if let Ok(mut c) = self.chunks.write() {
300            c.clear();
301        }
302        if let Ok(mut p) = self.postings.write() {
303            p.clear();
304        }
305        if let Ok(mut d) = self.doc_lengths.write() {
306            d.clear();
307        }
308        if let Ok(mut f) = self.indexed_files.write() {
309            f.clear();
310        }
311    }
312
313    /// Check if a file has been indexed.
314    #[must_use]
315    pub fn is_indexed(&self, file_id: &str) -> bool {
316        self.indexed_files.read().map(|f| f.contains(file_id)).unwrap_or(false)
317    }
318}
319
320// ============================================================================
321// Shared utilities
322// ============================================================================
323
324/// Split text into overlapping chunks (~token_count * 4 chars each).
325fn chunk_text(text: &str, max_tokens: usize, overlap_tokens: usize) -> Vec<String> {
326    let max_chars = max_tokens * 4;
327    let overlap_chars = overlap_tokens.min(max_tokens / 2) * 4;
328
329    if text.len() <= max_chars {
330        return vec![text.to_string()];
331    }
332
333    let mut chunks = Vec::new();
334    let mut start = 0;
335    while start < text.len() {
336        let end = (start + max_chars).min(text.len());
337        chunks.push(text[start..end].to_string());
338        if end == text.len() {
339            break;
340        }
341        start = end.saturating_sub(overlap_chars);
342    }
343    chunks
344}
345
346/// Simple whitespace + lowercase tokenizer.
347#[cfg(not(feature = "rag"))]
348fn tokenize(text: &str) -> Vec<String> {
349    text.split_whitespace()
350        .map(|w| w.to_lowercase().trim_matches(|c: char| !c.is_alphanumeric()).to_string())
351        .filter(|w| w.len() > 1)
352        .collect()
353}