use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use anyhow::Result;
use super::config::find_project_root;
use super::definitions;
fn open_store_with(
opener: fn(&Path) -> std::result::Result<cqs::Store, cqs::store::StoreError>,
) -> Result<(cqs::Store, PathBuf, PathBuf)> {
let root = find_project_root();
let cqs_dir = cqs::resolve_index_dir(&root);
let index_path = cqs_dir.join("index.db");
if !index_path.exists() {
anyhow::bail!("Index not found. Run 'cqs init && cqs index' first.");
}
let store = opener(&index_path)
.map_err(|e| anyhow::anyhow!("Failed to open index at {}: {}", index_path.display(), e))?;
Ok((store, root, cqs_dir))
}
pub(crate) fn open_project_store() -> Result<(cqs::Store, PathBuf, PathBuf)> {
open_store_with(cqs::Store::open)
}
pub(crate) fn open_project_store_readonly() -> Result<(cqs::Store, PathBuf, PathBuf)> {
open_store_with(cqs::Store::open_readonly_pooled)
}
pub(crate) struct CommandContext<'a> {
pub cli: &'a definitions::Cli,
pub store: cqs::Store,
pub root: PathBuf,
pub cqs_dir: PathBuf,
reranker: OnceLock<cqs::Reranker>,
embedder: OnceLock<cqs::Embedder>,
splade_encoder: OnceLock<Option<cqs::splade::SpladeEncoder>>,
splade_index: OnceLock<Option<cqs::splade::index::SpladeIndex>>,
}
impl<'a> CommandContext<'a> {
pub fn open_readonly(cli: &'a definitions::Cli) -> Result<Self> {
let (store, root, cqs_dir) = open_project_store_readonly()?;
Ok(Self {
cli,
store,
root,
cqs_dir,
reranker: OnceLock::new(),
embedder: OnceLock::new(),
splade_encoder: OnceLock::new(),
splade_index: OnceLock::new(),
})
}
pub fn open_readwrite(cli: &'a definitions::Cli) -> Result<Self> {
let _span = tracing::info_span!("CommandContext::open_readwrite").entered();
let (store, root, cqs_dir) = open_project_store()?;
Ok(Self {
cli,
store,
root,
cqs_dir,
reranker: OnceLock::new(),
embedder: OnceLock::new(),
splade_encoder: OnceLock::new(),
splade_index: OnceLock::new(),
})
}
#[allow(deprecated)]
pub fn model_config(&self) -> &cqs::embedder::ModelConfig {
self.cli.model_config()
}
pub fn reranker(&self) -> Result<&cqs::Reranker> {
if let Some(r) = self.reranker.get() {
return Ok(r);
}
let _span = tracing::info_span!("command_context_reranker_init").entered();
let r = cqs::Reranker::new().map_err(|e| anyhow::anyhow!("Reranker init failed: {e}"))?;
let _ = self.reranker.set(r);
Ok(self
.reranker
.get()
.expect("reranker OnceLock populated by set() above"))
}
pub fn embedder(&self) -> Result<&cqs::Embedder> {
if let Some(e) = self.embedder.get() {
return Ok(e);
}
let _span = tracing::info_span!("command_context_embedder_init").entered();
let e = cqs::Embedder::new(self.model_config().clone())
.map_err(|e| anyhow::anyhow!("Embedder init failed: {e}"))?;
let _ = self.embedder.set(e);
Ok(self
.embedder
.get()
.expect("embedder OnceLock populated by set() above"))
}
pub fn splade_encoder(&self) -> Option<&cqs::splade::SpladeEncoder> {
let opt = self.splade_encoder.get_or_init(|| {
let _span = tracing::debug_span!("command_context_splade_encoder_init").entered();
let model_dir = cqs::splade::resolve_splade_model_dir()?;
match cqs::splade::SpladeEncoder::new(
&model_dir,
cqs::splade::SpladeEncoder::default_threshold(),
) {
Ok(enc) => Some(enc),
Err(e) => {
tracing::warn!(
path = %model_dir.display(),
error = %e,
"Failed to load SPLADE encoder"
);
None
}
}
});
opt.as_ref()
}
pub fn splade_index(&self) -> Option<&cqs::splade::index::SpladeIndex> {
let opt = self.splade_index.get_or_init(|| {
let _span = tracing::debug_span!("command_context_splade_index_init").entered();
let generation = match self.store.splade_generation() {
Ok(g) => g,
Err(e) => {
tracing::warn!(
error = %e,
"Failed to read splade_generation — skipping SPLADE entirely for this \
invocation; search will fall back to dense-only"
);
return None;
}
};
let splade_path = self.cqs_dir.join(cqs::splade::index::SPLADE_INDEX_FILENAME);
let store = &self.store;
let (idx, rebuilt) =
cqs::splade::index::SpladeIndex::load_or_build(&splade_path, generation, || {
match store.load_all_sparse_vectors() {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, "Failed to load sparse vectors");
Vec::new()
}
}
});
if idx.is_empty() {
tracing::debug!("No sparse vectors in store, SPLADE index unavailable");
return None;
}
tracing::info!(
chunks = idx.len(),
tokens = idx.unique_tokens(),
rebuilt,
"SPLADE index ready"
);
Some(idx)
});
opt.as_ref()
}
}
pub(crate) fn build_vector_index(
store: &cqs::Store,
cqs_dir: &Path,
) -> Result<Option<Box<dyn cqs::index::VectorIndex>>> {
build_vector_index_with_config(store, cqs_dir, None)
}
pub(crate) fn build_vector_index_with_config(
store: &cqs::Store,
cqs_dir: &Path,
ef_search: Option<usize>,
) -> Result<Option<Box<dyn cqs::index::VectorIndex>>> {
let _span = tracing::info_span!("build_vector_index_with_config").entered();
let _ = store; #[cfg(feature = "gpu-index")]
{
let cagra_threshold: u64 = std::env::var("CQS_CAGRA_THRESHOLD")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5000);
let chunk_count = store.chunk_count().unwrap_or_else(|e| {
tracing::warn!(error = %e, "Failed to get chunk count for CAGRA threshold check");
0
});
if chunk_count >= cagra_threshold && cqs::cagra::CagraIndex::gpu_available() {
match cqs::cagra::CagraIndex::build_from_store(store, store.dim()) {
Ok(idx) => {
tracing::info!("Using CAGRA GPU index ({} vectors)", idx.len());
return Ok(Some(Box::new(idx) as Box<dyn cqs::index::VectorIndex>));
}
Err(e) => {
tracing::warn!(error = %e, "Failed to build CAGRA index, falling back to HNSW");
}
}
} else if chunk_count < cagra_threshold {
tracing::debug!(
"Index too small for CAGRA ({} < {}), using HNSW",
chunk_count,
cagra_threshold
);
} else {
tracing::debug!("GPU not available, using HNSW");
}
}
if store.is_hnsw_dirty().unwrap_or(true) {
match cqs::hnsw::verify_hnsw_checksums(cqs_dir, "index") {
Ok(()) => {
tracing::info!(
"HNSW dirty flag set but checksums pass — clearing flag (self-heal)"
);
if let Err(e) = store.set_hnsw_dirty(false) {
tracing::warn!(error = %e, "Failed to clear dirty flag");
}
}
Err(e) => {
tracing::warn!(
error = %e,
"HNSW index stale (checksum mismatch). \
Falling back to brute-force search. Run 'cqs index' to rebuild."
);
return Ok(None);
}
}
}
Ok(cqs::HnswIndex::try_load_with_ef(
cqs_dir,
ef_search,
Some(store.dim()),
))
}
pub(crate) fn build_base_vector_index(
store: &cqs::Store,
cqs_dir: &Path,
) -> Result<Option<Box<dyn cqs::index::VectorIndex>>> {
let _span = tracing::info_span!("build_base_vector_index").entered();
if std::env::var("CQS_DISABLE_BASE_INDEX").as_deref() == Ok("1") {
tracing::info!("CQS_DISABLE_BASE_INDEX=1 — base index bypass active");
return Ok(None);
}
if store.is_hnsw_dirty().unwrap_or(true) {
match cqs::hnsw::verify_hnsw_checksums(cqs_dir, "index_base") {
Ok(()) => {
tracing::info!(
"Base HNSW dirty flag set but checksums pass — clearing flag (self-heal)"
);
if let Err(e) = store.set_hnsw_dirty(false) {
tracing::warn!(error = %e, "Failed to clear dirty flag");
}
}
Err(e) => {
tracing::warn!(
error = %e,
"Base HNSW index stale (checksum mismatch) — router falls back to enriched"
);
return Ok(None);
}
}
}
Ok(cqs::HnswIndex::try_load_base_with_ef(
cqs_dir,
None,
Some(store.dim()),
))
}
#[cfg(test)]
mod base_index_tests {
use super::*;
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
fn make_embedding(seed: f32, dim: usize) -> cqs::embedder::Embedding {
let mut v = vec![seed; dim];
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
cqs::embedder::Embedding::new(v)
}
#[test]
fn test_disable_base_index_env_short_circuits_with_files_present() {
let _guard = ENV_LOCK.lock().unwrap();
let dir = tempfile::TempDir::new().unwrap();
let db_path = dir.path().join("index.db");
let store = cqs::Store::open(&db_path).unwrap();
store.init(&cqs::store::ModelInfo::default()).unwrap();
store.set_hnsw_dirty(false).unwrap();
let dim = store.dim();
let embeddings: Vec<(String, cqs::embedder::Embedding)> = (0..10)
.map(|i| (format!("vec{i}"), make_embedding(i as f32 + 0.1, dim)))
.collect();
let index = cqs::HnswIndex::build_with_dim(embeddings, dim).unwrap();
index.save(dir.path(), "index_base").unwrap();
std::env::remove_var("CQS_DISABLE_BASE_INDEX");
let loaded = build_base_vector_index(&store, dir.path()).unwrap();
assert!(
loaded.is_some(),
"without bypass, base files present + store clean → should load"
);
assert_eq!(loaded.unwrap().len(), 10);
std::env::set_var("CQS_DISABLE_BASE_INDEX", "1");
let bypassed = build_base_vector_index(&store, dir.path()).unwrap();
assert!(
bypassed.is_none(),
"with CQS_DISABLE_BASE_INDEX=1, base files exist + store clean \
→ must return None (this is the load-bearing A/B-eval behavior)"
);
std::env::remove_var("CQS_DISABLE_BASE_INDEX");
let after_unset = build_base_vector_index(&store, dir.path()).unwrap();
assert!(
after_unset.is_some(),
"after env var unset, normal load path should resume"
);
}
#[test]
fn test_disable_base_index_env_strict_value_match() {
let _guard = ENV_LOCK.lock().unwrap();
let dir = tempfile::TempDir::new().unwrap();
let db_path = dir.path().join("index.db");
let store = cqs::Store::open(&db_path).unwrap();
store.init(&cqs::store::ModelInfo::default()).unwrap();
store.set_hnsw_dirty(false).unwrap();
let dim = store.dim();
let embeddings: Vec<(String, cqs::embedder::Embedding)> = (0..5)
.map(|i| (format!("v{i}"), make_embedding(i as f32 + 0.1, dim)))
.collect();
let index = cqs::HnswIndex::build_with_dim(embeddings, dim).unwrap();
index.save(dir.path(), "index_base").unwrap();
for non_one in ["", "0", "true", "yes", "on", "TRUE", " 1", "1 ", "false"] {
std::env::set_var("CQS_DISABLE_BASE_INDEX", non_one);
let result = build_base_vector_index(&store, dir.path()).unwrap();
assert!(
result.is_some(),
"CQS_DISABLE_BASE_INDEX={non_one:?} must NOT activate bypass"
);
}
std::env::remove_var("CQS_DISABLE_BASE_INDEX");
}
}