1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::{detect_language, extract_symbols_from_tree, grammar_for};
3use crate::symbols::{Symbol, SymbolKind};
4
5use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
6use rayon::prelude::*;
7use reqwest::blocking::Client;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::env;
11use std::fmt::Display;
12use std::fs;
13use std::path::{Path, PathBuf};
14use std::time::Duration;
15use std::time::SystemTime;
16use tree_sitter::Parser;
17use url::Url;
18
19const DEFAULT_DIMENSION: usize = 384;
20const MAX_ENTRIES: usize = 1_000_000;
21const MAX_DIMENSION: usize = 1024;
22const F32_BYTES: usize = std::mem::size_of::<f32>();
23const HEADER_BYTES_V1: usize = 9;
24const HEADER_BYTES_V2: usize = 13;
25const ONNX_RUNTIME_INSTALL_HINT: &str =
26 "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
27
28const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
29const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
30const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
35const SEMANTIC_INDEX_VERSION_V4: u8 = 4;
38const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
39const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
40const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
42const DEFAULT_MAX_BATCH_SIZE: usize = 64;
43const FALLBACK_BACKEND: &str = "none";
44const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
45const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct SemanticIndexFingerprint {
49 pub backend: String,
50 pub model: String,
51 #[serde(default)]
52 pub base_url: String,
53 pub dimension: usize,
54}
55
56impl SemanticIndexFingerprint {
57 fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
58 let base_url = config
61 .base_url
62 .as_ref()
63 .and_then(|u| normalize_base_url(u).ok())
64 .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
65 Self {
66 backend: config.backend.as_str().to_string(),
67 model: config.model.clone(),
68 base_url,
69 dimension,
70 }
71 }
72
73 pub fn as_string(&self) -> String {
74 serde_json::to_string(self).unwrap_or_else(|_| String::new())
75 }
76
77 fn matches_expected(&self, expected: &str) -> bool {
78 let encoded = self.as_string();
79 !encoded.is_empty() && encoded == expected
80 }
81}
82
83enum SemanticEmbeddingEngine {
84 Fastembed(TextEmbedding),
85 OpenAiCompatible {
86 client: Client,
87 model: String,
88 base_url: String,
89 api_key: Option<String>,
90 },
91 Ollama {
92 client: Client,
93 model: String,
94 base_url: String,
95 },
96}
97
98pub struct SemanticEmbeddingModel {
99 backend: SemanticBackend,
100 model: String,
101 base_url: Option<String>,
102 timeout_ms: u64,
103 max_batch_size: usize,
104 dimension: Option<usize>,
105 engine: SemanticEmbeddingEngine,
106}
107
108pub type EmbeddingModel = SemanticEmbeddingModel;
109
110fn validate_embedding_batch(
111 vectors: &[Vec<f32>],
112 expected_count: usize,
113 context: &str,
114) -> Result<(), String> {
115 if expected_count > 0 && vectors.is_empty() {
116 return Err(format!(
117 "{context} returned no vectors for {expected_count} inputs"
118 ));
119 }
120
121 if vectors.len() != expected_count {
122 return Err(format!(
123 "{context} returned {} vectors for {} inputs",
124 vectors.len(),
125 expected_count
126 ));
127 }
128
129 let Some(first_vector) = vectors.first() else {
130 return Ok(());
131 };
132 let expected_dimension = first_vector.len();
133 for (index, vector) in vectors.iter().enumerate() {
134 if vector.len() != expected_dimension {
135 return Err(format!(
136 "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
137 vector.len()
138 ));
139 }
140 }
141
142 Ok(())
143}
144
145fn normalize_base_url(raw: &str) -> Result<String, String> {
146 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
147 let scheme = parsed.scheme();
148 if scheme != "http" && scheme != "https" {
149 return Err(format!(
150 "unsupported URL scheme '{}' — only http:// and https:// are allowed",
151 scheme
152 ));
153 }
154 Ok(parsed.to_string().trim_end_matches('/').to_string())
155}
156
157fn build_openai_embeddings_endpoint(base_url: &str) -> String {
158 if base_url.ends_with("/v1") {
159 format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
160 } else {
161 format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
162 }
163}
164
165fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
166 if base_url.ends_with("/api") {
167 format!("{base_url}/embed")
168 } else {
169 format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
170 }
171}
172
173fn normalize_api_key(value: Option<String>) -> Option<String> {
174 value.and_then(|token| {
175 let token = token.trim();
176 if token.is_empty() {
177 None
178 } else {
179 Some(token.to_string())
180 }
181 })
182}
183
184fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
185 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
186}
187
188fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
189 error.is_connect()
190}
191
192fn sleep_before_embedding_retry(attempt_index: usize) {
193 if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
194 std::thread::sleep(Duration::from_millis(*delay_ms));
195 }
196}
197
198fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
199where
200 F: FnMut() -> reqwest::blocking::RequestBuilder,
201{
202 for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
203 let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
204
205 let response = match make_request().send() {
206 Ok(response) => response,
207 Err(error) => {
208 if !last_attempt && is_retryable_embedding_error(&error) {
209 sleep_before_embedding_retry(attempt_index);
210 continue;
211 }
212 return Err(format!("{backend_label} request failed: {error}"));
213 }
214 };
215
216 let status = response.status();
217 let raw = match response.text() {
218 Ok(raw) => raw,
219 Err(error) => {
220 if !last_attempt && is_retryable_embedding_error(&error) {
221 sleep_before_embedding_retry(attempt_index);
222 continue;
223 }
224 return Err(format!("{backend_label} response read failed: {error}"));
225 }
226 };
227
228 if status.is_success() {
229 return Ok(raw);
230 }
231
232 if !last_attempt && is_retryable_embedding_status(status) {
233 sleep_before_embedding_retry(attempt_index);
234 continue;
235 }
236
237 return Err(format!(
238 "{backend_label} request failed (HTTP {}): {}",
239 status, raw
240 ));
241 }
242
243 unreachable!("embedding request retries exhausted without returning")
244}
245
246impl SemanticEmbeddingModel {
247 pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
248 let timeout_ms = if config.timeout_ms == 0 {
249 DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
250 } else {
251 config.timeout_ms
252 };
253
254 let max_batch_size = if config.max_batch_size == 0 {
255 DEFAULT_MAX_BATCH_SIZE
256 } else {
257 config.max_batch_size
258 };
259
260 let api_key_env = normalize_api_key(config.api_key_env.clone());
261 let model = config.model.clone();
262
263 let client = Client::builder()
264 .timeout(Duration::from_millis(timeout_ms))
265 .redirect(reqwest::redirect::Policy::none())
266 .build()
267 .map_err(|error| format!("failed to configure embedding client: {error}"))?;
268
269 let engine = match config.backend {
270 SemanticBackend::Fastembed => {
271 SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
272 }
273 SemanticBackend::OpenAiCompatible => {
274 let raw = config.base_url.as_ref().ok_or_else(|| {
275 "base_url is required for openai_compatible backend".to_string()
276 })?;
277 let base_url = normalize_base_url(raw)?;
278
279 let api_key = match api_key_env {
280 Some(var_name) => Some(env::var(&var_name).map_err(|_| {
281 format!("missing api_key_env '{var_name}' for openai_compatible backend")
282 })?),
283 None => None,
284 };
285
286 SemanticEmbeddingEngine::OpenAiCompatible {
287 client,
288 model,
289 base_url,
290 api_key,
291 }
292 }
293 SemanticBackend::Ollama => {
294 let raw = config
295 .base_url
296 .as_ref()
297 .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
298 let base_url = normalize_base_url(raw)?;
299
300 SemanticEmbeddingEngine::Ollama {
301 client,
302 model,
303 base_url,
304 }
305 }
306 };
307
308 Ok(Self {
309 backend: config.backend,
310 model: config.model.clone(),
311 base_url: config.base_url.clone(),
312 timeout_ms,
313 max_batch_size,
314 dimension: None,
315 engine,
316 })
317 }
318
319 pub fn backend(&self) -> SemanticBackend {
320 self.backend
321 }
322
323 pub fn model(&self) -> &str {
324 &self.model
325 }
326
327 pub fn base_url(&self) -> Option<&str> {
328 self.base_url.as_deref()
329 }
330
331 pub fn max_batch_size(&self) -> usize {
332 self.max_batch_size
333 }
334
335 pub fn timeout_ms(&self) -> u64 {
336 self.timeout_ms
337 }
338
339 pub fn fingerprint(
340 &mut self,
341 config: &SemanticBackendConfig,
342 ) -> Result<SemanticIndexFingerprint, String> {
343 let dimension = self.dimension()?;
344 Ok(SemanticIndexFingerprint::from_config(config, dimension))
345 }
346
347 pub fn dimension(&mut self) -> Result<usize, String> {
348 if let Some(dimension) = self.dimension {
349 return Ok(dimension);
350 }
351
352 let dimension = match &mut self.engine {
353 SemanticEmbeddingEngine::Fastembed(model) => {
354 let vectors = model
355 .embed(vec!["semantic index fingerprint probe".to_string()], None)
356 .map_err(|error| format_embedding_init_error(error.to_string()))?;
357 vectors
358 .first()
359 .map(|v| v.len())
360 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
361 }
362 SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
363 let vectors =
364 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
365 vectors
366 .first()
367 .map(|v| v.len())
368 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
369 }
370 SemanticEmbeddingEngine::Ollama { .. } => {
371 let vectors =
372 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
373 vectors
374 .first()
375 .map(|v| v.len())
376 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
377 }
378 };
379
380 self.dimension = Some(dimension);
381 Ok(dimension)
382 }
383
384 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
385 self.embed_texts(texts)
386 }
387
388 fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
389 match &mut self.engine {
390 SemanticEmbeddingEngine::Fastembed(model) => model
391 .embed(texts, None::<usize>)
392 .map_err(|error| format_embedding_init_error(error.to_string()))
393 .map_err(|error| format!("failed to embed batch: {error}")),
394 SemanticEmbeddingEngine::OpenAiCompatible {
395 client,
396 model,
397 base_url,
398 api_key,
399 } => {
400 let expected_text_count = texts.len();
401 let endpoint = build_openai_embeddings_endpoint(base_url);
402 let body = serde_json::json!({
403 "input": texts,
404 "model": model,
405 });
406
407 let raw = send_embedding_request(
408 || {
409 let mut request = client
410 .post(&endpoint)
411 .json(&body)
412 .header("Content-Type", "application/json");
413
414 if let Some(api_key) = api_key {
415 request = request.header("Authorization", format!("Bearer {api_key}"));
416 }
417
418 request
419 },
420 "openai compatible",
421 )?;
422
423 #[derive(Deserialize)]
424 struct OpenAiResponse {
425 data: Vec<OpenAiEmbeddingResult>,
426 }
427
428 #[derive(Deserialize)]
429 struct OpenAiEmbeddingResult {
430 embedding: Vec<f32>,
431 index: Option<u32>,
432 }
433
434 let parsed: OpenAiResponse = serde_json::from_str(&raw)
435 .map_err(|error| format!("invalid openai compatible response: {error}"))?;
436 if parsed.data.len() != expected_text_count {
437 return Err(format!(
438 "openai compatible response returned {} embeddings for {} inputs",
439 parsed.data.len(),
440 expected_text_count
441 ));
442 }
443
444 let mut vectors = vec![Vec::new(); parsed.data.len()];
445 for (i, item) in parsed.data.into_iter().enumerate() {
446 let index = item.index.unwrap_or(i as u32) as usize;
447 if index >= vectors.len() {
448 return Err(
449 "openai compatible response contains invalid vector index".to_string()
450 );
451 }
452 vectors[index] = item.embedding;
453 }
454
455 for vector in &vectors {
456 if vector.is_empty() {
457 return Err(
458 "openai compatible response contained missing vectors".to_string()
459 );
460 }
461 }
462
463 self.dimension = vectors.first().map(Vec::len);
464 Ok(vectors)
465 }
466 SemanticEmbeddingEngine::Ollama {
467 client,
468 model,
469 base_url,
470 } => {
471 let expected_text_count = texts.len();
472 let endpoint = build_ollama_embeddings_endpoint(base_url);
473
474 #[derive(Serialize)]
475 struct OllamaPayload<'a> {
476 model: &'a str,
477 input: Vec<String>,
478 }
479
480 let payload = OllamaPayload {
481 model,
482 input: texts,
483 };
484
485 let raw = send_embedding_request(
486 || {
487 client
488 .post(&endpoint)
489 .json(&payload)
490 .header("Content-Type", "application/json")
491 },
492 "ollama",
493 )?;
494
495 #[derive(Deserialize)]
496 struct OllamaResponse {
497 embeddings: Vec<Vec<f32>>,
498 }
499
500 let parsed: OllamaResponse = serde_json::from_str(&raw)
501 .map_err(|error| format!("invalid ollama response: {error}"))?;
502 if parsed.embeddings.is_empty() {
503 return Err("ollama response returned no embeddings".to_string());
504 }
505 if parsed.embeddings.len() != expected_text_count {
506 return Err(format!(
507 "ollama response returned {} embeddings for {} inputs",
508 parsed.embeddings.len(),
509 expected_text_count
510 ));
511 }
512
513 let vectors = parsed.embeddings;
514 for vector in &vectors {
515 if vector.is_empty() {
516 return Err("ollama response contained empty embeddings".to_string());
517 }
518 }
519
520 self.dimension = vectors.first().map(Vec::len);
521 Ok(vectors)
522 }
523 }
524 }
525}
526
527pub fn pre_validate_onnx_runtime() -> Result<(), String> {
531 let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
532
533 #[cfg(any(target_os = "linux", target_os = "macos"))]
534 {
535 #[cfg(target_os = "linux")]
536 let default_name = "libonnxruntime.so";
537 #[cfg(target_os = "macos")]
538 let default_name = "libonnxruntime.dylib";
539
540 let lib_name = dylib_path.as_deref().unwrap_or(default_name);
541
542 unsafe {
543 let c_name = std::ffi::CString::new(lib_name)
544 .map_err(|e| format!("invalid library path: {}", e))?;
545 let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
546 if handle.is_null() {
547 let err = libc::dlerror();
548 let msg = if err.is_null() {
549 "unknown dlopen error".to_string()
550 } else {
551 std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
552 };
553 return Err(format!(
554 "ONNX Runtime not found. dlopen('{}') failed: {}. \
555 Run `bunx @cortexkit/aft-opencode@latest doctor` to diagnose.",
556 lib_name, msg
557 ));
558 }
559
560 let detected_version = detect_ort_version_from_path(lib_name);
563
564 libc::dlclose(handle);
565
566 if let Some(ref version) = detected_version {
568 let parts: Vec<&str> = version.split('.').collect();
569 if let (Some(major), Some(minor)) = (
570 parts.first().and_then(|s| s.parse::<u32>().ok()),
571 parts.get(1).and_then(|s| s.parse::<u32>().ok()),
572 ) {
573 if major != 1 || minor < 20 {
574 return Err(format!(
575 "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
576 Solutions:\n\
577 1. Remove the old library and restart (AFT auto-downloads the correct version):\n\
578 {}\n\
579 2. Or install ONNX Runtime 1.24: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
580 3. Run `bunx @cortexkit/aft-opencode@latest doctor` for full diagnostics.",
581 version,
582 lib_name,
583 suggest_removal_command(lib_name),
584 ));
585 }
586 }
587 }
588 }
589 }
590
591 #[cfg(target_os = "windows")]
592 {
593 let _ = dylib_path;
595 }
596
597 Ok(())
598}
599
600fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
603 let path = std::path::Path::new(lib_path);
604
605 for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
607 .into_iter()
608 .flatten()
609 {
610 if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
611 if let Some(version) = extract_version_from_filename(name) {
612 return Some(version);
613 }
614 }
615 }
616
617 if let Some(parent) = path.parent() {
619 if let Ok(entries) = std::fs::read_dir(parent) {
620 for entry in entries.flatten() {
621 if let Some(name) = entry.file_name().to_str() {
622 if name.starts_with("libonnxruntime") {
623 if let Some(version) = extract_version_from_filename(name) {
624 return Some(version);
625 }
626 }
627 }
628 }
629 }
630 }
631
632 None
633}
634
635fn extract_version_from_filename(name: &str) -> Option<String> {
637 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
639 re.find(name).map(|m| m.as_str().to_string())
640}
641
642fn suggest_removal_command(lib_path: &str) -> String {
643 if lib_path.starts_with("/usr/local/lib")
644 || lib_path == "libonnxruntime.so"
645 || lib_path == "libonnxruntime.dylib"
646 {
647 #[cfg(target_os = "linux")]
648 return " sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
649 #[cfg(target_os = "macos")]
650 return " sudo rm /usr/local/lib/libonnxruntime*".to_string();
651 #[cfg(target_os = "windows")]
652 return " Delete the ONNX Runtime DLL from your PATH".to_string();
653 }
654 format!(" rm '{}'", lib_path)
655}
656
657pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
658 pre_validate_onnx_runtime()?;
660
661 let selected_model = match model {
662 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
663 _ => {
664 return Err(format!(
665 "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
666 model
667 ))
668 }
669 };
670
671 TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
672}
673
674pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
675 if message.trim_start().starts_with("ONNX Runtime not found.") {
676 return true;
677 }
678
679 let message = message.to_ascii_lowercase();
680 let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
681 .iter()
682 .any(|pattern| message.contains(pattern));
683 let mentions_dynamic_load_failure = [
684 "shared library",
685 "dynamic library",
686 "failed to load",
687 "could not load",
688 "unable to load",
689 "dlopen",
690 "loadlibrary",
691 "no such file",
692 "not found",
693 ]
694 .iter()
695 .any(|pattern| message.contains(pattern));
696
697 mentions_onnx_runtime && mentions_dynamic_load_failure
698}
699
700fn format_embedding_init_error(error: impl Display) -> String {
701 let message = error.to_string();
702
703 if is_onnx_runtime_unavailable(&message) {
704 return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
705 }
706
707 format!("failed to initialize semantic embedding model: {message}")
708}
709
710#[derive(Debug, Clone)]
712pub struct SemanticChunk {
713 pub file: PathBuf,
715 pub name: String,
717 pub kind: SymbolKind,
719 pub start_line: u32,
721 pub end_line: u32,
722 pub exported: bool,
724 pub embed_text: String,
726 pub snippet: String,
728}
729
730#[derive(Debug)]
732struct EmbeddingEntry {
733 chunk: SemanticChunk,
734 vector: Vec<f32>,
735}
736
737#[derive(Debug)]
739pub struct SemanticIndex {
740 entries: Vec<EmbeddingEntry>,
741 file_mtimes: HashMap<PathBuf, SystemTime>,
743 dimension: usize,
745 fingerprint: Option<SemanticIndexFingerprint>,
746}
747
748#[derive(Debug)]
750pub struct SemanticResult {
751 pub file: PathBuf,
752 pub name: String,
753 pub kind: SymbolKind,
754 pub start_line: u32,
755 pub end_line: u32,
756 pub exported: bool,
757 pub snippet: String,
758 pub score: f32,
759}
760
761impl SemanticIndex {
762 pub fn new() -> Self {
763 Self {
764 entries: Vec::new(),
765 file_mtimes: HashMap::new(),
766 dimension: DEFAULT_DIMENSION, fingerprint: None,
768 }
769 }
770
771 pub fn entry_count(&self) -> usize {
773 self.entries.len()
774 }
775
776 pub fn status_label(&self) -> &'static str {
778 if self.entries.is_empty() {
779 "empty"
780 } else {
781 "ready"
782 }
783 }
784
785 fn collect_chunks(
786 project_root: &Path,
787 files: &[PathBuf],
788 ) -> (Vec<SemanticChunk>, HashMap<PathBuf, SystemTime>) {
789 let per_file: Vec<(PathBuf, SystemTime, Vec<SemanticChunk>)> = files
790 .par_iter()
791 .map_init(HashMap::new, |parsers, file| {
792 let mtime = std::fs::metadata(file)
793 .and_then(|m| m.modified())
794 .unwrap_or(SystemTime::UNIX_EPOCH);
795
796 let chunks = collect_file_chunks(project_root, file, parsers).unwrap_or_default();
797
798 (file.clone(), mtime, chunks)
799 })
800 .collect();
801
802 let mut chunks: Vec<SemanticChunk> = Vec::new();
803 let mut file_mtimes: HashMap<PathBuf, SystemTime> = HashMap::new();
804
805 for (file, mtime, file_chunks) in per_file {
806 file_mtimes.insert(file, mtime);
807 chunks.extend(file_chunks);
808 }
809
810 (chunks, file_mtimes)
811 }
812
813 fn build_from_chunks<F, P>(
814 chunks: Vec<SemanticChunk>,
815 file_mtimes: HashMap<PathBuf, SystemTime>,
816 embed_fn: &mut F,
817 max_batch_size: usize,
818 mut progress: Option<&mut P>,
819 ) -> Result<Self, String>
820 where
821 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
822 P: FnMut(usize, usize),
823 {
824 let total_chunks = chunks.len();
825
826 if chunks.is_empty() {
827 return Ok(Self {
828 entries: Vec::new(),
829 file_mtimes,
830 dimension: DEFAULT_DIMENSION,
831 fingerprint: None,
832 });
833 }
834
835 let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
837 let mut expected_dimension: Option<usize> = None;
838 let batch_size = max_batch_size.max(1);
839 for batch_start in (0..chunks.len()).step_by(batch_size) {
840 let batch_end = (batch_start + batch_size).min(chunks.len());
841 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
842 .iter()
843 .map(|c| c.embed_text.clone())
844 .collect();
845
846 let vectors = embed_fn(batch_texts)?;
847 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
848
849 if let Some(dim) = vectors.first().map(|v| v.len()) {
851 match expected_dimension {
852 None => expected_dimension = Some(dim),
853 Some(expected) if dim != expected => {
854 return Err(format!(
855 "embedding dimension changed across batches: expected {expected}, got {dim}"
856 ));
857 }
858 _ => {}
859 }
860 }
861
862 for (i, vector) in vectors.into_iter().enumerate() {
863 let chunk_idx = batch_start + i;
864 entries.push(EmbeddingEntry {
865 chunk: chunks[chunk_idx].clone(),
866 vector,
867 });
868 }
869
870 if let Some(callback) = progress.as_mut() {
871 callback(entries.len(), total_chunks);
872 }
873 }
874
875 let dimension = entries
876 .first()
877 .map(|e| e.vector.len())
878 .unwrap_or(DEFAULT_DIMENSION);
879
880 Ok(Self {
881 entries,
882 file_mtimes,
883 dimension,
884 fingerprint: None,
885 })
886 }
887
888 pub fn build<F>(
891 project_root: &Path,
892 files: &[PathBuf],
893 embed_fn: &mut F,
894 max_batch_size: usize,
895 ) -> Result<Self, String>
896 where
897 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
898 {
899 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
900 Self::build_from_chunks(
901 chunks,
902 file_mtimes,
903 embed_fn,
904 max_batch_size,
905 Option::<&mut fn(usize, usize)>::None,
906 )
907 }
908
909 pub fn build_with_progress<F, P>(
911 project_root: &Path,
912 files: &[PathBuf],
913 embed_fn: &mut F,
914 max_batch_size: usize,
915 progress: &mut P,
916 ) -> Result<Self, String>
917 where
918 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
919 P: FnMut(usize, usize),
920 {
921 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
922 let total_chunks = chunks.len();
923 progress(0, total_chunks);
924 Self::build_from_chunks(
925 chunks,
926 file_mtimes,
927 embed_fn,
928 max_batch_size,
929 Some(progress),
930 )
931 }
932
933 pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
935 if self.entries.is_empty() || query_vector.len() != self.dimension {
936 return Vec::new();
937 }
938
939 let mut scored: Vec<(f32, usize)> = self
940 .entries
941 .iter()
942 .enumerate()
943 .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
944 .collect();
945
946 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
948
949 scored
950 .into_iter()
951 .take(top_k)
952 .filter(|(score, _)| *score > 0.0)
953 .map(|(score, idx)| {
954 let entry = &self.entries[idx];
955 SemanticResult {
956 file: entry.chunk.file.clone(),
957 name: entry.chunk.name.clone(),
958 kind: entry.chunk.kind.clone(),
959 start_line: entry.chunk.start_line,
960 end_line: entry.chunk.end_line,
961 exported: entry.chunk.exported,
962 snippet: entry.chunk.snippet.clone(),
963 score,
964 }
965 })
966 .collect()
967 }
968
969 pub fn len(&self) -> usize {
971 self.entries.len()
972 }
973
974 pub fn is_file_stale(&self, file: &Path) -> bool {
976 match self.file_mtimes.get(file) {
977 None => true,
978 Some(stored_mtime) => match fs::metadata(file).and_then(|m| m.modified()) {
979 Ok(current_mtime) => *stored_mtime != current_mtime,
980 Err(_) => true,
981 },
982 }
983 }
984
985 pub fn count_stale_files(&self) -> usize {
986 self.file_mtimes
987 .keys()
988 .filter(|path| self.is_file_stale(path))
989 .count()
990 }
991
992 pub fn remove_file(&mut self, file: &Path) {
994 self.invalidate_file(file);
995 }
996
997 pub fn invalidate_file(&mut self, file: &Path) {
998 self.entries.retain(|e| e.chunk.file != file);
999 self.file_mtimes.remove(file);
1000 }
1001
1002 pub fn dimension(&self) -> usize {
1004 self.dimension
1005 }
1006
1007 pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1008 self.fingerprint.as_ref()
1009 }
1010
1011 pub fn backend_label(&self) -> Option<&str> {
1012 self.fingerprint.as_ref().map(|f| f.backend.as_str())
1013 }
1014
1015 pub fn model_label(&self) -> Option<&str> {
1016 self.fingerprint.as_ref().map(|f| f.model.as_str())
1017 }
1018
1019 pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1020 self.fingerprint = Some(fingerprint);
1021 }
1022
1023 pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1025 if self.entries.is_empty() {
1028 log::info!("[aft] skipping semantic index persistence (0 entries)");
1029 return;
1030 }
1031 let dir = storage_dir.join("semantic").join(project_key);
1032 if let Err(e) = fs::create_dir_all(&dir) {
1033 log::warn!("[aft] failed to create semantic cache dir: {}", e);
1034 return;
1035 }
1036 let data_path = dir.join("semantic.bin");
1037 let tmp_path = dir.join("semantic.bin.tmp");
1038 let bytes = self.to_bytes();
1039 if let Err(e) = fs::write(&tmp_path, &bytes) {
1040 log::warn!("[aft] failed to write semantic index: {}", e);
1041 let _ = fs::remove_file(&tmp_path);
1042 return;
1043 }
1044 if let Err(e) = fs::rename(&tmp_path, &data_path) {
1045 log::warn!("[aft] failed to rename semantic index: {}", e);
1046 let _ = fs::remove_file(&tmp_path);
1047 return;
1048 }
1049 log::info!(
1050 "[aft] semantic index persisted: {} entries, {:.1} KB",
1051 self.entries.len(),
1052 bytes.len() as f64 / 1024.0
1053 );
1054 }
1055
1056 pub fn read_from_disk(
1058 storage_dir: &Path,
1059 project_key: &str,
1060 expected_fingerprint: Option<&str>,
1061 ) -> Option<Self> {
1062 let data_path = storage_dir
1063 .join("semantic")
1064 .join(project_key)
1065 .join("semantic.bin");
1066 let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1067 if file_len < HEADER_BYTES_V1 {
1068 log::warn!(
1069 "[aft] corrupt semantic index (too small: {} bytes), removing",
1070 file_len
1071 );
1072 let _ = fs::remove_file(&data_path);
1073 return None;
1074 }
1075
1076 let bytes = fs::read(&data_path).ok()?;
1077 let version = bytes[0];
1078 if version != SEMANTIC_INDEX_VERSION_V4 {
1079 log::info!(
1080 "[aft] cached semantic index version {} is older than {}, rebuilding",
1081 version,
1082 SEMANTIC_INDEX_VERSION_V4
1083 );
1084 let _ = fs::remove_file(&data_path);
1085 return None;
1086 }
1087 match Self::from_bytes(&bytes) {
1088 Ok(index) => {
1089 if index.entries.is_empty() {
1090 log::info!("[aft] cached semantic index is empty, will rebuild");
1091 let _ = fs::remove_file(&data_path);
1092 return None;
1093 }
1094 if let Some(expected) = expected_fingerprint {
1095 let matches = index
1096 .fingerprint()
1097 .map(|fingerprint| fingerprint.matches_expected(expected))
1098 .unwrap_or(false);
1099 if !matches {
1100 log::info!("[aft] cached semantic index fingerprint mismatch, rebuilding");
1101 let _ = fs::remove_file(&data_path);
1102 return None;
1103 }
1104 }
1105 log::info!(
1106 "[aft] loaded semantic index from disk: {} entries",
1107 index.entries.len()
1108 );
1109 Some(index)
1110 }
1111 Err(e) => {
1112 log::warn!("[aft] corrupt semantic index, rebuilding: {}", e);
1113 let _ = fs::remove_file(&data_path);
1114 None
1115 }
1116 }
1117 }
1118
1119 pub fn to_bytes(&self) -> Vec<u8> {
1121 let mut buf = Vec::new();
1122 let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1123 let encoded = fingerprint.as_string();
1124 if encoded.is_empty() {
1125 None
1126 } else {
1127 Some(encoded.into_bytes())
1128 }
1129 });
1130
1131 let version = SEMANTIC_INDEX_VERSION_V4;
1144 buf.push(version);
1145 buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1146 buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1147 let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1148 buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1149 buf.extend_from_slice(fp_bytes_ref);
1150
1151 buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1154 for (path, mtime) in &self.file_mtimes {
1155 let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1156 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1157 buf.extend_from_slice(&path_bytes);
1158 let duration = mtime
1159 .duration_since(SystemTime::UNIX_EPOCH)
1160 .unwrap_or_default();
1161 buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1162 buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1163 }
1164
1165 for entry in &self.entries {
1167 let c = &entry.chunk;
1168
1169 let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1171 buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1172 buf.extend_from_slice(&file_bytes);
1173
1174 let name_bytes = c.name.as_bytes();
1176 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1177 buf.extend_from_slice(name_bytes);
1178
1179 buf.push(symbol_kind_to_u8(&c.kind));
1181
1182 buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1184 buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1185 buf.push(c.exported as u8);
1186
1187 let snippet_bytes = c.snippet.as_bytes();
1189 buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1190 buf.extend_from_slice(snippet_bytes);
1191
1192 let embed_bytes = c.embed_text.as_bytes();
1194 buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1195 buf.extend_from_slice(embed_bytes);
1196
1197 for &val in &entry.vector {
1199 buf.extend_from_slice(&val.to_le_bytes());
1200 }
1201 }
1202
1203 buf
1204 }
1205
1206 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1208 let mut pos = 0;
1209
1210 if data.len() < HEADER_BYTES_V1 {
1211 return Err("data too short".to_string());
1212 }
1213
1214 let version = data[pos];
1215 pos += 1;
1216 if version != SEMANTIC_INDEX_VERSION_V1
1217 && version != SEMANTIC_INDEX_VERSION_V2
1218 && version != SEMANTIC_INDEX_VERSION_V3
1219 && version != SEMANTIC_INDEX_VERSION_V4
1220 {
1221 return Err(format!("unsupported version: {}", version));
1222 }
1223 if (version == SEMANTIC_INDEX_VERSION_V2
1227 || version == SEMANTIC_INDEX_VERSION_V3
1228 || version == SEMANTIC_INDEX_VERSION_V4)
1229 && data.len() < HEADER_BYTES_V2
1230 {
1231 return Err("data too short for semantic index v2/v3/v4 header".to_string());
1232 }
1233
1234 let dimension = read_u32(data, &mut pos)? as usize;
1235 let entry_count = read_u32(data, &mut pos)? as usize;
1236 if dimension == 0 || dimension > MAX_DIMENSION {
1237 return Err(format!("invalid embedding dimension: {}", dimension));
1238 }
1239 if entry_count > MAX_ENTRIES {
1240 return Err(format!("too many semantic index entries: {}", entry_count));
1241 }
1242
1243 let has_fingerprint_field = version == SEMANTIC_INDEX_VERSION_V2
1249 || version == SEMANTIC_INDEX_VERSION_V3
1250 || version == SEMANTIC_INDEX_VERSION_V4;
1251 let fingerprint = if has_fingerprint_field {
1252 let fingerprint_len = read_u32(data, &mut pos)? as usize;
1253 if pos + fingerprint_len > data.len() {
1254 return Err("unexpected end of data reading fingerprint".to_string());
1255 }
1256 if fingerprint_len == 0 {
1257 None
1258 } else {
1259 let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1260 pos += fingerprint_len;
1261 Some(
1262 serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1263 .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1264 )
1265 }
1266 } else {
1267 None
1268 };
1269
1270 let mtime_count = read_u32(data, &mut pos)? as usize;
1272 if mtime_count > MAX_ENTRIES {
1273 return Err(format!("too many semantic file mtimes: {}", mtime_count));
1274 }
1275
1276 let vector_bytes = entry_count
1277 .checked_mul(dimension)
1278 .and_then(|count| count.checked_mul(F32_BYTES))
1279 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1280 if vector_bytes > data.len().saturating_sub(pos) {
1281 return Err("semantic index vectors exceed available data".to_string());
1282 }
1283
1284 let mut file_mtimes = HashMap::with_capacity(mtime_count);
1285 for _ in 0..mtime_count {
1286 let path = read_string(data, &mut pos)?;
1287 let secs = read_u64(data, &mut pos)?;
1288 let nanos =
1294 if version == SEMANTIC_INDEX_VERSION_V3 || version == SEMANTIC_INDEX_VERSION_V4 {
1295 read_u32(data, &mut pos)?
1296 } else {
1297 0
1298 };
1299 if nanos >= 1_000_000_000 {
1306 return Err(format!(
1307 "invalid semantic mtime: nanos {} >= 1_000_000_000",
1308 nanos
1309 ));
1310 }
1311 let duration = std::time::Duration::new(secs, nanos);
1312 let mtime = SystemTime::UNIX_EPOCH
1313 .checked_add(duration)
1314 .ok_or_else(|| {
1315 format!(
1316 "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1317 secs, nanos
1318 )
1319 })?;
1320 file_mtimes.insert(PathBuf::from(path), mtime);
1321 }
1322
1323 let mut entries = Vec::with_capacity(entry_count);
1325 for _ in 0..entry_count {
1326 let file = PathBuf::from(read_string(data, &mut pos)?);
1327 let name = read_string(data, &mut pos)?;
1328
1329 if pos >= data.len() {
1330 return Err("unexpected end of data".to_string());
1331 }
1332 let kind = u8_to_symbol_kind(data[pos]);
1333 pos += 1;
1334
1335 let start_line = read_u32(data, &mut pos)?;
1336 let end_line = read_u32(data, &mut pos)?;
1337
1338 if pos >= data.len() {
1339 return Err("unexpected end of data".to_string());
1340 }
1341 let exported = data[pos] != 0;
1342 pos += 1;
1343
1344 let snippet = read_string(data, &mut pos)?;
1345 let embed_text = read_string(data, &mut pos)?;
1346
1347 let vec_bytes = dimension
1349 .checked_mul(F32_BYTES)
1350 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1351 if pos + vec_bytes > data.len() {
1352 return Err("unexpected end of data reading vector".to_string());
1353 }
1354 let mut vector = Vec::with_capacity(dimension);
1355 for _ in 0..dimension {
1356 let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1357 vector.push(f32::from_le_bytes(bytes));
1358 pos += 4;
1359 }
1360
1361 entries.push(EmbeddingEntry {
1362 chunk: SemanticChunk {
1363 file,
1364 name,
1365 kind,
1366 start_line,
1367 end_line,
1368 exported,
1369 embed_text,
1370 snippet,
1371 },
1372 vector,
1373 });
1374 }
1375
1376 Ok(Self {
1377 entries,
1378 file_mtimes,
1379 dimension,
1380 fingerprint,
1381 })
1382 }
1383}
1384
1385fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1387 let relative = file
1388 .strip_prefix(project_root)
1389 .unwrap_or(file)
1390 .to_string_lossy();
1391
1392 let kind_label = match &symbol.kind {
1393 SymbolKind::Function => "function",
1394 SymbolKind::Class => "class",
1395 SymbolKind::Method => "method",
1396 SymbolKind::Struct => "struct",
1397 SymbolKind::Interface => "interface",
1398 SymbolKind::Enum => "enum",
1399 SymbolKind::TypeAlias => "type",
1400 SymbolKind::Variable => "variable",
1401 SymbolKind::Heading => "heading",
1402 };
1403
1404 let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1406
1407 if let Some(sig) = &symbol.signature {
1408 text.push_str(&format!(" signature:{}", sig));
1409 }
1410
1411 let lines: Vec<&str> = source.lines().collect();
1413 let start = (symbol.range.start_line as usize).min(lines.len());
1414 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1416 if start < end {
1417 let body: String = lines[start..end]
1418 .iter()
1419 .take(15) .copied()
1421 .collect::<Vec<&str>>()
1422 .join("\n");
1423 let snippet = if body.len() > 300 {
1424 format!("{}...", &body[..body.floor_char_boundary(300)])
1425 } else {
1426 body
1427 };
1428 text.push_str(&format!(" body:{}", snippet));
1429 }
1430
1431 text
1432}
1433
1434fn parser_for(
1435 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1436 lang: crate::parser::LangId,
1437) -> Result<&mut Parser, String> {
1438 use std::collections::hash_map::Entry;
1439
1440 match parsers.entry(lang) {
1441 Entry::Occupied(entry) => Ok(entry.into_mut()),
1442 Entry::Vacant(entry) => {
1443 let grammar = grammar_for(lang);
1444 let mut parser = Parser::new();
1445 parser
1446 .set_language(&grammar)
1447 .map_err(|error| error.to_string())?;
1448 Ok(entry.insert(parser))
1449 }
1450 }
1451}
1452
1453fn collect_file_chunks(
1454 project_root: &Path,
1455 file: &Path,
1456 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1457) -> Result<Vec<SemanticChunk>, String> {
1458 let lang = detect_language(file).ok_or_else(|| "unsupported file extension".to_string())?;
1459 let source = std::fs::read_to_string(file).map_err(|error| error.to_string())?;
1460 let tree = parser_for(parsers, lang)?
1461 .parse(&source, None)
1462 .ok_or_else(|| format!("tree-sitter parse returned None for {}", file.display()))?;
1463 let symbols =
1464 extract_symbols_from_tree(&source, &tree, lang).map_err(|error| error.to_string())?;
1465
1466 Ok(symbols_to_chunks(file, &symbols, &source, project_root))
1467}
1468
1469fn build_snippet(symbol: &Symbol, source: &str) -> String {
1471 let lines: Vec<&str> = source.lines().collect();
1472 let start = (symbol.range.start_line as usize).min(lines.len());
1473 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1475 if start < end {
1476 let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1477 let mut snippet = snippet_lines.join("\n");
1478 if end - start > 5 {
1479 snippet.push_str("\n ...");
1480 }
1481 if snippet.len() > 300 {
1482 snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1483 }
1484 snippet
1485 } else {
1486 String::new()
1487 }
1488}
1489
1490fn symbols_to_chunks(
1492 file: &Path,
1493 symbols: &[Symbol],
1494 source: &str,
1495 project_root: &Path,
1496) -> Vec<SemanticChunk> {
1497 let mut chunks = Vec::new();
1498
1499 for symbol in symbols {
1500 if matches!(symbol.kind, SymbolKind::Heading) {
1505 continue;
1506 }
1507
1508 let line_count = symbol
1510 .range
1511 .end_line
1512 .saturating_sub(symbol.range.start_line)
1513 + 1;
1514 if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1515 continue;
1516 }
1517
1518 let embed_text = build_embed_text(symbol, source, file, project_root);
1519 let snippet = build_snippet(symbol, source);
1520
1521 chunks.push(SemanticChunk {
1522 file: file.to_path_buf(),
1523 name: symbol.name.clone(),
1524 kind: symbol.kind.clone(),
1525 start_line: symbol.range.start_line,
1526 end_line: symbol.range.end_line,
1527 exported: symbol.exported,
1528 embed_text,
1529 snippet,
1530 });
1531
1532 }
1535
1536 chunks
1537}
1538
1539fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1541 if a.len() != b.len() {
1542 return 0.0;
1543 }
1544
1545 let mut dot = 0.0f32;
1546 let mut norm_a = 0.0f32;
1547 let mut norm_b = 0.0f32;
1548
1549 for i in 0..a.len() {
1550 dot += a[i] * b[i];
1551 norm_a += a[i] * a[i];
1552 norm_b += b[i] * b[i];
1553 }
1554
1555 let denom = norm_a.sqrt() * norm_b.sqrt();
1556 if denom == 0.0 {
1557 0.0
1558 } else {
1559 dot / denom
1560 }
1561}
1562
1563fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1565 match kind {
1566 SymbolKind::Function => 0,
1567 SymbolKind::Class => 1,
1568 SymbolKind::Method => 2,
1569 SymbolKind::Struct => 3,
1570 SymbolKind::Interface => 4,
1571 SymbolKind::Enum => 5,
1572 SymbolKind::TypeAlias => 6,
1573 SymbolKind::Variable => 7,
1574 SymbolKind::Heading => 8,
1575 }
1576}
1577
1578fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1579 match v {
1580 0 => SymbolKind::Function,
1581 1 => SymbolKind::Class,
1582 2 => SymbolKind::Method,
1583 3 => SymbolKind::Struct,
1584 4 => SymbolKind::Interface,
1585 5 => SymbolKind::Enum,
1586 6 => SymbolKind::TypeAlias,
1587 7 => SymbolKind::Variable,
1588 _ => SymbolKind::Heading,
1589 }
1590}
1591
1592fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
1593 if *pos + 4 > data.len() {
1594 return Err("unexpected end of data reading u32".to_string());
1595 }
1596 let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
1597 *pos += 4;
1598 Ok(val)
1599}
1600
1601fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
1602 if *pos + 8 > data.len() {
1603 return Err("unexpected end of data reading u64".to_string());
1604 }
1605 let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
1606 *pos += 8;
1607 Ok(u64::from_le_bytes(bytes))
1608}
1609
1610fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
1611 let len = read_u32(data, pos)? as usize;
1612 if *pos + len > data.len() {
1613 return Err("unexpected end of data reading string".to_string());
1614 }
1615 let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
1616 *pos += len;
1617 Ok(s)
1618}
1619
1620#[cfg(test)]
1621mod tests {
1622 use super::*;
1623 use crate::config::{SemanticBackend, SemanticBackendConfig};
1624 use crate::parser::FileParser;
1625 use std::io::{Read, Write};
1626 use std::net::TcpListener;
1627 use std::thread;
1628
1629 fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
1630 where
1631 F: Fn(String, String, String) -> String + Send + 'static,
1632 {
1633 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1634 let addr = listener.local_addr().expect("local addr");
1635 let handle = thread::spawn(move || {
1636 let (mut stream, _) = listener.accept().expect("accept request");
1637 let mut buf = Vec::new();
1638 let mut chunk = [0u8; 4096];
1639 let mut header_end = None;
1640 let mut content_length = 0usize;
1641 loop {
1642 let n = stream.read(&mut chunk).expect("read request");
1643 if n == 0 {
1644 break;
1645 }
1646 buf.extend_from_slice(&chunk[..n]);
1647 if header_end.is_none() {
1648 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
1649 header_end = Some(pos + 4);
1650 let headers = String::from_utf8_lossy(&buf[..pos + 4]);
1651 for line in headers.lines() {
1652 if let Some(value) = line.strip_prefix("Content-Length:") {
1653 content_length = value.trim().parse::<usize>().unwrap_or(0);
1654 }
1655 }
1656 }
1657 }
1658 if let Some(end) = header_end {
1659 if buf.len() >= end + content_length {
1660 break;
1661 }
1662 }
1663 }
1664
1665 let end = header_end.expect("header terminator");
1666 let request = String::from_utf8_lossy(&buf[..end]).to_string();
1667 let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
1668 let mut lines = request.lines();
1669 let request_line = lines.next().expect("request line").to_string();
1670 let path = request_line
1671 .split_whitespace()
1672 .nth(1)
1673 .expect("request path")
1674 .to_string();
1675 let response_body = handler(request_line, path, body);
1676 let response = format!(
1677 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
1678 response_body.len(),
1679 response_body
1680 );
1681 stream
1682 .write_all(response.as_bytes())
1683 .expect("write response");
1684 });
1685
1686 (format!("http://{}", addr), handle)
1687 }
1688
1689 #[test]
1690 fn test_cosine_similarity_identical() {
1691 let a = vec![1.0, 0.0, 0.0];
1692 let b = vec![1.0, 0.0, 0.0];
1693 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1694 }
1695
1696 #[test]
1697 fn test_cosine_similarity_orthogonal() {
1698 let a = vec![1.0, 0.0, 0.0];
1699 let b = vec![0.0, 1.0, 0.0];
1700 assert!(cosine_similarity(&a, &b).abs() < 0.001);
1701 }
1702
1703 #[test]
1704 fn test_cosine_similarity_opposite() {
1705 let a = vec![1.0, 0.0, 0.0];
1706 let b = vec![-1.0, 0.0, 0.0];
1707 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
1708 }
1709
1710 #[test]
1711 fn test_serialization_roundtrip() {
1712 let mut index = SemanticIndex::new();
1713 index.entries.push(EmbeddingEntry {
1714 chunk: SemanticChunk {
1715 file: PathBuf::from("/src/main.rs"),
1716 name: "handle_request".to_string(),
1717 kind: SymbolKind::Function,
1718 start_line: 10,
1719 end_line: 25,
1720 exported: true,
1721 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1722 snippet: "fn handle_request() {\n // ...\n}".to_string(),
1723 },
1724 vector: vec![0.1, 0.2, 0.3, 0.4],
1725 });
1726 index.dimension = 4;
1727 index
1728 .file_mtimes
1729 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1730 index.set_fingerprint(SemanticIndexFingerprint {
1731 backend: "fastembed".to_string(),
1732 model: "all-MiniLM-L6-v2".to_string(),
1733 base_url: FALLBACK_BACKEND.to_string(),
1734 dimension: 4,
1735 });
1736
1737 let bytes = index.to_bytes();
1738 let restored = SemanticIndex::from_bytes(&bytes).unwrap();
1739
1740 assert_eq!(restored.entries.len(), 1);
1741 assert_eq!(restored.entries[0].chunk.name, "handle_request");
1742 assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
1743 assert_eq!(restored.dimension, 4);
1744 assert_eq!(restored.backend_label(), Some("fastembed"));
1745 assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
1746 }
1747
1748 #[test]
1749 fn test_search_top_k() {
1750 let mut index = SemanticIndex::new();
1751 index.dimension = 3;
1752
1753 for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
1755 let mut vec = vec![0.0f32; 3];
1756 vec[i] = 1.0; index.entries.push(EmbeddingEntry {
1758 chunk: SemanticChunk {
1759 file: PathBuf::from("/src/lib.rs"),
1760 name: name.to_string(),
1761 kind: SymbolKind::Function,
1762 start_line: (i * 10 + 1) as u32,
1763 end_line: (i * 10 + 5) as u32,
1764 exported: true,
1765 embed_text: format!("kind:function name:{}", name),
1766 snippet: format!("fn {}() {{}}", name),
1767 },
1768 vector: vec,
1769 });
1770 }
1771
1772 let query = vec![0.9, 0.1, 0.0];
1774 let results = index.search(&query, 2);
1775
1776 assert_eq!(results.len(), 2);
1777 assert_eq!(results[0].name, "auth"); assert!(results[0].score > results[1].score);
1779 }
1780
1781 #[test]
1782 fn test_empty_index_search() {
1783 let index = SemanticIndex::new();
1784 let results = index.search(&[0.1, 0.2, 0.3], 10);
1785 assert!(results.is_empty());
1786 }
1787
1788 #[test]
1789 fn single_line_symbol_builds_non_empty_snippet() {
1790 let symbol = Symbol {
1791 name: "answer".to_string(),
1792 kind: SymbolKind::Variable,
1793 range: crate::symbols::Range {
1794 start_line: 0,
1795 start_col: 0,
1796 end_line: 0,
1797 end_col: 24,
1798 },
1799 signature: Some("const answer = 42".to_string()),
1800 scope_chain: Vec::new(),
1801 exported: true,
1802 parent: None,
1803 };
1804 let source = "export const answer = 42;\n";
1805
1806 let snippet = build_snippet(&symbol, source);
1807
1808 assert_eq!(snippet, "export const answer = 42;");
1809 }
1810
1811 #[test]
1812 fn optimized_file_chunk_collection_matches_file_parser_path() {
1813 let project_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1814 let file = project_root.join("src/semantic_index.rs");
1815 let source = std::fs::read_to_string(&file).unwrap();
1816
1817 let mut legacy_parser = FileParser::new();
1818 let legacy_symbols = legacy_parser.extract_symbols(&file).unwrap();
1819 let legacy_chunks = symbols_to_chunks(&file, &legacy_symbols, &source, &project_root);
1820
1821 let mut parsers = HashMap::new();
1822 let optimized_chunks = collect_file_chunks(&project_root, &file, &mut parsers).unwrap();
1823
1824 assert_eq!(
1825 chunk_fingerprint(&optimized_chunks),
1826 chunk_fingerprint(&legacy_chunks)
1827 );
1828 }
1829
1830 fn chunk_fingerprint(
1831 chunks: &[SemanticChunk],
1832 ) -> Vec<(String, SymbolKind, u32, u32, bool, String, String)> {
1833 chunks
1834 .iter()
1835 .map(|chunk| {
1836 (
1837 chunk.name.clone(),
1838 chunk.kind.clone(),
1839 chunk.start_line,
1840 chunk.end_line,
1841 chunk.exported,
1842 chunk.embed_text.clone(),
1843 chunk.snippet.clone(),
1844 )
1845 })
1846 .collect()
1847 }
1848
1849 #[test]
1850 fn rejects_oversized_dimension_during_deserialization() {
1851 let mut bytes = Vec::new();
1852 bytes.push(1u8);
1853 bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
1854 bytes.extend_from_slice(&0u32.to_le_bytes());
1855 bytes.extend_from_slice(&0u32.to_le_bytes());
1856
1857 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1858 }
1859
1860 #[test]
1861 fn rejects_oversized_entry_count_during_deserialization() {
1862 let mut bytes = Vec::new();
1863 bytes.push(1u8);
1864 bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
1865 bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
1866 bytes.extend_from_slice(&0u32.to_le_bytes());
1867
1868 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1869 }
1870
1871 #[test]
1872 fn invalidate_file_removes_entries_and_mtime() {
1873 let target = PathBuf::from("/src/main.rs");
1874 let mut index = SemanticIndex::new();
1875 index.entries.push(EmbeddingEntry {
1876 chunk: SemanticChunk {
1877 file: target.clone(),
1878 name: "main".to_string(),
1879 kind: SymbolKind::Function,
1880 start_line: 0,
1881 end_line: 1,
1882 exported: false,
1883 embed_text: "main".to_string(),
1884 snippet: "fn main() {}".to_string(),
1885 },
1886 vector: vec![1.0; DEFAULT_DIMENSION],
1887 });
1888 index
1889 .file_mtimes
1890 .insert(target.clone(), SystemTime::UNIX_EPOCH);
1891
1892 index.invalidate_file(&target);
1893
1894 assert!(index.entries.is_empty());
1895 assert!(!index.file_mtimes.contains_key(&target));
1896 }
1897
1898 #[test]
1899 fn detects_missing_onnx_runtime_from_dynamic_load_error() {
1900 let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
1901
1902 assert!(is_onnx_runtime_unavailable(message));
1903 }
1904
1905 #[test]
1906 fn formats_missing_onnx_runtime_with_install_hint() {
1907 let message = format_embedding_init_error(
1908 "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
1909 );
1910
1911 assert!(message.starts_with("ONNX Runtime not found. Install via:"));
1912 assert!(message.contains("Original error:"));
1913 }
1914
1915 #[test]
1916 fn openai_compatible_backend_embeds_with_mock_server() {
1917 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1918 assert!(request_line.starts_with("POST "));
1919 assert_eq!(path, "/v1/embeddings");
1920 "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
1921 });
1922
1923 let config = SemanticBackendConfig {
1924 backend: SemanticBackend::OpenAiCompatible,
1925 model: "test-embedding".to_string(),
1926 base_url: Some(base_url),
1927 api_key_env: None,
1928 timeout_ms: 5_000,
1929 max_batch_size: 64,
1930 };
1931
1932 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1933 let vectors = model
1934 .embed(vec!["hello".to_string(), "world".to_string()])
1935 .unwrap();
1936
1937 assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
1938 handle.join().unwrap();
1939 }
1940
1941 #[test]
1942 fn ollama_backend_embeds_with_mock_server() {
1943 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1944 assert!(request_line.starts_with("POST "));
1945 assert_eq!(path, "/api/embed");
1946 "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
1947 });
1948
1949 let config = SemanticBackendConfig {
1950 backend: SemanticBackend::Ollama,
1951 model: "embeddinggemma".to_string(),
1952 base_url: Some(base_url),
1953 api_key_env: None,
1954 timeout_ms: 5_000,
1955 max_batch_size: 64,
1956 };
1957
1958 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1959 let vectors = model
1960 .embed(vec!["hello".to_string(), "world".to_string()])
1961 .unwrap();
1962
1963 assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
1964 handle.join().unwrap();
1965 }
1966
1967 #[test]
1968 fn read_from_disk_rejects_fingerprint_mismatch() {
1969 let storage = tempfile::tempdir().unwrap();
1970 let project_key = "proj";
1971
1972 let mut index = SemanticIndex::new();
1973 index.entries.push(EmbeddingEntry {
1974 chunk: SemanticChunk {
1975 file: PathBuf::from("/src/main.rs"),
1976 name: "handle_request".to_string(),
1977 kind: SymbolKind::Function,
1978 start_line: 10,
1979 end_line: 25,
1980 exported: true,
1981 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1982 snippet: "fn handle_request() {}".to_string(),
1983 },
1984 vector: vec![0.1, 0.2, 0.3],
1985 });
1986 index.dimension = 3;
1987 index
1988 .file_mtimes
1989 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1990 index.set_fingerprint(SemanticIndexFingerprint {
1991 backend: "openai_compatible".to_string(),
1992 model: "test-embedding".to_string(),
1993 base_url: "http://127.0.0.1:1234/v1".to_string(),
1994 dimension: 3,
1995 });
1996 index.write_to_disk(storage.path(), project_key);
1997
1998 let matching = index.fingerprint().unwrap().as_string();
1999 assert!(
2000 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
2001 );
2002
2003 let mismatched = SemanticIndexFingerprint {
2004 backend: "ollama".to_string(),
2005 model: "embeddinggemma".to_string(),
2006 base_url: "http://127.0.0.1:11434".to_string(),
2007 dimension: 3,
2008 }
2009 .as_string();
2010 assert!(
2011 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
2012 );
2013 }
2014
2015 #[test]
2016 fn read_from_disk_rejects_v3_cache_for_snippet_rebuild() {
2017 let storage = tempfile::tempdir().unwrap();
2018 let project_key = "proj-v3";
2019 let dir = storage.path().join("semantic").join(project_key);
2020 fs::create_dir_all(&dir).unwrap();
2021
2022 let mut index = SemanticIndex::new();
2023 index.entries.push(EmbeddingEntry {
2024 chunk: SemanticChunk {
2025 file: PathBuf::from("/src/main.rs"),
2026 name: "handle_request".to_string(),
2027 kind: SymbolKind::Function,
2028 start_line: 0,
2029 end_line: 0,
2030 exported: true,
2031 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2032 snippet: "fn handle_request() {}".to_string(),
2033 },
2034 vector: vec![0.1, 0.2, 0.3],
2035 });
2036 index.dimension = 3;
2037 index
2038 .file_mtimes
2039 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2040 let fingerprint = SemanticIndexFingerprint {
2041 backend: "fastembed".to_string(),
2042 model: "test".to_string(),
2043 base_url: FALLBACK_BACKEND.to_string(),
2044 dimension: 3,
2045 };
2046 index.set_fingerprint(fingerprint.clone());
2047
2048 let mut bytes = index.to_bytes();
2049 bytes[0] = SEMANTIC_INDEX_VERSION_V3;
2050 fs::write(dir.join("semantic.bin"), bytes).unwrap();
2051
2052 assert!(SemanticIndex::read_from_disk(
2053 storage.path(),
2054 project_key,
2055 Some(&fingerprint.as_string())
2056 )
2057 .is_none());
2058 assert!(!dir.join("semantic.bin").exists());
2059 }
2060
2061 fn make_symbol(kind: SymbolKind, name: &str, start: u32, end: u32) -> crate::symbols::Symbol {
2062 crate::symbols::Symbol {
2063 name: name.to_string(),
2064 kind,
2065 range: crate::symbols::Range {
2066 start_line: start,
2067 start_col: 0,
2068 end_line: end,
2069 end_col: 0,
2070 },
2071 signature: None,
2072 scope_chain: Vec::new(),
2073 exported: false,
2074 parent: None,
2075 }
2076 }
2077
2078 #[test]
2083 fn symbols_to_chunks_skips_heading_symbols() {
2084 let project_root = PathBuf::from("/proj");
2085 let file = project_root.join("README.md");
2086 let source = "# Title\n\nbody text\n\n## Section\n\nmore text\n";
2087
2088 let symbols = vec![
2089 make_symbol(SymbolKind::Heading, "Title", 0, 2),
2090 make_symbol(SymbolKind::Heading, "Section", 4, 6),
2091 ];
2092
2093 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2094 assert!(
2095 chunks.is_empty(),
2096 "Heading symbols must be filtered out before embedding; got {} chunk(s)",
2097 chunks.len()
2098 );
2099 }
2100
2101 #[test]
2105 fn symbols_to_chunks_keeps_code_symbols_alongside_skipped_headings() {
2106 let project_root = PathBuf::from("/proj");
2107 let file = project_root.join("src/lib.rs");
2108 let source = "pub fn handle_request() -> bool {\n true\n}\n";
2109
2110 let symbols = vec![
2111 make_symbol(SymbolKind::Heading, "doc heading", 0, 1),
2113 make_symbol(SymbolKind::Function, "handle_request", 0, 2),
2114 make_symbol(SymbolKind::Struct, "AuthService", 4, 6),
2115 ];
2116
2117 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2118 assert_eq!(
2119 chunks.len(),
2120 2,
2121 "Expected 2 code chunks (Function + Struct), got {}",
2122 chunks.len()
2123 );
2124 let names: Vec<&str> = chunks.iter().map(|c| c.name.as_str()).collect();
2125 assert!(names.contains(&"handle_request"));
2126 assert!(names.contains(&"AuthService"));
2127 assert!(
2128 !names.contains(&"doc heading"),
2129 "Heading symbol leaked into chunks: {names:?}"
2130 );
2131 }
2132}