use anyhow::{anyhow, bail, Context, Result};
use clap::Args;
use futures::future::join_all;
use qdrant_client::{
qdrant::{SearchPointsBuilder, Filter, Condition, PointStruct, ScoredPoint},
Qdrant,
};
use std::{
sync::Arc,
path::PathBuf,
};
use crate::{
config::AppConfig,
vectordb::{embedding, embedding_logic::EmbeddingHandler},
cli::repo_commands::get_collection_name,
cli::formatters::print_search_results,
cli::commands::{FIELD_LANGUAGE, FIELD_ELEMENT_TYPE, FIELD_BRANCH},
};
use super::commands::CliArgs;
#[derive(Args, Debug)]
pub struct QueryArgs {
#[arg(required = true)]
pub query: String,
#[arg(short, long, default_value_t = 10)]
pub limit: u64,
#[arg(short, long, conflicts_with = "all_repos")]
pub repo: Option<Vec<String>>,
#[arg(long)]
pub all_repos: bool,
#[arg(short, long)]
pub branch: Option<String>,
#[arg(long)]
pub lang: Option<String>,
#[arg(long = "type")]
pub element_type: Option<String>,
}
pub async fn handle_query(
args: &QueryArgs,
cli_args: &CliArgs,
config: AppConfig,
client: Arc<Qdrant>,
) -> Result<()> {
log::info!("Starting query process...");
let target_repos: Vec<String> = match (&args.repo, args.all_repos) {
(Some(repo_names), _) => { for name in repo_names {
if !config.repositories.iter().any(|r| r.name == *name) {
bail!("Repository '{}' not found in configuration.", name);
}
}
repo_names.clone()
}
(None, true) => { config.repositories.iter().map(|r| r.name.clone()).collect()
}
(None, false) => { vec![config.active_repository.clone().ok_or_else(|| {
anyhow!("No active repository set and no specific repository requested via --repo or --all-repos. Use 'repo use <name>' first.")
})?]
}
};
if target_repos.is_empty() {
println!("No repositories configured or specified to search.");
return Ok(());
}
log::info!("Target repositories: {:?}", target_repos);
let collection_names: Vec<String> = target_repos.iter().map(|name| get_collection_name(name)).collect();
let model_env_var = std::env::var("VECTORDB_ONNX_MODEL").ok();
let tokenizer_env_var = std::env::var("VECTORDB_ONNX_TOKENIZER_DIR").ok();
let onnx_model_path_str = cli_args.onnx_model_path_arg.as_ref()
.or(model_env_var.as_ref())
.or(config.onnx_model_path.as_ref())
.ok_or_else(|| anyhow!("ONNX model path must be provided via --onnx-model, VECTORDB_ONNX_MODEL, or config"))?;
let onnx_tokenizer_dir_str = cli_args.onnx_tokenizer_dir_arg.as_ref()
.or(tokenizer_env_var.as_ref())
.or(config.onnx_tokenizer_path.as_ref())
.ok_or_else(|| anyhow!("ONNX tokenizer path must be provided via --onnx-tokenizer-dir, VECTORDB_ONNX_TOKENIZER_DIR, or config"))?;
let onnx_model_path = PathBuf::from(onnx_model_path_str);
let onnx_tokenizer_path = PathBuf::from(onnx_tokenizer_dir_str);
let embedding_handler = EmbeddingHandler::new(
embedding::EmbeddingModelType::Onnx,
Some(onnx_model_path),
Some(onnx_tokenizer_path),
)
.context("Failed to initialize embedding handler")?;
let query_embedding = embedding_handler.create_embedding_model()?
.embed(&args.query)?;
log::info!("Query embedding generated.");
let mut filter_conditions = Vec::new();
if let Some(branch_name) = &args.branch {
if !target_repos.is_empty() {
filter_conditions.push(Condition::matches(FIELD_BRANCH, branch_name.clone()));
log::info!("Filtering by branch: {}", branch_name);
} else {
log::warn!("Branch filter specified but no repository target found (this shouldn't happen). Ignoring filter.");
}
}
if let Some(lang_name) = &args.lang {
filter_conditions.push(Condition::matches(FIELD_LANGUAGE, lang_name.clone()));
log::info!("Filtering by language: {}", lang_name);
}
if let Some(element_type) = &args.element_type {
filter_conditions.push(Condition::matches(FIELD_ELEMENT_TYPE, element_type.clone()));
log::info!("Filtering by element type: {}", element_type);
}
let search_filter = if filter_conditions.is_empty() { None } else { Some(Filter::must(filter_conditions)) };
log::info!("Executing search against collections: {:?}...", collection_names);
let search_futures: Vec<_> = collection_names.into_iter().map(|collection_name| {
let client = Arc::clone(&client);
let query_embedding_clone = query_embedding.clone();
let search_filter_clone = search_filter.clone();
let limit = args.limit;
tokio::spawn(async move {
let mut builder = SearchPointsBuilder::new(&collection_name, query_embedding_clone, limit)
.with_payload(true);
if let Some(filter) = search_filter_clone {
builder = builder.filter(filter);
}
let search_request = builder.build();
client.search_points(search_request).await
})
}).collect();
let search_results = join_all(search_futures).await;
let mut all_scored_points = Vec::new();
let mut errors = Vec::new();
for (i, result) in search_results.into_iter().enumerate() {
match result {
Ok(Ok(search_response)) => {
log::debug!("Search returned {} results from collection {}", search_response.result.len(), target_repos[i]);
all_scored_points.extend(search_response.result);
}
Ok(Err(e)) => {
let err_msg = format!("Qdrant search failed for repo '{}': {}", target_repos[i], e);
log::error!("{}", err_msg);
errors.push(err_msg);
}
Err(e) => { let err_msg = format!("Task panicked for repo '{}': {}", target_repos[i], e);
log::error!("{}", err_msg);
errors.push(err_msg);
}
}
}
if !errors.is_empty() {
eprintln!("Warning: Some searches failed:\n - {}", errors.join("\n - "));
}
all_scored_points.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
all_scored_points.truncate(args.limit as usize);
log::info!("Total unique results after aggregation: {}", all_scored_points.len());
print_search_results(&all_scored_points, &args.query)?;
Ok(())
}