use anyhow::{Context, Result};
use super::super::types::ChunkOutput;
use super::super::BatchContext;
use crate::cli::args::SearchArgs;
pub(in crate::cli::batch) fn dispatch_search(
ctx: &BatchContext,
args: &SearchArgs,
) -> Result<serde_json::Value> {
let _span = tracing::info_span!("batch_search", query = %args.query).entered();
let _ = (args.context, args.expand, args.no_stale_check);
let _ = (args.include_docs, args.pattern.as_ref());
if args.name_only {
let results = ctx
.store()
.search_by_name(&args.query, args.limit.clamp(1, 100))?;
let json_results: Vec<serde_json::Value> = results
.iter()
.map(|r| {
serde_json::to_value(ChunkOutput::from_search_result(r, false))
.unwrap_or_else(|e| {
tracing::warn!(error = %e, name = %r.chunk.name, "ChunkOutput serialization failed (NaN score?)");
serde_json::json!({"error": "serialization failed", "name": r.chunk.name})
})
})
.collect();
return Ok(serde_json::json!({
"results": json_results,
"query": args.query,
"total": json_results.len(),
}));
}
let embedder = ctx.embedder()?;
let query_embedding = embedder
.embed_query(&args.query)
.context("Failed to embed query")?;
let languages = match &args.lang {
Some(l) => Some(vec![l
.parse()
.map_err(|_| anyhow::anyhow!("Invalid language '{}'", l))?]),
None => None,
};
let limit = args.limit.clamp(1, 100);
let effective_limit = if args.rerank {
(limit * 4).min(100)
} else {
limit
};
let include_types = match &args.include_type {
Some(types) => {
let parsed: Result<Vec<cqs::parser::ChunkType>, _> =
types.iter().map(|t| t.parse()).collect();
Some(parsed.map_err(|e| anyhow::anyhow!("Invalid --include-type: {e}"))?)
}
None => Some(cqs::parser::ChunkType::code_types()),
};
let exclude_types = match &args.exclude_type {
Some(types) => {
let parsed: Result<Vec<cqs::parser::ChunkType>, _> =
types.iter().map(|t| t.parse()).collect();
Some(parsed.map_err(|e| anyhow::anyhow!("Invalid --exclude-type: {e}"))?)
}
None => None,
};
let classification = cqs::search::router::classify_query(&args.query);
let (use_splade, splade_alpha) = match args.splade_alpha {
Some(alpha) => (true, alpha),
None => (
true,
cqs::search::router::resolve_splade_alpha(&classification.category),
),
};
let _ = args.splade;
let use_base = matches!(
classification.strategy,
cqs::search::router::SearchStrategy::DenseBase
) || std::env::var("CQS_FORCE_BASE_INDEX").as_deref() == Ok("1");
let filter = cqs::SearchFilter {
languages,
include_types,
exclude_types,
path_pattern: args.path.clone(),
name_boost: args.name_boost,
query_text: args.query.clone(),
enable_rrf: args.rrf,
enable_demotion: !args.no_demote,
enable_splade: use_splade,
splade_alpha,
type_boost_types: classification.type_hints.clone(),
};
filter.validate().map_err(|e| anyhow::anyhow!(e))?;
if let Some(ref ref_name) = args.ref_name {
let ref_idx = crate::cli::commands::resolve::find_reference(&ctx.root, ref_name)?;
let ref_limit = if args.rerank {
(limit * 4).min(100)
} else {
limit
};
let threshold = args.threshold;
let mut results = cqs::reference::search_reference(
&ref_idx,
&query_embedding,
&filter,
ref_limit,
threshold,
false,
)?;
if args.rerank && results.len() > 1 {
let reranker = ctx.reranker()?;
reranker
.rerank(&args.query, &mut results, limit)
.map_err(|e| anyhow::anyhow!("Reranking failed: {e}"))?;
}
let show_content = !args.no_content;
let json_results: Vec<serde_json::Value> = results
.iter()
.map(|r| {
serde_json::to_value(ChunkOutput::from_search_result(r, show_content))
.unwrap_or_else(|e| {
tracing::warn!(error = %e, name = %r.chunk.name, "ChunkOutput serialization failed (NaN score?)");
serde_json::json!({"error": "serialization failed", "name": r.chunk.name})
})
})
.collect();
return Ok(serde_json::json!({
"results": json_results,
"query": args.query,
"total": json_results.len(),
"source": ref_name,
}));
}
let splade_query = if use_splade {
ctx.splade_encoder()
.and_then(|enc| match enc.encode(&args.query) {
Ok(sv) => Some(sv),
Err(e) => {
tracing::warn!(error = %e, "SPLADE query encoding failed, falling back to cosine-only");
None
}
})
} else {
None
};
if use_splade {
ctx.ensure_splade_index();
}
let audit_mode = ctx.audit_state();
let index = if use_base {
match ctx.base_vector_index()? {
Some(base_idx) => {
tracing::info!(
category = %classification.category,
"Router selected base HNSW for non-enriched query (batch)"
);
Some(base_idx)
}
None => {
tracing::info!("Base HNSW unavailable — falling back to enriched index (batch)");
ctx.vector_index()?
}
}
} else {
ctx.vector_index()?
};
let index = index.as_deref();
let splade_index_ref = ctx.borrow_splade_index();
let splade_arg = splade_query
.as_ref()
.and_then(|sq| splade_index_ref.as_ref().map(|si| (si, sq)));
let threshold = args.threshold;
let results = if audit_mode.is_active() || splade_arg.is_some() {
let code_results = ctx.store().search_hybrid(
&query_embedding,
&filter,
effective_limit,
threshold,
index,
splade_arg,
)?;
code_results
.into_iter()
.map(cqs::store::UnifiedResult::Code)
.collect()
} else {
ctx.store().search_unified_with_index(
&query_embedding,
&filter,
effective_limit,
threshold,
index,
)?
};
let results = if args.rerank && results.len() > 1 {
let mut code_results: Vec<cqs::store::SearchResult> = results
.into_iter()
.map(|r| match r {
cqs::store::UnifiedResult::Code(sr) => sr,
})
.collect();
let reranker = ctx.reranker()?;
reranker
.rerank(&args.query, &mut code_results, limit)
.map_err(|e| anyhow::anyhow!("Reranking failed: {e}"))?;
code_results
.into_iter()
.map(cqs::store::UnifiedResult::Code)
.collect()
} else {
results
};
let results = if args.include_refs {
let config = cqs::config::Config::load(&ctx.root);
let references = cqs::reference::load_references(&config.references);
if !references.is_empty() {
use rayon::prelude::*;
let ref_results: Vec<_> = references
.par_iter()
.filter_map(|ref_idx| {
match cqs::reference::search_reference(
ref_idx,
&query_embedding,
&filter,
limit,
threshold,
true,
) {
Ok(r) if !r.is_empty() => Some((ref_idx.name.clone(), r)),
Err(e) => {
tracing::warn!(reference = %ref_idx.name, error = %e, "Reference search failed");
None
}
_ => None,
}
})
.collect();
let tagged = cqs::reference::merge_results(results, ref_results, limit);
tagged.into_iter().map(|t| t.result).collect()
} else {
results
}
} else {
results
};
let (results, token_info) = if let Some(budget) = args.tokens {
let embedder = ctx.embedder()?;
crate::cli::commands::token_pack_results(
results,
budget,
crate::cli::commands::JSON_OVERHEAD_PER_RESULT,
embedder,
|r| match r {
cqs::store::UnifiedResult::Code(sr) => sr.chunk.content.as_str(),
},
|r| match r {
cqs::store::UnifiedResult::Code(sr) => sr.score,
},
"batch_search",
)
} else {
(results, None)
};
let show_content = !args.no_content;
let json_results: Vec<serde_json::Value> = results
.iter()
.map(|r| match r {
cqs::store::UnifiedResult::Code(sr) => {
serde_json::to_value(ChunkOutput::from_search_result(sr, show_content))
.unwrap_or_else(|e| {
tracing::warn!(error = %e, name = %sr.chunk.name, "ChunkOutput serialization failed (NaN score?)");
serde_json::json!({"error": "serialization failed", "name": sr.chunk.name})
})
}
})
.collect();
let mut response = serde_json::json!({
"results": json_results,
"query": args.query,
"total": json_results.len(),
});
crate::cli::commands::inject_token_info(&mut response, token_info);
Ok(response)
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::super::commands::{BatchCmd, BatchInput};
use super::super::super::{create_test_context, BatchContext};
use clap::Parser;
use cqs::embedder::Embedding;
use cqs::parser::{Chunk, ChunkType, Language};
use cqs::store::{ModelInfo, Store};
use std::path::PathBuf;
use tempfile::TempDir;
fn make_chunk(
id: &str,
file: &str,
language: Language,
chunk_type: ChunkType,
name: &str,
signature: &str,
content: &str,
) -> Chunk {
let content_hash = blake3::hash(content.as_bytes()).to_hex().to_string();
Chunk {
id: id.to_string(),
file: PathBuf::from(file),
language,
chunk_type,
name: name.to_string(),
signature: signature.to_string(),
content: content.to_string(),
doc: None,
line_start: 1,
line_end: 5,
content_hash,
parent_id: None,
window_idx: None,
parent_type_name: None,
}
}
fn ctx_with_chunks(chunks: Vec<Chunk>) -> (TempDir, BatchContext) {
let dir = TempDir::new().expect("Failed to create temp dir");
let cqs_dir = dir.path().join(".cqs");
std::fs::create_dir_all(&cqs_dir).expect("Failed to create .cqs dir");
let index_path = cqs_dir.join("index.db");
let mut emb_vec = vec![0.0_f32; cqs::EMBEDDING_DIM];
emb_vec[0] = 1.0;
let embedding = Embedding::new(emb_vec);
{
let store = Store::open(&index_path).expect("Failed to open test store");
store
.init(&ModelInfo::default())
.expect("Failed to init test store");
if !chunks.is_empty() {
let pairs: Vec<(Chunk, Embedding)> = chunks
.iter()
.map(|c| (c.clone(), embedding.clone()))
.collect();
store
.upsert_chunks_batch(&pairs, Some(0))
.expect("upsert_chunks_batch failed");
}
}
let ctx = create_test_context(&cqs_dir).expect("Failed to create test context");
(dir, ctx)
}
fn ctx_with_chunk(
id: &str,
file: &str,
language: Language,
chunk_type: ChunkType,
name: &str,
signature: &str,
content: &str,
) -> (TempDir, BatchContext) {
ctx_with_chunks(vec![make_chunk(
id, file, language, chunk_type, name, signature, content,
)])
}
fn empty_ctx() -> (TempDir, BatchContext) {
ctx_with_chunks(vec![])
}
fn parse_search_args(cli_args: &[&str]) -> crate::cli::args::SearchArgs {
let mut full = vec!["search"];
full.extend_from_slice(cli_args);
let input = BatchInput::try_parse_from(&full).expect("clap parse failed");
match input.cmd {
BatchCmd::Search { args } => args,
other => panic!("Expected Search, got {:?}", other),
}
}
#[test]
fn test_dispatch_search_name_only_exact_match_top_result() {
let (_dir, ctx) = ctx_with_chunks(vec![
make_chunk(
"src/lib.rs:1:aaaa0001",
"src/lib.rs",
Language::Rust,
ChunkType::Function,
"process_data",
"fn process_data(input: &str) -> String",
"fn process_data(input: &str) -> String { input.to_uppercase() }",
),
make_chunk(
"src/lib.rs:7:aaaa0002",
"src/lib.rs",
Language::Rust,
ChunkType::Function,
"unrelated_helper",
"fn unrelated_helper()",
"fn unrelated_helper() { println!(\"noop\"); }",
),
]);
let args = parse_search_args(&["process_data", "--name-only"]);
let json = dispatch_search(&ctx, &args).expect("dispatch_search failed");
assert_eq!(json["query"], "process_data");
assert_eq!(json["total"], 1, "Expected exactly 1 matching chunk");
let results = json["results"].as_array().expect("results is array");
assert_eq!(results.len(), 1, "results.len() must match total");
assert_eq!(
results[0]["name"], "process_data",
"Top result must be the exact-name match, not '{}'",
results[0]["name"]
);
let score = results[0]["score"]
.as_f64()
.expect("score is finite number");
assert!(
(score - 1.0).abs() < 1e-6,
"Exact-name match should score 1.0, got {score}. A regression in \
score_name_match_pre_lower or the sort in search_by_name would \
break this."
);
assert_eq!(results[0]["chunk_type"], "function");
assert_eq!(results[0]["language"], "rust");
}
#[test]
fn test_dispatch_search_name_only_prefix_match_ranks_first() {
let (_dir, ctx) = ctx_with_chunks(vec![
make_chunk(
"src/parse.rs:1:bbbb0001",
"src/parse.rs",
Language::Rust,
ChunkType::Function,
"parse_config",
"fn parse_config() -> Config",
"fn parse_config() -> Config { Config::default() }",
),
make_chunk(
"src/lib.rs:1:bbbb0002",
"src/lib.rs",
Language::Rust,
ChunkType::Function,
"do_parse_config",
"fn do_parse_config()",
"fn do_parse_config() { parse_config(); }",
),
]);
let args = parse_search_args(&["parse", "--name-only"]);
let json = dispatch_search(&ctx, &args).expect("dispatch_search failed");
let results = json["results"].as_array().expect("results is array");
assert!(
!results.is_empty(),
"Expected at least one match for 'parse' prefix, got {}",
results.len()
);
assert_eq!(
results[0]["name"], "parse_config",
"Prefix match (0.9) must outrank substring match (0.7); got '{}' first",
results[0]["name"]
);
let top_score = results[0]["score"].as_f64().unwrap();
assert!(
(top_score - 0.9).abs() < 1e-6,
"Prefix match should score 0.9, got {top_score}"
);
if results.len() > 1 {
assert_eq!(
results[1]["name"], "do_parse_config",
"Second result should be the substring match"
);
let second = results[1]["score"].as_f64().unwrap();
assert!(
second < top_score,
"Substring (score={second}) must rank below prefix (score={top_score})"
);
}
}
#[test]
fn test_dispatch_search_name_only_limit_clamp() {
let chunks: Vec<Chunk> = (0..10)
.map(|i| {
make_chunk(
&format!("src/lib.rs:{i}:cccc{i:04}"),
"src/lib.rs",
Language::Rust,
ChunkType::Function,
&format!("handler_{i}"),
&format!("fn handler_{i}()"),
&format!("fn handler_{i}() {{}}"),
)
})
.collect();
let (_dir, ctx) = ctx_with_chunks(chunks);
let default = parse_search_args(&["handler", "--name-only"]);
let json = dispatch_search(&ctx, &default).expect("dispatch_search failed");
let results = json["results"].as_array().unwrap();
assert_eq!(
results.len(),
5,
"Default limit=5 must bound results; got {} with total={}",
results.len(),
json["total"]
);
assert_eq!(json["total"], 5, "total must equal results.len()");
for r in results {
let name = r["name"].as_str().unwrap();
assert!(
name.starts_with("handler_"),
"All results must be handler_* prefix matches, got '{name}'"
);
}
let three = parse_search_args(&["handler", "--name-only", "--limit", "3"]);
let json = dispatch_search(&ctx, &three).expect("dispatch_search failed");
assert_eq!(json["total"], 3);
assert_eq!(json["results"].as_array().unwrap().len(), 3);
}
#[test]
fn test_dispatch_search_name_only_no_match_returns_empty() {
let (_dir, ctx) = ctx_with_chunk(
"src/lib.rs:1:dddd0001",
"src/lib.rs",
Language::Rust,
ChunkType::Function,
"alpha",
"fn alpha()",
"fn alpha() {}",
);
let args = parse_search_args(&["zxyvwu_no_such_name", "--name-only"]);
let json = dispatch_search(&ctx, &args).expect("dispatch_search failed");
assert_eq!(json["query"], "zxyvwu_no_such_name");
assert_eq!(json["total"], 0, "No-match query must return total=0");
assert_eq!(
json["results"].as_array().unwrap().len(),
0,
"Empty results array must be present (callers depend on schema)"
);
}
#[test]
fn test_dispatch_search_name_only_cross_language_content() {
let (_dir, ctx) = ctx_with_chunks(vec![
make_chunk(
"src/lib.rs:1:eeee0001",
"src/lib.rs",
Language::Rust,
ChunkType::Function,
"validate_input",
"fn validate_input()",
"fn validate_input() {}",
),
make_chunk(
"src/app.py:1:eeee0002",
"src/app.py",
Language::Python,
ChunkType::Function,
"validate_input",
"def validate_input()",
"def validate_input():\n pass",
),
]);
let args = parse_search_args(&["validate_input", "--name-only"]);
let json = dispatch_search(&ctx, &args).expect("dispatch_search failed");
let results = json["results"].as_array().unwrap();
assert_eq!(
json["total"], 2,
"Expected both Rust and Python 'validate_input' chunks"
);
assert_eq!(results.len(), 2);
let languages: std::collections::HashSet<&str> = results
.iter()
.map(|r| r["language"].as_str().unwrap())
.collect();
assert!(
languages.contains("rust"),
"Rust result missing from {languages:?}"
);
assert!(
languages.contains("python"),
"Python result missing from {languages:?}"
);
for r in results {
assert_eq!(r["name"], "validate_input");
let score = r["score"].as_f64().unwrap();
assert!(
(score - 1.0).abs() < 1e-6,
"Exact match on {} should score 1.0, got {score}",
r["language"]
);
}
}
#[test]
fn test_dispatch_search_invalid_include_type_errors_fast() {
let (_dir, ctx) = empty_ctx();
let args = parse_search_args(&["anything", "--include-type", "not_a_real_type"]);
let err = dispatch_search(&ctx, &args)
.expect_err("Invalid --include-type must error, not silently return all types");
let msg = format!("{err:#}");
assert!(
msg.contains("Invalid --include-type"),
"Error message must reference --include-type flag, got: {msg}"
);
assert!(
msg.contains("not_a_real_type"),
"Error must surface the offending input, got: {msg}"
);
}
#[test]
fn test_dispatch_search_invalid_exclude_type_errors_fast() {
let (_dir, ctx) = empty_ctx();
let args = parse_search_args(&["anything", "--exclude-type", "bogusbogus"]);
let err = dispatch_search(&ctx, &args)
.expect_err("Invalid --exclude-type must error, not silently accept");
let msg = format!("{err:#}");
assert!(
msg.contains("Invalid --exclude-type"),
"Error message must reference --exclude-type flag, got: {msg}"
);
assert!(
msg.contains("bogusbogus"),
"Error must surface the offending input, got: {msg}"
);
}
}