mod embedding;
mod parsing;
mod types;
mod upsert;
mod windowing;
pub(crate) use types::embed_batch_size;
pub(crate) use types::PipelineStats;
pub(crate) use windowing::apply_windowing;
use std::path::Path;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use anyhow::{Context, Result};
use crossbeam_channel::{bounded, Receiver, Sender};
use indicatif::{ProgressBar, ProgressStyle};
use cqs::embedder::ModelConfig;
use cqs::{Parser as CqParser, Store};
use embedding::{cpu_embed_stage, gpu_embed_stage};
use parsing::{parser_stage, ParserStageContext};
use types::{
embed_channel_depth, parse_channel_depth, EmbedStageContext, EmbeddedBatch, ParsedBatch,
};
use upsert::store_stage;
pub(crate) fn run_index_pipeline(
root: &Path,
files: Vec<PathBuf>,
store: Arc<Store>,
force: bool,
quiet: bool,
model_config: ModelConfig,
) -> Result<PipelineStats> {
let _span = tracing::info_span!("run_index_pipeline", file_count = files.len()).entered();
let total_files = files.len();
let (parse_tx, parse_rx): (Sender<ParsedBatch>, Receiver<ParsedBatch>) =
bounded(parse_channel_depth());
let (embed_tx, embed_rx): (Sender<EmbeddedBatch>, Receiver<EmbeddedBatch>) =
bounded(embed_channel_depth());
let (fail_tx, fail_rx): (Sender<ParsedBatch>, Receiver<ParsedBatch>) =
bounded(embed_channel_depth());
let parser = Arc::new(CqParser::new().context("Failed to initialize parser")?);
let parsed_count = Arc::new(AtomicUsize::new(0));
let embedded_count = Arc::new(AtomicUsize::new(0));
let gpu_failures = Arc::new(AtomicUsize::new(0));
let parse_errors = Arc::new(AtomicUsize::new(0));
let parse_rx_cpu = parse_rx.clone();
let embed_tx_cpu = embed_tx.clone();
let parser_handle = {
let parser = Arc::clone(&parser);
let store = Arc::clone(&store);
let parsed_count = Arc::clone(&parsed_count);
let parse_errors = Arc::clone(&parse_errors);
let root = root.to_path_buf();
thread::spawn(move || {
parser_stage(
files,
ParserStageContext {
root,
force,
parser,
store,
parsed_count,
parse_errors,
},
parse_tx,
)
})
};
let global_cache: Option<Arc<cqs::cache::EmbeddingCache>> = {
let cache_path = cqs::cache::EmbeddingCache::default_path();
match cqs::cache::EmbeddingCache::open(&cache_path) {
Ok(c) => {
tracing::info!(path = %cache_path.display(), "Global embedding cache opened");
Some(Arc::new(c))
}
Err(e) => {
tracing::warn!(error = %e, "Global embedding cache unavailable");
None
}
}
};
let gpu_handle = {
let ctx = EmbedStageContext {
store: Arc::clone(&store),
embedded_count: Arc::clone(&embedded_count),
model_config: model_config.clone(),
global_cache: global_cache.clone(),
};
let gpu_failures = Arc::clone(&gpu_failures);
thread::spawn(move || gpu_embed_stage(parse_rx, embed_tx, fail_tx, ctx, gpu_failures))
};
let cpu_handle = {
let ctx = EmbedStageContext {
store: Arc::clone(&store),
embedded_count: Arc::clone(&embedded_count),
model_config,
global_cache: global_cache.clone(),
};
thread::spawn(move || cpu_embed_stage(parse_rx_cpu, fail_rx, embed_tx_cpu, ctx))
};
let progress = if quiet {
ProgressBar::hidden()
} else {
let pb = ProgressBar::new(total_files as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {msg}")
.unwrap_or_else(|e| {
tracing::warn!("Progress template error: {}, using default", e);
ProgressStyle::default_bar()
}),
);
pb
};
let (total_embedded, total_cached, total_type_edges, total_calls) =
store_stage(embed_rx, &store, &parsed_count, &embedded_count, &progress)?;
progress.finish_with_message("done");
parser_handle
.join()
.map_err(|e| anyhow::anyhow!("Parser thread panicked: {}", panic_message(&e)))??;
gpu_handle
.join()
.map_err(|e| anyhow::anyhow!("GPU embedder thread panicked: {}", panic_message(&e)))??;
cpu_handle
.join()
.map_err(|e| anyhow::anyhow!("CPU embedder thread panicked: {}", panic_message(&e)))??;
if let Some(ref cache) = global_cache {
if let Err(e) = cache.evict() {
tracing::warn!(error = %e, "Global cache eviction failed");
}
}
if let Err(e) = store.touch_updated_at() {
tracing::warn!(error = %e, "Failed to update timestamp");
}
let stats = PipelineStats {
total_embedded,
total_cached,
gpu_failures: gpu_failures.load(Ordering::Relaxed),
parse_errors: parse_errors.load(Ordering::Relaxed),
total_type_edges,
total_calls,
};
tracing::info!(
total_embedded = stats.total_embedded,
total_cached = stats.total_cached,
gpu_failures = stats.gpu_failures,
parse_errors = stats.parse_errors,
total_type_edges = stats.total_type_edges,
total_calls = stats.total_calls,
"Pipeline indexing complete"
);
Ok(stats)
}
fn panic_message(payload: &Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
}
}
#[cfg(test)]
mod tests {
use super::embedding::create_embedded_batch;
use super::types::RelationshipData;
use super::windowing::*;
use cqs::language::{ChunkType, Language};
use cqs::{Chunk, Embedding};
use std::path::PathBuf;
fn make_test_chunk(id: &str, content: &str) -> Chunk {
Chunk {
id: id.to_string(),
file: PathBuf::from("test.rs"),
language: Language::Rust,
chunk_type: ChunkType::Function,
name: id.to_string(),
signature: String::new(),
content: content.to_string(),
doc: None,
line_start: 1,
line_end: 10,
content_hash: blake3::hash(content.as_bytes()).to_hex().to_string(),
parent_id: None,
window_idx: None,
parent_type_name: None,
}
}
fn test_mtimes(mtime: i64) -> std::collections::HashMap<PathBuf, i64> {
let mut m = std::collections::HashMap::new();
m.insert(PathBuf::from("test.rs"), mtime);
m
}
#[test]
fn test_create_embedded_batch_all_cached() {
let chunk = make_test_chunk("c1", "fn foo() {}");
let emb = Embedding::new(vec![0.0; cqs::EMBEDDING_DIM]);
let cached = vec![(chunk, emb)];
let batch = create_embedded_batch(
cached,
vec![],
vec![],
RelationshipData::default(),
test_mtimes(12345),
);
assert_eq!(batch.chunk_embeddings.len(), 1);
assert_eq!(batch.cached_count, 1);
assert_eq!(batch.file_mtimes[&PathBuf::from("test.rs")], 12345);
}
#[test]
fn test_create_embedded_batch_all_new() {
let chunk = make_test_chunk("c1", "fn foo() {}");
let emb = Embedding::new(vec![1.0; cqs::EMBEDDING_DIM]);
let batch = create_embedded_batch(
vec![],
vec![chunk],
vec![emb],
RelationshipData::default(),
test_mtimes(99),
);
assert_eq!(batch.chunk_embeddings.len(), 1);
assert_eq!(batch.cached_count, 0);
assert_eq!(batch.file_mtimes[&PathBuf::from("test.rs")], 99);
}
#[test]
fn test_create_embedded_batch_mixed() {
let cached_chunk = make_test_chunk("c1", "fn foo() {}");
let cached_emb = Embedding::new(vec![0.0; cqs::EMBEDDING_DIM]);
let new_chunk = make_test_chunk("c2", "fn bar() {}");
let new_emb = Embedding::new(vec![1.0; cqs::EMBEDDING_DIM]);
let batch = create_embedded_batch(
vec![(cached_chunk, cached_emb)],
vec![new_chunk],
vec![new_emb],
RelationshipData::default(),
test_mtimes(12345),
);
assert_eq!(batch.chunk_embeddings.len(), 2);
assert_eq!(batch.cached_count, 1);
}
#[test]
fn test_create_embedded_batch_empty() {
let batch = create_embedded_batch(
vec![],
vec![],
vec![],
RelationshipData::default(),
std::collections::HashMap::new(),
);
assert_eq!(batch.chunk_embeddings.len(), 0);
assert_eq!(batch.cached_count, 0);
}
#[test]
fn test_create_embedded_batch_preserves_order() {
let c1 = make_test_chunk("c1", "fn first() {}");
let e1 = Embedding::new(vec![1.0; cqs::EMBEDDING_DIM]);
let c2 = make_test_chunk("c2", "fn second() {}");
let e2 = Embedding::new(vec![2.0; cqs::EMBEDDING_DIM]);
let c3 = make_test_chunk("c3", "fn third() {}");
let e3 = Embedding::new(vec![3.0; cqs::EMBEDDING_DIM]);
let batch = create_embedded_batch(
vec![(c1, e1)],
vec![c2, c3],
vec![e2, e3],
RelationshipData::default(),
test_mtimes(0),
);
assert_eq!(batch.chunk_embeddings.len(), 3);
assert_eq!(batch.chunk_embeddings[0].0.id, "c1");
assert_eq!(batch.chunk_embeddings[1].0.id, "c2");
assert_eq!(batch.chunk_embeddings[2].0.id, "c3");
}
#[test]
fn test_windowing_constants() {
assert_eq!(max_tokens_per_window(512), 480); assert_eq!(max_tokens_per_window(8192), 8160); assert_eq!(max_tokens_per_window(32768), 32736); assert_eq!(max_tokens_per_window(0), 480); assert!(max_tokens_per_window(64) >= 128);
assert_eq!(window_overlap_tokens(480), 64); assert_eq!(window_overlap_tokens(8160), 1020); assert_eq!(window_overlap_tokens(32736), 4092); assert_eq!(window_overlap_tokens(0), 0);
assert_eq!(window_overlap_tokens(128), 63); assert!(window_overlap_tokens(128) < 128 / 2);
assert!(window_overlap_tokens(200) < 200 / 2);
}
#[test]
#[ignore] fn test_apply_windowing_empty() {
use cqs::embedder::ModelConfig;
use cqs::Embedder;
let embedder = Embedder::new_cpu(ModelConfig::resolve(None, None)).unwrap();
let result = apply_windowing(vec![], &embedder);
assert!(result.is_empty());
}
#[test]
#[ignore] fn test_apply_windowing_short_chunk() {
use cqs::embedder::ModelConfig;
use cqs::Embedder;
let embedder = Embedder::new_cpu(ModelConfig::resolve(None, None)).unwrap();
let mut chunk = make_test_chunk("short1", "fn foo() {}");
chunk.doc = Some("A short function".to_string());
let result = apply_windowing(vec![chunk], &embedder);
assert_eq!(result.len(), 1);
let c = &result[0];
assert_eq!(c.id, "short1");
assert_eq!(c.name, "short1");
assert_eq!(c.doc, Some("A short function".to_string()));
assert_eq!(c.parent_id, None, "short chunk should not have parent_id");
assert_eq!(c.window_idx, None, "short chunk should not have window_idx");
assert_eq!(c.file, PathBuf::from("test.rs"));
assert_eq!(c.language, Language::Rust);
assert_eq!(c.chunk_type, ChunkType::Function);
assert_eq!(c.content, "fn foo() {}");
}
#[test]
#[ignore] fn test_apply_windowing_long_chunk() {
use cqs::embedder::ModelConfig;
use cqs::Embedder;
let embedder = Embedder::new_cpu(ModelConfig::resolve(None, None)).unwrap();
let long_content: String = (0..500)
.map(|i| format!(" let variable_{i} = {i};\n"))
.collect();
let content = format!("fn big_function() {{\n{long_content}}}");
let mut chunk = make_test_chunk("long1", &content);
chunk.doc = Some("A very long function".to_string());
chunk.line_start = 10;
chunk.line_end = 520;
chunk.parent_type_name = Some("MyStruct".to_string());
let original_id = chunk.id.clone();
let result = apply_windowing(vec![chunk], &embedder);
assert!(
result.len() > 1,
"Expected multiple windows, got {}",
result.len()
);
for (i, window) in result.iter().enumerate() {
let idx = i as u32;
assert_eq!(
window.id,
format!("{original_id}:w{idx}"),
"window {i} has wrong id"
);
assert_eq!(
window.parent_id,
Some(original_id.clone()),
"window {i} missing parent_id"
);
assert_eq!(
window.window_idx,
Some(idx),
"window {i} has wrong window_idx"
);
assert_eq!(window.file, PathBuf::from("test.rs"));
assert_eq!(window.language, Language::Rust);
assert_eq!(window.chunk_type, ChunkType::Function);
assert_eq!(window.name, "long1");
assert_eq!(window.line_start, 10);
assert_eq!(window.line_end, 520);
assert_eq!(window.parent_type_name, Some("MyStruct".to_string()));
let expected_hash = blake3::hash(window.content.as_bytes()).to_hex().to_string();
assert_eq!(
window.content_hash, expected_hash,
"window {i} hash mismatch"
);
assert!(!window.content.is_empty(), "window {i} has empty content");
}
assert_eq!(
result[0].doc,
Some("A very long function".to_string()),
"first window should preserve doc"
);
for window in &result[1..] {
assert_eq!(window.doc, None, "non-first window should have doc = None");
}
}
static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn test_embed_batch_size() {
use super::types::embed_batch_size;
let _lock = ENV_MUTEX.lock().unwrap();
std::env::remove_var("CQS_EMBED_BATCH_SIZE");
assert_eq!(embed_batch_size(), 64);
std::env::set_var("CQS_EMBED_BATCH_SIZE", "128");
assert_eq!(embed_batch_size(), 128);
std::env::remove_var("CQS_EMBED_BATCH_SIZE");
std::env::set_var("CQS_EMBED_BATCH_SIZE", "not_a_number");
assert_eq!(embed_batch_size(), 64);
std::env::remove_var("CQS_EMBED_BATCH_SIZE");
std::env::set_var("CQS_EMBED_BATCH_SIZE", "0");
assert_eq!(embed_batch_size(), 64);
std::env::remove_var("CQS_EMBED_BATCH_SIZE");
}
}