use std::path::Path;
use std::time::Instant;
use crate::semantic::cache::EmbeddingCache;
use crate::semantic::chunker::chunk_code;
use crate::semantic::embedder::Embedder;
use crate::semantic::similarity::top_k_similar;
use crate::semantic::types::{
CacheConfig, ChunkGranularity, ChunkOptions, EmbeddedChunk, EmbeddingModel,
SemanticSearchReport, SemanticSearchResult, SimilarityReport,
};
use crate::{TldrError, TldrResult};
pub const MAX_INDEX_SIZE: usize = 100_000;
const BYTES_PER_CHUNK: usize = 768 * 4 + 500;
const MAX_MEMORY_BYTES: usize = 500 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct BuildOptions {
pub model: EmbeddingModel,
pub granularity: ChunkGranularity,
pub languages: Option<Vec<String>>,
pub show_progress: bool,
pub use_cache: bool,
}
impl Default for BuildOptions {
fn default() -> Self {
Self {
model: EmbeddingModel::default(),
granularity: ChunkGranularity::Function,
languages: None,
show_progress: true,
use_cache: true,
}
}
}
#[derive(Debug, Clone)]
pub struct SearchOptions {
pub top_k: usize,
pub threshold: f64,
pub include_snippet: bool,
pub snippet_lines: usize,
}
impl Default for SearchOptions {
fn default() -> Self {
Self {
top_k: 10,
threshold: 0.5,
include_snippet: true,
snippet_lines: 5,
}
}
}
pub struct SemanticIndex {
chunks: Vec<EmbeddedChunk>,
model: EmbeddingModel,
embedder: Option<Embedder>,
}
impl SemanticIndex {
pub fn build<P: AsRef<Path>>(
root: P,
options: BuildOptions,
cache_config: Option<CacheConfig>,
) -> TldrResult<Self> {
let start = Instant::now();
let root = root.as_ref();
let mut cache = if options.use_cache {
cache_config.map(EmbeddingCache::open).transpose()?
} else {
None
};
let chunk_languages = options.languages.as_ref().map(|langs| {
langs
.iter()
.filter_map(|s| crate::Language::from_extension(s))
.collect()
});
let chunk_opts = ChunkOptions {
granularity: options.granularity,
languages: chunk_languages,
..Default::default()
};
let chunk_result = chunk_code(root, &chunk_opts)?;
if chunk_result.chunks.len() > MAX_INDEX_SIZE {
return Err(TldrError::IndexTooLarge {
count: chunk_result.chunks.len(),
max: MAX_INDEX_SIZE,
});
}
let estimated_memory = chunk_result.chunks.len() * BYTES_PER_CHUNK;
if estimated_memory > MAX_MEMORY_BYTES {
return Err(TldrError::MemoryLimitExceeded {
estimated_mb: estimated_memory / (1024 * 1024),
max_mb: MAX_MEMORY_BYTES / (1024 * 1024),
});
}
if options.show_progress && !chunk_result.chunks.is_empty() {
eprintln!("Building index for {} chunks...", chunk_result.chunks.len());
}
if !chunk_result.skipped.is_empty() && options.show_progress {
eprintln!(
"Skipped {} files (parse errors or unsupported)",
chunk_result.skipped.len()
);
}
let mut embedded_chunks: Vec<EmbeddedChunk> = Vec::with_capacity(chunk_result.chunks.len());
let mut uncached_indices: Vec<usize> = Vec::new();
for (i, chunk) in chunk_result.chunks.iter().enumerate() {
let cached_embedding = if let Some(ref mut c) = cache {
c.get(chunk, options.model)
} else {
None
};
match cached_embedding {
Some(e) => {
embedded_chunks.push(EmbeddedChunk {
chunk: chunk.clone(),
embedding: e,
});
}
None => {
embedded_chunks.push(EmbeddedChunk {
chunk: chunk.clone(),
embedding: Vec::new(),
});
uncached_indices.push(i);
}
}
}
let embedder = if !uncached_indices.is_empty() {
if options.show_progress {
eprintln!(
"Batch embedding {} uncached chunks...",
uncached_indices.len()
);
}
let mut embedder = Embedder::new(options.model)?;
let texts: Vec<&str> = uncached_indices
.iter()
.map(|&i| chunk_result.chunks[i].content.as_str())
.collect();
let embeddings = embedder.embed_batch(texts, options.show_progress)?;
for (idx, embedding) in uncached_indices.iter().zip(embeddings) {
if let Some(ref mut c) = cache {
c.put(&chunk_result.chunks[*idx], embedding.clone(), options.model);
}
embedded_chunks[*idx].embedding = embedding;
}
Some(embedder)
} else {
if options.show_progress {
eprintln!("All chunks cached - skipping embedder initialization");
}
None
};
if let Some(ref mut c) = cache {
c.flush()?;
}
if options.show_progress {
eprintln!("Index built in {:?}", start.elapsed());
}
Ok(Self {
chunks: embedded_chunks,
model: options.model,
embedder,
})
}
pub fn search(
&mut self,
query: &str,
options: &SearchOptions,
) -> TldrResult<SemanticSearchReport> {
let start = Instant::now();
if self.embedder.is_none() {
self.embedder = Some(Embedder::new(self.model)?);
}
let query_embedding = self.embedder.as_mut().unwrap().embed_text(query)?;
let candidates: Vec<(usize, &[f32])> = self
.chunks
.iter()
.enumerate()
.map(|(i, c)| (i, c.embedding.as_slice()))
.collect();
let similar = top_k_similar(
&query_embedding,
&candidates,
options.top_k,
options.threshold,
);
let results: Vec<SemanticSearchResult> = similar
.into_iter()
.map(|(idx, score)| {
let chunk = &self.chunks[idx].chunk;
let snippet = if options.include_snippet {
make_snippet(&chunk.content, options.snippet_lines)
} else {
String::new()
};
SemanticSearchResult {
file_path: chunk.file_path.clone(),
function_name: chunk.function_name.clone(),
class_name: chunk.class_name.clone(),
score,
line_start: chunk.line_start,
line_end: chunk.line_end,
snippet,
}
})
.collect();
let matches_above_threshold = results.len();
Ok(SemanticSearchReport {
query: query.to_string(),
model: self.model,
results,
total_chunks: self.chunks.len(),
matches_above_threshold,
latency_ms: start.elapsed().as_millis() as u64,
cache_hit: false, })
}
pub fn find_similar(
&self,
file_path: &str,
function_name: Option<&str>,
options: &SearchOptions,
) -> TldrResult<SimilarityReport> {
let query_chunk = self
.chunks
.iter()
.find(|c| {
c.chunk.file_path.to_string_lossy() == file_path
&& (function_name.is_none()
|| c.chunk.function_name.as_deref() == function_name)
})
.ok_or_else(|| TldrError::ChunkNotFound {
file: file_path.to_string(),
function: function_name.map(String::from),
})?;
let candidates: Vec<(usize, &[f32])> = self
.chunks
.iter()
.enumerate()
.filter(|(_, c)| {
c.chunk.file_path.to_string_lossy() != file_path
|| c.chunk.function_name != query_chunk.chunk.function_name
})
.map(|(i, c)| (i, c.embedding.as_slice()))
.collect();
let similar = top_k_similar(
&query_chunk.embedding,
&candidates,
options.top_k,
options.threshold,
);
let results: Vec<SemanticSearchResult> = similar
.into_iter()
.map(|(idx, score)| {
let chunk = &self.chunks[idx].chunk;
let snippet = if options.include_snippet {
make_snippet(&chunk.content, options.snippet_lines)
} else {
String::new()
};
SemanticSearchResult {
file_path: chunk.file_path.clone(),
function_name: chunk.function_name.clone(),
class_name: chunk.class_name.clone(),
score,
line_start: chunk.line_start,
line_end: chunk.line_end,
snippet,
}
})
.collect();
Ok(SimilarityReport {
source: query_chunk.chunk.clone(),
model: self.model,
similar: results,
total_compared: candidates.len(),
exclude_self: true,
})
}
pub fn get_chunk(
&self,
file_path: &str,
function_name: Option<&str>,
) -> Option<&EmbeddedChunk> {
self.chunks.iter().find(|c| {
c.chunk.file_path.to_string_lossy() == file_path
&& (function_name.is_none() || c.chunk.function_name.as_deref() == function_name)
})
}
pub fn len(&self) -> usize {
self.chunks.len()
}
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}
pub fn chunks(&self) -> &[EmbeddedChunk] {
&self.chunks
}
pub fn model(&self) -> EmbeddingModel {
self.model
}
}
fn make_snippet(content: &str, max_lines: usize) -> String {
content
.lines()
.take(max_lines)
.collect::<Vec<_>>()
.join("\n")
}
#[cfg(test)]
mod index_tests {
use super::*;
#[test]
fn search_options_default_values() {
let options = SearchOptions::default();
assert_eq!(options.top_k, 10);
assert!((options.threshold - 0.5).abs() < 1e-6);
assert!(options.include_snippet);
assert_eq!(options.snippet_lines, 5);
}
#[test]
fn build_options_default_values() {
let options = BuildOptions::default();
assert_eq!(options.model, EmbeddingModel::ArcticM);
assert_eq!(options.granularity, ChunkGranularity::Function);
assert!(options.languages.is_none());
assert!(options.show_progress);
assert!(options.use_cache);
}
#[test]
fn make_snippet_limits_lines() {
let content = "line1\nline2\nline3\nline4\nline5\nline6";
let snippet = make_snippet(content, 3);
assert_eq!(snippet, "line1\nline2\nline3");
}
#[test]
fn make_snippet_handles_short_content() {
let content = "line1\nline2";
let snippet = make_snippet(content, 5);
assert_eq!(snippet, "line1\nline2");
}
#[test]
fn make_snippet_handles_empty_content() {
let content = "";
let snippet = make_snippet(content, 5);
assert_eq!(snippet, "");
}
#[test]
#[ignore = "Requires model download"]
fn semantic_index_build_from_directory() {
let temp_dir = tempfile::tempdir().unwrap();
let test_file = temp_dir.path().join("test.py");
std::fs::write(&test_file, "def foo():\n pass\n").unwrap();
let options = BuildOptions {
show_progress: false,
use_cache: false,
..Default::default()
};
let index = SemanticIndex::build(temp_dir.path(), options, None).unwrap();
assert!(!index.is_empty());
}
#[test]
#[ignore = "Requires model download"]
fn semantic_index_search_returns_ranked_results() {
let temp_dir = tempfile::tempdir().unwrap();
std::fs::write(
temp_dir.path().join("config.py"),
"def parse_config():\n pass\n",
)
.unwrap();
std::fs::write(
temp_dir.path().join("loader.py"),
"def load_data():\n pass\n",
)
.unwrap();
let options = BuildOptions {
show_progress: false,
use_cache: false,
..Default::default()
};
let mut index = SemanticIndex::build(temp_dir.path(), options, None).unwrap();
let search_opts = SearchOptions::default();
let report = index.search("parse configuration", &search_opts).unwrap();
if report.results.len() >= 2 {
assert!(report.results[0].score >= report.results[1].score);
}
}
#[test]
#[ignore = "Requires model download"]
fn semantic_index_search_respects_top_k() {
let temp_dir = tempfile::tempdir().unwrap();
for i in 0..5 {
std::fs::write(
temp_dir.path().join(format!("file{}.py", i)),
format!("def func{}():\n pass\n", i),
)
.unwrap();
}
let options = BuildOptions {
show_progress: false,
use_cache: false,
..Default::default()
};
let mut index = SemanticIndex::build(temp_dir.path(), options, None).unwrap();
let search_opts = SearchOptions {
top_k: 2,
threshold: 0.0, ..Default::default()
};
let report = index.search("function", &search_opts).unwrap();
assert!(report.results.len() <= 2);
}
#[test]
#[ignore = "Requires model download"]
fn semantic_index_search_respects_threshold() {
let temp_dir = tempfile::tempdir().unwrap();
std::fs::write(temp_dir.path().join("test.py"), "def foo():\n pass\n").unwrap();
let options = BuildOptions {
show_progress: false,
use_cache: false,
..Default::default()
};
let mut index = SemanticIndex::build(temp_dir.path(), options, None).unwrap();
let search_opts = SearchOptions {
top_k: 10,
threshold: 0.99, ..Default::default()
};
let report = index
.search("completely unrelated query", &search_opts)
.unwrap();
assert!(report.results.iter().all(|r| r.score >= 0.99));
}
#[test]
fn semantic_index_empty_returns_no_results() {
}
#[test]
#[ignore = "Requires model download"]
fn semantic_index_len_returns_chunk_count() {
let temp_dir = tempfile::tempdir().unwrap();
std::fs::write(temp_dir.path().join("a.py"), "def a():\n pass\n").unwrap();
std::fs::write(temp_dir.path().join("b.py"), "def b():\n pass\n").unwrap();
let options = BuildOptions {
show_progress: false,
use_cache: false,
..Default::default()
};
let index = SemanticIndex::build(temp_dir.path(), options, None).unwrap();
assert!(index.len() >= 2); }
#[test]
#[ignore = "Requires model download"]
fn semantic_index_build_uses_batch_embedding() {
let temp_dir = tempfile::tempdir().unwrap();
for i in 0..10 {
std::fs::write(
temp_dir.path().join(format!("mod{}.py", i)),
format!("def func_{}(x):\n return x + {}\n", i, i),
)
.unwrap();
}
let options = BuildOptions {
show_progress: false,
use_cache: false,
..Default::default()
};
let index = SemanticIndex::build(temp_dir.path(), options, None).unwrap();
assert!(
index.len() >= 10,
"Expected at least 10 chunks, got {}",
index.len()
);
for chunk in index.chunks() {
assert_eq!(
chunk.embedding.len(),
768,
"Each chunk should have 768-dim embedding"
);
let norm: f32 = chunk.embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-4,
"Embedding should be normalized, got norm={}",
norm
);
}
}
#[test]
#[ignore = "Requires model download"]
fn semantic_index_build_batch_matches_sequential() {
let temp_dir = tempfile::tempdir().unwrap();
std::fs::write(
temp_dir.path().join("parser.py"),
"def parse_config(path):\n with open(path) as f:\n return f.read()\n",
)
.unwrap();
std::fs::write(
temp_dir.path().join("loader.py"),
"def load_data(file):\n return read(file)\n",
)
.unwrap();
std::fs::write(
temp_dir.path().join("math.py"),
"def add_numbers(a, b):\n return a + b\n",
)
.unwrap();
let options = BuildOptions {
show_progress: false,
use_cache: false,
..Default::default()
};
let mut index = SemanticIndex::build(temp_dir.path(), options, None).unwrap();
let search_opts = SearchOptions {
top_k: 3,
threshold: 0.0,
..Default::default()
};
let report = index.search("parse configuration", &search_opts).unwrap();
assert!(!report.results.is_empty(), "Should have results");
let parser_result = report
.results
.iter()
.find(|r| r.function_name.as_deref() == Some("parse_config"));
let math_result = report
.results
.iter()
.find(|r| r.function_name.as_deref() == Some("add_numbers"));
if let (Some(p), Some(m)) = (parser_result, math_result) {
assert!(
p.score > m.score,
"parse_config ({}) should score higher than add_numbers ({}) for 'parse configuration'",
p.score,
m.score
);
}
}
#[test]
#[ignore = "Requires model download"]
fn semantic_index_find_similar() {
let temp_dir = tempfile::tempdir().unwrap();
std::fs::write(
temp_dir.path().join("config.py"),
"def parse_config(path):\n return read(path)\n",
)
.unwrap();
std::fs::write(
temp_dir.path().join("settings.py"),
"def load_settings(file):\n return read(file)\n",
)
.unwrap();
std::fs::write(
temp_dir.path().join("unrelated.py"),
"def calculate_sum(a, b):\n return a + b\n",
)
.unwrap();
let options = BuildOptions {
show_progress: false,
use_cache: false,
..Default::default()
};
let index = SemanticIndex::build(temp_dir.path(), options, None).unwrap();
let search_opts = SearchOptions {
top_k: 5,
threshold: 0.0,
..Default::default()
};
let report = index
.find_similar("config.py", Some("parse_config"), &search_opts)
.unwrap();
assert!(report.exclude_self);
assert!(!report.similar.iter().any(|r| {
r.file_path.to_string_lossy() == "config.py"
&& r.function_name.as_deref() == Some("parse_config")
}));
}
}