1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::{detect_language, extract_symbols_from_tree, grammar_for};
3use crate::symbols::{Symbol, SymbolKind};
4use crate::{slog_info, slog_warn};
5
6use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
7use rayon::prelude::*;
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::env;
12use std::fmt::Display;
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16use std::time::SystemTime;
17use tree_sitter::Parser;
18use url::Url;
19
20const DEFAULT_DIMENSION: usize = 384;
21const MAX_ENTRIES: usize = 1_000_000;
22const MAX_DIMENSION: usize = 1024;
23const F32_BYTES: usize = std::mem::size_of::<f32>();
24const HEADER_BYTES_V1: usize = 9;
25const HEADER_BYTES_V2: usize = 13;
26const ONNX_RUNTIME_INSTALL_HINT: &str =
27 "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
28
29const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
30const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
31const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
36const SEMANTIC_INDEX_VERSION_V4: u8 = 4;
39const SEMANTIC_INDEX_VERSION_V5: u8 = 5;
42const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
43const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
44const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
46const DEFAULT_MAX_BATCH_SIZE: usize = 64;
47const FALLBACK_BACKEND: &str = "none";
48const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
49const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct SemanticIndexFingerprint {
53 pub backend: String,
54 pub model: String,
55 #[serde(default)]
56 pub base_url: String,
57 pub dimension: usize,
58}
59
60impl SemanticIndexFingerprint {
61 fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
62 let base_url = config
65 .base_url
66 .as_ref()
67 .and_then(|u| normalize_base_url(u).ok())
68 .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
69 Self {
70 backend: config.backend.as_str().to_string(),
71 model: config.model.clone(),
72 base_url,
73 dimension,
74 }
75 }
76
77 pub fn as_string(&self) -> String {
78 serde_json::to_string(self).unwrap_or_else(|_| String::new())
79 }
80
81 fn matches_expected(&self, expected: &str) -> bool {
82 let encoded = self.as_string();
83 !encoded.is_empty() && encoded == expected
84 }
85}
86
87enum SemanticEmbeddingEngine {
88 Fastembed(TextEmbedding),
89 OpenAiCompatible {
90 client: Client,
91 model: String,
92 base_url: String,
93 api_key: Option<String>,
94 },
95 Ollama {
96 client: Client,
97 model: String,
98 base_url: String,
99 },
100}
101
102pub struct SemanticEmbeddingModel {
103 backend: SemanticBackend,
104 model: String,
105 base_url: Option<String>,
106 timeout_ms: u64,
107 max_batch_size: usize,
108 dimension: Option<usize>,
109 engine: SemanticEmbeddingEngine,
110}
111
112pub type EmbeddingModel = SemanticEmbeddingModel;
113
114fn validate_embedding_batch(
115 vectors: &[Vec<f32>],
116 expected_count: usize,
117 context: &str,
118) -> Result<(), String> {
119 if expected_count > 0 && vectors.is_empty() {
120 return Err(format!(
121 "{context} returned no vectors for {expected_count} inputs"
122 ));
123 }
124
125 if vectors.len() != expected_count {
126 return Err(format!(
127 "{context} returned {} vectors for {} inputs",
128 vectors.len(),
129 expected_count
130 ));
131 }
132
133 let Some(first_vector) = vectors.first() else {
134 return Ok(());
135 };
136 let expected_dimension = first_vector.len();
137 for (index, vector) in vectors.iter().enumerate() {
138 if vector.len() != expected_dimension {
139 return Err(format!(
140 "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
141 vector.len()
142 ));
143 }
144 }
145
146 Ok(())
147}
148
149fn normalize_base_url(raw: &str) -> Result<String, String> {
153 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
154 let scheme = parsed.scheme();
155 if scheme != "http" && scheme != "https" {
156 return Err(format!(
157 "unsupported URL scheme '{}' — only http:// and https:// are allowed",
158 scheme
159 ));
160 }
161 Ok(parsed.to_string().trim_end_matches('/').to_string())
162}
163
164pub fn validate_base_url_no_ssrf(raw: &str) -> Result<(), String> {
179 use std::net::{IpAddr, ToSocketAddrs};
180
181 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
182
183 let host = parsed.host_str().unwrap_or("");
184
185 let is_loopback_host =
190 host == "localhost" || host == "localhost.localdomain" || host.ends_with(".localhost");
191 if is_loopback_host {
192 return Ok(());
193 }
194
195 if host.ends_with(".local") {
198 return Err(format!(
199 "base_url host '{host}' is an mDNS name — only loopback (localhost / 127.0.0.1) and public endpoints are allowed"
200 ));
201 }
202
203 let port = parsed.port_or_known_default().unwrap_or(443);
206 let addr_str = format!("{host}:{port}");
207 let addrs: Vec<IpAddr> = addr_str
208 .to_socket_addrs()
209 .map(|iter| iter.map(|sa| sa.ip()).collect())
210 .unwrap_or_default();
211 for ip in &addrs {
212 if is_private_non_loopback_ip(ip) {
213 return Err(format!(
214 "base_url '{raw}' resolves to a private/reserved IP — only loopback (127.0.0.1) and public endpoints are allowed"
215 ));
216 }
217 }
218
219 Ok(())
220}
221
222fn is_private_non_loopback_ip(ip: &std::net::IpAddr) -> bool {
226 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
227 match ip {
228 IpAddr::V4(v4) => {
229 let o = v4.octets();
230 o[0] == 10
233 || (o[0] == 172 && (16..=31).contains(&o[1]))
235 || (o[0] == 192 && o[1] == 168)
237 || (o[0] == 169 && o[1] == 254)
239 || (o[0] == 100 && (64..=127).contains(&o[1]))
241 || o[0] == 0
243 }
244 IpAddr::V6(v6) => {
245 let _ = Ipv6Addr::LOCALHOST; (v6.segments()[0] & 0xffc0) == 0xfe80
249 || (v6.segments()[0] & 0xfe00) == 0xfc00
251 || (v6.segments()[0] == 0 && v6.segments()[1] == 0
253 && v6.segments()[2] == 0 && v6.segments()[3] == 0
254 && v6.segments()[4] == 0 && v6.segments()[5] == 0xffff
255 && {
256 let [a, b] = v6.segments()[6..8] else { return false; };
257 let ipv4 = Ipv4Addr::new((a >> 8) as u8, (a & 0xff) as u8, (b >> 8) as u8, (b & 0xff) as u8);
258 is_private_non_loopback_ip(&IpAddr::V4(ipv4))
259 })
260 }
261 }
262}
263
264fn build_openai_embeddings_endpoint(base_url: &str) -> String {
265 if base_url.ends_with("/v1") {
266 format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
267 } else {
268 format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
269 }
270}
271
272fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
273 if base_url.ends_with("/api") {
274 format!("{base_url}/embed")
275 } else {
276 format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
277 }
278}
279
280fn normalize_api_key(value: Option<String>) -> Option<String> {
281 value.and_then(|token| {
282 let token = token.trim();
283 if token.is_empty() {
284 None
285 } else {
286 Some(token.to_string())
287 }
288 })
289}
290
291fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
292 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
293}
294
295fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
296 error.is_connect()
297}
298
299fn sleep_before_embedding_retry(attempt_index: usize) {
300 if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
301 std::thread::sleep(Duration::from_millis(*delay_ms));
302 }
303}
304
305fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
306where
307 F: FnMut() -> reqwest::blocking::RequestBuilder,
308{
309 for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
310 let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
311
312 let response = match make_request().send() {
313 Ok(response) => response,
314 Err(error) => {
315 if !last_attempt && is_retryable_embedding_error(&error) {
316 sleep_before_embedding_retry(attempt_index);
317 continue;
318 }
319 return Err(format!("{backend_label} request failed: {error}"));
320 }
321 };
322
323 let status = response.status();
324 let raw = match response.text() {
325 Ok(raw) => raw,
326 Err(error) => {
327 if !last_attempt && is_retryable_embedding_error(&error) {
328 sleep_before_embedding_retry(attempt_index);
329 continue;
330 }
331 return Err(format!("{backend_label} response read failed: {error}"));
332 }
333 };
334
335 if status.is_success() {
336 return Ok(raw);
337 }
338
339 if !last_attempt && is_retryable_embedding_status(status) {
340 sleep_before_embedding_retry(attempt_index);
341 continue;
342 }
343
344 return Err(format!(
345 "{backend_label} request failed (HTTP {}): {}",
346 status, raw
347 ));
348 }
349
350 unreachable!("embedding request retries exhausted without returning")
351}
352
353impl SemanticEmbeddingModel {
354 pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
355 let timeout_ms = if config.timeout_ms == 0 {
356 DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
357 } else {
358 config.timeout_ms
359 };
360
361 let max_batch_size = if config.max_batch_size == 0 {
362 DEFAULT_MAX_BATCH_SIZE
363 } else {
364 config.max_batch_size
365 };
366
367 let api_key_env = normalize_api_key(config.api_key_env.clone());
368 let model = config.model.clone();
369
370 let client = Client::builder()
371 .timeout(Duration::from_millis(timeout_ms))
372 .redirect(reqwest::redirect::Policy::none())
373 .build()
374 .map_err(|error| format!("failed to configure embedding client: {error}"))?;
375
376 let engine = match config.backend {
377 SemanticBackend::Fastembed => {
378 SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
379 }
380 SemanticBackend::OpenAiCompatible => {
381 let raw = config.base_url.as_ref().ok_or_else(|| {
382 "base_url is required for openai_compatible backend".to_string()
383 })?;
384 let base_url = normalize_base_url(raw)?;
385
386 let api_key = match api_key_env {
387 Some(var_name) => Some(env::var(&var_name).map_err(|_| {
388 format!("missing api_key_env '{var_name}' for openai_compatible backend")
389 })?),
390 None => None,
391 };
392
393 SemanticEmbeddingEngine::OpenAiCompatible {
394 client,
395 model,
396 base_url,
397 api_key,
398 }
399 }
400 SemanticBackend::Ollama => {
401 let raw = config
402 .base_url
403 .as_ref()
404 .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
405 let base_url = normalize_base_url(raw)?;
406
407 SemanticEmbeddingEngine::Ollama {
408 client,
409 model,
410 base_url,
411 }
412 }
413 };
414
415 Ok(Self {
416 backend: config.backend,
417 model: config.model.clone(),
418 base_url: config.base_url.clone(),
419 timeout_ms,
420 max_batch_size,
421 dimension: None,
422 engine,
423 })
424 }
425
426 pub fn backend(&self) -> SemanticBackend {
427 self.backend
428 }
429
430 pub fn model(&self) -> &str {
431 &self.model
432 }
433
434 pub fn base_url(&self) -> Option<&str> {
435 self.base_url.as_deref()
436 }
437
438 pub fn max_batch_size(&self) -> usize {
439 self.max_batch_size
440 }
441
442 pub fn timeout_ms(&self) -> u64 {
443 self.timeout_ms
444 }
445
446 pub fn fingerprint(
447 &mut self,
448 config: &SemanticBackendConfig,
449 ) -> Result<SemanticIndexFingerprint, String> {
450 let dimension = self.dimension()?;
451 Ok(SemanticIndexFingerprint::from_config(config, dimension))
452 }
453
454 pub fn dimension(&mut self) -> Result<usize, String> {
455 if let Some(dimension) = self.dimension {
456 return Ok(dimension);
457 }
458
459 let dimension = match &mut self.engine {
460 SemanticEmbeddingEngine::Fastembed(model) => {
461 let vectors = model
462 .embed(vec!["semantic index fingerprint probe".to_string()], None)
463 .map_err(|error| format_embedding_init_error(error.to_string()))?;
464 vectors
465 .first()
466 .map(|v| v.len())
467 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
468 }
469 SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
470 let vectors =
471 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
472 vectors
473 .first()
474 .map(|v| v.len())
475 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
476 }
477 SemanticEmbeddingEngine::Ollama { .. } => {
478 let vectors =
479 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
480 vectors
481 .first()
482 .map(|v| v.len())
483 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
484 }
485 };
486
487 self.dimension = Some(dimension);
488 Ok(dimension)
489 }
490
491 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
492 self.embed_texts(texts)
493 }
494
495 fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
496 match &mut self.engine {
497 SemanticEmbeddingEngine::Fastembed(model) => model
498 .embed(texts, None::<usize>)
499 .map_err(|error| format_embedding_init_error(error.to_string()))
500 .map_err(|error| format!("failed to embed batch: {error}")),
501 SemanticEmbeddingEngine::OpenAiCompatible {
502 client,
503 model,
504 base_url,
505 api_key,
506 } => {
507 let expected_text_count = texts.len();
508 let endpoint = build_openai_embeddings_endpoint(base_url);
509 let body = serde_json::json!({
510 "input": texts,
511 "model": model,
512 });
513
514 let raw = send_embedding_request(
515 || {
516 let mut request = client.post(&endpoint).json(&body);
526
527 if let Some(api_key) = api_key {
528 request = request.header("Authorization", format!("Bearer {api_key}"));
529 }
530
531 request
532 },
533 "openai compatible",
534 )?;
535
536 #[derive(Deserialize)]
537 struct OpenAiResponse {
538 data: Vec<OpenAiEmbeddingResult>,
539 }
540
541 #[derive(Deserialize)]
542 struct OpenAiEmbeddingResult {
543 embedding: Vec<f32>,
544 index: Option<u32>,
545 }
546
547 let parsed: OpenAiResponse = serde_json::from_str(&raw)
548 .map_err(|error| format!("invalid openai compatible response: {error}"))?;
549 if parsed.data.len() != expected_text_count {
550 return Err(format!(
551 "openai compatible response returned {} embeddings for {} inputs",
552 parsed.data.len(),
553 expected_text_count
554 ));
555 }
556
557 let mut vectors = vec![Vec::new(); parsed.data.len()];
558 for (i, item) in parsed.data.into_iter().enumerate() {
559 let index = item.index.unwrap_or(i as u32) as usize;
560 if index >= vectors.len() {
561 return Err(
562 "openai compatible response contains invalid vector index".to_string()
563 );
564 }
565 vectors[index] = item.embedding;
566 }
567
568 for vector in &vectors {
569 if vector.is_empty() {
570 return Err(
571 "openai compatible response contained missing vectors".to_string()
572 );
573 }
574 }
575
576 self.dimension = vectors.first().map(Vec::len);
577 Ok(vectors)
578 }
579 SemanticEmbeddingEngine::Ollama {
580 client,
581 model,
582 base_url,
583 } => {
584 let expected_text_count = texts.len();
585 let endpoint = build_ollama_embeddings_endpoint(base_url);
586
587 #[derive(Serialize)]
588 struct OllamaPayload<'a> {
589 model: &'a str,
590 input: Vec<String>,
591 }
592
593 let payload = OllamaPayload {
594 model,
595 input: texts,
596 };
597
598 let raw = send_embedding_request(
599 || {
600 client.post(&endpoint).json(&payload)
605 },
606 "ollama",
607 )?;
608
609 #[derive(Deserialize)]
610 struct OllamaResponse {
611 embeddings: Vec<Vec<f32>>,
612 }
613
614 let parsed: OllamaResponse = serde_json::from_str(&raw)
615 .map_err(|error| format!("invalid ollama response: {error}"))?;
616 if parsed.embeddings.is_empty() {
617 return Err("ollama response returned no embeddings".to_string());
618 }
619 if parsed.embeddings.len() != expected_text_count {
620 return Err(format!(
621 "ollama response returned {} embeddings for {} inputs",
622 parsed.embeddings.len(),
623 expected_text_count
624 ));
625 }
626
627 let vectors = parsed.embeddings;
628 for vector in &vectors {
629 if vector.is_empty() {
630 return Err("ollama response contained empty embeddings".to_string());
631 }
632 }
633
634 self.dimension = vectors.first().map(Vec::len);
635 Ok(vectors)
636 }
637 }
638 }
639}
640
641pub fn pre_validate_onnx_runtime() -> Result<(), String> {
645 let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
646
647 #[cfg(any(target_os = "linux", target_os = "macos"))]
648 {
649 #[cfg(target_os = "linux")]
650 let default_name = "libonnxruntime.so";
651 #[cfg(target_os = "macos")]
652 let default_name = "libonnxruntime.dylib";
653
654 let lib_name = dylib_path.as_deref().unwrap_or(default_name);
655
656 unsafe {
657 let c_name = std::ffi::CString::new(lib_name)
658 .map_err(|e| format!("invalid library path: {}", e))?;
659 let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
660 if handle.is_null() {
661 let err = libc::dlerror();
662 let msg = if err.is_null() {
663 "unknown dlopen error".to_string()
664 } else {
665 std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
666 };
667 return Err(format!(
668 "ONNX Runtime not found. dlopen('{}') failed: {}. \
669 Run `npx @cortexkit/aft doctor` to diagnose.",
670 lib_name, msg
671 ));
672 }
673
674 let detected_version = detect_ort_version_from_path(lib_name);
677
678 libc::dlclose(handle);
679
680 if let Some(ref version) = detected_version {
682 let parts: Vec<&str> = version.split('.').collect();
683 if let (Some(major), Some(minor)) = (
684 parts.first().and_then(|s| s.parse::<u32>().ok()),
685 parts.get(1).and_then(|s| s.parse::<u32>().ok()),
686 ) {
687 if major != 1 || minor < 20 {
688 return Err(format_ort_version_mismatch(version, lib_name));
689 }
690 }
691 }
692 }
693 }
694
695 #[cfg(target_os = "windows")]
696 {
697 let _ = dylib_path;
699 }
700
701 Ok(())
702}
703
704#[cfg(any(test, target_os = "linux", target_os = "macos"))]
707fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
708 let path = std::path::Path::new(lib_path);
709
710 for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
712 .into_iter()
713 .flatten()
714 {
715 if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
716 if let Some(version) = extract_version_from_filename(name) {
717 return Some(version);
718 }
719 }
720 }
721
722 if let Some(parent) = path.parent() {
724 if let Ok(entries) = std::fs::read_dir(parent) {
725 for entry in entries.flatten() {
726 if let Some(name) = entry.file_name().to_str() {
727 if name.starts_with("libonnxruntime") {
728 if let Some(version) = extract_version_from_filename(name) {
729 return Some(version);
730 }
731 }
732 }
733 }
734 }
735 }
736
737 None
738}
739
740#[cfg(any(test, target_os = "linux", target_os = "macos"))]
742fn extract_version_from_filename(name: &str) -> Option<String> {
743 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
745 re.find(name).map(|m| m.as_str().to_string())
746}
747
748#[cfg(any(test, target_os = "linux", target_os = "macos"))]
749fn suggest_removal_command(lib_path: &str) -> String {
750 if lib_path.starts_with("/usr/local/lib")
751 || lib_path == "libonnxruntime.so"
752 || lib_path == "libonnxruntime.dylib"
753 {
754 #[cfg(target_os = "linux")]
755 return " sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
756 #[cfg(target_os = "macos")]
757 return " sudo rm /usr/local/lib/libonnxruntime*".to_string();
758 #[cfg(target_os = "windows")]
759 return " Delete the ONNX Runtime DLL from your PATH".to_string();
760 }
761 format!(" rm '{}'", lib_path)
762}
763
764#[cfg(any(test, target_os = "linux", target_os = "macos"))]
770pub(crate) fn format_ort_version_mismatch(version: &str, lib_name: &str) -> String {
771 format!(
772 "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
773 Solutions:\n\
774 1. Auto-fix (recommended): run `npx @cortexkit/aft doctor --fix`. \
775 This downloads AFT-managed ONNX Runtime v1.24 into AFT's storage and \
776 configures the bridge to load it instead of the system library — no \
777 changes to '{}'.\n\
778 2. Remove the old library and restart (AFT auto-downloads the correct version on next start):\n\
779 {}\n\
780 3. Or install ONNX Runtime 1.24 system-wide: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
781 4. Run `npx @cortexkit/aft doctor` for full diagnostics.",
782 version,
783 lib_name,
784 lib_name,
785 suggest_removal_command(lib_name),
786 )
787}
788
789pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
790 pre_validate_onnx_runtime()?;
792
793 let selected_model = match model {
794 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
795 _ => {
796 return Err(format!(
797 "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
798 model
799 ))
800 }
801 };
802
803 TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
804}
805
806pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
807 if message.trim_start().starts_with("ONNX Runtime not found.") {
808 return true;
809 }
810
811 let message = message.to_ascii_lowercase();
812 let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
813 .iter()
814 .any(|pattern| message.contains(pattern));
815 let mentions_dynamic_load_failure = [
816 "shared library",
817 "dynamic library",
818 "failed to load",
819 "could not load",
820 "unable to load",
821 "dlopen",
822 "loadlibrary",
823 "no such file",
824 "not found",
825 ]
826 .iter()
827 .any(|pattern| message.contains(pattern));
828
829 mentions_onnx_runtime && mentions_dynamic_load_failure
830}
831
832fn format_embedding_init_error(error: impl Display) -> String {
833 let message = error.to_string();
834
835 if is_onnx_runtime_unavailable(&message) {
836 return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
837 }
838
839 format!("failed to initialize semantic embedding model: {message}")
840}
841
842#[derive(Debug, Clone)]
844pub struct SemanticChunk {
845 pub file: PathBuf,
847 pub name: String,
849 pub kind: SymbolKind,
851 pub start_line: u32,
853 pub end_line: u32,
854 pub exported: bool,
856 pub embed_text: String,
858 pub snippet: String,
860}
861
862#[derive(Debug)]
864struct EmbeddingEntry {
865 chunk: SemanticChunk,
866 vector: Vec<f32>,
867}
868
869#[derive(Debug)]
871pub struct SemanticIndex {
872 entries: Vec<EmbeddingEntry>,
873 file_mtimes: HashMap<PathBuf, SystemTime>,
875 file_sizes: HashMap<PathBuf, u64>,
877 dimension: usize,
879 fingerprint: Option<SemanticIndexFingerprint>,
880}
881
882#[derive(Debug, Clone, Copy)]
883struct IndexedFileMetadata {
884 mtime: SystemTime,
885 size: u64,
886}
887
888#[derive(Debug, Default, Clone, Copy)]
891pub struct RefreshSummary {
892 pub changed: usize,
893 pub added: usize,
894 pub deleted: usize,
895 pub total_processed: usize,
896}
897
898impl RefreshSummary {
899 pub fn is_noop(&self) -> bool {
901 self.changed == 0 && self.added == 0 && self.deleted == 0
902 }
903}
904
905#[derive(Debug)]
907pub struct SemanticResult {
908 pub file: PathBuf,
909 pub name: String,
910 pub kind: SymbolKind,
911 pub start_line: u32,
912 pub end_line: u32,
913 pub exported: bool,
914 pub snippet: String,
915 pub score: f32,
916}
917
918impl SemanticIndex {
919 pub fn new() -> Self {
920 Self {
921 entries: Vec::new(),
922 file_mtimes: HashMap::new(),
923 file_sizes: HashMap::new(),
924 dimension: DEFAULT_DIMENSION, fingerprint: None,
926 }
927 }
928
929 pub fn entry_count(&self) -> usize {
931 self.entries.len()
932 }
933
934 pub fn status_label(&self) -> &'static str {
936 if self.entries.is_empty() {
937 "empty"
938 } else {
939 "ready"
940 }
941 }
942
943 fn collect_chunks(
944 project_root: &Path,
945 files: &[PathBuf],
946 ) -> (Vec<SemanticChunk>, HashMap<PathBuf, IndexedFileMetadata>) {
947 let per_file: Vec<(
948 PathBuf,
949 Result<(IndexedFileMetadata, Vec<SemanticChunk>), String>,
950 )> = files
951 .par_iter()
952 .map_init(HashMap::new, |parsers, file| {
953 let result = collect_file_metadata(file).and_then(|metadata| {
954 collect_file_chunks(project_root, file, parsers)
955 .map(|chunks| (metadata, chunks))
956 });
957 (file.clone(), result)
958 })
959 .collect();
960
961 let mut chunks: Vec<SemanticChunk> = Vec::new();
962 let mut file_metadata: HashMap<PathBuf, IndexedFileMetadata> = HashMap::new();
963
964 for (file, result) in per_file {
965 match result {
966 Ok((metadata, file_chunks)) => {
967 file_metadata.insert(file, metadata);
968 chunks.extend(file_chunks);
969 }
970 Err(error) => {
971 if error == "unsupported file extension" {
977 continue;
978 }
979 slog_warn!(
980 "failed to collect semantic chunks for {}: {}",
981 file.display(),
982 error
983 );
984 }
985 }
986 }
987
988 (chunks, file_metadata)
989 }
990
991 fn build_from_chunks<F, P>(
992 chunks: Vec<SemanticChunk>,
993 file_metadata: HashMap<PathBuf, IndexedFileMetadata>,
994 embed_fn: &mut F,
995 max_batch_size: usize,
996 mut progress: Option<&mut P>,
997 ) -> Result<Self, String>
998 where
999 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1000 P: FnMut(usize, usize),
1001 {
1002 let total_chunks = chunks.len();
1003
1004 if chunks.is_empty() {
1005 return Ok(Self {
1006 entries: Vec::new(),
1007 file_mtimes: file_metadata
1008 .iter()
1009 .map(|(path, metadata)| (path.clone(), metadata.mtime))
1010 .collect(),
1011 file_sizes: file_metadata
1012 .into_iter()
1013 .map(|(path, metadata)| (path, metadata.size))
1014 .collect(),
1015 dimension: DEFAULT_DIMENSION,
1016 fingerprint: None,
1017 });
1018 }
1019
1020 let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
1022 let mut expected_dimension: Option<usize> = None;
1023 let batch_size = max_batch_size.max(1);
1024 for batch_start in (0..chunks.len()).step_by(batch_size) {
1025 let batch_end = (batch_start + batch_size).min(chunks.len());
1026 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
1027 .iter()
1028 .map(|c| c.embed_text.clone())
1029 .collect();
1030
1031 let vectors = embed_fn(batch_texts)?;
1032 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
1033
1034 if let Some(dim) = vectors.first().map(|v| v.len()) {
1036 match expected_dimension {
1037 None => expected_dimension = Some(dim),
1038 Some(expected) if dim != expected => {
1039 return Err(format!(
1040 "embedding dimension changed across batches: expected {expected}, got {dim}"
1041 ));
1042 }
1043 _ => {}
1044 }
1045 }
1046
1047 for (i, vector) in vectors.into_iter().enumerate() {
1048 let chunk_idx = batch_start + i;
1049 entries.push(EmbeddingEntry {
1050 chunk: chunks[chunk_idx].clone(),
1051 vector,
1052 });
1053 }
1054
1055 if let Some(callback) = progress.as_mut() {
1056 callback(entries.len(), total_chunks);
1057 }
1058 }
1059
1060 let dimension = entries
1061 .first()
1062 .map(|e| e.vector.len())
1063 .unwrap_or(DEFAULT_DIMENSION);
1064
1065 Ok(Self {
1066 entries,
1067 file_mtimes: file_metadata
1068 .iter()
1069 .map(|(path, metadata)| (path.clone(), metadata.mtime))
1070 .collect(),
1071 file_sizes: file_metadata
1072 .into_iter()
1073 .map(|(path, metadata)| (path, metadata.size))
1074 .collect(),
1075 dimension,
1076 fingerprint: None,
1077 })
1078 }
1079
1080 pub fn build<F>(
1083 project_root: &Path,
1084 files: &[PathBuf],
1085 embed_fn: &mut F,
1086 max_batch_size: usize,
1087 ) -> Result<Self, String>
1088 where
1089 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1090 {
1091 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1092 Self::build_from_chunks(
1093 chunks,
1094 file_mtimes,
1095 embed_fn,
1096 max_batch_size,
1097 Option::<&mut fn(usize, usize)>::None,
1098 )
1099 }
1100
1101 pub fn build_with_progress<F, P>(
1103 project_root: &Path,
1104 files: &[PathBuf],
1105 embed_fn: &mut F,
1106 max_batch_size: usize,
1107 progress: &mut P,
1108 ) -> Result<Self, String>
1109 where
1110 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1111 P: FnMut(usize, usize),
1112 {
1113 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1114 let total_chunks = chunks.len();
1115 progress(0, total_chunks);
1116 Self::build_from_chunks(
1117 chunks,
1118 file_mtimes,
1119 embed_fn,
1120 max_batch_size,
1121 Some(progress),
1122 )
1123 }
1124
1125 pub fn refresh_stale_files<F, P>(
1136 &mut self,
1137 project_root: &Path,
1138 current_files: &[PathBuf],
1139 embed_fn: &mut F,
1140 max_batch_size: usize,
1141 progress: &mut P,
1142 ) -> Result<RefreshSummary, String>
1143 where
1144 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1145 P: FnMut(usize, usize),
1146 {
1147 self.backfill_missing_file_sizes();
1148
1149 let current_set: HashSet<&Path> = current_files.iter().map(PathBuf::as_path).collect();
1151 let total_processed = current_set.len() + self.file_mtimes.len()
1152 - self
1153 .file_mtimes
1154 .keys()
1155 .filter(|path| current_set.contains(path.as_path()))
1156 .count();
1157
1158 let mut deleted: Vec<PathBuf> = Vec::new();
1161 let mut changed: Vec<PathBuf> = Vec::new();
1162 for indexed_path in self.file_mtimes.keys() {
1163 if !current_set.contains(indexed_path.as_path()) {
1164 deleted.push(indexed_path.clone());
1165 continue;
1166 }
1167 if self.is_file_stale(indexed_path) {
1168 changed.push(indexed_path.clone());
1169 }
1170 }
1171
1172 let mut added: Vec<PathBuf> = Vec::new();
1174 for path in current_files {
1175 if !self.file_mtimes.contains_key(path) {
1176 added.push(path.clone());
1177 }
1178 }
1179
1180 if deleted.is_empty() && changed.is_empty() && added.is_empty() {
1182 progress(0, 0);
1183 return Ok(RefreshSummary {
1184 total_processed,
1185 ..RefreshSummary::default()
1186 });
1187 }
1188
1189 if !deleted.is_empty() {
1193 let deleted_set: HashSet<&Path> = deleted.iter().map(PathBuf::as_path).collect();
1194 self.entries
1195 .retain(|entry| !deleted_set.contains(entry.chunk.file.as_path()));
1196 for path in &deleted {
1197 self.file_mtimes.remove(path);
1198 self.file_sizes.remove(path);
1199 }
1200 }
1201
1202 let mut to_embed: Vec<PathBuf> = Vec::with_capacity(changed.len() + added.len());
1204 to_embed.extend(changed.iter().cloned());
1205 to_embed.extend(added.iter().cloned());
1206
1207 if to_embed.is_empty() {
1208 progress(0, 0);
1210 return Ok(RefreshSummary {
1211 changed: 0,
1212 added: 0,
1213 deleted: deleted.len(),
1214 total_processed,
1215 });
1216 }
1217
1218 let (chunks, fresh_metadata) = Self::collect_chunks(project_root, &to_embed);
1219
1220 if chunks.is_empty() {
1221 progress(0, 0);
1222 let successful_files: HashSet<PathBuf> = fresh_metadata.keys().cloned().collect();
1223 if !successful_files.is_empty() {
1224 self.entries
1225 .retain(|entry| !successful_files.contains(&entry.chunk.file));
1226 }
1227 let changed_count = changed
1228 .iter()
1229 .filter(|path| successful_files.contains(*path))
1230 .count();
1231 let added_count = added
1232 .iter()
1233 .filter(|path| successful_files.contains(*path))
1234 .count();
1235 for (file, metadata) in fresh_metadata {
1236 self.file_mtimes.insert(file.clone(), metadata.mtime);
1237 self.file_sizes.insert(file, metadata.size);
1238 }
1239 return Ok(RefreshSummary {
1240 changed: changed_count,
1241 added: added_count,
1242 deleted: deleted.len(),
1243 total_processed,
1244 });
1245 }
1246
1247 let total_chunks = chunks.len();
1249 progress(0, total_chunks);
1250 let batch_size = max_batch_size.max(1);
1251 let existing_dimension = if self.entries.is_empty() {
1252 None
1253 } else {
1254 Some(self.dimension)
1255 };
1256 let mut new_entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
1257 let mut observed_dimension: Option<usize> = existing_dimension;
1258
1259 for batch_start in (0..chunks.len()).step_by(batch_size) {
1260 let batch_end = (batch_start + batch_size).min(chunks.len());
1261 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
1262 .iter()
1263 .map(|c| c.embed_text.clone())
1264 .collect();
1265
1266 let vectors = embed_fn(batch_texts)?;
1267 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
1268
1269 if let Some(dim) = vectors.first().map(|v| v.len()) {
1270 match observed_dimension {
1271 None => observed_dimension = Some(dim),
1272 Some(expected) if dim != expected => {
1273 return Err(format!(
1276 "embedding dimension changed during incremental refresh: \
1277 cached index uses {expected}, new vectors use {dim}"
1278 ));
1279 }
1280 _ => {}
1281 }
1282 }
1283
1284 for (i, vector) in vectors.into_iter().enumerate() {
1285 let chunk_idx = batch_start + i;
1286 new_entries.push(EmbeddingEntry {
1287 chunk: chunks[chunk_idx].clone(),
1288 vector,
1289 });
1290 }
1291
1292 progress(new_entries.len(), total_chunks);
1293 }
1294
1295 let successful_files: HashSet<PathBuf> = fresh_metadata.keys().cloned().collect();
1296 if !successful_files.is_empty() {
1297 self.entries
1298 .retain(|entry| !successful_files.contains(&entry.chunk.file));
1299 }
1300
1301 self.entries.extend(new_entries);
1302 for (file, metadata) in fresh_metadata {
1303 self.file_mtimes.insert(file.clone(), metadata.mtime);
1304 self.file_sizes.insert(file, metadata.size);
1305 }
1306 if let Some(dim) = observed_dimension {
1307 self.dimension = dim;
1308 }
1309
1310 Ok(RefreshSummary {
1311 changed: changed
1312 .iter()
1313 .filter(|path| successful_files.contains(*path))
1314 .count(),
1315 added: added
1316 .iter()
1317 .filter(|path| successful_files.contains(*path))
1318 .count(),
1319 deleted: deleted.len(),
1320 total_processed,
1321 })
1322 }
1323
1324 pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
1326 if self.entries.is_empty() || query_vector.len() != self.dimension {
1327 return Vec::new();
1328 }
1329
1330 let mut scored: Vec<(f32, usize)> = self
1331 .entries
1332 .iter()
1333 .enumerate()
1334 .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
1335 .collect();
1336
1337 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
1339
1340 scored
1341 .into_iter()
1342 .take(top_k)
1343 .filter(|(score, _)| *score > 0.0)
1344 .map(|(score, idx)| {
1345 let entry = &self.entries[idx];
1346 SemanticResult {
1347 file: entry.chunk.file.clone(),
1348 name: entry.chunk.name.clone(),
1349 kind: entry.chunk.kind.clone(),
1350 start_line: entry.chunk.start_line,
1351 end_line: entry.chunk.end_line,
1352 exported: entry.chunk.exported,
1353 snippet: entry.chunk.snippet.clone(),
1354 score,
1355 }
1356 })
1357 .collect()
1358 }
1359
1360 pub fn len(&self) -> usize {
1362 self.entries.len()
1363 }
1364
1365 pub fn is_file_stale(&self, file: &Path) -> bool {
1367 let Some(stored_mtime) = self.file_mtimes.get(file) else {
1368 return true;
1369 };
1370 let Some(stored_size) = self.file_sizes.get(file) else {
1371 return true;
1372 };
1373 match collect_file_metadata(file) {
1374 Ok(current) => *stored_mtime != current.mtime || *stored_size != current.size,
1375 Err(_) => true,
1376 }
1377 }
1378
1379 fn backfill_missing_file_sizes(&mut self) {
1380 for path in self.file_mtimes.keys() {
1381 if self.file_sizes.contains_key(path) {
1382 continue;
1383 }
1384 if let Ok(metadata) = fs::metadata(path) {
1385 self.file_sizes.insert(path.clone(), metadata.len());
1386 }
1387 }
1388 }
1389
1390 pub fn remove_file(&mut self, file: &Path) {
1392 self.invalidate_file(file);
1393 }
1394
1395 pub fn invalidate_file(&mut self, file: &Path) {
1396 self.entries.retain(|e| e.chunk.file != file);
1397 self.file_mtimes.remove(file);
1398 self.file_sizes.remove(file);
1399 }
1400
1401 pub fn dimension(&self) -> usize {
1403 self.dimension
1404 }
1405
1406 pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1407 self.fingerprint.as_ref()
1408 }
1409
1410 pub fn backend_label(&self) -> Option<&str> {
1411 self.fingerprint.as_ref().map(|f| f.backend.as_str())
1412 }
1413
1414 pub fn model_label(&self) -> Option<&str> {
1415 self.fingerprint.as_ref().map(|f| f.model.as_str())
1416 }
1417
1418 pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1419 self.fingerprint = Some(fingerprint);
1420 }
1421
1422 pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1424 if self.entries.is_empty() {
1427 slog_info!("skipping semantic index persistence (0 entries)");
1428 return;
1429 }
1430 let dir = storage_dir.join("semantic").join(project_key);
1431 if let Err(e) = fs::create_dir_all(&dir) {
1432 slog_warn!("failed to create semantic cache dir: {}", e);
1433 return;
1434 }
1435 let data_path = dir.join("semantic.bin");
1436 let tmp_path = dir.join(format!(
1437 "semantic.bin.tmp.{}.{}",
1438 std::process::id(),
1439 SystemTime::now()
1440 .duration_since(SystemTime::UNIX_EPOCH)
1441 .unwrap_or(Duration::ZERO)
1442 .as_nanos()
1443 ));
1444 let bytes = self.to_bytes();
1445 let write_result = (|| -> std::io::Result<()> {
1446 use std::io::Write;
1447 let mut file = fs::File::create(&tmp_path)?;
1448 file.write_all(&bytes)?;
1449 file.sync_all()?;
1450 Ok(())
1451 })();
1452 if let Err(e) = write_result {
1453 slog_warn!("failed to write semantic index: {}", e);
1454 let _ = fs::remove_file(&tmp_path);
1455 return;
1456 }
1457 if let Err(e) = fs::rename(&tmp_path, &data_path) {
1458 slog_warn!("failed to rename semantic index: {}", e);
1459 let _ = fs::remove_file(&tmp_path);
1460 return;
1461 }
1462 slog_info!(
1463 "semantic index persisted: {} entries, {:.1} KB",
1464 self.entries.len(),
1465 bytes.len() as f64 / 1024.0
1466 );
1467 }
1468
1469 pub fn read_from_disk(
1471 storage_dir: &Path,
1472 project_key: &str,
1473 expected_fingerprint: Option<&str>,
1474 ) -> Option<Self> {
1475 let data_path = storage_dir
1476 .join("semantic")
1477 .join(project_key)
1478 .join("semantic.bin");
1479 let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1480 if file_len < HEADER_BYTES_V1 {
1481 slog_warn!(
1482 "corrupt semantic index (too small: {} bytes), removing",
1483 file_len
1484 );
1485 let _ = fs::remove_file(&data_path);
1486 return None;
1487 }
1488
1489 let bytes = fs::read(&data_path).ok()?;
1490 let version = bytes[0];
1491 if version != SEMANTIC_INDEX_VERSION_V5 {
1492 slog_info!(
1493 "cached semantic index version {} is older than {}, rebuilding",
1494 version,
1495 SEMANTIC_INDEX_VERSION_V5
1496 );
1497 let _ = fs::remove_file(&data_path);
1498 return None;
1499 }
1500 match Self::from_bytes(&bytes) {
1501 Ok(index) => {
1502 if index.entries.is_empty() {
1503 slog_info!("cached semantic index is empty, will rebuild");
1504 let _ = fs::remove_file(&data_path);
1505 return None;
1506 }
1507 if let Some(expected) = expected_fingerprint {
1508 let matches = index
1509 .fingerprint()
1510 .map(|fingerprint| fingerprint.matches_expected(expected))
1511 .unwrap_or(false);
1512 if !matches {
1513 slog_info!("cached semantic index fingerprint mismatch, rebuilding");
1514 let _ = fs::remove_file(&data_path);
1515 return None;
1516 }
1517 }
1518 slog_info!(
1519 "loaded semantic index from disk: {} entries",
1520 index.entries.len()
1521 );
1522 Some(index)
1523 }
1524 Err(e) => {
1525 slog_warn!("corrupt semantic index, rebuilding: {}", e);
1526 let _ = fs::remove_file(&data_path);
1527 None
1528 }
1529 }
1530 }
1531
1532 pub fn to_bytes(&self) -> Vec<u8> {
1534 let mut buf = Vec::new();
1535 let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1536 let encoded = fingerprint.as_string();
1537 if encoded.is_empty() {
1538 None
1539 } else {
1540 Some(encoded.into_bytes())
1541 }
1542 });
1543
1544 let version = SEMANTIC_INDEX_VERSION_V5;
1556 buf.push(version);
1557 buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1558 buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1559 let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1560 buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1561 buf.extend_from_slice(fp_bytes_ref);
1562
1563 buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1566 for (path, mtime) in &self.file_mtimes {
1567 let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1568 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1569 buf.extend_from_slice(&path_bytes);
1570 let duration = mtime
1571 .duration_since(SystemTime::UNIX_EPOCH)
1572 .unwrap_or_default();
1573 buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1574 buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1575 let size = self.file_sizes.get(path).copied().unwrap_or_default();
1576 buf.extend_from_slice(&size.to_le_bytes());
1577 }
1578
1579 for entry in &self.entries {
1581 let c = &entry.chunk;
1582
1583 let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1585 buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1586 buf.extend_from_slice(&file_bytes);
1587
1588 let name_bytes = c.name.as_bytes();
1590 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1591 buf.extend_from_slice(name_bytes);
1592
1593 buf.push(symbol_kind_to_u8(&c.kind));
1595
1596 buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1598 buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1599 buf.push(c.exported as u8);
1600
1601 let snippet_bytes = c.snippet.as_bytes();
1603 buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1604 buf.extend_from_slice(snippet_bytes);
1605
1606 let embed_bytes = c.embed_text.as_bytes();
1608 buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1609 buf.extend_from_slice(embed_bytes);
1610
1611 for &val in &entry.vector {
1613 buf.extend_from_slice(&val.to_le_bytes());
1614 }
1615 }
1616
1617 buf
1618 }
1619
1620 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1622 let mut pos = 0;
1623
1624 if data.len() < HEADER_BYTES_V1 {
1625 return Err("data too short".to_string());
1626 }
1627
1628 let version = data[pos];
1629 pos += 1;
1630 if version != SEMANTIC_INDEX_VERSION_V1
1631 && version != SEMANTIC_INDEX_VERSION_V2
1632 && version != SEMANTIC_INDEX_VERSION_V3
1633 && version != SEMANTIC_INDEX_VERSION_V4
1634 && version != SEMANTIC_INDEX_VERSION_V5
1635 {
1636 return Err(format!("unsupported version: {}", version));
1637 }
1638 if (version == SEMANTIC_INDEX_VERSION_V2
1642 || version == SEMANTIC_INDEX_VERSION_V3
1643 || version == SEMANTIC_INDEX_VERSION_V4
1644 || version == SEMANTIC_INDEX_VERSION_V5)
1645 && data.len() < HEADER_BYTES_V2
1646 {
1647 return Err("data too short for semantic index v2/v3/v4/v5 header".to_string());
1648 }
1649
1650 let dimension = read_u32(data, &mut pos)? as usize;
1651 let entry_count = read_u32(data, &mut pos)? as usize;
1652 if dimension == 0 || dimension > MAX_DIMENSION {
1653 return Err(format!("invalid embedding dimension: {}", dimension));
1654 }
1655 if entry_count > MAX_ENTRIES {
1656 return Err(format!("too many semantic index entries: {}", entry_count));
1657 }
1658
1659 let has_fingerprint_field = version == SEMANTIC_INDEX_VERSION_V2
1665 || version == SEMANTIC_INDEX_VERSION_V3
1666 || version == SEMANTIC_INDEX_VERSION_V4
1667 || version == SEMANTIC_INDEX_VERSION_V5;
1668 let fingerprint = if has_fingerprint_field {
1669 let fingerprint_len = read_u32(data, &mut pos)? as usize;
1670 if pos + fingerprint_len > data.len() {
1671 return Err("unexpected end of data reading fingerprint".to_string());
1672 }
1673 if fingerprint_len == 0 {
1674 None
1675 } else {
1676 let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1677 pos += fingerprint_len;
1678 Some(
1679 serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1680 .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1681 )
1682 }
1683 } else {
1684 None
1685 };
1686
1687 let mtime_count = read_u32(data, &mut pos)? as usize;
1689 if mtime_count > MAX_ENTRIES {
1690 return Err(format!("too many semantic file mtimes: {}", mtime_count));
1691 }
1692
1693 let vector_bytes = entry_count
1694 .checked_mul(dimension)
1695 .and_then(|count| count.checked_mul(F32_BYTES))
1696 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1697 if vector_bytes > data.len().saturating_sub(pos) {
1698 return Err("semantic index vectors exceed available data".to_string());
1699 }
1700
1701 let mut file_mtimes = HashMap::with_capacity(mtime_count);
1702 let mut file_sizes = HashMap::with_capacity(mtime_count);
1703 for _ in 0..mtime_count {
1704 let path = read_string(data, &mut pos)?;
1705 let secs = read_u64(data, &mut pos)?;
1706 let nanos = if version == SEMANTIC_INDEX_VERSION_V3
1712 || version == SEMANTIC_INDEX_VERSION_V4
1713 || version == SEMANTIC_INDEX_VERSION_V5
1714 {
1715 read_u32(data, &mut pos)?
1716 } else {
1717 0
1718 };
1719 let size = if version == SEMANTIC_INDEX_VERSION_V5 {
1720 read_u64(data, &mut pos)?
1721 } else {
1722 0
1723 };
1724 if nanos >= 1_000_000_000 {
1731 return Err(format!(
1732 "invalid semantic mtime: nanos {} >= 1_000_000_000",
1733 nanos
1734 ));
1735 }
1736 let duration = std::time::Duration::new(secs, nanos);
1737 let mtime = SystemTime::UNIX_EPOCH
1738 .checked_add(duration)
1739 .ok_or_else(|| {
1740 format!(
1741 "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1742 secs, nanos
1743 )
1744 })?;
1745 let path = PathBuf::from(path);
1746 file_mtimes.insert(path.clone(), mtime);
1747 file_sizes.insert(path, size);
1748 }
1749
1750 let mut entries = Vec::with_capacity(entry_count);
1752 for _ in 0..entry_count {
1753 let file = PathBuf::from(read_string(data, &mut pos)?);
1754 let name = read_string(data, &mut pos)?;
1755
1756 if pos >= data.len() {
1757 return Err("unexpected end of data".to_string());
1758 }
1759 let kind = u8_to_symbol_kind(data[pos]);
1760 pos += 1;
1761
1762 let start_line = read_u32(data, &mut pos)?;
1763 let end_line = read_u32(data, &mut pos)?;
1764
1765 if pos >= data.len() {
1766 return Err("unexpected end of data".to_string());
1767 }
1768 let exported = data[pos] != 0;
1769 pos += 1;
1770
1771 let snippet = read_string(data, &mut pos)?;
1772 let embed_text = read_string(data, &mut pos)?;
1773
1774 let vec_bytes = dimension
1776 .checked_mul(F32_BYTES)
1777 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1778 if pos + vec_bytes > data.len() {
1779 return Err("unexpected end of data reading vector".to_string());
1780 }
1781 let mut vector = Vec::with_capacity(dimension);
1782 for _ in 0..dimension {
1783 let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1784 vector.push(f32::from_le_bytes(bytes));
1785 pos += 4;
1786 }
1787
1788 entries.push(EmbeddingEntry {
1789 chunk: SemanticChunk {
1790 file,
1791 name,
1792 kind,
1793 start_line,
1794 end_line,
1795 exported,
1796 embed_text,
1797 snippet,
1798 },
1799 vector,
1800 });
1801 }
1802
1803 if entries.len() != entry_count {
1804 return Err(format!(
1805 "semantic cache entry count drift: header={} decoded={}",
1806 entry_count,
1807 entries.len()
1808 ));
1809 }
1810 for entry in &entries {
1811 if !file_mtimes.contains_key(&entry.chunk.file) {
1812 return Err(format!(
1813 "semantic cache metadata missing for entry file {}",
1814 entry.chunk.file.display()
1815 ));
1816 }
1817 }
1818
1819 Ok(Self {
1820 entries,
1821 file_mtimes,
1822 file_sizes,
1823 dimension,
1824 fingerprint,
1825 })
1826 }
1827}
1828
1829fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1831 let relative = file
1832 .strip_prefix(project_root)
1833 .unwrap_or(file)
1834 .to_string_lossy();
1835
1836 let kind_label = match &symbol.kind {
1837 SymbolKind::Function => "function",
1838 SymbolKind::Class => "class",
1839 SymbolKind::Method => "method",
1840 SymbolKind::Struct => "struct",
1841 SymbolKind::Interface => "interface",
1842 SymbolKind::Enum => "enum",
1843 SymbolKind::TypeAlias => "type",
1844 SymbolKind::Variable => "variable",
1845 SymbolKind::Heading => "heading",
1846 };
1847
1848 let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1850
1851 if let Some(sig) = &symbol.signature {
1852 text.push_str(&format!(" signature:{}", sig));
1853 }
1854
1855 let lines: Vec<&str> = source.lines().collect();
1857 let start = (symbol.range.start_line as usize).min(lines.len());
1858 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1860 if start < end {
1861 let body: String = lines[start..end]
1862 .iter()
1863 .take(15) .copied()
1865 .collect::<Vec<&str>>()
1866 .join("\n");
1867 let snippet = if body.len() > 300 {
1868 format!("{}...", &body[..body.floor_char_boundary(300)])
1869 } else {
1870 body
1871 };
1872 text.push_str(&format!(" body:{}", snippet));
1873 }
1874
1875 text
1876}
1877
1878fn parser_for(
1879 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1880 lang: crate::parser::LangId,
1881) -> Result<&mut Parser, String> {
1882 use std::collections::hash_map::Entry;
1883
1884 match parsers.entry(lang) {
1885 Entry::Occupied(entry) => Ok(entry.into_mut()),
1886 Entry::Vacant(entry) => {
1887 let grammar = grammar_for(lang);
1888 let mut parser = Parser::new();
1889 parser
1890 .set_language(&grammar)
1891 .map_err(|error| error.to_string())?;
1892 Ok(entry.insert(parser))
1893 }
1894 }
1895}
1896
1897fn collect_file_metadata(file: &Path) -> Result<IndexedFileMetadata, String> {
1898 let metadata = fs::metadata(file).map_err(|error| error.to_string())?;
1899 let mtime = metadata.modified().map_err(|error| error.to_string())?;
1900 Ok(IndexedFileMetadata {
1901 mtime,
1902 size: metadata.len(),
1903 })
1904}
1905
1906fn collect_file_chunks(
1907 project_root: &Path,
1908 file: &Path,
1909 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1910) -> Result<Vec<SemanticChunk>, String> {
1911 let lang = detect_language(file).ok_or_else(|| "unsupported file extension".to_string())?;
1912 let source = std::fs::read_to_string(file).map_err(|error| error.to_string())?;
1913 let tree = parser_for(parsers, lang)?
1914 .parse(&source, None)
1915 .ok_or_else(|| format!("tree-sitter parse returned None for {}", file.display()))?;
1916 let symbols =
1917 extract_symbols_from_tree(&source, &tree, lang).map_err(|error| error.to_string())?;
1918
1919 Ok(symbols_to_chunks(file, &symbols, &source, project_root))
1920}
1921
1922fn build_snippet(symbol: &Symbol, source: &str) -> String {
1924 let lines: Vec<&str> = source.lines().collect();
1925 let start = (symbol.range.start_line as usize).min(lines.len());
1926 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1928 if start < end {
1929 let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1930 let mut snippet = snippet_lines.join("\n");
1931 if end - start > 5 {
1932 snippet.push_str("\n ...");
1933 }
1934 if snippet.len() > 300 {
1935 snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1936 }
1937 snippet
1938 } else {
1939 String::new()
1940 }
1941}
1942
1943fn symbols_to_chunks(
1945 file: &Path,
1946 symbols: &[Symbol],
1947 source: &str,
1948 project_root: &Path,
1949) -> Vec<SemanticChunk> {
1950 let mut chunks = Vec::new();
1951
1952 for symbol in symbols {
1953 if matches!(symbol.kind, SymbolKind::Heading) {
1958 continue;
1959 }
1960
1961 let line_count = symbol
1963 .range
1964 .end_line
1965 .saturating_sub(symbol.range.start_line)
1966 + 1;
1967 if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1968 continue;
1969 }
1970
1971 let embed_text = build_embed_text(symbol, source, file, project_root);
1972 let snippet = build_snippet(symbol, source);
1973
1974 chunks.push(SemanticChunk {
1975 file: file.to_path_buf(),
1976 name: symbol.name.clone(),
1977 kind: symbol.kind.clone(),
1978 start_line: symbol.range.start_line,
1979 end_line: symbol.range.end_line,
1980 exported: symbol.exported,
1981 embed_text,
1982 snippet,
1983 });
1984
1985 }
1988
1989 chunks
1990}
1991
1992fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1994 if a.len() != b.len() {
1995 return 0.0;
1996 }
1997
1998 let mut dot = 0.0f32;
1999 let mut norm_a = 0.0f32;
2000 let mut norm_b = 0.0f32;
2001
2002 for i in 0..a.len() {
2003 dot += a[i] * b[i];
2004 norm_a += a[i] * a[i];
2005 norm_b += b[i] * b[i];
2006 }
2007
2008 let denom = norm_a.sqrt() * norm_b.sqrt();
2009 if denom == 0.0 {
2010 0.0
2011 } else {
2012 dot / denom
2013 }
2014}
2015
2016fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
2018 match kind {
2019 SymbolKind::Function => 0,
2020 SymbolKind::Class => 1,
2021 SymbolKind::Method => 2,
2022 SymbolKind::Struct => 3,
2023 SymbolKind::Interface => 4,
2024 SymbolKind::Enum => 5,
2025 SymbolKind::TypeAlias => 6,
2026 SymbolKind::Variable => 7,
2027 SymbolKind::Heading => 8,
2028 }
2029}
2030
2031fn u8_to_symbol_kind(v: u8) -> SymbolKind {
2032 match v {
2033 0 => SymbolKind::Function,
2034 1 => SymbolKind::Class,
2035 2 => SymbolKind::Method,
2036 3 => SymbolKind::Struct,
2037 4 => SymbolKind::Interface,
2038 5 => SymbolKind::Enum,
2039 6 => SymbolKind::TypeAlias,
2040 7 => SymbolKind::Variable,
2041 _ => SymbolKind::Heading,
2042 }
2043}
2044
2045fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
2046 if *pos + 4 > data.len() {
2047 return Err("unexpected end of data reading u32".to_string());
2048 }
2049 let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
2050 *pos += 4;
2051 Ok(val)
2052}
2053
2054fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
2055 if *pos + 8 > data.len() {
2056 return Err("unexpected end of data reading u64".to_string());
2057 }
2058 let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
2059 *pos += 8;
2060 Ok(u64::from_le_bytes(bytes))
2061}
2062
2063fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
2064 let len = read_u32(data, pos)? as usize;
2065 if *pos + len > data.len() {
2066 return Err("unexpected end of data reading string".to_string());
2067 }
2068 let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
2069 *pos += len;
2070 Ok(s)
2071}
2072
2073#[cfg(test)]
2074mod tests {
2075 use super::*;
2076 use crate::config::{SemanticBackend, SemanticBackendConfig};
2077 use crate::parser::FileParser;
2078 use std::io::{Read, Write};
2079 use std::net::TcpListener;
2080 use std::thread;
2081
2082 fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
2083 where
2084 F: Fn(String, String, String) -> String + Send + 'static,
2085 {
2086 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
2087 let addr = listener.local_addr().expect("local addr");
2088 let handle = thread::spawn(move || {
2089 let (mut stream, _) = listener.accept().expect("accept request");
2090 let mut buf = Vec::new();
2091 let mut chunk = [0u8; 4096];
2092 let mut header_end = None;
2093 let mut content_length = 0usize;
2094 loop {
2095 let n = stream.read(&mut chunk).expect("read request");
2096 if n == 0 {
2097 break;
2098 }
2099 buf.extend_from_slice(&chunk[..n]);
2100 if header_end.is_none() {
2101 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
2102 header_end = Some(pos + 4);
2103 let headers = String::from_utf8_lossy(&buf[..pos + 4]);
2104 for line in headers.lines() {
2105 if let Some(value) = line.strip_prefix("Content-Length:") {
2106 content_length = value.trim().parse::<usize>().unwrap_or(0);
2107 }
2108 }
2109 }
2110 }
2111 if let Some(end) = header_end {
2112 if buf.len() >= end + content_length {
2113 break;
2114 }
2115 }
2116 }
2117
2118 let end = header_end.expect("header terminator");
2119 let request = String::from_utf8_lossy(&buf[..end]).to_string();
2120 let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
2121 let mut lines = request.lines();
2122 let request_line = lines.next().expect("request line").to_string();
2123 let path = request_line
2124 .split_whitespace()
2125 .nth(1)
2126 .expect("request path")
2127 .to_string();
2128 let response_body = handler(request_line, path, body);
2129 let response = format!(
2130 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
2131 response_body.len(),
2132 response_body
2133 );
2134 stream
2135 .write_all(response.as_bytes())
2136 .expect("write response");
2137 });
2138
2139 (format!("http://{}", addr), handle)
2140 }
2141
2142 fn test_vector_for_texts(texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
2143 Ok(texts.iter().map(|_| vec![1.0, 0.0, 0.0]).collect())
2144 }
2145
2146 fn write_rust_file(path: &Path, function_name: &str) {
2147 fs::write(
2148 path,
2149 format!("pub fn {function_name}() -> bool {{\n true\n}}\n"),
2150 )
2151 .unwrap();
2152 }
2153
2154 fn build_test_index(project_root: &Path, files: &[PathBuf]) -> SemanticIndex {
2155 let mut embed = test_vector_for_texts;
2156 SemanticIndex::build(project_root, files, &mut embed, 8).unwrap()
2157 }
2158
2159 fn set_file_metadata(index: &mut SemanticIndex, file: &Path, mtime: SystemTime, size: u64) {
2160 index.file_mtimes.insert(file.to_path_buf(), mtime);
2161 index.file_sizes.insert(file.to_path_buf(), size);
2162 }
2163
2164 #[test]
2165 fn test_cosine_similarity_identical() {
2166 let a = vec![1.0, 0.0, 0.0];
2167 let b = vec![1.0, 0.0, 0.0];
2168 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
2169 }
2170
2171 #[test]
2172 fn test_cosine_similarity_orthogonal() {
2173 let a = vec![1.0, 0.0, 0.0];
2174 let b = vec![0.0, 1.0, 0.0];
2175 assert!(cosine_similarity(&a, &b).abs() < 0.001);
2176 }
2177
2178 #[test]
2179 fn test_cosine_similarity_opposite() {
2180 let a = vec![1.0, 0.0, 0.0];
2181 let b = vec![-1.0, 0.0, 0.0];
2182 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
2183 }
2184
2185 #[test]
2186 fn test_serialization_roundtrip() {
2187 let mut index = SemanticIndex::new();
2188 index.entries.push(EmbeddingEntry {
2189 chunk: SemanticChunk {
2190 file: PathBuf::from("/src/main.rs"),
2191 name: "handle_request".to_string(),
2192 kind: SymbolKind::Function,
2193 start_line: 10,
2194 end_line: 25,
2195 exported: true,
2196 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2197 snippet: "fn handle_request() {\n // ...\n}".to_string(),
2198 },
2199 vector: vec![0.1, 0.2, 0.3, 0.4],
2200 });
2201 index.dimension = 4;
2202 index
2203 .file_mtimes
2204 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2205 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2206 index.set_fingerprint(SemanticIndexFingerprint {
2207 backend: "fastembed".to_string(),
2208 model: "all-MiniLM-L6-v2".to_string(),
2209 base_url: FALLBACK_BACKEND.to_string(),
2210 dimension: 4,
2211 });
2212
2213 let bytes = index.to_bytes();
2214 let restored = SemanticIndex::from_bytes(&bytes).unwrap();
2215
2216 assert_eq!(restored.entries.len(), 1);
2217 assert_eq!(restored.entries[0].chunk.name, "handle_request");
2218 assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
2219 assert_eq!(restored.dimension, 4);
2220 assert_eq!(restored.backend_label(), Some("fastembed"));
2221 assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
2222 }
2223
2224 #[test]
2225 fn test_search_top_k() {
2226 let mut index = SemanticIndex::new();
2227 index.dimension = 3;
2228
2229 for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
2231 let mut vec = vec![0.0f32; 3];
2232 vec[i] = 1.0; index.entries.push(EmbeddingEntry {
2234 chunk: SemanticChunk {
2235 file: PathBuf::from("/src/lib.rs"),
2236 name: name.to_string(),
2237 kind: SymbolKind::Function,
2238 start_line: (i * 10 + 1) as u32,
2239 end_line: (i * 10 + 5) as u32,
2240 exported: true,
2241 embed_text: format!("kind:function name:{}", name),
2242 snippet: format!("fn {}() {{}}", name),
2243 },
2244 vector: vec,
2245 });
2246 }
2247
2248 let query = vec![0.9, 0.1, 0.0];
2250 let results = index.search(&query, 2);
2251
2252 assert_eq!(results.len(), 2);
2253 assert_eq!(results[0].name, "auth"); assert!(results[0].score > results[1].score);
2255 }
2256
2257 #[test]
2258 fn test_empty_index_search() {
2259 let index = SemanticIndex::new();
2260 let results = index.search(&[0.1, 0.2, 0.3], 10);
2261 assert!(results.is_empty());
2262 }
2263
2264 #[test]
2265 fn single_line_symbol_builds_non_empty_snippet() {
2266 let symbol = Symbol {
2267 name: "answer".to_string(),
2268 kind: SymbolKind::Variable,
2269 range: crate::symbols::Range {
2270 start_line: 0,
2271 start_col: 0,
2272 end_line: 0,
2273 end_col: 24,
2274 },
2275 signature: Some("const answer = 42".to_string()),
2276 scope_chain: Vec::new(),
2277 exported: true,
2278 parent: None,
2279 };
2280 let source = "export const answer = 42;\n";
2281
2282 let snippet = build_snippet(&symbol, source);
2283
2284 assert_eq!(snippet, "export const answer = 42;");
2285 }
2286
2287 #[test]
2288 fn optimized_file_chunk_collection_matches_file_parser_path() {
2289 let project_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2290 let file = project_root.join("src/semantic_index.rs");
2291 let source = std::fs::read_to_string(&file).unwrap();
2292
2293 let mut legacy_parser = FileParser::new();
2294 let legacy_symbols = legacy_parser.extract_symbols(&file).unwrap();
2295 let legacy_chunks = symbols_to_chunks(&file, &legacy_symbols, &source, &project_root);
2296
2297 let mut parsers = HashMap::new();
2298 let optimized_chunks = collect_file_chunks(&project_root, &file, &mut parsers).unwrap();
2299
2300 assert_eq!(
2301 chunk_fingerprint(&optimized_chunks),
2302 chunk_fingerprint(&legacy_chunks)
2303 );
2304 }
2305
2306 fn chunk_fingerprint(
2307 chunks: &[SemanticChunk],
2308 ) -> Vec<(String, SymbolKind, u32, u32, bool, String, String)> {
2309 chunks
2310 .iter()
2311 .map(|chunk| {
2312 (
2313 chunk.name.clone(),
2314 chunk.kind.clone(),
2315 chunk.start_line,
2316 chunk.end_line,
2317 chunk.exported,
2318 chunk.embed_text.clone(),
2319 chunk.snippet.clone(),
2320 )
2321 })
2322 .collect()
2323 }
2324
2325 #[test]
2326 fn rejects_oversized_dimension_during_deserialization() {
2327 let mut bytes = Vec::new();
2328 bytes.push(1u8);
2329 bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
2330 bytes.extend_from_slice(&0u32.to_le_bytes());
2331 bytes.extend_from_slice(&0u32.to_le_bytes());
2332
2333 assert!(SemanticIndex::from_bytes(&bytes).is_err());
2334 }
2335
2336 #[test]
2337 fn rejects_oversized_entry_count_during_deserialization() {
2338 let mut bytes = Vec::new();
2339 bytes.push(1u8);
2340 bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
2341 bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
2342 bytes.extend_from_slice(&0u32.to_le_bytes());
2343
2344 assert!(SemanticIndex::from_bytes(&bytes).is_err());
2345 }
2346
2347 #[test]
2348 fn invalidate_file_removes_entries_and_mtime() {
2349 let target = PathBuf::from("/src/main.rs");
2350 let mut index = SemanticIndex::new();
2351 index.entries.push(EmbeddingEntry {
2352 chunk: SemanticChunk {
2353 file: target.clone(),
2354 name: "main".to_string(),
2355 kind: SymbolKind::Function,
2356 start_line: 0,
2357 end_line: 1,
2358 exported: false,
2359 embed_text: "main".to_string(),
2360 snippet: "fn main() {}".to_string(),
2361 },
2362 vector: vec![1.0; DEFAULT_DIMENSION],
2363 });
2364 index
2365 .file_mtimes
2366 .insert(target.clone(), SystemTime::UNIX_EPOCH);
2367 index.file_sizes.insert(target.clone(), 0);
2368
2369 index.invalidate_file(&target);
2370
2371 assert!(index.entries.is_empty());
2372 assert!(!index.file_mtimes.contains_key(&target));
2373 assert!(!index.file_sizes.contains_key(&target));
2374 }
2375
2376 #[test]
2377 fn refresh_transient_error_preserves_existing_entry_and_mtime() {
2378 let temp = tempfile::tempdir().unwrap();
2379 let project_root = temp.path();
2380 let file = project_root.join("src/lib.rs");
2381 fs::create_dir_all(file.parent().unwrap()).unwrap();
2382 write_rust_file(&file, "kept_symbol");
2383
2384 let mut index = build_test_index(project_root, std::slice::from_ref(&file));
2385 let original_entry_count = index.entries.len();
2386 let original_mtime = *index.file_mtimes.get(&file).unwrap();
2387 let original_size = *index.file_sizes.get(&file).unwrap();
2388
2389 let stale_mtime = SystemTime::UNIX_EPOCH;
2390 set_file_metadata(&mut index, &file, stale_mtime, original_size + 1);
2391 fs::remove_file(&file).unwrap();
2392
2393 let mut embed = test_vector_for_texts;
2394 let mut progress = |_done: usize, _total: usize| {};
2395 let summary = index
2396 .refresh_stale_files(
2397 project_root,
2398 std::slice::from_ref(&file),
2399 &mut embed,
2400 8,
2401 &mut progress,
2402 )
2403 .unwrap();
2404
2405 assert_eq!(summary.changed, 0);
2406 assert_eq!(summary.added, 0);
2407 assert_eq!(summary.deleted, 0);
2408 assert_eq!(index.entries.len(), original_entry_count);
2409 assert_eq!(index.entries[0].chunk.name, "kept_symbol");
2410 assert_eq!(index.file_mtimes.get(&file), Some(&stale_mtime));
2411 assert_ne!(index.file_mtimes.get(&file), Some(&original_mtime));
2412 assert_eq!(index.file_sizes.get(&file), Some(&(original_size + 1)));
2413 }
2414
2415 #[test]
2416 fn refresh_never_indexed_file_error_does_not_record_mtime() {
2417 let temp = tempfile::tempdir().unwrap();
2418 let project_root = temp.path();
2419 let missing = project_root.join("src/missing.rs");
2420 fs::create_dir_all(missing.parent().unwrap()).unwrap();
2421
2422 let mut index = SemanticIndex::new();
2423 let mut embed = test_vector_for_texts;
2424 let mut progress = |_done: usize, _total: usize| {};
2425 let summary = index
2426 .refresh_stale_files(
2427 project_root,
2428 std::slice::from_ref(&missing),
2429 &mut embed,
2430 8,
2431 &mut progress,
2432 )
2433 .unwrap();
2434
2435 assert_eq!(summary.added, 0);
2436 assert_eq!(summary.changed, 0);
2437 assert_eq!(summary.deleted, 0);
2438 assert!(!index.file_mtimes.contains_key(&missing));
2439 assert!(!index.file_sizes.contains_key(&missing));
2440 assert!(index.entries.is_empty());
2441 }
2442
2443 #[test]
2444 fn refresh_reports_added_for_new_files() {
2445 let temp = tempfile::tempdir().unwrap();
2446 let project_root = temp.path();
2447 let existing = project_root.join("src/lib.rs");
2448 let added = project_root.join("src/new.rs");
2449 fs::create_dir_all(existing.parent().unwrap()).unwrap();
2450 write_rust_file(&existing, "existing_symbol");
2451 write_rust_file(&added, "added_symbol");
2452
2453 let mut index = build_test_index(project_root, std::slice::from_ref(&existing));
2454 let mut embed = test_vector_for_texts;
2455 let mut progress = |_done: usize, _total: usize| {};
2456 let summary = index
2457 .refresh_stale_files(
2458 project_root,
2459 &[existing.clone(), added.clone()],
2460 &mut embed,
2461 8,
2462 &mut progress,
2463 )
2464 .unwrap();
2465
2466 assert_eq!(summary.added, 1);
2467 assert_eq!(summary.changed, 0);
2468 assert_eq!(summary.deleted, 0);
2469 assert_eq!(summary.total_processed, 2);
2470 assert!(index.file_mtimes.contains_key(&added));
2471 assert!(index.entries.iter().any(|entry| entry.chunk.file == added));
2472 }
2473
2474 #[test]
2475 fn refresh_reports_deleted_for_removed_files() {
2476 let temp = tempfile::tempdir().unwrap();
2477 let project_root = temp.path();
2478 let deleted = project_root.join("src/deleted.rs");
2479 fs::create_dir_all(deleted.parent().unwrap()).unwrap();
2480 write_rust_file(&deleted, "deleted_symbol");
2481
2482 let mut index = build_test_index(project_root, std::slice::from_ref(&deleted));
2483 fs::remove_file(&deleted).unwrap();
2484
2485 let mut embed = test_vector_for_texts;
2486 let mut progress = |_done: usize, _total: usize| {};
2487 let summary = index
2488 .refresh_stale_files(project_root, &[], &mut embed, 8, &mut progress)
2489 .unwrap();
2490
2491 assert_eq!(summary.deleted, 1);
2492 assert_eq!(summary.changed, 0);
2493 assert_eq!(summary.added, 0);
2494 assert_eq!(summary.total_processed, 1);
2495 assert!(!index.file_mtimes.contains_key(&deleted));
2496 assert!(index.entries.is_empty());
2497 }
2498
2499 #[test]
2500 fn refresh_reports_changed_for_modified_files() {
2501 let temp = tempfile::tempdir().unwrap();
2502 let project_root = temp.path();
2503 let file = project_root.join("src/lib.rs");
2504 fs::create_dir_all(file.parent().unwrap()).unwrap();
2505 write_rust_file(&file, "old_symbol");
2506
2507 let mut index = build_test_index(project_root, std::slice::from_ref(&file));
2508 set_file_metadata(&mut index, &file, SystemTime::UNIX_EPOCH, 0);
2509 write_rust_file(&file, "new_symbol");
2510
2511 let mut embed = test_vector_for_texts;
2512 let mut progress = |_done: usize, _total: usize| {};
2513 let summary = index
2514 .refresh_stale_files(
2515 project_root,
2516 std::slice::from_ref(&file),
2517 &mut embed,
2518 8,
2519 &mut progress,
2520 )
2521 .unwrap();
2522
2523 assert_eq!(summary.changed, 1);
2524 assert_eq!(summary.added, 0);
2525 assert_eq!(summary.deleted, 0);
2526 assert_eq!(summary.total_processed, 1);
2527 assert!(index
2528 .entries
2529 .iter()
2530 .any(|entry| entry.chunk.name == "new_symbol"));
2531 assert!(!index
2532 .entries
2533 .iter()
2534 .any(|entry| entry.chunk.name == "old_symbol"));
2535 }
2536
2537 #[test]
2538 fn refresh_all_clean_reports_zero_counts_and_no_embedding_work() {
2539 let temp = tempfile::tempdir().unwrap();
2540 let project_root = temp.path();
2541 let file = project_root.join("src/lib.rs");
2542 fs::create_dir_all(file.parent().unwrap()).unwrap();
2543 write_rust_file(&file, "clean_symbol");
2544
2545 let mut index = build_test_index(project_root, std::slice::from_ref(&file));
2546 let original_entries = index.entries.len();
2547 let mut embed_called = false;
2548 let mut embed = |texts: Vec<String>| {
2549 embed_called = true;
2550 test_vector_for_texts(texts)
2551 };
2552 let mut progress = |_done: usize, _total: usize| {};
2553 let summary = index
2554 .refresh_stale_files(
2555 project_root,
2556 std::slice::from_ref(&file),
2557 &mut embed,
2558 8,
2559 &mut progress,
2560 )
2561 .unwrap();
2562
2563 assert!(summary.is_noop());
2564 assert_eq!(summary.total_processed, 1);
2565 assert!(!embed_called);
2566 assert_eq!(index.entries.len(), original_entries);
2567 }
2568
2569 #[test]
2570 fn detects_missing_onnx_runtime_from_dynamic_load_error() {
2571 let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
2572
2573 assert!(is_onnx_runtime_unavailable(message));
2574 }
2575
2576 #[test]
2577 fn formats_missing_onnx_runtime_with_install_hint() {
2578 let message = format_embedding_init_error(
2579 "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
2580 );
2581
2582 assert!(message.starts_with("ONNX Runtime not found. Install via:"));
2583 assert!(message.contains("Original error:"));
2584 }
2585
2586 #[test]
2587 fn openai_compatible_backend_embeds_with_mock_server() {
2588 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2589 assert!(request_line.starts_with("POST "));
2590 assert_eq!(path, "/v1/embeddings");
2591 "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
2592 });
2593
2594 let config = SemanticBackendConfig {
2595 backend: SemanticBackend::OpenAiCompatible,
2596 model: "test-embedding".to_string(),
2597 base_url: Some(base_url),
2598 api_key_env: None,
2599 timeout_ms: 5_000,
2600 max_batch_size: 64,
2601 };
2602
2603 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2604 let vectors = model
2605 .embed(vec!["hello".to_string(), "world".to_string()])
2606 .unwrap();
2607
2608 assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
2609 handle.join().unwrap();
2610 }
2611
2612 #[test]
2622 fn openai_compatible_request_has_single_content_type_header() {
2623 use std::sync::{Arc, Mutex};
2624 let captured: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
2625 let captured_for_thread = Arc::clone(&captured);
2626
2627 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
2628 let addr = listener.local_addr().expect("local addr");
2629 let handle = thread::spawn(move || {
2630 let (mut stream, _) = listener.accept().expect("accept");
2631 let mut buf = Vec::new();
2632 let mut chunk = [0u8; 4096];
2633 let mut header_end = None;
2634 let mut content_length = 0usize;
2635 loop {
2636 let n = stream.read(&mut chunk).expect("read");
2637 if n == 0 {
2638 break;
2639 }
2640 buf.extend_from_slice(&chunk[..n]);
2641 if header_end.is_none() {
2642 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
2643 header_end = Some(pos + 4);
2644 for line in String::from_utf8_lossy(&buf[..pos + 4]).lines() {
2645 if let Some(value) = line.strip_prefix("Content-Length:") {
2646 content_length = value.trim().parse::<usize>().unwrap_or(0);
2647 }
2648 }
2649 }
2650 }
2651 if let Some(end) = header_end {
2652 if buf.len() >= end + content_length {
2653 break;
2654 }
2655 }
2656 }
2657 *captured_for_thread.lock().unwrap() = buf;
2658 let body = "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0}]}";
2659 let response = format!(
2660 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
2661 body.len(),
2662 body
2663 );
2664 let _ = stream.write_all(response.as_bytes());
2665 });
2666
2667 let config = SemanticBackendConfig {
2668 backend: SemanticBackend::OpenAiCompatible,
2669 model: "text-embedding-3-small".to_string(),
2670 base_url: Some(format!("http://{}", addr)),
2671 api_key_env: None,
2672 timeout_ms: 5_000,
2673 max_batch_size: 64,
2674 };
2675 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2676 let _ = model.embed(vec!["probe".to_string()]).unwrap();
2677 handle.join().unwrap();
2678
2679 let bytes = captured.lock().unwrap().clone();
2680 let request = String::from_utf8_lossy(&bytes);
2681
2682 let content_type_lines = request
2685 .lines()
2686 .filter(|line| {
2687 let lower = line.to_ascii_lowercase();
2688 lower.starts_with("content-type:")
2689 })
2690 .count();
2691 assert_eq!(
2692 content_type_lines, 1,
2693 "expected exactly one Content-Type header but found {content_type_lines}; full request:\n{request}",
2694 );
2695
2696 assert!(
2699 request.contains(r#""model":"text-embedding-3-small""#),
2700 "request body should contain model field; full request:\n{request}",
2701 );
2702 }
2703
2704 #[test]
2705 fn ollama_backend_embeds_with_mock_server() {
2706 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2707 assert!(request_line.starts_with("POST "));
2708 assert_eq!(path, "/api/embed");
2709 "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
2710 });
2711
2712 let config = SemanticBackendConfig {
2713 backend: SemanticBackend::Ollama,
2714 model: "embeddinggemma".to_string(),
2715 base_url: Some(base_url),
2716 api_key_env: None,
2717 timeout_ms: 5_000,
2718 max_batch_size: 64,
2719 };
2720
2721 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2722 let vectors = model
2723 .embed(vec!["hello".to_string(), "world".to_string()])
2724 .unwrap();
2725
2726 assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
2727 handle.join().unwrap();
2728 }
2729
2730 #[test]
2731 fn read_from_disk_rejects_fingerprint_mismatch() {
2732 let storage = tempfile::tempdir().unwrap();
2733 let project_key = "proj";
2734
2735 let mut index = SemanticIndex::new();
2736 index.entries.push(EmbeddingEntry {
2737 chunk: SemanticChunk {
2738 file: PathBuf::from("/src/main.rs"),
2739 name: "handle_request".to_string(),
2740 kind: SymbolKind::Function,
2741 start_line: 10,
2742 end_line: 25,
2743 exported: true,
2744 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2745 snippet: "fn handle_request() {}".to_string(),
2746 },
2747 vector: vec![0.1, 0.2, 0.3],
2748 });
2749 index.dimension = 3;
2750 index
2751 .file_mtimes
2752 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2753 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2754 index.set_fingerprint(SemanticIndexFingerprint {
2755 backend: "openai_compatible".to_string(),
2756 model: "test-embedding".to_string(),
2757 base_url: "http://127.0.0.1:1234/v1".to_string(),
2758 dimension: 3,
2759 });
2760 index.write_to_disk(storage.path(), project_key);
2761
2762 let matching = index.fingerprint().unwrap().as_string();
2763 assert!(
2764 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
2765 );
2766
2767 let mismatched = SemanticIndexFingerprint {
2768 backend: "ollama".to_string(),
2769 model: "embeddinggemma".to_string(),
2770 base_url: "http://127.0.0.1:11434".to_string(),
2771 dimension: 3,
2772 }
2773 .as_string();
2774 assert!(
2775 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
2776 );
2777 }
2778
2779 #[test]
2780 fn read_from_disk_rejects_v3_cache_for_snippet_rebuild() {
2781 let storage = tempfile::tempdir().unwrap();
2782 let project_key = "proj-v3";
2783 let dir = storage.path().join("semantic").join(project_key);
2784 fs::create_dir_all(&dir).unwrap();
2785
2786 let mut index = SemanticIndex::new();
2787 index.entries.push(EmbeddingEntry {
2788 chunk: SemanticChunk {
2789 file: PathBuf::from("/src/main.rs"),
2790 name: "handle_request".to_string(),
2791 kind: SymbolKind::Function,
2792 start_line: 0,
2793 end_line: 0,
2794 exported: true,
2795 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2796 snippet: "fn handle_request() {}".to_string(),
2797 },
2798 vector: vec![0.1, 0.2, 0.3],
2799 });
2800 index.dimension = 3;
2801 index
2802 .file_mtimes
2803 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2804 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2805 let fingerprint = SemanticIndexFingerprint {
2806 backend: "fastembed".to_string(),
2807 model: "test".to_string(),
2808 base_url: FALLBACK_BACKEND.to_string(),
2809 dimension: 3,
2810 };
2811 index.set_fingerprint(fingerprint.clone());
2812
2813 let mut bytes = index.to_bytes();
2814 bytes[0] = SEMANTIC_INDEX_VERSION_V3;
2815 fs::write(dir.join("semantic.bin"), bytes).unwrap();
2816
2817 assert!(SemanticIndex::read_from_disk(
2818 storage.path(),
2819 project_key,
2820 Some(&fingerprint.as_string())
2821 )
2822 .is_none());
2823 assert!(!dir.join("semantic.bin").exists());
2824 }
2825
2826 fn make_symbol(kind: SymbolKind, name: &str, start: u32, end: u32) -> crate::symbols::Symbol {
2827 crate::symbols::Symbol {
2828 name: name.to_string(),
2829 kind,
2830 range: crate::symbols::Range {
2831 start_line: start,
2832 start_col: 0,
2833 end_line: end,
2834 end_col: 0,
2835 },
2836 signature: None,
2837 scope_chain: Vec::new(),
2838 exported: false,
2839 parent: None,
2840 }
2841 }
2842
2843 #[test]
2848 fn symbols_to_chunks_skips_heading_symbols() {
2849 let project_root = PathBuf::from("/proj");
2850 let file = project_root.join("README.md");
2851 let source = "# Title\n\nbody text\n\n## Section\n\nmore text\n";
2852
2853 let symbols = vec![
2854 make_symbol(SymbolKind::Heading, "Title", 0, 2),
2855 make_symbol(SymbolKind::Heading, "Section", 4, 6),
2856 ];
2857
2858 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2859 assert!(
2860 chunks.is_empty(),
2861 "Heading symbols must be filtered out before embedding; got {} chunk(s)",
2862 chunks.len()
2863 );
2864 }
2865
2866 #[test]
2870 fn symbols_to_chunks_keeps_code_symbols_alongside_skipped_headings() {
2871 let project_root = PathBuf::from("/proj");
2872 let file = project_root.join("src/lib.rs");
2873 let source = "pub fn handle_request() -> bool {\n true\n}\n";
2874
2875 let symbols = vec![
2876 make_symbol(SymbolKind::Heading, "doc heading", 0, 1),
2878 make_symbol(SymbolKind::Function, "handle_request", 0, 2),
2879 make_symbol(SymbolKind::Struct, "AuthService", 4, 6),
2880 ];
2881
2882 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2883 assert_eq!(
2884 chunks.len(),
2885 2,
2886 "Expected 2 code chunks (Function + Struct), got {}",
2887 chunks.len()
2888 );
2889 let names: Vec<&str> = chunks.iter().map(|c| c.name.as_str()).collect();
2890 assert!(names.contains(&"handle_request"));
2891 assert!(names.contains(&"AuthService"));
2892 assert!(
2893 !names.contains(&"doc heading"),
2894 "Heading symbol leaked into chunks: {names:?}"
2895 );
2896 }
2897
2898 #[test]
2899 fn validate_ssrf_allows_loopback_hostnames() {
2900 for host in &[
2903 "http://localhost",
2904 "http://localhost:8080",
2905 "http://localhost:11434", "http://localhost.localdomain",
2907 "http://foo.localhost",
2908 ] {
2909 assert!(
2910 validate_base_url_no_ssrf(host).is_ok(),
2911 "Expected {host} to be allowed (loopback), got: {:?}",
2912 validate_base_url_no_ssrf(host)
2913 );
2914 }
2915 }
2916
2917 #[test]
2918 fn validate_ssrf_allows_loopback_ips() {
2919 for url in &[
2922 "http://127.0.0.1",
2923 "http://127.0.0.1:11434", "http://127.0.0.1:8080",
2925 "http://127.1.2.3",
2926 ] {
2927 let result = validate_base_url_no_ssrf(url);
2928 assert!(
2929 result.is_ok(),
2930 "Expected {url} to be allowed (loopback), got: {:?}",
2931 result
2932 );
2933 }
2934 }
2935
2936 #[test]
2937 fn validate_ssrf_rejects_private_non_loopback_ips() {
2938 for url in &[
2943 "http://192.168.1.1",
2944 "http://10.0.0.1",
2945 "http://172.16.0.1",
2946 "http://169.254.169.254",
2947 "http://100.64.0.1",
2948 ] {
2949 let result = validate_base_url_no_ssrf(url);
2950 assert!(
2951 result.is_err(),
2952 "Expected {url} to be rejected (non-loopback private), got: {:?}",
2953 result
2954 );
2955 }
2956 }
2957
2958 #[test]
2959 fn validate_ssrf_rejects_mdns_local_hostnames() {
2960 for host in &[
2963 "http://printer.local",
2964 "http://nas.local:8080",
2965 "http://homelab.local",
2966 ] {
2967 let result = validate_base_url_no_ssrf(host);
2968 assert!(
2969 result.is_err(),
2970 "Expected {host} to be rejected (mDNS), got: {:?}",
2971 result
2972 );
2973 }
2974 }
2975
2976 #[test]
2977 fn normalize_base_url_allows_localhost_for_tests() {
2978 assert!(normalize_base_url("http://127.0.0.1:9999").is_ok());
2981 assert!(normalize_base_url("http://localhost:8080").is_ok());
2982 }
2983
2984 #[test]
2991 fn ort_mismatch_message_recommends_auto_fix_first() {
2992 let msg =
2993 format_ort_version_mismatch("1.9.0", "/usr/lib/x86_64-linux-gnu/libonnxruntime.so");
2994
2995 assert!(
2997 msg.contains("v1.9.0"),
2998 "should report detected version: {msg}"
2999 );
3000 assert!(
3001 msg.contains("/usr/lib/x86_64-linux-gnu/libonnxruntime.so"),
3002 "should report system path: {msg}"
3003 );
3004 assert!(msg.contains("v1.20+"), "should state requirement: {msg}");
3005
3006 let auto_fix_pos = msg
3008 .find("Auto-fix")
3009 .expect("Auto-fix solution missing — users won't discover --fix");
3010 let remove_pos = msg
3011 .find("Remove the old library")
3012 .expect("system-rm solution missing");
3013 assert!(
3014 auto_fix_pos < remove_pos,
3015 "Auto-fix must come before manual rm — see PR comment thread"
3016 );
3017
3018 assert!(
3020 msg.contains("npx @cortexkit/aft doctor --fix"),
3021 "auto-fix command must be present and copy-pasteable: {msg}"
3022 );
3023 }
3024
3025 #[test]
3029 fn ort_mismatch_message_handles_macos_dylib_path() {
3030 let msg = format_ort_version_mismatch("1.9.0", "/opt/homebrew/lib/libonnxruntime.dylib");
3031 assert!(msg.contains("v1.9.0"));
3032 assert!(msg.contains("/opt/homebrew/lib/libonnxruntime.dylib"));
3033 assert!(
3037 msg.contains("'/opt/homebrew/lib/libonnxruntime.dylib'"),
3038 "system path should be quoted in the auto-fix sentence: {msg}"
3039 );
3040 }
3041}