use anyhow::{Context, Result};
use rusqlite::Connection;
use std::path::{Path, PathBuf};
use std::sync::{Mutex, OnceLock};
use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
use crate::commands::oxidize::trainer::Projection;
use crate::commands::scry::internal::enrichment::{enrich_results, SearchResults};
use crate::retrieval::oracle::{Oracle, OracleMetadata, OracleResult};
use patina::embeddings::{create_embedder, EmbeddingEngine};
struct SemanticCache {
embedder: Mutex<Box<dyn EmbeddingEngine>>,
projection: Option<Projection>,
index: Index,
}
pub struct SemanticOracle {
db_path: PathBuf,
index_path: PathBuf,
projection_path: PathBuf,
domain: String,
cache: OnceLock<Result<SemanticCache, String>>,
}
impl SemanticOracle {
pub fn new() -> Self {
Self::for_domain("knowledge")
}
pub fn for_domain(domain: &str) -> Self {
let model = patina::project::load(Path::new("."))
.ok()
.map(|c| c.embeddings.model)
.unwrap_or_else(|| "e5-base-v2".to_string());
let embeddings_dir = format!(".patina/local/data/embeddings/{}/projections", model);
let (index_name, proj_name) = if domain == "knowledge" {
if PathBuf::from(format!("{}/knowledge.usearch", embeddings_dir)).exists() {
("knowledge", "knowledge")
} else {
("semantic", "semantic")
}
} else {
(domain, domain)
};
Self {
db_path: PathBuf::from(".patina/local/data/patina.db"),
index_path: PathBuf::from(format!("{}/{}.usearch", embeddings_dir, index_name)),
projection_path: PathBuf::from(format!("{}/{}.safetensors", embeddings_dir, proj_name)),
domain: domain.to_string(),
cache: OnceLock::new(),
}
}
pub fn available_domains() -> Vec<String> {
let model = patina::project::load(Path::new("."))
.ok()
.map(|c| c.embeddings.model)
.unwrap_or_else(|| "e5-base-v2".to_string());
let embeddings_dir = format!(".patina/local/data/embeddings/{}/projections", model);
let dir = Path::new(&embeddings_dir);
if !dir.exists() {
return vec![];
}
let excluded = ["temporal", "dependency"];
let mut domains = Vec::new();
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("usearch") {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
if excluded.contains(&stem) {
continue;
}
let domain = if stem == "semantic" {
"knowledge".to_string()
} else {
stem.to_string()
};
if !domains.contains(&domain) {
domains.push(domain);
}
}
}
}
}
domains.sort();
domains
}
fn init_cache(&self) -> Result<SemanticCache, String> {
let embedder =
create_embedder().map_err(|e| format!("Failed to create embedder: {}", e))?;
let projection = if self.projection_path.exists() {
Some(
Projection::load_safetensors(&self.projection_path)
.map_err(|e| format!("Failed to load projection: {}", e))?,
)
} else {
None
};
let dimensions = match &projection {
Some(proj) => proj.w2.len(), None => embedder.dimension(), };
let index_options = IndexOptions {
dimensions,
metric: MetricKind::Cos,
quantization: ScalarKind::F32,
..Default::default()
};
let index =
Index::new(&index_options).map_err(|e| format!("Failed to create index: {}", e))?;
index
.load(self.index_path.to_str().unwrap_or(""))
.map_err(|e| format!("Failed to load index: {}", e))?;
Ok(SemanticCache {
embedder: Mutex::new(embedder),
projection,
index,
})
}
fn get_cache(&self) -> Result<&SemanticCache> {
let cache_result = self.cache.get_or_init(|| self.init_cache());
match cache_result {
Ok(cache) => Ok(cache),
Err(msg) => Err(anyhow::anyhow!("{}", msg)),
}
}
}
impl Oracle for SemanticOracle {
fn name(&self) -> &'static str {
"semantic"
}
fn query(&self, query: &str, limit: usize) -> Result<Vec<OracleResult>> {
let cache = self.get_cache()?;
let query_embedding = {
let mut embedder = cache
.embedder
.lock()
.map_err(|e| anyhow::anyhow!("Embedder lock poisoned: {}", e))?;
embedder.embed_query(query)?
};
let projected = match &cache.projection {
Some(proj) => proj.forward(&query_embedding),
None => query_embedding,
};
const EXACT_SEARCH_THRESHOLD: usize = 10_000;
let matches = if cache.index.size() < EXACT_SEARCH_THRESHOLD {
cache
.index
.exact_search(&projected, limit)
.with_context(|| "Exact vector search failed")?
} else {
cache
.index
.search(&projected, limit)
.with_context(|| "Vector search failed")?
};
let results = SearchResults {
keys: matches.keys,
distances: matches.distances,
};
let conn = Connection::open(&self.db_path)
.with_context(|| format!("Failed to open database: {:?}", self.db_path))?;
let enriched = enrich_results(&conn, &results, &self.domain, 0.0)?;
let source = self.name();
Ok(enriched
.into_iter()
.map(|r| OracleResult {
doc_id: r.source_id.clone(),
content: r.content,
source,
score: r.score,
score_type: "cosine",
metadata: OracleMetadata {
file_path: Some(r.source_id),
timestamp: if r.timestamp.is_empty() {
None
} else {
Some(r.timestamp)
},
event_type: Some(r.event_type),
matches: None,
},
})
.collect())
}
fn is_available(&self) -> bool {
self.index_path.exists() && self.db_path.exists()
}
}