1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::{detect_language, extract_symbols_from_tree, grammar_for};
3use crate::symbols::{Symbol, SymbolKind};
4use crate::{slog_info, slog_warn};
5
6use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
7use rayon::prelude::*;
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::env;
12use std::fmt::Display;
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16use std::time::SystemTime;
17use tree_sitter::Parser;
18use url::Url;
19
20const DEFAULT_DIMENSION: usize = 384;
21const MAX_ENTRIES: usize = 1_000_000;
22const MAX_DIMENSION: usize = 1024;
23const F32_BYTES: usize = std::mem::size_of::<f32>();
24const HEADER_BYTES_V1: usize = 9;
25const HEADER_BYTES_V2: usize = 13;
26const ONNX_RUNTIME_INSTALL_HINT: &str =
27 "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
28
29const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
30const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
31const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
36const SEMANTIC_INDEX_VERSION_V4: u8 = 4;
39const SEMANTIC_INDEX_VERSION_V5: u8 = 5;
42const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
43const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
44const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
46const DEFAULT_MAX_BATCH_SIZE: usize = 64;
47const FALLBACK_BACKEND: &str = "none";
48const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
49const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct SemanticIndexFingerprint {
53 pub backend: String,
54 pub model: String,
55 #[serde(default)]
56 pub base_url: String,
57 pub dimension: usize,
58}
59
60impl SemanticIndexFingerprint {
61 fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
62 let base_url = config
65 .base_url
66 .as_ref()
67 .and_then(|u| normalize_base_url(u).ok())
68 .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
69 Self {
70 backend: config.backend.as_str().to_string(),
71 model: config.model.clone(),
72 base_url,
73 dimension,
74 }
75 }
76
77 pub fn as_string(&self) -> String {
78 serde_json::to_string(self).unwrap_or_else(|_| String::new())
79 }
80
81 fn matches_expected(&self, expected: &str) -> bool {
82 let encoded = self.as_string();
83 !encoded.is_empty() && encoded == expected
84 }
85}
86
87enum SemanticEmbeddingEngine {
88 Fastembed(TextEmbedding),
89 OpenAiCompatible {
90 client: Client,
91 model: String,
92 base_url: String,
93 api_key: Option<String>,
94 },
95 Ollama {
96 client: Client,
97 model: String,
98 base_url: String,
99 },
100}
101
102pub struct SemanticEmbeddingModel {
103 backend: SemanticBackend,
104 model: String,
105 base_url: Option<String>,
106 timeout_ms: u64,
107 max_batch_size: usize,
108 dimension: Option<usize>,
109 engine: SemanticEmbeddingEngine,
110}
111
112pub type EmbeddingModel = SemanticEmbeddingModel;
113
114fn validate_embedding_batch(
115 vectors: &[Vec<f32>],
116 expected_count: usize,
117 context: &str,
118) -> Result<(), String> {
119 if expected_count > 0 && vectors.is_empty() {
120 return Err(format!(
121 "{context} returned no vectors for {expected_count} inputs"
122 ));
123 }
124
125 if vectors.len() != expected_count {
126 return Err(format!(
127 "{context} returned {} vectors for {} inputs",
128 vectors.len(),
129 expected_count
130 ));
131 }
132
133 let Some(first_vector) = vectors.first() else {
134 return Ok(());
135 };
136 let expected_dimension = first_vector.len();
137 for (index, vector) in vectors.iter().enumerate() {
138 if vector.len() != expected_dimension {
139 return Err(format!(
140 "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
141 vector.len()
142 ));
143 }
144 }
145
146 Ok(())
147}
148
149fn normalize_base_url(raw: &str) -> Result<String, String> {
153 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
154 let scheme = parsed.scheme();
155 if scheme != "http" && scheme != "https" {
156 return Err(format!(
157 "unsupported URL scheme '{}' — only http:// and https:// are allowed",
158 scheme
159 ));
160 }
161 Ok(parsed.to_string().trim_end_matches('/').to_string())
162}
163
164pub fn validate_base_url_no_ssrf(raw: &str) -> Result<(), String> {
179 use std::net::{IpAddr, ToSocketAddrs};
180
181 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
182
183 let host = parsed.host_str().unwrap_or("");
184
185 let is_loopback_host =
190 host == "localhost" || host == "localhost.localdomain" || host.ends_with(".localhost");
191 if is_loopback_host {
192 return Ok(());
193 }
194
195 if host.ends_with(".local") {
198 return Err(format!(
199 "base_url host '{host}' is an mDNS name — only loopback (localhost / 127.0.0.1) and public endpoints are allowed"
200 ));
201 }
202
203 let port = parsed.port_or_known_default().unwrap_or(443);
206 let addr_str = format!("{host}:{port}");
207 let addrs: Vec<IpAddr> = addr_str
208 .to_socket_addrs()
209 .map(|iter| iter.map(|sa| sa.ip()).collect())
210 .unwrap_or_default();
211 for ip in &addrs {
212 if is_private_non_loopback_ip(ip) {
213 return Err(format!(
214 "base_url '{raw}' resolves to a private/reserved IP — only loopback (127.0.0.1) and public endpoints are allowed"
215 ));
216 }
217 }
218
219 Ok(())
220}
221
222fn is_private_non_loopback_ip(ip: &std::net::IpAddr) -> bool {
226 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
227 match ip {
228 IpAddr::V4(v4) => {
229 let o = v4.octets();
230 o[0] == 10
233 || (o[0] == 172 && (16..=31).contains(&o[1]))
235 || (o[0] == 192 && o[1] == 168)
237 || (o[0] == 169 && o[1] == 254)
239 || (o[0] == 100 && (64..=127).contains(&o[1]))
241 || o[0] == 0
243 }
244 IpAddr::V6(v6) => {
245 let _ = Ipv6Addr::LOCALHOST; (v6.segments()[0] & 0xffc0) == 0xfe80
249 || (v6.segments()[0] & 0xfe00) == 0xfc00
251 || (v6.segments()[0] == 0 && v6.segments()[1] == 0
253 && v6.segments()[2] == 0 && v6.segments()[3] == 0
254 && v6.segments()[4] == 0 && v6.segments()[5] == 0xffff
255 && {
256 let [a, b] = v6.segments()[6..8] else { return false; };
257 let ipv4 = Ipv4Addr::new((a >> 8) as u8, (a & 0xff) as u8, (b >> 8) as u8, (b & 0xff) as u8);
258 is_private_non_loopback_ip(&IpAddr::V4(ipv4))
259 })
260 }
261 }
262}
263
264fn build_openai_embeddings_endpoint(base_url: &str) -> String {
265 if base_url.ends_with("/v1") {
266 format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
267 } else {
268 format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
269 }
270}
271
272fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
273 if base_url.ends_with("/api") {
274 format!("{base_url}/embed")
275 } else {
276 format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
277 }
278}
279
280fn normalize_api_key(value: Option<String>) -> Option<String> {
281 value.and_then(|token| {
282 let token = token.trim();
283 if token.is_empty() {
284 None
285 } else {
286 Some(token.to_string())
287 }
288 })
289}
290
291fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
292 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
293}
294
295fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
296 error.is_connect()
297}
298
299fn sleep_before_embedding_retry(attempt_index: usize) {
300 if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
301 std::thread::sleep(Duration::from_millis(*delay_ms));
302 }
303}
304
305fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
306where
307 F: FnMut() -> reqwest::blocking::RequestBuilder,
308{
309 for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
310 let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
311
312 let response = match make_request().send() {
313 Ok(response) => response,
314 Err(error) => {
315 if !last_attempt && is_retryable_embedding_error(&error) {
316 sleep_before_embedding_retry(attempt_index);
317 continue;
318 }
319 return Err(format!("{backend_label} request failed: {error}"));
320 }
321 };
322
323 let status = response.status();
324 let raw = match response.text() {
325 Ok(raw) => raw,
326 Err(error) => {
327 if !last_attempt && is_retryable_embedding_error(&error) {
328 sleep_before_embedding_retry(attempt_index);
329 continue;
330 }
331 return Err(format!("{backend_label} response read failed: {error}"));
332 }
333 };
334
335 if status.is_success() {
336 return Ok(raw);
337 }
338
339 if !last_attempt && is_retryable_embedding_status(status) {
340 sleep_before_embedding_retry(attempt_index);
341 continue;
342 }
343
344 return Err(format!(
345 "{backend_label} request failed (HTTP {}): {}",
346 status, raw
347 ));
348 }
349
350 unreachable!("embedding request retries exhausted without returning")
351}
352
353impl SemanticEmbeddingModel {
354 pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
355 let timeout_ms = if config.timeout_ms == 0 {
356 DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
357 } else {
358 config.timeout_ms
359 };
360
361 let max_batch_size = if config.max_batch_size == 0 {
362 DEFAULT_MAX_BATCH_SIZE
363 } else {
364 config.max_batch_size
365 };
366
367 let api_key_env = normalize_api_key(config.api_key_env.clone());
368 let model = config.model.clone();
369
370 let client = Client::builder()
371 .timeout(Duration::from_millis(timeout_ms))
372 .redirect(reqwest::redirect::Policy::none())
373 .build()
374 .map_err(|error| format!("failed to configure embedding client: {error}"))?;
375
376 let engine = match config.backend {
377 SemanticBackend::Fastembed => {
378 SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
379 }
380 SemanticBackend::OpenAiCompatible => {
381 let raw = config.base_url.as_ref().ok_or_else(|| {
382 "base_url is required for openai_compatible backend".to_string()
383 })?;
384 let base_url = normalize_base_url(raw)?;
385
386 let api_key = match api_key_env {
387 Some(var_name) => Some(env::var(&var_name).map_err(|_| {
388 format!("missing api_key_env '{var_name}' for openai_compatible backend")
389 })?),
390 None => None,
391 };
392
393 SemanticEmbeddingEngine::OpenAiCompatible {
394 client,
395 model,
396 base_url,
397 api_key,
398 }
399 }
400 SemanticBackend::Ollama => {
401 let raw = config
402 .base_url
403 .as_ref()
404 .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
405 let base_url = normalize_base_url(raw)?;
406
407 SemanticEmbeddingEngine::Ollama {
408 client,
409 model,
410 base_url,
411 }
412 }
413 };
414
415 Ok(Self {
416 backend: config.backend,
417 model: config.model.clone(),
418 base_url: config.base_url.clone(),
419 timeout_ms,
420 max_batch_size,
421 dimension: None,
422 engine,
423 })
424 }
425
426 pub fn backend(&self) -> SemanticBackend {
427 self.backend
428 }
429
430 pub fn model(&self) -> &str {
431 &self.model
432 }
433
434 pub fn base_url(&self) -> Option<&str> {
435 self.base_url.as_deref()
436 }
437
438 pub fn max_batch_size(&self) -> usize {
439 self.max_batch_size
440 }
441
442 pub fn timeout_ms(&self) -> u64 {
443 self.timeout_ms
444 }
445
446 pub fn fingerprint(
447 &mut self,
448 config: &SemanticBackendConfig,
449 ) -> Result<SemanticIndexFingerprint, String> {
450 let dimension = self.dimension()?;
451 Ok(SemanticIndexFingerprint::from_config(config, dimension))
452 }
453
454 pub fn dimension(&mut self) -> Result<usize, String> {
455 if let Some(dimension) = self.dimension {
456 return Ok(dimension);
457 }
458
459 let dimension = match &mut self.engine {
460 SemanticEmbeddingEngine::Fastembed(model) => {
461 let vectors = model
462 .embed(vec!["semantic index fingerprint probe".to_string()], None)
463 .map_err(|error| format_embedding_init_error(error.to_string()))?;
464 vectors
465 .first()
466 .map(|v| v.len())
467 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
468 }
469 SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
470 let vectors =
471 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
472 vectors
473 .first()
474 .map(|v| v.len())
475 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
476 }
477 SemanticEmbeddingEngine::Ollama { .. } => {
478 let vectors =
479 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
480 vectors
481 .first()
482 .map(|v| v.len())
483 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
484 }
485 };
486
487 self.dimension = Some(dimension);
488 Ok(dimension)
489 }
490
491 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
492 self.embed_texts(texts)
493 }
494
495 fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
496 match &mut self.engine {
497 SemanticEmbeddingEngine::Fastembed(model) => model
498 .embed(texts, None::<usize>)
499 .map_err(|error| format_embedding_init_error(error.to_string()))
500 .map_err(|error| format!("failed to embed batch: {error}")),
501 SemanticEmbeddingEngine::OpenAiCompatible {
502 client,
503 model,
504 base_url,
505 api_key,
506 } => {
507 let expected_text_count = texts.len();
508 let endpoint = build_openai_embeddings_endpoint(base_url);
509 let body = serde_json::json!({
510 "input": texts,
511 "model": model,
512 });
513
514 let raw = send_embedding_request(
515 || {
516 let mut request = client
517 .post(&endpoint)
518 .json(&body)
519 .header("Content-Type", "application/json");
520
521 if let Some(api_key) = api_key {
522 request = request.header("Authorization", format!("Bearer {api_key}"));
523 }
524
525 request
526 },
527 "openai compatible",
528 )?;
529
530 #[derive(Deserialize)]
531 struct OpenAiResponse {
532 data: Vec<OpenAiEmbeddingResult>,
533 }
534
535 #[derive(Deserialize)]
536 struct OpenAiEmbeddingResult {
537 embedding: Vec<f32>,
538 index: Option<u32>,
539 }
540
541 let parsed: OpenAiResponse = serde_json::from_str(&raw)
542 .map_err(|error| format!("invalid openai compatible response: {error}"))?;
543 if parsed.data.len() != expected_text_count {
544 return Err(format!(
545 "openai compatible response returned {} embeddings for {} inputs",
546 parsed.data.len(),
547 expected_text_count
548 ));
549 }
550
551 let mut vectors = vec![Vec::new(); parsed.data.len()];
552 for (i, item) in parsed.data.into_iter().enumerate() {
553 let index = item.index.unwrap_or(i as u32) as usize;
554 if index >= vectors.len() {
555 return Err(
556 "openai compatible response contains invalid vector index".to_string()
557 );
558 }
559 vectors[index] = item.embedding;
560 }
561
562 for vector in &vectors {
563 if vector.is_empty() {
564 return Err(
565 "openai compatible response contained missing vectors".to_string()
566 );
567 }
568 }
569
570 self.dimension = vectors.first().map(Vec::len);
571 Ok(vectors)
572 }
573 SemanticEmbeddingEngine::Ollama {
574 client,
575 model,
576 base_url,
577 } => {
578 let expected_text_count = texts.len();
579 let endpoint = build_ollama_embeddings_endpoint(base_url);
580
581 #[derive(Serialize)]
582 struct OllamaPayload<'a> {
583 model: &'a str,
584 input: Vec<String>,
585 }
586
587 let payload = OllamaPayload {
588 model,
589 input: texts,
590 };
591
592 let raw = send_embedding_request(
593 || {
594 client
595 .post(&endpoint)
596 .json(&payload)
597 .header("Content-Type", "application/json")
598 },
599 "ollama",
600 )?;
601
602 #[derive(Deserialize)]
603 struct OllamaResponse {
604 embeddings: Vec<Vec<f32>>,
605 }
606
607 let parsed: OllamaResponse = serde_json::from_str(&raw)
608 .map_err(|error| format!("invalid ollama response: {error}"))?;
609 if parsed.embeddings.is_empty() {
610 return Err("ollama response returned no embeddings".to_string());
611 }
612 if parsed.embeddings.len() != expected_text_count {
613 return Err(format!(
614 "ollama response returned {} embeddings for {} inputs",
615 parsed.embeddings.len(),
616 expected_text_count
617 ));
618 }
619
620 let vectors = parsed.embeddings;
621 for vector in &vectors {
622 if vector.is_empty() {
623 return Err("ollama response contained empty embeddings".to_string());
624 }
625 }
626
627 self.dimension = vectors.first().map(Vec::len);
628 Ok(vectors)
629 }
630 }
631 }
632}
633
634pub fn pre_validate_onnx_runtime() -> Result<(), String> {
638 let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
639
640 #[cfg(any(target_os = "linux", target_os = "macos"))]
641 {
642 #[cfg(target_os = "linux")]
643 let default_name = "libonnxruntime.so";
644 #[cfg(target_os = "macos")]
645 let default_name = "libonnxruntime.dylib";
646
647 let lib_name = dylib_path.as_deref().unwrap_or(default_name);
648
649 unsafe {
650 let c_name = std::ffi::CString::new(lib_name)
651 .map_err(|e| format!("invalid library path: {}", e))?;
652 let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
653 if handle.is_null() {
654 let err = libc::dlerror();
655 let msg = if err.is_null() {
656 "unknown dlopen error".to_string()
657 } else {
658 std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
659 };
660 return Err(format!(
661 "ONNX Runtime not found. dlopen('{}') failed: {}. \
662 Run `npx @cortexkit/aft doctor` to diagnose.",
663 lib_name, msg
664 ));
665 }
666
667 let detected_version = detect_ort_version_from_path(lib_name);
670
671 libc::dlclose(handle);
672
673 if let Some(ref version) = detected_version {
675 let parts: Vec<&str> = version.split('.').collect();
676 if let (Some(major), Some(minor)) = (
677 parts.first().and_then(|s| s.parse::<u32>().ok()),
678 parts.get(1).and_then(|s| s.parse::<u32>().ok()),
679 ) {
680 if major != 1 || minor < 20 {
681 return Err(format_ort_version_mismatch(version, lib_name));
682 }
683 }
684 }
685 }
686 }
687
688 #[cfg(target_os = "windows")]
689 {
690 let _ = dylib_path;
692 }
693
694 Ok(())
695}
696
697fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
700 let path = std::path::Path::new(lib_path);
701
702 for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
704 .into_iter()
705 .flatten()
706 {
707 if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
708 if let Some(version) = extract_version_from_filename(name) {
709 return Some(version);
710 }
711 }
712 }
713
714 if let Some(parent) = path.parent() {
716 if let Ok(entries) = std::fs::read_dir(parent) {
717 for entry in entries.flatten() {
718 if let Some(name) = entry.file_name().to_str() {
719 if name.starts_with("libonnxruntime") {
720 if let Some(version) = extract_version_from_filename(name) {
721 return Some(version);
722 }
723 }
724 }
725 }
726 }
727 }
728
729 None
730}
731
732fn extract_version_from_filename(name: &str) -> Option<String> {
734 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
736 re.find(name).map(|m| m.as_str().to_string())
737}
738
739fn suggest_removal_command(lib_path: &str) -> String {
740 if lib_path.starts_with("/usr/local/lib")
741 || lib_path == "libonnxruntime.so"
742 || lib_path == "libonnxruntime.dylib"
743 {
744 #[cfg(target_os = "linux")]
745 return " sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
746 #[cfg(target_os = "macos")]
747 return " sudo rm /usr/local/lib/libonnxruntime*".to_string();
748 #[cfg(target_os = "windows")]
749 return " Delete the ONNX Runtime DLL from your PATH".to_string();
750 }
751 format!(" rm '{}'", lib_path)
752}
753
754pub(crate) fn format_ort_version_mismatch(version: &str, lib_name: &str) -> String {
760 format!(
761 "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
762 Solutions:\n\
763 1. Auto-fix (recommended): run `npx @cortexkit/aft doctor --fix`. \
764 This downloads AFT-managed ONNX Runtime v1.24 into AFT's storage and \
765 configures the bridge to load it instead of the system library — no \
766 changes to '{}'.\n\
767 2. Remove the old library and restart (AFT auto-downloads the correct version on next start):\n\
768 {}\n\
769 3. Or install ONNX Runtime 1.24 system-wide: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
770 4. Run `npx @cortexkit/aft doctor` for full diagnostics.",
771 version,
772 lib_name,
773 lib_name,
774 suggest_removal_command(lib_name),
775 )
776}
777
778pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
779 pre_validate_onnx_runtime()?;
781
782 let selected_model = match model {
783 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
784 _ => {
785 return Err(format!(
786 "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
787 model
788 ))
789 }
790 };
791
792 TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
793}
794
795pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
796 if message.trim_start().starts_with("ONNX Runtime not found.") {
797 return true;
798 }
799
800 let message = message.to_ascii_lowercase();
801 let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
802 .iter()
803 .any(|pattern| message.contains(pattern));
804 let mentions_dynamic_load_failure = [
805 "shared library",
806 "dynamic library",
807 "failed to load",
808 "could not load",
809 "unable to load",
810 "dlopen",
811 "loadlibrary",
812 "no such file",
813 "not found",
814 ]
815 .iter()
816 .any(|pattern| message.contains(pattern));
817
818 mentions_onnx_runtime && mentions_dynamic_load_failure
819}
820
821fn format_embedding_init_error(error: impl Display) -> String {
822 let message = error.to_string();
823
824 if is_onnx_runtime_unavailable(&message) {
825 return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
826 }
827
828 format!("failed to initialize semantic embedding model: {message}")
829}
830
831#[derive(Debug, Clone)]
833pub struct SemanticChunk {
834 pub file: PathBuf,
836 pub name: String,
838 pub kind: SymbolKind,
840 pub start_line: u32,
842 pub end_line: u32,
843 pub exported: bool,
845 pub embed_text: String,
847 pub snippet: String,
849}
850
851#[derive(Debug)]
853struct EmbeddingEntry {
854 chunk: SemanticChunk,
855 vector: Vec<f32>,
856}
857
858#[derive(Debug)]
860pub struct SemanticIndex {
861 entries: Vec<EmbeddingEntry>,
862 file_mtimes: HashMap<PathBuf, SystemTime>,
864 file_sizes: HashMap<PathBuf, u64>,
866 dimension: usize,
868 fingerprint: Option<SemanticIndexFingerprint>,
869}
870
871#[derive(Debug, Clone, Copy)]
872struct IndexedFileMetadata {
873 mtime: SystemTime,
874 size: u64,
875}
876
877#[derive(Debug, Default, Clone, Copy)]
880pub struct RefreshSummary {
881 pub changed: usize,
882 pub added: usize,
883 pub deleted: usize,
884 pub total_processed: usize,
885}
886
887impl RefreshSummary {
888 pub fn is_noop(&self) -> bool {
890 self.changed == 0 && self.added == 0 && self.deleted == 0
891 }
892}
893
894#[derive(Debug)]
896pub struct SemanticResult {
897 pub file: PathBuf,
898 pub name: String,
899 pub kind: SymbolKind,
900 pub start_line: u32,
901 pub end_line: u32,
902 pub exported: bool,
903 pub snippet: String,
904 pub score: f32,
905}
906
907impl SemanticIndex {
908 pub fn new() -> Self {
909 Self {
910 entries: Vec::new(),
911 file_mtimes: HashMap::new(),
912 file_sizes: HashMap::new(),
913 dimension: DEFAULT_DIMENSION, fingerprint: None,
915 }
916 }
917
918 pub fn entry_count(&self) -> usize {
920 self.entries.len()
921 }
922
923 pub fn status_label(&self) -> &'static str {
925 if self.entries.is_empty() {
926 "empty"
927 } else {
928 "ready"
929 }
930 }
931
932 fn collect_chunks(
933 project_root: &Path,
934 files: &[PathBuf],
935 ) -> (Vec<SemanticChunk>, HashMap<PathBuf, IndexedFileMetadata>) {
936 let per_file: Vec<(
937 PathBuf,
938 Result<(IndexedFileMetadata, Vec<SemanticChunk>), String>,
939 )> = files
940 .par_iter()
941 .map_init(HashMap::new, |parsers, file| {
942 let result = collect_file_metadata(file).and_then(|metadata| {
943 collect_file_chunks(project_root, file, parsers)
944 .map(|chunks| (metadata, chunks))
945 });
946 (file.clone(), result)
947 })
948 .collect();
949
950 let mut chunks: Vec<SemanticChunk> = Vec::new();
951 let mut file_metadata: HashMap<PathBuf, IndexedFileMetadata> = HashMap::new();
952
953 for (file, result) in per_file {
954 match result {
955 Ok((metadata, file_chunks)) => {
956 file_metadata.insert(file, metadata);
957 chunks.extend(file_chunks);
958 }
959 Err(error) => {
960 if error == "unsupported file extension" {
966 continue;
967 }
968 slog_warn!(
969 "failed to collect semantic chunks for {}: {}",
970 file.display(),
971 error
972 );
973 }
974 }
975 }
976
977 (chunks, file_metadata)
978 }
979
980 fn build_from_chunks<F, P>(
981 chunks: Vec<SemanticChunk>,
982 file_metadata: HashMap<PathBuf, IndexedFileMetadata>,
983 embed_fn: &mut F,
984 max_batch_size: usize,
985 mut progress: Option<&mut P>,
986 ) -> Result<Self, String>
987 where
988 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
989 P: FnMut(usize, usize),
990 {
991 let total_chunks = chunks.len();
992
993 if chunks.is_empty() {
994 return Ok(Self {
995 entries: Vec::new(),
996 file_mtimes: file_metadata
997 .iter()
998 .map(|(path, metadata)| (path.clone(), metadata.mtime))
999 .collect(),
1000 file_sizes: file_metadata
1001 .into_iter()
1002 .map(|(path, metadata)| (path, metadata.size))
1003 .collect(),
1004 dimension: DEFAULT_DIMENSION,
1005 fingerprint: None,
1006 });
1007 }
1008
1009 let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
1011 let mut expected_dimension: Option<usize> = None;
1012 let batch_size = max_batch_size.max(1);
1013 for batch_start in (0..chunks.len()).step_by(batch_size) {
1014 let batch_end = (batch_start + batch_size).min(chunks.len());
1015 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
1016 .iter()
1017 .map(|c| c.embed_text.clone())
1018 .collect();
1019
1020 let vectors = embed_fn(batch_texts)?;
1021 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
1022
1023 if let Some(dim) = vectors.first().map(|v| v.len()) {
1025 match expected_dimension {
1026 None => expected_dimension = Some(dim),
1027 Some(expected) if dim != expected => {
1028 return Err(format!(
1029 "embedding dimension changed across batches: expected {expected}, got {dim}"
1030 ));
1031 }
1032 _ => {}
1033 }
1034 }
1035
1036 for (i, vector) in vectors.into_iter().enumerate() {
1037 let chunk_idx = batch_start + i;
1038 entries.push(EmbeddingEntry {
1039 chunk: chunks[chunk_idx].clone(),
1040 vector,
1041 });
1042 }
1043
1044 if let Some(callback) = progress.as_mut() {
1045 callback(entries.len(), total_chunks);
1046 }
1047 }
1048
1049 let dimension = entries
1050 .first()
1051 .map(|e| e.vector.len())
1052 .unwrap_or(DEFAULT_DIMENSION);
1053
1054 Ok(Self {
1055 entries,
1056 file_mtimes: file_metadata
1057 .iter()
1058 .map(|(path, metadata)| (path.clone(), metadata.mtime))
1059 .collect(),
1060 file_sizes: file_metadata
1061 .into_iter()
1062 .map(|(path, metadata)| (path, metadata.size))
1063 .collect(),
1064 dimension,
1065 fingerprint: None,
1066 })
1067 }
1068
1069 pub fn build<F>(
1072 project_root: &Path,
1073 files: &[PathBuf],
1074 embed_fn: &mut F,
1075 max_batch_size: usize,
1076 ) -> Result<Self, String>
1077 where
1078 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1079 {
1080 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1081 Self::build_from_chunks(
1082 chunks,
1083 file_mtimes,
1084 embed_fn,
1085 max_batch_size,
1086 Option::<&mut fn(usize, usize)>::None,
1087 )
1088 }
1089
1090 pub fn build_with_progress<F, P>(
1092 project_root: &Path,
1093 files: &[PathBuf],
1094 embed_fn: &mut F,
1095 max_batch_size: usize,
1096 progress: &mut P,
1097 ) -> Result<Self, String>
1098 where
1099 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1100 P: FnMut(usize, usize),
1101 {
1102 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1103 let total_chunks = chunks.len();
1104 progress(0, total_chunks);
1105 Self::build_from_chunks(
1106 chunks,
1107 file_mtimes,
1108 embed_fn,
1109 max_batch_size,
1110 Some(progress),
1111 )
1112 }
1113
1114 pub fn refresh_stale_files<F, P>(
1125 &mut self,
1126 project_root: &Path,
1127 current_files: &[PathBuf],
1128 embed_fn: &mut F,
1129 max_batch_size: usize,
1130 progress: &mut P,
1131 ) -> Result<RefreshSummary, String>
1132 where
1133 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1134 P: FnMut(usize, usize),
1135 {
1136 self.backfill_missing_file_sizes();
1137
1138 let current_set: HashSet<&Path> = current_files.iter().map(PathBuf::as_path).collect();
1140 let total_processed = current_set.len() + self.file_mtimes.len()
1141 - self
1142 .file_mtimes
1143 .keys()
1144 .filter(|path| current_set.contains(path.as_path()))
1145 .count();
1146
1147 let mut deleted: Vec<PathBuf> = Vec::new();
1150 let mut changed: Vec<PathBuf> = Vec::new();
1151 for indexed_path in self.file_mtimes.keys() {
1152 if !current_set.contains(indexed_path.as_path()) {
1153 deleted.push(indexed_path.clone());
1154 continue;
1155 }
1156 if self.is_file_stale(indexed_path) {
1157 changed.push(indexed_path.clone());
1158 }
1159 }
1160
1161 let mut added: Vec<PathBuf> = Vec::new();
1163 for path in current_files {
1164 if !self.file_mtimes.contains_key(path) {
1165 added.push(path.clone());
1166 }
1167 }
1168
1169 if deleted.is_empty() && changed.is_empty() && added.is_empty() {
1171 progress(0, 0);
1172 return Ok(RefreshSummary {
1173 total_processed,
1174 ..RefreshSummary::default()
1175 });
1176 }
1177
1178 if !deleted.is_empty() {
1182 let deleted_set: HashSet<&Path> = deleted.iter().map(PathBuf::as_path).collect();
1183 self.entries
1184 .retain(|entry| !deleted_set.contains(entry.chunk.file.as_path()));
1185 for path in &deleted {
1186 self.file_mtimes.remove(path);
1187 self.file_sizes.remove(path);
1188 }
1189 }
1190
1191 let mut to_embed: Vec<PathBuf> = Vec::with_capacity(changed.len() + added.len());
1193 to_embed.extend(changed.iter().cloned());
1194 to_embed.extend(added.iter().cloned());
1195
1196 if to_embed.is_empty() {
1197 progress(0, 0);
1199 return Ok(RefreshSummary {
1200 changed: 0,
1201 added: 0,
1202 deleted: deleted.len(),
1203 total_processed,
1204 });
1205 }
1206
1207 let (chunks, fresh_metadata) = Self::collect_chunks(project_root, &to_embed);
1208
1209 if chunks.is_empty() {
1210 progress(0, 0);
1211 let successful_files: HashSet<PathBuf> = fresh_metadata.keys().cloned().collect();
1212 if !successful_files.is_empty() {
1213 self.entries
1214 .retain(|entry| !successful_files.contains(&entry.chunk.file));
1215 }
1216 let changed_count = changed
1217 .iter()
1218 .filter(|path| successful_files.contains(*path))
1219 .count();
1220 let added_count = added
1221 .iter()
1222 .filter(|path| successful_files.contains(*path))
1223 .count();
1224 for (file, metadata) in fresh_metadata {
1225 self.file_mtimes.insert(file.clone(), metadata.mtime);
1226 self.file_sizes.insert(file, metadata.size);
1227 }
1228 return Ok(RefreshSummary {
1229 changed: changed_count,
1230 added: added_count,
1231 deleted: deleted.len(),
1232 total_processed,
1233 });
1234 }
1235
1236 let total_chunks = chunks.len();
1238 progress(0, total_chunks);
1239 let batch_size = max_batch_size.max(1);
1240 let existing_dimension = if self.entries.is_empty() {
1241 None
1242 } else {
1243 Some(self.dimension)
1244 };
1245 let mut new_entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
1246 let mut observed_dimension: Option<usize> = existing_dimension;
1247
1248 for batch_start in (0..chunks.len()).step_by(batch_size) {
1249 let batch_end = (batch_start + batch_size).min(chunks.len());
1250 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
1251 .iter()
1252 .map(|c| c.embed_text.clone())
1253 .collect();
1254
1255 let vectors = embed_fn(batch_texts)?;
1256 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
1257
1258 if let Some(dim) = vectors.first().map(|v| v.len()) {
1259 match observed_dimension {
1260 None => observed_dimension = Some(dim),
1261 Some(expected) if dim != expected => {
1262 return Err(format!(
1265 "embedding dimension changed during incremental refresh: \
1266 cached index uses {expected}, new vectors use {dim}"
1267 ));
1268 }
1269 _ => {}
1270 }
1271 }
1272
1273 for (i, vector) in vectors.into_iter().enumerate() {
1274 let chunk_idx = batch_start + i;
1275 new_entries.push(EmbeddingEntry {
1276 chunk: chunks[chunk_idx].clone(),
1277 vector,
1278 });
1279 }
1280
1281 progress(new_entries.len(), total_chunks);
1282 }
1283
1284 let successful_files: HashSet<PathBuf> = fresh_metadata.keys().cloned().collect();
1285 if !successful_files.is_empty() {
1286 self.entries
1287 .retain(|entry| !successful_files.contains(&entry.chunk.file));
1288 }
1289
1290 self.entries.extend(new_entries);
1291 for (file, metadata) in fresh_metadata {
1292 self.file_mtimes.insert(file.clone(), metadata.mtime);
1293 self.file_sizes.insert(file, metadata.size);
1294 }
1295 if let Some(dim) = observed_dimension {
1296 self.dimension = dim;
1297 }
1298
1299 Ok(RefreshSummary {
1300 changed: changed
1301 .iter()
1302 .filter(|path| successful_files.contains(*path))
1303 .count(),
1304 added: added
1305 .iter()
1306 .filter(|path| successful_files.contains(*path))
1307 .count(),
1308 deleted: deleted.len(),
1309 total_processed,
1310 })
1311 }
1312
1313 pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
1315 if self.entries.is_empty() || query_vector.len() != self.dimension {
1316 return Vec::new();
1317 }
1318
1319 let mut scored: Vec<(f32, usize)> = self
1320 .entries
1321 .iter()
1322 .enumerate()
1323 .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
1324 .collect();
1325
1326 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
1328
1329 scored
1330 .into_iter()
1331 .take(top_k)
1332 .filter(|(score, _)| *score > 0.0)
1333 .map(|(score, idx)| {
1334 let entry = &self.entries[idx];
1335 SemanticResult {
1336 file: entry.chunk.file.clone(),
1337 name: entry.chunk.name.clone(),
1338 kind: entry.chunk.kind.clone(),
1339 start_line: entry.chunk.start_line,
1340 end_line: entry.chunk.end_line,
1341 exported: entry.chunk.exported,
1342 snippet: entry.chunk.snippet.clone(),
1343 score,
1344 }
1345 })
1346 .collect()
1347 }
1348
1349 pub fn len(&self) -> usize {
1351 self.entries.len()
1352 }
1353
1354 pub fn is_file_stale(&self, file: &Path) -> bool {
1356 let Some(stored_mtime) = self.file_mtimes.get(file) else {
1357 return true;
1358 };
1359 let Some(stored_size) = self.file_sizes.get(file) else {
1360 return true;
1361 };
1362 match collect_file_metadata(file) {
1363 Ok(current) => *stored_mtime != current.mtime || *stored_size != current.size,
1364 Err(_) => true,
1365 }
1366 }
1367
1368 fn backfill_missing_file_sizes(&mut self) {
1369 for path in self.file_mtimes.keys() {
1370 if self.file_sizes.contains_key(path) {
1371 continue;
1372 }
1373 if let Ok(metadata) = fs::metadata(path) {
1374 self.file_sizes.insert(path.clone(), metadata.len());
1375 }
1376 }
1377 }
1378
1379 pub fn remove_file(&mut self, file: &Path) {
1381 self.invalidate_file(file);
1382 }
1383
1384 pub fn invalidate_file(&mut self, file: &Path) {
1385 self.entries.retain(|e| e.chunk.file != file);
1386 self.file_mtimes.remove(file);
1387 self.file_sizes.remove(file);
1388 }
1389
1390 pub fn dimension(&self) -> usize {
1392 self.dimension
1393 }
1394
1395 pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1396 self.fingerprint.as_ref()
1397 }
1398
1399 pub fn backend_label(&self) -> Option<&str> {
1400 self.fingerprint.as_ref().map(|f| f.backend.as_str())
1401 }
1402
1403 pub fn model_label(&self) -> Option<&str> {
1404 self.fingerprint.as_ref().map(|f| f.model.as_str())
1405 }
1406
1407 pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1408 self.fingerprint = Some(fingerprint);
1409 }
1410
1411 pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1413 if self.entries.is_empty() {
1416 slog_info!("skipping semantic index persistence (0 entries)");
1417 return;
1418 }
1419 let dir = storage_dir.join("semantic").join(project_key);
1420 if let Err(e) = fs::create_dir_all(&dir) {
1421 slog_warn!("failed to create semantic cache dir: {}", e);
1422 return;
1423 }
1424 let data_path = dir.join("semantic.bin");
1425 let tmp_path = dir.join(format!(
1426 "semantic.bin.tmp.{}.{}",
1427 std::process::id(),
1428 SystemTime::now()
1429 .duration_since(SystemTime::UNIX_EPOCH)
1430 .unwrap_or(Duration::ZERO)
1431 .as_nanos()
1432 ));
1433 let bytes = self.to_bytes();
1434 let write_result = (|| -> std::io::Result<()> {
1435 use std::io::Write;
1436 let mut file = fs::File::create(&tmp_path)?;
1437 file.write_all(&bytes)?;
1438 file.sync_all()?;
1439 Ok(())
1440 })();
1441 if let Err(e) = write_result {
1442 slog_warn!("failed to write semantic index: {}", e);
1443 let _ = fs::remove_file(&tmp_path);
1444 return;
1445 }
1446 if let Err(e) = fs::rename(&tmp_path, &data_path) {
1447 slog_warn!("failed to rename semantic index: {}", e);
1448 let _ = fs::remove_file(&tmp_path);
1449 return;
1450 }
1451 slog_info!(
1452 "semantic index persisted: {} entries, {:.1} KB",
1453 self.entries.len(),
1454 bytes.len() as f64 / 1024.0
1455 );
1456 }
1457
1458 pub fn read_from_disk(
1460 storage_dir: &Path,
1461 project_key: &str,
1462 expected_fingerprint: Option<&str>,
1463 ) -> Option<Self> {
1464 let data_path = storage_dir
1465 .join("semantic")
1466 .join(project_key)
1467 .join("semantic.bin");
1468 let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1469 if file_len < HEADER_BYTES_V1 {
1470 slog_warn!(
1471 "corrupt semantic index (too small: {} bytes), removing",
1472 file_len
1473 );
1474 let _ = fs::remove_file(&data_path);
1475 return None;
1476 }
1477
1478 let bytes = fs::read(&data_path).ok()?;
1479 let version = bytes[0];
1480 if version != SEMANTIC_INDEX_VERSION_V5 {
1481 slog_info!(
1482 "cached semantic index version {} is older than {}, rebuilding",
1483 version,
1484 SEMANTIC_INDEX_VERSION_V5
1485 );
1486 let _ = fs::remove_file(&data_path);
1487 return None;
1488 }
1489 match Self::from_bytes(&bytes) {
1490 Ok(index) => {
1491 if index.entries.is_empty() {
1492 slog_info!("cached semantic index is empty, will rebuild");
1493 let _ = fs::remove_file(&data_path);
1494 return None;
1495 }
1496 if let Some(expected) = expected_fingerprint {
1497 let matches = index
1498 .fingerprint()
1499 .map(|fingerprint| fingerprint.matches_expected(expected))
1500 .unwrap_or(false);
1501 if !matches {
1502 slog_info!("cached semantic index fingerprint mismatch, rebuilding");
1503 let _ = fs::remove_file(&data_path);
1504 return None;
1505 }
1506 }
1507 slog_info!(
1508 "loaded semantic index from disk: {} entries",
1509 index.entries.len()
1510 );
1511 Some(index)
1512 }
1513 Err(e) => {
1514 slog_warn!("corrupt semantic index, rebuilding: {}", e);
1515 let _ = fs::remove_file(&data_path);
1516 None
1517 }
1518 }
1519 }
1520
1521 pub fn to_bytes(&self) -> Vec<u8> {
1523 let mut buf = Vec::new();
1524 let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1525 let encoded = fingerprint.as_string();
1526 if encoded.is_empty() {
1527 None
1528 } else {
1529 Some(encoded.into_bytes())
1530 }
1531 });
1532
1533 let version = SEMANTIC_INDEX_VERSION_V5;
1545 buf.push(version);
1546 buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1547 buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1548 let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1549 buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1550 buf.extend_from_slice(fp_bytes_ref);
1551
1552 buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1555 for (path, mtime) in &self.file_mtimes {
1556 let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1557 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1558 buf.extend_from_slice(&path_bytes);
1559 let duration = mtime
1560 .duration_since(SystemTime::UNIX_EPOCH)
1561 .unwrap_or_default();
1562 buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1563 buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1564 let size = self.file_sizes.get(path).copied().unwrap_or_default();
1565 buf.extend_from_slice(&size.to_le_bytes());
1566 }
1567
1568 for entry in &self.entries {
1570 let c = &entry.chunk;
1571
1572 let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1574 buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1575 buf.extend_from_slice(&file_bytes);
1576
1577 let name_bytes = c.name.as_bytes();
1579 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1580 buf.extend_from_slice(name_bytes);
1581
1582 buf.push(symbol_kind_to_u8(&c.kind));
1584
1585 buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1587 buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1588 buf.push(c.exported as u8);
1589
1590 let snippet_bytes = c.snippet.as_bytes();
1592 buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1593 buf.extend_from_slice(snippet_bytes);
1594
1595 let embed_bytes = c.embed_text.as_bytes();
1597 buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1598 buf.extend_from_slice(embed_bytes);
1599
1600 for &val in &entry.vector {
1602 buf.extend_from_slice(&val.to_le_bytes());
1603 }
1604 }
1605
1606 buf
1607 }
1608
1609 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1611 let mut pos = 0;
1612
1613 if data.len() < HEADER_BYTES_V1 {
1614 return Err("data too short".to_string());
1615 }
1616
1617 let version = data[pos];
1618 pos += 1;
1619 if version != SEMANTIC_INDEX_VERSION_V1
1620 && version != SEMANTIC_INDEX_VERSION_V2
1621 && version != SEMANTIC_INDEX_VERSION_V3
1622 && version != SEMANTIC_INDEX_VERSION_V4
1623 && version != SEMANTIC_INDEX_VERSION_V5
1624 {
1625 return Err(format!("unsupported version: {}", version));
1626 }
1627 if (version == SEMANTIC_INDEX_VERSION_V2
1631 || version == SEMANTIC_INDEX_VERSION_V3
1632 || version == SEMANTIC_INDEX_VERSION_V4
1633 || version == SEMANTIC_INDEX_VERSION_V5)
1634 && data.len() < HEADER_BYTES_V2
1635 {
1636 return Err("data too short for semantic index v2/v3/v4/v5 header".to_string());
1637 }
1638
1639 let dimension = read_u32(data, &mut pos)? as usize;
1640 let entry_count = read_u32(data, &mut pos)? as usize;
1641 if dimension == 0 || dimension > MAX_DIMENSION {
1642 return Err(format!("invalid embedding dimension: {}", dimension));
1643 }
1644 if entry_count > MAX_ENTRIES {
1645 return Err(format!("too many semantic index entries: {}", entry_count));
1646 }
1647
1648 let has_fingerprint_field = version == SEMANTIC_INDEX_VERSION_V2
1654 || version == SEMANTIC_INDEX_VERSION_V3
1655 || version == SEMANTIC_INDEX_VERSION_V4
1656 || version == SEMANTIC_INDEX_VERSION_V5;
1657 let fingerprint = if has_fingerprint_field {
1658 let fingerprint_len = read_u32(data, &mut pos)? as usize;
1659 if pos + fingerprint_len > data.len() {
1660 return Err("unexpected end of data reading fingerprint".to_string());
1661 }
1662 if fingerprint_len == 0 {
1663 None
1664 } else {
1665 let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1666 pos += fingerprint_len;
1667 Some(
1668 serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1669 .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1670 )
1671 }
1672 } else {
1673 None
1674 };
1675
1676 let mtime_count = read_u32(data, &mut pos)? as usize;
1678 if mtime_count > MAX_ENTRIES {
1679 return Err(format!("too many semantic file mtimes: {}", mtime_count));
1680 }
1681
1682 let vector_bytes = entry_count
1683 .checked_mul(dimension)
1684 .and_then(|count| count.checked_mul(F32_BYTES))
1685 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1686 if vector_bytes > data.len().saturating_sub(pos) {
1687 return Err("semantic index vectors exceed available data".to_string());
1688 }
1689
1690 let mut file_mtimes = HashMap::with_capacity(mtime_count);
1691 let mut file_sizes = HashMap::with_capacity(mtime_count);
1692 for _ in 0..mtime_count {
1693 let path = read_string(data, &mut pos)?;
1694 let secs = read_u64(data, &mut pos)?;
1695 let nanos = if version == SEMANTIC_INDEX_VERSION_V3
1701 || version == SEMANTIC_INDEX_VERSION_V4
1702 || version == SEMANTIC_INDEX_VERSION_V5
1703 {
1704 read_u32(data, &mut pos)?
1705 } else {
1706 0
1707 };
1708 let size = if version == SEMANTIC_INDEX_VERSION_V5 {
1709 read_u64(data, &mut pos)?
1710 } else {
1711 0
1712 };
1713 if nanos >= 1_000_000_000 {
1720 return Err(format!(
1721 "invalid semantic mtime: nanos {} >= 1_000_000_000",
1722 nanos
1723 ));
1724 }
1725 let duration = std::time::Duration::new(secs, nanos);
1726 let mtime = SystemTime::UNIX_EPOCH
1727 .checked_add(duration)
1728 .ok_or_else(|| {
1729 format!(
1730 "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1731 secs, nanos
1732 )
1733 })?;
1734 let path = PathBuf::from(path);
1735 file_mtimes.insert(path.clone(), mtime);
1736 file_sizes.insert(path, size);
1737 }
1738
1739 let mut entries = Vec::with_capacity(entry_count);
1741 for _ in 0..entry_count {
1742 let file = PathBuf::from(read_string(data, &mut pos)?);
1743 let name = read_string(data, &mut pos)?;
1744
1745 if pos >= data.len() {
1746 return Err("unexpected end of data".to_string());
1747 }
1748 let kind = u8_to_symbol_kind(data[pos]);
1749 pos += 1;
1750
1751 let start_line = read_u32(data, &mut pos)?;
1752 let end_line = read_u32(data, &mut pos)?;
1753
1754 if pos >= data.len() {
1755 return Err("unexpected end of data".to_string());
1756 }
1757 let exported = data[pos] != 0;
1758 pos += 1;
1759
1760 let snippet = read_string(data, &mut pos)?;
1761 let embed_text = read_string(data, &mut pos)?;
1762
1763 let vec_bytes = dimension
1765 .checked_mul(F32_BYTES)
1766 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1767 if pos + vec_bytes > data.len() {
1768 return Err("unexpected end of data reading vector".to_string());
1769 }
1770 let mut vector = Vec::with_capacity(dimension);
1771 for _ in 0..dimension {
1772 let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1773 vector.push(f32::from_le_bytes(bytes));
1774 pos += 4;
1775 }
1776
1777 entries.push(EmbeddingEntry {
1778 chunk: SemanticChunk {
1779 file,
1780 name,
1781 kind,
1782 start_line,
1783 end_line,
1784 exported,
1785 embed_text,
1786 snippet,
1787 },
1788 vector,
1789 });
1790 }
1791
1792 if entries.len() != entry_count {
1793 return Err(format!(
1794 "semantic cache entry count drift: header={} decoded={}",
1795 entry_count,
1796 entries.len()
1797 ));
1798 }
1799 for entry in &entries {
1800 if !file_mtimes.contains_key(&entry.chunk.file) {
1801 return Err(format!(
1802 "semantic cache metadata missing for entry file {}",
1803 entry.chunk.file.display()
1804 ));
1805 }
1806 }
1807
1808 Ok(Self {
1809 entries,
1810 file_mtimes,
1811 file_sizes,
1812 dimension,
1813 fingerprint,
1814 })
1815 }
1816}
1817
1818fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1820 let relative = file
1821 .strip_prefix(project_root)
1822 .unwrap_or(file)
1823 .to_string_lossy();
1824
1825 let kind_label = match &symbol.kind {
1826 SymbolKind::Function => "function",
1827 SymbolKind::Class => "class",
1828 SymbolKind::Method => "method",
1829 SymbolKind::Struct => "struct",
1830 SymbolKind::Interface => "interface",
1831 SymbolKind::Enum => "enum",
1832 SymbolKind::TypeAlias => "type",
1833 SymbolKind::Variable => "variable",
1834 SymbolKind::Heading => "heading",
1835 };
1836
1837 let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1839
1840 if let Some(sig) = &symbol.signature {
1841 text.push_str(&format!(" signature:{}", sig));
1842 }
1843
1844 let lines: Vec<&str> = source.lines().collect();
1846 let start = (symbol.range.start_line as usize).min(lines.len());
1847 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1849 if start < end {
1850 let body: String = lines[start..end]
1851 .iter()
1852 .take(15) .copied()
1854 .collect::<Vec<&str>>()
1855 .join("\n");
1856 let snippet = if body.len() > 300 {
1857 format!("{}...", &body[..body.floor_char_boundary(300)])
1858 } else {
1859 body
1860 };
1861 text.push_str(&format!(" body:{}", snippet));
1862 }
1863
1864 text
1865}
1866
1867fn parser_for(
1868 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1869 lang: crate::parser::LangId,
1870) -> Result<&mut Parser, String> {
1871 use std::collections::hash_map::Entry;
1872
1873 match parsers.entry(lang) {
1874 Entry::Occupied(entry) => Ok(entry.into_mut()),
1875 Entry::Vacant(entry) => {
1876 let grammar = grammar_for(lang);
1877 let mut parser = Parser::new();
1878 parser
1879 .set_language(&grammar)
1880 .map_err(|error| error.to_string())?;
1881 Ok(entry.insert(parser))
1882 }
1883 }
1884}
1885
1886fn collect_file_metadata(file: &Path) -> Result<IndexedFileMetadata, String> {
1887 let metadata = fs::metadata(file).map_err(|error| error.to_string())?;
1888 let mtime = metadata.modified().map_err(|error| error.to_string())?;
1889 Ok(IndexedFileMetadata {
1890 mtime,
1891 size: metadata.len(),
1892 })
1893}
1894
1895fn collect_file_chunks(
1896 project_root: &Path,
1897 file: &Path,
1898 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1899) -> Result<Vec<SemanticChunk>, String> {
1900 let lang = detect_language(file).ok_or_else(|| "unsupported file extension".to_string())?;
1901 let source = std::fs::read_to_string(file).map_err(|error| error.to_string())?;
1902 let tree = parser_for(parsers, lang)?
1903 .parse(&source, None)
1904 .ok_or_else(|| format!("tree-sitter parse returned None for {}", file.display()))?;
1905 let symbols =
1906 extract_symbols_from_tree(&source, &tree, lang).map_err(|error| error.to_string())?;
1907
1908 Ok(symbols_to_chunks(file, &symbols, &source, project_root))
1909}
1910
1911fn build_snippet(symbol: &Symbol, source: &str) -> String {
1913 let lines: Vec<&str> = source.lines().collect();
1914 let start = (symbol.range.start_line as usize).min(lines.len());
1915 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1917 if start < end {
1918 let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1919 let mut snippet = snippet_lines.join("\n");
1920 if end - start > 5 {
1921 snippet.push_str("\n ...");
1922 }
1923 if snippet.len() > 300 {
1924 snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1925 }
1926 snippet
1927 } else {
1928 String::new()
1929 }
1930}
1931
1932fn symbols_to_chunks(
1934 file: &Path,
1935 symbols: &[Symbol],
1936 source: &str,
1937 project_root: &Path,
1938) -> Vec<SemanticChunk> {
1939 let mut chunks = Vec::new();
1940
1941 for symbol in symbols {
1942 if matches!(symbol.kind, SymbolKind::Heading) {
1947 continue;
1948 }
1949
1950 let line_count = symbol
1952 .range
1953 .end_line
1954 .saturating_sub(symbol.range.start_line)
1955 + 1;
1956 if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1957 continue;
1958 }
1959
1960 let embed_text = build_embed_text(symbol, source, file, project_root);
1961 let snippet = build_snippet(symbol, source);
1962
1963 chunks.push(SemanticChunk {
1964 file: file.to_path_buf(),
1965 name: symbol.name.clone(),
1966 kind: symbol.kind.clone(),
1967 start_line: symbol.range.start_line,
1968 end_line: symbol.range.end_line,
1969 exported: symbol.exported,
1970 embed_text,
1971 snippet,
1972 });
1973
1974 }
1977
1978 chunks
1979}
1980
1981fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1983 if a.len() != b.len() {
1984 return 0.0;
1985 }
1986
1987 let mut dot = 0.0f32;
1988 let mut norm_a = 0.0f32;
1989 let mut norm_b = 0.0f32;
1990
1991 for i in 0..a.len() {
1992 dot += a[i] * b[i];
1993 norm_a += a[i] * a[i];
1994 norm_b += b[i] * b[i];
1995 }
1996
1997 let denom = norm_a.sqrt() * norm_b.sqrt();
1998 if denom == 0.0 {
1999 0.0
2000 } else {
2001 dot / denom
2002 }
2003}
2004
2005fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
2007 match kind {
2008 SymbolKind::Function => 0,
2009 SymbolKind::Class => 1,
2010 SymbolKind::Method => 2,
2011 SymbolKind::Struct => 3,
2012 SymbolKind::Interface => 4,
2013 SymbolKind::Enum => 5,
2014 SymbolKind::TypeAlias => 6,
2015 SymbolKind::Variable => 7,
2016 SymbolKind::Heading => 8,
2017 }
2018}
2019
2020fn u8_to_symbol_kind(v: u8) -> SymbolKind {
2021 match v {
2022 0 => SymbolKind::Function,
2023 1 => SymbolKind::Class,
2024 2 => SymbolKind::Method,
2025 3 => SymbolKind::Struct,
2026 4 => SymbolKind::Interface,
2027 5 => SymbolKind::Enum,
2028 6 => SymbolKind::TypeAlias,
2029 7 => SymbolKind::Variable,
2030 _ => SymbolKind::Heading,
2031 }
2032}
2033
2034fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
2035 if *pos + 4 > data.len() {
2036 return Err("unexpected end of data reading u32".to_string());
2037 }
2038 let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
2039 *pos += 4;
2040 Ok(val)
2041}
2042
2043fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
2044 if *pos + 8 > data.len() {
2045 return Err("unexpected end of data reading u64".to_string());
2046 }
2047 let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
2048 *pos += 8;
2049 Ok(u64::from_le_bytes(bytes))
2050}
2051
2052fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
2053 let len = read_u32(data, pos)? as usize;
2054 if *pos + len > data.len() {
2055 return Err("unexpected end of data reading string".to_string());
2056 }
2057 let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
2058 *pos += len;
2059 Ok(s)
2060}
2061
2062#[cfg(test)]
2063mod tests {
2064 use super::*;
2065 use crate::config::{SemanticBackend, SemanticBackendConfig};
2066 use crate::parser::FileParser;
2067 use std::io::{Read, Write};
2068 use std::net::TcpListener;
2069 use std::thread;
2070
2071 fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
2072 where
2073 F: Fn(String, String, String) -> String + Send + 'static,
2074 {
2075 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
2076 let addr = listener.local_addr().expect("local addr");
2077 let handle = thread::spawn(move || {
2078 let (mut stream, _) = listener.accept().expect("accept request");
2079 let mut buf = Vec::new();
2080 let mut chunk = [0u8; 4096];
2081 let mut header_end = None;
2082 let mut content_length = 0usize;
2083 loop {
2084 let n = stream.read(&mut chunk).expect("read request");
2085 if n == 0 {
2086 break;
2087 }
2088 buf.extend_from_slice(&chunk[..n]);
2089 if header_end.is_none() {
2090 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
2091 header_end = Some(pos + 4);
2092 let headers = String::from_utf8_lossy(&buf[..pos + 4]);
2093 for line in headers.lines() {
2094 if let Some(value) = line.strip_prefix("Content-Length:") {
2095 content_length = value.trim().parse::<usize>().unwrap_or(0);
2096 }
2097 }
2098 }
2099 }
2100 if let Some(end) = header_end {
2101 if buf.len() >= end + content_length {
2102 break;
2103 }
2104 }
2105 }
2106
2107 let end = header_end.expect("header terminator");
2108 let request = String::from_utf8_lossy(&buf[..end]).to_string();
2109 let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
2110 let mut lines = request.lines();
2111 let request_line = lines.next().expect("request line").to_string();
2112 let path = request_line
2113 .split_whitespace()
2114 .nth(1)
2115 .expect("request path")
2116 .to_string();
2117 let response_body = handler(request_line, path, body);
2118 let response = format!(
2119 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
2120 response_body.len(),
2121 response_body
2122 );
2123 stream
2124 .write_all(response.as_bytes())
2125 .expect("write response");
2126 });
2127
2128 (format!("http://{}", addr), handle)
2129 }
2130
2131 fn test_vector_for_texts(texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
2132 Ok(texts.iter().map(|_| vec![1.0, 0.0, 0.0]).collect())
2133 }
2134
2135 fn write_rust_file(path: &Path, function_name: &str) {
2136 fs::write(
2137 path,
2138 format!("pub fn {function_name}() -> bool {{\n true\n}}\n"),
2139 )
2140 .unwrap();
2141 }
2142
2143 fn build_test_index(project_root: &Path, files: &[PathBuf]) -> SemanticIndex {
2144 let mut embed = test_vector_for_texts;
2145 SemanticIndex::build(project_root, files, &mut embed, 8).unwrap()
2146 }
2147
2148 fn set_file_metadata(index: &mut SemanticIndex, file: &Path, mtime: SystemTime, size: u64) {
2149 index.file_mtimes.insert(file.to_path_buf(), mtime);
2150 index.file_sizes.insert(file.to_path_buf(), size);
2151 }
2152
2153 #[test]
2154 fn test_cosine_similarity_identical() {
2155 let a = vec![1.0, 0.0, 0.0];
2156 let b = vec![1.0, 0.0, 0.0];
2157 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
2158 }
2159
2160 #[test]
2161 fn test_cosine_similarity_orthogonal() {
2162 let a = vec![1.0, 0.0, 0.0];
2163 let b = vec![0.0, 1.0, 0.0];
2164 assert!(cosine_similarity(&a, &b).abs() < 0.001);
2165 }
2166
2167 #[test]
2168 fn test_cosine_similarity_opposite() {
2169 let a = vec![1.0, 0.0, 0.0];
2170 let b = vec![-1.0, 0.0, 0.0];
2171 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
2172 }
2173
2174 #[test]
2175 fn test_serialization_roundtrip() {
2176 let mut index = SemanticIndex::new();
2177 index.entries.push(EmbeddingEntry {
2178 chunk: SemanticChunk {
2179 file: PathBuf::from("/src/main.rs"),
2180 name: "handle_request".to_string(),
2181 kind: SymbolKind::Function,
2182 start_line: 10,
2183 end_line: 25,
2184 exported: true,
2185 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2186 snippet: "fn handle_request() {\n // ...\n}".to_string(),
2187 },
2188 vector: vec![0.1, 0.2, 0.3, 0.4],
2189 });
2190 index.dimension = 4;
2191 index
2192 .file_mtimes
2193 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2194 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2195 index.set_fingerprint(SemanticIndexFingerprint {
2196 backend: "fastembed".to_string(),
2197 model: "all-MiniLM-L6-v2".to_string(),
2198 base_url: FALLBACK_BACKEND.to_string(),
2199 dimension: 4,
2200 });
2201
2202 let bytes = index.to_bytes();
2203 let restored = SemanticIndex::from_bytes(&bytes).unwrap();
2204
2205 assert_eq!(restored.entries.len(), 1);
2206 assert_eq!(restored.entries[0].chunk.name, "handle_request");
2207 assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
2208 assert_eq!(restored.dimension, 4);
2209 assert_eq!(restored.backend_label(), Some("fastembed"));
2210 assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
2211 }
2212
2213 #[test]
2214 fn test_search_top_k() {
2215 let mut index = SemanticIndex::new();
2216 index.dimension = 3;
2217
2218 for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
2220 let mut vec = vec![0.0f32; 3];
2221 vec[i] = 1.0; index.entries.push(EmbeddingEntry {
2223 chunk: SemanticChunk {
2224 file: PathBuf::from("/src/lib.rs"),
2225 name: name.to_string(),
2226 kind: SymbolKind::Function,
2227 start_line: (i * 10 + 1) as u32,
2228 end_line: (i * 10 + 5) as u32,
2229 exported: true,
2230 embed_text: format!("kind:function name:{}", name),
2231 snippet: format!("fn {}() {{}}", name),
2232 },
2233 vector: vec,
2234 });
2235 }
2236
2237 let query = vec![0.9, 0.1, 0.0];
2239 let results = index.search(&query, 2);
2240
2241 assert_eq!(results.len(), 2);
2242 assert_eq!(results[0].name, "auth"); assert!(results[0].score > results[1].score);
2244 }
2245
2246 #[test]
2247 fn test_empty_index_search() {
2248 let index = SemanticIndex::new();
2249 let results = index.search(&[0.1, 0.2, 0.3], 10);
2250 assert!(results.is_empty());
2251 }
2252
2253 #[test]
2254 fn single_line_symbol_builds_non_empty_snippet() {
2255 let symbol = Symbol {
2256 name: "answer".to_string(),
2257 kind: SymbolKind::Variable,
2258 range: crate::symbols::Range {
2259 start_line: 0,
2260 start_col: 0,
2261 end_line: 0,
2262 end_col: 24,
2263 },
2264 signature: Some("const answer = 42".to_string()),
2265 scope_chain: Vec::new(),
2266 exported: true,
2267 parent: None,
2268 };
2269 let source = "export const answer = 42;\n";
2270
2271 let snippet = build_snippet(&symbol, source);
2272
2273 assert_eq!(snippet, "export const answer = 42;");
2274 }
2275
2276 #[test]
2277 fn optimized_file_chunk_collection_matches_file_parser_path() {
2278 let project_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2279 let file = project_root.join("src/semantic_index.rs");
2280 let source = std::fs::read_to_string(&file).unwrap();
2281
2282 let mut legacy_parser = FileParser::new();
2283 let legacy_symbols = legacy_parser.extract_symbols(&file).unwrap();
2284 let legacy_chunks = symbols_to_chunks(&file, &legacy_symbols, &source, &project_root);
2285
2286 let mut parsers = HashMap::new();
2287 let optimized_chunks = collect_file_chunks(&project_root, &file, &mut parsers).unwrap();
2288
2289 assert_eq!(
2290 chunk_fingerprint(&optimized_chunks),
2291 chunk_fingerprint(&legacy_chunks)
2292 );
2293 }
2294
2295 fn chunk_fingerprint(
2296 chunks: &[SemanticChunk],
2297 ) -> Vec<(String, SymbolKind, u32, u32, bool, String, String)> {
2298 chunks
2299 .iter()
2300 .map(|chunk| {
2301 (
2302 chunk.name.clone(),
2303 chunk.kind.clone(),
2304 chunk.start_line,
2305 chunk.end_line,
2306 chunk.exported,
2307 chunk.embed_text.clone(),
2308 chunk.snippet.clone(),
2309 )
2310 })
2311 .collect()
2312 }
2313
2314 #[test]
2315 fn rejects_oversized_dimension_during_deserialization() {
2316 let mut bytes = Vec::new();
2317 bytes.push(1u8);
2318 bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
2319 bytes.extend_from_slice(&0u32.to_le_bytes());
2320 bytes.extend_from_slice(&0u32.to_le_bytes());
2321
2322 assert!(SemanticIndex::from_bytes(&bytes).is_err());
2323 }
2324
2325 #[test]
2326 fn rejects_oversized_entry_count_during_deserialization() {
2327 let mut bytes = Vec::new();
2328 bytes.push(1u8);
2329 bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
2330 bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
2331 bytes.extend_from_slice(&0u32.to_le_bytes());
2332
2333 assert!(SemanticIndex::from_bytes(&bytes).is_err());
2334 }
2335
2336 #[test]
2337 fn invalidate_file_removes_entries_and_mtime() {
2338 let target = PathBuf::from("/src/main.rs");
2339 let mut index = SemanticIndex::new();
2340 index.entries.push(EmbeddingEntry {
2341 chunk: SemanticChunk {
2342 file: target.clone(),
2343 name: "main".to_string(),
2344 kind: SymbolKind::Function,
2345 start_line: 0,
2346 end_line: 1,
2347 exported: false,
2348 embed_text: "main".to_string(),
2349 snippet: "fn main() {}".to_string(),
2350 },
2351 vector: vec![1.0; DEFAULT_DIMENSION],
2352 });
2353 index
2354 .file_mtimes
2355 .insert(target.clone(), SystemTime::UNIX_EPOCH);
2356 index.file_sizes.insert(target.clone(), 0);
2357
2358 index.invalidate_file(&target);
2359
2360 assert!(index.entries.is_empty());
2361 assert!(!index.file_mtimes.contains_key(&target));
2362 assert!(!index.file_sizes.contains_key(&target));
2363 }
2364
2365 #[test]
2366 fn refresh_transient_error_preserves_existing_entry_and_mtime() {
2367 let temp = tempfile::tempdir().unwrap();
2368 let project_root = temp.path();
2369 let file = project_root.join("src/lib.rs");
2370 fs::create_dir_all(file.parent().unwrap()).unwrap();
2371 write_rust_file(&file, "kept_symbol");
2372
2373 let mut index = build_test_index(project_root, std::slice::from_ref(&file));
2374 let original_entry_count = index.entries.len();
2375 let original_mtime = *index.file_mtimes.get(&file).unwrap();
2376 let original_size = *index.file_sizes.get(&file).unwrap();
2377
2378 let stale_mtime = SystemTime::UNIX_EPOCH;
2379 set_file_metadata(&mut index, &file, stale_mtime, original_size + 1);
2380 fs::remove_file(&file).unwrap();
2381
2382 let mut embed = test_vector_for_texts;
2383 let mut progress = |_done: usize, _total: usize| {};
2384 let summary = index
2385 .refresh_stale_files(
2386 project_root,
2387 std::slice::from_ref(&file),
2388 &mut embed,
2389 8,
2390 &mut progress,
2391 )
2392 .unwrap();
2393
2394 assert_eq!(summary.changed, 0);
2395 assert_eq!(summary.added, 0);
2396 assert_eq!(summary.deleted, 0);
2397 assert_eq!(index.entries.len(), original_entry_count);
2398 assert_eq!(index.entries[0].chunk.name, "kept_symbol");
2399 assert_eq!(index.file_mtimes.get(&file), Some(&stale_mtime));
2400 assert_ne!(index.file_mtimes.get(&file), Some(&original_mtime));
2401 assert_eq!(index.file_sizes.get(&file), Some(&(original_size + 1)));
2402 }
2403
2404 #[test]
2405 fn refresh_never_indexed_file_error_does_not_record_mtime() {
2406 let temp = tempfile::tempdir().unwrap();
2407 let project_root = temp.path();
2408 let missing = project_root.join("src/missing.rs");
2409 fs::create_dir_all(missing.parent().unwrap()).unwrap();
2410
2411 let mut index = SemanticIndex::new();
2412 let mut embed = test_vector_for_texts;
2413 let mut progress = |_done: usize, _total: usize| {};
2414 let summary = index
2415 .refresh_stale_files(
2416 project_root,
2417 std::slice::from_ref(&missing),
2418 &mut embed,
2419 8,
2420 &mut progress,
2421 )
2422 .unwrap();
2423
2424 assert_eq!(summary.added, 0);
2425 assert_eq!(summary.changed, 0);
2426 assert_eq!(summary.deleted, 0);
2427 assert!(!index.file_mtimes.contains_key(&missing));
2428 assert!(!index.file_sizes.contains_key(&missing));
2429 assert!(index.entries.is_empty());
2430 }
2431
2432 #[test]
2433 fn refresh_reports_added_for_new_files() {
2434 let temp = tempfile::tempdir().unwrap();
2435 let project_root = temp.path();
2436 let existing = project_root.join("src/lib.rs");
2437 let added = project_root.join("src/new.rs");
2438 fs::create_dir_all(existing.parent().unwrap()).unwrap();
2439 write_rust_file(&existing, "existing_symbol");
2440 write_rust_file(&added, "added_symbol");
2441
2442 let mut index = build_test_index(project_root, std::slice::from_ref(&existing));
2443 let mut embed = test_vector_for_texts;
2444 let mut progress = |_done: usize, _total: usize| {};
2445 let summary = index
2446 .refresh_stale_files(
2447 project_root,
2448 &[existing.clone(), added.clone()],
2449 &mut embed,
2450 8,
2451 &mut progress,
2452 )
2453 .unwrap();
2454
2455 assert_eq!(summary.added, 1);
2456 assert_eq!(summary.changed, 0);
2457 assert_eq!(summary.deleted, 0);
2458 assert_eq!(summary.total_processed, 2);
2459 assert!(index.file_mtimes.contains_key(&added));
2460 assert!(index.entries.iter().any(|entry| entry.chunk.file == added));
2461 }
2462
2463 #[test]
2464 fn refresh_reports_deleted_for_removed_files() {
2465 let temp = tempfile::tempdir().unwrap();
2466 let project_root = temp.path();
2467 let deleted = project_root.join("src/deleted.rs");
2468 fs::create_dir_all(deleted.parent().unwrap()).unwrap();
2469 write_rust_file(&deleted, "deleted_symbol");
2470
2471 let mut index = build_test_index(project_root, std::slice::from_ref(&deleted));
2472 fs::remove_file(&deleted).unwrap();
2473
2474 let mut embed = test_vector_for_texts;
2475 let mut progress = |_done: usize, _total: usize| {};
2476 let summary = index
2477 .refresh_stale_files(project_root, &[], &mut embed, 8, &mut progress)
2478 .unwrap();
2479
2480 assert_eq!(summary.deleted, 1);
2481 assert_eq!(summary.changed, 0);
2482 assert_eq!(summary.added, 0);
2483 assert_eq!(summary.total_processed, 1);
2484 assert!(!index.file_mtimes.contains_key(&deleted));
2485 assert!(index.entries.is_empty());
2486 }
2487
2488 #[test]
2489 fn refresh_reports_changed_for_modified_files() {
2490 let temp = tempfile::tempdir().unwrap();
2491 let project_root = temp.path();
2492 let file = project_root.join("src/lib.rs");
2493 fs::create_dir_all(file.parent().unwrap()).unwrap();
2494 write_rust_file(&file, "old_symbol");
2495
2496 let mut index = build_test_index(project_root, std::slice::from_ref(&file));
2497 set_file_metadata(&mut index, &file, SystemTime::UNIX_EPOCH, 0);
2498 write_rust_file(&file, "new_symbol");
2499
2500 let mut embed = test_vector_for_texts;
2501 let mut progress = |_done: usize, _total: usize| {};
2502 let summary = index
2503 .refresh_stale_files(
2504 project_root,
2505 std::slice::from_ref(&file),
2506 &mut embed,
2507 8,
2508 &mut progress,
2509 )
2510 .unwrap();
2511
2512 assert_eq!(summary.changed, 1);
2513 assert_eq!(summary.added, 0);
2514 assert_eq!(summary.deleted, 0);
2515 assert_eq!(summary.total_processed, 1);
2516 assert!(index
2517 .entries
2518 .iter()
2519 .any(|entry| entry.chunk.name == "new_symbol"));
2520 assert!(!index
2521 .entries
2522 .iter()
2523 .any(|entry| entry.chunk.name == "old_symbol"));
2524 }
2525
2526 #[test]
2527 fn refresh_all_clean_reports_zero_counts_and_no_embedding_work() {
2528 let temp = tempfile::tempdir().unwrap();
2529 let project_root = temp.path();
2530 let file = project_root.join("src/lib.rs");
2531 fs::create_dir_all(file.parent().unwrap()).unwrap();
2532 write_rust_file(&file, "clean_symbol");
2533
2534 let mut index = build_test_index(project_root, std::slice::from_ref(&file));
2535 let original_entries = index.entries.len();
2536 let mut embed_called = false;
2537 let mut embed = |texts: Vec<String>| {
2538 embed_called = true;
2539 test_vector_for_texts(texts)
2540 };
2541 let mut progress = |_done: usize, _total: usize| {};
2542 let summary = index
2543 .refresh_stale_files(
2544 project_root,
2545 std::slice::from_ref(&file),
2546 &mut embed,
2547 8,
2548 &mut progress,
2549 )
2550 .unwrap();
2551
2552 assert!(summary.is_noop());
2553 assert_eq!(summary.total_processed, 1);
2554 assert!(!embed_called);
2555 assert_eq!(index.entries.len(), original_entries);
2556 }
2557
2558 #[test]
2559 fn detects_missing_onnx_runtime_from_dynamic_load_error() {
2560 let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
2561
2562 assert!(is_onnx_runtime_unavailable(message));
2563 }
2564
2565 #[test]
2566 fn formats_missing_onnx_runtime_with_install_hint() {
2567 let message = format_embedding_init_error(
2568 "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
2569 );
2570
2571 assert!(message.starts_with("ONNX Runtime not found. Install via:"));
2572 assert!(message.contains("Original error:"));
2573 }
2574
2575 #[test]
2576 fn openai_compatible_backend_embeds_with_mock_server() {
2577 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2578 assert!(request_line.starts_with("POST "));
2579 assert_eq!(path, "/v1/embeddings");
2580 "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
2581 });
2582
2583 let config = SemanticBackendConfig {
2584 backend: SemanticBackend::OpenAiCompatible,
2585 model: "test-embedding".to_string(),
2586 base_url: Some(base_url),
2587 api_key_env: None,
2588 timeout_ms: 5_000,
2589 max_batch_size: 64,
2590 };
2591
2592 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2593 let vectors = model
2594 .embed(vec!["hello".to_string(), "world".to_string()])
2595 .unwrap();
2596
2597 assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
2598 handle.join().unwrap();
2599 }
2600
2601 #[test]
2602 fn ollama_backend_embeds_with_mock_server() {
2603 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2604 assert!(request_line.starts_with("POST "));
2605 assert_eq!(path, "/api/embed");
2606 "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
2607 });
2608
2609 let config = SemanticBackendConfig {
2610 backend: SemanticBackend::Ollama,
2611 model: "embeddinggemma".to_string(),
2612 base_url: Some(base_url),
2613 api_key_env: None,
2614 timeout_ms: 5_000,
2615 max_batch_size: 64,
2616 };
2617
2618 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2619 let vectors = model
2620 .embed(vec!["hello".to_string(), "world".to_string()])
2621 .unwrap();
2622
2623 assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
2624 handle.join().unwrap();
2625 }
2626
2627 #[test]
2628 fn read_from_disk_rejects_fingerprint_mismatch() {
2629 let storage = tempfile::tempdir().unwrap();
2630 let project_key = "proj";
2631
2632 let mut index = SemanticIndex::new();
2633 index.entries.push(EmbeddingEntry {
2634 chunk: SemanticChunk {
2635 file: PathBuf::from("/src/main.rs"),
2636 name: "handle_request".to_string(),
2637 kind: SymbolKind::Function,
2638 start_line: 10,
2639 end_line: 25,
2640 exported: true,
2641 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2642 snippet: "fn handle_request() {}".to_string(),
2643 },
2644 vector: vec![0.1, 0.2, 0.3],
2645 });
2646 index.dimension = 3;
2647 index
2648 .file_mtimes
2649 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2650 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2651 index.set_fingerprint(SemanticIndexFingerprint {
2652 backend: "openai_compatible".to_string(),
2653 model: "test-embedding".to_string(),
2654 base_url: "http://127.0.0.1:1234/v1".to_string(),
2655 dimension: 3,
2656 });
2657 index.write_to_disk(storage.path(), project_key);
2658
2659 let matching = index.fingerprint().unwrap().as_string();
2660 assert!(
2661 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
2662 );
2663
2664 let mismatched = SemanticIndexFingerprint {
2665 backend: "ollama".to_string(),
2666 model: "embeddinggemma".to_string(),
2667 base_url: "http://127.0.0.1:11434".to_string(),
2668 dimension: 3,
2669 }
2670 .as_string();
2671 assert!(
2672 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
2673 );
2674 }
2675
2676 #[test]
2677 fn read_from_disk_rejects_v3_cache_for_snippet_rebuild() {
2678 let storage = tempfile::tempdir().unwrap();
2679 let project_key = "proj-v3";
2680 let dir = storage.path().join("semantic").join(project_key);
2681 fs::create_dir_all(&dir).unwrap();
2682
2683 let mut index = SemanticIndex::new();
2684 index.entries.push(EmbeddingEntry {
2685 chunk: SemanticChunk {
2686 file: PathBuf::from("/src/main.rs"),
2687 name: "handle_request".to_string(),
2688 kind: SymbolKind::Function,
2689 start_line: 0,
2690 end_line: 0,
2691 exported: true,
2692 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2693 snippet: "fn handle_request() {}".to_string(),
2694 },
2695 vector: vec![0.1, 0.2, 0.3],
2696 });
2697 index.dimension = 3;
2698 index
2699 .file_mtimes
2700 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2701 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2702 let fingerprint = SemanticIndexFingerprint {
2703 backend: "fastembed".to_string(),
2704 model: "test".to_string(),
2705 base_url: FALLBACK_BACKEND.to_string(),
2706 dimension: 3,
2707 };
2708 index.set_fingerprint(fingerprint.clone());
2709
2710 let mut bytes = index.to_bytes();
2711 bytes[0] = SEMANTIC_INDEX_VERSION_V3;
2712 fs::write(dir.join("semantic.bin"), bytes).unwrap();
2713
2714 assert!(SemanticIndex::read_from_disk(
2715 storage.path(),
2716 project_key,
2717 Some(&fingerprint.as_string())
2718 )
2719 .is_none());
2720 assert!(!dir.join("semantic.bin").exists());
2721 }
2722
2723 fn make_symbol(kind: SymbolKind, name: &str, start: u32, end: u32) -> crate::symbols::Symbol {
2724 crate::symbols::Symbol {
2725 name: name.to_string(),
2726 kind,
2727 range: crate::symbols::Range {
2728 start_line: start,
2729 start_col: 0,
2730 end_line: end,
2731 end_col: 0,
2732 },
2733 signature: None,
2734 scope_chain: Vec::new(),
2735 exported: false,
2736 parent: None,
2737 }
2738 }
2739
2740 #[test]
2745 fn symbols_to_chunks_skips_heading_symbols() {
2746 let project_root = PathBuf::from("/proj");
2747 let file = project_root.join("README.md");
2748 let source = "# Title\n\nbody text\n\n## Section\n\nmore text\n";
2749
2750 let symbols = vec![
2751 make_symbol(SymbolKind::Heading, "Title", 0, 2),
2752 make_symbol(SymbolKind::Heading, "Section", 4, 6),
2753 ];
2754
2755 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2756 assert!(
2757 chunks.is_empty(),
2758 "Heading symbols must be filtered out before embedding; got {} chunk(s)",
2759 chunks.len()
2760 );
2761 }
2762
2763 #[test]
2767 fn symbols_to_chunks_keeps_code_symbols_alongside_skipped_headings() {
2768 let project_root = PathBuf::from("/proj");
2769 let file = project_root.join("src/lib.rs");
2770 let source = "pub fn handle_request() -> bool {\n true\n}\n";
2771
2772 let symbols = vec![
2773 make_symbol(SymbolKind::Heading, "doc heading", 0, 1),
2775 make_symbol(SymbolKind::Function, "handle_request", 0, 2),
2776 make_symbol(SymbolKind::Struct, "AuthService", 4, 6),
2777 ];
2778
2779 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2780 assert_eq!(
2781 chunks.len(),
2782 2,
2783 "Expected 2 code chunks (Function + Struct), got {}",
2784 chunks.len()
2785 );
2786 let names: Vec<&str> = chunks.iter().map(|c| c.name.as_str()).collect();
2787 assert!(names.contains(&"handle_request"));
2788 assert!(names.contains(&"AuthService"));
2789 assert!(
2790 !names.contains(&"doc heading"),
2791 "Heading symbol leaked into chunks: {names:?}"
2792 );
2793 }
2794
2795 #[test]
2796 fn validate_ssrf_allows_loopback_hostnames() {
2797 for host in &[
2800 "http://localhost",
2801 "http://localhost:8080",
2802 "http://localhost:11434", "http://localhost.localdomain",
2804 "http://foo.localhost",
2805 ] {
2806 assert!(
2807 validate_base_url_no_ssrf(host).is_ok(),
2808 "Expected {host} to be allowed (loopback), got: {:?}",
2809 validate_base_url_no_ssrf(host)
2810 );
2811 }
2812 }
2813
2814 #[test]
2815 fn validate_ssrf_allows_loopback_ips() {
2816 for url in &[
2819 "http://127.0.0.1",
2820 "http://127.0.0.1:11434", "http://127.0.0.1:8080",
2822 "http://127.1.2.3",
2823 ] {
2824 let result = validate_base_url_no_ssrf(url);
2825 assert!(
2826 result.is_ok(),
2827 "Expected {url} to be allowed (loopback), got: {:?}",
2828 result
2829 );
2830 }
2831 }
2832
2833 #[test]
2834 fn validate_ssrf_rejects_private_non_loopback_ips() {
2835 for url in &[
2840 "http://192.168.1.1",
2841 "http://10.0.0.1",
2842 "http://172.16.0.1",
2843 "http://169.254.169.254",
2844 "http://100.64.0.1",
2845 ] {
2846 let result = validate_base_url_no_ssrf(url);
2847 assert!(
2848 result.is_err(),
2849 "Expected {url} to be rejected (non-loopback private), got: {:?}",
2850 result
2851 );
2852 }
2853 }
2854
2855 #[test]
2856 fn validate_ssrf_rejects_mdns_local_hostnames() {
2857 for host in &[
2860 "http://printer.local",
2861 "http://nas.local:8080",
2862 "http://homelab.local",
2863 ] {
2864 let result = validate_base_url_no_ssrf(host);
2865 assert!(
2866 result.is_err(),
2867 "Expected {host} to be rejected (mDNS), got: {:?}",
2868 result
2869 );
2870 }
2871 }
2872
2873 #[test]
2874 fn normalize_base_url_allows_localhost_for_tests() {
2875 assert!(normalize_base_url("http://127.0.0.1:9999").is_ok());
2878 assert!(normalize_base_url("http://localhost:8080").is_ok());
2879 }
2880
2881 #[test]
2888 fn ort_mismatch_message_recommends_auto_fix_first() {
2889 let msg =
2890 format_ort_version_mismatch("1.9.0", "/usr/lib/x86_64-linux-gnu/libonnxruntime.so");
2891
2892 assert!(
2894 msg.contains("v1.9.0"),
2895 "should report detected version: {msg}"
2896 );
2897 assert!(
2898 msg.contains("/usr/lib/x86_64-linux-gnu/libonnxruntime.so"),
2899 "should report system path: {msg}"
2900 );
2901 assert!(msg.contains("v1.20+"), "should state requirement: {msg}");
2902
2903 let auto_fix_pos = msg
2905 .find("Auto-fix")
2906 .expect("Auto-fix solution missing — users won't discover --fix");
2907 let remove_pos = msg
2908 .find("Remove the old library")
2909 .expect("system-rm solution missing");
2910 assert!(
2911 auto_fix_pos < remove_pos,
2912 "Auto-fix must come before manual rm — see PR comment thread"
2913 );
2914
2915 assert!(
2917 msg.contains("npx @cortexkit/aft doctor --fix"),
2918 "auto-fix command must be present and copy-pasteable: {msg}"
2919 );
2920 }
2921
2922 #[test]
2926 fn ort_mismatch_message_handles_macos_dylib_path() {
2927 let msg = format_ort_version_mismatch("1.9.0", "/opt/homebrew/lib/libonnxruntime.dylib");
2928 assert!(msg.contains("v1.9.0"));
2929 assert!(msg.contains("/opt/homebrew/lib/libonnxruntime.dylib"));
2930 assert!(
2934 msg.contains("'/opt/homebrew/lib/libonnxruntime.dylib'"),
2935 "system path should be quoted in the auto-fix sentence: {msg}"
2936 );
2937 }
2938}