use std::collections::HashSet;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use futures::StreamExt as _;
use tokio::sync::watch;
use crate::chunker::{ChunkerConfig, CodeChunk, chunk_file};
use crate::context::contextualize_for_embedding;
use crate::error::{IndexError, Result};
use crate::languages::{detect_language, is_indexable};
use crate::store::{ChunkInsert, CodeStore};
use zeph_common::BlockingSpawner;
use zeph_llm::any::AnyProvider;
use zeph_llm::provider::LlmProvider;
static CHUNK_TASK_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone)]
pub struct IndexerConfig {
pub chunker: ChunkerConfig,
pub concurrency: usize,
pub batch_size: usize,
pub memory_batch_size: usize,
pub max_file_bytes: usize,
pub embed_concurrency: usize,
}
impl Default for IndexerConfig {
fn default() -> Self {
Self {
chunker: ChunkerConfig::default(),
concurrency: 2,
batch_size: 16,
memory_batch_size: 16,
max_file_bytes: 512 * 1024,
embed_concurrency: 1,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct IndexProgress {
pub files_done: usize,
pub files_total: usize,
pub chunks_created: usize,
}
#[derive(Debug, Default)]
pub struct IndexReport {
pub files_scanned: usize,
pub files_indexed: usize,
pub chunks_created: usize,
pub chunks_skipped: usize,
pub chunks_removed: usize,
pub errors: Vec<String>,
pub duration_ms: u64,
}
pub struct CodeIndexer {
store: CodeStore,
provider: Arc<AnyProvider>,
config: IndexerConfig,
spawner: Option<Arc<dyn BlockingSpawner>>,
indexing: Arc<AtomicBool>,
}
impl CodeIndexer {
#[must_use]
pub fn new(store: CodeStore, provider: Arc<AnyProvider>, config: IndexerConfig) -> Self {
Self {
store,
provider,
config,
spawner: None,
indexing: Arc::new(AtomicBool::new(false)),
}
}
#[must_use]
pub fn with_spawner(mut self, spawner: Arc<dyn BlockingSpawner>) -> Self {
self.spawner = Some(spawner);
self
}
pub async fn index_project(
&self,
root: &Path,
progress_tx: Option<&watch::Sender<IndexProgress>>,
) -> Result<IndexReport> {
if self
.indexing
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
tracing::info!("index_project already running, skipping concurrent request");
return Ok(IndexReport::default());
}
let _guard = IndexingGuard(Arc::clone(&self.indexing));
let start = std::time::Instant::now();
let mut report = IndexReport::default();
self.ensure_collection_for_provider().await?;
let (entries, current_files) = self.walk_project_files(root).await?;
let total = entries.len();
tracing::info!(total, "indexing started");
let memory_batch_size = self.config.memory_batch_size.max(1);
let mut files_done = 0usize;
for batch in entries.chunks(memory_batch_size) {
self.index_batch(
batch,
root,
total,
&mut files_done,
&mut report,
progress_tx,
)
.await;
}
self.cleanup_removed_files(¤t_files, &mut report)
.await?;
report.duration_ms = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
Ok(report)
}
async fn ensure_collection_for_provider(&self) -> Result<()> {
let probe = self.provider.embed("probe").await?;
let vector_size = u64::try_from(probe.len())?;
self.store.ensure_collection(vector_size).await
}
async fn walk_project_files(
&self,
root: &Path,
) -> Result<(Vec<ignore::DirEntry>, HashSet<String>)> {
let root_buf = root.to_path_buf();
tokio::task::spawn_blocking(move || {
let entries: Vec<_> = ignore::WalkBuilder::new(&root_buf)
.hidden(true)
.git_ignore(true)
.build()
.flatten()
.filter(|e| e.file_type().is_some_and(|ft| ft.is_file()) && is_indexable(e.path()))
.collect();
let mut current_files: HashSet<String> = HashSet::new();
for entry in &entries {
let rel_path = entry
.path()
.strip_prefix(&root_buf)
.unwrap_or(entry.path())
.to_string_lossy()
.to_string();
current_files.insert(rel_path);
}
(entries, current_files)
})
.await
.map_err(|e| IndexError::Other(format!("directory walk panicked: {e:#}")))
}
#[allow(clippy::too_many_arguments)]
async fn index_batch(
&self,
batch: &[ignore::DirEntry],
root: &Path,
total: usize,
files_done: &mut usize,
report: &mut IndexReport,
progress_tx: Option<&watch::Sender<IndexProgress>>,
) {
let store = self.store.clone();
let provider = Arc::clone(&self.provider);
let config = self.config.clone();
let spawner = self.spawner.clone();
let concurrency = self.config.embed_concurrency.max(1);
let file_pairs = make_file_pairs(batch, root);
let mut stream =
futures::stream::iter(file_pairs.into_iter().map(|(rel_path, abs_path)| {
let store = store.clone();
let provider = Arc::clone(&provider);
let config = config.clone();
let spawner = spawner.clone();
async move {
let worker = FileIndexWorker {
store,
provider,
config,
spawner,
};
let result = worker.index_file(&abs_path, &rel_path).await;
(rel_path, result)
}
}))
.buffer_unordered(concurrency);
while let Some((rel_path, outcome)) = stream.next().await {
report.files_scanned += 1;
*files_done += 1;
match outcome {
Ok((created, skipped)) => {
if created > 0 {
report.files_indexed += 1;
}
report.chunks_created += created;
report.chunks_skipped += skipped;
tracing::info!(
file = %rel_path,
progress = format_args!("{files_done}/{total}"),
created,
skipped,
);
}
Err(e) => {
report.errors.push(format!("{rel_path}: {e:#}"));
}
}
if let Some(tx) = progress_tx {
let _ = tx.send(IndexProgress {
files_done: *files_done,
files_total: total,
chunks_created: report.chunks_created,
});
}
}
drop(stream);
tokio::task::yield_now().await;
}
async fn cleanup_removed_files(
&self,
current_files: &HashSet<String>,
report: &mut IndexReport,
) -> Result<()> {
let indexed = self.store.indexed_files().await?;
for old_file in &indexed {
if !current_files.contains(old_file) {
match self.store.remove_file_chunks(old_file).await {
Ok(n) => report.chunks_removed += n,
Err(e) => report.errors.push(format!("cleanup {old_file}: {e:#}")),
}
}
}
Ok(())
}
pub async fn reindex_file(&self, root: &Path, abs_path: &Path) -> Result<usize> {
let rel_path = abs_path
.strip_prefix(root)
.unwrap_or(abs_path)
.to_string_lossy()
.to_string();
self.store.remove_file_chunks(&rel_path).await?;
let worker = FileIndexWorker {
store: self.store.clone(),
provider: Arc::clone(&self.provider),
config: self.config.clone(),
spawner: self.spawner.clone(),
};
let (created, _) = worker.index_file(abs_path, &rel_path).await?;
Ok(created)
}
}
struct FileIndexWorker {
store: CodeStore,
provider: Arc<AnyProvider>,
config: IndexerConfig,
spawner: Option<Arc<dyn BlockingSpawner>>,
}
impl FileIndexWorker {
async fn index_file(&self, abs_path: &Path, rel_path: &str) -> Result<(usize, usize)> {
let metadata = tokio::fs::metadata(abs_path).await?;
if metadata.len() > self.config.max_file_bytes as u64 {
tracing::debug!(
file = %abs_path.display(),
size = metadata.len(),
"skipping oversized file"
);
return Ok((0, 0));
}
let source = tokio::fs::read_to_string(abs_path).await?;
let lang = detect_language(abs_path).ok_or(IndexError::UnsupportedLanguage)?;
let rel_path_owned = rel_path.to_owned();
let chunker_config = self.config.chunker.clone();
let chunks = if let Some(ref spawner) = self.spawner {
let task_id = CHUNK_TASK_COUNTER.fetch_add(1, Ordering::Relaxed);
let task_name: std::sync::Arc<str> =
std::sync::Arc::from(format!("chunk_file_{task_id}").as_str());
let (result_tx, result_rx) = tokio::sync::oneshot::channel();
let _join = spawner.spawn_blocking_named(
task_name,
Box::new(move || {
let result = chunk_file(&source, &rel_path_owned, lang, &chunker_config);
let _ = result_tx.send(result);
}),
);
result_rx
.await
.map_err(|_| IndexError::Other("chunk_file task dropped result".to_owned()))??
} else {
tokio::task::spawn_blocking(move || {
chunk_file(&source, &rel_path_owned, lang, &chunker_config)
})
.await
.map_err(|e| IndexError::Other(format!("chunk_file panicked: {e}")))??
};
let all_hashes: Vec<&str> = chunks.iter().map(|c| c.content_hash.as_str()).collect();
let existing = self.store.existing_hashes(&all_hashes).await?;
let mut new_chunks: Vec<CodeChunk> = Vec::new();
let mut skipped = 0usize;
for chunk in chunks {
if existing.contains(&chunk.content_hash) {
skipped += 1;
} else {
new_chunks.push(chunk);
}
}
if new_chunks.is_empty() {
return Ok((0, skipped));
}
let embedding_texts: Vec<String> =
new_chunks.iter().map(contextualize_for_embedding).collect();
let text_refs: Vec<&str> = embedding_texts.iter().map(String::as_str).collect();
let vectors = self.provider.embed_batch(&text_refs).await?;
let batch: Vec<(ChunkInsert<'_>, Vec<f32>)> = new_chunks
.iter()
.zip(vectors)
.map(|(chunk, vector)| (chunk_to_insert(chunk), vector))
.collect();
let created = match tokio::time::timeout(
Duration::from_secs(30),
self.store.upsert_chunks_batch(batch),
)
.await
{
Ok(Ok(inserted)) => inserted.len(),
Ok(Err(e)) => {
tracing::warn!("upsert_chunks_batch failed, skipping batch: {e}");
0
}
Err(_elapsed) => {
tracing::warn!(
"upsert_chunks_batch timed out after 30s, skipping batch of {} chunks",
new_chunks.len()
);
0
}
};
if created > 0 {
tracing::debug!("{rel_path}: {created} chunks indexed, {skipped} unchanged");
}
Ok((created, skipped))
}
}
fn make_file_pairs(batch: &[ignore::DirEntry], root: &Path) -> Vec<(String, std::path::PathBuf)> {
batch
.iter()
.map(|entry| {
let rel = entry
.path()
.strip_prefix(root)
.unwrap_or(entry.path())
.to_string_lossy()
.to_string();
let abs = entry.path().to_path_buf();
(rel, abs)
})
.collect()
}
fn chunk_to_insert(chunk: &CodeChunk) -> ChunkInsert<'_> {
ChunkInsert {
file_path: &chunk.file_path,
language: chunk.language.id(),
node_type: &chunk.node_type,
entity_name: chunk.entity_name.as_deref(),
line_start: chunk.line_range.0,
line_end: chunk.line_range.1,
code: &chunk.code,
scope_chain: &chunk.scope_chain,
content_hash: &chunk.content_hash,
}
}
struct IndexingGuard(Arc<AtomicBool>);
impl Drop for IndexingGuard {
fn drop(&mut self) {
self.0.store(false, Ordering::Release);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn index_progress_default() {
let p = IndexProgress::default();
assert_eq!(p.files_done, 0);
assert_eq!(p.files_total, 0);
assert_eq!(p.chunks_created, 0);
}
#[test]
fn progress_send_no_receivers_is_ignored() {
let (tx, rx) = tokio::sync::watch::channel(IndexProgress::default());
drop(rx);
let _ = tx.send(IndexProgress {
files_done: 1,
files_total: 5,
chunks_created: 3,
});
}
#[test]
fn progress_send_multiple_times_accumulates() {
let (tx, rx) = tokio::sync::watch::channel(IndexProgress::default());
for i in 1..=3usize {
let _ = tx.send(IndexProgress {
files_done: i,
files_total: 3,
chunks_created: i * 2,
});
}
let p = rx.borrow();
assert_eq!(p.files_done, 3);
assert_eq!(p.files_total, 3);
assert_eq!(p.chunks_created, 6);
}
#[test]
fn progress_none_tx_skips_send() {
let progress_tx: Option<&tokio::sync::watch::Sender<IndexProgress>> = None;
let entries = [1usize, 2, 3];
for (i, _) in entries.iter().enumerate() {
if let Some(tx) = progress_tx {
let _ = tx.send(IndexProgress {
files_done: i + 1,
files_total: entries.len(),
chunks_created: 0,
});
}
}
}
#[test]
fn chunk_to_insert_maps_fields() {
let chunk = CodeChunk {
code: "fn test() {}".to_string(),
file_path: "src/lib.rs".to_string(),
language: crate::languages::Lang::Rust,
node_type: "function_item".to_string(),
entity_name: Some("test".to_string()),
line_range: (1, 3),
scope_chain: "Foo".to_string(),
imports: String::new(),
content_hash: "abc".to_string(),
};
let insert = chunk_to_insert(&chunk);
assert_eq!(insert.file_path, "src/lib.rs");
assert_eq!(insert.language, "rust");
assert_eq!(insert.entity_name, Some("test"));
assert_eq!(insert.line_start, 1);
assert_eq!(insert.line_end, 3);
}
#[test]
fn default_config() {
let config = IndexerConfig::default();
assert_eq!(config.chunker.target_size, 600);
assert_eq!(config.concurrency, 2);
assert_eq!(config.batch_size, 16);
assert_eq!(config.embed_concurrency, 1);
}
#[test]
fn indexer_config_custom_concurrency_and_batch_size() {
let config = IndexerConfig {
concurrency: 8,
batch_size: 64,
..IndexerConfig::default()
};
assert_eq!(config.concurrency, 8);
assert_eq!(config.batch_size, 64);
}
#[test]
fn index_report_defaults() {
let report = IndexReport::default();
assert_eq!(report.files_scanned, 0);
assert!(report.errors.is_empty());
}
#[tokio::test]
async fn index_file_spawn_blocking_dedup_path() {
use std::sync::Arc;
use tempfile::TempDir;
use zeph_llm::any::AnyProvider;
use zeph_llm::mock::MockProvider;
use zeph_memory::QdrantOps;
let dir = TempDir::new().unwrap();
let rs_path = dir.path().join("sample.rs");
std::fs::write(
&rs_path,
"pub fn hello() -> &'static str { \"hello\" }\n\
pub fn world() -> &'static str { \"world\" }\n",
)
.unwrap();
let pool = zeph_db::DbConfig {
url: ":memory:".to_string(),
..Default::default()
}
.connect()
.await
.unwrap();
let source = std::fs::read_to_string(&rs_path).unwrap();
let lang = crate::languages::detect_language(&rs_path).unwrap();
let chunks =
crate::chunker::chunk_file(&source, "sample.rs", lang, &ChunkerConfig::default())
.unwrap();
let chunk_count = chunks.len();
assert!(chunk_count > 0, "test file must produce at least one chunk");
for (i, chunk) in chunks.iter().enumerate() {
zeph_db::query(zeph_db::sql!(
"INSERT INTO chunk_metadata \
(qdrant_id, file_path, content_hash, line_start, line_end, language, node_type) \
VALUES (?, ?, ?, ?, ?, ?, ?)"
))
.bind(format!("q{i}"))
.bind("sample.rs")
.bind(&chunk.content_hash)
.bind(i64::try_from(chunk.line_range.0).unwrap_or(i64::MAX))
.bind(i64::try_from(chunk.line_range.1).unwrap_or(i64::MAX))
.bind("rust")
.bind("function_item")
.execute(&pool)
.await
.unwrap();
}
let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap();
let store = crate::store::CodeStore::with_ops(ops, pool);
let provider = Arc::new(AnyProvider::Mock(
MockProvider::default().with_embedding(vec![0.0_f32; 384]),
));
let worker = FileIndexWorker {
store,
provider,
config: IndexerConfig::default(),
spawner: None,
};
let (created, skipped) = worker.index_file(&rs_path, "sample.rs").await.unwrap();
assert_eq!(created, 0);
assert_eq!(skipped, chunk_count);
let (created2, skipped2) = worker.index_file(&rs_path, "sample.rs").await.unwrap();
assert_eq!(created2, 0);
assert_eq!(skipped2, chunk_count);
}
#[tokio::test]
async fn index_file_with_blocking_spawner() {
use std::sync::Arc;
use tempfile::TempDir;
use zeph_llm::any::AnyProvider;
use zeph_llm::mock::MockProvider;
use zeph_memory::QdrantOps;
struct MockBlockingSpawner;
impl BlockingSpawner for MockBlockingSpawner {
fn spawn_blocking_named(
&self,
_name: std::sync::Arc<str>,
f: Box<dyn FnOnce() + Send + 'static>,
) -> tokio::task::JoinHandle<()> {
tokio::task::spawn_blocking(f)
}
}
let dir = TempDir::new().unwrap();
let rs_path = dir.path().join("sample.rs");
tokio::fs::write(&rs_path, b"fn hello() {}\n")
.await
.unwrap();
let pool = zeph_db::DbConfig {
url: ":memory:".to_string(),
..Default::default()
}
.connect()
.await
.unwrap();
let ops = QdrantOps::new("http://127.0.0.1:1", None).unwrap();
let store = crate::store::CodeStore::with_ops(ops, pool);
let provider = Arc::new(AnyProvider::Mock(
MockProvider::default().with_embedding(vec![0.0_f32; 384]),
));
let worker = FileIndexWorker {
store,
provider,
config: IndexerConfig::default(),
spawner: Some(Arc::new(MockBlockingSpawner)),
};
let result = worker.index_file(&rs_path, "sample.rs").await;
if let Err(ref e) = result {
let msg = e.to_string();
assert!(
!msg.contains("chunk_file task dropped result"),
"spawner path must not drop the result channel; got: {msg}"
);
}
}
#[test]
fn indexing_guard_resets_flag_on_drop() {
let flag = Arc::new(AtomicBool::new(false));
{
flag.store(true, Ordering::Relaxed);
let _guard = IndexingGuard(Arc::clone(&flag));
assert!(flag.load(Ordering::Relaxed));
}
assert!(!flag.load(Ordering::Relaxed));
}
#[test]
fn indexing_guard_compare_exchange_skips_concurrent() {
let flag = Arc::new(AtomicBool::new(false));
assert!(
flag.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok(),
"first caller should succeed"
);
assert!(
flag.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err(),
"second caller should be rejected while flag is true"
);
flag.store(false, Ordering::Release);
assert!(
flag.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok(),
"third caller should succeed after reset"
);
}
}