1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::{detect_language, extract_symbols_from_tree, grammar_for};
3use crate::symbols::{Symbol, SymbolKind};
4use crate::{slog_info, slog_warn};
5
6use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
7use rayon::prelude::*;
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::env;
12use std::fmt::Display;
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16use std::time::SystemTime;
17use tree_sitter::Parser;
18use url::Url;
19
20const DEFAULT_DIMENSION: usize = 384;
21const MAX_ENTRIES: usize = 1_000_000;
22const MAX_DIMENSION: usize = 1024;
23const F32_BYTES: usize = std::mem::size_of::<f32>();
24const HEADER_BYTES_V1: usize = 9;
25const HEADER_BYTES_V2: usize = 13;
26const ONNX_RUNTIME_INSTALL_HINT: &str =
27 "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
28
29const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
30const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
31const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
36const SEMANTIC_INDEX_VERSION_V4: u8 = 4;
39const SEMANTIC_INDEX_VERSION_V5: u8 = 5;
42const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
43const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
44const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
46const DEFAULT_MAX_BATCH_SIZE: usize = 64;
47const FALLBACK_BACKEND: &str = "none";
48const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
49const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct SemanticIndexFingerprint {
53 pub backend: String,
54 pub model: String,
55 #[serde(default)]
56 pub base_url: String,
57 pub dimension: usize,
58}
59
60impl SemanticIndexFingerprint {
61 fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
62 let base_url = config
65 .base_url
66 .as_ref()
67 .and_then(|u| normalize_base_url(u).ok())
68 .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
69 Self {
70 backend: config.backend.as_str().to_string(),
71 model: config.model.clone(),
72 base_url,
73 dimension,
74 }
75 }
76
77 pub fn as_string(&self) -> String {
78 serde_json::to_string(self).unwrap_or_else(|_| String::new())
79 }
80
81 fn matches_expected(&self, expected: &str) -> bool {
82 let encoded = self.as_string();
83 !encoded.is_empty() && encoded == expected
84 }
85}
86
87enum SemanticEmbeddingEngine {
88 Fastembed(TextEmbedding),
89 OpenAiCompatible {
90 client: Client,
91 model: String,
92 base_url: String,
93 api_key: Option<String>,
94 },
95 Ollama {
96 client: Client,
97 model: String,
98 base_url: String,
99 },
100}
101
102pub struct SemanticEmbeddingModel {
103 backend: SemanticBackend,
104 model: String,
105 base_url: Option<String>,
106 timeout_ms: u64,
107 max_batch_size: usize,
108 dimension: Option<usize>,
109 engine: SemanticEmbeddingEngine,
110}
111
112pub type EmbeddingModel = SemanticEmbeddingModel;
113
114fn validate_embedding_batch(
115 vectors: &[Vec<f32>],
116 expected_count: usize,
117 context: &str,
118) -> Result<(), String> {
119 if expected_count > 0 && vectors.is_empty() {
120 return Err(format!(
121 "{context} returned no vectors for {expected_count} inputs"
122 ));
123 }
124
125 if vectors.len() != expected_count {
126 return Err(format!(
127 "{context} returned {} vectors for {} inputs",
128 vectors.len(),
129 expected_count
130 ));
131 }
132
133 let Some(first_vector) = vectors.first() else {
134 return Ok(());
135 };
136 let expected_dimension = first_vector.len();
137 for (index, vector) in vectors.iter().enumerate() {
138 if vector.len() != expected_dimension {
139 return Err(format!(
140 "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
141 vector.len()
142 ));
143 }
144 }
145
146 Ok(())
147}
148
149fn normalize_base_url(raw: &str) -> Result<String, String> {
153 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
154 let scheme = parsed.scheme();
155 if scheme != "http" && scheme != "https" {
156 return Err(format!(
157 "unsupported URL scheme '{}' — only http:// and https:// are allowed",
158 scheme
159 ));
160 }
161 Ok(parsed.to_string().trim_end_matches('/').to_string())
162}
163
164pub fn validate_base_url_no_ssrf(raw: &str) -> Result<(), String> {
179 use std::net::{IpAddr, ToSocketAddrs};
180
181 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
182
183 let host = parsed.host_str().unwrap_or("");
184
185 let is_loopback_host =
190 host == "localhost" || host == "localhost.localdomain" || host.ends_with(".localhost");
191 if is_loopback_host {
192 return Ok(());
193 }
194
195 if host.ends_with(".local") {
198 return Err(format!(
199 "base_url host '{host}' is an mDNS name — only loopback (localhost / 127.0.0.1) and public endpoints are allowed"
200 ));
201 }
202
203 let port = parsed.port_or_known_default().unwrap_or(443);
206 let addr_str = format!("{host}:{port}");
207 let addrs: Vec<IpAddr> = addr_str
208 .to_socket_addrs()
209 .map(|iter| iter.map(|sa| sa.ip()).collect())
210 .unwrap_or_default();
211 for ip in &addrs {
212 if is_private_non_loopback_ip(ip) {
213 return Err(format!(
214 "base_url '{raw}' resolves to a private/reserved IP — only loopback (127.0.0.1) and public endpoints are allowed"
215 ));
216 }
217 }
218
219 Ok(())
220}
221
222fn is_private_non_loopback_ip(ip: &std::net::IpAddr) -> bool {
226 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
227 match ip {
228 IpAddr::V4(v4) => {
229 let o = v4.octets();
230 o[0] == 10
233 || (o[0] == 172 && (16..=31).contains(&o[1]))
235 || (o[0] == 192 && o[1] == 168)
237 || (o[0] == 169 && o[1] == 254)
239 || (o[0] == 100 && (64..=127).contains(&o[1]))
241 || o[0] == 0
243 }
244 IpAddr::V6(v6) => {
245 let _ = Ipv6Addr::LOCALHOST; (v6.segments()[0] & 0xffc0) == 0xfe80
249 || (v6.segments()[0] & 0xfe00) == 0xfc00
251 || (v6.segments()[0] == 0 && v6.segments()[1] == 0
253 && v6.segments()[2] == 0 && v6.segments()[3] == 0
254 && v6.segments()[4] == 0 && v6.segments()[5] == 0xffff
255 && {
256 let [a, b] = v6.segments()[6..8] else { return false; };
257 let ipv4 = Ipv4Addr::new((a >> 8) as u8, (a & 0xff) as u8, (b >> 8) as u8, (b & 0xff) as u8);
258 is_private_non_loopback_ip(&IpAddr::V4(ipv4))
259 })
260 }
261 }
262}
263
264fn build_openai_embeddings_endpoint(base_url: &str) -> String {
265 if base_url.ends_with("/v1") {
266 format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
267 } else {
268 format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
269 }
270}
271
272fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
273 if base_url.ends_with("/api") {
274 format!("{base_url}/embed")
275 } else {
276 format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
277 }
278}
279
280fn normalize_api_key(value: Option<String>) -> Option<String> {
281 value.and_then(|token| {
282 let token = token.trim();
283 if token.is_empty() {
284 None
285 } else {
286 Some(token.to_string())
287 }
288 })
289}
290
291fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
292 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
293}
294
295fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
296 error.is_connect()
297}
298
299fn sleep_before_embedding_retry(attempt_index: usize) {
300 if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
301 std::thread::sleep(Duration::from_millis(*delay_ms));
302 }
303}
304
305fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
306where
307 F: FnMut() -> reqwest::blocking::RequestBuilder,
308{
309 for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
310 let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
311
312 let response = match make_request().send() {
313 Ok(response) => response,
314 Err(error) => {
315 if !last_attempt && is_retryable_embedding_error(&error) {
316 sleep_before_embedding_retry(attempt_index);
317 continue;
318 }
319 return Err(format!("{backend_label} request failed: {error}"));
320 }
321 };
322
323 let status = response.status();
324 let raw = match response.text() {
325 Ok(raw) => raw,
326 Err(error) => {
327 if !last_attempt && is_retryable_embedding_error(&error) {
328 sleep_before_embedding_retry(attempt_index);
329 continue;
330 }
331 return Err(format!("{backend_label} response read failed: {error}"));
332 }
333 };
334
335 if status.is_success() {
336 return Ok(raw);
337 }
338
339 if !last_attempt && is_retryable_embedding_status(status) {
340 sleep_before_embedding_retry(attempt_index);
341 continue;
342 }
343
344 return Err(format!(
345 "{backend_label} request failed (HTTP {}): {}",
346 status, raw
347 ));
348 }
349
350 unreachable!("embedding request retries exhausted without returning")
351}
352
353impl SemanticEmbeddingModel {
354 pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
355 let timeout_ms = if config.timeout_ms == 0 {
356 DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
357 } else {
358 config.timeout_ms
359 };
360
361 let max_batch_size = if config.max_batch_size == 0 {
362 DEFAULT_MAX_BATCH_SIZE
363 } else {
364 config.max_batch_size
365 };
366
367 let api_key_env = normalize_api_key(config.api_key_env.clone());
368 let model = config.model.clone();
369
370 let client = Client::builder()
371 .timeout(Duration::from_millis(timeout_ms))
372 .redirect(reqwest::redirect::Policy::none())
373 .build()
374 .map_err(|error| format!("failed to configure embedding client: {error}"))?;
375
376 let engine = match config.backend {
377 SemanticBackend::Fastembed => {
378 SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
379 }
380 SemanticBackend::OpenAiCompatible => {
381 let raw = config.base_url.as_ref().ok_or_else(|| {
382 "base_url is required for openai_compatible backend".to_string()
383 })?;
384 let base_url = normalize_base_url(raw)?;
385
386 let api_key = match api_key_env {
387 Some(var_name) => Some(env::var(&var_name).map_err(|_| {
388 format!("missing api_key_env '{var_name}' for openai_compatible backend")
389 })?),
390 None => None,
391 };
392
393 SemanticEmbeddingEngine::OpenAiCompatible {
394 client,
395 model,
396 base_url,
397 api_key,
398 }
399 }
400 SemanticBackend::Ollama => {
401 let raw = config
402 .base_url
403 .as_ref()
404 .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
405 let base_url = normalize_base_url(raw)?;
406
407 SemanticEmbeddingEngine::Ollama {
408 client,
409 model,
410 base_url,
411 }
412 }
413 };
414
415 Ok(Self {
416 backend: config.backend,
417 model: config.model.clone(),
418 base_url: config.base_url.clone(),
419 timeout_ms,
420 max_batch_size,
421 dimension: None,
422 engine,
423 })
424 }
425
426 pub fn backend(&self) -> SemanticBackend {
427 self.backend
428 }
429
430 pub fn model(&self) -> &str {
431 &self.model
432 }
433
434 pub fn base_url(&self) -> Option<&str> {
435 self.base_url.as_deref()
436 }
437
438 pub fn max_batch_size(&self) -> usize {
439 self.max_batch_size
440 }
441
442 pub fn timeout_ms(&self) -> u64 {
443 self.timeout_ms
444 }
445
446 pub fn fingerprint(
447 &mut self,
448 config: &SemanticBackendConfig,
449 ) -> Result<SemanticIndexFingerprint, String> {
450 let dimension = self.dimension()?;
451 Ok(SemanticIndexFingerprint::from_config(config, dimension))
452 }
453
454 pub fn dimension(&mut self) -> Result<usize, String> {
455 if let Some(dimension) = self.dimension {
456 return Ok(dimension);
457 }
458
459 let dimension = match &mut self.engine {
460 SemanticEmbeddingEngine::Fastembed(model) => {
461 let vectors = model
462 .embed(vec!["semantic index fingerprint probe".to_string()], None)
463 .map_err(|error| format_embedding_init_error(error.to_string()))?;
464 vectors
465 .first()
466 .map(|v| v.len())
467 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
468 }
469 SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
470 let vectors =
471 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
472 vectors
473 .first()
474 .map(|v| v.len())
475 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
476 }
477 SemanticEmbeddingEngine::Ollama { .. } => {
478 let vectors =
479 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
480 vectors
481 .first()
482 .map(|v| v.len())
483 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
484 }
485 };
486
487 self.dimension = Some(dimension);
488 Ok(dimension)
489 }
490
491 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
492 self.embed_texts(texts)
493 }
494
495 fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
496 match &mut self.engine {
497 SemanticEmbeddingEngine::Fastembed(model) => model
498 .embed(texts, None::<usize>)
499 .map_err(|error| format_embedding_init_error(error.to_string()))
500 .map_err(|error| format!("failed to embed batch: {error}")),
501 SemanticEmbeddingEngine::OpenAiCompatible {
502 client,
503 model,
504 base_url,
505 api_key,
506 } => {
507 let expected_text_count = texts.len();
508 let endpoint = build_openai_embeddings_endpoint(base_url);
509 let body = serde_json::json!({
510 "input": texts,
511 "model": model,
512 });
513
514 let raw = send_embedding_request(
515 || {
516 let mut request = client
517 .post(&endpoint)
518 .json(&body)
519 .header("Content-Type", "application/json");
520
521 if let Some(api_key) = api_key {
522 request = request.header("Authorization", format!("Bearer {api_key}"));
523 }
524
525 request
526 },
527 "openai compatible",
528 )?;
529
530 #[derive(Deserialize)]
531 struct OpenAiResponse {
532 data: Vec<OpenAiEmbeddingResult>,
533 }
534
535 #[derive(Deserialize)]
536 struct OpenAiEmbeddingResult {
537 embedding: Vec<f32>,
538 index: Option<u32>,
539 }
540
541 let parsed: OpenAiResponse = serde_json::from_str(&raw)
542 .map_err(|error| format!("invalid openai compatible response: {error}"))?;
543 if parsed.data.len() != expected_text_count {
544 return Err(format!(
545 "openai compatible response returned {} embeddings for {} inputs",
546 parsed.data.len(),
547 expected_text_count
548 ));
549 }
550
551 let mut vectors = vec![Vec::new(); parsed.data.len()];
552 for (i, item) in parsed.data.into_iter().enumerate() {
553 let index = item.index.unwrap_or(i as u32) as usize;
554 if index >= vectors.len() {
555 return Err(
556 "openai compatible response contains invalid vector index".to_string()
557 );
558 }
559 vectors[index] = item.embedding;
560 }
561
562 for vector in &vectors {
563 if vector.is_empty() {
564 return Err(
565 "openai compatible response contained missing vectors".to_string()
566 );
567 }
568 }
569
570 self.dimension = vectors.first().map(Vec::len);
571 Ok(vectors)
572 }
573 SemanticEmbeddingEngine::Ollama {
574 client,
575 model,
576 base_url,
577 } => {
578 let expected_text_count = texts.len();
579 let endpoint = build_ollama_embeddings_endpoint(base_url);
580
581 #[derive(Serialize)]
582 struct OllamaPayload<'a> {
583 model: &'a str,
584 input: Vec<String>,
585 }
586
587 let payload = OllamaPayload {
588 model,
589 input: texts,
590 };
591
592 let raw = send_embedding_request(
593 || {
594 client
595 .post(&endpoint)
596 .json(&payload)
597 .header("Content-Type", "application/json")
598 },
599 "ollama",
600 )?;
601
602 #[derive(Deserialize)]
603 struct OllamaResponse {
604 embeddings: Vec<Vec<f32>>,
605 }
606
607 let parsed: OllamaResponse = serde_json::from_str(&raw)
608 .map_err(|error| format!("invalid ollama response: {error}"))?;
609 if parsed.embeddings.is_empty() {
610 return Err("ollama response returned no embeddings".to_string());
611 }
612 if parsed.embeddings.len() != expected_text_count {
613 return Err(format!(
614 "ollama response returned {} embeddings for {} inputs",
615 parsed.embeddings.len(),
616 expected_text_count
617 ));
618 }
619
620 let vectors = parsed.embeddings;
621 for vector in &vectors {
622 if vector.is_empty() {
623 return Err("ollama response contained empty embeddings".to_string());
624 }
625 }
626
627 self.dimension = vectors.first().map(Vec::len);
628 Ok(vectors)
629 }
630 }
631 }
632}
633
634pub fn pre_validate_onnx_runtime() -> Result<(), String> {
638 let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
639
640 #[cfg(any(target_os = "linux", target_os = "macos"))]
641 {
642 #[cfg(target_os = "linux")]
643 let default_name = "libonnxruntime.so";
644 #[cfg(target_os = "macos")]
645 let default_name = "libonnxruntime.dylib";
646
647 let lib_name = dylib_path.as_deref().unwrap_or(default_name);
648
649 unsafe {
650 let c_name = std::ffi::CString::new(lib_name)
651 .map_err(|e| format!("invalid library path: {}", e))?;
652 let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
653 if handle.is_null() {
654 let err = libc::dlerror();
655 let msg = if err.is_null() {
656 "unknown dlopen error".to_string()
657 } else {
658 std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
659 };
660 return Err(format!(
661 "ONNX Runtime not found. dlopen('{}') failed: {}. \
662 Run `npx @cortexkit/aft doctor` to diagnose.",
663 lib_name, msg
664 ));
665 }
666
667 let detected_version = detect_ort_version_from_path(lib_name);
670
671 libc::dlclose(handle);
672
673 if let Some(ref version) = detected_version {
675 let parts: Vec<&str> = version.split('.').collect();
676 if let (Some(major), Some(minor)) = (
677 parts.first().and_then(|s| s.parse::<u32>().ok()),
678 parts.get(1).and_then(|s| s.parse::<u32>().ok()),
679 ) {
680 if major != 1 || minor < 20 {
681 return Err(format_ort_version_mismatch(version, lib_name));
682 }
683 }
684 }
685 }
686 }
687
688 #[cfg(target_os = "windows")]
689 {
690 let _ = dylib_path;
692 }
693
694 Ok(())
695}
696
697fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
700 let path = std::path::Path::new(lib_path);
701
702 for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
704 .into_iter()
705 .flatten()
706 {
707 if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
708 if let Some(version) = extract_version_from_filename(name) {
709 return Some(version);
710 }
711 }
712 }
713
714 if let Some(parent) = path.parent() {
716 if let Ok(entries) = std::fs::read_dir(parent) {
717 for entry in entries.flatten() {
718 if let Some(name) = entry.file_name().to_str() {
719 if name.starts_with("libonnxruntime") {
720 if let Some(version) = extract_version_from_filename(name) {
721 return Some(version);
722 }
723 }
724 }
725 }
726 }
727 }
728
729 None
730}
731
732fn extract_version_from_filename(name: &str) -> Option<String> {
734 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
736 re.find(name).map(|m| m.as_str().to_string())
737}
738
739fn suggest_removal_command(lib_path: &str) -> String {
740 if lib_path.starts_with("/usr/local/lib")
741 || lib_path == "libonnxruntime.so"
742 || lib_path == "libonnxruntime.dylib"
743 {
744 #[cfg(target_os = "linux")]
745 return " sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
746 #[cfg(target_os = "macos")]
747 return " sudo rm /usr/local/lib/libonnxruntime*".to_string();
748 #[cfg(target_os = "windows")]
749 return " Delete the ONNX Runtime DLL from your PATH".to_string();
750 }
751 format!(" rm '{}'", lib_path)
752}
753
754pub(crate) fn format_ort_version_mismatch(version: &str, lib_name: &str) -> String {
760 format!(
761 "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
762 Solutions:\n\
763 1. Auto-fix (recommended): run `npx @cortexkit/aft doctor --fix`. \
764 This downloads AFT-managed ONNX Runtime v1.24 into AFT's storage and \
765 configures the bridge to load it instead of the system library — no \
766 changes to '{}'.\n\
767 2. Remove the old library and restart (AFT auto-downloads the correct version on next start):\n\
768 {}\n\
769 3. Or install ONNX Runtime 1.24 system-wide: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
770 4. Run `npx @cortexkit/aft doctor` for full diagnostics.",
771 version,
772 lib_name,
773 lib_name,
774 suggest_removal_command(lib_name),
775 )
776}
777
778pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
779 pre_validate_onnx_runtime()?;
781
782 let selected_model = match model {
783 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
784 _ => {
785 return Err(format!(
786 "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
787 model
788 ))
789 }
790 };
791
792 TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
793}
794
795pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
796 if message.trim_start().starts_with("ONNX Runtime not found.") {
797 return true;
798 }
799
800 let message = message.to_ascii_lowercase();
801 let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
802 .iter()
803 .any(|pattern| message.contains(pattern));
804 let mentions_dynamic_load_failure = [
805 "shared library",
806 "dynamic library",
807 "failed to load",
808 "could not load",
809 "unable to load",
810 "dlopen",
811 "loadlibrary",
812 "no such file",
813 "not found",
814 ]
815 .iter()
816 .any(|pattern| message.contains(pattern));
817
818 mentions_onnx_runtime && mentions_dynamic_load_failure
819}
820
821fn format_embedding_init_error(error: impl Display) -> String {
822 let message = error.to_string();
823
824 if is_onnx_runtime_unavailable(&message) {
825 return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
826 }
827
828 format!("failed to initialize semantic embedding model: {message}")
829}
830
831#[derive(Debug, Clone)]
833pub struct SemanticChunk {
834 pub file: PathBuf,
836 pub name: String,
838 pub kind: SymbolKind,
840 pub start_line: u32,
842 pub end_line: u32,
843 pub exported: bool,
845 pub embed_text: String,
847 pub snippet: String,
849}
850
851#[derive(Debug)]
853struct EmbeddingEntry {
854 chunk: SemanticChunk,
855 vector: Vec<f32>,
856}
857
858#[derive(Debug)]
860pub struct SemanticIndex {
861 entries: Vec<EmbeddingEntry>,
862 file_mtimes: HashMap<PathBuf, SystemTime>,
864 file_sizes: HashMap<PathBuf, u64>,
866 dimension: usize,
868 fingerprint: Option<SemanticIndexFingerprint>,
869}
870
871#[derive(Debug, Clone, Copy)]
872struct IndexedFileMetadata {
873 mtime: SystemTime,
874 size: u64,
875}
876
877#[derive(Debug, Default, Clone, Copy)]
880pub struct RefreshSummary {
881 pub changed: usize,
882 pub added: usize,
883 pub deleted: usize,
884 pub total_processed: usize,
885}
886
887impl RefreshSummary {
888 pub fn is_noop(&self) -> bool {
890 self.changed == 0 && self.added == 0 && self.deleted == 0
891 }
892}
893
894#[derive(Debug)]
896pub struct SemanticResult {
897 pub file: PathBuf,
898 pub name: String,
899 pub kind: SymbolKind,
900 pub start_line: u32,
901 pub end_line: u32,
902 pub exported: bool,
903 pub snippet: String,
904 pub score: f32,
905}
906
907impl SemanticIndex {
908 pub fn new() -> Self {
909 Self {
910 entries: Vec::new(),
911 file_mtimes: HashMap::new(),
912 file_sizes: HashMap::new(),
913 dimension: DEFAULT_DIMENSION, fingerprint: None,
915 }
916 }
917
918 pub fn entry_count(&self) -> usize {
920 self.entries.len()
921 }
922
923 pub fn status_label(&self) -> &'static str {
925 if self.entries.is_empty() {
926 "empty"
927 } else {
928 "ready"
929 }
930 }
931
932 fn collect_chunks(
933 project_root: &Path,
934 files: &[PathBuf],
935 ) -> (Vec<SemanticChunk>, HashMap<PathBuf, IndexedFileMetadata>) {
936 let per_file: Vec<(
937 PathBuf,
938 Result<(IndexedFileMetadata, Vec<SemanticChunk>), String>,
939 )> = files
940 .par_iter()
941 .map_init(HashMap::new, |parsers, file| {
942 let result = collect_file_metadata(file).and_then(|metadata| {
943 collect_file_chunks(project_root, file, parsers)
944 .map(|chunks| (metadata, chunks))
945 });
946 (file.clone(), result)
947 })
948 .collect();
949
950 let mut chunks: Vec<SemanticChunk> = Vec::new();
951 let mut file_metadata: HashMap<PathBuf, IndexedFileMetadata> = HashMap::new();
952
953 for (file, result) in per_file {
954 match result {
955 Ok((metadata, file_chunks)) => {
956 file_metadata.insert(file, metadata);
957 chunks.extend(file_chunks);
958 }
959 Err(error) => {
960 if error == "unsupported file extension" {
966 continue;
967 }
968 slog_warn!(
969 "failed to collect semantic chunks for {}: {}",
970 file.display(),
971 error
972 );
973 }
974 }
975 }
976
977 (chunks, file_metadata)
978 }
979
980 fn build_from_chunks<F, P>(
981 chunks: Vec<SemanticChunk>,
982 file_metadata: HashMap<PathBuf, IndexedFileMetadata>,
983 embed_fn: &mut F,
984 max_batch_size: usize,
985 mut progress: Option<&mut P>,
986 ) -> Result<Self, String>
987 where
988 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
989 P: FnMut(usize, usize),
990 {
991 let total_chunks = chunks.len();
992
993 if chunks.is_empty() {
994 return Ok(Self {
995 entries: Vec::new(),
996 file_mtimes: file_metadata
997 .iter()
998 .map(|(path, metadata)| (path.clone(), metadata.mtime))
999 .collect(),
1000 file_sizes: file_metadata
1001 .into_iter()
1002 .map(|(path, metadata)| (path, metadata.size))
1003 .collect(),
1004 dimension: DEFAULT_DIMENSION,
1005 fingerprint: None,
1006 });
1007 }
1008
1009 let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
1011 let mut expected_dimension: Option<usize> = None;
1012 let batch_size = max_batch_size.max(1);
1013 for batch_start in (0..chunks.len()).step_by(batch_size) {
1014 let batch_end = (batch_start + batch_size).min(chunks.len());
1015 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
1016 .iter()
1017 .map(|c| c.embed_text.clone())
1018 .collect();
1019
1020 let vectors = embed_fn(batch_texts)?;
1021 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
1022
1023 if let Some(dim) = vectors.first().map(|v| v.len()) {
1025 match expected_dimension {
1026 None => expected_dimension = Some(dim),
1027 Some(expected) if dim != expected => {
1028 return Err(format!(
1029 "embedding dimension changed across batches: expected {expected}, got {dim}"
1030 ));
1031 }
1032 _ => {}
1033 }
1034 }
1035
1036 for (i, vector) in vectors.into_iter().enumerate() {
1037 let chunk_idx = batch_start + i;
1038 entries.push(EmbeddingEntry {
1039 chunk: chunks[chunk_idx].clone(),
1040 vector,
1041 });
1042 }
1043
1044 if let Some(callback) = progress.as_mut() {
1045 callback(entries.len(), total_chunks);
1046 }
1047 }
1048
1049 let dimension = entries
1050 .first()
1051 .map(|e| e.vector.len())
1052 .unwrap_or(DEFAULT_DIMENSION);
1053
1054 Ok(Self {
1055 entries,
1056 file_mtimes: file_metadata
1057 .iter()
1058 .map(|(path, metadata)| (path.clone(), metadata.mtime))
1059 .collect(),
1060 file_sizes: file_metadata
1061 .into_iter()
1062 .map(|(path, metadata)| (path, metadata.size))
1063 .collect(),
1064 dimension,
1065 fingerprint: None,
1066 })
1067 }
1068
1069 pub fn build<F>(
1072 project_root: &Path,
1073 files: &[PathBuf],
1074 embed_fn: &mut F,
1075 max_batch_size: usize,
1076 ) -> Result<Self, String>
1077 where
1078 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1079 {
1080 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1081 Self::build_from_chunks(
1082 chunks,
1083 file_mtimes,
1084 embed_fn,
1085 max_batch_size,
1086 Option::<&mut fn(usize, usize)>::None,
1087 )
1088 }
1089
1090 pub fn build_with_progress<F, P>(
1092 project_root: &Path,
1093 files: &[PathBuf],
1094 embed_fn: &mut F,
1095 max_batch_size: usize,
1096 progress: &mut P,
1097 ) -> Result<Self, String>
1098 where
1099 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1100 P: FnMut(usize, usize),
1101 {
1102 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1103 let total_chunks = chunks.len();
1104 progress(0, total_chunks);
1105 Self::build_from_chunks(
1106 chunks,
1107 file_mtimes,
1108 embed_fn,
1109 max_batch_size,
1110 Some(progress),
1111 )
1112 }
1113
1114 pub fn refresh_stale_files<F, P>(
1125 &mut self,
1126 project_root: &Path,
1127 current_files: &[PathBuf],
1128 embed_fn: &mut F,
1129 max_batch_size: usize,
1130 progress: &mut P,
1131 ) -> Result<RefreshSummary, String>
1132 where
1133 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1134 P: FnMut(usize, usize),
1135 {
1136 self.backfill_missing_file_sizes();
1137
1138 let current_set: HashSet<&Path> = current_files.iter().map(PathBuf::as_path).collect();
1140 let total_processed = current_set.len() + self.file_mtimes.len()
1141 - self
1142 .file_mtimes
1143 .keys()
1144 .filter(|path| current_set.contains(path.as_path()))
1145 .count();
1146
1147 let mut deleted: Vec<PathBuf> = Vec::new();
1150 let mut changed: Vec<PathBuf> = Vec::new();
1151 for indexed_path in self.file_mtimes.keys() {
1152 if !current_set.contains(indexed_path.as_path()) {
1153 deleted.push(indexed_path.clone());
1154 continue;
1155 }
1156 if self.is_file_stale(indexed_path) {
1157 changed.push(indexed_path.clone());
1158 }
1159 }
1160
1161 let mut added: Vec<PathBuf> = Vec::new();
1163 for path in current_files {
1164 if !self.file_mtimes.contains_key(path) {
1165 added.push(path.clone());
1166 }
1167 }
1168
1169 if deleted.is_empty() && changed.is_empty() && added.is_empty() {
1171 progress(0, 0);
1172 return Ok(RefreshSummary {
1173 total_processed,
1174 ..RefreshSummary::default()
1175 });
1176 }
1177
1178 if !deleted.is_empty() {
1182 let deleted_set: HashSet<&Path> = deleted.iter().map(PathBuf::as_path).collect();
1183 self.entries
1184 .retain(|entry| !deleted_set.contains(entry.chunk.file.as_path()));
1185 for path in &deleted {
1186 self.file_mtimes.remove(path);
1187 self.file_sizes.remove(path);
1188 }
1189 }
1190
1191 let mut to_embed: Vec<PathBuf> = Vec::with_capacity(changed.len() + added.len());
1193 to_embed.extend(changed.iter().cloned());
1194 to_embed.extend(added.iter().cloned());
1195
1196 if to_embed.is_empty() {
1197 progress(0, 0);
1199 return Ok(RefreshSummary {
1200 changed: 0,
1201 added: 0,
1202 deleted: deleted.len(),
1203 total_processed,
1204 });
1205 }
1206
1207 let (chunks, fresh_metadata) = Self::collect_chunks(project_root, &to_embed);
1208
1209 if chunks.is_empty() {
1210 progress(0, 0);
1211 let successful_files: HashSet<PathBuf> = fresh_metadata.keys().cloned().collect();
1212 if !successful_files.is_empty() {
1213 self.entries
1214 .retain(|entry| !successful_files.contains(&entry.chunk.file));
1215 }
1216 let changed_count = changed
1217 .iter()
1218 .filter(|path| successful_files.contains(*path))
1219 .count();
1220 let added_count = added
1221 .iter()
1222 .filter(|path| successful_files.contains(*path))
1223 .count();
1224 for (file, metadata) in fresh_metadata {
1225 self.file_mtimes.insert(file.clone(), metadata.mtime);
1226 self.file_sizes.insert(file, metadata.size);
1227 }
1228 return Ok(RefreshSummary {
1229 changed: changed_count,
1230 added: added_count,
1231 deleted: deleted.len(),
1232 total_processed,
1233 });
1234 }
1235
1236 let total_chunks = chunks.len();
1238 progress(0, total_chunks);
1239 let batch_size = max_batch_size.max(1);
1240 let existing_dimension = if self.entries.is_empty() {
1241 None
1242 } else {
1243 Some(self.dimension)
1244 };
1245 let mut new_entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
1246 let mut observed_dimension: Option<usize> = existing_dimension;
1247
1248 for batch_start in (0..chunks.len()).step_by(batch_size) {
1249 let batch_end = (batch_start + batch_size).min(chunks.len());
1250 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
1251 .iter()
1252 .map(|c| c.embed_text.clone())
1253 .collect();
1254
1255 let vectors = embed_fn(batch_texts)?;
1256 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
1257
1258 if let Some(dim) = vectors.first().map(|v| v.len()) {
1259 match observed_dimension {
1260 None => observed_dimension = Some(dim),
1261 Some(expected) if dim != expected => {
1262 return Err(format!(
1265 "embedding dimension changed during incremental refresh: \
1266 cached index uses {expected}, new vectors use {dim}"
1267 ));
1268 }
1269 _ => {}
1270 }
1271 }
1272
1273 for (i, vector) in vectors.into_iter().enumerate() {
1274 let chunk_idx = batch_start + i;
1275 new_entries.push(EmbeddingEntry {
1276 chunk: chunks[chunk_idx].clone(),
1277 vector,
1278 });
1279 }
1280
1281 progress(new_entries.len(), total_chunks);
1282 }
1283
1284 let successful_files: HashSet<PathBuf> = fresh_metadata.keys().cloned().collect();
1285 if !successful_files.is_empty() {
1286 self.entries
1287 .retain(|entry| !successful_files.contains(&entry.chunk.file));
1288 }
1289
1290 self.entries.extend(new_entries);
1291 for (file, metadata) in fresh_metadata {
1292 self.file_mtimes.insert(file.clone(), metadata.mtime);
1293 self.file_sizes.insert(file, metadata.size);
1294 }
1295 if let Some(dim) = observed_dimension {
1296 self.dimension = dim;
1297 }
1298
1299 Ok(RefreshSummary {
1300 changed: changed
1301 .iter()
1302 .filter(|path| successful_files.contains(*path))
1303 .count(),
1304 added: added
1305 .iter()
1306 .filter(|path| successful_files.contains(*path))
1307 .count(),
1308 deleted: deleted.len(),
1309 total_processed,
1310 })
1311 }
1312
1313 pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
1315 if self.entries.is_empty() || query_vector.len() != self.dimension {
1316 return Vec::new();
1317 }
1318
1319 let mut scored: Vec<(f32, usize)> = self
1320 .entries
1321 .iter()
1322 .enumerate()
1323 .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
1324 .collect();
1325
1326 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
1328
1329 scored
1330 .into_iter()
1331 .take(top_k)
1332 .filter(|(score, _)| *score > 0.0)
1333 .map(|(score, idx)| {
1334 let entry = &self.entries[idx];
1335 SemanticResult {
1336 file: entry.chunk.file.clone(),
1337 name: entry.chunk.name.clone(),
1338 kind: entry.chunk.kind.clone(),
1339 start_line: entry.chunk.start_line,
1340 end_line: entry.chunk.end_line,
1341 exported: entry.chunk.exported,
1342 snippet: entry.chunk.snippet.clone(),
1343 score,
1344 }
1345 })
1346 .collect()
1347 }
1348
1349 pub fn len(&self) -> usize {
1351 self.entries.len()
1352 }
1353
1354 pub fn is_file_stale(&self, file: &Path) -> bool {
1356 let Some(stored_mtime) = self.file_mtimes.get(file) else {
1357 return true;
1358 };
1359 let Some(stored_size) = self.file_sizes.get(file) else {
1360 return true;
1361 };
1362 match collect_file_metadata(file) {
1363 Ok(current) => *stored_mtime != current.mtime || *stored_size != current.size,
1364 Err(_) => true,
1365 }
1366 }
1367
1368 fn backfill_missing_file_sizes(&mut self) {
1369 for path in self.file_mtimes.keys() {
1370 if self.file_sizes.contains_key(path) {
1371 continue;
1372 }
1373 if let Ok(metadata) = fs::metadata(path) {
1374 self.file_sizes.insert(path.clone(), metadata.len());
1375 }
1376 }
1377 }
1378
1379 pub fn remove_file(&mut self, file: &Path) {
1381 self.invalidate_file(file);
1382 }
1383
1384 pub fn invalidate_file(&mut self, file: &Path) {
1385 self.entries.retain(|e| e.chunk.file != file);
1386 self.file_mtimes.remove(file);
1387 self.file_sizes.remove(file);
1388 }
1389
1390 pub fn dimension(&self) -> usize {
1392 self.dimension
1393 }
1394
1395 pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1396 self.fingerprint.as_ref()
1397 }
1398
1399 pub fn backend_label(&self) -> Option<&str> {
1400 self.fingerprint.as_ref().map(|f| f.backend.as_str())
1401 }
1402
1403 pub fn model_label(&self) -> Option<&str> {
1404 self.fingerprint.as_ref().map(|f| f.model.as_str())
1405 }
1406
1407 pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1408 self.fingerprint = Some(fingerprint);
1409 }
1410
1411 pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1413 if self.entries.is_empty() {
1416 slog_info!("skipping semantic index persistence (0 entries)");
1417 return;
1418 }
1419 let dir = storage_dir.join("semantic").join(project_key);
1420 if let Err(e) = fs::create_dir_all(&dir) {
1421 slog_warn!("failed to create semantic cache dir: {}", e);
1422 return;
1423 }
1424 let data_path = dir.join("semantic.bin");
1425 let tmp_path = dir.join("semantic.bin.tmp");
1426 let bytes = self.to_bytes();
1427 if let Err(e) = fs::write(&tmp_path, &bytes) {
1428 slog_warn!("failed to write semantic index: {}", e);
1429 let _ = fs::remove_file(&tmp_path);
1430 return;
1431 }
1432 if let Err(e) = fs::rename(&tmp_path, &data_path) {
1433 slog_warn!("failed to rename semantic index: {}", e);
1434 let _ = fs::remove_file(&tmp_path);
1435 return;
1436 }
1437 slog_info!(
1438 "semantic index persisted: {} entries, {:.1} KB",
1439 self.entries.len(),
1440 bytes.len() as f64 / 1024.0
1441 );
1442 }
1443
1444 pub fn read_from_disk(
1446 storage_dir: &Path,
1447 project_key: &str,
1448 expected_fingerprint: Option<&str>,
1449 ) -> Option<Self> {
1450 let data_path = storage_dir
1451 .join("semantic")
1452 .join(project_key)
1453 .join("semantic.bin");
1454 let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1455 if file_len < HEADER_BYTES_V1 {
1456 slog_warn!(
1457 "corrupt semantic index (too small: {} bytes), removing",
1458 file_len
1459 );
1460 let _ = fs::remove_file(&data_path);
1461 return None;
1462 }
1463
1464 let bytes = fs::read(&data_path).ok()?;
1465 let version = bytes[0];
1466 if version != SEMANTIC_INDEX_VERSION_V5 {
1467 slog_info!(
1468 "cached semantic index version {} is older than {}, rebuilding",
1469 version,
1470 SEMANTIC_INDEX_VERSION_V5
1471 );
1472 let _ = fs::remove_file(&data_path);
1473 return None;
1474 }
1475 match Self::from_bytes(&bytes) {
1476 Ok(index) => {
1477 if index.entries.is_empty() {
1478 slog_info!("cached semantic index is empty, will rebuild");
1479 let _ = fs::remove_file(&data_path);
1480 return None;
1481 }
1482 if let Some(expected) = expected_fingerprint {
1483 let matches = index
1484 .fingerprint()
1485 .map(|fingerprint| fingerprint.matches_expected(expected))
1486 .unwrap_or(false);
1487 if !matches {
1488 slog_info!("cached semantic index fingerprint mismatch, rebuilding");
1489 let _ = fs::remove_file(&data_path);
1490 return None;
1491 }
1492 }
1493 slog_info!(
1494 "loaded semantic index from disk: {} entries",
1495 index.entries.len()
1496 );
1497 Some(index)
1498 }
1499 Err(e) => {
1500 slog_warn!("corrupt semantic index, rebuilding: {}", e);
1501 let _ = fs::remove_file(&data_path);
1502 None
1503 }
1504 }
1505 }
1506
1507 pub fn to_bytes(&self) -> Vec<u8> {
1509 let mut buf = Vec::new();
1510 let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1511 let encoded = fingerprint.as_string();
1512 if encoded.is_empty() {
1513 None
1514 } else {
1515 Some(encoded.into_bytes())
1516 }
1517 });
1518
1519 let version = SEMANTIC_INDEX_VERSION_V5;
1531 buf.push(version);
1532 buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1533 buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1534 let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1535 buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1536 buf.extend_from_slice(fp_bytes_ref);
1537
1538 buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1541 for (path, mtime) in &self.file_mtimes {
1542 let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1543 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1544 buf.extend_from_slice(&path_bytes);
1545 let duration = mtime
1546 .duration_since(SystemTime::UNIX_EPOCH)
1547 .unwrap_or_default();
1548 buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1549 buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1550 let size = self.file_sizes.get(path).copied().unwrap_or_default();
1551 buf.extend_from_slice(&size.to_le_bytes());
1552 }
1553
1554 for entry in &self.entries {
1556 let c = &entry.chunk;
1557
1558 let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1560 buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1561 buf.extend_from_slice(&file_bytes);
1562
1563 let name_bytes = c.name.as_bytes();
1565 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1566 buf.extend_from_slice(name_bytes);
1567
1568 buf.push(symbol_kind_to_u8(&c.kind));
1570
1571 buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1573 buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1574 buf.push(c.exported as u8);
1575
1576 let snippet_bytes = c.snippet.as_bytes();
1578 buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1579 buf.extend_from_slice(snippet_bytes);
1580
1581 let embed_bytes = c.embed_text.as_bytes();
1583 buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1584 buf.extend_from_slice(embed_bytes);
1585
1586 for &val in &entry.vector {
1588 buf.extend_from_slice(&val.to_le_bytes());
1589 }
1590 }
1591
1592 buf
1593 }
1594
1595 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1597 let mut pos = 0;
1598
1599 if data.len() < HEADER_BYTES_V1 {
1600 return Err("data too short".to_string());
1601 }
1602
1603 let version = data[pos];
1604 pos += 1;
1605 if version != SEMANTIC_INDEX_VERSION_V1
1606 && version != SEMANTIC_INDEX_VERSION_V2
1607 && version != SEMANTIC_INDEX_VERSION_V3
1608 && version != SEMANTIC_INDEX_VERSION_V4
1609 && version != SEMANTIC_INDEX_VERSION_V5
1610 {
1611 return Err(format!("unsupported version: {}", version));
1612 }
1613 if (version == SEMANTIC_INDEX_VERSION_V2
1617 || version == SEMANTIC_INDEX_VERSION_V3
1618 || version == SEMANTIC_INDEX_VERSION_V4
1619 || version == SEMANTIC_INDEX_VERSION_V5)
1620 && data.len() < HEADER_BYTES_V2
1621 {
1622 return Err("data too short for semantic index v2/v3/v4/v5 header".to_string());
1623 }
1624
1625 let dimension = read_u32(data, &mut pos)? as usize;
1626 let entry_count = read_u32(data, &mut pos)? as usize;
1627 if dimension == 0 || dimension > MAX_DIMENSION {
1628 return Err(format!("invalid embedding dimension: {}", dimension));
1629 }
1630 if entry_count > MAX_ENTRIES {
1631 return Err(format!("too many semantic index entries: {}", entry_count));
1632 }
1633
1634 let has_fingerprint_field = version == SEMANTIC_INDEX_VERSION_V2
1640 || version == SEMANTIC_INDEX_VERSION_V3
1641 || version == SEMANTIC_INDEX_VERSION_V4
1642 || version == SEMANTIC_INDEX_VERSION_V5;
1643 let fingerprint = if has_fingerprint_field {
1644 let fingerprint_len = read_u32(data, &mut pos)? as usize;
1645 if pos + fingerprint_len > data.len() {
1646 return Err("unexpected end of data reading fingerprint".to_string());
1647 }
1648 if fingerprint_len == 0 {
1649 None
1650 } else {
1651 let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1652 pos += fingerprint_len;
1653 Some(
1654 serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1655 .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1656 )
1657 }
1658 } else {
1659 None
1660 };
1661
1662 let mtime_count = read_u32(data, &mut pos)? as usize;
1664 if mtime_count > MAX_ENTRIES {
1665 return Err(format!("too many semantic file mtimes: {}", mtime_count));
1666 }
1667
1668 let vector_bytes = entry_count
1669 .checked_mul(dimension)
1670 .and_then(|count| count.checked_mul(F32_BYTES))
1671 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1672 if vector_bytes > data.len().saturating_sub(pos) {
1673 return Err("semantic index vectors exceed available data".to_string());
1674 }
1675
1676 let mut file_mtimes = HashMap::with_capacity(mtime_count);
1677 let mut file_sizes = HashMap::with_capacity(mtime_count);
1678 for _ in 0..mtime_count {
1679 let path = read_string(data, &mut pos)?;
1680 let secs = read_u64(data, &mut pos)?;
1681 let nanos = if version == SEMANTIC_INDEX_VERSION_V3
1687 || version == SEMANTIC_INDEX_VERSION_V4
1688 || version == SEMANTIC_INDEX_VERSION_V5
1689 {
1690 read_u32(data, &mut pos)?
1691 } else {
1692 0
1693 };
1694 let size = if version == SEMANTIC_INDEX_VERSION_V5 {
1695 read_u64(data, &mut pos)?
1696 } else {
1697 0
1698 };
1699 if nanos >= 1_000_000_000 {
1706 return Err(format!(
1707 "invalid semantic mtime: nanos {} >= 1_000_000_000",
1708 nanos
1709 ));
1710 }
1711 let duration = std::time::Duration::new(secs, nanos);
1712 let mtime = SystemTime::UNIX_EPOCH
1713 .checked_add(duration)
1714 .ok_or_else(|| {
1715 format!(
1716 "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1717 secs, nanos
1718 )
1719 })?;
1720 let path = PathBuf::from(path);
1721 file_mtimes.insert(path.clone(), mtime);
1722 file_sizes.insert(path, size);
1723 }
1724
1725 let mut entries = Vec::with_capacity(entry_count);
1727 for _ in 0..entry_count {
1728 let file = PathBuf::from(read_string(data, &mut pos)?);
1729 let name = read_string(data, &mut pos)?;
1730
1731 if pos >= data.len() {
1732 return Err("unexpected end of data".to_string());
1733 }
1734 let kind = u8_to_symbol_kind(data[pos]);
1735 pos += 1;
1736
1737 let start_line = read_u32(data, &mut pos)?;
1738 let end_line = read_u32(data, &mut pos)?;
1739
1740 if pos >= data.len() {
1741 return Err("unexpected end of data".to_string());
1742 }
1743 let exported = data[pos] != 0;
1744 pos += 1;
1745
1746 let snippet = read_string(data, &mut pos)?;
1747 let embed_text = read_string(data, &mut pos)?;
1748
1749 let vec_bytes = dimension
1751 .checked_mul(F32_BYTES)
1752 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1753 if pos + vec_bytes > data.len() {
1754 return Err("unexpected end of data reading vector".to_string());
1755 }
1756 let mut vector = Vec::with_capacity(dimension);
1757 for _ in 0..dimension {
1758 let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1759 vector.push(f32::from_le_bytes(bytes));
1760 pos += 4;
1761 }
1762
1763 entries.push(EmbeddingEntry {
1764 chunk: SemanticChunk {
1765 file,
1766 name,
1767 kind,
1768 start_line,
1769 end_line,
1770 exported,
1771 embed_text,
1772 snippet,
1773 },
1774 vector,
1775 });
1776 }
1777
1778 Ok(Self {
1779 entries,
1780 file_mtimes,
1781 file_sizes,
1782 dimension,
1783 fingerprint,
1784 })
1785 }
1786}
1787
1788fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1790 let relative = file
1791 .strip_prefix(project_root)
1792 .unwrap_or(file)
1793 .to_string_lossy();
1794
1795 let kind_label = match &symbol.kind {
1796 SymbolKind::Function => "function",
1797 SymbolKind::Class => "class",
1798 SymbolKind::Method => "method",
1799 SymbolKind::Struct => "struct",
1800 SymbolKind::Interface => "interface",
1801 SymbolKind::Enum => "enum",
1802 SymbolKind::TypeAlias => "type",
1803 SymbolKind::Variable => "variable",
1804 SymbolKind::Heading => "heading",
1805 };
1806
1807 let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1809
1810 if let Some(sig) = &symbol.signature {
1811 text.push_str(&format!(" signature:{}", sig));
1812 }
1813
1814 let lines: Vec<&str> = source.lines().collect();
1816 let start = (symbol.range.start_line as usize).min(lines.len());
1817 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1819 if start < end {
1820 let body: String = lines[start..end]
1821 .iter()
1822 .take(15) .copied()
1824 .collect::<Vec<&str>>()
1825 .join("\n");
1826 let snippet = if body.len() > 300 {
1827 format!("{}...", &body[..body.floor_char_boundary(300)])
1828 } else {
1829 body
1830 };
1831 text.push_str(&format!(" body:{}", snippet));
1832 }
1833
1834 text
1835}
1836
1837fn parser_for(
1838 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1839 lang: crate::parser::LangId,
1840) -> Result<&mut Parser, String> {
1841 use std::collections::hash_map::Entry;
1842
1843 match parsers.entry(lang) {
1844 Entry::Occupied(entry) => Ok(entry.into_mut()),
1845 Entry::Vacant(entry) => {
1846 let grammar = grammar_for(lang);
1847 let mut parser = Parser::new();
1848 parser
1849 .set_language(&grammar)
1850 .map_err(|error| error.to_string())?;
1851 Ok(entry.insert(parser))
1852 }
1853 }
1854}
1855
1856fn collect_file_metadata(file: &Path) -> Result<IndexedFileMetadata, String> {
1857 let metadata = fs::metadata(file).map_err(|error| error.to_string())?;
1858 let mtime = metadata.modified().map_err(|error| error.to_string())?;
1859 Ok(IndexedFileMetadata {
1860 mtime,
1861 size: metadata.len(),
1862 })
1863}
1864
1865fn collect_file_chunks(
1866 project_root: &Path,
1867 file: &Path,
1868 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1869) -> Result<Vec<SemanticChunk>, String> {
1870 let lang = detect_language(file).ok_or_else(|| "unsupported file extension".to_string())?;
1871 let source = std::fs::read_to_string(file).map_err(|error| error.to_string())?;
1872 let tree = parser_for(parsers, lang)?
1873 .parse(&source, None)
1874 .ok_or_else(|| format!("tree-sitter parse returned None for {}", file.display()))?;
1875 let symbols =
1876 extract_symbols_from_tree(&source, &tree, lang).map_err(|error| error.to_string())?;
1877
1878 Ok(symbols_to_chunks(file, &symbols, &source, project_root))
1879}
1880
1881fn build_snippet(symbol: &Symbol, source: &str) -> String {
1883 let lines: Vec<&str> = source.lines().collect();
1884 let start = (symbol.range.start_line as usize).min(lines.len());
1885 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1887 if start < end {
1888 let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1889 let mut snippet = snippet_lines.join("\n");
1890 if end - start > 5 {
1891 snippet.push_str("\n ...");
1892 }
1893 if snippet.len() > 300 {
1894 snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1895 }
1896 snippet
1897 } else {
1898 String::new()
1899 }
1900}
1901
1902fn symbols_to_chunks(
1904 file: &Path,
1905 symbols: &[Symbol],
1906 source: &str,
1907 project_root: &Path,
1908) -> Vec<SemanticChunk> {
1909 let mut chunks = Vec::new();
1910
1911 for symbol in symbols {
1912 if matches!(symbol.kind, SymbolKind::Heading) {
1917 continue;
1918 }
1919
1920 let line_count = symbol
1922 .range
1923 .end_line
1924 .saturating_sub(symbol.range.start_line)
1925 + 1;
1926 if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1927 continue;
1928 }
1929
1930 let embed_text = build_embed_text(symbol, source, file, project_root);
1931 let snippet = build_snippet(symbol, source);
1932
1933 chunks.push(SemanticChunk {
1934 file: file.to_path_buf(),
1935 name: symbol.name.clone(),
1936 kind: symbol.kind.clone(),
1937 start_line: symbol.range.start_line,
1938 end_line: symbol.range.end_line,
1939 exported: symbol.exported,
1940 embed_text,
1941 snippet,
1942 });
1943
1944 }
1947
1948 chunks
1949}
1950
1951fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1953 if a.len() != b.len() {
1954 return 0.0;
1955 }
1956
1957 let mut dot = 0.0f32;
1958 let mut norm_a = 0.0f32;
1959 let mut norm_b = 0.0f32;
1960
1961 for i in 0..a.len() {
1962 dot += a[i] * b[i];
1963 norm_a += a[i] * a[i];
1964 norm_b += b[i] * b[i];
1965 }
1966
1967 let denom = norm_a.sqrt() * norm_b.sqrt();
1968 if denom == 0.0 {
1969 0.0
1970 } else {
1971 dot / denom
1972 }
1973}
1974
1975fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1977 match kind {
1978 SymbolKind::Function => 0,
1979 SymbolKind::Class => 1,
1980 SymbolKind::Method => 2,
1981 SymbolKind::Struct => 3,
1982 SymbolKind::Interface => 4,
1983 SymbolKind::Enum => 5,
1984 SymbolKind::TypeAlias => 6,
1985 SymbolKind::Variable => 7,
1986 SymbolKind::Heading => 8,
1987 }
1988}
1989
1990fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1991 match v {
1992 0 => SymbolKind::Function,
1993 1 => SymbolKind::Class,
1994 2 => SymbolKind::Method,
1995 3 => SymbolKind::Struct,
1996 4 => SymbolKind::Interface,
1997 5 => SymbolKind::Enum,
1998 6 => SymbolKind::TypeAlias,
1999 7 => SymbolKind::Variable,
2000 _ => SymbolKind::Heading,
2001 }
2002}
2003
2004fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
2005 if *pos + 4 > data.len() {
2006 return Err("unexpected end of data reading u32".to_string());
2007 }
2008 let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
2009 *pos += 4;
2010 Ok(val)
2011}
2012
2013fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
2014 if *pos + 8 > data.len() {
2015 return Err("unexpected end of data reading u64".to_string());
2016 }
2017 let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
2018 *pos += 8;
2019 Ok(u64::from_le_bytes(bytes))
2020}
2021
2022fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
2023 let len = read_u32(data, pos)? as usize;
2024 if *pos + len > data.len() {
2025 return Err("unexpected end of data reading string".to_string());
2026 }
2027 let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
2028 *pos += len;
2029 Ok(s)
2030}
2031
2032#[cfg(test)]
2033mod tests {
2034 use super::*;
2035 use crate::config::{SemanticBackend, SemanticBackendConfig};
2036 use crate::parser::FileParser;
2037 use std::io::{Read, Write};
2038 use std::net::TcpListener;
2039 use std::thread;
2040
2041 fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
2042 where
2043 F: Fn(String, String, String) -> String + Send + 'static,
2044 {
2045 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
2046 let addr = listener.local_addr().expect("local addr");
2047 let handle = thread::spawn(move || {
2048 let (mut stream, _) = listener.accept().expect("accept request");
2049 let mut buf = Vec::new();
2050 let mut chunk = [0u8; 4096];
2051 let mut header_end = None;
2052 let mut content_length = 0usize;
2053 loop {
2054 let n = stream.read(&mut chunk).expect("read request");
2055 if n == 0 {
2056 break;
2057 }
2058 buf.extend_from_slice(&chunk[..n]);
2059 if header_end.is_none() {
2060 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
2061 header_end = Some(pos + 4);
2062 let headers = String::from_utf8_lossy(&buf[..pos + 4]);
2063 for line in headers.lines() {
2064 if let Some(value) = line.strip_prefix("Content-Length:") {
2065 content_length = value.trim().parse::<usize>().unwrap_or(0);
2066 }
2067 }
2068 }
2069 }
2070 if let Some(end) = header_end {
2071 if buf.len() >= end + content_length {
2072 break;
2073 }
2074 }
2075 }
2076
2077 let end = header_end.expect("header terminator");
2078 let request = String::from_utf8_lossy(&buf[..end]).to_string();
2079 let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
2080 let mut lines = request.lines();
2081 let request_line = lines.next().expect("request line").to_string();
2082 let path = request_line
2083 .split_whitespace()
2084 .nth(1)
2085 .expect("request path")
2086 .to_string();
2087 let response_body = handler(request_line, path, body);
2088 let response = format!(
2089 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
2090 response_body.len(),
2091 response_body
2092 );
2093 stream
2094 .write_all(response.as_bytes())
2095 .expect("write response");
2096 });
2097
2098 (format!("http://{}", addr), handle)
2099 }
2100
2101 fn test_vector_for_texts(texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
2102 Ok(texts.iter().map(|_| vec![1.0, 0.0, 0.0]).collect())
2103 }
2104
2105 fn write_rust_file(path: &Path, function_name: &str) {
2106 fs::write(
2107 path,
2108 format!("pub fn {function_name}() -> bool {{\n true\n}}\n"),
2109 )
2110 .unwrap();
2111 }
2112
2113 fn build_test_index(project_root: &Path, files: &[PathBuf]) -> SemanticIndex {
2114 let mut embed = test_vector_for_texts;
2115 SemanticIndex::build(project_root, files, &mut embed, 8).unwrap()
2116 }
2117
2118 fn set_file_metadata(index: &mut SemanticIndex, file: &Path, mtime: SystemTime, size: u64) {
2119 index.file_mtimes.insert(file.to_path_buf(), mtime);
2120 index.file_sizes.insert(file.to_path_buf(), size);
2121 }
2122
2123 #[test]
2124 fn test_cosine_similarity_identical() {
2125 let a = vec![1.0, 0.0, 0.0];
2126 let b = vec![1.0, 0.0, 0.0];
2127 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
2128 }
2129
2130 #[test]
2131 fn test_cosine_similarity_orthogonal() {
2132 let a = vec![1.0, 0.0, 0.0];
2133 let b = vec![0.0, 1.0, 0.0];
2134 assert!(cosine_similarity(&a, &b).abs() < 0.001);
2135 }
2136
2137 #[test]
2138 fn test_cosine_similarity_opposite() {
2139 let a = vec![1.0, 0.0, 0.0];
2140 let b = vec![-1.0, 0.0, 0.0];
2141 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
2142 }
2143
2144 #[test]
2145 fn test_serialization_roundtrip() {
2146 let mut index = SemanticIndex::new();
2147 index.entries.push(EmbeddingEntry {
2148 chunk: SemanticChunk {
2149 file: PathBuf::from("/src/main.rs"),
2150 name: "handle_request".to_string(),
2151 kind: SymbolKind::Function,
2152 start_line: 10,
2153 end_line: 25,
2154 exported: true,
2155 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2156 snippet: "fn handle_request() {\n // ...\n}".to_string(),
2157 },
2158 vector: vec![0.1, 0.2, 0.3, 0.4],
2159 });
2160 index.dimension = 4;
2161 index
2162 .file_mtimes
2163 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2164 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2165 index.set_fingerprint(SemanticIndexFingerprint {
2166 backend: "fastembed".to_string(),
2167 model: "all-MiniLM-L6-v2".to_string(),
2168 base_url: FALLBACK_BACKEND.to_string(),
2169 dimension: 4,
2170 });
2171
2172 let bytes = index.to_bytes();
2173 let restored = SemanticIndex::from_bytes(&bytes).unwrap();
2174
2175 assert_eq!(restored.entries.len(), 1);
2176 assert_eq!(restored.entries[0].chunk.name, "handle_request");
2177 assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
2178 assert_eq!(restored.dimension, 4);
2179 assert_eq!(restored.backend_label(), Some("fastembed"));
2180 assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
2181 }
2182
2183 #[test]
2184 fn test_search_top_k() {
2185 let mut index = SemanticIndex::new();
2186 index.dimension = 3;
2187
2188 for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
2190 let mut vec = vec![0.0f32; 3];
2191 vec[i] = 1.0; index.entries.push(EmbeddingEntry {
2193 chunk: SemanticChunk {
2194 file: PathBuf::from("/src/lib.rs"),
2195 name: name.to_string(),
2196 kind: SymbolKind::Function,
2197 start_line: (i * 10 + 1) as u32,
2198 end_line: (i * 10 + 5) as u32,
2199 exported: true,
2200 embed_text: format!("kind:function name:{}", name),
2201 snippet: format!("fn {}() {{}}", name),
2202 },
2203 vector: vec,
2204 });
2205 }
2206
2207 let query = vec![0.9, 0.1, 0.0];
2209 let results = index.search(&query, 2);
2210
2211 assert_eq!(results.len(), 2);
2212 assert_eq!(results[0].name, "auth"); assert!(results[0].score > results[1].score);
2214 }
2215
2216 #[test]
2217 fn test_empty_index_search() {
2218 let index = SemanticIndex::new();
2219 let results = index.search(&[0.1, 0.2, 0.3], 10);
2220 assert!(results.is_empty());
2221 }
2222
2223 #[test]
2224 fn single_line_symbol_builds_non_empty_snippet() {
2225 let symbol = Symbol {
2226 name: "answer".to_string(),
2227 kind: SymbolKind::Variable,
2228 range: crate::symbols::Range {
2229 start_line: 0,
2230 start_col: 0,
2231 end_line: 0,
2232 end_col: 24,
2233 },
2234 signature: Some("const answer = 42".to_string()),
2235 scope_chain: Vec::new(),
2236 exported: true,
2237 parent: None,
2238 };
2239 let source = "export const answer = 42;\n";
2240
2241 let snippet = build_snippet(&symbol, source);
2242
2243 assert_eq!(snippet, "export const answer = 42;");
2244 }
2245
2246 #[test]
2247 fn optimized_file_chunk_collection_matches_file_parser_path() {
2248 let project_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2249 let file = project_root.join("src/semantic_index.rs");
2250 let source = std::fs::read_to_string(&file).unwrap();
2251
2252 let mut legacy_parser = FileParser::new();
2253 let legacy_symbols = legacy_parser.extract_symbols(&file).unwrap();
2254 let legacy_chunks = symbols_to_chunks(&file, &legacy_symbols, &source, &project_root);
2255
2256 let mut parsers = HashMap::new();
2257 let optimized_chunks = collect_file_chunks(&project_root, &file, &mut parsers).unwrap();
2258
2259 assert_eq!(
2260 chunk_fingerprint(&optimized_chunks),
2261 chunk_fingerprint(&legacy_chunks)
2262 );
2263 }
2264
2265 fn chunk_fingerprint(
2266 chunks: &[SemanticChunk],
2267 ) -> Vec<(String, SymbolKind, u32, u32, bool, String, String)> {
2268 chunks
2269 .iter()
2270 .map(|chunk| {
2271 (
2272 chunk.name.clone(),
2273 chunk.kind.clone(),
2274 chunk.start_line,
2275 chunk.end_line,
2276 chunk.exported,
2277 chunk.embed_text.clone(),
2278 chunk.snippet.clone(),
2279 )
2280 })
2281 .collect()
2282 }
2283
2284 #[test]
2285 fn rejects_oversized_dimension_during_deserialization() {
2286 let mut bytes = Vec::new();
2287 bytes.push(1u8);
2288 bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
2289 bytes.extend_from_slice(&0u32.to_le_bytes());
2290 bytes.extend_from_slice(&0u32.to_le_bytes());
2291
2292 assert!(SemanticIndex::from_bytes(&bytes).is_err());
2293 }
2294
2295 #[test]
2296 fn rejects_oversized_entry_count_during_deserialization() {
2297 let mut bytes = Vec::new();
2298 bytes.push(1u8);
2299 bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
2300 bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
2301 bytes.extend_from_slice(&0u32.to_le_bytes());
2302
2303 assert!(SemanticIndex::from_bytes(&bytes).is_err());
2304 }
2305
2306 #[test]
2307 fn invalidate_file_removes_entries_and_mtime() {
2308 let target = PathBuf::from("/src/main.rs");
2309 let mut index = SemanticIndex::new();
2310 index.entries.push(EmbeddingEntry {
2311 chunk: SemanticChunk {
2312 file: target.clone(),
2313 name: "main".to_string(),
2314 kind: SymbolKind::Function,
2315 start_line: 0,
2316 end_line: 1,
2317 exported: false,
2318 embed_text: "main".to_string(),
2319 snippet: "fn main() {}".to_string(),
2320 },
2321 vector: vec![1.0; DEFAULT_DIMENSION],
2322 });
2323 index
2324 .file_mtimes
2325 .insert(target.clone(), SystemTime::UNIX_EPOCH);
2326 index.file_sizes.insert(target.clone(), 0);
2327
2328 index.invalidate_file(&target);
2329
2330 assert!(index.entries.is_empty());
2331 assert!(!index.file_mtimes.contains_key(&target));
2332 assert!(!index.file_sizes.contains_key(&target));
2333 }
2334
2335 #[test]
2336 fn refresh_transient_error_preserves_existing_entry_and_mtime() {
2337 let temp = tempfile::tempdir().unwrap();
2338 let project_root = temp.path();
2339 let file = project_root.join("src/lib.rs");
2340 fs::create_dir_all(file.parent().unwrap()).unwrap();
2341 write_rust_file(&file, "kept_symbol");
2342
2343 let mut index = build_test_index(project_root, std::slice::from_ref(&file));
2344 let original_entry_count = index.entries.len();
2345 let original_mtime = *index.file_mtimes.get(&file).unwrap();
2346 let original_size = *index.file_sizes.get(&file).unwrap();
2347
2348 let stale_mtime = SystemTime::UNIX_EPOCH;
2349 set_file_metadata(&mut index, &file, stale_mtime, original_size + 1);
2350 fs::remove_file(&file).unwrap();
2351
2352 let mut embed = test_vector_for_texts;
2353 let mut progress = |_done: usize, _total: usize| {};
2354 let summary = index
2355 .refresh_stale_files(
2356 project_root,
2357 std::slice::from_ref(&file),
2358 &mut embed,
2359 8,
2360 &mut progress,
2361 )
2362 .unwrap();
2363
2364 assert_eq!(summary.changed, 0);
2365 assert_eq!(summary.added, 0);
2366 assert_eq!(summary.deleted, 0);
2367 assert_eq!(index.entries.len(), original_entry_count);
2368 assert_eq!(index.entries[0].chunk.name, "kept_symbol");
2369 assert_eq!(index.file_mtimes.get(&file), Some(&stale_mtime));
2370 assert_ne!(index.file_mtimes.get(&file), Some(&original_mtime));
2371 assert_eq!(index.file_sizes.get(&file), Some(&(original_size + 1)));
2372 }
2373
2374 #[test]
2375 fn refresh_never_indexed_file_error_does_not_record_mtime() {
2376 let temp = tempfile::tempdir().unwrap();
2377 let project_root = temp.path();
2378 let missing = project_root.join("src/missing.rs");
2379 fs::create_dir_all(missing.parent().unwrap()).unwrap();
2380
2381 let mut index = SemanticIndex::new();
2382 let mut embed = test_vector_for_texts;
2383 let mut progress = |_done: usize, _total: usize| {};
2384 let summary = index
2385 .refresh_stale_files(
2386 project_root,
2387 std::slice::from_ref(&missing),
2388 &mut embed,
2389 8,
2390 &mut progress,
2391 )
2392 .unwrap();
2393
2394 assert_eq!(summary.added, 0);
2395 assert_eq!(summary.changed, 0);
2396 assert_eq!(summary.deleted, 0);
2397 assert!(!index.file_mtimes.contains_key(&missing));
2398 assert!(!index.file_sizes.contains_key(&missing));
2399 assert!(index.entries.is_empty());
2400 }
2401
2402 #[test]
2403 fn refresh_reports_added_for_new_files() {
2404 let temp = tempfile::tempdir().unwrap();
2405 let project_root = temp.path();
2406 let existing = project_root.join("src/lib.rs");
2407 let added = project_root.join("src/new.rs");
2408 fs::create_dir_all(existing.parent().unwrap()).unwrap();
2409 write_rust_file(&existing, "existing_symbol");
2410 write_rust_file(&added, "added_symbol");
2411
2412 let mut index = build_test_index(project_root, std::slice::from_ref(&existing));
2413 let mut embed = test_vector_for_texts;
2414 let mut progress = |_done: usize, _total: usize| {};
2415 let summary = index
2416 .refresh_stale_files(
2417 project_root,
2418 &[existing.clone(), added.clone()],
2419 &mut embed,
2420 8,
2421 &mut progress,
2422 )
2423 .unwrap();
2424
2425 assert_eq!(summary.added, 1);
2426 assert_eq!(summary.changed, 0);
2427 assert_eq!(summary.deleted, 0);
2428 assert_eq!(summary.total_processed, 2);
2429 assert!(index.file_mtimes.contains_key(&added));
2430 assert!(index.entries.iter().any(|entry| entry.chunk.file == added));
2431 }
2432
2433 #[test]
2434 fn refresh_reports_deleted_for_removed_files() {
2435 let temp = tempfile::tempdir().unwrap();
2436 let project_root = temp.path();
2437 let deleted = project_root.join("src/deleted.rs");
2438 fs::create_dir_all(deleted.parent().unwrap()).unwrap();
2439 write_rust_file(&deleted, "deleted_symbol");
2440
2441 let mut index = build_test_index(project_root, std::slice::from_ref(&deleted));
2442 fs::remove_file(&deleted).unwrap();
2443
2444 let mut embed = test_vector_for_texts;
2445 let mut progress = |_done: usize, _total: usize| {};
2446 let summary = index
2447 .refresh_stale_files(project_root, &[], &mut embed, 8, &mut progress)
2448 .unwrap();
2449
2450 assert_eq!(summary.deleted, 1);
2451 assert_eq!(summary.changed, 0);
2452 assert_eq!(summary.added, 0);
2453 assert_eq!(summary.total_processed, 1);
2454 assert!(!index.file_mtimes.contains_key(&deleted));
2455 assert!(index.entries.is_empty());
2456 }
2457
2458 #[test]
2459 fn refresh_reports_changed_for_modified_files() {
2460 let temp = tempfile::tempdir().unwrap();
2461 let project_root = temp.path();
2462 let file = project_root.join("src/lib.rs");
2463 fs::create_dir_all(file.parent().unwrap()).unwrap();
2464 write_rust_file(&file, "old_symbol");
2465
2466 let mut index = build_test_index(project_root, std::slice::from_ref(&file));
2467 set_file_metadata(&mut index, &file, SystemTime::UNIX_EPOCH, 0);
2468 write_rust_file(&file, "new_symbol");
2469
2470 let mut embed = test_vector_for_texts;
2471 let mut progress = |_done: usize, _total: usize| {};
2472 let summary = index
2473 .refresh_stale_files(
2474 project_root,
2475 std::slice::from_ref(&file),
2476 &mut embed,
2477 8,
2478 &mut progress,
2479 )
2480 .unwrap();
2481
2482 assert_eq!(summary.changed, 1);
2483 assert_eq!(summary.added, 0);
2484 assert_eq!(summary.deleted, 0);
2485 assert_eq!(summary.total_processed, 1);
2486 assert!(index
2487 .entries
2488 .iter()
2489 .any(|entry| entry.chunk.name == "new_symbol"));
2490 assert!(!index
2491 .entries
2492 .iter()
2493 .any(|entry| entry.chunk.name == "old_symbol"));
2494 }
2495
2496 #[test]
2497 fn refresh_all_clean_reports_zero_counts_and_no_embedding_work() {
2498 let temp = tempfile::tempdir().unwrap();
2499 let project_root = temp.path();
2500 let file = project_root.join("src/lib.rs");
2501 fs::create_dir_all(file.parent().unwrap()).unwrap();
2502 write_rust_file(&file, "clean_symbol");
2503
2504 let mut index = build_test_index(project_root, std::slice::from_ref(&file));
2505 let original_entries = index.entries.len();
2506 let mut embed_called = false;
2507 let mut embed = |texts: Vec<String>| {
2508 embed_called = true;
2509 test_vector_for_texts(texts)
2510 };
2511 let mut progress = |_done: usize, _total: usize| {};
2512 let summary = index
2513 .refresh_stale_files(
2514 project_root,
2515 std::slice::from_ref(&file),
2516 &mut embed,
2517 8,
2518 &mut progress,
2519 )
2520 .unwrap();
2521
2522 assert!(summary.is_noop());
2523 assert_eq!(summary.total_processed, 1);
2524 assert!(!embed_called);
2525 assert_eq!(index.entries.len(), original_entries);
2526 }
2527
2528 #[test]
2529 fn detects_missing_onnx_runtime_from_dynamic_load_error() {
2530 let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
2531
2532 assert!(is_onnx_runtime_unavailable(message));
2533 }
2534
2535 #[test]
2536 fn formats_missing_onnx_runtime_with_install_hint() {
2537 let message = format_embedding_init_error(
2538 "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
2539 );
2540
2541 assert!(message.starts_with("ONNX Runtime not found. Install via:"));
2542 assert!(message.contains("Original error:"));
2543 }
2544
2545 #[test]
2546 fn openai_compatible_backend_embeds_with_mock_server() {
2547 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2548 assert!(request_line.starts_with("POST "));
2549 assert_eq!(path, "/v1/embeddings");
2550 "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
2551 });
2552
2553 let config = SemanticBackendConfig {
2554 backend: SemanticBackend::OpenAiCompatible,
2555 model: "test-embedding".to_string(),
2556 base_url: Some(base_url),
2557 api_key_env: None,
2558 timeout_ms: 5_000,
2559 max_batch_size: 64,
2560 };
2561
2562 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2563 let vectors = model
2564 .embed(vec!["hello".to_string(), "world".to_string()])
2565 .unwrap();
2566
2567 assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
2568 handle.join().unwrap();
2569 }
2570
2571 #[test]
2572 fn ollama_backend_embeds_with_mock_server() {
2573 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2574 assert!(request_line.starts_with("POST "));
2575 assert_eq!(path, "/api/embed");
2576 "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
2577 });
2578
2579 let config = SemanticBackendConfig {
2580 backend: SemanticBackend::Ollama,
2581 model: "embeddinggemma".to_string(),
2582 base_url: Some(base_url),
2583 api_key_env: None,
2584 timeout_ms: 5_000,
2585 max_batch_size: 64,
2586 };
2587
2588 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2589 let vectors = model
2590 .embed(vec!["hello".to_string(), "world".to_string()])
2591 .unwrap();
2592
2593 assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
2594 handle.join().unwrap();
2595 }
2596
2597 #[test]
2598 fn read_from_disk_rejects_fingerprint_mismatch() {
2599 let storage = tempfile::tempdir().unwrap();
2600 let project_key = "proj";
2601
2602 let mut index = SemanticIndex::new();
2603 index.entries.push(EmbeddingEntry {
2604 chunk: SemanticChunk {
2605 file: PathBuf::from("/src/main.rs"),
2606 name: "handle_request".to_string(),
2607 kind: SymbolKind::Function,
2608 start_line: 10,
2609 end_line: 25,
2610 exported: true,
2611 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2612 snippet: "fn handle_request() {}".to_string(),
2613 },
2614 vector: vec![0.1, 0.2, 0.3],
2615 });
2616 index.dimension = 3;
2617 index
2618 .file_mtimes
2619 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2620 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2621 index.set_fingerprint(SemanticIndexFingerprint {
2622 backend: "openai_compatible".to_string(),
2623 model: "test-embedding".to_string(),
2624 base_url: "http://127.0.0.1:1234/v1".to_string(),
2625 dimension: 3,
2626 });
2627 index.write_to_disk(storage.path(), project_key);
2628
2629 let matching = index.fingerprint().unwrap().as_string();
2630 assert!(
2631 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
2632 );
2633
2634 let mismatched = SemanticIndexFingerprint {
2635 backend: "ollama".to_string(),
2636 model: "embeddinggemma".to_string(),
2637 base_url: "http://127.0.0.1:11434".to_string(),
2638 dimension: 3,
2639 }
2640 .as_string();
2641 assert!(
2642 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
2643 );
2644 }
2645
2646 #[test]
2647 fn read_from_disk_rejects_v3_cache_for_snippet_rebuild() {
2648 let storage = tempfile::tempdir().unwrap();
2649 let project_key = "proj-v3";
2650 let dir = storage.path().join("semantic").join(project_key);
2651 fs::create_dir_all(&dir).unwrap();
2652
2653 let mut index = SemanticIndex::new();
2654 index.entries.push(EmbeddingEntry {
2655 chunk: SemanticChunk {
2656 file: PathBuf::from("/src/main.rs"),
2657 name: "handle_request".to_string(),
2658 kind: SymbolKind::Function,
2659 start_line: 0,
2660 end_line: 0,
2661 exported: true,
2662 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2663 snippet: "fn handle_request() {}".to_string(),
2664 },
2665 vector: vec![0.1, 0.2, 0.3],
2666 });
2667 index.dimension = 3;
2668 index
2669 .file_mtimes
2670 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2671 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2672 let fingerprint = SemanticIndexFingerprint {
2673 backend: "fastembed".to_string(),
2674 model: "test".to_string(),
2675 base_url: FALLBACK_BACKEND.to_string(),
2676 dimension: 3,
2677 };
2678 index.set_fingerprint(fingerprint.clone());
2679
2680 let mut bytes = index.to_bytes();
2681 bytes[0] = SEMANTIC_INDEX_VERSION_V3;
2682 fs::write(dir.join("semantic.bin"), bytes).unwrap();
2683
2684 assert!(SemanticIndex::read_from_disk(
2685 storage.path(),
2686 project_key,
2687 Some(&fingerprint.as_string())
2688 )
2689 .is_none());
2690 assert!(!dir.join("semantic.bin").exists());
2691 }
2692
2693 fn make_symbol(kind: SymbolKind, name: &str, start: u32, end: u32) -> crate::symbols::Symbol {
2694 crate::symbols::Symbol {
2695 name: name.to_string(),
2696 kind,
2697 range: crate::symbols::Range {
2698 start_line: start,
2699 start_col: 0,
2700 end_line: end,
2701 end_col: 0,
2702 },
2703 signature: None,
2704 scope_chain: Vec::new(),
2705 exported: false,
2706 parent: None,
2707 }
2708 }
2709
2710 #[test]
2715 fn symbols_to_chunks_skips_heading_symbols() {
2716 let project_root = PathBuf::from("/proj");
2717 let file = project_root.join("README.md");
2718 let source = "# Title\n\nbody text\n\n## Section\n\nmore text\n";
2719
2720 let symbols = vec![
2721 make_symbol(SymbolKind::Heading, "Title", 0, 2),
2722 make_symbol(SymbolKind::Heading, "Section", 4, 6),
2723 ];
2724
2725 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2726 assert!(
2727 chunks.is_empty(),
2728 "Heading symbols must be filtered out before embedding; got {} chunk(s)",
2729 chunks.len()
2730 );
2731 }
2732
2733 #[test]
2737 fn symbols_to_chunks_keeps_code_symbols_alongside_skipped_headings() {
2738 let project_root = PathBuf::from("/proj");
2739 let file = project_root.join("src/lib.rs");
2740 let source = "pub fn handle_request() -> bool {\n true\n}\n";
2741
2742 let symbols = vec![
2743 make_symbol(SymbolKind::Heading, "doc heading", 0, 1),
2745 make_symbol(SymbolKind::Function, "handle_request", 0, 2),
2746 make_symbol(SymbolKind::Struct, "AuthService", 4, 6),
2747 ];
2748
2749 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2750 assert_eq!(
2751 chunks.len(),
2752 2,
2753 "Expected 2 code chunks (Function + Struct), got {}",
2754 chunks.len()
2755 );
2756 let names: Vec<&str> = chunks.iter().map(|c| c.name.as_str()).collect();
2757 assert!(names.contains(&"handle_request"));
2758 assert!(names.contains(&"AuthService"));
2759 assert!(
2760 !names.contains(&"doc heading"),
2761 "Heading symbol leaked into chunks: {names:?}"
2762 );
2763 }
2764
2765 #[test]
2766 fn validate_ssrf_allows_loopback_hostnames() {
2767 for host in &[
2770 "http://localhost",
2771 "http://localhost:8080",
2772 "http://localhost:11434", "http://localhost.localdomain",
2774 "http://foo.localhost",
2775 ] {
2776 assert!(
2777 validate_base_url_no_ssrf(host).is_ok(),
2778 "Expected {host} to be allowed (loopback), got: {:?}",
2779 validate_base_url_no_ssrf(host)
2780 );
2781 }
2782 }
2783
2784 #[test]
2785 fn validate_ssrf_allows_loopback_ips() {
2786 for url in &[
2789 "http://127.0.0.1",
2790 "http://127.0.0.1:11434", "http://127.0.0.1:8080",
2792 "http://127.1.2.3",
2793 ] {
2794 let result = validate_base_url_no_ssrf(url);
2795 assert!(
2796 result.is_ok(),
2797 "Expected {url} to be allowed (loopback), got: {:?}",
2798 result
2799 );
2800 }
2801 }
2802
2803 #[test]
2804 fn validate_ssrf_rejects_private_non_loopback_ips() {
2805 for url in &[
2810 "http://192.168.1.1",
2811 "http://10.0.0.1",
2812 "http://172.16.0.1",
2813 "http://169.254.169.254",
2814 "http://100.64.0.1",
2815 ] {
2816 let result = validate_base_url_no_ssrf(url);
2817 assert!(
2818 result.is_err(),
2819 "Expected {url} to be rejected (non-loopback private), got: {:?}",
2820 result
2821 );
2822 }
2823 }
2824
2825 #[test]
2826 fn validate_ssrf_rejects_mdns_local_hostnames() {
2827 for host in &[
2830 "http://printer.local",
2831 "http://nas.local:8080",
2832 "http://homelab.local",
2833 ] {
2834 let result = validate_base_url_no_ssrf(host);
2835 assert!(
2836 result.is_err(),
2837 "Expected {host} to be rejected (mDNS), got: {:?}",
2838 result
2839 );
2840 }
2841 }
2842
2843 #[test]
2844 fn normalize_base_url_allows_localhost_for_tests() {
2845 assert!(normalize_base_url("http://127.0.0.1:9999").is_ok());
2848 assert!(normalize_base_url("http://localhost:8080").is_ok());
2849 }
2850
2851 #[test]
2858 fn ort_mismatch_message_recommends_auto_fix_first() {
2859 let msg =
2860 format_ort_version_mismatch("1.9.0", "/usr/lib/x86_64-linux-gnu/libonnxruntime.so");
2861
2862 assert!(
2864 msg.contains("v1.9.0"),
2865 "should report detected version: {msg}"
2866 );
2867 assert!(
2868 msg.contains("/usr/lib/x86_64-linux-gnu/libonnxruntime.so"),
2869 "should report system path: {msg}"
2870 );
2871 assert!(msg.contains("v1.20+"), "should state requirement: {msg}");
2872
2873 let auto_fix_pos = msg
2875 .find("Auto-fix")
2876 .expect("Auto-fix solution missing — users won't discover --fix");
2877 let remove_pos = msg
2878 .find("Remove the old library")
2879 .expect("system-rm solution missing");
2880 assert!(
2881 auto_fix_pos < remove_pos,
2882 "Auto-fix must come before manual rm — see PR comment thread"
2883 );
2884
2885 assert!(
2887 msg.contains("npx @cortexkit/aft doctor --fix"),
2888 "auto-fix command must be present and copy-pasteable: {msg}"
2889 );
2890 }
2891
2892 #[test]
2896 fn ort_mismatch_message_handles_macos_dylib_path() {
2897 let msg = format_ort_version_mismatch("1.9.0", "/opt/homebrew/lib/libonnxruntime.dylib");
2898 assert!(msg.contains("v1.9.0"));
2899 assert!(msg.contains("/opt/homebrew/lib/libonnxruntime.dylib"));
2900 assert!(
2904 msg.contains("'/opt/homebrew/lib/libonnxruntime.dylib'"),
2905 "system path should be quoted in the auto-fix sentence: {msg}"
2906 );
2907 }
2908}