1use std::{
8 collections::HashMap,
9 path::{Path, PathBuf},
10 sync::{Arc, Mutex},
11};
12
13use crate::{
14 embed::Embedder,
15 error::Result,
16 retrieve_store::{
17 ChunkHit, Document, FileSearchResult, FtsQuery, HybridQuery, RetrieveStore, VectorQuery,
18 },
19 vector_store::VecInfo,
20};
21
22#[cfg(feature = "sqlite-store")]
23use crate::sqlite_store::SqliteStore;
24
25#[cfg(feature = "lancedb-store")]
26use crate::lancedb_store::LanceDbBackend;
27
28#[cfg(feature = "sqlite-store")]
29pub use crate::sqlite_store::SCHEMA_VERSION;
30
31struct InMemoryStore {
37 state: Mutex<InMemoryState>,
38}
39
40#[derive(Default)]
41struct InMemoryState {
42 files: HashMap<String, i64>,
43 documents: HashMap<i64, Document>,
44}
45
46impl InMemoryStore {
47 fn new() -> Self {
48 Self {
49 state: Mutex::new(InMemoryState::default()),
50 }
51 }
52}
53
54impl RetrieveStore for InMemoryStore {
55 fn file_mtimes(&self) -> Result<HashMap<String, i64>> {
56 Ok(self.state.lock().unwrap().files.clone())
57 }
58
59 fn upsert_file(&self, path: &str, mtime: i64) -> Result<()> {
60 self.state
61 .lock()
62 .unwrap()
63 .files
64 .insert(path.to_owned(), mtime);
65 Ok(())
66 }
67
68 fn remove_file(&self, path: &str) -> Result<()> {
69 self.state.lock().unwrap().files.remove(path);
70 Ok(())
71 }
72
73 fn file_count(&self) -> Result<u64> {
74 Ok(self.state.lock().unwrap().files.len() as u64)
75 }
76
77 fn upsert_document(&self, doc: &Document) -> Result<()> {
78 self.state
79 .lock()
80 .unwrap()
81 .documents
82 .insert(doc.id, doc.clone());
83 Ok(())
84 }
85
86 fn remove_document(&self, id: i64) -> Result<()> {
87 self.state.lock().unwrap().documents.remove(&id);
88 Ok(())
89 }
90
91 fn rebuild_fts(&self) -> Result<()> {
92 Ok(())
93 }
94
95 fn search_fts(&self, q: &FtsQuery<'_>) -> Result<Vec<FileSearchResult>> {
96 let state = self.state.lock().unwrap();
97 let needle = q.query.to_lowercase();
98 let prefix = q.path_prefix.map(|p| p.to_string_lossy().to_string());
99 let mut results: Vec<FileSearchResult> = state
100 .documents
101 .values()
102 .filter(|doc| {
103 if let Some(ref pfx) = prefix
104 && !doc.path.starts_with(pfx.as_str())
105 {
106 return false;
107 }
108 doc.body.to_lowercase().contains(&needle)
109 })
110 .take(q.limit)
111 .map(|doc| FileSearchResult {
112 id: doc.id,
113 path: doc.path.clone(),
114 score: 0.0,
115 chunks: vec![ChunkHit {
116 line_start: 0,
117 line_end: 0,
118 text: String::new(),
119 score: 0.0,
120 }],
121 })
122 .collect();
123 results.sort_by(|a, b| a.path.cmp(&b.path));
124 Ok(results)
125 }
126
127 fn document_ids(&self) -> Result<Vec<i64>> {
128 Ok(self
129 .state
130 .lock()
131 .unwrap()
132 .documents
133 .keys()
134 .copied()
135 .collect())
136 }
137
138 fn document_count(&self) -> Result<u64> {
139 Ok(self.state.lock().unwrap().documents.len() as u64)
140 }
141
142 fn embed_pending(
143 &self,
144 _embedder: &dyn Embedder,
145 _on_progress: &dyn Fn(usize, usize),
146 ) -> Result<usize> {
147 Ok(0)
148 }
149
150 fn vec_info(&self) -> Result<VecInfo> {
151 Ok(VecInfo {
152 embedding_dim: 0,
153 vector_count: 0,
154 pending_count: 0,
155 })
156 }
157
158 fn search_similar(&self, _q: &VectorQuery<'_>) -> Result<Vec<FileSearchResult>> {
159 Ok(vec![])
160 }
161}
162
163enum BackendState {
166 #[allow(dead_code)]
167 InMemory(Arc<InMemoryStore>),
168 #[cfg(feature = "sqlite-store")]
169 Sqlite(Arc<SqliteStore>),
170 #[cfg(feature = "lancedb-store")]
171 LanceDb(Arc<LanceDbBackend>),
172}
173
174impl BackendState {
175 fn as_store(&self) -> Arc<dyn RetrieveStore> {
176 match self {
177 BackendState::InMemory(s) => Arc::clone(s) as Arc<dyn RetrieveStore>,
178 #[cfg(feature = "sqlite-store")]
179 BackendState::Sqlite(s) => Arc::clone(s) as Arc<dyn RetrieveStore>,
180 #[cfg(feature = "lancedb-store")]
181 BackendState::LanceDb(l) => Arc::clone(l) as Arc<dyn RetrieveStore>,
182 }
183 }
184
185 fn needs_init(&self) -> bool {
186 match self {
187 BackendState::InMemory(_) => true,
188 #[cfg(feature = "sqlite-store")]
189 BackendState::Sqlite(s) => s.dim().is_none(),
190 #[cfg(feature = "lancedb-store")]
191 BackendState::LanceDb(_) => false,
192 }
193 }
194}
195
196pub struct RetrieveDb {
199 db_path: PathBuf,
200 backend: Mutex<BackendState>,
201}
202
203impl RetrieveDb {
204 pub fn open(db_path: &Path) -> Result<Self> {
205 #[cfg(feature = "sqlite-store")]
206 {
207 let store = SqliteStore::new_fts_only(db_path.to_owned());
208 Ok(Self {
209 db_path: db_path.to_owned(),
210 backend: Mutex::new(BackendState::Sqlite(Arc::new(store))),
211 })
212 }
213
214 #[cfg(not(feature = "sqlite-store"))]
215 Ok(Self {
216 db_path: db_path.to_owned(),
217 backend: Mutex::new(BackendState::InMemory(Arc::new(InMemoryStore::new()))),
218 })
219 }
220
221 pub fn rebuild(db_path: &Path) -> Result<Self> {
222 #[cfg(feature = "sqlite-store")]
223 crate::sqlite_store::wipe_db_files(db_path);
224 Self::open(db_path)
225 }
226
227 #[cfg(feature = "sqlite-store")]
228 pub fn init_sqlite_vec(&self, embedding_dim: u32) -> Result<()> {
229 let mut guard = self.backend.lock().unwrap();
230 if guard.needs_init() {
231 let store = SqliteStore::new_with_vec(self.db_path.clone(), embedding_dim)?;
232 *guard = BackendState::Sqlite(Arc::new(store));
233 }
234 Ok(())
235 }
236
237 #[cfg(feature = "lancedb-store")]
238 pub fn init_lancedb(&self, lancedb_dir: &Path, embedding_dim: u32) -> Result<()> {
239 let mut guard = self.backend.lock().unwrap();
240 if guard.needs_init() {
241 let backend = LanceDbBackend::new(lancedb_dir, embedding_dim)?;
242 *guard = BackendState::LanceDb(Arc::new(backend));
243 }
244 Ok(())
245 }
246
247 fn store(&self) -> Arc<dyn RetrieveStore> {
248 self.backend.lock().unwrap().as_store()
249 }
250
251 pub fn upsert_document(&self, doc: &Document) -> Result<()> {
254 self.store().upsert_document(doc)
255 }
256
257 pub fn remove_document(&self, id: i64) -> Result<()> {
258 self.store().remove_document(id)
259 }
260
261 pub fn rebuild_fts(&self) -> Result<()> {
262 self.store().rebuild_fts()
263 }
264
265 pub fn search_fts(&self, q: &FtsQuery<'_>) -> Result<Vec<FileSearchResult>> {
268 self.store().search_fts(q)
269 }
270
271 pub fn search_similar(&self, q: &VectorQuery<'_>) -> Result<Vec<FileSearchResult>> {
272 self.store().search_similar(q)
273 }
274
275 pub fn search_hybrid(&self, q: &HybridQuery<'_>) -> Result<Vec<FileSearchResult>> {
276 self.store().search_hybrid(q)
277 }
278
279 pub fn embed_pending(
282 &self,
283 embedder: &dyn Embedder,
284 on_progress: impl Fn(usize, usize),
285 ) -> Result<usize> {
286 self.store().embed_pending(embedder, &on_progress)
287 }
288
289 pub fn vec_info(&self) -> Result<VecInfo> {
290 self.store().vec_info()
291 }
292
293 pub fn document_ids(&self) -> Result<Vec<i64>> {
294 self.store().document_ids()
295 }
296
297 pub fn document_count(&self) -> Result<u64> {
298 self.store().document_count()
299 }
300
301 pub fn file_mtimes(&self) -> Result<HashMap<String, i64>> {
304 self.store().file_mtimes()
305 }
306
307 pub fn upsert_file(&self, path: &str, mtime: i64) -> Result<()> {
308 self.store().upsert_file(path, mtime)
309 }
310
311 pub fn remove_file(&self, path: &str) -> Result<()> {
312 self.store().remove_file(path)
313 }
314
315 pub fn file_count(&self) -> Result<u64> {
316 self.store().file_count()
317 }
318}
319
320pub fn merge_rrf_files(
328 fts: &[FileSearchResult],
329 sem: &[FileSearchResult],
330 k: f64,
331 w_fts: f64,
332 w_sem: f64,
333 limit: usize,
334) -> Vec<FileSearchResult> {
335 let mut acc: HashMap<String, (FileSearchResult, f64)> = HashMap::new();
337
338 for (rank, file) in fts.iter().enumerate() {
339 let rrf = w_fts / (k + (rank + 1) as f64);
340 acc.insert(file.path.clone(), (file.clone(), rrf));
341 }
342
343 for (rank, file) in sem.iter().enumerate() {
344 let rrf = w_sem / (k + (rank + 1) as f64);
345 acc.entry(file.path.clone())
346 .and_modify(|(existing, s)| {
347 *s += rrf;
348 merge_chunk_hits(&mut existing.chunks, &file.chunks);
349 })
350 .or_insert_with(|| (file.clone(), rrf));
351 }
352
353 let mut merged: Vec<_> = acc.into_values().collect();
354 merged.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
355 merged.truncate(limit);
356
357 merged
358 .into_iter()
359 .map(|(mut file, rrf_score)| {
360 file.score = rrf_score;
361 file
362 })
363 .collect()
364}
365
366fn merge_chunk_hits(existing: &mut Vec<ChunkHit>, incoming: &[ChunkHit]) {
372 use std::collections::HashSet;
373 let seen: HashSet<(usize, usize)> = existing
374 .iter()
375 .map(|c| (c.line_start, c.line_end))
376 .collect();
377 for c in incoming {
378 if !seen.contains(&(c.line_start, c.line_end)) {
379 existing.push(c.clone());
380 }
381 }
382}
383
384pub fn default_hybrid<S: RetrieveStore + ?Sized>(
390 store: &S,
391 q: &HybridQuery<'_>,
392) -> Result<Vec<FileSearchResult>> {
393 let over_fetch = q.limit * 3;
394 let fts = store.search_fts(&FtsQuery {
395 query: q.query,
396 limit: over_fetch,
397 path_prefix: q.path_prefix,
398 })?;
399
400 let Some(embedder) = q.embedder else {
401 return Ok(fts.into_iter().take(q.limit).collect());
402 };
403
404 let sem = store.search_similar(&VectorQuery {
405 query: q.query,
406 embedder,
407 limit: over_fetch,
408 path_prefix: q.path_prefix,
409 })?;
410
411 Ok(merge_rrf_files(
412 &fts,
413 &sem,
414 q.rrf_k,
415 q.weight_fts,
416 q.weight_sem,
417 q.limit,
418 ))
419}
420
421pub fn open_in_memory() -> Arc<dyn RetrieveStore + Send + Sync> {
425 Arc::new(InMemoryStore::new())
426}
427
428#[cfg(feature = "sqlite-store")]
429pub fn open_sqlite_fts(db_path: &Path) -> Arc<dyn RetrieveStore + Send + Sync> {
430 Arc::new(SqliteStore::new_fts_only(db_path.to_owned()))
431}
432
433#[cfg(feature = "sqlite-store")]
434pub fn open_sqlite_vec(db_path: &Path, dim: u32) -> Result<Arc<dyn RetrieveStore + Send + Sync>> {
435 Ok(Arc::new(SqliteStore::new_with_vec(
436 db_path.to_owned(),
437 dim,
438 )?))
439}
440
441#[cfg(feature = "lancedb-store")]
442pub fn open_lancedb(data_dir: &Path, dim: u32) -> Result<Arc<dyn RetrieveStore + Send + Sync>> {
443 Ok(Arc::new(LanceDbBackend::new(data_dir, dim)?))
444}