1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::{detect_language, extract_symbols_from_tree, grammar_for};
3use crate::symbols::{Symbol, SymbolKind};
4use crate::{slog_info, slog_warn};
5
6use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
7use rayon::prelude::*;
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::env;
12use std::fmt::Display;
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16use std::time::SystemTime;
17use tree_sitter::Parser;
18use url::Url;
19
20const DEFAULT_DIMENSION: usize = 384;
21const MAX_ENTRIES: usize = 1_000_000;
22const MAX_DIMENSION: usize = 1024;
23const F32_BYTES: usize = std::mem::size_of::<f32>();
24const HEADER_BYTES_V1: usize = 9;
25const HEADER_BYTES_V2: usize = 13;
26const ONNX_RUNTIME_INSTALL_HINT: &str =
27 "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
28
29const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
30const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
31const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
36const SEMANTIC_INDEX_VERSION_V4: u8 = 4;
39const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
40const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
41const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
43const DEFAULT_MAX_BATCH_SIZE: usize = 64;
44const FALLBACK_BACKEND: &str = "none";
45const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
46const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SemanticIndexFingerprint {
50 pub backend: String,
51 pub model: String,
52 #[serde(default)]
53 pub base_url: String,
54 pub dimension: usize,
55}
56
57impl SemanticIndexFingerprint {
58 fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
59 let base_url = config
62 .base_url
63 .as_ref()
64 .and_then(|u| normalize_base_url(u).ok())
65 .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
66 Self {
67 backend: config.backend.as_str().to_string(),
68 model: config.model.clone(),
69 base_url,
70 dimension,
71 }
72 }
73
74 pub fn as_string(&self) -> String {
75 serde_json::to_string(self).unwrap_or_else(|_| String::new())
76 }
77
78 fn matches_expected(&self, expected: &str) -> bool {
79 let encoded = self.as_string();
80 !encoded.is_empty() && encoded == expected
81 }
82}
83
84enum SemanticEmbeddingEngine {
85 Fastembed(TextEmbedding),
86 OpenAiCompatible {
87 client: Client,
88 model: String,
89 base_url: String,
90 api_key: Option<String>,
91 },
92 Ollama {
93 client: Client,
94 model: String,
95 base_url: String,
96 },
97}
98
99pub struct SemanticEmbeddingModel {
100 backend: SemanticBackend,
101 model: String,
102 base_url: Option<String>,
103 timeout_ms: u64,
104 max_batch_size: usize,
105 dimension: Option<usize>,
106 engine: SemanticEmbeddingEngine,
107}
108
109pub type EmbeddingModel = SemanticEmbeddingModel;
110
111fn validate_embedding_batch(
112 vectors: &[Vec<f32>],
113 expected_count: usize,
114 context: &str,
115) -> Result<(), String> {
116 if expected_count > 0 && vectors.is_empty() {
117 return Err(format!(
118 "{context} returned no vectors for {expected_count} inputs"
119 ));
120 }
121
122 if vectors.len() != expected_count {
123 return Err(format!(
124 "{context} returned {} vectors for {} inputs",
125 vectors.len(),
126 expected_count
127 ));
128 }
129
130 let Some(first_vector) = vectors.first() else {
131 return Ok(());
132 };
133 let expected_dimension = first_vector.len();
134 for (index, vector) in vectors.iter().enumerate() {
135 if vector.len() != expected_dimension {
136 return Err(format!(
137 "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
138 vector.len()
139 ));
140 }
141 }
142
143 Ok(())
144}
145
146fn normalize_base_url(raw: &str) -> Result<String, String> {
147 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
148 let scheme = parsed.scheme();
149 if scheme != "http" && scheme != "https" {
150 return Err(format!(
151 "unsupported URL scheme '{}' — only http:// and https:// are allowed",
152 scheme
153 ));
154 }
155 Ok(parsed.to_string().trim_end_matches('/').to_string())
156}
157
158fn build_openai_embeddings_endpoint(base_url: &str) -> String {
159 if base_url.ends_with("/v1") {
160 format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
161 } else {
162 format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
163 }
164}
165
166fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
167 if base_url.ends_with("/api") {
168 format!("{base_url}/embed")
169 } else {
170 format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
171 }
172}
173
174fn normalize_api_key(value: Option<String>) -> Option<String> {
175 value.and_then(|token| {
176 let token = token.trim();
177 if token.is_empty() {
178 None
179 } else {
180 Some(token.to_string())
181 }
182 })
183}
184
185fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
186 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
187}
188
189fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
190 error.is_connect()
191}
192
193fn sleep_before_embedding_retry(attempt_index: usize) {
194 if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
195 std::thread::sleep(Duration::from_millis(*delay_ms));
196 }
197}
198
199fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
200where
201 F: FnMut() -> reqwest::blocking::RequestBuilder,
202{
203 for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
204 let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
205
206 let response = match make_request().send() {
207 Ok(response) => response,
208 Err(error) => {
209 if !last_attempt && is_retryable_embedding_error(&error) {
210 sleep_before_embedding_retry(attempt_index);
211 continue;
212 }
213 return Err(format!("{backend_label} request failed: {error}"));
214 }
215 };
216
217 let status = response.status();
218 let raw = match response.text() {
219 Ok(raw) => raw,
220 Err(error) => {
221 if !last_attempt && is_retryable_embedding_error(&error) {
222 sleep_before_embedding_retry(attempt_index);
223 continue;
224 }
225 return Err(format!("{backend_label} response read failed: {error}"));
226 }
227 };
228
229 if status.is_success() {
230 return Ok(raw);
231 }
232
233 if !last_attempt && is_retryable_embedding_status(status) {
234 sleep_before_embedding_retry(attempt_index);
235 continue;
236 }
237
238 return Err(format!(
239 "{backend_label} request failed (HTTP {}): {}",
240 status, raw
241 ));
242 }
243
244 unreachable!("embedding request retries exhausted without returning")
245}
246
247impl SemanticEmbeddingModel {
248 pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
249 let timeout_ms = if config.timeout_ms == 0 {
250 DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
251 } else {
252 config.timeout_ms
253 };
254
255 let max_batch_size = if config.max_batch_size == 0 {
256 DEFAULT_MAX_BATCH_SIZE
257 } else {
258 config.max_batch_size
259 };
260
261 let api_key_env = normalize_api_key(config.api_key_env.clone());
262 let model = config.model.clone();
263
264 let client = Client::builder()
265 .timeout(Duration::from_millis(timeout_ms))
266 .redirect(reqwest::redirect::Policy::none())
267 .build()
268 .map_err(|error| format!("failed to configure embedding client: {error}"))?;
269
270 let engine = match config.backend {
271 SemanticBackend::Fastembed => {
272 SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
273 }
274 SemanticBackend::OpenAiCompatible => {
275 let raw = config.base_url.as_ref().ok_or_else(|| {
276 "base_url is required for openai_compatible backend".to_string()
277 })?;
278 let base_url = normalize_base_url(raw)?;
279
280 let api_key = match api_key_env {
281 Some(var_name) => Some(env::var(&var_name).map_err(|_| {
282 format!("missing api_key_env '{var_name}' for openai_compatible backend")
283 })?),
284 None => None,
285 };
286
287 SemanticEmbeddingEngine::OpenAiCompatible {
288 client,
289 model,
290 base_url,
291 api_key,
292 }
293 }
294 SemanticBackend::Ollama => {
295 let raw = config
296 .base_url
297 .as_ref()
298 .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
299 let base_url = normalize_base_url(raw)?;
300
301 SemanticEmbeddingEngine::Ollama {
302 client,
303 model,
304 base_url,
305 }
306 }
307 };
308
309 Ok(Self {
310 backend: config.backend,
311 model: config.model.clone(),
312 base_url: config.base_url.clone(),
313 timeout_ms,
314 max_batch_size,
315 dimension: None,
316 engine,
317 })
318 }
319
320 pub fn backend(&self) -> SemanticBackend {
321 self.backend
322 }
323
324 pub fn model(&self) -> &str {
325 &self.model
326 }
327
328 pub fn base_url(&self) -> Option<&str> {
329 self.base_url.as_deref()
330 }
331
332 pub fn max_batch_size(&self) -> usize {
333 self.max_batch_size
334 }
335
336 pub fn timeout_ms(&self) -> u64 {
337 self.timeout_ms
338 }
339
340 pub fn fingerprint(
341 &mut self,
342 config: &SemanticBackendConfig,
343 ) -> Result<SemanticIndexFingerprint, String> {
344 let dimension = self.dimension()?;
345 Ok(SemanticIndexFingerprint::from_config(config, dimension))
346 }
347
348 pub fn dimension(&mut self) -> Result<usize, String> {
349 if let Some(dimension) = self.dimension {
350 return Ok(dimension);
351 }
352
353 let dimension = match &mut self.engine {
354 SemanticEmbeddingEngine::Fastembed(model) => {
355 let vectors = model
356 .embed(vec!["semantic index fingerprint probe".to_string()], None)
357 .map_err(|error| format_embedding_init_error(error.to_string()))?;
358 vectors
359 .first()
360 .map(|v| v.len())
361 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
362 }
363 SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
364 let vectors =
365 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
366 vectors
367 .first()
368 .map(|v| v.len())
369 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
370 }
371 SemanticEmbeddingEngine::Ollama { .. } => {
372 let vectors =
373 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
374 vectors
375 .first()
376 .map(|v| v.len())
377 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
378 }
379 };
380
381 self.dimension = Some(dimension);
382 Ok(dimension)
383 }
384
385 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
386 self.embed_texts(texts)
387 }
388
389 fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
390 match &mut self.engine {
391 SemanticEmbeddingEngine::Fastembed(model) => model
392 .embed(texts, None::<usize>)
393 .map_err(|error| format_embedding_init_error(error.to_string()))
394 .map_err(|error| format!("failed to embed batch: {error}")),
395 SemanticEmbeddingEngine::OpenAiCompatible {
396 client,
397 model,
398 base_url,
399 api_key,
400 } => {
401 let expected_text_count = texts.len();
402 let endpoint = build_openai_embeddings_endpoint(base_url);
403 let body = serde_json::json!({
404 "input": texts,
405 "model": model,
406 });
407
408 let raw = send_embedding_request(
409 || {
410 let mut request = client
411 .post(&endpoint)
412 .json(&body)
413 .header("Content-Type", "application/json");
414
415 if let Some(api_key) = api_key {
416 request = request.header("Authorization", format!("Bearer {api_key}"));
417 }
418
419 request
420 },
421 "openai compatible",
422 )?;
423
424 #[derive(Deserialize)]
425 struct OpenAiResponse {
426 data: Vec<OpenAiEmbeddingResult>,
427 }
428
429 #[derive(Deserialize)]
430 struct OpenAiEmbeddingResult {
431 embedding: Vec<f32>,
432 index: Option<u32>,
433 }
434
435 let parsed: OpenAiResponse = serde_json::from_str(&raw)
436 .map_err(|error| format!("invalid openai compatible response: {error}"))?;
437 if parsed.data.len() != expected_text_count {
438 return Err(format!(
439 "openai compatible response returned {} embeddings for {} inputs",
440 parsed.data.len(),
441 expected_text_count
442 ));
443 }
444
445 let mut vectors = vec![Vec::new(); parsed.data.len()];
446 for (i, item) in parsed.data.into_iter().enumerate() {
447 let index = item.index.unwrap_or(i as u32) as usize;
448 if index >= vectors.len() {
449 return Err(
450 "openai compatible response contains invalid vector index".to_string()
451 );
452 }
453 vectors[index] = item.embedding;
454 }
455
456 for vector in &vectors {
457 if vector.is_empty() {
458 return Err(
459 "openai compatible response contained missing vectors".to_string()
460 );
461 }
462 }
463
464 self.dimension = vectors.first().map(Vec::len);
465 Ok(vectors)
466 }
467 SemanticEmbeddingEngine::Ollama {
468 client,
469 model,
470 base_url,
471 } => {
472 let expected_text_count = texts.len();
473 let endpoint = build_ollama_embeddings_endpoint(base_url);
474
475 #[derive(Serialize)]
476 struct OllamaPayload<'a> {
477 model: &'a str,
478 input: Vec<String>,
479 }
480
481 let payload = OllamaPayload {
482 model,
483 input: texts,
484 };
485
486 let raw = send_embedding_request(
487 || {
488 client
489 .post(&endpoint)
490 .json(&payload)
491 .header("Content-Type", "application/json")
492 },
493 "ollama",
494 )?;
495
496 #[derive(Deserialize)]
497 struct OllamaResponse {
498 embeddings: Vec<Vec<f32>>,
499 }
500
501 let parsed: OllamaResponse = serde_json::from_str(&raw)
502 .map_err(|error| format!("invalid ollama response: {error}"))?;
503 if parsed.embeddings.is_empty() {
504 return Err("ollama response returned no embeddings".to_string());
505 }
506 if parsed.embeddings.len() != expected_text_count {
507 return Err(format!(
508 "ollama response returned {} embeddings for {} inputs",
509 parsed.embeddings.len(),
510 expected_text_count
511 ));
512 }
513
514 let vectors = parsed.embeddings;
515 for vector in &vectors {
516 if vector.is_empty() {
517 return Err("ollama response contained empty embeddings".to_string());
518 }
519 }
520
521 self.dimension = vectors.first().map(Vec::len);
522 Ok(vectors)
523 }
524 }
525 }
526}
527
528pub fn pre_validate_onnx_runtime() -> Result<(), String> {
532 let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
533
534 #[cfg(any(target_os = "linux", target_os = "macos"))]
535 {
536 #[cfg(target_os = "linux")]
537 let default_name = "libonnxruntime.so";
538 #[cfg(target_os = "macos")]
539 let default_name = "libonnxruntime.dylib";
540
541 let lib_name = dylib_path.as_deref().unwrap_or(default_name);
542
543 unsafe {
544 let c_name = std::ffi::CString::new(lib_name)
545 .map_err(|e| format!("invalid library path: {}", e))?;
546 let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
547 if handle.is_null() {
548 let err = libc::dlerror();
549 let msg = if err.is_null() {
550 "unknown dlopen error".to_string()
551 } else {
552 std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
553 };
554 return Err(format!(
555 "ONNX Runtime not found. dlopen('{}') failed: {}. \
556 Run `bunx @cortexkit/aft-opencode@latest doctor` to diagnose.",
557 lib_name, msg
558 ));
559 }
560
561 let detected_version = detect_ort_version_from_path(lib_name);
564
565 libc::dlclose(handle);
566
567 if let Some(ref version) = detected_version {
569 let parts: Vec<&str> = version.split('.').collect();
570 if let (Some(major), Some(minor)) = (
571 parts.first().and_then(|s| s.parse::<u32>().ok()),
572 parts.get(1).and_then(|s| s.parse::<u32>().ok()),
573 ) {
574 if major != 1 || minor < 20 {
575 return Err(format!(
576 "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
577 Solutions:\n\
578 1. Remove the old library and restart (AFT auto-downloads the correct version):\n\
579 {}\n\
580 2. Or install ONNX Runtime 1.24: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
581 3. Run `bunx @cortexkit/aft-opencode@latest doctor` for full diagnostics.",
582 version,
583 lib_name,
584 suggest_removal_command(lib_name),
585 ));
586 }
587 }
588 }
589 }
590 }
591
592 #[cfg(target_os = "windows")]
593 {
594 let _ = dylib_path;
596 }
597
598 Ok(())
599}
600
601fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
604 let path = std::path::Path::new(lib_path);
605
606 for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
608 .into_iter()
609 .flatten()
610 {
611 if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
612 if let Some(version) = extract_version_from_filename(name) {
613 return Some(version);
614 }
615 }
616 }
617
618 if let Some(parent) = path.parent() {
620 if let Ok(entries) = std::fs::read_dir(parent) {
621 for entry in entries.flatten() {
622 if let Some(name) = entry.file_name().to_str() {
623 if name.starts_with("libonnxruntime") {
624 if let Some(version) = extract_version_from_filename(name) {
625 return Some(version);
626 }
627 }
628 }
629 }
630 }
631 }
632
633 None
634}
635
636fn extract_version_from_filename(name: &str) -> Option<String> {
638 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
640 re.find(name).map(|m| m.as_str().to_string())
641}
642
643fn suggest_removal_command(lib_path: &str) -> String {
644 if lib_path.starts_with("/usr/local/lib")
645 || lib_path == "libonnxruntime.so"
646 || lib_path == "libonnxruntime.dylib"
647 {
648 #[cfg(target_os = "linux")]
649 return " sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
650 #[cfg(target_os = "macos")]
651 return " sudo rm /usr/local/lib/libonnxruntime*".to_string();
652 #[cfg(target_os = "windows")]
653 return " Delete the ONNX Runtime DLL from your PATH".to_string();
654 }
655 format!(" rm '{}'", lib_path)
656}
657
658pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
659 pre_validate_onnx_runtime()?;
661
662 let selected_model = match model {
663 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
664 _ => {
665 return Err(format!(
666 "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
667 model
668 ))
669 }
670 };
671
672 TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
673}
674
675pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
676 if message.trim_start().starts_with("ONNX Runtime not found.") {
677 return true;
678 }
679
680 let message = message.to_ascii_lowercase();
681 let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
682 .iter()
683 .any(|pattern| message.contains(pattern));
684 let mentions_dynamic_load_failure = [
685 "shared library",
686 "dynamic library",
687 "failed to load",
688 "could not load",
689 "unable to load",
690 "dlopen",
691 "loadlibrary",
692 "no such file",
693 "not found",
694 ]
695 .iter()
696 .any(|pattern| message.contains(pattern));
697
698 mentions_onnx_runtime && mentions_dynamic_load_failure
699}
700
701fn format_embedding_init_error(error: impl Display) -> String {
702 let message = error.to_string();
703
704 if is_onnx_runtime_unavailable(&message) {
705 return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
706 }
707
708 format!("failed to initialize semantic embedding model: {message}")
709}
710
711#[derive(Debug, Clone)]
713pub struct SemanticChunk {
714 pub file: PathBuf,
716 pub name: String,
718 pub kind: SymbolKind,
720 pub start_line: u32,
722 pub end_line: u32,
723 pub exported: bool,
725 pub embed_text: String,
727 pub snippet: String,
729}
730
731#[derive(Debug)]
733struct EmbeddingEntry {
734 chunk: SemanticChunk,
735 vector: Vec<f32>,
736}
737
738#[derive(Debug)]
740pub struct SemanticIndex {
741 entries: Vec<EmbeddingEntry>,
742 file_mtimes: HashMap<PathBuf, SystemTime>,
744 dimension: usize,
746 fingerprint: Option<SemanticIndexFingerprint>,
747}
748
749#[derive(Debug)]
751pub struct SemanticResult {
752 pub file: PathBuf,
753 pub name: String,
754 pub kind: SymbolKind,
755 pub start_line: u32,
756 pub end_line: u32,
757 pub exported: bool,
758 pub snippet: String,
759 pub score: f32,
760}
761
762impl SemanticIndex {
763 pub fn new() -> Self {
764 Self {
765 entries: Vec::new(),
766 file_mtimes: HashMap::new(),
767 dimension: DEFAULT_DIMENSION, fingerprint: None,
769 }
770 }
771
772 pub fn entry_count(&self) -> usize {
774 self.entries.len()
775 }
776
777 pub fn status_label(&self) -> &'static str {
779 if self.entries.is_empty() {
780 "empty"
781 } else {
782 "ready"
783 }
784 }
785
786 fn collect_chunks(
787 project_root: &Path,
788 files: &[PathBuf],
789 ) -> (Vec<SemanticChunk>, HashMap<PathBuf, SystemTime>) {
790 let per_file: Vec<(PathBuf, SystemTime, Vec<SemanticChunk>)> = files
791 .par_iter()
792 .map_init(HashMap::new, |parsers, file| {
793 let mtime = std::fs::metadata(file)
794 .and_then(|m| m.modified())
795 .unwrap_or(SystemTime::UNIX_EPOCH);
796
797 let chunks = collect_file_chunks(project_root, file, parsers).unwrap_or_default();
798
799 (file.clone(), mtime, chunks)
800 })
801 .collect();
802
803 let mut chunks: Vec<SemanticChunk> = Vec::new();
804 let mut file_mtimes: HashMap<PathBuf, SystemTime> = HashMap::new();
805
806 for (file, mtime, file_chunks) in per_file {
807 file_mtimes.insert(file, mtime);
808 chunks.extend(file_chunks);
809 }
810
811 (chunks, file_mtimes)
812 }
813
814 fn build_from_chunks<F, P>(
815 chunks: Vec<SemanticChunk>,
816 file_mtimes: HashMap<PathBuf, SystemTime>,
817 embed_fn: &mut F,
818 max_batch_size: usize,
819 mut progress: Option<&mut P>,
820 ) -> Result<Self, String>
821 where
822 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
823 P: FnMut(usize, usize),
824 {
825 let total_chunks = chunks.len();
826
827 if chunks.is_empty() {
828 return Ok(Self {
829 entries: Vec::new(),
830 file_mtimes,
831 dimension: DEFAULT_DIMENSION,
832 fingerprint: None,
833 });
834 }
835
836 let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
838 let mut expected_dimension: Option<usize> = None;
839 let batch_size = max_batch_size.max(1);
840 for batch_start in (0..chunks.len()).step_by(batch_size) {
841 let batch_end = (batch_start + batch_size).min(chunks.len());
842 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
843 .iter()
844 .map(|c| c.embed_text.clone())
845 .collect();
846
847 let vectors = embed_fn(batch_texts)?;
848 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
849
850 if let Some(dim) = vectors.first().map(|v| v.len()) {
852 match expected_dimension {
853 None => expected_dimension = Some(dim),
854 Some(expected) if dim != expected => {
855 return Err(format!(
856 "embedding dimension changed across batches: expected {expected}, got {dim}"
857 ));
858 }
859 _ => {}
860 }
861 }
862
863 for (i, vector) in vectors.into_iter().enumerate() {
864 let chunk_idx = batch_start + i;
865 entries.push(EmbeddingEntry {
866 chunk: chunks[chunk_idx].clone(),
867 vector,
868 });
869 }
870
871 if let Some(callback) = progress.as_mut() {
872 callback(entries.len(), total_chunks);
873 }
874 }
875
876 let dimension = entries
877 .first()
878 .map(|e| e.vector.len())
879 .unwrap_or(DEFAULT_DIMENSION);
880
881 Ok(Self {
882 entries,
883 file_mtimes,
884 dimension,
885 fingerprint: None,
886 })
887 }
888
889 pub fn build<F>(
892 project_root: &Path,
893 files: &[PathBuf],
894 embed_fn: &mut F,
895 max_batch_size: usize,
896 ) -> Result<Self, String>
897 where
898 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
899 {
900 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
901 Self::build_from_chunks(
902 chunks,
903 file_mtimes,
904 embed_fn,
905 max_batch_size,
906 Option::<&mut fn(usize, usize)>::None,
907 )
908 }
909
910 pub fn build_with_progress<F, P>(
912 project_root: &Path,
913 files: &[PathBuf],
914 embed_fn: &mut F,
915 max_batch_size: usize,
916 progress: &mut P,
917 ) -> Result<Self, String>
918 where
919 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
920 P: FnMut(usize, usize),
921 {
922 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
923 let total_chunks = chunks.len();
924 progress(0, total_chunks);
925 Self::build_from_chunks(
926 chunks,
927 file_mtimes,
928 embed_fn,
929 max_batch_size,
930 Some(progress),
931 )
932 }
933
934 pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
936 if self.entries.is_empty() || query_vector.len() != self.dimension {
937 return Vec::new();
938 }
939
940 let mut scored: Vec<(f32, usize)> = self
941 .entries
942 .iter()
943 .enumerate()
944 .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
945 .collect();
946
947 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
949
950 scored
951 .into_iter()
952 .take(top_k)
953 .filter(|(score, _)| *score > 0.0)
954 .map(|(score, idx)| {
955 let entry = &self.entries[idx];
956 SemanticResult {
957 file: entry.chunk.file.clone(),
958 name: entry.chunk.name.clone(),
959 kind: entry.chunk.kind.clone(),
960 start_line: entry.chunk.start_line,
961 end_line: entry.chunk.end_line,
962 exported: entry.chunk.exported,
963 snippet: entry.chunk.snippet.clone(),
964 score,
965 }
966 })
967 .collect()
968 }
969
970 pub fn len(&self) -> usize {
972 self.entries.len()
973 }
974
975 pub fn is_file_stale(&self, file: &Path) -> bool {
977 match self.file_mtimes.get(file) {
978 None => true,
979 Some(stored_mtime) => match fs::metadata(file).and_then(|m| m.modified()) {
980 Ok(current_mtime) => *stored_mtime != current_mtime,
981 Err(_) => true,
982 },
983 }
984 }
985
986 pub fn count_stale_files(&self) -> usize {
987 self.file_mtimes
988 .keys()
989 .filter(|path| self.is_file_stale(path))
990 .count()
991 }
992
993 pub fn remove_file(&mut self, file: &Path) {
995 self.invalidate_file(file);
996 }
997
998 pub fn invalidate_file(&mut self, file: &Path) {
999 self.entries.retain(|e| e.chunk.file != file);
1000 self.file_mtimes.remove(file);
1001 }
1002
1003 pub fn dimension(&self) -> usize {
1005 self.dimension
1006 }
1007
1008 pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1009 self.fingerprint.as_ref()
1010 }
1011
1012 pub fn backend_label(&self) -> Option<&str> {
1013 self.fingerprint.as_ref().map(|f| f.backend.as_str())
1014 }
1015
1016 pub fn model_label(&self) -> Option<&str> {
1017 self.fingerprint.as_ref().map(|f| f.model.as_str())
1018 }
1019
1020 pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1021 self.fingerprint = Some(fingerprint);
1022 }
1023
1024 pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1026 if self.entries.is_empty() {
1029 slog_info!("skipping semantic index persistence (0 entries)");
1030 return;
1031 }
1032 let dir = storage_dir.join("semantic").join(project_key);
1033 if let Err(e) = fs::create_dir_all(&dir) {
1034 slog_warn!("failed to create semantic cache dir: {}", e);
1035 return;
1036 }
1037 let data_path = dir.join("semantic.bin");
1038 let tmp_path = dir.join("semantic.bin.tmp");
1039 let bytes = self.to_bytes();
1040 if let Err(e) = fs::write(&tmp_path, &bytes) {
1041 slog_warn!("failed to write semantic index: {}", e);
1042 let _ = fs::remove_file(&tmp_path);
1043 return;
1044 }
1045 if let Err(e) = fs::rename(&tmp_path, &data_path) {
1046 slog_warn!("failed to rename semantic index: {}", e);
1047 let _ = fs::remove_file(&tmp_path);
1048 return;
1049 }
1050 slog_info!(
1051 "semantic index persisted: {} entries, {:.1} KB",
1052 self.entries.len(),
1053 bytes.len() as f64 / 1024.0
1054 );
1055 }
1056
1057 pub fn read_from_disk(
1059 storage_dir: &Path,
1060 project_key: &str,
1061 expected_fingerprint: Option<&str>,
1062 ) -> Option<Self> {
1063 let data_path = storage_dir
1064 .join("semantic")
1065 .join(project_key)
1066 .join("semantic.bin");
1067 let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1068 if file_len < HEADER_BYTES_V1 {
1069 slog_warn!(
1070 "corrupt semantic index (too small: {} bytes), removing",
1071 file_len
1072 );
1073 let _ = fs::remove_file(&data_path);
1074 return None;
1075 }
1076
1077 let bytes = fs::read(&data_path).ok()?;
1078 let version = bytes[0];
1079 if version != SEMANTIC_INDEX_VERSION_V4 {
1080 slog_info!(
1081 "cached semantic index version {} is older than {}, rebuilding",
1082 version,
1083 SEMANTIC_INDEX_VERSION_V4
1084 );
1085 let _ = fs::remove_file(&data_path);
1086 return None;
1087 }
1088 match Self::from_bytes(&bytes) {
1089 Ok(index) => {
1090 if index.entries.is_empty() {
1091 slog_info!("cached semantic index is empty, will rebuild");
1092 let _ = fs::remove_file(&data_path);
1093 return None;
1094 }
1095 if let Some(expected) = expected_fingerprint {
1096 let matches = index
1097 .fingerprint()
1098 .map(|fingerprint| fingerprint.matches_expected(expected))
1099 .unwrap_or(false);
1100 if !matches {
1101 slog_info!("cached semantic index fingerprint mismatch, rebuilding");
1102 let _ = fs::remove_file(&data_path);
1103 return None;
1104 }
1105 }
1106 slog_info!(
1107 "loaded semantic index from disk: {} entries",
1108 index.entries.len()
1109 );
1110 Some(index)
1111 }
1112 Err(e) => {
1113 slog_warn!("corrupt semantic index, rebuilding: {}", e);
1114 let _ = fs::remove_file(&data_path);
1115 None
1116 }
1117 }
1118 }
1119
1120 pub fn to_bytes(&self) -> Vec<u8> {
1122 let mut buf = Vec::new();
1123 let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1124 let encoded = fingerprint.as_string();
1125 if encoded.is_empty() {
1126 None
1127 } else {
1128 Some(encoded.into_bytes())
1129 }
1130 });
1131
1132 let version = SEMANTIC_INDEX_VERSION_V4;
1145 buf.push(version);
1146 buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1147 buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1148 let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1149 buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1150 buf.extend_from_slice(fp_bytes_ref);
1151
1152 buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1155 for (path, mtime) in &self.file_mtimes {
1156 let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1157 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1158 buf.extend_from_slice(&path_bytes);
1159 let duration = mtime
1160 .duration_since(SystemTime::UNIX_EPOCH)
1161 .unwrap_or_default();
1162 buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1163 buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1164 }
1165
1166 for entry in &self.entries {
1168 let c = &entry.chunk;
1169
1170 let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1172 buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1173 buf.extend_from_slice(&file_bytes);
1174
1175 let name_bytes = c.name.as_bytes();
1177 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1178 buf.extend_from_slice(name_bytes);
1179
1180 buf.push(symbol_kind_to_u8(&c.kind));
1182
1183 buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1185 buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1186 buf.push(c.exported as u8);
1187
1188 let snippet_bytes = c.snippet.as_bytes();
1190 buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1191 buf.extend_from_slice(snippet_bytes);
1192
1193 let embed_bytes = c.embed_text.as_bytes();
1195 buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1196 buf.extend_from_slice(embed_bytes);
1197
1198 for &val in &entry.vector {
1200 buf.extend_from_slice(&val.to_le_bytes());
1201 }
1202 }
1203
1204 buf
1205 }
1206
1207 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1209 let mut pos = 0;
1210
1211 if data.len() < HEADER_BYTES_V1 {
1212 return Err("data too short".to_string());
1213 }
1214
1215 let version = data[pos];
1216 pos += 1;
1217 if version != SEMANTIC_INDEX_VERSION_V1
1218 && version != SEMANTIC_INDEX_VERSION_V2
1219 && version != SEMANTIC_INDEX_VERSION_V3
1220 && version != SEMANTIC_INDEX_VERSION_V4
1221 {
1222 return Err(format!("unsupported version: {}", version));
1223 }
1224 if (version == SEMANTIC_INDEX_VERSION_V2
1228 || version == SEMANTIC_INDEX_VERSION_V3
1229 || version == SEMANTIC_INDEX_VERSION_V4)
1230 && data.len() < HEADER_BYTES_V2
1231 {
1232 return Err("data too short for semantic index v2/v3/v4 header".to_string());
1233 }
1234
1235 let dimension = read_u32(data, &mut pos)? as usize;
1236 let entry_count = read_u32(data, &mut pos)? as usize;
1237 if dimension == 0 || dimension > MAX_DIMENSION {
1238 return Err(format!("invalid embedding dimension: {}", dimension));
1239 }
1240 if entry_count > MAX_ENTRIES {
1241 return Err(format!("too many semantic index entries: {}", entry_count));
1242 }
1243
1244 let has_fingerprint_field = version == SEMANTIC_INDEX_VERSION_V2
1250 || version == SEMANTIC_INDEX_VERSION_V3
1251 || version == SEMANTIC_INDEX_VERSION_V4;
1252 let fingerprint = if has_fingerprint_field {
1253 let fingerprint_len = read_u32(data, &mut pos)? as usize;
1254 if pos + fingerprint_len > data.len() {
1255 return Err("unexpected end of data reading fingerprint".to_string());
1256 }
1257 if fingerprint_len == 0 {
1258 None
1259 } else {
1260 let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1261 pos += fingerprint_len;
1262 Some(
1263 serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1264 .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1265 )
1266 }
1267 } else {
1268 None
1269 };
1270
1271 let mtime_count = read_u32(data, &mut pos)? as usize;
1273 if mtime_count > MAX_ENTRIES {
1274 return Err(format!("too many semantic file mtimes: {}", mtime_count));
1275 }
1276
1277 let vector_bytes = entry_count
1278 .checked_mul(dimension)
1279 .and_then(|count| count.checked_mul(F32_BYTES))
1280 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1281 if vector_bytes > data.len().saturating_sub(pos) {
1282 return Err("semantic index vectors exceed available data".to_string());
1283 }
1284
1285 let mut file_mtimes = HashMap::with_capacity(mtime_count);
1286 for _ in 0..mtime_count {
1287 let path = read_string(data, &mut pos)?;
1288 let secs = read_u64(data, &mut pos)?;
1289 let nanos =
1295 if version == SEMANTIC_INDEX_VERSION_V3 || version == SEMANTIC_INDEX_VERSION_V4 {
1296 read_u32(data, &mut pos)?
1297 } else {
1298 0
1299 };
1300 if nanos >= 1_000_000_000 {
1307 return Err(format!(
1308 "invalid semantic mtime: nanos {} >= 1_000_000_000",
1309 nanos
1310 ));
1311 }
1312 let duration = std::time::Duration::new(secs, nanos);
1313 let mtime = SystemTime::UNIX_EPOCH
1314 .checked_add(duration)
1315 .ok_or_else(|| {
1316 format!(
1317 "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1318 secs, nanos
1319 )
1320 })?;
1321 file_mtimes.insert(PathBuf::from(path), mtime);
1322 }
1323
1324 let mut entries = Vec::with_capacity(entry_count);
1326 for _ in 0..entry_count {
1327 let file = PathBuf::from(read_string(data, &mut pos)?);
1328 let name = read_string(data, &mut pos)?;
1329
1330 if pos >= data.len() {
1331 return Err("unexpected end of data".to_string());
1332 }
1333 let kind = u8_to_symbol_kind(data[pos]);
1334 pos += 1;
1335
1336 let start_line = read_u32(data, &mut pos)?;
1337 let end_line = read_u32(data, &mut pos)?;
1338
1339 if pos >= data.len() {
1340 return Err("unexpected end of data".to_string());
1341 }
1342 let exported = data[pos] != 0;
1343 pos += 1;
1344
1345 let snippet = read_string(data, &mut pos)?;
1346 let embed_text = read_string(data, &mut pos)?;
1347
1348 let vec_bytes = dimension
1350 .checked_mul(F32_BYTES)
1351 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1352 if pos + vec_bytes > data.len() {
1353 return Err("unexpected end of data reading vector".to_string());
1354 }
1355 let mut vector = Vec::with_capacity(dimension);
1356 for _ in 0..dimension {
1357 let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1358 vector.push(f32::from_le_bytes(bytes));
1359 pos += 4;
1360 }
1361
1362 entries.push(EmbeddingEntry {
1363 chunk: SemanticChunk {
1364 file,
1365 name,
1366 kind,
1367 start_line,
1368 end_line,
1369 exported,
1370 embed_text,
1371 snippet,
1372 },
1373 vector,
1374 });
1375 }
1376
1377 Ok(Self {
1378 entries,
1379 file_mtimes,
1380 dimension,
1381 fingerprint,
1382 })
1383 }
1384}
1385
1386fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1388 let relative = file
1389 .strip_prefix(project_root)
1390 .unwrap_or(file)
1391 .to_string_lossy();
1392
1393 let kind_label = match &symbol.kind {
1394 SymbolKind::Function => "function",
1395 SymbolKind::Class => "class",
1396 SymbolKind::Method => "method",
1397 SymbolKind::Struct => "struct",
1398 SymbolKind::Interface => "interface",
1399 SymbolKind::Enum => "enum",
1400 SymbolKind::TypeAlias => "type",
1401 SymbolKind::Variable => "variable",
1402 SymbolKind::Heading => "heading",
1403 };
1404
1405 let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1407
1408 if let Some(sig) = &symbol.signature {
1409 text.push_str(&format!(" signature:{}", sig));
1410 }
1411
1412 let lines: Vec<&str> = source.lines().collect();
1414 let start = (symbol.range.start_line as usize).min(lines.len());
1415 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1417 if start < end {
1418 let body: String = lines[start..end]
1419 .iter()
1420 .take(15) .copied()
1422 .collect::<Vec<&str>>()
1423 .join("\n");
1424 let snippet = if body.len() > 300 {
1425 format!("{}...", &body[..body.floor_char_boundary(300)])
1426 } else {
1427 body
1428 };
1429 text.push_str(&format!(" body:{}", snippet));
1430 }
1431
1432 text
1433}
1434
1435fn parser_for(
1436 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1437 lang: crate::parser::LangId,
1438) -> Result<&mut Parser, String> {
1439 use std::collections::hash_map::Entry;
1440
1441 match parsers.entry(lang) {
1442 Entry::Occupied(entry) => Ok(entry.into_mut()),
1443 Entry::Vacant(entry) => {
1444 let grammar = grammar_for(lang);
1445 let mut parser = Parser::new();
1446 parser
1447 .set_language(&grammar)
1448 .map_err(|error| error.to_string())?;
1449 Ok(entry.insert(parser))
1450 }
1451 }
1452}
1453
1454fn collect_file_chunks(
1455 project_root: &Path,
1456 file: &Path,
1457 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1458) -> Result<Vec<SemanticChunk>, String> {
1459 let lang = detect_language(file).ok_or_else(|| "unsupported file extension".to_string())?;
1460 let source = std::fs::read_to_string(file).map_err(|error| error.to_string())?;
1461 let tree = parser_for(parsers, lang)?
1462 .parse(&source, None)
1463 .ok_or_else(|| format!("tree-sitter parse returned None for {}", file.display()))?;
1464 let symbols =
1465 extract_symbols_from_tree(&source, &tree, lang).map_err(|error| error.to_string())?;
1466
1467 Ok(symbols_to_chunks(file, &symbols, &source, project_root))
1468}
1469
1470fn build_snippet(symbol: &Symbol, source: &str) -> String {
1472 let lines: Vec<&str> = source.lines().collect();
1473 let start = (symbol.range.start_line as usize).min(lines.len());
1474 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1476 if start < end {
1477 let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1478 let mut snippet = snippet_lines.join("\n");
1479 if end - start > 5 {
1480 snippet.push_str("\n ...");
1481 }
1482 if snippet.len() > 300 {
1483 snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1484 }
1485 snippet
1486 } else {
1487 String::new()
1488 }
1489}
1490
1491fn symbols_to_chunks(
1493 file: &Path,
1494 symbols: &[Symbol],
1495 source: &str,
1496 project_root: &Path,
1497) -> Vec<SemanticChunk> {
1498 let mut chunks = Vec::new();
1499
1500 for symbol in symbols {
1501 if matches!(symbol.kind, SymbolKind::Heading) {
1506 continue;
1507 }
1508
1509 let line_count = symbol
1511 .range
1512 .end_line
1513 .saturating_sub(symbol.range.start_line)
1514 + 1;
1515 if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1516 continue;
1517 }
1518
1519 let embed_text = build_embed_text(symbol, source, file, project_root);
1520 let snippet = build_snippet(symbol, source);
1521
1522 chunks.push(SemanticChunk {
1523 file: file.to_path_buf(),
1524 name: symbol.name.clone(),
1525 kind: symbol.kind.clone(),
1526 start_line: symbol.range.start_line,
1527 end_line: symbol.range.end_line,
1528 exported: symbol.exported,
1529 embed_text,
1530 snippet,
1531 });
1532
1533 }
1536
1537 chunks
1538}
1539
1540fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1542 if a.len() != b.len() {
1543 return 0.0;
1544 }
1545
1546 let mut dot = 0.0f32;
1547 let mut norm_a = 0.0f32;
1548 let mut norm_b = 0.0f32;
1549
1550 for i in 0..a.len() {
1551 dot += a[i] * b[i];
1552 norm_a += a[i] * a[i];
1553 norm_b += b[i] * b[i];
1554 }
1555
1556 let denom = norm_a.sqrt() * norm_b.sqrt();
1557 if denom == 0.0 {
1558 0.0
1559 } else {
1560 dot / denom
1561 }
1562}
1563
1564fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1566 match kind {
1567 SymbolKind::Function => 0,
1568 SymbolKind::Class => 1,
1569 SymbolKind::Method => 2,
1570 SymbolKind::Struct => 3,
1571 SymbolKind::Interface => 4,
1572 SymbolKind::Enum => 5,
1573 SymbolKind::TypeAlias => 6,
1574 SymbolKind::Variable => 7,
1575 SymbolKind::Heading => 8,
1576 }
1577}
1578
1579fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1580 match v {
1581 0 => SymbolKind::Function,
1582 1 => SymbolKind::Class,
1583 2 => SymbolKind::Method,
1584 3 => SymbolKind::Struct,
1585 4 => SymbolKind::Interface,
1586 5 => SymbolKind::Enum,
1587 6 => SymbolKind::TypeAlias,
1588 7 => SymbolKind::Variable,
1589 _ => SymbolKind::Heading,
1590 }
1591}
1592
1593fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
1594 if *pos + 4 > data.len() {
1595 return Err("unexpected end of data reading u32".to_string());
1596 }
1597 let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
1598 *pos += 4;
1599 Ok(val)
1600}
1601
1602fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
1603 if *pos + 8 > data.len() {
1604 return Err("unexpected end of data reading u64".to_string());
1605 }
1606 let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
1607 *pos += 8;
1608 Ok(u64::from_le_bytes(bytes))
1609}
1610
1611fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
1612 let len = read_u32(data, pos)? as usize;
1613 if *pos + len > data.len() {
1614 return Err("unexpected end of data reading string".to_string());
1615 }
1616 let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
1617 *pos += len;
1618 Ok(s)
1619}
1620
1621#[cfg(test)]
1622mod tests {
1623 use super::*;
1624 use crate::config::{SemanticBackend, SemanticBackendConfig};
1625 use crate::parser::FileParser;
1626 use std::io::{Read, Write};
1627 use std::net::TcpListener;
1628 use std::thread;
1629
1630 fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
1631 where
1632 F: Fn(String, String, String) -> String + Send + 'static,
1633 {
1634 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1635 let addr = listener.local_addr().expect("local addr");
1636 let handle = thread::spawn(move || {
1637 let (mut stream, _) = listener.accept().expect("accept request");
1638 let mut buf = Vec::new();
1639 let mut chunk = [0u8; 4096];
1640 let mut header_end = None;
1641 let mut content_length = 0usize;
1642 loop {
1643 let n = stream.read(&mut chunk).expect("read request");
1644 if n == 0 {
1645 break;
1646 }
1647 buf.extend_from_slice(&chunk[..n]);
1648 if header_end.is_none() {
1649 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
1650 header_end = Some(pos + 4);
1651 let headers = String::from_utf8_lossy(&buf[..pos + 4]);
1652 for line in headers.lines() {
1653 if let Some(value) = line.strip_prefix("Content-Length:") {
1654 content_length = value.trim().parse::<usize>().unwrap_or(0);
1655 }
1656 }
1657 }
1658 }
1659 if let Some(end) = header_end {
1660 if buf.len() >= end + content_length {
1661 break;
1662 }
1663 }
1664 }
1665
1666 let end = header_end.expect("header terminator");
1667 let request = String::from_utf8_lossy(&buf[..end]).to_string();
1668 let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
1669 let mut lines = request.lines();
1670 let request_line = lines.next().expect("request line").to_string();
1671 let path = request_line
1672 .split_whitespace()
1673 .nth(1)
1674 .expect("request path")
1675 .to_string();
1676 let response_body = handler(request_line, path, body);
1677 let response = format!(
1678 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
1679 response_body.len(),
1680 response_body
1681 );
1682 stream
1683 .write_all(response.as_bytes())
1684 .expect("write response");
1685 });
1686
1687 (format!("http://{}", addr), handle)
1688 }
1689
1690 #[test]
1691 fn test_cosine_similarity_identical() {
1692 let a = vec![1.0, 0.0, 0.0];
1693 let b = vec![1.0, 0.0, 0.0];
1694 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1695 }
1696
1697 #[test]
1698 fn test_cosine_similarity_orthogonal() {
1699 let a = vec![1.0, 0.0, 0.0];
1700 let b = vec![0.0, 1.0, 0.0];
1701 assert!(cosine_similarity(&a, &b).abs() < 0.001);
1702 }
1703
1704 #[test]
1705 fn test_cosine_similarity_opposite() {
1706 let a = vec![1.0, 0.0, 0.0];
1707 let b = vec![-1.0, 0.0, 0.0];
1708 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
1709 }
1710
1711 #[test]
1712 fn test_serialization_roundtrip() {
1713 let mut index = SemanticIndex::new();
1714 index.entries.push(EmbeddingEntry {
1715 chunk: SemanticChunk {
1716 file: PathBuf::from("/src/main.rs"),
1717 name: "handle_request".to_string(),
1718 kind: SymbolKind::Function,
1719 start_line: 10,
1720 end_line: 25,
1721 exported: true,
1722 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1723 snippet: "fn handle_request() {\n // ...\n}".to_string(),
1724 },
1725 vector: vec![0.1, 0.2, 0.3, 0.4],
1726 });
1727 index.dimension = 4;
1728 index
1729 .file_mtimes
1730 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1731 index.set_fingerprint(SemanticIndexFingerprint {
1732 backend: "fastembed".to_string(),
1733 model: "all-MiniLM-L6-v2".to_string(),
1734 base_url: FALLBACK_BACKEND.to_string(),
1735 dimension: 4,
1736 });
1737
1738 let bytes = index.to_bytes();
1739 let restored = SemanticIndex::from_bytes(&bytes).unwrap();
1740
1741 assert_eq!(restored.entries.len(), 1);
1742 assert_eq!(restored.entries[0].chunk.name, "handle_request");
1743 assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
1744 assert_eq!(restored.dimension, 4);
1745 assert_eq!(restored.backend_label(), Some("fastembed"));
1746 assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
1747 }
1748
1749 #[test]
1750 fn test_search_top_k() {
1751 let mut index = SemanticIndex::new();
1752 index.dimension = 3;
1753
1754 for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
1756 let mut vec = vec![0.0f32; 3];
1757 vec[i] = 1.0; index.entries.push(EmbeddingEntry {
1759 chunk: SemanticChunk {
1760 file: PathBuf::from("/src/lib.rs"),
1761 name: name.to_string(),
1762 kind: SymbolKind::Function,
1763 start_line: (i * 10 + 1) as u32,
1764 end_line: (i * 10 + 5) as u32,
1765 exported: true,
1766 embed_text: format!("kind:function name:{}", name),
1767 snippet: format!("fn {}() {{}}", name),
1768 },
1769 vector: vec,
1770 });
1771 }
1772
1773 let query = vec![0.9, 0.1, 0.0];
1775 let results = index.search(&query, 2);
1776
1777 assert_eq!(results.len(), 2);
1778 assert_eq!(results[0].name, "auth"); assert!(results[0].score > results[1].score);
1780 }
1781
1782 #[test]
1783 fn test_empty_index_search() {
1784 let index = SemanticIndex::new();
1785 let results = index.search(&[0.1, 0.2, 0.3], 10);
1786 assert!(results.is_empty());
1787 }
1788
1789 #[test]
1790 fn single_line_symbol_builds_non_empty_snippet() {
1791 let symbol = Symbol {
1792 name: "answer".to_string(),
1793 kind: SymbolKind::Variable,
1794 range: crate::symbols::Range {
1795 start_line: 0,
1796 start_col: 0,
1797 end_line: 0,
1798 end_col: 24,
1799 },
1800 signature: Some("const answer = 42".to_string()),
1801 scope_chain: Vec::new(),
1802 exported: true,
1803 parent: None,
1804 };
1805 let source = "export const answer = 42;\n";
1806
1807 let snippet = build_snippet(&symbol, source);
1808
1809 assert_eq!(snippet, "export const answer = 42;");
1810 }
1811
1812 #[test]
1813 fn optimized_file_chunk_collection_matches_file_parser_path() {
1814 let project_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1815 let file = project_root.join("src/semantic_index.rs");
1816 let source = std::fs::read_to_string(&file).unwrap();
1817
1818 let mut legacy_parser = FileParser::new();
1819 let legacy_symbols = legacy_parser.extract_symbols(&file).unwrap();
1820 let legacy_chunks = symbols_to_chunks(&file, &legacy_symbols, &source, &project_root);
1821
1822 let mut parsers = HashMap::new();
1823 let optimized_chunks = collect_file_chunks(&project_root, &file, &mut parsers).unwrap();
1824
1825 assert_eq!(
1826 chunk_fingerprint(&optimized_chunks),
1827 chunk_fingerprint(&legacy_chunks)
1828 );
1829 }
1830
1831 fn chunk_fingerprint(
1832 chunks: &[SemanticChunk],
1833 ) -> Vec<(String, SymbolKind, u32, u32, bool, String, String)> {
1834 chunks
1835 .iter()
1836 .map(|chunk| {
1837 (
1838 chunk.name.clone(),
1839 chunk.kind.clone(),
1840 chunk.start_line,
1841 chunk.end_line,
1842 chunk.exported,
1843 chunk.embed_text.clone(),
1844 chunk.snippet.clone(),
1845 )
1846 })
1847 .collect()
1848 }
1849
1850 #[test]
1851 fn rejects_oversized_dimension_during_deserialization() {
1852 let mut bytes = Vec::new();
1853 bytes.push(1u8);
1854 bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
1855 bytes.extend_from_slice(&0u32.to_le_bytes());
1856 bytes.extend_from_slice(&0u32.to_le_bytes());
1857
1858 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1859 }
1860
1861 #[test]
1862 fn rejects_oversized_entry_count_during_deserialization() {
1863 let mut bytes = Vec::new();
1864 bytes.push(1u8);
1865 bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
1866 bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
1867 bytes.extend_from_slice(&0u32.to_le_bytes());
1868
1869 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1870 }
1871
1872 #[test]
1873 fn invalidate_file_removes_entries_and_mtime() {
1874 let target = PathBuf::from("/src/main.rs");
1875 let mut index = SemanticIndex::new();
1876 index.entries.push(EmbeddingEntry {
1877 chunk: SemanticChunk {
1878 file: target.clone(),
1879 name: "main".to_string(),
1880 kind: SymbolKind::Function,
1881 start_line: 0,
1882 end_line: 1,
1883 exported: false,
1884 embed_text: "main".to_string(),
1885 snippet: "fn main() {}".to_string(),
1886 },
1887 vector: vec![1.0; DEFAULT_DIMENSION],
1888 });
1889 index
1890 .file_mtimes
1891 .insert(target.clone(), SystemTime::UNIX_EPOCH);
1892
1893 index.invalidate_file(&target);
1894
1895 assert!(index.entries.is_empty());
1896 assert!(!index.file_mtimes.contains_key(&target));
1897 }
1898
1899 #[test]
1900 fn detects_missing_onnx_runtime_from_dynamic_load_error() {
1901 let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
1902
1903 assert!(is_onnx_runtime_unavailable(message));
1904 }
1905
1906 #[test]
1907 fn formats_missing_onnx_runtime_with_install_hint() {
1908 let message = format_embedding_init_error(
1909 "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
1910 );
1911
1912 assert!(message.starts_with("ONNX Runtime not found. Install via:"));
1913 assert!(message.contains("Original error:"));
1914 }
1915
1916 #[test]
1917 fn openai_compatible_backend_embeds_with_mock_server() {
1918 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1919 assert!(request_line.starts_with("POST "));
1920 assert_eq!(path, "/v1/embeddings");
1921 "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
1922 });
1923
1924 let config = SemanticBackendConfig {
1925 backend: SemanticBackend::OpenAiCompatible,
1926 model: "test-embedding".to_string(),
1927 base_url: Some(base_url),
1928 api_key_env: None,
1929 timeout_ms: 5_000,
1930 max_batch_size: 64,
1931 };
1932
1933 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1934 let vectors = model
1935 .embed(vec!["hello".to_string(), "world".to_string()])
1936 .unwrap();
1937
1938 assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
1939 handle.join().unwrap();
1940 }
1941
1942 #[test]
1943 fn ollama_backend_embeds_with_mock_server() {
1944 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1945 assert!(request_line.starts_with("POST "));
1946 assert_eq!(path, "/api/embed");
1947 "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
1948 });
1949
1950 let config = SemanticBackendConfig {
1951 backend: SemanticBackend::Ollama,
1952 model: "embeddinggemma".to_string(),
1953 base_url: Some(base_url),
1954 api_key_env: None,
1955 timeout_ms: 5_000,
1956 max_batch_size: 64,
1957 };
1958
1959 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1960 let vectors = model
1961 .embed(vec!["hello".to_string(), "world".to_string()])
1962 .unwrap();
1963
1964 assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
1965 handle.join().unwrap();
1966 }
1967
1968 #[test]
1969 fn read_from_disk_rejects_fingerprint_mismatch() {
1970 let storage = tempfile::tempdir().unwrap();
1971 let project_key = "proj";
1972
1973 let mut index = SemanticIndex::new();
1974 index.entries.push(EmbeddingEntry {
1975 chunk: SemanticChunk {
1976 file: PathBuf::from("/src/main.rs"),
1977 name: "handle_request".to_string(),
1978 kind: SymbolKind::Function,
1979 start_line: 10,
1980 end_line: 25,
1981 exported: true,
1982 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1983 snippet: "fn handle_request() {}".to_string(),
1984 },
1985 vector: vec![0.1, 0.2, 0.3],
1986 });
1987 index.dimension = 3;
1988 index
1989 .file_mtimes
1990 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1991 index.set_fingerprint(SemanticIndexFingerprint {
1992 backend: "openai_compatible".to_string(),
1993 model: "test-embedding".to_string(),
1994 base_url: "http://127.0.0.1:1234/v1".to_string(),
1995 dimension: 3,
1996 });
1997 index.write_to_disk(storage.path(), project_key);
1998
1999 let matching = index.fingerprint().unwrap().as_string();
2000 assert!(
2001 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
2002 );
2003
2004 let mismatched = SemanticIndexFingerprint {
2005 backend: "ollama".to_string(),
2006 model: "embeddinggemma".to_string(),
2007 base_url: "http://127.0.0.1:11434".to_string(),
2008 dimension: 3,
2009 }
2010 .as_string();
2011 assert!(
2012 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
2013 );
2014 }
2015
2016 #[test]
2017 fn read_from_disk_rejects_v3_cache_for_snippet_rebuild() {
2018 let storage = tempfile::tempdir().unwrap();
2019 let project_key = "proj-v3";
2020 let dir = storage.path().join("semantic").join(project_key);
2021 fs::create_dir_all(&dir).unwrap();
2022
2023 let mut index = SemanticIndex::new();
2024 index.entries.push(EmbeddingEntry {
2025 chunk: SemanticChunk {
2026 file: PathBuf::from("/src/main.rs"),
2027 name: "handle_request".to_string(),
2028 kind: SymbolKind::Function,
2029 start_line: 0,
2030 end_line: 0,
2031 exported: true,
2032 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2033 snippet: "fn handle_request() {}".to_string(),
2034 },
2035 vector: vec![0.1, 0.2, 0.3],
2036 });
2037 index.dimension = 3;
2038 index
2039 .file_mtimes
2040 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2041 let fingerprint = SemanticIndexFingerprint {
2042 backend: "fastembed".to_string(),
2043 model: "test".to_string(),
2044 base_url: FALLBACK_BACKEND.to_string(),
2045 dimension: 3,
2046 };
2047 index.set_fingerprint(fingerprint.clone());
2048
2049 let mut bytes = index.to_bytes();
2050 bytes[0] = SEMANTIC_INDEX_VERSION_V3;
2051 fs::write(dir.join("semantic.bin"), bytes).unwrap();
2052
2053 assert!(SemanticIndex::read_from_disk(
2054 storage.path(),
2055 project_key,
2056 Some(&fingerprint.as_string())
2057 )
2058 .is_none());
2059 assert!(!dir.join("semantic.bin").exists());
2060 }
2061
2062 fn make_symbol(kind: SymbolKind, name: &str, start: u32, end: u32) -> crate::symbols::Symbol {
2063 crate::symbols::Symbol {
2064 name: name.to_string(),
2065 kind,
2066 range: crate::symbols::Range {
2067 start_line: start,
2068 start_col: 0,
2069 end_line: end,
2070 end_col: 0,
2071 },
2072 signature: None,
2073 scope_chain: Vec::new(),
2074 exported: false,
2075 parent: None,
2076 }
2077 }
2078
2079 #[test]
2084 fn symbols_to_chunks_skips_heading_symbols() {
2085 let project_root = PathBuf::from("/proj");
2086 let file = project_root.join("README.md");
2087 let source = "# Title\n\nbody text\n\n## Section\n\nmore text\n";
2088
2089 let symbols = vec![
2090 make_symbol(SymbolKind::Heading, "Title", 0, 2),
2091 make_symbol(SymbolKind::Heading, "Section", 4, 6),
2092 ];
2093
2094 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2095 assert!(
2096 chunks.is_empty(),
2097 "Heading symbols must be filtered out before embedding; got {} chunk(s)",
2098 chunks.len()
2099 );
2100 }
2101
2102 #[test]
2106 fn symbols_to_chunks_keeps_code_symbols_alongside_skipped_headings() {
2107 let project_root = PathBuf::from("/proj");
2108 let file = project_root.join("src/lib.rs");
2109 let source = "pub fn handle_request() -> bool {\n true\n}\n";
2110
2111 let symbols = vec![
2112 make_symbol(SymbolKind::Heading, "doc heading", 0, 1),
2114 make_symbol(SymbolKind::Function, "handle_request", 0, 2),
2115 make_symbol(SymbolKind::Struct, "AuthService", 4, 6),
2116 ];
2117
2118 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2119 assert_eq!(
2120 chunks.len(),
2121 2,
2122 "Expected 2 code chunks (Function + Struct), got {}",
2123 chunks.len()
2124 );
2125 let names: Vec<&str> = chunks.iter().map(|c| c.name.as_str()).collect();
2126 assert!(names.contains(&"handle_request"));
2127 assert!(names.contains(&"AuthService"));
2128 assert!(
2129 !names.contains(&"doc heading"),
2130 "Heading symbol leaked into chunks: {names:?}"
2131 );
2132 }
2133}