use anyhow::{anyhow, Context, Result};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use sha2::{Digest, Sha256};
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
use tracing::warn;
mod arc_path_serde {
use super::*;
pub fn serialize<S: Serializer>(path: &Arc<Path>, serializer: S) -> Result<S::Ok, S::Error> {
path.as_ref().serialize(serializer)
}
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Arc<Path>, D::Error> {
let pb = PathBuf::deserialize(deserializer)?;
Ok(Arc::from(pb.as_path()))
}
}
mod arc_str_serde {
use super::*;
pub fn serialize<S: Serializer>(s: &Arc<str>, serializer: S) -> Result<S::Ok, S::Error> {
s.as_ref().serialize(serializer)
}
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Arc<str>, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(Arc::from(s.as_str()))
}
}
pub const EMBEDDING_DIM: usize = 384;
pub const MAX_CHUNKS: usize = 100_000;
pub const MAX_VOCABULARY_SIZE: usize = 50_000;
#[derive(Clone, Copy, PartialEq)]
struct OrdF32(f32);
impl Eq for OrdF32 {}
impl PartialOrd for OrdF32 {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrdF32 {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.total_cmp(&other.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum ChunkType {
Function,
Struct,
Enum,
Trait,
Impl,
Module,
Import,
Comment,
Test,
Constant,
#[default]
CodeBlock,
Text,
}
impl ChunkType {
pub fn weight(&self) -> f32 {
match self {
Self::Function => 1.0,
Self::Struct => 1.0,
Self::Enum => 0.9,
Self::Trait => 1.0,
Self::Impl => 0.8,
Self::Module => 0.7,
Self::Import => 0.3,
Self::Comment => 0.5,
Self::Test => 0.8,
Self::Constant => 0.6,
Self::CodeBlock => 0.7,
Self::Text => 0.5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkMetadata {
#[serde(with = "arc_path_serde")]
pub file_path: Arc<Path>,
pub start_line: usize,
pub end_line: usize,
pub chunk_type: ChunkType,
pub symbol_name: Option<String>,
#[serde(with = "arc_str_serde")]
pub language: Arc<str>,
pub content_hash: String,
pub indexed_at: u64,
pub tags: Vec<String>,
}
impl ChunkMetadata {
pub fn new(
file_path: impl Into<Arc<Path>>,
start_line: usize,
end_line: usize,
chunk_type: ChunkType,
language: impl Into<Arc<str>>,
content: &str,
) -> Self {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
let content_hash = hex::encode(hasher.finalize());
let indexed_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
file_path: file_path.into(),
start_line,
end_line,
chunk_type,
symbol_name: None,
language: language.into(),
content_hash,
indexed_at,
tags: Vec::new(),
}
}
pub fn with_symbol(mut self, name: impl Into<String>) -> Self {
self.symbol_name = Some(name.into());
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeChunk {
pub id: String,
pub content: String,
pub metadata: ChunkMetadata,
#[serde(skip)]
pub embedding: Option<Vec<f32>>,
}
impl CodeChunk {
pub fn new(content: String, metadata: ChunkMetadata) -> Self {
let id = format!(
"{}:{}:{}",
metadata.file_path.display(),
metadata.start_line,
&metadata.content_hash[..8]
);
Self {
id,
content,
metadata,
embedding: None,
}
}
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = Some(embedding);
self
}
pub fn len(&self) -> usize {
self.content.len()
}
pub fn is_empty(&self) -> bool {
self.content.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub chunk: CodeChunk,
pub score: f32,
pub distance: f32,
}
#[derive(Debug, Clone, Default)]
pub struct SearchFilter {
pub file_patterns: Vec<String>,
pub chunk_types: Vec<ChunkType>,
pub languages: Vec<String>,
pub tags: Vec<String>,
pub min_score: Option<f32>,
}
impl SearchFilter {
pub fn new() -> Self {
Self::default()
}
pub fn with_file_pattern(mut self, pattern: impl Into<String>) -> Self {
self.file_patterns.push(pattern.into());
self
}
pub fn with_chunk_type(mut self, chunk_type: ChunkType) -> Self {
self.chunk_types.push(chunk_type);
self
}
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.languages.push(language.into());
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn with_min_score(mut self, score: f32) -> Self {
self.min_score = Some(score);
self
}
pub fn matches(&self, chunk: &CodeChunk) -> bool {
if !self.file_patterns.is_empty() {
let path_str = chunk.metadata.file_path.to_string_lossy();
let matches = self.file_patterns.iter().any(|pattern| {
glob::Pattern::new(pattern)
.map(|p| p.matches(&path_str))
.unwrap_or(false)
});
if !matches {
return false;
}
}
if !self.chunk_types.is_empty() && !self.chunk_types.contains(&chunk.metadata.chunk_type) {
return false;
}
if !self.languages.is_empty()
&& !self
.languages
.iter()
.any(|l| l.eq_ignore_ascii_case(&chunk.metadata.language))
{
return false;
}
if !self.tags.is_empty() && !self.tags.iter().any(|t| chunk.metadata.tags.contains(t)) {
return false;
}
true
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum CollectionScope {
#[default]
Project,
Session,
Global,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndexHealth {
Healthy,
Degraded,
Corrupt,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorCollection {
pub name: String,
pub scope: CollectionScope,
#[serde(skip)]
chunks: Vec<CodeChunk>,
#[serde(skip)]
id_index: HashMap<String, usize>,
file_index: HashMap<PathBuf, Vec<String>>,
pub created_at: u64,
pub updated_at: u64,
}
impl VectorCollection {
pub fn new(name: impl Into<String>, scope: CollectionScope) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
name: name.into(),
scope,
chunks: Vec::new(),
id_index: HashMap::new(),
file_index: HashMap::new(),
created_at: now,
updated_at: now,
}
}
pub fn add_chunk(&mut self, chunk: CodeChunk) -> Result<()> {
if self.chunks.len() >= MAX_CHUNKS {
return Err(anyhow!(
"Collection {} is full (max {} chunks)",
self.name,
MAX_CHUNKS
));
}
self.file_index
.entry(chunk.metadata.file_path.to_path_buf())
.or_default()
.push(chunk.id.clone());
let idx = self.chunks.len();
self.id_index.insert(chunk.id.clone(), idx);
self.chunks.push(chunk);
self.updated_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Ok(())
}
pub fn get_chunk(&self, id: &str) -> Option<&CodeChunk> {
self.id_index.get(id).map(|&idx| &self.chunks[idx])
}
pub fn remove_chunk(&mut self, id: &str) -> Option<CodeChunk> {
if let Some(&idx) = self.id_index.get(id) {
let chunk = self.chunks.swap_remove(idx);
self.id_index.remove(id);
if idx < self.chunks.len() {
self.id_index.insert(self.chunks[idx].id.clone(), idx);
}
if let Some(file_chunks) = self.file_index.get_mut(chunk.metadata.file_path.as_ref()) {
file_chunks.retain(|cid| cid != id);
}
self.updated_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Some(chunk)
} else {
None
}
}
pub fn remove_file(&mut self, path: &Path) {
if let Some(chunk_ids) = self.file_index.remove(path) {
let ids_to_remove: HashSet<&String> = chunk_ids.iter().collect();
self.chunks.retain(|c| !ids_to_remove.contains(&c.id));
self.id_index.clear();
for (i, c) in self.chunks.iter().enumerate() {
self.id_index.insert(c.id.clone(), i);
}
self.updated_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
}
}
pub fn chunks(&self) -> &[CodeChunk] {
&self.chunks
}
pub fn len(&self) -> usize {
self.chunks.len()
}
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}
pub fn files(&self) -> Vec<&PathBuf> {
self.file_index.keys().collect()
}
}
#[async_trait::async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
fn dimension(&self) -> usize;
}
pub struct MockEmbeddingProvider {
dimension: usize,
}
impl MockEmbeddingProvider {
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
}
impl Default for MockEmbeddingProvider {
fn default() -> Self {
Self::new(EMBEDDING_DIM)
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for MockEmbeddingProvider {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut hasher = Sha256::new();
hasher.update(text.as_bytes());
let hash = hasher.finalize();
let mut embedding = vec![0.0f32; self.dimension];
for (i, byte) in hash.iter().cycle().take(self.dimension).enumerate() {
embedding[i] = (*byte as f32 - 128.0) / 128.0;
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut embedding {
*x /= norm;
}
}
Ok(embedding)
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed(text).await?);
}
Ok(results)
}
fn dimension(&self) -> usize {
self.dimension
}
}
pub struct TfIdfEmbeddingProvider {
dimension: usize,
vocabulary: Arc<RwLock<HashMap<String, usize>>>,
usage_counts: Arc<RwLock<HashMap<String, u64>>>,
}
impl TfIdfEmbeddingProvider {
pub fn new(dimension: usize) -> Self {
Self {
dimension,
vocabulary: Arc::new(RwLock::new(HashMap::new())),
usage_counts: Arc::new(RwLock::new(HashMap::new())),
}
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric() && c != '_')
.filter(|s| s.len() > 1)
.map(String::from)
.collect()
}
fn get_or_create_index(&self, token: &str) -> usize {
{
let read = self.vocabulary.read().unwrap_or_else(|e| e.into_inner());
if let Some(&idx) = read.get(token) {
drop(read);
let mut counts = self.usage_counts.write().unwrap_or_else(|e| e.into_inner());
*counts.entry(token.to_string()).or_default() += 1;
return idx;
}
}
let mut write = self.vocabulary.write().unwrap_or_else(|e| e.into_inner());
if let Some(&idx) = write.get(token) {
drop(write);
let mut counts = self.usage_counts.write().unwrap_or_else(|e| e.into_inner());
*counts.entry(token.to_string()).or_default() += 1;
return idx;
}
let idx = write.len() % self.dimension;
write.insert(token.to_string(), idx);
if write.len() > MAX_VOCABULARY_SIZE {
let mut counts = self.usage_counts.write().unwrap_or_else(|e| e.into_inner());
let evict_count = write.len() - MAX_VOCABULARY_SIZE;
warn!(
"TF-IDF vocabulary exceeded cap of {}; evicting {} least-used terms",
MAX_VOCABULARY_SIZE, evict_count
);
let mut terms_by_usage: Vec<(String, u64)> = write
.keys()
.map(|k| {
let count = counts.get(k).copied().unwrap_or(0);
(k.clone(), count)
})
.collect();
terms_by_usage.sort_by_key(|(_, count)| *count);
for (term, _) in terms_by_usage.into_iter().take(evict_count) {
if term != token {
write.remove(&term);
counts.remove(&term);
}
}
}
drop(write);
let mut counts = self.usage_counts.write().unwrap_or_else(|e| e.into_inner());
*counts.entry(token.to_string()).or_default() += 1;
idx
}
}
impl Default for TfIdfEmbeddingProvider {
fn default() -> Self {
Self::new(EMBEDDING_DIM)
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for TfIdfEmbeddingProvider {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let tokens = Self::tokenize(text);
let mut embedding = vec![0.0f32; self.dimension];
let mut tf: HashMap<String, f32> = HashMap::new();
for token in &tokens {
*tf.entry(token.clone()).or_default() += 1.0;
}
for (token, count) in tf {
let idx = self.get_or_create_index(&token);
embedding[idx] += count / tokens.len() as f32;
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut embedding {
*x /= norm;
}
}
Ok(embedding)
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed(text).await?);
}
Ok(results)
}
fn dimension(&self) -> usize {
self.dimension
}
}
pub struct VectorIndex {
embeddings: Vec<Vec<f32>>,
chunk_ids: Vec<String>,
dimension: usize,
}
impl VectorIndex {
pub fn new(dimension: usize) -> Self {
Self {
embeddings: Vec::new(),
chunk_ids: Vec::new(),
dimension,
}
}
pub fn add(&mut self, chunk_id: String, mut embedding: Vec<f32>) -> Result<()> {
if embedding.len() != self.dimension {
return Err(anyhow!(
"Embedding dimension mismatch: expected {}, got {}",
self.dimension,
embedding.len()
));
}
Self::l2_normalize(&mut embedding);
self.embeddings.push(embedding);
self.chunk_ids.push(chunk_id);
Ok(())
}
pub fn remove(&mut self, chunk_id: &str) {
if let Some(pos) = self.chunk_ids.iter().position(|id| id == chunk_id) {
self.embeddings.swap_remove(pos);
self.chunk_ids.swap_remove(pos);
}
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
if query.len() != self.dimension || k == 0 {
return Vec::new();
}
let mut normed_query = query.to_vec();
Self::l2_normalize(&mut normed_query);
let mut heap: BinaryHeap<std::cmp::Reverse<(OrdF32, usize)>> =
BinaryHeap::with_capacity(k + 1);
const EARLY_TERM_THRESHOLD: f32 = 0.95;
for (i, emb) in self.embeddings.iter().enumerate() {
let score = Self::dot_product(&normed_query, emb);
if heap.len() < k {
heap.push(std::cmp::Reverse((OrdF32(score), i)));
} else if let Some(&std::cmp::Reverse((OrdF32(min_score), _))) = heap.peek() {
if score > min_score {
heap.pop();
heap.push(std::cmp::Reverse((OrdF32(score), i)));
}
}
if heap.len() == k {
if let Some(&std::cmp::Reverse((OrdF32(min_score), _))) = heap.peek() {
if min_score > EARLY_TERM_THRESHOLD {
break;
}
}
}
}
let mut results: Vec<(String, f32)> = heap
.into_iter()
.map(|std::cmp::Reverse((OrdF32(score), i))| (self.chunk_ids[i].clone(), score))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
#[inline]
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
fn l2_normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let mut na = a.to_vec();
let mut nb = b.to_vec();
Self::l2_normalize(&mut na);
Self::l2_normalize(&mut nb);
Self::dot_product(&na, &nb)
}
pub fn len(&self) -> usize {
self.embeddings.len()
}
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
pub fn clear(&mut self) {
self.embeddings.clear();
self.chunk_ids.clear();
}
pub fn verify_index_integrity(&self) -> Vec<String> {
let mut issues = Vec::new();
let mut seen_ids = HashSet::new();
for id in &self.chunk_ids {
if !seen_ids.insert(id.as_str()) {
issues.push(format!("Duplicate chunk ID: {}", id));
}
}
for (i, embedding) in self.embeddings.iter().enumerate() {
let id = self
.chunk_ids
.get(i)
.map(|s| s.as_str())
.unwrap_or("<missing>");
if embedding.len() != self.dimension {
issues.push(format!(
"Dimension mismatch for '{}': expected {}, got {}",
id,
self.dimension,
embedding.len()
));
}
if embedding.is_empty() {
issues.push(format!("Empty embedding vector for '{}'", id));
continue;
}
let has_nan = embedding.iter().any(|v| v.is_nan());
let has_inf = embedding.iter().any(|v| v.is_infinite());
if has_nan {
issues.push(format!("NaN values in embedding for '{}'", id));
}
if has_inf {
issues.push(format!("Inf values in embedding for '{}'", id));
}
}
if self.embeddings.len() != self.chunk_ids.len() {
issues.push(format!(
"Array length mismatch: {} embeddings vs {} chunk_ids",
self.embeddings.len(),
self.chunk_ids.len()
));
}
issues
}
pub fn check_health(&self) -> IndexHealth {
let issues = self.verify_index_integrity();
if issues.is_empty() {
return IndexHealth::Healthy;
}
let has_corrupt = issues.iter().any(|issue| {
issue.contains("NaN")
|| issue.contains("Inf")
|| issue.contains("Dimension mismatch")
|| issue.contains("Array length mismatch")
|| issue.contains("Empty embedding")
});
if has_corrupt {
IndexHealth::Corrupt
} else {
IndexHealth::Degraded
}
}
}
pub struct CodeChunker {
pub max_chunk_size: usize,
pub min_chunk_size: usize,
pub overlap: usize,
}
impl Default for CodeChunker {
fn default() -> Self {
Self {
max_chunk_size: 2000,
min_chunk_size: 100,
overlap: 50,
}
}
}
impl CodeChunker {
pub fn new(max_chunk_size: usize) -> Self {
Self {
max_chunk_size,
..Default::default()
}
}
pub fn chunk_rust(&self, content: &str, file_path: &Path) -> Vec<CodeChunk> {
static PATTERNS: once_cell::sync::Lazy<Vec<(regex::Regex, ChunkType)>> =
once_cell::sync::Lazy::new(|| {
[
(r"^\s*(pub\s+)?(async\s+)?fn\s+", ChunkType::Function),
(r"^\s*(pub\s+)?struct\s+", ChunkType::Struct),
(r"^\s*(pub\s+)?enum\s+", ChunkType::Enum),
(r"^\s*(pub\s+)?trait\s+", ChunkType::Trait),
(r"^\s*impl\s+", ChunkType::Impl),
(r"^\s*(pub\s+)?mod\s+", ChunkType::Module),
(r"^\s*#\[test\]", ChunkType::Test),
(r"^\s*(pub\s+)?const\s+", ChunkType::Constant),
(r"^\s*use\s+", ChunkType::Import),
]
.into_iter()
.filter_map(|(pat, ct)| regex::Regex::new(pat).ok().map(|re| (re, ct)))
.collect()
});
let mut chunks = Vec::new();
let lines: Vec<&str> = content.lines().collect();
let shared_path: Arc<Path> = Arc::from(file_path);
let shared_lang: Arc<str> = Arc::from("rust");
let mut current_start = 0;
let mut current_type = ChunkType::CodeBlock;
let mut brace_depth = 0;
let mut in_block = false;
for (line_num, line) in lines.iter().enumerate() {
for (pattern, chunk_type) in PATTERNS.iter() {
if pattern.is_match(line) && !in_block {
if line_num > current_start {
let chunk_content: String = lines[current_start..line_num].join("\n");
if chunk_content.len() >= self.min_chunk_size {
let metadata = ChunkMetadata::new(
shared_path.clone(),
current_start + 1,
line_num,
current_type,
shared_lang.clone(),
&chunk_content,
);
chunks.push(CodeChunk::new(chunk_content, metadata));
}
}
current_start = line_num;
current_type = *chunk_type;
in_block = true;
break;
}
}
brace_depth += line.chars().filter(|c| *c == '{').count() as i32;
brace_depth -= line.chars().filter(|c| *c == '}').count() as i32;
if in_block && brace_depth <= 0 {
let chunk_content: String = lines[current_start..=line_num].join("\n");
let symbol_name = self.extract_rust_symbol(&chunk_content, current_type);
let mut metadata = ChunkMetadata::new(
shared_path.clone(),
current_start + 1,
line_num + 1,
current_type,
shared_lang.clone(),
&chunk_content,
);
if let Some(name) = symbol_name {
metadata = metadata.with_symbol(name);
}
chunks.push(CodeChunk::new(chunk_content, metadata));
current_start = line_num + 1;
current_type = ChunkType::CodeBlock;
in_block = false;
brace_depth = 0;
}
}
if current_start < lines.len() {
let chunk_content: String = lines[current_start..].join("\n");
if chunk_content.len() >= self.min_chunk_size {
let metadata = ChunkMetadata::new(
shared_path.clone(),
current_start + 1,
lines.len(),
current_type,
shared_lang.clone(),
&chunk_content,
);
chunks.push(CodeChunk::new(chunk_content, metadata));
}
}
chunks
}
fn extract_rust_symbol(&self, content: &str, chunk_type: ChunkType) -> Option<String> {
use std::sync::LazyLock;
static SYM_FN_RE: LazyLock<regex::Regex> =
LazyLock::new(|| regex::Regex::new(r"fn\s+(\w+)").expect("invalid fn regex"));
static SYM_STRUCT_RE: LazyLock<regex::Regex> =
LazyLock::new(|| regex::Regex::new(r"struct\s+(\w+)").expect("invalid struct regex"));
static SYM_ENUM_RE: LazyLock<regex::Regex> =
LazyLock::new(|| regex::Regex::new(r"enum\s+(\w+)").expect("invalid enum regex"));
static SYM_TRAIT_RE: LazyLock<regex::Regex> =
LazyLock::new(|| regex::Regex::new(r"trait\s+(\w+)").expect("invalid trait regex"));
static SYM_IMPL_RE: LazyLock<regex::Regex> = LazyLock::new(|| {
regex::Regex::new(r"impl(?:<[^>]+>)?\s+(?:(\w+)|(?:\w+)\s+for\s+(\w+))")
.expect("invalid impl regex")
});
static SYM_MOD_RE: LazyLock<regex::Regex> =
LazyLock::new(|| regex::Regex::new(r"mod\s+(\w+)").expect("invalid mod regex"));
let first_line = content.lines().next()?;
match chunk_type {
ChunkType::Function => SYM_FN_RE
.captures(first_line)
.and_then(|c| c.get(1))
.map(|m| m.as_str().to_string()),
ChunkType::Struct => SYM_STRUCT_RE
.captures(first_line)
.and_then(|c| c.get(1))
.map(|m| m.as_str().to_string()),
ChunkType::Enum => SYM_ENUM_RE
.captures(first_line)
.and_then(|c| c.get(1))
.map(|m| m.as_str().to_string()),
ChunkType::Trait => SYM_TRAIT_RE
.captures(first_line)
.and_then(|c| c.get(1))
.map(|m| m.as_str().to_string()),
ChunkType::Impl => SYM_IMPL_RE.captures(first_line).and_then(|c| {
c.get(1)
.or_else(|| c.get(2))
.map(|m| m.as_str().to_string())
}),
ChunkType::Module => SYM_MOD_RE
.captures(first_line)
.and_then(|c| c.get(1))
.map(|m| m.as_str().to_string()),
_ => None,
}
}
pub fn chunk_fixed_size(
&self,
content: &str,
file_path: &Path,
language: &str,
) -> Vec<CodeChunk> {
let mut chunks = Vec::new();
let lines: Vec<&str> = content.lines().collect();
let shared_path: Arc<Path> = Arc::from(file_path);
let shared_lang: Arc<str> = Arc::from(language);
let mut start = 0;
while start < lines.len() {
let mut end = start;
let mut size = 0;
while end < lines.len() && size + lines[end].len() < self.max_chunk_size {
size += lines[end].len() + 1; end += 1;
}
if end == start {
end = start + 1;
}
let chunk_content: String = lines[start..end].join("\n");
let metadata = ChunkMetadata::new(
shared_path.clone(),
start + 1,
end,
ChunkType::CodeBlock,
shared_lang.clone(),
&chunk_content,
);
chunks.push(CodeChunk::new(chunk_content, metadata));
if end >= lines.len() {
break;
}
start = end.saturating_sub(self.overlap / 50);
}
chunks
}
pub fn chunk(&self, content: &str, file_path: &Path) -> Vec<CodeChunk> {
let ext = file_path.extension().and_then(|e| e.to_str()).unwrap_or("");
match ext {
"rs" => self.chunk_rust(content, file_path),
_ => self.chunk_fixed_size(content, file_path, ext),
}
}
}
pub enum EmbeddingBackend {
Mock(MockEmbeddingProvider),
TfIdf(TfIdfEmbeddingProvider),
}
impl EmbeddingBackend {
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
match self {
Self::Mock(p) => p.embed(text).await,
Self::TfIdf(p) => p.embed(text).await,
}
}
pub async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
match self {
Self::Mock(p) => p.embed_batch(texts).await,
Self::TfIdf(p) => p.embed_batch(texts).await,
}
}
pub fn dimension(&self) -> usize {
match self {
Self::Mock(p) => p.dimension(),
Self::TfIdf(p) => p.dimension(),
}
}
}
pub struct VectorStore {
collections: HashMap<String, VectorCollection>,
indices: HashMap<String, VectorIndex>,
provider: Arc<EmbeddingBackend>,
storage_path: Option<PathBuf>,
chunker: CodeChunker,
}
impl VectorStore {
pub fn new(provider: Arc<EmbeddingBackend>) -> Self {
Self {
collections: HashMap::new(),
indices: HashMap::new(),
provider,
storage_path: None,
chunker: CodeChunker::default(),
}
}
pub fn with_storage(mut self, path: impl Into<PathBuf>) -> Self {
self.storage_path = Some(path.into());
self
}
pub fn with_chunker(mut self, chunker: CodeChunker) -> Self {
self.chunker = chunker;
self
}
pub fn collection(&mut self, name: &str, scope: CollectionScope) -> &mut VectorCollection {
if !self.collections.contains_key(name) {
let collection = VectorCollection::new(name, scope);
let index = VectorIndex::new(self.provider.dimension());
self.collections.insert(name.to_string(), collection);
self.indices.insert(name.to_string(), index);
}
self.collections
.get_mut(name)
.unwrap_or_else(|| unreachable!("collection was just inserted"))
}
pub fn get_collection(&self, name: &str) -> Option<&VectorCollection> {
self.collections.get(name)
}
pub fn list_collections(&self) -> Vec<&str> {
self.collections.keys().map(|s| s.as_str()).collect()
}
pub fn delete_collection(&mut self, name: &str) -> Option<VectorCollection> {
self.indices.remove(name);
let removed = self.collections.remove(name);
if let Some(ref storage_path) = self.storage_path {
let json_path = storage_path.join(format!("{}.json", name));
let idx_path = storage_path.join(format!("{}.idx", name));
if json_path.exists() {
let _ = std::fs::remove_file(&json_path);
}
if idx_path.exists() {
let _ = std::fs::remove_file(&idx_path);
}
}
removed
}
pub async fn index_file(&mut self, collection_name: &str, file_path: &Path) -> Result<usize> {
let content = std::fs::read_to_string(file_path)?;
let chunks = self.chunker.chunk(&content, file_path);
let chunk_count = chunks.len();
let texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
let embeddings = self.provider.embed_batch(&texts).await?;
if !self.collections.contains_key(collection_name) {
self.collection(collection_name, CollectionScope::Project);
}
let collection = self.collections.get_mut(collection_name).with_context(|| {
format!("collection '{}' not found after creation", collection_name)
})?;
let index = self
.indices
.get_mut(collection_name)
.with_context(|| format!("index for collection '{}' not found", collection_name))?;
for (chunk, embedding) in chunks.into_iter().zip(embeddings.into_iter()) {
let chunk_id = chunk.id.clone();
let chunk = chunk.with_embedding(embedding.clone());
collection.add_chunk(chunk)?;
index.add(chunk_id, embedding)?;
}
Ok(chunk_count)
}
pub async fn rebuild_index(&mut self, collection_name: &str) -> Result<()> {
let collection = self
.collections
.get(collection_name)
.ok_or_else(|| anyhow!("Collection not found: {}", collection_name))?;
let texts: Vec<String> = collection
.chunks()
.iter()
.map(|c| c.content.clone())
.collect();
let ids: Vec<String> = collection.chunks().iter().map(|c| c.id.clone()).collect();
let embeddings = self.provider.embed_batch(&texts).await?;
let mut new_index = VectorIndex::new(self.provider.dimension());
for (id, embedding) in ids.into_iter().zip(embeddings.into_iter()) {
new_index.add(id, embedding)?;
}
self.indices.insert(collection_name.to_string(), new_index);
warn!(
"Rebuilt vector index for collection '{}' ({} vectors)",
collection_name,
texts.len()
);
Ok(())
}
fn build_search_results(
collection: &VectorCollection,
raw_results: Vec<(String, f32)>,
k: usize,
filter: Option<&SearchFilter>,
) -> Vec<SearchResult> {
let mut search_results = Vec::new();
for (chunk_id, score) in raw_results {
if let Some(chunk) = collection.get_chunk(&chunk_id) {
if let Some(filter) = filter {
if !filter.matches(chunk) {
continue;
}
if let Some(min_score) = filter.min_score {
if score < min_score {
continue;
}
}
}
let weighted_score = score * chunk.metadata.chunk_type.weight();
search_results.push(SearchResult {
chunk: chunk.clone(),
score: weighted_score,
distance: 1.0 - score,
});
}
}
search_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
search_results.truncate(k);
search_results
}
pub async fn search(
&self,
collection_name: &str,
query: &str,
k: usize,
filter: Option<&SearchFilter>,
) -> Result<Vec<SearchResult>> {
let collection = self
.collections
.get(collection_name)
.ok_or_else(|| anyhow!("Collection not found: {}", collection_name))?;
let index = self
.indices
.get(collection_name)
.ok_or_else(|| anyhow!("Index not found: {}", collection_name))?;
let query_embedding = self.provider.embed(query).await?;
let raw_results = index.search(&query_embedding, k * 2);
let all_nan =
!raw_results.is_empty() && raw_results.iter().all(|(_, score)| score.is_nan());
if all_nan {
warn!(
"All search scores are NaN for collection '{}' — index may be corrupt; \
consider calling search_or_rebuild()",
collection_name
);
}
Ok(Self::build_search_results(
collection,
raw_results,
k,
filter,
))
}
pub async fn search_or_rebuild(
&mut self,
collection_name: &str,
query: &str,
k: usize,
filter: Option<&SearchFilter>,
) -> Result<Vec<SearchResult>> {
let query_embedding = self.provider.embed(query).await?;
let raw_results = {
let index = self
.indices
.get(collection_name)
.ok_or_else(|| anyhow!("Index not found: {}", collection_name))?;
index.search(&query_embedding, k * 2)
};
let all_nan =
!raw_results.is_empty() && raw_results.iter().all(|(_, score)| score.is_nan());
let raw_results = if all_nan {
warn!(
"All search scores are NaN for collection '{}' — rebuilding index",
collection_name
);
self.rebuild_index(collection_name).await?;
let index = self
.indices
.get(collection_name)
.ok_or_else(|| anyhow!("Index not found after rebuild: {}", collection_name))?;
index.search(&query_embedding, k * 2)
} else {
raw_results
};
let collection = self
.collections
.get(collection_name)
.ok_or_else(|| anyhow!("Collection not found: {}", collection_name))?;
Ok(Self::build_search_results(
collection,
raw_results,
k,
filter,
))
}
pub fn save(&self) -> Result<()> {
let storage_path = self
.storage_path
.as_ref()
.ok_or_else(|| anyhow!("Storage path not set"))?;
std::fs::create_dir_all(storage_path)?;
let pid = std::process::id();
for (name, collection) in &self.collections {
let collection_path = storage_path.join(format!("{}.json", name));
let json = serde_json::to_string_pretty(collection)?;
let tmp_json = collection_path.with_extension(format!("json.tmp.{}", pid));
std::fs::write(&tmp_json, &json)?;
if let Err(e) = std::fs::rename(&tmp_json, &collection_path) {
let _ = std::fs::remove_file(&tmp_json);
return Err(e).context("Failed to atomically save collection");
}
if let Some(index) = self.indices.get(name) {
let index_path = storage_path.join(format!("{}.idx", name));
let data = bincode::serde::encode_to_vec(
(&index.embeddings, &index.chunk_ids),
bincode::config::standard(),
)?;
let tmp_idx = index_path.with_extension(format!("idx.tmp.{}", pid));
std::fs::write(&tmp_idx, &data)?;
if let Err(e) = std::fs::rename(&tmp_idx, &index_path) {
let _ = std::fs::remove_file(&tmp_idx);
return Err(e).context("Failed to atomically save index");
}
}
}
Ok(())
}
pub fn load(&mut self) -> Result<()> {
let storage_path = self
.storage_path
.as_ref()
.ok_or_else(|| anyhow!("Storage path not set"))?
.clone();
if !storage_path.exists() {
return Ok(()); }
for entry in std::fs::read_dir(&storage_path)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("json") {
let name = path
.file_stem()
.and_then(|s| s.to_str())
.ok_or_else(|| anyhow!("Invalid collection file name"))?;
let json = std::fs::read_to_string(&path)?;
let collection: VectorCollection = serde_json::from_str(&json)?;
self.collections.insert(name.to_string(), collection);
let index_path = storage_path.join(format!("{}.idx", name));
if index_path.exists() {
let data = std::fs::read(&index_path)?;
let ((embeddings, chunk_ids), _): ((Vec<Vec<f32>>, Vec<String>), usize) =
bincode::serde::decode_from_slice(&data, bincode::config::standard())?;
if embeddings.len() != chunk_ids.len() {
tracing::warn!(
"Corrupt vector index for '{}': {} embeddings vs {} chunk_ids — skipping",
name,
embeddings.len(),
chunk_ids.len()
);
continue;
}
let mut index = VectorIndex::new(self.provider.dimension());
for (chunk_id, embedding) in chunk_ids.into_iter().zip(embeddings.into_iter()) {
index.add(chunk_id, embedding)?;
}
self.indices.insert(name.to_string(), index);
}
}
}
Ok(())
}
pub fn stats(&self) -> VectorStoreStats {
let mut total_chunks = 0;
let mut total_files = 0;
let mut collections = Vec::new();
for (name, collection) in &self.collections {
total_chunks += collection.len();
total_files += collection.files().len();
collections.push(CollectionStats {
name: name.clone(),
chunk_count: collection.len(),
file_count: collection.files().len(),
scope: collection.scope,
});
}
VectorStoreStats {
total_chunks,
total_files,
collection_count: self.collections.len(),
collections,
embedding_dimension: self.provider.dimension(),
}
}
}
#[derive(Debug, Clone)]
pub struct VectorStoreStats {
pub total_chunks: usize,
pub total_files: usize,
pub collection_count: usize,
pub collections: Vec<CollectionStats>,
pub embedding_dimension: usize,
}
#[derive(Debug, Clone)]
pub struct CollectionStats {
pub name: String,
pub chunk_count: usize,
pub file_count: usize,
pub scope: CollectionScope,
}
pub const DEFAULT_MAX_ITEMS: usize = 10_000;
pub struct BoundedVectorStore {
inner: VectorStore,
max_items: usize,
insertion_order: std::sync::Mutex<std::collections::VecDeque<(String, String)>>,
}
impl BoundedVectorStore {
pub fn new(inner: VectorStore, max_items: usize) -> Self {
Self {
inner,
max_items,
insertion_order: std::sync::Mutex::new(std::collections::VecDeque::new()),
}
}
pub fn with_default_capacity(inner: VectorStore) -> Self {
Self::new(inner, DEFAULT_MAX_ITEMS)
}
pub fn len(&self) -> usize {
self.inner.stats().total_chunks
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn max_items(&self) -> usize {
self.max_items
}
pub fn clear(&mut self) {
let names: Vec<String> = self
.inner
.list_collections()
.iter()
.map(|s| s.to_string())
.collect();
for name in names {
self.inner.delete_collection(&name);
}
if let Ok(mut order) = self.insertion_order.lock() {
order.clear();
}
}
fn evict_if_needed(&mut self) {
let mut current = self.len();
if current <= self.max_items {
return;
}
let mut order = self
.insertion_order
.lock()
.unwrap_or_else(|e| e.into_inner());
while current > self.max_items {
if let Some((collection_name, chunk_id)) = order.pop_front() {
if let Some(collection) = self.inner.collections.get_mut(&collection_name) {
if collection.remove_chunk(&chunk_id).is_some() {
if let Some(index) = self.inner.indices.get_mut(&collection_name) {
index.remove(&chunk_id);
}
current -= 1;
}
}
} else {
break;
}
}
}
pub async fn index_file(&mut self, collection_name: &str, file_path: &Path) -> Result<usize> {
let count = self.inner.index_file(collection_name, file_path).await?;
if let Some(collection) = self.inner.get_collection(collection_name) {
let mut order = self
.insertion_order
.lock()
.unwrap_or_else(|e| e.into_inner());
let chunks = collection.chunks();
let start = chunks.len().saturating_sub(count);
for chunk in &chunks[start..] {
order.push_back((collection_name.to_string(), chunk.id.clone()));
}
}
self.evict_if_needed();
Ok(count)
}
pub fn inner(&self) -> &VectorStore {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut VectorStore {
&mut self.inner
}
pub fn collection(&mut self, name: &str, scope: CollectionScope) -> &mut VectorCollection {
self.inner.collection(name, scope)
}
pub async fn search(
&self,
collection_name: &str,
query: &str,
k: usize,
filter: Option<&SearchFilter>,
) -> Result<Vec<SearchResult>> {
self.inner.search(collection_name, query, k, filter).await
}
pub fn stats(&self) -> VectorStoreStats {
self.inner.stats()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tempfile::tempdir;
#[test]
fn test_chunk_type_weight() {
assert_eq!(ChunkType::Function.weight(), 1.0);
assert_eq!(ChunkType::Import.weight(), 0.3);
assert!(ChunkType::Comment.weight() < ChunkType::Function.weight());
}
#[test]
fn test_chunk_metadata_creation() {
let meta = ChunkMetadata::new(
PathBuf::from("src/lib.rs"),
1,
10,
ChunkType::Function,
"rust",
"fn main() {}",
);
assert_eq!(*meta.file_path, *Path::new("src/lib.rs"));
assert_eq!(meta.start_line, 1);
assert_eq!(meta.end_line, 10);
assert_eq!(meta.chunk_type, ChunkType::Function);
assert!(!meta.content_hash.is_empty());
}
#[test]
fn test_chunk_metadata_with_symbol() {
let meta = ChunkMetadata::new(
PathBuf::from("lib.rs"),
1,
5,
ChunkType::Function,
"rust",
"fn test() {}",
)
.with_symbol("test")
.with_tag("unit-test");
assert_eq!(meta.symbol_name, Some("test".to_string()));
assert!(meta.tags.contains(&"unit-test".to_string()));
}
#[test]
fn test_code_chunk_creation() {
let meta = ChunkMetadata::new(
PathBuf::from("lib.rs"),
1,
3,
ChunkType::Function,
"rust",
"fn hello() {}",
);
let chunk = CodeChunk::new("fn hello() {}".to_string(), meta);
assert!(!chunk.id.is_empty());
assert_eq!(chunk.content, "fn hello() {}");
assert_eq!(chunk.len(), 13);
assert!(!chunk.is_empty());
}
#[test]
fn test_search_filter() {
let filter = SearchFilter::new()
.with_file_pattern("*.rs")
.with_chunk_type(ChunkType::Function)
.with_language("rust")
.with_min_score(0.5);
let meta = ChunkMetadata::new(
PathBuf::from("test.rs"),
1,
5,
ChunkType::Function,
"rust",
"fn test() {}",
);
let chunk = CodeChunk::new("fn test() {}".to_string(), meta);
assert!(filter.matches(&chunk));
}
#[test]
fn test_search_filter_file_pattern_mismatch() {
let filter = SearchFilter::new().with_file_pattern("*.py");
let meta = ChunkMetadata::new(
PathBuf::from("test.rs"),
1,
5,
ChunkType::Function,
"rust",
"fn test() {}",
);
let chunk = CodeChunk::new("fn test() {}".to_string(), meta);
assert!(!filter.matches(&chunk));
}
#[test]
fn test_vector_collection_add_get() {
let mut collection = VectorCollection::new("test", CollectionScope::Project);
let meta = ChunkMetadata::new(
PathBuf::from("lib.rs"),
1,
5,
ChunkType::Function,
"rust",
"fn test() {}",
);
let chunk = CodeChunk::new("fn test() {}".to_string(), meta);
let chunk_id = chunk.id.clone();
collection.add_chunk(chunk).unwrap();
assert_eq!(collection.len(), 1);
assert!(collection.get_chunk(&chunk_id).is_some());
}
#[test]
fn test_vector_collection_remove_chunk() {
let mut collection = VectorCollection::new("test", CollectionScope::Project);
let meta = ChunkMetadata::new(
PathBuf::from("lib.rs"),
1,
5,
ChunkType::Function,
"rust",
"fn test() {}",
);
let chunk = CodeChunk::new("fn test() {}".to_string(), meta);
let chunk_id = chunk.id.clone();
collection.add_chunk(chunk).unwrap();
assert_eq!(collection.len(), 1);
let removed = collection.remove_chunk(&chunk_id);
assert!(removed.is_some());
assert_eq!(collection.len(), 0);
}
#[test]
fn test_vector_collection_remove_file() {
let mut collection = VectorCollection::new("test", CollectionScope::Project);
let path = PathBuf::from("lib.rs");
for i in 0..3 {
let meta = ChunkMetadata::new(
path.clone(),
i * 10 + 1,
(i + 1) * 10,
ChunkType::Function,
"rust",
&format!("fn test{}() {{}}", i),
);
let chunk = CodeChunk::new(format!("fn test{}() {{}}", i), meta);
collection.add_chunk(chunk).unwrap();
}
assert_eq!(collection.len(), 3);
collection.remove_file(&path);
assert_eq!(collection.len(), 0);
}
#[tokio::test]
async fn test_mock_embedding_provider() {
let provider = MockEmbeddingProvider::new(384);
let embedding = provider.embed("test text").await.unwrap();
assert_eq!(embedding.len(), 384);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[tokio::test]
async fn test_mock_embedding_deterministic() {
let provider = MockEmbeddingProvider::new(384);
let e1 = provider.embed("test").await.unwrap();
let e2 = provider.embed("test").await.unwrap();
assert_eq!(e1, e2);
}
#[tokio::test]
async fn test_tfidf_embedding_provider() {
let provider = TfIdfEmbeddingProvider::new(256);
let embedding = provider.embed("fn test() {}").await.unwrap();
assert_eq!(embedding.len(), 256);
}
#[tokio::test]
async fn test_tfidf_similar_texts() {
let provider = TfIdfEmbeddingProvider::new(256);
let e1 = provider.embed("function test").await.unwrap();
let e2 = provider.embed("test function").await.unwrap();
let similarity = VectorIndex::cosine_similarity(&e1, &e2);
assert!(similarity > 0.5);
}
#[test]
fn test_vector_index_add_search() {
let mut index = VectorIndex::new(4);
index
.add("a".to_string(), vec![1.0, 0.0, 0.0, 0.0])
.unwrap();
index
.add("b".to_string(), vec![0.0, 1.0, 0.0, 0.0])
.unwrap();
index
.add("c".to_string(), vec![0.9, 0.1, 0.0, 0.0])
.unwrap();
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "a"); assert_eq!(results[1].0, "c"); }
#[test]
fn test_vector_index_remove() {
let mut index = VectorIndex::new(4);
index
.add("a".to_string(), vec![1.0, 0.0, 0.0, 0.0])
.unwrap();
index
.add("b".to_string(), vec![0.0, 1.0, 0.0, 0.0])
.unwrap();
assert_eq!(index.len(), 2);
index.remove("a");
assert_eq!(index.len(), 1);
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 1);
assert_eq!(results[0].0, "b"); }
#[test]
fn test_code_chunker_rust() {
let chunker = CodeChunker::default();
let content = r#"
pub fn hello() {
println!("Hello");
}
pub struct Point {
x: i32,
y: i32,
}
impl Point {
pub fn new() -> Self {
Self { x: 0, y: 0 }
}
}
"#;
let chunks = chunker.chunk_rust(content, Path::new("lib.rs"));
assert!(chunks.len() >= 3);
let types: Vec<_> = chunks.iter().map(|c| c.metadata.chunk_type).collect();
assert!(types.contains(&ChunkType::Function));
assert!(types.contains(&ChunkType::Struct));
assert!(types.contains(&ChunkType::Impl));
}
#[test]
fn test_code_chunker_extract_symbol() {
let chunker = CodeChunker::default();
let fn_name = chunker.extract_rust_symbol("pub fn hello() {}", ChunkType::Function);
assert_eq!(fn_name, Some("hello".to_string()));
let struct_name = chunker.extract_rust_symbol("pub struct MyStruct {", ChunkType::Struct);
assert_eq!(struct_name, Some("MyStruct".to_string()));
let impl_name = chunker.extract_rust_symbol("impl MyStruct {", ChunkType::Impl);
assert_eq!(impl_name, Some("MyStruct".to_string()));
}
#[test]
fn test_code_chunker_fixed_size() {
let chunker = CodeChunker {
max_chunk_size: 100,
min_chunk_size: 10,
overlap: 10,
};
let content = "a\n".repeat(50);
let chunks = chunker.chunk_fixed_size(&content, Path::new("test.txt"), "txt");
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(chunk.len() <= 100);
}
}
#[tokio::test]
async fn test_vector_store_create_collection() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let mut store = VectorStore::new(provider);
store.collection("test", CollectionScope::Project);
assert!(store.get_collection("test").is_some());
assert!(store.list_collections().contains(&"test"));
}
#[tokio::test]
async fn test_vector_store_delete_collection() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let mut store = VectorStore::new(provider);
store.collection("test", CollectionScope::Project);
let deleted = store.delete_collection("test");
assert!(deleted.is_some());
assert!(store.get_collection("test").is_none());
}
#[tokio::test]
async fn test_vector_store_index_file() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let mut store = VectorStore::new(provider);
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.rs");
std::fs::write(&file_path, "pub fn test() {}\npub fn hello() {}").unwrap();
store.collection("project", CollectionScope::Project);
let count = store.index_file("project", &file_path).await.unwrap();
assert!(count >= 1);
let collection = store.get_collection("project").unwrap();
assert!(!collection.is_empty());
}
#[tokio::test]
async fn test_vector_store_search() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let mut store = VectorStore::new(provider);
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.rs");
std::fs::write(
&file_path,
r#"
pub fn calculate_sum(a: i32, b: i32) -> i32 {
a + b
}
pub fn calculate_product(a: i32, b: i32) -> i32 {
a * b
}
"#,
)
.unwrap();
store.collection("project", CollectionScope::Project);
store.index_file("project", &file_path).await.unwrap();
let results = store
.search("project", "sum addition", 5, None)
.await
.unwrap();
assert!(!results.is_empty());
}
#[tokio::test]
async fn test_vector_store_search_with_filter() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let mut store = VectorStore::new(provider);
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.rs");
std::fs::write(&file_path, "pub fn test() {}").unwrap();
store.collection("project", CollectionScope::Project);
store.index_file("project", &file_path).await.unwrap();
let filter = SearchFilter::new()
.with_chunk_type(ChunkType::Struct)
.with_min_score(0.9);
let results = store
.search("project", "test", 5, Some(&filter))
.await
.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_vector_store_persistence() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let dir = tempdir().unwrap();
let storage_path = dir.path().join("vector_store");
{
let mut store = VectorStore::new(provider.clone()).with_storage(&storage_path);
let file_path = dir.path().join("test.rs");
std::fs::write(&file_path, "pub fn test() {}").unwrap();
store.collection("project", CollectionScope::Project);
store.index_file("project", &file_path).await.unwrap();
store.save().unwrap();
}
{
let mut store = VectorStore::new(provider).with_storage(&storage_path);
store.load().unwrap();
assert!(store.get_collection("project").is_some());
}
}
#[tokio::test]
async fn test_vector_store_stats() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let mut store = VectorStore::new(provider);
store.collection("project1", CollectionScope::Project);
store.collection("project2", CollectionScope::Session);
let stats = store.stats();
assert_eq!(stats.collection_count, 2);
assert_eq!(stats.embedding_dimension, EMBEDDING_DIM);
}
#[test]
fn test_cosine_similarity() {
let sim = VectorIndex::cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]);
assert!((sim - 1.0).abs() < 0.01);
let sim = VectorIndex::cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]);
assert!(sim.abs() < 0.01);
let sim = VectorIndex::cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]);
assert!((sim + 1.0).abs() < 0.01);
}
#[test]
fn test_collection_scope_default() {
assert_eq!(CollectionScope::default(), CollectionScope::Project);
}
#[test]
fn test_chunk_type_default() {
assert_eq!(ChunkType::default(), ChunkType::CodeBlock);
}
#[test]
fn test_empty_vector_index() {
let index = VectorIndex::new(4);
assert!(index.is_empty());
assert_eq!(index.len(), 0);
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 5);
assert!(results.is_empty());
}
#[test]
fn test_vector_index_dimension_mismatch() {
let mut index = VectorIndex::new(4);
let result = index.add("a".to_string(), vec![1.0, 0.0, 0.0]); assert!(result.is_err());
}
#[tokio::test]
async fn test_embedding_batch() {
let provider = MockEmbeddingProvider::default();
let texts = vec!["hello".to_string(), "world".to_string()];
let embeddings = provider.embed_batch(&texts).await.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), EMBEDDING_DIM);
}
#[test]
fn test_search_filter_empty_matches_all() {
let filter = SearchFilter::new();
let meta = ChunkMetadata::new(
PathBuf::from("any.py"),
1,
5,
ChunkType::Text,
"python",
"# comment",
);
let chunk = CodeChunk::new("# comment".to_string(), meta);
assert!(filter.matches(&chunk)); }
#[test]
fn test_chunk_with_embedding() {
let meta = ChunkMetadata::new(
PathBuf::from("lib.rs"),
1,
3,
ChunkType::Function,
"rust",
"fn hello() {}",
);
let chunk = CodeChunk::new("fn hello() {}".to_string(), meta);
let embedding = vec![0.1, 0.2, 0.3];
let chunk = chunk.with_embedding(embedding.clone());
assert_eq!(chunk.embedding, Some(embedding));
}
#[test]
fn test_collection_files() {
let mut collection = VectorCollection::new("test", CollectionScope::Project);
for path in ["a.rs", "b.rs", "c.rs"] {
let meta = ChunkMetadata::new(
PathBuf::from(path),
1,
5,
ChunkType::Function,
"rust",
"fn test() {}",
);
let chunk = CodeChunk::new("fn test() {}".to_string(), meta);
collection.add_chunk(chunk).unwrap();
}
let files = collection.files();
assert_eq!(files.len(), 3);
}
#[test]
fn test_chunk_type_all_variants() {
let types = [
ChunkType::Function,
ChunkType::Struct,
ChunkType::Enum,
ChunkType::Trait,
ChunkType::Impl,
ChunkType::Module,
ChunkType::Import,
ChunkType::Comment,
ChunkType::Test,
ChunkType::Constant,
ChunkType::CodeBlock,
ChunkType::Text,
];
for chunk_type in types {
assert!(chunk_type.weight() >= 0.0);
assert!(chunk_type.weight() <= 1.0);
let _ = format!("{:?}", chunk_type);
}
}
#[test]
fn test_chunk_metadata_clone() {
let meta = ChunkMetadata::new(
PathBuf::from("test.rs"),
1,
10,
ChunkType::Function,
"rust",
"fn test() {}",
);
let cloned = meta.clone();
assert_eq!(meta.file_path, cloned.file_path);
assert_eq!(meta.content_hash, cloned.content_hash);
}
#[test]
fn test_chunk_metadata_serialization() {
let meta = ChunkMetadata::new(
PathBuf::from("test.rs"),
1,
10,
ChunkType::Function,
"rust",
"fn test() {}",
);
let json = serde_json::to_string(&meta).unwrap();
let deserialized: ChunkMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(meta.chunk_type, deserialized.chunk_type);
}
#[test]
fn test_code_chunk_clone() {
let meta = ChunkMetadata::new(
PathBuf::from("lib.rs"),
1,
5,
ChunkType::Function,
"rust",
"fn hello() {}",
);
let chunk = CodeChunk::new("fn hello() {}".to_string(), meta);
let cloned = chunk.clone();
assert_eq!(chunk.id, cloned.id);
assert_eq!(chunk.content, cloned.content);
}
#[test]
fn test_search_filter_clone() {
let filter = SearchFilter::new()
.with_file_pattern("*.rs")
.with_chunk_type(ChunkType::Function);
let cloned = filter.clone();
assert_eq!(filter.file_patterns, cloned.file_patterns);
}
#[test]
fn test_search_filter_with_tag() {
let filter = SearchFilter::new().with_tag("important");
let meta = ChunkMetadata::new(
PathBuf::from("test.rs"),
1,
5,
ChunkType::Function,
"rust",
"fn test() {}",
)
.with_tag("important");
let chunk = CodeChunk::new("fn test() {}".to_string(), meta);
assert!(filter.matches(&chunk));
}
#[test]
fn test_collection_scope_all_variants() {
let scopes = [
CollectionScope::Project,
CollectionScope::Session,
CollectionScope::Global,
];
for scope in scopes {
let _ = format!("{:?}", scope);
let cloned = scope;
assert_eq!(scope, cloned);
}
}
#[test]
fn test_vector_collection_is_empty() {
let collection = VectorCollection::new("test", CollectionScope::Project);
assert!(collection.is_empty());
assert_eq!(collection.len(), 0);
}
#[test]
fn test_vector_collection_name() {
let collection = VectorCollection::new("test_collection", CollectionScope::Project);
assert_eq!(collection.name, "test_collection");
}
#[test]
fn test_search_result_clone() {
let meta = ChunkMetadata::new(
PathBuf::from("test.rs"),
1,
5,
ChunkType::Function,
"rust",
"fn test() {}",
);
let chunk = CodeChunk::new("fn test() {}".to_string(), meta);
let result = SearchResult {
chunk,
score: 0.95,
distance: 0.05,
};
let cloned = result.clone();
assert_eq!(result.score, cloned.score);
assert_eq!(result.distance, cloned.distance);
}
#[test]
fn test_vector_index_clear() {
let mut index = VectorIndex::new(4);
index
.add("a".to_string(), vec![1.0, 0.0, 0.0, 0.0])
.unwrap();
index
.add("b".to_string(), vec![0.0, 1.0, 0.0, 0.0])
.unwrap();
assert_eq!(index.len(), 2);
index.clear();
assert!(index.is_empty());
}
#[tokio::test]
async fn test_mock_embedding_provider_dimension() {
let provider = MockEmbeddingProvider::new(512);
let embedding = provider.embed("test").await.unwrap();
assert_eq!(embedding.len(), 512);
}
#[test]
fn test_code_chunker_new() {
let chunker = CodeChunker::new(2000);
assert_eq!(chunker.max_chunk_size, 2000);
}
#[test]
fn test_vector_store_stats_empty() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let store = VectorStore::new(provider);
let stats = store.stats();
assert_eq!(stats.collection_count, 0);
assert_eq!(stats.total_chunks, 0);
}
#[test]
fn test_verify_index_integrity_healthy() {
let mut index = VectorIndex::new(3);
index.add("a".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
index.add("b".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
let issues = index.verify_index_integrity();
assert!(issues.is_empty(), "Expected no issues, got: {:?}", issues);
}
#[test]
fn test_verify_index_integrity_nan() {
let mut index = VectorIndex::new(3);
index
.add("a".to_string(), vec![1.0, f32::NAN, 0.0])
.unwrap();
let issues = index.verify_index_integrity();
assert!(!issues.is_empty());
assert!(issues.iter().any(|i| i.contains("NaN")));
}
#[test]
fn test_verify_index_integrity_inf() {
let mut index = VectorIndex::new(3);
index
.add("a".to_string(), vec![1.0, f32::INFINITY, 0.0])
.unwrap();
let issues = index.verify_index_integrity();
assert!(!issues.is_empty());
assert!(issues
.iter()
.any(|i| i.contains("NaN") || i.contains("Inf")));
}
#[test]
fn test_verify_index_integrity_duplicate_ids() {
let mut index = VectorIndex::new(2);
index.add("dup".to_string(), vec![1.0, 0.0]).unwrap();
index.add("dup".to_string(), vec![0.0, 1.0]).unwrap();
let issues = index.verify_index_integrity();
assert!(issues.iter().any(|i| i.contains("Duplicate")));
}
#[test]
fn test_check_health_healthy() {
let mut index = VectorIndex::new(2);
index.add("a".to_string(), vec![1.0, 0.0]).unwrap();
assert_eq!(index.check_health(), IndexHealth::Healthy);
}
#[test]
fn test_check_health_corrupt_nan() {
let mut index = VectorIndex::new(2);
index.add("a".to_string(), vec![f32::NAN, 0.0]).unwrap();
assert_eq!(index.check_health(), IndexHealth::Corrupt);
}
#[test]
fn test_check_health_degraded_duplicates() {
let mut index = VectorIndex::new(2);
index.add("dup".to_string(), vec![1.0, 0.0]).unwrap();
index.add("dup".to_string(), vec![0.0, 1.0]).unwrap();
assert_eq!(index.check_health(), IndexHealth::Degraded);
}
#[test]
fn test_check_health_empty_index() {
let index = VectorIndex::new(4);
assert_eq!(index.check_health(), IndexHealth::Healthy);
}
#[tokio::test]
async fn test_rebuild_index() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let mut store = VectorStore::new(provider);
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.rs");
std::fs::write(&file_path, "pub fn test() {}").unwrap();
store.collection("project", CollectionScope::Project);
store.index_file("project", &file_path).await.unwrap();
store.rebuild_index("project").await.unwrap();
let index = store.indices.get("project").unwrap();
assert_eq!(index.check_health(), IndexHealth::Healthy);
}
#[test]
fn test_cached_extract_rust_symbol_fn() {
let chunker = CodeChunker::default();
assert_eq!(
chunker.extract_rust_symbol("pub fn hello() {}", ChunkType::Function),
Some("hello".to_string())
);
assert_eq!(
chunker.extract_rust_symbol("fn world() {}", ChunkType::Function),
Some("world".to_string())
);
assert_eq!(
chunker.extract_rust_symbol("pub struct Foo {", ChunkType::Function),
None, );
}
#[test]
fn test_cached_extract_rust_symbol_all_types() {
let chunker = CodeChunker::default();
assert_eq!(
chunker.extract_rust_symbol("pub struct MyStruct {", ChunkType::Struct),
Some("MyStruct".to_string())
);
assert_eq!(
chunker.extract_rust_symbol("enum Color {", ChunkType::Enum),
Some("Color".to_string())
);
assert_eq!(
chunker.extract_rust_symbol("pub trait Display {", ChunkType::Trait),
Some("Display".to_string())
);
assert_eq!(
chunker.extract_rust_symbol("impl<T> MyStruct {", ChunkType::Impl),
Some("MyStruct".to_string())
);
assert_eq!(
chunker.extract_rust_symbol("impl Display for MyStruct {", ChunkType::Impl),
Some("Display".to_string())
);
assert_eq!(
chunker.extract_rust_symbol("mod utils {", ChunkType::Module),
Some("utils".to_string())
);
assert_eq!(
chunker.extract_rust_symbol("// comment", ChunkType::Comment),
None,
);
}
#[tokio::test]
async fn test_bounded_vector_store_eviction_at_capacity() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let inner = VectorStore::new(provider);
let mut bounded = BoundedVectorStore::new(inner, 3);
bounded.collection("test", CollectionScope::Project);
let dir = tempdir().unwrap();
for i in 0..6 {
let file_path = dir.path().join(format!("file{}.rs", i));
std::fs::write(
&file_path,
format!("pub fn func_{}() {{ println!(\"hello\"); }}", i),
)
.unwrap();
bounded.index_file("test", &file_path).await.unwrap();
}
assert!(
bounded.len() <= 3,
"Store has {} items but max is 3",
bounded.len()
);
}
#[tokio::test]
async fn test_bounded_vector_store_stays_within_bounds() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let inner = VectorStore::new(provider);
let mut bounded = BoundedVectorStore::new(inner, 5);
bounded.collection("coll", CollectionScope::Session);
let dir = tempdir().unwrap();
for i in 0..20 {
let file_path = dir.path().join(format!("mod{}.rs", i));
std::fs::write(
&file_path,
format!("pub fn handler_{}() {{ let x = {}; }}", i, i * 42),
)
.unwrap();
bounded.index_file("coll", &file_path).await.unwrap();
assert!(
bounded.len() <= 5,
"After inserting file {}, store has {} items (max 5)",
i,
bounded.len()
);
}
}
#[tokio::test]
async fn test_bounded_vector_store_clear() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let inner = VectorStore::new(provider);
let mut bounded = BoundedVectorStore::new(inner, 100);
bounded.collection("proj", CollectionScope::Project);
let dir = tempdir().unwrap();
let file_path = dir.path().join("code.rs");
std::fs::write(&file_path, "pub fn example() { let _ = 1 + 2; }").unwrap();
bounded.index_file("proj", &file_path).await.unwrap();
assert!(!bounded.is_empty());
bounded.clear();
assert!(bounded.is_empty());
assert_eq!(bounded.len(), 0);
}
#[test]
fn test_bounded_vector_store_default_capacity() {
let provider = Arc::new(EmbeddingBackend::Mock(MockEmbeddingProvider::default()));
let inner = VectorStore::new(provider);
let bounded = BoundedVectorStore::with_default_capacity(inner);
assert_eq!(bounded.max_items(), DEFAULT_MAX_ITEMS);
assert!(bounded.is_empty());
}
}