use anyhow::{bail, Context, Result};
use clap::Args;
use indicatif::{ProgressBar, ProgressStyle};
use qdrant_client::{qdrant::{PointStruct, VectorParamsBuilder, CreateCollectionBuilder, Distance, FieldType}, Payload, Qdrant};
use std::{
collections::HashSet,
path::PathBuf,
sync::Arc,
time::Duration,
};
use uuid::Uuid;
use walkdir::WalkDir;
use git2::Repository;
use crate::{
config::AppConfig,
syntax,
vectordb::{embedding, embedding_logic::EmbeddingHandler},
};
use super::commands::{
upsert_batch, BATCH_SIZE, CliArgs, FIELD_CHUNK_CONTENT, FIELD_ELEMENT_TYPE,
FIELD_END_LINE, FIELD_FILE_EXTENSION, FIELD_FILE_PATH, FIELD_LANGUAGE, FIELD_START_LINE,
ensure_payload_index,
};
use super::repo_commands::{get_collection_name, FIELD_BRANCH, FIELD_COMMIT_HASH, DEFAULT_VECTOR_DIMENSION};
const LEGACY_INDEX_COLLECTION: &str = "vectordb-code-search";
#[derive(Args, Debug)]
pub struct IndexArgs {
#[arg(required = true)]
pub paths: Vec<PathBuf>,
#[arg(short = 'e', long = "extension")]
pub file_extensions: Option<Vec<String>>,
}
pub async fn handle_index(
cmd_args: &IndexArgs,
cli_args: &CliArgs,
config: AppConfig, client: Arc<Qdrant>,
) -> Result<()> {
log::info!("Starting legacy indexing process...");
let collection_name = LEGACY_INDEX_COLLECTION;
log::info!("Indexing into default collection: '{}'", collection_name);
ensure_legacy_collection_exists(&client, collection_name).await?;
for path in &cmd_args.paths {
if !path.exists() {
bail!("Input path does not exist: {}", path.display());
}
}
log::info!("Processing input paths: {:?}", cmd_args.paths);
let model_env_var = std::env::var("VECTORDB_ONNX_MODEL").ok();
let tokenizer_env_var = std::env::var("VECTORDB_ONNX_TOKENIZER_DIR").ok();
if cli_args.onnx_model_path_arg.is_some() && model_env_var.is_some() {
return Err(anyhow::anyhow!("Cannot provide ONNX model path via both --onnx-model argument and VECTORDB_ONNX_MODEL environment variable."));
}
if cli_args.onnx_tokenizer_dir_arg.is_some() && tokenizer_env_var.is_some() {
return Err(anyhow::anyhow!("Cannot provide ONNX tokenizer dir via both --onnx-tokenizer-dir argument and VECTORDB_ONNX_TOKENIZER_DIR environment variable."));
}
let onnx_model_path_str = cli_args.onnx_model_path_arg.as_ref()
.or(model_env_var.as_ref())
.or(config.onnx_model_path.as_ref())
.ok_or_else(|| anyhow::anyhow!("ONNX model path must be provided via --onnx-model, VECTORDB_ONNX_MODEL, or config"))?;
let onnx_tokenizer_dir_str = cli_args.onnx_tokenizer_dir_arg.as_ref()
.or(tokenizer_env_var.as_ref())
.or(config.onnx_tokenizer_path.as_ref())
.ok_or_else(|| anyhow::anyhow!("ONNX tokenizer path must be provided via --onnx-tokenizer-dir, VECTORDB_ONNX_TOKENIZER_DIR, or config"))?;
let onnx_model_path = PathBuf::from(onnx_model_path_str);
let onnx_tokenizer_path = PathBuf::from(onnx_tokenizer_dir_str);
if !onnx_model_path.exists() {
return Err(anyhow::anyhow!("Resolved ONNX model path does not exist: {}", onnx_model_path.display()));
}
if !onnx_tokenizer_path.is_dir() {
return Err(anyhow::anyhow!("Resolved ONNX tokenizer path is not a directory: {}", onnx_tokenizer_path.display()));
}
let tokenizer_file = onnx_tokenizer_path.join("tokenizer.json");
if !tokenizer_file.exists() {
return Err(anyhow::anyhow!("tokenizer.json not found in the ONNX tokenizer directory: {}", onnx_tokenizer_path.display()));
}
log::info!("Using resolved ONNX model: {}", onnx_model_path.display());
log::info!("Using resolved ONNX tokenizer directory: {}", onnx_tokenizer_path.display());
log::info!("Initializing embedding handler...");
let embedding_handler = Arc::new(
EmbeddingHandler::new(
embedding::EmbeddingModelType::Onnx,
Some(onnx_model_path),
Some(onnx_tokenizer_path),
)
.context("Failed to initialize embedding handler")?,
);
let embedding_dim = embedding_handler
.dimension()
.context("Failed to get embedding dimension")?;
log::info!("Embedding dimension: {}", embedding_dim);
if !client.collection_exists(collection_name.to_string()).await? {
bail!("Collection '{}' not found. Please run 'repo add' again or check Qdrant.", collection_name);
}
let file_types_set: Option<HashSet<String>> = cmd_args
.file_extensions
.as_ref()
.map(|ft_vec| {
ft_vec
.iter()
.map(|s| s.trim_start_matches('.').to_lowercase())
.collect()
});
if let Some(ref ft_set) = file_types_set {
log::info!("Filtering by file extensions: {:?}", ft_set);
}
log::info!("Starting file traversal and processing...");
let pb_style = ProgressStyle::with_template(
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} files ({per_sec}) {msg}",
)?
.progress_chars("#>-");
let pb = ProgressBar::new(0);
pb.set_style(pb_style);
pb.enable_steady_tick(Duration::from_millis(100));
pb.set_message("Scanning directories...");
let mut files_to_process = Vec::new();
for path_arg in &cmd_args.paths {
if path_arg.is_file() {
if let Some(ref filter_set) = file_types_set {
let extension = path_arg
.extension()
.and_then(|ext| ext.to_str())
.map(|s| s.to_lowercase())
.unwrap_or_default();
if filter_set.contains(&extension) {
files_to_process.push(path_arg.clone());
} else {
log::trace!("Skipping file due to extension filter: {}", path_arg.display());
}
} else {
files_to_process.push(path_arg.clone());
}
} else if path_arg.is_dir() {
for entry_result in WalkDir::new(path_arg).into_iter().filter_map(|e| e.ok()) {
let absolute_path = entry_result.path();
if !absolute_path.is_file() {
continue;
}
if let Some(ref filter_set) = file_types_set {
let extension = absolute_path
.extension()
.and_then(|ext| ext.to_str())
.map(|s| s.to_lowercase())
.unwrap_or_default();
if !filter_set.contains(&extension) {
log::trace!("Skipping file due to extension filter: {}", absolute_path.display());
continue;
}
}
files_to_process.push(absolute_path.to_path_buf());
}
} else {
log::warn!("Input path is neither a file nor a directory: {}. Skipping.", path_arg.display());
}
}
pb.set_length(files_to_process.len() as u64);
pb.set_position(0);
pb.set_message("Processing files...");
let mut total_points_processed = 0;
let mut total_files_processed = 0;
let mut total_files_skipped = 0;
let model = embedding_handler
.create_embedding_model()
.context("Failed to create embedding model")?;
let mut points_batch = Vec::with_capacity(BATCH_SIZE);
for file_path in files_to_process { let absolute_path_str = file_path.to_string_lossy().to_string(); log::debug!("Processing file: {}", file_path.display());
let chunks = match syntax::get_chunks(&file_path) {
Ok(chunks) => chunks,
Err(e) => {
log::warn!("Failed to parse file {}: {}. Skipping.", file_path.display(), e);
pb.println(format!("Warning: Failed to parse {}, skipping.", file_path.display()));
total_files_skipped += 1;
pb.inc(1);
continue;
}
};
if chunks.is_empty() {
log::debug!("No text chunks found in file {}. Skipping.", file_path.display());
total_files_skipped += 1;
pb.inc(1);
continue;
}
let chunk_contents: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
let embeddings = match model.embed_batch(&chunk_contents) {
Ok(embeddings) => embeddings,
Err(e) => {
log::error!(
"Failed to generate embeddings for {}: {}. Skipping file.",
file_path.display(),
e
);
pb.println(format!("Error embedding {}, skipping.", file_path.display()));
total_files_skipped += 1;
pb.inc(1); continue;
}
};
let file_extension = file_path.extension().and_then(|ext| ext.to_str()).unwrap_or("").to_string();
for (i, chunk) in chunks.iter().enumerate() {
let mut payload = Payload::new();
payload.insert(FIELD_FILE_PATH, absolute_path_str.clone()); payload.insert(FIELD_START_LINE, chunk.start_line as i64);
payload.insert(FIELD_END_LINE, chunk.end_line as i64);
payload.insert(FIELD_LANGUAGE, chunk.language.to_string());
payload.insert(FIELD_FILE_EXTENSION, file_extension.clone());
payload.insert(FIELD_ELEMENT_TYPE, chunk.element_type.clone());
payload.insert(FIELD_CHUNK_CONTENT, chunk.content.clone());
let point = PointStruct::new(
Uuid::new_v4().to_string(), embeddings[i].clone(), payload,
);
points_batch.push(point);
if points_batch.len() >= BATCH_SIZE {
let batch_to_upsert = std::mem::take(&mut points_batch);
upsert_batch(&client, &collection_name, batch_to_upsert, &pb).await?;
total_points_processed += BATCH_SIZE;
}
}
total_files_processed += 1;
pb.inc(1); }
if !points_batch.is_empty() {
let final_batch_size = points_batch.len();
upsert_batch(&client, &collection_name, points_batch, &pb).await?;
total_points_processed += final_batch_size;
}
pb.finish_with_message("Indexing complete!");
log::info!("Indexing finished.");
log::info!("Total files processed: {}", total_files_processed);
log::info!("Total files skipped: {}", total_files_skipped);
log::info!("Total points upserted: {}", total_points_processed);
Ok(())
}
async fn ensure_legacy_collection_exists(
client: &Qdrant,
collection_name: &str,
) -> Result<()> {
let exists = client.collection_exists(collection_name.to_string()).await?; if !exists {
log::info!("Default collection '{}' does not exist. Creating...", collection_name);
let vector_params = VectorParamsBuilder::new(DEFAULT_VECTOR_DIMENSION, Distance::Cosine).build(); let create_request = CreateCollectionBuilder::new(collection_name)
.vectors_config(vector_params)
.build();
client.create_collection(create_request).await?;
log::info!("Default collection '{}' created.", collection_name);
tokio::time::sleep(Duration::from_millis(100)).await;
let mut attempts = 0;
loop {
let info = client.collection_info(collection_name.to_string()).await?; if info.result.map_or(false, |i| i.status == qdrant_client::qdrant::CollectionStatus::Green as i32) {
break;
}
attempts += 1;
if attempts > 50 {
bail!("Collection '{}' did not become ready in time.", collection_name);
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
log::info!("Collection '{}' is ready.", collection_name);
}
ensure_payload_index(client, collection_name, FIELD_FILE_PATH, FieldType::Keyword).await?;
ensure_payload_index(client, collection_name, FIELD_START_LINE, FieldType::Integer).await?;
ensure_payload_index(client, collection_name, FIELD_END_LINE, FieldType::Integer).await?;
ensure_payload_index(client, collection_name, FIELD_LANGUAGE, FieldType::Keyword).await?;
Ok(())
}