1use anyhow::{anyhow, Result};
7use std::path::Path;
8use tantivy::collector::TopDocs;
9use tantivy::query::QueryParser;
10use tantivy::schema::{Field, Schema, Value, STORED, STRING, TEXT};
11use tantivy::{doc, Index, IndexReader, IndexWriter, ReloadPolicy, TantivyDocument};
12
13#[cfg(test)]
14mod tests;
15
16#[derive(Debug, Clone)]
30pub struct TextSearchConfig {
31 pub writer_buffer_mb: usize,
38}
39
40impl Default for TextSearchConfig {
41 fn default() -> Self {
42 Self {
43 writer_buffer_mb: 50,
44 }
45 }
46}
47
48pub struct TextIndex {
53 index: Index,
54 writer: IndexWriter,
55 reader: IndexReader,
56 id_field: Field,
57 text_field: Field,
58}
59
60impl TextIndex {
61 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
69 Self::open_with_config(path, &TextSearchConfig::default())
70 }
71
72 pub fn open_with_config<P: AsRef<Path>>(path: P, config: &TextSearchConfig) -> Result<Self> {
81 let path = path.as_ref();
82 std::fs::create_dir_all(path)?;
83
84 let schema = Self::create_schema();
85 let id_field = schema.get_field("id").expect("id field exists");
86 let text_field = schema.get_field("text").expect("text field exists");
87
88 let index = if path.join("meta.json").exists() {
90 Index::open_in_dir(path)?
91 } else {
92 Index::create_in_dir(path, schema.clone())?
93 };
94
95 let buffer_bytes = config.writer_buffer_mb * 1_000_000;
96 let writer = index.writer(buffer_bytes)?;
97
98 let reader = index
99 .reader_builder()
100 .reload_policy(ReloadPolicy::OnCommitWithDelay)
101 .try_into()?;
102
103 Ok(Self {
104 index,
105 writer,
106 reader,
107 id_field,
108 text_field,
109 })
110 }
111
112 pub fn open_in_memory() -> Result<Self> {
114 Self::open_in_memory_with_config(&TextSearchConfig::default())
115 }
116
117 pub fn open_in_memory_with_config(config: &TextSearchConfig) -> Result<Self> {
119 let schema = Self::create_schema();
120 let id_field = schema.get_field("id").expect("id field exists");
121 let text_field = schema.get_field("text").expect("text field exists");
122
123 let index = Index::create_in_ram(schema);
124
125 let buffer_bytes = config.writer_buffer_mb * 1_000_000;
126 let writer = index.writer(buffer_bytes)?;
127
128 let reader = index
129 .reader_builder()
130 .reload_policy(ReloadPolicy::OnCommitWithDelay)
131 .try_into()?;
132
133 Ok(Self {
134 index,
135 writer,
136 reader,
137 id_field,
138 text_field,
139 })
140 }
141
142 fn create_schema() -> Schema {
143 let mut builder = Schema::builder();
144
145 builder.add_text_field("id", STRING | STORED);
147
148 builder.add_text_field("text", TEXT);
150
151 builder.build()
152 }
153
154 pub fn index_document(&mut self, id: &str, text: &str) -> Result<()> {
158 self.delete_document(id)?;
160
161 self.writer.add_document(doc!(
162 self.id_field => id,
163 self.text_field => text,
164 ))?;
165
166 Ok(())
167 }
168
169 pub fn delete_document(&mut self, id: &str) -> Result<()> {
171 let term = tantivy::Term::from_field_text(self.id_field, id);
172 self.writer.delete_term(term);
173 Ok(())
174 }
175
176 pub fn commit(&mut self) -> Result<()> {
181 self.writer.commit()?;
182 self.reader.reload()?;
184 Ok(())
185 }
186
187 pub fn search(&self, query_str: &str, limit: usize) -> Result<Vec<(String, f32)>> {
195 if query_str.trim().is_empty() {
196 return Ok(vec![]);
197 }
198
199 let searcher = self.reader.searcher();
200
201 let query_parser = QueryParser::for_index(&self.index, vec![self.text_field]);
202 let query = query_parser
203 .parse_query(query_str)
204 .map_err(|e| anyhow!("Invalid query: {e}"))?;
205
206 let top_docs = searcher.search(&query, &TopDocs::with_limit(limit))?;
207
208 let results = top_docs
209 .into_iter()
210 .filter_map(|(score, doc_addr)| {
211 let doc: TantivyDocument = searcher.doc(doc_addr).ok()?;
212 let id = doc.get_first(self.id_field)?.as_str()?.to_string();
213 Some((id, score))
214 })
215 .collect();
216
217 Ok(results)
218 }
219
220 #[must_use]
222 pub fn num_docs(&self) -> u64 {
223 self.reader.searcher().num_docs()
224 }
225
226 #[must_use]
228 pub fn index(&self) -> &Index {
229 &self.index
230 }
231
232 #[must_use]
234 pub fn reader(&self) -> &IndexReader {
235 &self.reader
236 }
237}
238
239pub const DEFAULT_RRF_K: usize = 60;
241
242#[derive(Debug, Clone)]
246pub struct HybridResult {
247 pub id: String,
249 pub score: f32,
251 pub keyword_score: Option<f32>,
253 pub semantic_score: Option<f32>,
255}
256
257#[must_use]
271pub fn reciprocal_rank_fusion(
272 vector_results: Vec<(String, f32)>,
273 text_results: Vec<(String, f32)>,
274 limit: usize,
275 rrf_k: usize,
276) -> Vec<(String, f32)> {
277 weighted_reciprocal_rank_fusion(vector_results, text_results, limit, rrf_k, 0.5)
278}
279
280#[must_use]
297pub fn weighted_reciprocal_rank_fusion(
298 vector_results: Vec<(String, f32)>,
299 text_results: Vec<(String, f32)>,
300 limit: usize,
301 rrf_k: usize,
302 alpha: f32,
303) -> Vec<(String, f32)> {
304 use std::collections::HashMap;
305
306 let alpha = alpha.clamp(0.0, 1.0);
308
309 let mut scores: HashMap<String, f32> = HashMap::new();
310
311 for (rank, (id, _distance)) in vector_results.iter().enumerate() {
314 let rrf_score = 1.0 / (rrf_k + rank + 1) as f32;
315 *scores.entry(id.clone()).or_default() += alpha * rrf_score;
316 }
317
318 for (rank, (id, _score)) in text_results.iter().enumerate() {
321 let rrf_score = 1.0 / (rrf_k + rank + 1) as f32;
322 *scores.entry(id.clone()).or_default() += (1.0 - alpha) * rrf_score;
323 }
324
325 let mut results: Vec<_> = scores.into_iter().collect();
327 results.sort_by(|a, b| b.1.total_cmp(&a.1));
328 results.truncate(limit);
329
330 results
331}
332
333#[must_use]
345pub fn weighted_reciprocal_rank_fusion_with_subscores(
346 vector_results: Vec<(String, f32)>,
347 text_results: Vec<(String, f32)>,
348 limit: usize,
349 rrf_k: usize,
350 alpha: f32,
351) -> Vec<HybridResult> {
352 use std::collections::HashMap;
353
354 let alpha = alpha.clamp(0.0, 1.0);
355
356 let mut rrf_scores: HashMap<String, f32> = HashMap::new();
358 let mut semantic_scores: HashMap<String, f32> = HashMap::new();
359 let mut keyword_scores: HashMap<String, f32> = HashMap::new();
360
361 for (rank, (id, distance)) in vector_results.iter().enumerate() {
363 let rrf_score = 1.0 / (rrf_k + rank + 1) as f32;
364 *rrf_scores.entry(id.clone()).or_default() += alpha * rrf_score;
365 semantic_scores.insert(id.clone(), *distance);
366 }
367
368 for (rank, (id, bm25_score)) in text_results.iter().enumerate() {
370 let rrf_score = 1.0 / (rrf_k + rank + 1) as f32;
371 *rrf_scores.entry(id.clone()).or_default() += (1.0 - alpha) * rrf_score;
372 keyword_scores.insert(id.clone(), *bm25_score);
373 }
374
375 let mut results: Vec<HybridResult> = rrf_scores
377 .into_iter()
378 .map(|(id, score)| HybridResult {
379 keyword_score: keyword_scores.get(&id).copied(),
380 semantic_score: semantic_scores.get(&id).copied(),
381 id,
382 score,
383 })
384 .collect();
385
386 results.sort_by(|a, b| b.score.total_cmp(&a.score));
387 results.truncate(limit);
388
389 results
390}