use rayon::prelude::*;
use crate::config::ReferenceConfig;
use crate::hnsw::HnswIndex;
use crate::index::VectorIndex;
use crate::store::{SearchFilter, SearchResult, Store, StoreError, UnifiedResult};
use crate::Embedding;
pub struct ReferenceIndex {
pub name: String,
pub store: Store,
pub index: Option<Box<dyn VectorIndex>>,
pub weight: f32,
}
impl std::fmt::Debug for ReferenceIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReferenceIndex")
.field("name", &self.name)
.field("weight", &self.weight)
.field("has_index", &self.index.is_some())
.finish()
}
}
#[derive(Debug)]
pub struct TaggedResult {
pub result: UnifiedResult,
pub source: Option<String>,
}
fn load_single_reference(cfg: &ReferenceConfig) -> Option<ReferenceIndex> {
let _span = tracing::info_span!("load_single_reference", name = %cfg.name).entered();
if cfg
.path
.symlink_metadata()
.map(|m| m.is_symlink())
.unwrap_or(false)
{
tracing::warn!(
name = cfg.name,
path = %cfg.path.display(),
"Skipping reference: path is a symlink (use the real path instead)"
);
return None;
}
if let Ok(canonical) = cfg.path.canonicalize() {
let home = dirs::home_dir();
let cwd = std::env::current_dir().ok();
let in_home = home.as_ref().is_some_and(|h| canonical.starts_with(h));
let in_project = cwd.as_ref().is_some_and(|p| canonical.starts_with(p));
let in_cqs_dir = canonical.components().any(|c| c.as_os_str() == ".cqs");
if !in_home && !in_project && !in_cqs_dir {
tracing::warn!(
name = %cfg.name,
path = %canonical.display(),
"Reference path is outside project and home directories"
);
}
}
let db_path = cfg.path.join("index.db");
let store = match Store::open_readonly(&db_path) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
"Skipping reference '{}': failed to open {}: {}",
cfg.name,
db_path.display(),
e
);
return None;
}
};
let index = HnswIndex::try_load_with_ef(&cfg.path, None, None);
Some(ReferenceIndex {
name: cfg.name.clone(),
store,
index,
weight: cfg.weight,
})
}
pub fn load_references(configs: &[ReferenceConfig]) -> Vec<ReferenceIndex> {
let _span = tracing::debug_span!("load_references", count = configs.len()).entered();
let threads = std::env::var("CQS_RAYON_THREADS")
.ok()
.and_then(|v| {
let parsed = v.parse();
if parsed.is_err() {
tracing::warn!(value = %v, "Invalid CQS_RAYON_THREADS, using default");
}
parsed.ok()
})
.unwrap_or(4);
let pool = match rayon::ThreadPoolBuilder::new().num_threads(threads).build() {
Ok(p) => p,
Err(e) => {
tracing::warn!(error = %e, "Failed to create reference loading thread pool, loading sequentially");
return configs.iter().filter_map(load_single_reference).collect();
}
};
let refs: Vec<ReferenceIndex> = pool.install(|| {
configs
.par_iter()
.filter_map(load_single_reference)
.collect()
});
if !refs.is_empty() {
tracing::info!("Loaded {} reference indexes", refs.len());
}
refs
}
pub fn search_reference(
ref_idx: &ReferenceIndex,
query_embedding: &Embedding,
filter: &SearchFilter,
limit: usize,
threshold: f32,
apply_weight: bool,
) -> Result<Vec<SearchResult>, StoreError> {
let _span =
tracing::info_span!("search_reference", name = %ref_idx.name, weight = ref_idx.weight, apply_weight)
.entered();
let mut results = ref_idx.store.search_filtered_with_index(
query_embedding,
filter,
limit,
threshold,
ref_idx.index.as_deref(),
)?;
if apply_weight {
for r in &mut results {
r.score *= ref_idx.weight;
}
results.retain(|r| r.score >= threshold);
}
Ok(results)
}
pub fn search_reference_by_name(
ref_idx: &ReferenceIndex,
name: &str,
limit: usize,
threshold: f32,
apply_weight: bool,
) -> Result<Vec<SearchResult>, StoreError> {
let _span =
tracing::info_span!("search_reference_by_name", ref_name = %ref_idx.name, query = name, apply_weight)
.entered();
let mut results = ref_idx.store.search_by_name(name, limit)?;
if apply_weight {
results.retain(|r| r.score * ref_idx.weight >= threshold);
for r in &mut results {
r.score *= ref_idx.weight;
}
} else {
results.retain(|r| r.score >= threshold);
}
Ok(results)
}
pub fn merge_results(
primary: Vec<UnifiedResult>,
refs: Vec<(String, Vec<SearchResult>)>,
limit: usize,
) -> Vec<TaggedResult> {
let mut tagged: Vec<TaggedResult> = Vec::new();
for result in primary {
tagged.push(TaggedResult {
result,
source: None,
});
}
for (name, results) in refs {
for r in results {
tagged.push(TaggedResult {
result: UnifiedResult::Code(r),
source: Some(name.clone()),
});
}
}
tagged.sort_by(|a, b| b.result.score().total_cmp(&a.result.score()));
let mut seen_hashes = std::collections::HashSet::new();
tagged.retain(|t| match &t.result {
UnifiedResult::Code(r) => {
if r.chunk.content_hash.is_empty() {
let hash = blake3::hash(r.chunk.content.as_bytes()).to_string();
seen_hashes.insert(hash)
} else {
seen_hashes.insert(r.chunk.content_hash.clone())
}
}
});
tagged.truncate(limit);
tagged
}
pub fn refs_dir() -> Option<std::path::PathBuf> {
let dir = dirs::data_local_dir();
if dir.is_none() {
tracing::warn!("Could not determine local data directory for reference storage");
}
dir.map(|d| d.join("cqs/refs"))
}
pub fn validate_ref_name(name: &str) -> Result<(), &'static str> {
if name.is_empty() {
return Err("Reference name cannot be empty");
}
if name.contains('\0') {
return Err("Reference name cannot contain null bytes");
}
if name.contains('/') || name.contains('\\') || name.contains("..") {
return Err("Reference name cannot contain '/', '\\', or '..'");
}
if name == "." {
return Err("Reference name cannot be '.'");
}
if name.starts_with('.') {
return Err("Reference name cannot start with '.'");
}
Ok(())
}
pub fn ref_path(name: &str) -> Option<std::path::PathBuf> {
validate_ref_name(name).ok()?;
refs_dir().map(|d| d.join(name))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::ChunkSummary;
fn make_code_result(name: &str, score: f32) -> SearchResult {
SearchResult {
chunk: ChunkSummary {
id: format!("id-{}", name),
file: std::path::PathBuf::from(format!("src/{}.rs", name)),
language: crate::parser::Language::Rust,
chunk_type: crate::parser::ChunkType::Function,
name: name.to_string(),
signature: String::new(),
content: format!("fn {}() {{}}", name),
doc: None,
line_start: 1,
line_end: 1,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
},
score,
}
}
#[test]
fn test_merge_results_empty_refs() {
let primary = vec![UnifiedResult::Code(make_code_result("foo", 0.9))];
let refs: Vec<(String, Vec<SearchResult>)> = vec![];
let merged = merge_results(primary, refs, 10);
assert_eq!(merged.len(), 1);
assert!(merged[0].source.is_none());
}
#[test]
fn test_merge_results_only_refs() {
let primary: Vec<UnifiedResult> = vec![];
let refs = vec![("tokio".to_string(), vec![make_code_result("spawn", 0.8)])];
let merged = merge_results(primary, refs, 10);
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].source.as_deref(), Some("tokio"));
}
#[test]
fn test_merge_results_sorted_by_score() {
let primary = vec![
UnifiedResult::Code(make_code_result("primary_low", 0.5)),
UnifiedResult::Code(make_code_result("primary_high", 0.95)),
];
let refs = vec![(
"tokio".to_string(),
vec![
make_code_result("ref_mid", 0.7),
make_code_result("ref_high", 0.9),
],
)];
let merged = merge_results(primary, refs, 10);
assert_eq!(merged.len(), 4);
assert!(merged[0].result.score() >= merged[1].result.score());
assert!(merged[1].result.score() >= merged[2].result.score());
assert!(merged[2].result.score() >= merged[3].result.score());
}
#[test]
fn test_merge_results_truncates_to_limit() {
let primary = vec![
UnifiedResult::Code(make_code_result("a", 0.9)),
UnifiedResult::Code(make_code_result("b", 0.8)),
UnifiedResult::Code(make_code_result("c", 0.7)),
];
let refs = vec![("tokio".to_string(), vec![make_code_result("d", 0.85)])];
let merged = merge_results(primary, refs, 2);
assert_eq!(merged.len(), 2);
assert!(merged[0].result.score() > 0.85);
}
#[test]
fn test_merge_results_weight_applied() {
let primary = vec![UnifiedResult::Code(make_code_result("project_fn", 0.8))];
let refs = vec![(
"tokio".to_string(),
vec![make_code_result("ref_fn", 0.72)], )];
let merged = merge_results(primary, refs, 10);
assert_eq!(merged.len(), 2);
assert!(merged[0].source.is_none());
assert_eq!(merged[1].source.as_deref(), Some("tokio"));
}
#[test]
fn test_tagged_result_source_values() {
let primary = vec![UnifiedResult::Code(make_code_result("a", 0.9))];
let refs = vec![
("tokio".to_string(), vec![make_code_result("b", 0.8)]),
("serde".to_string(), vec![make_code_result("c", 0.7)]),
];
let merged = merge_results(primary, refs, 10);
assert!(merged[0].source.is_none()); assert_eq!(merged[1].source.as_deref(), Some("tokio"));
assert_eq!(merged[2].source.as_deref(), Some("serde"));
}
#[test]
fn test_load_references_skips_missing_path() {
let configs = vec![ReferenceConfig {
name: "nonexistent".into(),
path: "/tmp/cqs_test_nonexistent_ref_path_12345".into(),
source: None,
weight: 0.8,
}];
let refs = load_references(&configs);
assert!(refs.is_empty());
}
#[test]
fn test_ref_path_helper() {
if let Some(path) = ref_path("tokio") {
assert!(path.ends_with("cqs/refs/tokio"));
}
}
#[test]
fn test_validate_ref_name_rejects_traversal() {
assert!(validate_ref_name("../etc").is_err());
assert!(validate_ref_name("foo/bar").is_err());
assert!(validate_ref_name("foo\\bar").is_err());
assert!(validate_ref_name("..").is_err());
assert!(validate_ref_name(".").is_err());
assert!(validate_ref_name("").is_err());
assert!(validate_ref_name("foo\0bar").is_err());
}
#[test]
fn test_validate_ref_name_accepts_valid() {
assert!(validate_ref_name("tokio").is_ok());
assert!(validate_ref_name("my-ref").is_ok());
assert!(validate_ref_name("ref_v2").is_ok());
}
#[test]
fn test_merge_deduplicates_by_content() {
let primary = vec![UnifiedResult::Code(SearchResult {
chunk: ChunkSummary {
id: "primary-id".to_string(),
file: std::path::PathBuf::from("src/foo.rs"),
language: crate::parser::Language::Rust,
chunk_type: crate::parser::ChunkType::Function,
name: "foo".to_string(),
signature: String::new(),
content: "fn foo() {}".to_string(), doc: None,
line_start: 1,
line_end: 1,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
},
score: 0.9,
})];
let refs = vec![(
"ref1".to_string(),
vec![SearchResult {
chunk: ChunkSummary {
id: "ref-id".to_string(),
file: std::path::PathBuf::from("src/foo.rs"),
language: crate::parser::Language::Rust,
chunk_type: crate::parser::ChunkType::Function,
name: "foo".to_string(),
signature: String::new(),
content: "fn foo() {}".to_string(), doc: None,
line_start: 1,
line_end: 1,
parent_id: None,
parent_type_name: None,
content_hash: String::new(),
window_idx: None,
},
score: 0.7,
}],
)];
let merged = merge_results(primary, refs, 10);
assert_eq!(merged.len(), 1);
assert!(merged[0].source.is_none());
assert!((merged[0].result.score() - 0.9).abs() < 0.01);
}
#[test]
fn test_ref_path_rejects_traversal() {
assert!(ref_path("../etc").is_none());
assert!(ref_path("foo/bar").is_none());
}
}