pub mod paths;
pub mod state;
pub mod worktree;
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::sync::{mpsc, Arc};
use std::thread;
use anyhow::{Context, Result};
use globset::{Glob, GlobSet, GlobSetBuilder};
use ignore::gitignore::GitignoreBuilder;
use ignore::WalkBuilder;
use indicatif::{ProgressBar, ProgressStyle};
use next_plaid::{
delete_from_index, encode_index_chunk, filtering, prepare_codec_artifacts,
write_index_from_encoded_chunks, EncodedIndexChunk, IndexConfig, Metadata, MmapIndex,
SearchParameters, UpdateConfig,
};
use next_plaid_onnx::{pool_document_embeddings, Colbert, ExecutionProvider};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
#[cfg(feature = "cuda")]
use crate::acceleration::apply_acceleration_mode;
use crate::acceleration::{env_acceleration_mode_lossy, AccelerationMode};
use crate::embed::build_embedding_text;
use crate::parser::{build_call_graph, detect_language, extract_units, CodeUnit, Language};
use crate::signal::{is_interrupted, is_interrupted_outside_critical, CriticalSectionGuard};
use paths::{
acquire_index_lock, get_index_dir_for_project, get_vector_index_path, try_acquire_index_lock,
ProjectMetadata,
};
use state::{get_mtime, hash_file, FileInfo, IndexState};
const MAX_FILE_SIZE: u64 = 512 * 1024;
const INDEX_CHUNK_SIZE: usize = 1024;
const BUILDING_MARKER: &str = ".building";
const BUILD_CHECKPOINT_UNITS: usize = 4096;
const LARGE_BATCH_THRESHOLD: usize = 10_000;
const LARGE_BATCH_POOL_FACTOR: usize = 2;
const DEFAULT_ENCODE_BATCH_SIZE: usize = 64;
#[cfg(feature = "cuda")]
const SMALL_BATCH_CPU_THRESHOLD: usize = 300;
const POOLED_EMBEDDING_QUEUE_CAPACITY: usize = 4;
const METADATA_QUEUE_CAPACITY: usize = 8;
struct ParsedFileResult {
path: PathBuf,
units: Vec<CodeUnit>,
file_info: Option<FileInfo>,
skip_reason: Option<String>,
}
#[derive(Debug)]
pub struct UpdateStats {
pub added: usize,
pub changed: usize,
pub deleted: usize,
pub unchanged: usize,
pub skipped: usize,
}
#[derive(Debug, Default)]
pub struct UpdatePlan {
pub added: Vec<PathBuf>,
pub changed: Vec<PathBuf>,
pub deleted: Vec<PathBuf>,
pub unchanged: usize,
}
#[derive(Clone)]
struct SortedUnit {
unit: Arc<CodeUnit>,
text: Arc<str>,
}
struct PreparedChunk {
units: Vec<Arc<CodeUnit>>,
unique_texts: Vec<Arc<str>>,
original_to_unique: Vec<usize>,
}
struct TokenizedChunk {
units: Vec<Arc<CodeUnit>>,
prepared_batches: Vec<next_plaid_onnx::PreparedDocumentBatch>,
original_to_unique: Vec<usize>,
}
struct RawEncodedChunk {
units: Vec<Arc<CodeUnit>>,
raw_embeddings: Vec<ndarray::Array2<f32>>,
original_to_unique: Vec<usize>,
}
struct PooledChunkForIndex {
units: Vec<Arc<CodeUnit>>,
embeddings: Vec<ndarray::Array2<f32>>,
}
struct IndexedChunkForMetadata {
units: Vec<Arc<CodeUnit>>,
doc_ids: Vec<i64>,
}
struct ChunkPipelineConfig<'a> {
index_chunk_size: usize,
pool_factor: Option<usize>,
index_path: &'a str,
config: IndexConfig,
update_config: UpdateConfig,
pb: Option<&'a ProgressBar>,
}
struct ChunkForCoding {
embeddings: Vec<ndarray::Array2<f32>>,
}
pub const CONFIRMATION_THRESHOLD: usize = 30_000;
fn prepare_units_for_encoding(units: &[CodeUnit], sample_prefix_size: usize) -> Vec<SortedUnit> {
let mut items: Vec<SortedUnit> = units
.iter()
.map(|unit| SortedUnit {
unit: Arc::new(unit.clone()),
text: Arc::<str>::from(build_embedding_text(unit)),
})
.collect();
items.sort_unstable_by(|a, b| {
a.unit
.file
.cmp(&b.unit.file)
.then_with(|| a.unit.line.cmp(&b.unit.line))
});
let sample_prefix_size = sample_prefix_size.min(items.len());
if sample_prefix_size > 0 && sample_prefix_size < items.len() {
let stride = items.len() / sample_prefix_size;
let sampled_indices: std::collections::HashSet<usize> =
(0..sample_prefix_size).map(|i| i * stride).collect();
let (prefix, remainder): (Vec<_>, Vec<_>) = items
.into_iter()
.enumerate()
.partition::<Vec<_>, _>(|(i, _)| sampled_indices.contains(i));
let mut result: Vec<SortedUnit> = prefix.into_iter().map(|(_, item)| item).collect();
result.extend(remainder.into_iter().map(|(_, item)| item));
result
} else {
items
}
}
fn prepare_deduplicated_chunk(unit_chunk: &[SortedUnit]) -> PreparedChunk {
let mut index_by_text: HashMap<&str, usize> = HashMap::new();
let mut unique_texts: Vec<Arc<str>> = Vec::new();
let mut original_to_unique: Vec<usize> = Vec::with_capacity(unit_chunk.len());
for item in unit_chunk.iter() {
if let Some(&unique_idx) = index_by_text.get(item.text.as_ref()) {
original_to_unique.push(unique_idx);
} else {
let unique_idx = unique_texts.len();
index_by_text.insert(item.text.as_ref(), unique_idx);
unique_texts.push(Arc::clone(&item.text));
original_to_unique.push(unique_idx);
}
}
PreparedChunk {
units: unit_chunk
.iter()
.map(|item| Arc::clone(&item.unit))
.collect(),
unique_texts,
original_to_unique,
}
}
fn run_encode_stage(
receiver: mpsc::Receiver<TokenizedChunk>,
sender: mpsc::Sender<RawEncodedChunk>,
model: Colbert,
) -> Result<()> {
while let Ok(chunk) = receiver.recv() {
let raw_embeddings = model.encode_prepared_document_batches(chunk.prepared_batches)?;
sender
.send(RawEncodedChunk {
units: chunk.units,
raw_embeddings,
original_to_unique: chunk.original_to_unique,
})
.context("Failed to send raw embeddings to pooling stage")?;
}
Ok(())
}
fn run_tokenize_stage(
receiver: mpsc::Receiver<PreparedChunk>,
sender: mpsc::Sender<TokenizedChunk>,
model: Colbert,
) -> Result<()> {
while let Ok(chunk) = receiver.recv() {
let text_refs: Vec<&str> = chunk
.unique_texts
.iter()
.map(|text| text.as_ref())
.collect();
let prepared_batches = model.tokenize_documents_in_batches(&text_refs)?;
sender
.send(TokenizedChunk {
units: chunk.units,
prepared_batches,
original_to_unique: chunk.original_to_unique,
})
.context("Failed to send tokenized chunk to encode stage")?;
}
Ok(())
}
fn run_pool_stage(
receiver: mpsc::Receiver<RawEncodedChunk>,
sender: mpsc::SyncSender<PooledChunkForIndex>,
pool_factor: Option<usize>,
) -> Result<()> {
while let Ok(chunk) = receiver.recv() {
let pooled_unique = pool_document_embeddings(chunk.raw_embeddings, pool_factor);
let embeddings = chunk
.original_to_unique
.into_iter()
.map(|unique_idx| pooled_unique[unique_idx].clone())
.collect();
sender
.send(PooledChunkForIndex {
units: chunk.units,
embeddings,
})
.context("Failed to send pooled embeddings to index stage")?;
}
Ok(())
}
fn run_index_stage(
receiver: mpsc::Receiver<PooledChunkForIndex>,
sender: mpsc::SyncSender<IndexedChunkForMetadata>,
index_path: String,
config: IndexConfig,
update_config: UpdateConfig,
initial_kmeans_sample_docs: usize,
) -> Result<()> {
let initial_create = !Path::new(&index_path).join("metadata.json").exists();
if initial_create {
let first_chunk = match receiver.recv() {
Ok(chunk) => chunk,
Err(_) => return Ok(()),
};
let mut next_doc_id = 0i64;
let mut initial_create_config = config.clone();
if initial_create_config.n_samples_kmeans.is_none() {
initial_create_config.n_samples_kmeans = Some(initial_kmeans_sample_docs.max(1));
}
let sample_embeddings = first_chunk.embeddings.clone();
let kmeans_config = initial_create_config.clone();
let kmeans_handle = thread::Builder::new()
.name("colgrep-kmeans".to_string())
.spawn(move || -> Result<_> {
let centroids = next_plaid::compute_kmeans(
&sample_embeddings,
&next_plaid::kmeans::ComputeKmeansConfig {
kmeans_niters: kmeans_config.kmeans_niters,
max_points_per_centroid: kmeans_config.max_points_per_centroid,
seed: kmeans_config.seed.unwrap_or(42),
n_samples_kmeans: kmeans_config.n_samples_kmeans,
num_partitions: None,
force_cpu: kmeans_config.force_cpu,
},
)?;
Ok(prepare_codec_artifacts(
&sample_embeddings,
centroids,
&kmeans_config,
)?)
})
.context("Failed to spawn kmeans stage thread")?;
let (dedup_tx, dedup_rx) = mpsc::sync_channel::<ChunkForCoding>(8);
let coding_index_path = index_path.clone();
let coding_config = initial_create_config.clone();
let coding_force_cpu = update_config.force_cpu;
let coding_handle = thread::Builder::new()
.name("colgrep-coding".to_string())
.spawn(move || -> Result<()> {
let codec_artifacts = kmeans_handle
.join()
.map_err(|_| anyhow::anyhow!("K-means stage thread panicked"))??;
let mut encoded_chunks: Vec<EncodedIndexChunk> = Vec::new();
while let Ok(chunk) = dedup_rx.recv() {
let encoded = encode_index_chunk(
&chunk.embeddings,
&codec_artifacts.codec,
coding_force_cpu,
)?;
encoded_chunks.push(encoded);
}
let _guard = CriticalSectionGuard::new();
write_index_from_encoded_chunks(
&encoded_chunks,
&codec_artifacts,
&coding_index_path,
&coding_config,
)?;
Ok(())
})
.context("Failed to spawn coding stage thread")?;
let mut handle_chunk = |chunk: PooledChunkForIndex| -> Result<()> {
let doc_count = chunk.embeddings.len();
let doc_ids: Vec<i64> = (next_doc_id..next_doc_id + doc_count as i64).collect();
next_doc_id += doc_count as i64;
dedup_tx
.send(ChunkForCoding {
embeddings: chunk.embeddings,
})
.context("Failed to send chunk to coding stage")?;
sender
.send(IndexedChunkForMetadata {
units: chunk.units,
doc_ids,
})
.context("Failed to send indexed chunk to metadata stage")
};
handle_chunk(first_chunk)?;
while let Ok(chunk) = receiver.recv() {
handle_chunk(chunk)?;
}
drop(dedup_tx);
coding_handle
.join()
.map_err(|_| anyhow::anyhow!("Coding stage thread panicked"))??;
return Ok(());
}
while let Ok(chunk) = receiver.recv() {
let _guard = CriticalSectionGuard::new();
let (_, doc_ids) =
MmapIndex::update_or_create(&chunk.embeddings, &index_path, &config, &update_config)?;
sender
.send(IndexedChunkForMetadata {
units: chunk.units,
doc_ids,
})
.context("Failed to send indexed chunk to metadata stage")?;
}
Ok(())
}
const DEFAULT_N_FULL_SCORES: usize = 8192;
fn search_params_from_env(top_k: usize) -> SearchParameters {
let defaults = SearchParameters::default();
let n_ivf_probe = std::env::var("COLGREP_N_IVF_PROBE")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v > 0)
.unwrap_or(defaults.n_ivf_probe);
let n_full_scores = std::env::var("COLGREP_N_FULL_SCORES")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|v| *v > 0)
.unwrap_or(DEFAULT_N_FULL_SCORES);
let centroid_score_threshold = match std::env::var("COLGREP_CENTROID_SCORE_THRESHOLD") {
Ok(s) => match s.parse::<f32>() {
Ok(v) if v < 0.0 => None,
Ok(v) => Some(v),
Err(_) => defaults.centroid_score_threshold,
},
Err(_) => defaults.centroid_score_threshold,
};
SearchParameters {
top_k,
n_ivf_probe,
n_full_scores,
centroid_score_threshold,
..defaults
}
}
fn run_metadata_stage(
receiver: mpsc::Receiver<IndexedChunkForMetadata>,
index_path: String,
pb: Option<ProgressBar>,
) -> Result<()> {
let mut filtering_exists = filtering::exists(&index_path);
let mut completed_units = 0u64;
while let Ok(chunk) = receiver.recv() {
let metadata: Vec<serde_json::Value> = chunk
.units
.iter()
.map(|unit| serde_json::to_value(unit.as_ref()).unwrap())
.collect();
let db_result = if filtering_exists {
filtering::update(&index_path, &metadata, &chunk.doc_ids)
} else {
filtering::create(&index_path, &metadata, &chunk.doc_ids)
};
if let Err(e) = db_result {
if let Err(rollback_err) = delete_from_index(&chunk.doc_ids, &index_path) {
eprintln!("⚠️ Rollback failed: {}", rollback_err);
}
return Err(e.into());
}
if let Err(e) = next_plaid::text_search::index(
&index_path,
&metadata,
&chunk.doc_ids,
&next_plaid::FtsTokenizer::IdentifierAware,
) {
eprintln!("⚠️ FTS indexing failed (non-fatal): {}", e);
}
filtering_exists = true;
completed_units += chunk.units.len() as u64;
if let Some(pb) = pb.as_ref() {
pb.set_position(completed_units);
}
}
Ok(())
}
fn run_chunk_pipeline(
model: Colbert,
sorted_units: &[SortedUnit],
pipeline: ChunkPipelineConfig<'_>,
) -> Result<bool> {
let mut was_interrupted = false;
let ChunkPipelineConfig {
index_chunk_size,
pool_factor,
index_path,
config,
update_config,
pb,
} = pipeline;
let (tokenize_tx, tokenize_rx) = mpsc::channel::<PreparedChunk>();
let (encode_tx, encode_rx) = mpsc::channel::<TokenizedChunk>();
let (pool_tx, pool_rx) = mpsc::channel::<RawEncodedChunk>();
let (index_tx, index_rx) =
mpsc::sync_channel::<PooledChunkForIndex>(POOLED_EMBEDDING_QUEUE_CAPACITY);
let (metadata_tx, metadata_rx) =
mpsc::sync_channel::<IndexedChunkForMetadata>(METADATA_QUEUE_CAPACITY);
let tokenize_model = model.clone();
let tokenize_handle = thread::Builder::new()
.name("colgrep-tokenize".to_string())
.spawn(move || run_tokenize_stage(tokenize_rx, encode_tx, tokenize_model))
.context("Failed to spawn tokenize stage thread")?;
let encode_model = model.clone();
let encode_handle = thread::Builder::new()
.name("colgrep-encode".to_string())
.spawn(move || run_encode_stage(encode_rx, pool_tx, encode_model))
.context("Failed to spawn encode stage thread")?;
let pool_handle = thread::Builder::new()
.name("colgrep-pool".to_string())
.spawn(move || run_pool_stage(pool_rx, index_tx, pool_factor))
.context("Failed to spawn pool stage thread")?;
let index_path_for_index = index_path.to_string();
let index_handle = thread::Builder::new()
.name("colgrep-index".to_string())
.spawn(move || {
run_index_stage(
index_rx,
metadata_tx,
index_path_for_index,
config,
update_config,
index_chunk_size,
)
})
.context("Failed to spawn index stage thread")?;
let index_path_for_metadata = index_path.to_string();
let metadata_pb = pb.cloned();
let metadata_handle = thread::Builder::new()
.name("colgrep-metadata".to_string())
.spawn(move || run_metadata_stage(metadata_rx, index_path_for_metadata, metadata_pb))
.context("Failed to spawn metadata stage thread")?;
for unit_chunk in sorted_units.chunks(index_chunk_size) {
if is_interrupted_outside_critical() {
was_interrupted = true;
break;
}
let prepared = prepare_deduplicated_chunk(unit_chunk);
tokenize_tx
.send(prepared)
.context("Failed to send prepared chunk to tokenize stage")?;
}
drop(tokenize_tx);
tokenize_handle
.join()
.map_err(|_| anyhow::anyhow!("Tokenize stage thread panicked"))??;
encode_handle
.join()
.map_err(|_| anyhow::anyhow!("Encode stage thread panicked"))??;
pool_handle
.join()
.map_err(|_| anyhow::anyhow!("Pool stage thread panicked"))??;
index_handle
.join()
.map_err(|_| anyhow::anyhow!("Index stage thread panicked"))??;
metadata_handle
.join()
.map_err(|_| anyhow::anyhow!("Metadata stage thread panicked"))??;
Ok(was_interrupted)
}
fn parse_files_parallel(
project_root: &Path,
paths: &[PathBuf],
pb: Option<&ProgressBar>,
) -> Vec<ParsedFileResult> {
let progress = pb.cloned();
paths
.par_iter()
.map(|path| {
if is_interrupted() {
return ParsedFileResult {
path: path.clone(),
units: Vec::new(),
file_info: None,
skip_reason: None,
};
}
let full_path = project_root.join(path);
let result = match detect_language(&full_path) {
Some(lang) => match std::fs::read_to_string(&full_path) {
Ok(source) => {
let units = extract_units(path, &source, lang);
match hash_file(&full_path) {
Ok(content_hash) => match get_mtime(&full_path) {
Ok(mtime) => ParsedFileResult {
path: path.clone(),
units,
file_info: Some(FileInfo {
content_hash,
mtime,
}),
skip_reason: None,
},
Err(e) => ParsedFileResult {
path: path.clone(),
units: Vec::new(),
file_info: None,
skip_reason: Some(format!(
"Skipping {} ({})",
full_path.display(),
e
)),
},
},
Err(e) => ParsedFileResult {
path: path.clone(),
units: Vec::new(),
file_info: None,
skip_reason: Some(format!(
"Skipping {} ({})",
full_path.display(),
e
)),
},
}
}
Err(e) => ParsedFileResult {
path: path.clone(),
units: Vec::new(),
file_info: None,
skip_reason: Some(format!("Skipping {} ({})", full_path.display(), e)),
},
},
None => ParsedFileResult {
path: path.clone(),
units: Vec::new(),
file_info: None,
skip_reason: None,
},
};
if let Some(pb) = &progress {
pb.inc(1);
}
result
})
.collect()
}
pub struct IndexBuilder {
model: Option<Colbert>,
model_path: PathBuf,
quantized: bool,
parallel_sessions: Option<usize>,
batch_size: Option<usize>,
project_root: PathBuf,
index_dir: PathBuf,
pool_factor: Option<usize>,
encode_batch_size: Option<usize>,
index_chunk_size: Option<usize>,
dynamic_batch: bool,
auto_confirm: bool,
model_id: String,
}
impl IndexBuilder {
pub fn new(project_root: &Path, model_id: &str, model_path: &Path) -> Result<Self> {
Self::with_options(project_root, model_id, model_path, false, None, None, None)
}
pub fn with_quantized(
project_root: &Path,
model_id: &str,
model_path: &Path,
quantized: bool,
) -> Result<Self> {
Self::with_options(
project_root,
model_id,
model_path,
quantized,
None,
None,
None,
)
}
pub fn with_options(
project_root: &Path,
model_id: &str,
model_path: &Path,
quantized: bool,
pool_factor: Option<usize>,
parallel_sessions: Option<usize>,
batch_size: Option<usize>,
) -> Result<Self> {
let index_dir = get_index_dir_for_project(project_root, model_id)?;
Ok(Self {
model: None, model_path: model_path.to_path_buf(),
quantized,
parallel_sessions,
batch_size,
project_root: project_root.to_path_buf(),
index_dir,
pool_factor,
encode_batch_size: None,
index_chunk_size: None,
dynamic_batch: true,
auto_confirm: false, model_id: model_id.to_string(),
})
}
pub fn set_auto_confirm(&mut self, auto_confirm: bool) {
self.auto_confirm = auto_confirm;
}
pub fn set_encode_batch_size(&mut self, encode_batch_size: usize) {
self.encode_batch_size = Some(encode_batch_size.max(1));
}
pub fn set_index_chunk_size(&mut self, index_chunk_size: usize) {
self.index_chunk_size = Some(index_chunk_size.max(1));
}
pub fn set_dynamic_batch(&mut self, dynamic_batch: bool) {
self.dynamic_batch = dynamic_batch;
}
fn ensure_model_created(&mut self, num_units: usize) -> Result<()> {
if self.model.is_none() {
#[cfg(feature = "cuda")]
let acceleration_mode = env_acceleration_mode_lossy();
#[cfg(feature = "cuda")]
let (num_sessions, execution_provider) = {
match acceleration_mode {
AccelerationMode::ForceCpu => {
apply_acceleration_mode(AccelerationMode::ForceCpu);
crate::onnx_runtime::ensure_onnx_runtime()
.context("Failed to initialize ONNX Runtime")?;
(
self.parallel_sessions
.unwrap_or_else(crate::config::get_default_cpu_parallel_sessions),
ExecutionProvider::Cpu,
)
}
AccelerationMode::ForceGpu => {
apply_acceleration_mode(AccelerationMode::ForceGpu);
crate::onnx_runtime::ensure_onnx_runtime()
.context("Failed to initialize ONNX Runtime")?;
if !crate::onnx_runtime::is_cudnn_available() {
anyhow::bail!("FORCE_GPU is set, but cuDNN was not initialized");
}
if !next_plaid_onnx::is_cuda_available() {
anyhow::bail!(
"FORCE_GPU is set, but the CUDA execution provider was not initialized"
);
}
(
self.parallel_sessions
.unwrap_or(crate::config::DEFAULT_PARALLEL_SESSIONS_GPU),
ExecutionProvider::Cuda,
)
}
AccelerationMode::Auto => {
let force_cpu = num_units < SMALL_BATCH_CPU_THRESHOLD;
if force_cpu {
apply_acceleration_mode(AccelerationMode::ForceCpu);
} else {
apply_acceleration_mode(AccelerationMode::Auto);
}
crate::onnx_runtime::ensure_onnx_runtime()
.context("Failed to initialize ONNX Runtime")?;
let use_cuda = !force_cpu && {
crate::onnx_runtime::is_cudnn_available()
&& next_plaid_onnx::is_cuda_available()
};
if use_cuda {
(
self.parallel_sessions
.unwrap_or(crate::config::DEFAULT_PARALLEL_SESSIONS_GPU),
ExecutionProvider::Cuda,
)
} else {
(
self.parallel_sessions.unwrap_or_else(
crate::config::get_default_cpu_parallel_sessions,
),
ExecutionProvider::Cpu,
)
}
}
}
};
#[cfg(not(feature = "cuda"))]
let (num_sessions, execution_provider) = {
let _ = num_units;
crate::onnx_runtime::ensure_onnx_runtime()
.context("Failed to initialize ONNX Runtime")?;
(
self.parallel_sessions
.unwrap_or_else(crate::config::get_default_cpu_parallel_sessions),
ExecutionProvider::Cpu,
)
};
eprintln!("🤖 Model: {}", self.model_id);
eprintln!("📂 Building index...");
let batch = self
.batch_size
.unwrap_or_else(crate::config::get_default_batch_size);
let model = crate::stderr::with_suppressed_stderr(|| {
Colbert::builder(&self.model_path)
.with_quantized(self.quantized)
.with_parallel(num_sessions)
.with_batch_size(batch)
.with_dynamic_batch(self.dynamic_batch)
.with_execution_provider(execution_provider)
.build()
})
.context("Failed to load ColBERT model")?;
self.model = Some(model);
}
Ok(())
}
fn model(&self) -> &Colbert {
self.model
.as_ref()
.expect("Model not created. Call ensure_model_created() first.")
}
#[cfg(feature = "cuda")]
fn is_using_gpu(&self) -> bool {
self.model
.as_ref()
.is_some_and(|m| !matches!(m.requested_execution_provider, ExecutionProvider::Cpu))
}
#[cfg(feature = "cuda")]
fn rebuild_model_for_cpu(&mut self) -> Result<()> {
self.model = None;
apply_acceleration_mode(AccelerationMode::ForceCpu);
let num_sessions = self
.parallel_sessions
.unwrap_or_else(crate::config::get_default_cpu_parallel_sessions);
let batch = crate::config::DEFAULT_BATCH_SIZE_CPU;
let model = crate::stderr::with_suppressed_stderr(|| {
Colbert::builder(&self.model_path)
.with_quantized(self.quantized)
.with_parallel(num_sessions)
.with_batch_size(batch)
.with_dynamic_batch(false)
.with_execution_provider(ExecutionProvider::Cpu)
.build()
})
.context("Failed to load ColBERT model for CPU fallback")?;
self.model = Some(model);
Ok(())
}
fn run_encoding_pipeline(
&mut self,
sorted_units: &[SortedUnit],
index_chunk_size: usize,
pool_factor: Option<usize>,
index_path: &str,
pb: Option<&ProgressBar>,
) -> Result<bool> {
let force_cpu = next_plaid::is_force_cpu();
let config = IndexConfig {
force_cpu,
..Default::default()
};
let update_config = UpdateConfig {
force_cpu,
..Default::default()
};
let result = run_chunk_pipeline(
self.model().clone(),
sorted_units,
ChunkPipelineConfig {
index_chunk_size,
pool_factor,
index_path,
config,
update_config,
pb,
},
);
#[cfg(feature = "cuda")]
if let Err(gpu_err) = result {
if self.is_using_gpu() {
let accel = env_acceleration_mode_lossy();
if accel == AccelerationMode::ForceGpu {
anyhow::bail!(
"GPU encoding failed with --force-gpu. \
Not enough GPU memory for batch size {batch} and document length. \
Try reducing the batch size or use auto mode to allow CPU fallback.\n\
\nCaused by: {gpu_err}",
batch = self
.batch_size
.unwrap_or(crate::config::DEFAULT_BATCH_SIZE_GPU),
);
}
eprintln!(
"\n⚠️ GPU encoding failed, falling back to CPU. \
This is usually caused by insufficient GPU memory for the batch size.\n"
);
self.rebuild_model_for_cpu()?;
let force_cpu = next_plaid::is_force_cpu();
let config = IndexConfig {
force_cpu,
..Default::default()
};
let update_config = UpdateConfig {
force_cpu,
..Default::default()
};
return run_chunk_pipeline(
self.model().clone(),
sorted_units,
ChunkPipelineConfig {
index_chunk_size,
pool_factor,
index_path,
config,
update_config,
pb,
},
);
}
return Err(gpu_err);
}
result
}
pub fn index_dir(&self) -> &Path {
&self.index_dir
}
fn resolve_pool_factor(&self, num_units: usize) -> Option<usize> {
if num_units > LARGE_BATCH_THRESHOLD {
Some(LARGE_BATCH_POOL_FACTOR)
} else {
self.pool_factor
}
}
fn reconstruct_state_from_filtering_db(&self, index_path: &str) -> Result<IndexState> {
let all_metadata = filtering::get(index_path, None, &[], None)?;
if all_metadata.is_empty() {
anyhow::bail!("Filtering database is empty, cannot reconstruct state");
}
let mut unique_files: HashSet<PathBuf> = HashSet::new();
for meta in &all_metadata {
if let Some(file_str) = meta.get("file").and_then(|v| v.as_str()) {
unique_files.insert(PathBuf::from(file_str));
}
}
if unique_files.is_empty() {
anyhow::bail!("No file paths found in filtering database");
}
let mut state = IndexState::default();
for file_path in unique_files {
let full_path = self.project_root.join(&file_path);
if full_path.exists() {
if let (Ok(hash), Ok(mtime)) = (hash_file(&full_path), get_mtime(&full_path)) {
state.files.insert(
file_path,
FileInfo {
content_hash: hash,
mtime,
},
);
}
}
}
Ok(state)
}
fn reconcile_document_counts(
&self,
index_path: &str,
filtering_count: usize,
vector_count: usize,
) -> Result<()> {
eprintln!(
"⚠️ Index/DB desync: SQLite has {} entries, vector index has {} docs",
filtering_count, vector_count
);
if filtering_count > vector_count {
let orphan_ids = filtering::where_condition(
index_path,
"_subset_ >= ?",
&[serde_json::json!(vector_count as i64)],
)?;
if !orphan_ids.is_empty() {
filtering::delete(index_path, &orphan_ids)?;
}
}
Ok(())
}
fn repair_index_db_sync(&self, index_dir: &Path) -> Result<bool> {
let index_path = index_dir.to_str().unwrap();
if !index_dir.join("metadata.json").exists() {
return Ok(false); }
if !filtering::exists(index_path) {
return Ok(false); }
let index_metadata =
Metadata::load_from_path(index_dir).context("Failed to load index metadata")?;
let db_count = filtering::count(index_path).context("Failed to get DB count")?;
let index_count = index_metadata.num_documents;
if index_count == db_count {
return Ok(false); }
eprintln!(
"⚠️ Index/DB desync detected: index has {} docs, DB has {} records",
index_count, db_count
);
if db_count > index_count {
let extra_ids: Vec<i64> = (index_count as i64..db_count as i64).collect();
filtering::delete(index_path, &extra_ids)
.context("Failed to delete extra DB records")?;
eprintln!("🔧 Deleted {} orphan DB records", extra_ids.len());
} else {
let extra_ids: Vec<i64> = (db_count as i64..index_count as i64).collect();
delete_from_index(&extra_ids, index_path)
.context("Failed to delete extra index documents")?;
eprintln!("🔧 Deleted {} orphan index documents", extra_ids.len());
}
let new_index_metadata = Metadata::load_from_path(index_dir)
.context("Failed to reload index metadata after repair")?;
let new_db_count =
filtering::count(index_path).context("Failed to get DB count after repair")?;
if new_index_metadata.num_documents != new_db_count {
anyhow::bail!(
"Repair failed: index still has {} documents but DB has {} records",
new_index_metadata.num_documents,
new_db_count
);
}
Ok(true)
}
pub fn index(&mut self, languages: Option<&[Language]>, force: bool) -> Result<UpdateStats> {
let _lock = acquire_index_lock(&self.index_dir)?;
self.run_indexing(languages, force)
}
fn run_indexing(&mut self, languages: Option<&[Language]>, force: bool) -> Result<UpdateStats> {
let _ = std::fs::remove_dir_all(self.index_dir.join("index.tmp"));
let _ = std::fs::remove_dir_all(self.index_dir.join("index.old"));
self.maybe_seed_from_worktree(force);
let state = IndexState::load(&self.index_dir)?;
let index_dir = get_vector_index_path(&self.index_dir);
let index_path = index_dir.to_str().unwrap();
let index_exists = index_dir.join("metadata.json").exists();
let filtering_exists = filtering::exists(index_path);
let building = self.index_dir.join(BUILDING_MARKER).exists();
let current_version = env!("CARGO_PKG_VERSION");
let version_mismatch =
index_exists && !state.cli_version.is_empty() && state.cli_version != current_version;
if force || version_mismatch {
let _ = std::fs::remove_file(self.index_dir.join(BUILDING_MARKER));
return self.full_rebuild(languages);
}
if building || !index_exists {
return self.build_resumable(languages);
}
if !filtering_exists || filtering::count(index_path).is_err() {
eprintln!("⚠️ Filtering database corrupted, rebuilding index...");
return self.full_rebuild(languages);
}
let state = if state.files.is_empty() {
match self.reconstruct_state_from_filtering_db(index_path) {
Ok(reconstructed) => {
eprintln!(
"📋 Reconstructed state from index ({} files)",
reconstructed.files.len()
);
reconstructed.save(&self.index_dir)?;
reconstructed
}
Err(_) => {
return self.full_rebuild(languages);
}
}
} else {
state
};
if let Ok(metadata_count) = filtering::count(index_path) {
if let Ok(index_metadata) = Metadata::load_from_path(&index_dir) {
if metadata_count != index_metadata.num_documents {
match self.reconcile_document_counts(
index_path,
metadata_count,
index_metadata.num_documents,
) {
Ok(()) => {
eprintln!(
"🔧 Reconciled index (filtering: {}, vector: {})",
metadata_count, index_metadata.num_documents
);
}
Err(_) => {
return self.full_rebuild(languages);
}
}
}
}
}
self.incremental_update(&state, languages)
}
pub fn try_index(
&mut self,
languages: Option<&[Language]>,
force: bool,
) -> Result<Option<UpdateStats>> {
let Some(_lock) = try_acquire_index_lock(&self.index_dir)? else {
return Ok(None);
};
self.run_indexing(languages, force).map(Some)
}
pub fn index_specific_files(&mut self, files: &[PathBuf]) -> Result<UpdateStats> {
if files.is_empty() {
return Ok(UpdateStats {
added: 0,
changed: 0,
deleted: 0,
unchanged: 0,
skipped: 0,
});
}
let _lock = acquire_index_lock(&self.index_dir)?;
let state = IndexState::load(&self.index_dir)?;
let index_dir = get_vector_index_path(&self.index_dir);
let index_path = index_dir.to_str().unwrap();
let gitignore = {
let mut builder = GitignoreBuilder::new(&self.project_root);
let gitignore_path = self.project_root.join(".gitignore");
if gitignore_path.exists() {
let _ = builder.add(&gitignore_path);
}
builder.build().ok()
};
let config = crate::config::Config::load().unwrap_or_default();
let extra_ignore = config.extra_ignore;
let force_include = config.force_include;
let mut files_added = Vec::new();
let mut files_changed = Vec::new();
let mut unchanged = 0;
for path in files {
if !is_within_project_root(&self.project_root, path) {
continue;
}
let full_path = self.project_root.join(path);
if !full_path.exists() {
continue;
}
if should_ignore(path, &extra_ignore, &force_include) {
continue;
}
if let Some(ref gi) = gitignore {
if gi
.matched_path_or_any_parents(path, full_path.is_dir())
.is_ignore()
{
continue;
}
}
let hash = hash_file(&full_path)?;
match state.files.get(path) {
Some(info) if info.content_hash == hash => {
unchanged += 1;
}
Some(_) => {
files_changed.push(path.clone());
}
None => {
files_added.push(path.clone());
}
}
}
let files_to_index: Vec<PathBuf> = files_added
.iter()
.chain(files_changed.iter())
.filter(|p| !state.ignored_files.contains(*p))
.cloned()
.collect();
if files_to_index.is_empty() {
return Ok(UpdateStats {
added: 0,
changed: 0,
deleted: 0,
unchanged,
skipped: 0,
});
}
let mut new_state = state.clone();
let mut new_units: Vec<CodeUnit> = Vec::new();
let pb = ProgressBar::new(files_to_index.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} ({eta}) {msg}")
.unwrap()
.progress_chars("█▓░"),
);
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb.set_message("Parsing files...");
for parsed in parse_files_parallel(&self.project_root, &files_to_index, Some(&pb)) {
if let Some(reason) = parsed.skip_reason {
eprintln!("⚠️ {}", reason);
new_state.ignored_files.insert(parsed.path);
continue;
}
new_units.extend(parsed.units);
if let Some(file_info) = parsed.file_info {
new_state.files.insert(parsed.path, file_info);
}
}
pb.finish_and_clear();
if new_units.is_empty() {
return Ok(UpdateStats {
added: 0,
changed: 0,
deleted: 0,
unchanged,
skipped: 0,
});
}
build_call_graph(&mut new_units);
self.ensure_model_created(new_units.len())?;
let pb = ProgressBar::new(new_units.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}")
.unwrap()
.progress_chars("█▓░"),
);
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb.set_message("Encoding...");
std::fs::create_dir_all(index_path)?;
let encode_batch_size = self.encode_batch_size.unwrap_or(DEFAULT_ENCODE_BATCH_SIZE);
let index_chunk_size = self
.index_chunk_size
.unwrap_or(INDEX_CHUNK_SIZE)
.max(encode_batch_size);
let pool_factor = self.resolve_pool_factor(new_units.len());
for file_path in &files_changed {
self.delete_file_from_index(index_path, file_path)?;
}
let sorted_units = prepare_units_for_encoding(&new_units, index_chunk_size);
let was_interrupted = self.run_encoding_pipeline(
&sorted_units,
index_chunk_size,
pool_factor,
index_path,
Some(&pb),
)?;
pb.finish_and_clear();
if was_interrupted || is_interrupted() {
anyhow::bail!("Indexing interrupted by user");
}
new_state.save(&self.index_dir)?;
Ok(UpdateStats {
added: files_added.len(),
changed: files_changed.len(),
deleted: 0,
unchanged,
skipped: 0,
})
}
pub fn scan_files_matching_patterns(&self, patterns: &[String]) -> Result<Vec<PathBuf>> {
let (all_files, _skipped) = self.scan_files(None)?;
if patterns.is_empty() {
return Ok(all_files);
}
let filtered: Vec<PathBuf> = all_files
.into_iter()
.filter(|path| matches_glob_pattern(path, patterns))
.collect();
Ok(filtered)
}
fn maybe_seed_from_worktree(&self, force: bool) {
if force
|| get_vector_index_path(&self.index_dir)
.join("metadata.json")
.exists()
{
return;
}
match self.try_seed_from_sibling_worktree() {
Ok(true) | Ok(false) => {}
Err(e) => eprintln!("⚠️ Worktree index seeding skipped ({e}); building from scratch"),
}
}
fn try_seed_from_sibling_worktree(&self) -> Result<bool> {
let current_version = env!("CARGO_PKG_VERSION");
let candidates = worktree::seed_candidates(&self.project_root, &self.model_id)?;
for candidate in candidates {
let src_dir = &candidate.index_dir;
let src_vector = get_vector_index_path(src_dir);
let Some(src_vector_str) = src_vector.to_str() else {
continue;
};
if !src_vector.join("metadata.json").exists() || !filtering::exists(src_vector_str) {
continue;
}
let src_state = match IndexState::load(src_dir) {
Ok(s) if !s.files.is_empty() && s.cli_version == current_version && !s.dirty => s,
_ => continue,
};
let dest_vector = get_vector_index_path(&self.index_dir);
let tmp = self.index_dir.join("index.tmp");
if tmp.exists() {
std::fs::remove_dir_all(&tmp)?;
}
worktree::copy_dir_all(&src_vector, &tmp)?;
if dest_vector.exists() {
std::fs::remove_dir_all(&dest_vector)?;
}
std::fs::rename(&tmp, &dest_vector)
.context("Failed to move seeded index into place")?;
src_state.save(&self.index_dir)?;
ProjectMetadata::new(&self.project_root, &self.model_id).save(&self.index_dir)?;
eprintln!(
"📋 Seeded index from worktree {} ({} files); re-embedding only changed files",
candidate.worktree_root.display(),
src_state.files.len()
);
return Ok(true);
}
Ok(false)
}
fn build_resumable(&mut self, languages: Option<&[Language]>) -> Result<UpdateStats> {
let index_dir = get_vector_index_path(&self.index_dir);
std::fs::create_dir_all(&index_dir)?;
let index_path = index_dir.to_str().unwrap().to_string();
let marker = self.index_dir.join(BUILDING_MARKER);
std::fs::write(&marker, "")?;
let mut state = IndexState::load(&self.index_dir)?;
if index_dir.join("metadata.json").exists() && filtering::exists(&index_path) {
let _ = self.repair_index_db_sync(&index_dir);
}
let (scanned, skipped) = self.scan_files(languages)?;
let scanned_set: HashSet<PathBuf> = scanned.iter().cloned().collect();
let stale: Vec<PathBuf> = state
.files
.keys()
.filter(|p| !scanned_set.contains(*p))
.cloned()
.collect();
for path in &stale {
self.delete_file_from_index(&index_path, path)?;
state.files.remove(path);
}
let todo: Vec<PathBuf> = scanned
.iter()
.filter(|p| !state.files.contains_key(*p) && !state.ignored_files.contains(*p))
.cloned()
.collect();
let pb = ProgressBar::new(todo.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} ({eta}) {msg}")
.unwrap()
.progress_chars("█▓░"),
);
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb.set_message("Parsing files...");
let mut all_units: Vec<CodeUnit> = Vec::new();
let mut file_info: HashMap<PathBuf, FileInfo> = HashMap::new();
for parsed in parse_files_parallel(&self.project_root, &todo, Some(&pb)) {
if let Some(reason) = parsed.skip_reason {
eprintln!("⚠️ {}", reason);
state.ignored_files.insert(parsed.path);
continue;
}
if let Some(fi) = parsed.file_info {
file_info.insert(parsed.path.clone(), fi);
all_units.extend(parsed.units);
}
}
let parsing_interrupted = is_interrupted();
pb.finish_and_clear();
if parsing_interrupted {
state.dirty = false;
state.save(&self.index_dir)?;
anyhow::bail!("Indexing interrupted by user");
}
build_call_graph(&mut all_units);
if !self.auto_confirm
&& all_units.len() > CONFIRMATION_THRESHOLD
&& !prompt_large_index_confirmation(all_units.len())
{
let _ = std::fs::remove_file(&marker);
anyhow::bail!("Indexing cancelled by user");
}
let mut units_by_file: HashMap<PathBuf, Vec<CodeUnit>> = HashMap::new();
for unit in all_units {
units_by_file
.entry(unit.file.clone())
.or_default()
.push(unit);
}
let encode_batch_size = self.encode_batch_size.unwrap_or(DEFAULT_ENCODE_BATCH_SIZE);
let index_chunk_size = self
.index_chunk_size
.unwrap_or(INDEX_CHUNK_SIZE)
.max(encode_batch_size);
let total_units: usize = units_by_file.values().map(|u| u.len()).sum();
let already = state.files.len();
let mut added = 0usize;
let mut recorded_empty = false;
for (path, fi) in &file_info {
if !units_by_file.contains_key(path) {
state.files.insert(path.clone(), fi.clone());
added += 1;
recorded_empty = true;
}
}
if recorded_empty {
state.dirty = false;
state.save(&self.index_dir)?;
}
let encode_pb = ProgressBar::new(total_units as u64);
encode_pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}")
.unwrap()
.progress_chars("█▓░"),
);
encode_pb.enable_steady_tick(std::time::Duration::from_millis(100));
encode_pb.set_message("Encoding...");
let mut batch_files: Vec<PathBuf> = Vec::new();
let mut batch_units: Vec<CodeUnit> = Vec::new();
let mut interrupted = false;
let mut ordered: Vec<PathBuf> = units_by_file.keys().cloned().collect();
ordered.sort();
for file in ordered {
let units = units_by_file.remove(&file).unwrap_or_default();
batch_units.extend(units);
batch_files.push(file);
if batch_units.len() >= BUILD_CHECKPOINT_UNITS {
if self.flush_build_batch(
&index_path,
&batch_files,
&batch_units,
index_chunk_size,
Some(&encode_pb),
)? {
interrupted = true;
break;
}
for f in batch_files.drain(..) {
if let Some(fi) = file_info.get(&f) {
state.files.insert(f, fi.clone());
added += 1;
}
}
batch_units.clear();
state.dirty = false;
state.save(&self.index_dir)?; }
}
if !interrupted && !batch_units.is_empty() {
if self.flush_build_batch(
&index_path,
&batch_files,
&batch_units,
index_chunk_size,
Some(&encode_pb),
)? {
interrupted = true;
} else {
for f in batch_files.drain(..) {
if let Some(fi) = file_info.get(&f) {
state.files.insert(f, fi.clone());
added += 1;
}
}
state.dirty = false;
state.save(&self.index_dir)?;
}
}
encode_pb.finish_and_clear();
if interrupted {
state.dirty = false;
state.save(&self.index_dir)?;
anyhow::bail!("Indexing interrupted by user");
}
state.dirty = false;
state.save(&self.index_dir)?;
ProjectMetadata::new(&self.project_root, &self.model_id).save(&self.index_dir)?;
let _ = std::fs::remove_file(&marker);
Ok(UpdateStats {
added,
changed: 0,
deleted: stale.len(),
unchanged: already,
skipped,
})
}
fn flush_build_batch(
&mut self,
index_path: &str,
batch_files: &[PathBuf],
batch_units: &[CodeUnit],
index_chunk_size: usize,
pb: Option<&ProgressBar>,
) -> Result<bool> {
if batch_units.is_empty() {
return Ok(false);
}
for file in batch_files {
self.delete_file_from_index(index_path, file)?;
}
self.ensure_model_created(batch_units.len())?;
let pool_factor = self.resolve_pool_factor(batch_units.len());
let sorted_units = prepare_units_for_encoding(batch_units, index_chunk_size);
let was_interrupted = self.run_encoding_pipeline(
&sorted_units,
index_chunk_size,
pool_factor,
index_path,
pb,
)?;
Ok(was_interrupted || is_interrupted())
}
fn full_rebuild(&mut self, languages: Option<&[Language]>) -> Result<UpdateStats> {
let _ = std::fs::remove_file(self.index_dir.join(BUILDING_MARKER));
let index_path = get_vector_index_path(&self.index_dir);
let temp_path = self.index_dir.join("index.tmp");
let old_path = self.index_dir.join("index.old");
if temp_path.exists() {
std::fs::remove_dir_all(&temp_path)?;
}
if old_path.exists() {
std::fs::remove_dir_all(&old_path)?;
}
let (files, skipped) = self.scan_files(languages)?;
let mut state = IndexState::default();
let mut all_units: Vec<CodeUnit> = Vec::new();
let pb = ProgressBar::new(files.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} ({eta}) {msg}")
.unwrap()
.progress_chars("█▓░"),
);
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb.set_message("Parsing files...");
for parsed in parse_files_parallel(&self.project_root, &files, Some(&pb)) {
if let Some(reason) = parsed.skip_reason {
eprintln!("⚠️ {}", reason);
state.ignored_files.insert(parsed.path);
continue;
}
all_units.extend(parsed.units);
if let Some(file_info) = parsed.file_info {
state.files.insert(parsed.path, file_info);
}
}
let parsing_interrupted = is_interrupted();
pb.finish_and_clear();
if parsing_interrupted {
eprintln!("⚠️ Indexing interrupted during parsing. Partial index not saved.");
anyhow::bail!("Indexing interrupted by user");
}
build_call_graph(&mut all_units);
if !self.auto_confirm
&& all_units.len() > CONFIRMATION_THRESHOLD
&& !prompt_large_index_confirmation(all_units.len())
{
anyhow::bail!("Indexing cancelled by user");
}
let was_interrupted = if !all_units.is_empty() {
self.ensure_model_created(all_units.len())?;
#[cfg(feature = "cuda")]
if !crate::onnx_runtime::is_cudnn_available()
&& std::env::var("_COLGREP_CUDNN_NOTICE").is_err()
{
std::env::set_var("_COLGREP_CUDNN_NOTICE", "1");
eprintln!("📂 cuDNN not found, encoding will use CPU.");
}
self.write_index_impl(&all_units, true, Some(&temp_path))?
} else {
false
};
if was_interrupted {
let _ = std::fs::remove_dir_all(&temp_path);
anyhow::bail!("Indexing interrupted by user");
}
if all_units.is_empty() {
if index_path.exists() {
std::fs::remove_dir_all(&index_path)?;
}
} else {
if index_path.exists() {
std::fs::rename(&index_path, &old_path)
.context("Failed to move old index aside")?;
}
if let Err(e) = std::fs::rename(&temp_path, &index_path) {
if old_path.exists() && !index_path.exists() {
let _ = std::fs::rename(&old_path, &index_path);
}
return Err(anyhow::anyhow!(
"Failed to move new index into place: {}",
e
));
}
if old_path.exists() {
let _ = std::fs::remove_dir_all(&old_path);
}
}
state.save(&self.index_dir)?;
ProjectMetadata::new(&self.project_root, &self.model_id).save(&self.index_dir)?;
Ok(UpdateStats {
added: files.len(),
changed: 0,
deleted: 0,
unchanged: 0,
skipped,
})
}
fn incremental_update(
&mut self,
old_state: &IndexState,
languages: Option<&[Language]>,
) -> Result<UpdateStats> {
let plan = self.compute_update_plan(old_state, languages)?;
let index_dir = get_vector_index_path(&self.index_dir);
let index_path = index_dir.to_str().unwrap();
if old_state.dirty {
if let Err(e) = self.repair_index_db_sync(&index_dir) {
eprintln!("⚠️ Repair failed: {}, falling back to full rebuild", e);
return self.full_rebuild(languages);
}
}
let orphaned_deleted = self.cleanup_orphaned_entries(index_path)?;
if plan.added.is_empty()
&& plan.changed.is_empty()
&& plan.deleted.is_empty()
&& orphaned_deleted == 0
{
return Ok(UpdateStats {
added: 0,
changed: 0,
deleted: 0,
unchanged: plan.unchanged,
skipped: 0,
});
}
let mut state = old_state.clone();
if !plan.deleted.is_empty() || !plan.changed.is_empty() || !plan.added.is_empty() {
state.dirty = true;
state.save(&self.index_dir)?;
}
for file_path in &plan.deleted {
self.delete_file_from_index(index_path, file_path)?;
}
for path in &plan.deleted {
state.files.remove(path);
}
let stale_paths: Vec<PathBuf> = state
.files
.keys()
.filter(|p| !self.project_root.join(p).exists())
.cloned()
.collect();
for path in stale_paths {
state.files.remove(&path);
}
let files_to_index: Vec<PathBuf> = plan
.added
.iter()
.chain(plan.changed.iter())
.filter(|p| !state.ignored_files.contains(*p))
.cloned()
.collect();
let mut new_units: Vec<CodeUnit> = Vec::new();
let pb = if !files_to_index.is_empty() {
let pb = ProgressBar::new(files_to_index.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} ({eta}) {msg}")
.unwrap()
.progress_chars("█▓░"),
);
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb.set_message("Parsing files...");
Some(pb)
} else {
None
};
let mut skipped_files: Vec<PathBuf> = Vec::new();
for parsed in parse_files_parallel(&self.project_root, &files_to_index, pb.as_ref()) {
if let Some(reason) = parsed.skip_reason {
eprintln!("⚠️ {}", reason);
state.files.remove(&parsed.path);
state.ignored_files.insert(parsed.path.clone());
skipped_files.push(parsed.path);
continue;
}
new_units.extend(parsed.units);
if let Some(file_info) = parsed.file_info {
state.files.insert(parsed.path, file_info);
}
}
let parsing_interrupted = is_interrupted();
if let Some(pb) = pb {
pb.finish_and_clear();
}
if parsing_interrupted {
anyhow::bail!("Indexing interrupted by user");
}
for file_path in &skipped_files {
if plan.changed.contains(file_path) {
let _ = self.delete_file_from_index(index_path, file_path);
}
}
let mut was_interrupted = false;
if !new_units.is_empty() {
build_call_graph(&mut new_units);
if !self.auto_confirm
&& new_units.len() > CONFIRMATION_THRESHOLD
&& !prompt_large_index_confirmation(new_units.len())
{
anyhow::bail!("Indexing cancelled by user");
}
self.ensure_model_created(new_units.len())?;
let pb = ProgressBar::new(new_units.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}")
.unwrap()
.progress_chars("█▓░"),
);
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb.set_message("Encoding...");
let encode_batch_size = self.encode_batch_size.unwrap_or(DEFAULT_ENCODE_BATCH_SIZE);
let index_chunk_size = self
.index_chunk_size
.unwrap_or(INDEX_CHUNK_SIZE)
.max(encode_batch_size);
let pool_factor = self.resolve_pool_factor(new_units.len());
for file_path in &plan.changed {
self.delete_file_from_index(index_path, file_path)?;
}
let sorted_units = prepare_units_for_encoding(&new_units, index_chunk_size);
let pipeline_interrupted = self.run_encoding_pipeline(
&sorted_units,
index_chunk_size,
pool_factor,
index_path,
Some(&pb),
)?;
was_interrupted |= pipeline_interrupted;
pb.finish_and_clear();
}
if was_interrupted || is_interrupted() {
anyhow::bail!("Indexing interrupted by user");
}
state.dirty = false;
state.save(&self.index_dir)?;
Ok(UpdateStats {
added: plan.added.len(),
changed: plan.changed.len(),
deleted: plan.deleted.len(),
unchanged: plan.unchanged,
skipped: 0,
})
}
fn scan_files(&self, languages: Option<&[Language]>) -> Result<(Vec<PathBuf>, usize)> {
let config = crate::config::Config::load().unwrap_or_default();
let extra_ignore = config.extra_ignore.clone();
let force_include = config.force_include.clone();
let project_root = self.project_root.clone();
let walker = WalkBuilder::new(&self.project_root)
.hidden(false) .git_ignore(true)
.follow_links(false) .filter_entry(move |entry| {
match entry.path().strip_prefix(&project_root) {
Ok(rel) if rel.as_os_str().is_empty() => true, Ok(rel) => !should_ignore(rel, &extra_ignore, &force_include),
Err(_) => !should_ignore(entry.path(), &extra_ignore, &force_include), }
})
.build();
let mut files = Vec::new();
let mut skipped = 0;
for entry in walker.filter_map(|e| e.ok()) {
if !entry.file_type().map(|t| t.is_file()).unwrap_or(false) {
continue;
}
let path = entry.path();
if is_file_too_large(path) {
skipped += 1;
continue;
}
let lang = match detect_language(path) {
Some(l) => l,
None => continue,
};
if languages.map(|ls| ls.contains(&lang)).unwrap_or(true) {
if let Ok(rel_path) = path.strip_prefix(&self.project_root) {
if is_within_project_root(&self.project_root, rel_path) {
files.push(rel_path.to_path_buf());
} else {
skipped += 1;
}
}
}
}
Ok((files, skipped))
}
}
fn is_file_too_large(path: &Path) -> bool {
match std::fs::metadata(path) {
Ok(meta) => meta.len() > MAX_FILE_SIZE,
Err(_) => false, }
}
fn is_within_project_root(project_root: &Path, relative_path: &Path) -> bool {
let path_str = relative_path.to_string_lossy();
if path_str.contains("..") {
let full_path = project_root.join(relative_path);
match full_path.canonicalize() {
Ok(canonical) => {
match project_root.canonicalize() {
Ok(canonical_root) => canonical.starts_with(&canonical_root),
Err(_) => false,
}
}
Err(_) => false, }
} else {
let full_path = project_root.join(relative_path);
if !full_path.exists() {
return true; }
match (full_path.canonicalize(), project_root.canonicalize()) {
(Ok(canonical), Ok(canonical_root)) => canonical.starts_with(&canonical_root),
_ => false,
}
}
}
const IGNORED_DIRS: &[&str] = &[
".git",
".svn",
".hg",
"node_modules",
"vendor",
"third_party",
"third-party",
"external",
"target",
"build",
"dist",
"out",
"bin",
"obj",
"__pycache__",
".venv",
"venv",
".env",
"env",
".tox",
".nox",
".pytest_cache",
".mypy_cache",
".ruff_cache",
"*.egg-info",
".eggs",
".next",
".nuxt",
".output",
".cache",
".parcel-cache",
".turbo",
"target",
"go.sum",
".gradle",
".m2",
".idea",
".vscode",
".vs",
"*.xcworkspace",
"*.xcodeproj",
"coverage",
".coverage",
"htmlcov",
".nyc_output",
"tmp",
"temp",
"logs",
".DS_Store",
];
const ALLOWED_HIDDEN_DIRS: &[&str] = &[
".github",
".gitlab",
".circleci",
".buildkite",
".claude",
".claude-plugin",
];
const ALLOWED_HIDDEN_FILES: &[&str] = &[".gitlab-ci.yml", ".gitlab-ci.yaml", ".travis.yml"];
pub fn path_contains_ignored_dir(path: &Path) -> Option<&'static str> {
for component in path.components() {
if let std::path::Component::Normal(name) = component {
let name_str = name.to_string_lossy();
for pattern in IGNORED_DIRS {
if !pattern.starts_with('*') && name_str == *pattern {
return Some(pattern);
}
}
}
}
None
}
fn should_ignore(path: &Path, extra_ignore: &[String], force_include: &[String]) -> bool {
let path_str = path.to_string_lossy();
for pattern in force_include {
if let Some(suffix) = pattern.strip_prefix('*') {
if path_str.ends_with(suffix) {
return false;
}
} else if path_str == *pattern || path_str.starts_with(&format!("{}/", pattern)) {
return false;
}
}
for component in path.components() {
if let std::path::Component::Normal(name) = component {
let name_str = name.to_string_lossy();
let force_included = force_include.iter().any(|p| {
if let Some(suffix) = p.strip_prefix('*') {
name_str.ends_with(suffix)
} else {
name_str.as_ref() == p
}
});
if force_included {
continue; }
if name_str.starts_with('.')
&& !ALLOWED_HIDDEN_DIRS.contains(&name_str.as_ref())
&& !ALLOWED_HIDDEN_FILES.contains(&name_str.as_ref())
{
return true;
}
for pattern in IGNORED_DIRS {
if let Some(suffix) = pattern.strip_prefix('*') {
if name_str.ends_with(suffix) {
return true;
}
} else if name_str == *pattern {
return true;
}
}
for pattern in extra_ignore {
if let Some(suffix) = pattern.strip_prefix('*') {
if name_str.ends_with(suffix) {
return true;
}
} else if name_str.as_ref() == pattern {
return true;
}
}
}
}
false
}
impl IndexBuilder {
fn compute_update_plan(
&self,
state: &IndexState,
languages: Option<&[Language]>,
) -> Result<UpdatePlan> {
let (current_files, _skipped) = self.scan_files(languages)?;
let current_set: HashSet<_> = current_files.iter().cloned().collect();
let mut plan = UpdatePlan::default();
for path in ¤t_files {
if state.ignored_files.contains(path) {
continue;
}
let full_path = self.project_root.join(path);
let hash = match hash_file(&full_path) {
Ok(h) => h,
Err(e) => {
eprintln!("⚠️ Skipping {} ({})", full_path.display(), e);
continue;
}
};
match state.files.get(path) {
Some(info) if info.content_hash == hash => plan.unchanged += 1,
Some(_) => plan.changed.push(path.clone()),
None => plan.added.push(path.clone()),
}
}
for path in state.files.keys() {
if !current_set.contains(path) {
plan.deleted.push(path.clone());
}
}
Ok(plan)
}
fn delete_file_from_index(&self, index_path: &str, file_path: &Path) -> Result<()> {
let file_str = file_path.to_string_lossy().to_string();
let ids =
filtering::where_condition(index_path, "file = ?", &[serde_json::json!(file_str)])
.unwrap_or_default();
if !ids.is_empty() {
delete_from_index(&ids, index_path)?;
filtering::delete(index_path, &ids)?;
}
Ok(())
}
fn cleanup_orphaned_entries(&self, index_path: &str) -> Result<usize> {
let files = filtering::get_distinct_strings(index_path, "file").unwrap_or_default();
let mut deleted_count = 0;
for file_str in files {
let full_path = self.project_root.join(&file_str);
if !full_path.exists() {
let ids = filtering::where_condition(
index_path,
"file = ?",
&[serde_json::json!(file_str)],
)
.unwrap_or_default();
if !ids.is_empty() {
delete_from_index(&ids, index_path)?;
filtering::delete(index_path, &ids)?;
deleted_count += ids.len();
}
}
}
Ok(deleted_count)
}
#[allow(dead_code)]
fn write_index(&mut self, units: &[CodeUnit]) -> Result<bool> {
self.write_index_impl(units, false, None)
}
#[allow(dead_code)]
fn write_index_with_progress(&mut self, units: &[CodeUnit]) -> Result<bool> {
self.write_index_impl(units, true, None)
}
fn write_index_impl(
&mut self,
units: &[CodeUnit],
show_progress: bool,
target_index_path: Option<&Path>,
) -> Result<bool> {
let index_dir = target_index_path
.map(|p| p.to_path_buf())
.unwrap_or_else(|| get_vector_index_path(&self.index_dir));
let index_path = index_dir.to_str().unwrap();
std::fs::create_dir_all(index_path)?;
let pb = if show_progress {
let pb = ProgressBar::new(units.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}")
.unwrap()
.progress_chars("█▓░"),
);
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb.set_message("Encoding...");
Some(pb)
} else {
None
};
let encode_batch_size = self.encode_batch_size.unwrap_or(DEFAULT_ENCODE_BATCH_SIZE);
let index_chunk_size = self
.index_chunk_size
.unwrap_or(INDEX_CHUNK_SIZE)
.max(encode_batch_size);
let pool_factor = self.resolve_pool_factor(units.len());
let sorted_units = prepare_units_for_encoding(units, index_chunk_size);
self.ensure_model_created(units.len())?;
let was_interrupted = self.run_encoding_pipeline(
&sorted_units,
index_chunk_size,
pool_factor,
index_path,
pb.as_ref(),
)?;
if let Some(pb) = pb {
pb.finish_and_clear();
}
Ok(was_interrupted || is_interrupted())
}
pub fn status(&self, languages: Option<&[Language]>) -> Result<UpdatePlan> {
let state = IndexState::load(&self.index_dir)?;
self.compute_update_plan(&state, languages)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub unit: CodeUnit,
pub score: f32,
}
pub fn bre_to_ere(pattern: &str) -> String {
let chars: Vec<char> = pattern.chars().collect();
let len = chars.len();
let mut convert = vec![false; len];
fn mark_pairs(chars: &[char], convert: &mut [bool], open: char, close: char) {
let len = chars.len();
let mut stack: Vec<usize> = Vec::new();
let mut i = 0;
while i < len {
if chars[i] == '\\' && i + 1 < len {
match chars[i + 1] {
'\\' => {
i += 2;
continue;
}
c if c == open => {
stack.push(i);
i += 2;
continue;
}
c if c == close => {
if let Some(open_pos) = stack.pop() {
convert[open_pos] = true;
convert[i] = true;
}
i += 2;
continue;
}
_ => {
i += 2;
continue;
}
}
}
i += 1;
}
}
mark_pairs(&chars, &mut convert, '(', ')');
mark_pairs(&chars, &mut convert, '{', '}');
let mut result = String::with_capacity(pattern.len());
let mut i = 0;
let mut skip_close_brace = 0usize;
while i < len {
if chars[i] != '\\' || i + 1 >= len {
result.push(chars[i]);
i += 1;
continue;
}
let next = chars[i + 1];
match next {
'\\' => {
result.push('\\');
result.push('\\');
i += 2;
}
'|' => {
result.push('|');
i += 2;
}
'+' | '?' => {
result.push('\\');
result.push(next);
i += 2;
}
'(' | ')' if convert[i] => {
result.push(next);
i += 2;
}
'{' if convert[i] => {
if result.is_empty() {
skip_close_brace += 1;
result.push('\\');
result.push('{');
} else {
result.push('{');
}
i += 2;
}
'}' if convert[i] => {
if skip_close_brace > 0 {
skip_close_brace -= 1;
result.push('\\');
result.push('}');
} else {
result.push('}');
}
i += 2;
}
_ => {
result.push('\\');
result.push(next);
i += 2;
}
}
}
result
}
pub fn escape_literal_braces(pattern: &str) -> String {
let mut result = String::with_capacity(pattern.len() + 10);
let chars: Vec<char> = pattern.chars().collect();
let len = chars.len();
let mut i = 0;
let mut in_char_class = false;
while i < len {
let c = chars[i];
if c == '[' && (i == 0 || chars[i - 1] != '\\') {
in_char_class = true;
result.push(c);
i += 1;
continue;
}
if c == ']' && in_char_class && (i == 0 || chars[i - 1] != '\\') {
in_char_class = false;
result.push(c);
i += 1;
continue;
}
if in_char_class {
result.push(c);
i += 1;
continue;
}
if c == '\\' && i + 1 < len {
let next = chars[i + 1];
if next == '{' || next == '}' {
result.push('[');
result.push(next);
result.push(']');
i += 2;
continue;
}
result.push('\\');
result.push(next);
i += 2;
continue;
}
if c == '{' {
if let Some(close_pos) = find_matching_brace(&chars, i) {
let content: String = chars[i + 1..close_pos].iter().collect();
if is_valid_quantifier(&content) {
for ch in chars.iter().take(close_pos + 1).skip(i) {
result.push(*ch);
}
i = close_pos + 1;
continue;
}
}
result.push_str("[{]");
i += 1;
continue;
}
if c == '}' {
result.push_str("[}]");
i += 1;
continue;
}
result.push(c);
i += 1;
}
result
}
fn find_matching_brace(chars: &[char], open_pos: usize) -> Option<usize> {
for (i, ch) in chars.iter().enumerate().skip(open_pos + 1) {
if *ch == '}' {
return Some(i);
}
if *ch == '{' {
return None;
}
}
None
}
fn is_valid_quantifier(content: &str) -> bool {
if content.is_empty() {
return false;
}
let parts: Vec<&str> = content.split(',').collect();
match parts.len() {
1 => {
!parts[0].is_empty() && parts[0].chars().all(|c| c.is_ascii_digit())
}
2 => {
let first_ok = parts[0].is_empty() || parts[0].chars().all(|c| c.is_ascii_digit());
let second_ok = parts[1].is_empty() || parts[1].chars().all(|c| c.is_ascii_digit());
let has_digits = !parts[0].is_empty() || !parts[1].is_empty();
first_ok && second_ok && has_digits
}
_ => false,
}
}
fn expand_braces(pattern: &str) -> Vec<String> {
let Some(start) = pattern.find('{') else {
return vec![pattern.to_string()];
};
let Some(end) = pattern[start..].find('}') else {
return vec![pattern.to_string()];
};
let end = start + end;
let prefix = &pattern[..start];
let alternatives = &pattern[start + 1..end];
let suffix = &pattern[end + 1..];
let mut results = Vec::new();
let mut current = String::new();
let mut depth = 0;
for c in alternatives.chars() {
match c {
'{' => {
depth += 1;
current.push(c);
}
'}' => {
depth -= 1;
current.push(c);
}
',' if depth == 0 => {
let expanded = format!("{}{}{}", prefix, current, suffix);
results.extend(expand_braces(&expanded));
current.clear();
}
_ => current.push(c),
}
}
if !current.is_empty() || alternatives.ends_with(',') {
let expanded = format!("{}{}{}", prefix, current, suffix);
results.extend(expand_braces(&expanded));
}
results
}
fn build_glob_set(patterns: &[String]) -> Option<GlobSet> {
if patterns.is_empty() {
return None;
}
let expanded_patterns: Vec<String> = patterns.iter().flat_map(|p| expand_braces(p)).collect();
let mut builder = GlobSetBuilder::new();
for pattern in &expanded_patterns {
let normalized = if !pattern.starts_with("**/") && !pattern.starts_with('/') {
format!("**/{}", pattern)
} else {
pattern.clone()
};
if let Ok(glob) = Glob::new(&normalized) {
builder.add(glob);
}
}
builder.build().ok()
}
fn glob_to_regex(pattern: &str) -> String {
let mut regex = String::new();
if !pattern.starts_with("**/") && !pattern.starts_with('/') {
regex.push_str("(^|.*/)")
}
let mut chars = pattern.chars().peekable();
while let Some(c) = chars.next() {
match c {
'*' => {
if chars.peek() == Some(&'*') {
chars.next(); if chars.peek() == Some(&'/') {
chars.next(); regex.push_str("(.*/)?");
} else {
regex.push_str(".*");
}
} else {
regex.push_str("[^/]*");
}
}
'?' => regex.push('.'),
'.' | '+' | '(' | ')' | '[' | ']' | '{' | '}' | '^' | '$' | '|' | '\\' => {
regex.push('\\');
regex.push(c);
}
_ => regex.push(c),
}
}
regex.push('$');
regex
}
fn is_glob_pattern(pattern: &str) -> bool {
pattern.contains('*') || pattern.contains('?') || pattern.contains('[')
}
fn dir_pattern_to_regex(pattern: &str) -> String {
if is_glob_pattern(pattern) {
let mut regex = String::new();
let pattern = if let Some(stripped) = pattern.strip_prefix("**/") {
regex.push_str("(^|.*/)");
stripped
} else if let Some(stripped) = pattern.strip_prefix("*/") {
regex.push_str("^[^/]*/");
stripped
} else if let Some(stripped) = pattern.strip_prefix('/') {
regex.push('^');
stripped
} else {
regex.push_str("(^|/)");
pattern
};
let mut chars = pattern.chars().peekable();
while let Some(c) = chars.next() {
match c {
'*' => {
if chars.peek() == Some(&'*') {
chars.next(); if chars.peek() == Some(&'/') {
chars.next(); regex.push_str("(.*/)?");
} else {
regex.push_str(".*");
}
} else {
regex.push_str("[^/]*");
}
}
'?' => regex.push('.'),
'.' | '+' | '(' | ')' | '[' | ']' | '{' | '}' | '^' | '$' | '|' | '\\' => {
regex.push('\\');
regex.push(c);
}
_ => regex.push(c),
}
}
regex.push('/');
regex
} else {
format!("(^|/){}/", regex::escape(pattern))
}
}
fn matches_glob_pattern(path: &Path, patterns: &[String]) -> bool {
if patterns.is_empty() {
return true;
}
let Some(glob_set) = build_glob_set(patterns) else {
return false;
};
glob_set.is_match(path)
}
pub struct Searcher {
model: Colbert,
index: MmapIndex,
index_path: String,
}
impl Searcher {
pub fn load(project_root: &Path, model_id: &str, model_path: &Path) -> Result<Self> {
Self::load_with_quantized(project_root, model_id, model_path, false)
}
pub fn load_with_quantized(
project_root: &Path,
model_id: &str,
model_path: &Path,
quantized: bool,
) -> Result<Self> {
let index_dir = get_index_dir_for_project(project_root, model_id)?;
let vector_dir = get_vector_index_path(&index_dir);
let index_path = vector_dir.to_str().unwrap().to_string();
let acceleration_mode = env_acceleration_mode_lossy();
let execution_provider = match acceleration_mode {
AccelerationMode::ForceGpu => ExecutionProvider::Cuda,
AccelerationMode::ForceCpu => ExecutionProvider::Cpu,
AccelerationMode::Auto => {
if cfg!(feature = "coreml") {
ExecutionProvider::CoreML
} else {
ExecutionProvider::Cpu
}
}
};
#[cfg(feature = "cuda")]
match acceleration_mode {
AccelerationMode::ForceGpu => apply_acceleration_mode(AccelerationMode::ForceGpu),
AccelerationMode::ForceCpu | AccelerationMode::Auto => {
apply_acceleration_mode(AccelerationMode::ForceCpu)
}
}
crate::onnx_runtime::ensure_onnx_runtime().context("Failed to initialize ONNX Runtime")?;
#[cfg(feature = "cuda")]
if matches!(acceleration_mode, AccelerationMode::ForceGpu) {
if !crate::onnx_runtime::is_cudnn_available() {
anyhow::bail!("FORCE_GPU is set, but cuDNN was not initialized");
}
if !next_plaid_onnx::is_cuda_available() {
anyhow::bail!(
"FORCE_GPU is set, but the CUDA execution provider was not initialized"
);
}
}
let num_threads = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(8)
.min(crate::config::MAX_INTRA_OP_THREADS);
let model = crate::stderr::with_suppressed_stderr(|| {
Colbert::builder(model_path)
.with_quantized(quantized)
.with_threads(num_threads)
.with_execution_provider(execution_provider)
.build()
})
.context("Failed to load ColBERT model")?;
let index = MmapIndex::load(&index_path).context("Failed to load index")?;
Ok(Self {
model,
index,
index_path,
})
}
pub fn load_from_index_dir(index_dir: &Path, model_path: &Path) -> Result<Self> {
Self::load_from_index_dir_with_quantized(index_dir, model_path, false)
}
pub fn load_from_index_dir_with_quantized(
index_dir: &Path,
model_path: &Path,
quantized: bool,
) -> Result<Self> {
let vector_dir = get_vector_index_path(index_dir);
let index_path = vector_dir.to_str().unwrap().to_string();
let acceleration_mode = env_acceleration_mode_lossy();
let execution_provider = match acceleration_mode {
AccelerationMode::ForceGpu => ExecutionProvider::Cuda,
AccelerationMode::ForceCpu => ExecutionProvider::Cpu,
AccelerationMode::Auto => {
if cfg!(feature = "coreml") {
ExecutionProvider::CoreML
} else {
ExecutionProvider::Cpu
}
}
};
#[cfg(feature = "cuda")]
match acceleration_mode {
AccelerationMode::ForceGpu => apply_acceleration_mode(AccelerationMode::ForceGpu),
AccelerationMode::ForceCpu | AccelerationMode::Auto => {
apply_acceleration_mode(AccelerationMode::ForceCpu)
}
}
crate::onnx_runtime::ensure_onnx_runtime().context("Failed to initialize ONNX Runtime")?;
#[cfg(feature = "cuda")]
if matches!(acceleration_mode, AccelerationMode::ForceGpu) {
if !crate::onnx_runtime::is_cudnn_available() {
anyhow::bail!("FORCE_GPU is set, but cuDNN was not initialized");
}
if !next_plaid_onnx::is_cuda_available() {
anyhow::bail!(
"FORCE_GPU is set, but the CUDA execution provider was not initialized"
);
}
}
let num_threads = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(8)
.min(crate::config::MAX_INTRA_OP_THREADS);
let model = crate::stderr::with_suppressed_stderr(|| {
Colbert::builder(model_path)
.with_quantized(quantized)
.with_threads(num_threads)
.with_execution_provider(execution_provider)
.build()
})
.context("Failed to load ColBERT model")?;
let index = MmapIndex::load(&index_path).context("Failed to load index")?;
Ok(Self {
model,
index,
index_path,
})
}
pub fn filter_by_path_prefix(&self, prefix: &Path) -> Result<Vec<i64>> {
let prefix_str = prefix.to_string_lossy();
let like_pattern = format!("{}%", prefix_str);
let subset = filtering::where_condition(
&self.index_path,
"file LIKE ?",
&[serde_json::json!(like_pattern)],
)
.unwrap_or_default();
Ok(subset)
}
pub fn filter_by_file_patterns(&self, patterns: &[String]) -> Result<Vec<i64>> {
if patterns.is_empty() {
return Ok(vec![]);
}
let Some(glob_set) = build_glob_set(patterns) else {
return Ok(vec![]);
};
let all_metadata = filtering::get(&self.index_path, None, &[], None).unwrap_or_default();
let matching_ids: Vec<i64> = all_metadata
.into_iter()
.filter_map(|row| {
let doc_id = row.get("_subset_")?.as_i64()?;
let file = row.get("file")?.as_str()?;
let path = Path::new(file);
if glob_set.is_match(path) {
Some(doc_id)
} else {
None
}
})
.collect();
Ok(matching_ids)
}
pub fn filter_exclude_by_patterns(&self, patterns: &[String]) -> Result<Vec<i64>> {
if patterns.is_empty() {
return filtering::where_condition(&self.index_path, "1=1", &[])
.map_err(|e| anyhow::anyhow!("{}", e));
}
let regex_patterns: Vec<String> = patterns.iter().map(|p| glob_to_regex(p)).collect();
let combined_regex = regex_patterns.join("|");
let subset = filtering::where_condition_regexp(
&self.index_path,
"NOT (file REGEXP ?)",
&[serde_json::json!(combined_regex)],
)
.unwrap_or_default();
Ok(subset)
}
pub fn filter_exclude_by_dirs(&self, dirs: &[String]) -> Result<Vec<i64>> {
if dirs.is_empty() {
return filtering::where_condition(&self.index_path, "1=1", &[])
.map_err(|e| anyhow::anyhow!("{}", e));
}
let dir_patterns: Vec<String> = dirs.iter().map(|d| dir_pattern_to_regex(d)).collect();
let combined_regex = dir_patterns.join("|");
let subset = filtering::where_condition_regexp(
&self.index_path,
"NOT (file REGEXP ?)",
&[serde_json::json!(combined_regex)],
)
.unwrap_or_default();
Ok(subset)
}
pub fn filter_by_files(&self, files: &[String]) -> Result<Vec<i64>> {
if files.is_empty() {
return Ok(vec![]);
}
let mut conditions = Vec::new();
let mut params = Vec::new();
for file in files {
conditions.push("file = ?");
params.push(serde_json::json!(file));
}
let condition = conditions.join(" OR ");
let subset =
filtering::where_condition(&self.index_path, &condition, ¶ms).unwrap_or_default();
Ok(subset)
}
pub fn filter_by_text_pattern_with_options(
&self,
pattern: &str,
extended_regexp: bool,
fixed_strings: bool,
word_regexp: bool,
case_sensitive: bool,
) -> Result<Vec<i64>> {
if pattern.is_empty() {
return Ok(vec![]);
}
if fixed_strings {
let escaped = regex::escape(pattern);
let regex_pattern = if word_regexp {
format!(r"\b{}\b", escaped)
} else {
escaped
};
return filtering::where_condition_regexp(
&self.index_path,
"code REGEXP ?",
&[serde_json::json!(regex_pattern)],
)
.map_err(|e| anyhow::anyhow!("{}", e));
}
let flags = if case_sensitive { "(?m)" } else { "(?mi)" };
let regex_pattern = if word_regexp {
let ere_pattern = escape_literal_braces(&bre_to_ere(pattern));
format!(r"{}\b{}\b", flags, ere_pattern)
} else if extended_regexp {
format!("{}{}", flags, escape_literal_braces(&bre_to_ere(pattern)))
} else {
format!("{}{}", flags, regex::escape(pattern))
};
let fixed_pattern = {
let escaped = regex::escape(pattern);
if word_regexp {
format!(r"\b{}\b", escaped)
} else {
escaped
}
};
let regex_results = filtering::where_condition_regexp(
&self.index_path,
"code REGEXP ?",
&[serde_json::json!(regex_pattern)],
);
if regex_pattern == fixed_pattern {
return regex_results.map_err(|e| anyhow::anyhow!("{}", e));
}
let fixed_results = filtering::where_condition_regexp(
&self.index_path,
"code REGEXP ?",
&[serde_json::json!(fixed_pattern)],
)
.unwrap_or_default();
match regex_results {
Ok(regex_ids) => {
let mut combined: std::collections::HashSet<i64> = regex_ids.into_iter().collect();
combined.extend(fixed_results);
Ok(combined.into_iter().collect())
}
Err(_) => {
Ok(fixed_results)
}
}
}
pub fn get_metadata_for_ids(&self, ids: &[i64]) -> Result<Vec<serde_json::Value>> {
if ids.is_empty() {
return Ok(vec![]);
}
let metadata = filtering::get(&self.index_path, None, &[], Some(ids)).unwrap_or_default();
Ok(metadata)
}
pub fn encode_query(&self, query: &str) -> Result<ndarray::Array2<f32>> {
let query_embeddings =
crate::stderr::with_suppressed_stderr(|| self.model.encode_queries(&[query]))
.context("Failed to encode query")?;
Ok(query_embeddings.into_iter().next().unwrap())
}
pub fn fts5_search(
&self,
query: &str,
top_k: usize,
subset: Option<&[i64]>,
) -> Option<next_plaid::QueryResult> {
let sanitized_query = next_plaid::text_search::sanitize_fts5_query_or(query);
if sanitized_query.is_empty() {
return None;
}
if let Some(sub) = subset {
next_plaid::text_search::search_filtered(&self.index_path, &sanitized_query, top_k, sub)
.ok()
} else {
next_plaid::text_search::search(&self.index_path, &sanitized_query, top_k).ok()
}
}
pub fn search(
&self,
query: &str,
top_k: usize,
subset: Option<&[i64]>,
) -> Result<Vec<SearchResult>> {
let query_emb = self.encode_query(query)?;
self.search_with_embedding(&query_emb, top_k, subset)
}
pub fn search_with_embedding(
&self,
query_emb: &ndarray::Array2<f32>,
top_k: usize,
subset: Option<&[i64]>,
) -> Result<Vec<SearchResult>> {
let params = search_params_from_env(top_k);
let results = self
.index
.search(query_emb, ¶ms, subset)
.context("Search failed")?;
let doc_ids: Vec<i64> = results.passage_ids.to_vec();
let metadata = filtering::get(&self.index_path, None, &[], Some(&doc_ids))
.context("Failed to retrieve metadata")?;
let search_results: Vec<SearchResult> = metadata
.into_iter()
.zip(results.scores.iter())
.filter_map(|(mut meta, &score)| {
fix_sqlite_types(&mut meta);
serde_json::from_value::<CodeUnit>(meta)
.ok()
.map(|unit| SearchResult { unit, score })
})
.collect();
Ok(search_results)
}
pub fn search_hybrid(
&self,
query: &str,
top_k: usize,
subset: Option<&[i64]>,
alpha: f32,
) -> Result<Vec<SearchResult>> {
let query_emb = self.encode_query(query)?;
self.search_hybrid_with_embedding(&query_emb, query, top_k, subset, alpha, None)
}
pub fn search_hybrid_with_embedding(
&self,
query_emb: &ndarray::Array2<f32>,
query: &str,
top_k: usize,
subset: Option<&[i64]>,
alpha: f32,
fts5_results: Option<&next_plaid::QueryResult>,
) -> Result<Vec<SearchResult>> {
let fetch_k = std::cmp::min(
std::cmp::max(top_k * 20, 200),
self.index.num_documents().max(top_k),
);
let params = search_params_from_env(fetch_k);
let semantic = self
.index
.search(query_emb, ¶ms, subset)
.context("Semantic search failed")?;
trace_log(
query,
"semantic",
&semantic.passage_ids,
&semantic.scores,
20,
);
let owned_fts5;
let keyword = match fts5_results {
Some(fts5) => {
if let Some(sub) = subset {
let sub_set: HashSet<i64> = sub.iter().copied().collect();
let mut filtered_ids = Vec::new();
let mut filtered_scores = Vec::new();
for (id, score) in fts5.passage_ids.iter().zip(fts5.scores.iter()) {
if sub_set.contains(id) {
filtered_ids.push(*id);
filtered_scores.push(*score);
}
}
owned_fts5 = next_plaid::QueryResult {
query_id: 0,
passage_ids: filtered_ids,
scores: filtered_scores,
};
Some(&owned_fts5)
} else {
Some(fts5)
}
}
None => {
owned_fts5 =
self.fts5_search(query, fetch_k, subset)
.unwrap_or(next_plaid::QueryResult {
query_id: 0,
passage_ids: vec![],
scores: vec![],
});
if owned_fts5.passage_ids.is_empty() {
None
} else {
Some(&owned_fts5)
}
}
};
if let Some(kw) = keyword {
trace_log(query, "bm25", &kw.passage_ids, &kw.scores, 20);
}
let (fused_ids, fused_scores) = if let Some(kw) = keyword {
if kw.passage_ids.is_empty() {
(semantic.passage_ids, semantic.scores)
} else {
next_plaid::text_search::fuse_relative_score(
&semantic.passage_ids,
&semantic.scores,
&kw.passage_ids,
&kw.scores,
alpha,
fetch_k,
)
}
} else {
(semantic.passage_ids, semantic.scores)
};
trace_log(query, "fused", &fused_ids, &fused_scores, 20);
let metadata = filtering::get(&self.index_path, None, &[], Some(&fused_ids))
.context("Failed to retrieve metadata")?;
let apply_penalty = crate::ranking::should_apply_path_penalty(query);
let mut meta_by_id: std::collections::HashMap<i64, serde_json::Value> =
std::collections::HashMap::with_capacity(metadata.len());
for mut m in metadata {
if let Some(id) = m.get("_subset_").and_then(|v| v.as_i64()) {
fix_sqlite_types(&mut m);
meta_by_id.insert(id, m);
}
}
let mut search_results: Vec<SearchResult> = fused_ids
.iter()
.zip(fused_scores.iter())
.filter_map(|(&id, &score)| {
let meta = meta_by_id.remove(&id)?;
serde_json::from_value::<CodeUnit>(meta).ok().map(|unit| {
let mut final_score = score;
if apply_penalty {
let file_str = unit.file.to_string_lossy();
final_score *= crate::ranking::file_path_penalty(&file_str);
}
SearchResult {
unit,
score: final_score,
}
})
})
.collect();
trace_log_results(query, "after_path_penalty", &search_results, 30);
crate::ranking::apply_path_stem_boost(
&mut search_results,
query,
|r| r.unit.file.to_str().unwrap_or(""),
|r| r.score,
|r, s| r.score = s,
);
trace_log_results(query, "after_path_stem_boost", &search_results, 30);
crate::ranking::apply_definition_boost(
&mut search_results,
query,
|r| r.unit.name.as_str(),
|r| {
matches!(
r.unit.unit_type,
crate::UnitType::Function
| crate::UnitType::Method
| crate::UnitType::Class
| crate::UnitType::Constant
)
},
|r| r.score,
|r, s| r.score = s,
);
trace_log_results(query, "after_definition_boost", &search_results, 30);
crate::ranking::apply_file_coherence_boost(
&mut search_results,
|r| r.unit.file.to_str().unwrap_or(""),
|r| r.score,
|r, s| r.score = s,
);
trace_log_results(query, "after_coherence_boost", &search_results, 30);
search_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
search_results = collapse_by_file(search_results, top_k);
trace_log_results(query, "final", &search_results, 30);
Ok(search_results)
}
pub fn num_documents(&self) -> usize {
self.index.num_documents()
}
}
fn trace_log(query: &str, stage: &str, ids: &[i64], scores: &[f32], limit: usize) {
if !trace_enabled() {
return;
}
let n = ids.len().min(scores.len()).min(limit);
let mut s = String::with_capacity(64 + n * 32);
s.push_str("{\"stage\":\"");
s.push_str(stage);
s.push_str("\",\"query\":");
json_escape(&mut s, query);
s.push_str(",\"ids\":[");
for (i, id) in ids.iter().take(n).enumerate() {
if i > 0 {
s.push(',');
}
s.push_str(&id.to_string());
}
s.push_str("],\"scores\":[");
for (i, sc) in scores.iter().take(n).enumerate() {
if i > 0 {
s.push(',');
}
s.push_str(&format!("{:.6}", sc));
}
s.push_str("]}");
eprintln!("__COLGREP_TRACE__ {}", s);
}
fn trace_log_results(query: &str, stage: &str, results: &[SearchResult], limit: usize) {
if !trace_enabled() {
return;
}
let n = results.len().min(limit);
let mut s = String::with_capacity(64 + n * 64);
s.push_str("{\"stage\":\"");
s.push_str(stage);
s.push_str("\",\"query\":");
json_escape(&mut s, query);
s.push_str(",\"results\":[");
for (i, r) in results.iter().take(n).enumerate() {
if i > 0 {
s.push(',');
}
s.push_str("{\"file\":");
json_escape(&mut s, &r.unit.file.to_string_lossy());
s.push_str(&format!(",\"score\":{:.6}}}", r.score));
}
s.push_str("]}");
eprintln!("__COLGREP_TRACE__ {}", s);
}
fn trace_enabled() -> bool {
matches!(
std::env::var("COLGREP_TRACE").as_deref(),
Ok("1") | Ok("true") | Ok("TRUE")
)
}
fn json_escape(out: &mut String, s: &str) {
out.push('"');
for c in s.chars() {
match c {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
c if (c as u32) < 0x20 => out.push_str(&format!("\\u{:04x}", c as u32)),
c => out.push(c),
}
}
out.push('"');
}
fn collapse_by_file(results: Vec<SearchResult>, top_k: usize) -> Vec<SearchResult> {
let mut by_file: std::collections::HashMap<std::path::PathBuf, usize> =
std::collections::HashMap::new();
let mut out: Vec<SearchResult> = Vec::with_capacity(top_k.min(results.len()));
for r in results {
if let Some(&idx) = by_file.get(&r.unit.file) {
let leader = &mut out[idx];
leader.unit.line = leader.unit.line.min(r.unit.line);
leader.unit.end_line = leader.unit.end_line.max(r.unit.end_line);
} else {
if out.len() >= top_k {
continue;
}
by_file.insert(r.unit.file.clone(), out.len());
out.push(r);
}
}
out
}
fn fix_sqlite_types(meta: &mut serde_json::Value) {
if let serde_json::Value::Object(ref mut obj) = meta {
let keys: Vec<String> = obj.keys().cloned().collect();
for key in keys {
if key.starts_with("has_") || key.starts_with("is_") {
if let Some(v) = obj.get(&key) {
if let Some(n) = v.as_i64() {
obj.insert(key, serde_json::Value::Bool(n != 0));
}
}
continue;
}
if let Some(serde_json::Value::String(s)) = obj.get(&key) {
if s.starts_with('[') {
if let Ok(arr) = serde_json::from_str::<serde_json::Value>(s) {
if arr.is_array() {
obj.insert(key, arr);
}
}
}
}
}
}
}
pub fn index_exists(project_root: &Path, model: &str) -> bool {
paths::index_exists(project_root, model)
}
fn prompt_large_index_confirmation(num_units: usize) -> bool {
use std::io::{self, BufRead, Write};
if !atty::is(atty::Stream::Stdin) {
return true;
}
eprintln!(
"\n⚠️ Large codebase detected: {} code units to index",
num_units
);
eprintln!(" This may take a while. Use -y/--yes to skip this prompt in the future.\n");
eprint!(" Proceed with indexing? [Y/n] ");
io::stderr().flush().ok();
let stdin = io::stdin();
let mut line = String::new();
if stdin.lock().read_line(&mut line).is_err() {
return false;
}
let response = line.trim().to_lowercase();
response.is_empty() || response == "y" || response == "yes"
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_glob_simple_extension() {
let patterns = vec!["*.rs".to_string()];
assert!(matches_glob_pattern(Path::new("src/main.rs"), &patterns));
assert!(matches_glob_pattern(
Path::new("nested/deep/file.rs"),
&patterns
));
assert!(!matches_glob_pattern(Path::new("src/main.py"), &patterns));
}
#[test]
fn test_glob_recursive_double_star() {
let patterns = vec!["**/*.rs".to_string()];
assert!(matches_glob_pattern(Path::new("src/main.rs"), &patterns));
assert!(matches_glob_pattern(Path::new("a/b/c/d.rs"), &patterns));
assert!(!matches_glob_pattern(Path::new("main.py"), &patterns));
}
#[test]
fn test_glob_directory_pattern() {
let patterns = vec!["src/**/*.rs".to_string()];
assert!(matches_glob_pattern(Path::new("src/main.rs"), &patterns));
assert!(matches_glob_pattern(
Path::new("src/index/mod.rs"),
&patterns
));
assert!(matches_glob_pattern(
Path::new("project/src/main.rs"),
&patterns
));
assert!(!matches_glob_pattern(Path::new("lib/main.rs"), &patterns));
}
#[test]
fn test_glob_github_workflows() {
let patterns = vec!["**/.github/**/*".to_string()];
assert!(matches_glob_pattern(
Path::new(".github/workflows/ci.yml"),
&patterns
));
assert!(matches_glob_pattern(
Path::new("project/.github/actions/setup.yml"),
&patterns
));
assert!(!matches_glob_pattern(Path::new("src/main.rs"), &patterns));
}
#[test]
fn test_glob_multiple_patterns() {
let patterns = vec!["*.rs".to_string(), "*.py".to_string()];
assert!(matches_glob_pattern(Path::new("main.rs"), &patterns));
assert!(matches_glob_pattern(Path::new("main.py"), &patterns));
assert!(!matches_glob_pattern(Path::new("main.js"), &patterns));
}
#[test]
fn test_glob_test_files() {
let patterns = vec!["*_test.go".to_string()];
assert!(matches_glob_pattern(
Path::new("pkg/main_test.go"),
&patterns
));
assert!(!matches_glob_pattern(Path::new("pkg/main.go"), &patterns));
}
#[test]
fn test_glob_empty_patterns() {
let patterns: Vec<String> = vec![];
assert!(matches_glob_pattern(Path::new("any/file.rs"), &patterns));
}
#[test]
fn test_expand_braces_simple() {
let expanded = expand_braces("*.{rs,md}");
assert_eq!(expanded, vec!["*.rs", "*.md"]);
}
#[test]
fn test_expand_braces_no_braces() {
let expanded = expand_braces("*.rs");
assert_eq!(expanded, vec!["*.rs"]);
}
#[test]
fn test_expand_braces_with_path() {
let expanded = expand_braces("src/**/*.{ts,tsx,js,jsx}");
assert_eq!(
expanded,
vec!["src/**/*.ts", "src/**/*.tsx", "src/**/*.js", "src/**/*.jsx"]
);
}
#[test]
fn test_expand_braces_prefix() {
let expanded = expand_braces("{src,lib}/**/*.rs");
assert_eq!(expanded, vec!["src/**/*.rs", "lib/**/*.rs"]);
}
#[test]
fn test_expand_braces_multiple_groups() {
let expanded = expand_braces("{src,lib}/*.{rs,md}");
assert_eq!(
expanded,
vec!["src/*.rs", "src/*.md", "lib/*.rs", "lib/*.md"]
);
}
#[test]
fn test_glob_brace_expansion() {
let patterns = vec!["*.{rs,py}".to_string()];
assert!(matches_glob_pattern(Path::new("main.rs"), &patterns));
assert!(matches_glob_pattern(Path::new("main.py"), &patterns));
assert!(!matches_glob_pattern(Path::new("main.js"), &patterns));
}
#[test]
fn test_glob_brace_expansion_with_directory() {
let patterns = vec!["src/**/*.{ts,tsx}".to_string()];
assert!(matches_glob_pattern(Path::new("src/app.ts"), &patterns));
assert!(matches_glob_pattern(
Path::new("src/components/Button.tsx"),
&patterns
));
assert!(!matches_glob_pattern(Path::new("src/main.js"), &patterns));
}
#[test]
fn test_is_within_project_root_simple_path() {
let temp_dir = std::env::temp_dir().join("plaid_test_project");
let _ = std::fs::create_dir_all(&temp_dir);
assert!(is_within_project_root(&temp_dir, Path::new("src/main.rs")));
assert!(is_within_project_root(&temp_dir, Path::new("file.txt")));
}
#[test]
fn test_is_within_project_root_path_traversal() {
let temp_dir = std::env::temp_dir().join("plaid_test_project");
let _ = std::fs::create_dir_all(&temp_dir);
assert!(!is_within_project_root(
&temp_dir,
Path::new("../../../etc/passwd")
));
assert!(!is_within_project_root(&temp_dir, Path::new("../sibling")));
assert!(!is_within_project_root(
&temp_dir,
Path::new("foo/../../..")
));
}
#[test]
fn test_is_within_project_root_hidden_traversal() {
let temp_dir = std::env::temp_dir().join("plaid_test_project");
let _ = std::fs::create_dir_all(&temp_dir);
assert!(!is_within_project_root(
&temp_dir,
Path::new("src/../../../etc/passwd")
));
assert!(!is_within_project_root(
&temp_dir,
Path::new("./foo/../../../bar")
));
}
#[test]
fn test_is_within_project_root_valid_dotdot_in_middle() {
let temp_dir = std::env::temp_dir().join("plaid_test_project_dotdot");
let sub_dir = temp_dir.join("src").join("subdir");
let _ = std::fs::create_dir_all(&sub_dir);
let test_file = temp_dir.join("src").join("main.rs");
let _ = std::fs::write(&test_file, "fn main() {}");
assert!(is_within_project_root(
&temp_dir,
Path::new("src/subdir/../main.rs")
));
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_bre_to_ere_alternation() {
assert_eq!(bre_to_ere(r"foo\|bar"), "foo|bar");
assert_eq!(bre_to_ere(r"a\|b\|c"), "a|b|c");
}
#[test]
fn test_bre_to_ere_quantifiers() {
assert_eq!(bre_to_ere(r"a\+"), r"a\+");
assert_eq!(bre_to_ere(r"a\?"), r"a\?");
assert_eq!(bre_to_ere(r"a\{2,3\}"), "a{2,3}");
}
#[test]
fn test_bre_to_ere_grouping() {
assert_eq!(bre_to_ere(r"\(foo\)"), "(foo)");
assert_eq!(bre_to_ere(r"\(a\|b\)"), "(a|b)");
}
#[test]
fn test_bre_to_ere_escaped_backslash() {
assert_eq!(bre_to_ere(r"foo\\bar"), r"foo\\bar");
assert_eq!(bre_to_ere(r"\\|"), r"\\|"); }
#[test]
fn test_bre_to_ere_no_change() {
assert_eq!(bre_to_ere("foo|bar"), "foo|bar");
assert_eq!(bre_to_ere("a+b?"), "a+b?");
assert_eq!(bre_to_ere(r"foo\.bar"), r"foo\.bar"); }
#[test]
fn test_bre_to_ere_mixed() {
assert_eq!(
bre_to_ere(r"default.*25\|top_k.*25"),
"default.*25|top_k.*25"
);
}
#[test]
fn test_bre_to_ere_trailing_backslash() {
assert_eq!(bre_to_ere(r"foo\"), r"foo\");
}
#[test]
fn test_bre_to_ere_unbalanced_parens() {
assert_eq!(bre_to_ere(r"error\(4"), r"error\(4");
assert_eq!(bre_to_ere(r"foo\)"), r"foo\)");
assert_eq!(bre_to_ere(r"a\(b\)c\(d"), "a(b)c\\(d");
}
#[test]
fn test_bre_to_ere_leading_quantifiers() {
assert_eq!(bre_to_ere(r"\+foo"), r"\+foo");
assert_eq!(bre_to_ere(r"\?foo"), r"\?foo");
}
#[test]
fn test_bre_to_ere_unbalanced_braces() {
assert_eq!(bre_to_ere(r"a\{2"), r"a\{2");
assert_eq!(bre_to_ere(r"\{2\}"), r"\{2\}");
}
#[test]
fn test_escape_literal_braces_quantifiers_unchanged() {
assert_eq!(escape_literal_braces("a{2}"), "a{2}");
assert_eq!(escape_literal_braces("a{2,}"), "a{2,}");
assert_eq!(escape_literal_braces("a{2,4}"), "a{2,4}");
assert_eq!(escape_literal_braces("a{,4}"), "a{,4}");
assert_eq!(escape_literal_braces("Error[0-9]{2,4}"), "Error[0-9]{2,4}");
}
#[test]
fn test_escape_literal_braces_literals_escaped() {
assert_eq!(escape_literal_braces("enum.*{"), "enum.*[{]");
assert_eq!(escape_literal_braces("struct {"), "struct [{]");
assert_eq!(escape_literal_braces("}"), "[}]");
assert_eq!(escape_literal_braces("{}"), "[{][}]");
}
#[test]
fn test_escape_literal_braces_mixed() {
assert_eq!(
escape_literal_braces("enum.*Error.*{[^}]*Error[0-9]{2,4}[^}]*}"),
"enum.*Error.*[{][^}]*Error[0-9]{2,4}[^}]*[}]"
);
}
#[test]
fn test_escape_literal_braces_character_class_unchanged() {
assert_eq!(escape_literal_braces("[{]"), "[{]");
assert_eq!(escape_literal_braces("[}]"), "[}]");
assert_eq!(escape_literal_braces("[{}]"), "[{}]");
assert_eq!(escape_literal_braces("a[{]b"), "a[{]b");
}
#[test]
fn test_escape_literal_braces_complex_pattern() {
let pattern = r"enum\s+[A-Za-z0-9_]+Error\s*{[^}]*Error[0-9]{2,4}[^}]*}";
let escaped = escape_literal_braces(pattern);
assert_eq!(
escaped,
r"enum\s+[A-Za-z0-9_]+Error\s*[{][^}]*Error[0-9]{2,4}[^}]*[}]"
);
}
#[test]
fn test_combine_search_results_no_duplicates() {
let regex_ids: Vec<i64> = vec![1, 2, 3, 4, 5];
let fixed_ids: Vec<i64> = vec![3, 4, 5, 6, 7];
let mut combined: std::collections::HashSet<i64> = regex_ids.into_iter().collect();
combined.extend(fixed_ids);
let result: Vec<i64> = combined.into_iter().collect();
let mut sorted = result.clone();
sorted.sort();
assert!(
sorted.windows(2).all(|w| w[0] != w[1]),
"Combined results contain duplicates"
);
assert_eq!(sorted.len(), 7);
let regex_ids: Vec<i64> = vec![10, 20, 30];
let fixed_ids: Vec<i64> = vec![10, 20, 30];
let mut combined: std::collections::HashSet<i64> = regex_ids.into_iter().collect();
combined.extend(fixed_ids);
let result: Vec<i64> = combined.into_iter().collect();
let mut sorted = result.clone();
sorted.sort();
assert!(
sorted.windows(2).all(|w| w[0] != w[1]),
"Identical results produced duplicates"
);
assert_eq!(sorted.len(), 3);
let regex_ids: Vec<i64> = vec![1, 2, 3];
let fixed_ids: Vec<i64> = vec![4, 5, 6];
let mut combined: std::collections::HashSet<i64> = regex_ids.into_iter().collect();
combined.extend(fixed_ids);
let result: Vec<i64> = combined.into_iter().collect();
let mut sorted = result.clone();
sorted.sort();
assert!(
sorted.windows(2).all(|w| w[0] != w[1]),
"Disjoint results produced duplicates"
);
assert_eq!(sorted.len(), 6);
}
#[test]
fn test_is_glob_pattern() {
assert!(is_glob_pattern("*.rs"));
assert!(is_glob_pattern("**/test"));
assert!(is_glob_pattern("foo?bar"));
assert!(is_glob_pattern("[abc]"));
assert!(!is_glob_pattern("vendor"));
assert!(!is_glob_pattern("node_modules"));
assert!(!is_glob_pattern(".claude/plugins"));
}
#[test]
fn test_dir_pattern_to_regex_literal() {
assert_eq!(dir_pattern_to_regex("vendor"), "(^|/)vendor/");
assert_eq!(dir_pattern_to_regex("node_modules"), "(^|/)node_modules/");
assert_eq!(
dir_pattern_to_regex(".claude/plugins"),
"(^|/)\\.claude/plugins/"
);
}
#[test]
fn test_dir_pattern_to_regex_glob() {
let pattern = dir_pattern_to_regex("*/plugins");
assert_eq!(pattern, "^[^/]*/plugins/");
let pattern = dir_pattern_to_regex("**/test_*");
assert_eq!(pattern, "(^|.*/)test_[^/]*/");
let pattern = dir_pattern_to_regex(".claude/*");
assert_eq!(pattern, "(^|/)\\.claude/[^/]*/");
}
#[test]
fn test_dir_pattern_to_regex_matching() {
use regex::Regex;
let pattern = dir_pattern_to_regex("vendor");
let re = Regex::new(&pattern).unwrap();
assert!(re.is_match("vendor/package.json"));
assert!(re.is_match("src/vendor/lib.rs"));
assert!(!re.is_match("vendorfile.txt"));
let pattern = dir_pattern_to_regex("*/plugins");
let re = Regex::new(&pattern).unwrap();
assert!(re.is_match(".claude/plugins/tool.json"));
assert!(re.is_match("foo/plugins/bar.txt"));
assert!(!re.is_match("plugins/direct.txt")); assert!(!re.is_match("a/b/plugins/deep.txt"));
let pattern = dir_pattern_to_regex("**/test_*");
let re = Regex::new(&pattern).unwrap();
assert!(re.is_match("test_utils/helper.rs"));
assert!(re.is_match("src/test_integration/spec.rs"));
assert!(re.is_match("a/b/c/test_foo/file.rs"));
assert!(!re.is_match("src/testing/file.rs"));
let pattern = dir_pattern_to_regex(".claude/*");
let re = Regex::new(&pattern).unwrap();
assert!(re.is_match(".claude/plugins/file.json"));
assert!(re.is_match("foo/.claude/bar/test.txt"));
assert!(!re.is_match(".claude/file.json")); }
#[test]
fn test_should_ignore_relative_hidden_subdir() {
let empty: &[String] = &[];
assert!(should_ignore(Path::new(".hidden/foo.rs"), empty, empty));
assert!(should_ignore(Path::new("src/.secret/bar.rs"), empty, empty));
assert!(!should_ignore(
Path::new(".github/workflows/ci.yml"),
empty,
empty
));
}
#[test]
fn test_should_ignore_does_not_reject_dotprefixed_root_when_relative() {
let empty: &[String] = &[];
assert!(!should_ignore(Path::new("index.ts"), empty, empty));
assert!(!should_ignore(Path::new("src/lib.rs"), empty, empty));
assert!(!should_ignore(Path::new("package.json"), empty, empty));
}
#[test]
fn test_should_ignore_absolute_dotprefixed_ancestors() {
let empty: &[String] = &[];
let path = Path::new("/home/user/.pi/agent/extensions/index.ts");
assert!(should_ignore(path, empty, empty));
}
#[test]
fn test_should_ignore_extra_ignore_patterns() {
let empty: &[String] = &[];
let extra = vec!["generated".to_string(), "*.pb.go".to_string()];
assert!(should_ignore(
Path::new("src/generated/types.rs"),
&extra,
empty
));
assert!(should_ignore(Path::new("api/service.pb.go"), &extra, empty));
assert!(!should_ignore(Path::new("src/main.rs"), &extra, empty));
}
#[test]
fn test_should_ignore_force_include_overrides_default() {
let empty: &[String] = &[];
let force = vec![".vscode".to_string()];
assert!(should_ignore(
Path::new(".vscode/settings.json"),
empty,
empty
));
assert!(!should_ignore(
Path::new(".vscode/settings.json"),
empty,
&force
));
}
#[test]
fn test_should_ignore_force_include_overrides_ignored_dir() {
let empty: &[String] = &[];
let force = vec!["vendor".to_string()];
assert!(should_ignore(Path::new("vendor/lib/util.go"), empty, empty));
assert!(!should_ignore(
Path::new("vendor/lib/util.go"),
empty,
&force
));
}
#[test]
fn test_should_ignore_force_include_overrides_extra_ignore() {
let extra = vec!["generated".to_string()];
let force = vec!["generated".to_string()];
assert!(!should_ignore(
Path::new("src/generated/types.rs"),
&extra,
&force
));
}
#[test]
fn test_should_ignore_force_include_path_prefix() {
let empty: &[String] = &[];
let force = vec!["vendor/internal".to_string()];
assert!(!should_ignore(
Path::new("vendor/internal/lib.go"),
empty,
&force
));
assert!(should_ignore(
Path::new("vendor/external/lib.go"),
empty,
&force
));
}
#[test]
fn test_should_ignore_force_include_suffix_pattern() {
let empty: &[String] = &[];
let force = vec!["*.egg-info".to_string()];
assert!(should_ignore(
Path::new("mypackage.egg-info/PKG-INFO"),
empty,
empty
));
assert!(!should_ignore(
Path::new("mypackage.egg-info/PKG-INFO"),
empty,
&force
));
}
#[test]
fn test_should_ignore_combined_extra_and_force() {
let extra = vec!["snapshots".to_string()];
let force = vec![".idea".to_string()];
assert!(should_ignore(
Path::new("tests/snapshots/test1.snap"),
&extra,
&force
));
assert!(!should_ignore(
Path::new(".idea/workspace.xml"),
&extra,
&force
));
assert!(should_ignore(
Path::new("node_modules/foo/bar.js"),
&extra,
&force
));
}
}