use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::{Arc, Mutex},
};
use crate::{
embed::Embedder,
error::Result,
retrieve_store::RetrieveStore,
vector_store::{ChunkSearchResult, VecInfo},
};
pub use crate::retrieve_store::{Document, SearchResult};
#[cfg(feature = "sqlite-store")]
use crate::sqlite_store::SqliteStore;
#[cfg(feature = "lancedb-store")]
use crate::lancedb_store::LanceDbBackend;
#[cfg(feature = "sqlite-store")]
pub use crate::sqlite_store::SCHEMA_VERSION;
struct InMemoryStore {
state: Mutex<InMemoryState>,
}
#[derive(Default)]
struct InMemoryState {
files: HashMap<String, i64>,
documents: HashMap<i64, Document>,
}
impl InMemoryStore {
fn new() -> Self {
Self { state: Mutex::new(InMemoryState::default()) }
}
}
impl RetrieveStore for InMemoryStore {
fn file_mtimes(&self) -> Result<HashMap<String, i64>> {
Ok(self.state.lock().unwrap().files.clone())
}
fn upsert_file(&self, path: &str, mtime: i64) -> Result<()> {
self.state.lock().unwrap().files.insert(path.to_owned(), mtime);
Ok(())
}
fn remove_file(&self, path: &str) -> Result<()> {
self.state.lock().unwrap().files.remove(path);
Ok(())
}
fn file_count(&self) -> Result<u64> {
Ok(self.state.lock().unwrap().files.len() as u64)
}
fn upsert_document(&self, doc: &Document) -> Result<()> {
self.state.lock().unwrap().documents.insert(doc.id, doc.clone());
Ok(())
}
fn remove_document(&self, id: i64) -> Result<()> {
self.state.lock().unwrap().documents.remove(&id);
Ok(())
}
fn rebuild_fts(&self) -> Result<()> {
Ok(())
}
fn search_fts(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
let state = self.state.lock().unwrap();
let q = query.to_lowercase();
let mut results: Vec<SearchResult> = state
.documents
.values()
.filter(|doc| {
doc.title.to_lowercase().contains(&q)
|| doc.body.to_lowercase().contains(&q)
})
.take(limit)
.map(|doc| SearchResult {
id: doc.id,
title: doc.title.clone(),
path: doc.path.clone(),
score: 0.0,
})
.collect();
results.sort_by(|a, b| a.title.cmp(&b.title));
Ok(results)
}
fn document_ids(&self) -> Result<Vec<i64>> {
Ok(self.state.lock().unwrap().documents.keys().copied().collect())
}
fn document_count(&self) -> Result<u64> {
Ok(self.state.lock().unwrap().documents.len() as u64)
}
fn embed_pending(
&self,
_embedder: &dyn Embedder,
_on_progress: &dyn Fn(usize, usize),
) -> Result<usize> {
Ok(0)
}
fn vec_info(&self) -> Result<VecInfo> {
Ok(VecInfo { embedding_dim: 0, vector_count: 0, pending_count: 0 })
}
fn search_similar(
&self,
_query_vec: &[f32],
_limit: usize,
) -> Result<Vec<ChunkSearchResult>> {
Ok(vec![])
}
}
enum BackendState {
InMemory(Arc<InMemoryStore>),
#[cfg(feature = "sqlite-store")]
Sqlite(Arc<SqliteStore>),
#[cfg(feature = "lancedb-store")]
LanceDb(Arc<LanceDbBackend>),
}
impl BackendState {
fn as_store(&self) -> Arc<dyn RetrieveStore> {
match self {
BackendState::InMemory(s) => Arc::clone(s) as Arc<dyn RetrieveStore>,
#[cfg(feature = "sqlite-store")]
BackendState::Sqlite(s) => Arc::clone(s) as Arc<dyn RetrieveStore>,
#[cfg(feature = "lancedb-store")]
BackendState::LanceDb(l) => Arc::clone(l) as Arc<dyn RetrieveStore>,
}
}
fn needs_init(&self) -> bool {
match self {
BackendState::InMemory(_) => true,
#[cfg(feature = "sqlite-store")]
BackendState::Sqlite(s) => s.dim().is_none(),
#[cfg(feature = "lancedb-store")]
BackendState::LanceDb(_) => false,
}
}
}
pub struct RetrieveDb {
db_path: PathBuf,
backend: Mutex<BackendState>,
}
impl RetrieveDb {
pub fn open(db_path: &Path) -> Result<Self> {
#[cfg(feature = "sqlite-store")]
{
let store = SqliteStore::new_fts_only(db_path.to_owned());
return Ok(Self {
db_path: db_path.to_owned(),
backend: Mutex::new(BackendState::Sqlite(Arc::new(store))),
});
}
#[cfg(not(feature = "sqlite-store"))]
Ok(Self {
db_path: db_path.to_owned(),
backend: Mutex::new(BackendState::InMemory(Arc::new(InMemoryStore::new()))),
})
}
pub fn rebuild(db_path: &Path) -> Result<Self> {
#[cfg(feature = "sqlite-store")]
crate::sqlite_store::wipe_db_files(db_path);
Self::open(db_path)
}
#[cfg(feature = "sqlite-store")]
pub fn init_sqlite_vec(&self, embedding_dim: u32) -> Result<()> {
let mut guard = self.backend.lock().unwrap();
if guard.needs_init() {
let store = SqliteStore::new_with_vec(self.db_path.clone(), embedding_dim)?;
*guard = BackendState::Sqlite(Arc::new(store));
}
Ok(())
}
#[cfg(feature = "lancedb-store")]
pub fn init_lancedb(&self, lancedb_dir: &Path, embedding_dim: u32) -> Result<()> {
let mut guard = self.backend.lock().unwrap();
if guard.needs_init() {
let backend = LanceDbBackend::new(lancedb_dir, embedding_dim)?;
*guard = BackendState::LanceDb(Arc::new(backend));
}
Ok(())
}
fn store(&self) -> Arc<dyn RetrieveStore> {
self.backend.lock().unwrap().as_store()
}
pub fn upsert_document(&self, doc: &Document) -> Result<()> {
self.store().upsert_document(doc)
}
pub fn remove_document(&self, id: i64) -> Result<()> {
self.store().remove_document(id)
}
pub fn rebuild_fts(&self) -> Result<()> {
self.store().rebuild_fts()
}
pub fn search_fts(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
self.store().search_fts(query, limit)
}
pub fn search_similar(
&self,
query_vec: &[f32],
limit: usize,
) -> Result<Vec<ChunkSearchResult>> {
self.store().search_similar(query_vec, limit)
}
pub fn dedup_chunk_results(
results: Vec<ChunkSearchResult>,
limit: usize,
) -> Vec<SearchResult> {
let mut best: HashMap<i64, ChunkSearchResult> = HashMap::new();
for r in results {
best.entry(r.doc_id)
.and_modify(|e| {
if r.score < e.score {
*e = r.clone();
}
})
.or_insert(r);
}
let mut deduped: Vec<_> = best.into_values().collect();
deduped.sort_by(|a, b| {
a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)
});
deduped.truncate(limit);
deduped
.into_iter()
.map(|r| SearchResult { id: r.doc_id, title: r.doc_title, path: r.doc_path, score: r.score })
.collect()
}
pub fn embed_pending(
&self,
embedder: &dyn Embedder,
on_progress: impl Fn(usize, usize),
) -> Result<usize> {
self.store().embed_pending(embedder, &on_progress)
}
pub fn vec_info(&self) -> Result<VecInfo> {
self.store().vec_info()
}
pub fn document_ids(&self) -> Result<Vec<i64>> {
self.store().document_ids()
}
pub fn document_count(&self) -> Result<u64> {
self.store().document_count()
}
pub fn file_mtimes(&self) -> Result<HashMap<String, i64>> {
self.store().file_mtimes()
}
pub fn upsert_file(&self, path: &str, mtime: i64) -> Result<()> {
self.store().upsert_file(path, mtime)
}
pub fn remove_file(&self, path: &str) -> Result<()> {
self.store().remove_file(path)
}
pub fn file_count(&self) -> Result<u64> {
self.store().file_count()
}
}