1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::{detect_language, extract_symbols_from_tree, grammar_for};
3use crate::symbols::{Symbol, SymbolKind};
4use crate::{slog_info, slog_warn};
5
6use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
7use rayon::prelude::*;
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::env;
12use std::fmt::Display;
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16use std::time::SystemTime;
17use tree_sitter::Parser;
18use url::Url;
19
20const DEFAULT_DIMENSION: usize = 384;
21const MAX_ENTRIES: usize = 1_000_000;
22const MAX_DIMENSION: usize = 1024;
23const F32_BYTES: usize = std::mem::size_of::<f32>();
24const HEADER_BYTES_V1: usize = 9;
25const HEADER_BYTES_V2: usize = 13;
26const ONNX_RUNTIME_INSTALL_HINT: &str =
27 "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
28
29const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
30const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
31const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
36const SEMANTIC_INDEX_VERSION_V4: u8 = 4;
39const SEMANTIC_INDEX_VERSION_V5: u8 = 5;
42const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
43const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
44const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
46const DEFAULT_MAX_BATCH_SIZE: usize = 64;
47const FALLBACK_BACKEND: &str = "none";
48const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
49const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct SemanticIndexFingerprint {
53 pub backend: String,
54 pub model: String,
55 #[serde(default)]
56 pub base_url: String,
57 pub dimension: usize,
58}
59
60impl SemanticIndexFingerprint {
61 fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
62 let base_url = config
65 .base_url
66 .as_ref()
67 .and_then(|u| normalize_base_url(u).ok())
68 .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
69 Self {
70 backend: config.backend.as_str().to_string(),
71 model: config.model.clone(),
72 base_url,
73 dimension,
74 }
75 }
76
77 pub fn as_string(&self) -> String {
78 serde_json::to_string(self).unwrap_or_else(|_| String::new())
79 }
80
81 fn matches_expected(&self, expected: &str) -> bool {
82 let encoded = self.as_string();
83 !encoded.is_empty() && encoded == expected
84 }
85}
86
87enum SemanticEmbeddingEngine {
88 Fastembed(TextEmbedding),
89 OpenAiCompatible {
90 client: Client,
91 model: String,
92 base_url: String,
93 api_key: Option<String>,
94 },
95 Ollama {
96 client: Client,
97 model: String,
98 base_url: String,
99 },
100}
101
102pub struct SemanticEmbeddingModel {
103 backend: SemanticBackend,
104 model: String,
105 base_url: Option<String>,
106 timeout_ms: u64,
107 max_batch_size: usize,
108 dimension: Option<usize>,
109 engine: SemanticEmbeddingEngine,
110}
111
112pub type EmbeddingModel = SemanticEmbeddingModel;
113
114fn validate_embedding_batch(
115 vectors: &[Vec<f32>],
116 expected_count: usize,
117 context: &str,
118) -> Result<(), String> {
119 if expected_count > 0 && vectors.is_empty() {
120 return Err(format!(
121 "{context} returned no vectors for {expected_count} inputs"
122 ));
123 }
124
125 if vectors.len() != expected_count {
126 return Err(format!(
127 "{context} returned {} vectors for {} inputs",
128 vectors.len(),
129 expected_count
130 ));
131 }
132
133 let Some(first_vector) = vectors.first() else {
134 return Ok(());
135 };
136 let expected_dimension = first_vector.len();
137 for (index, vector) in vectors.iter().enumerate() {
138 if vector.len() != expected_dimension {
139 return Err(format!(
140 "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
141 vector.len()
142 ));
143 }
144 }
145
146 Ok(())
147}
148
149fn normalize_base_url(raw: &str) -> Result<String, String> {
153 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
154 let scheme = parsed.scheme();
155 if scheme != "http" && scheme != "https" {
156 return Err(format!(
157 "unsupported URL scheme '{}' — only http:// and https:// are allowed",
158 scheme
159 ));
160 }
161 Ok(parsed.to_string().trim_end_matches('/').to_string())
162}
163
164pub fn validate_base_url_no_ssrf(raw: &str) -> Result<(), String> {
168 use std::net::{IpAddr, ToSocketAddrs};
169
170 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
171
172 let host = parsed.host_str().unwrap_or("");
174 if host == "localhost"
175 || host == "localhost.localdomain"
176 || host.ends_with(".localhost")
177 || host.ends_with(".local")
178 {
179 return Err(format!(
180 "base_url host '{host}' resolves to a private/loopback address — only public endpoints are allowed"
181 ));
182 }
183
184 let port = parsed.port_or_known_default().unwrap_or(443);
186 let addr_str = format!("{host}:{port}");
187 let addrs: Vec<IpAddr> = addr_str
188 .to_socket_addrs()
189 .map(|iter| iter.map(|sa| sa.ip()).collect())
190 .unwrap_or_default();
191 for ip in &addrs {
192 if is_private_ip(ip) {
193 return Err(format!(
194 "base_url '{raw}' resolves to a private/reserved IP address — only public endpoints are allowed"
195 ));
196 }
197 }
198
199 Ok(())
200}
201
202fn is_private_ip(ip: &std::net::IpAddr) -> bool {
203 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
204 match ip {
205 IpAddr::V4(v4) => {
206 let o = v4.octets();
207 o[0] == 10
209 || (o[0] == 172 && (16..=31).contains(&o[1]))
211 || (o[0] == 192 && o[1] == 168)
213 || o[0] == 127
215 || (o[0] == 169 && o[1] == 254)
217 || (o[0] == 100 && (64..=127).contains(&o[1]))
219 || o[0] == 0
221 }
222 IpAddr::V6(v6) => {
223 *v6 == Ipv6Addr::LOCALHOST
225 || (v6.segments()[0] & 0xffc0) == 0xfe80
227 || (v6.segments()[0] & 0xfe00) == 0xfc00
229 || (v6.segments()[0] == 0 && v6.segments()[1] == 0
231 && v6.segments()[2] == 0 && v6.segments()[3] == 0
232 && v6.segments()[4] == 0 && v6.segments()[5] == 0xffff
233 && {
234 let [a, b] = v6.segments()[6..8] else { return false; };
235 let ipv4 = Ipv4Addr::new((a >> 8) as u8, (a & 0xff) as u8, (b >> 8) as u8, (b & 0xff) as u8);
236 is_private_ip(&IpAddr::V4(ipv4))
237 })
238 }
239 }
240}
241
242fn build_openai_embeddings_endpoint(base_url: &str) -> String {
243 if base_url.ends_with("/v1") {
244 format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
245 } else {
246 format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
247 }
248}
249
250fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
251 if base_url.ends_with("/api") {
252 format!("{base_url}/embed")
253 } else {
254 format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
255 }
256}
257
258fn normalize_api_key(value: Option<String>) -> Option<String> {
259 value.and_then(|token| {
260 let token = token.trim();
261 if token.is_empty() {
262 None
263 } else {
264 Some(token.to_string())
265 }
266 })
267}
268
269fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
270 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
271}
272
273fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
274 error.is_connect()
275}
276
277fn sleep_before_embedding_retry(attempt_index: usize) {
278 if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
279 std::thread::sleep(Duration::from_millis(*delay_ms));
280 }
281}
282
283fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
284where
285 F: FnMut() -> reqwest::blocking::RequestBuilder,
286{
287 for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
288 let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
289
290 let response = match make_request().send() {
291 Ok(response) => response,
292 Err(error) => {
293 if !last_attempt && is_retryable_embedding_error(&error) {
294 sleep_before_embedding_retry(attempt_index);
295 continue;
296 }
297 return Err(format!("{backend_label} request failed: {error}"));
298 }
299 };
300
301 let status = response.status();
302 let raw = match response.text() {
303 Ok(raw) => raw,
304 Err(error) => {
305 if !last_attempt && is_retryable_embedding_error(&error) {
306 sleep_before_embedding_retry(attempt_index);
307 continue;
308 }
309 return Err(format!("{backend_label} response read failed: {error}"));
310 }
311 };
312
313 if status.is_success() {
314 return Ok(raw);
315 }
316
317 if !last_attempt && is_retryable_embedding_status(status) {
318 sleep_before_embedding_retry(attempt_index);
319 continue;
320 }
321
322 return Err(format!(
323 "{backend_label} request failed (HTTP {}): {}",
324 status, raw
325 ));
326 }
327
328 unreachable!("embedding request retries exhausted without returning")
329}
330
331impl SemanticEmbeddingModel {
332 pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
333 let timeout_ms = if config.timeout_ms == 0 {
334 DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
335 } else {
336 config.timeout_ms
337 };
338
339 let max_batch_size = if config.max_batch_size == 0 {
340 DEFAULT_MAX_BATCH_SIZE
341 } else {
342 config.max_batch_size
343 };
344
345 let api_key_env = normalize_api_key(config.api_key_env.clone());
346 let model = config.model.clone();
347
348 let client = Client::builder()
349 .timeout(Duration::from_millis(timeout_ms))
350 .redirect(reqwest::redirect::Policy::none())
351 .build()
352 .map_err(|error| format!("failed to configure embedding client: {error}"))?;
353
354 let engine = match config.backend {
355 SemanticBackend::Fastembed => {
356 SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
357 }
358 SemanticBackend::OpenAiCompatible => {
359 let raw = config.base_url.as_ref().ok_or_else(|| {
360 "base_url is required for openai_compatible backend".to_string()
361 })?;
362 let base_url = normalize_base_url(raw)?;
363
364 let api_key = match api_key_env {
365 Some(var_name) => Some(env::var(&var_name).map_err(|_| {
366 format!("missing api_key_env '{var_name}' for openai_compatible backend")
367 })?),
368 None => None,
369 };
370
371 SemanticEmbeddingEngine::OpenAiCompatible {
372 client,
373 model,
374 base_url,
375 api_key,
376 }
377 }
378 SemanticBackend::Ollama => {
379 let raw = config
380 .base_url
381 .as_ref()
382 .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
383 let base_url = normalize_base_url(raw)?;
384
385 SemanticEmbeddingEngine::Ollama {
386 client,
387 model,
388 base_url,
389 }
390 }
391 };
392
393 Ok(Self {
394 backend: config.backend,
395 model: config.model.clone(),
396 base_url: config.base_url.clone(),
397 timeout_ms,
398 max_batch_size,
399 dimension: None,
400 engine,
401 })
402 }
403
404 pub fn backend(&self) -> SemanticBackend {
405 self.backend
406 }
407
408 pub fn model(&self) -> &str {
409 &self.model
410 }
411
412 pub fn base_url(&self) -> Option<&str> {
413 self.base_url.as_deref()
414 }
415
416 pub fn max_batch_size(&self) -> usize {
417 self.max_batch_size
418 }
419
420 pub fn timeout_ms(&self) -> u64 {
421 self.timeout_ms
422 }
423
424 pub fn fingerprint(
425 &mut self,
426 config: &SemanticBackendConfig,
427 ) -> Result<SemanticIndexFingerprint, String> {
428 let dimension = self.dimension()?;
429 Ok(SemanticIndexFingerprint::from_config(config, dimension))
430 }
431
432 pub fn dimension(&mut self) -> Result<usize, String> {
433 if let Some(dimension) = self.dimension {
434 return Ok(dimension);
435 }
436
437 let dimension = match &mut self.engine {
438 SemanticEmbeddingEngine::Fastembed(model) => {
439 let vectors = model
440 .embed(vec!["semantic index fingerprint probe".to_string()], None)
441 .map_err(|error| format_embedding_init_error(error.to_string()))?;
442 vectors
443 .first()
444 .map(|v| v.len())
445 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
446 }
447 SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
448 let vectors =
449 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
450 vectors
451 .first()
452 .map(|v| v.len())
453 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
454 }
455 SemanticEmbeddingEngine::Ollama { .. } => {
456 let vectors =
457 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
458 vectors
459 .first()
460 .map(|v| v.len())
461 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
462 }
463 };
464
465 self.dimension = Some(dimension);
466 Ok(dimension)
467 }
468
469 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
470 self.embed_texts(texts)
471 }
472
473 fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
474 match &mut self.engine {
475 SemanticEmbeddingEngine::Fastembed(model) => model
476 .embed(texts, None::<usize>)
477 .map_err(|error| format_embedding_init_error(error.to_string()))
478 .map_err(|error| format!("failed to embed batch: {error}")),
479 SemanticEmbeddingEngine::OpenAiCompatible {
480 client,
481 model,
482 base_url,
483 api_key,
484 } => {
485 let expected_text_count = texts.len();
486 let endpoint = build_openai_embeddings_endpoint(base_url);
487 let body = serde_json::json!({
488 "input": texts,
489 "model": model,
490 });
491
492 let raw = send_embedding_request(
493 || {
494 let mut request = client
495 .post(&endpoint)
496 .json(&body)
497 .header("Content-Type", "application/json");
498
499 if let Some(api_key) = api_key {
500 request = request.header("Authorization", format!("Bearer {api_key}"));
501 }
502
503 request
504 },
505 "openai compatible",
506 )?;
507
508 #[derive(Deserialize)]
509 struct OpenAiResponse {
510 data: Vec<OpenAiEmbeddingResult>,
511 }
512
513 #[derive(Deserialize)]
514 struct OpenAiEmbeddingResult {
515 embedding: Vec<f32>,
516 index: Option<u32>,
517 }
518
519 let parsed: OpenAiResponse = serde_json::from_str(&raw)
520 .map_err(|error| format!("invalid openai compatible response: {error}"))?;
521 if parsed.data.len() != expected_text_count {
522 return Err(format!(
523 "openai compatible response returned {} embeddings for {} inputs",
524 parsed.data.len(),
525 expected_text_count
526 ));
527 }
528
529 let mut vectors = vec![Vec::new(); parsed.data.len()];
530 for (i, item) in parsed.data.into_iter().enumerate() {
531 let index = item.index.unwrap_or(i as u32) as usize;
532 if index >= vectors.len() {
533 return Err(
534 "openai compatible response contains invalid vector index".to_string()
535 );
536 }
537 vectors[index] = item.embedding;
538 }
539
540 for vector in &vectors {
541 if vector.is_empty() {
542 return Err(
543 "openai compatible response contained missing vectors".to_string()
544 );
545 }
546 }
547
548 self.dimension = vectors.first().map(Vec::len);
549 Ok(vectors)
550 }
551 SemanticEmbeddingEngine::Ollama {
552 client,
553 model,
554 base_url,
555 } => {
556 let expected_text_count = texts.len();
557 let endpoint = build_ollama_embeddings_endpoint(base_url);
558
559 #[derive(Serialize)]
560 struct OllamaPayload<'a> {
561 model: &'a str,
562 input: Vec<String>,
563 }
564
565 let payload = OllamaPayload {
566 model,
567 input: texts,
568 };
569
570 let raw = send_embedding_request(
571 || {
572 client
573 .post(&endpoint)
574 .json(&payload)
575 .header("Content-Type", "application/json")
576 },
577 "ollama",
578 )?;
579
580 #[derive(Deserialize)]
581 struct OllamaResponse {
582 embeddings: Vec<Vec<f32>>,
583 }
584
585 let parsed: OllamaResponse = serde_json::from_str(&raw)
586 .map_err(|error| format!("invalid ollama response: {error}"))?;
587 if parsed.embeddings.is_empty() {
588 return Err("ollama response returned no embeddings".to_string());
589 }
590 if parsed.embeddings.len() != expected_text_count {
591 return Err(format!(
592 "ollama response returned {} embeddings for {} inputs",
593 parsed.embeddings.len(),
594 expected_text_count
595 ));
596 }
597
598 let vectors = parsed.embeddings;
599 for vector in &vectors {
600 if vector.is_empty() {
601 return Err("ollama response contained empty embeddings".to_string());
602 }
603 }
604
605 self.dimension = vectors.first().map(Vec::len);
606 Ok(vectors)
607 }
608 }
609 }
610}
611
612pub fn pre_validate_onnx_runtime() -> Result<(), String> {
616 let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
617
618 #[cfg(any(target_os = "linux", target_os = "macos"))]
619 {
620 #[cfg(target_os = "linux")]
621 let default_name = "libonnxruntime.so";
622 #[cfg(target_os = "macos")]
623 let default_name = "libonnxruntime.dylib";
624
625 let lib_name = dylib_path.as_deref().unwrap_or(default_name);
626
627 unsafe {
628 let c_name = std::ffi::CString::new(lib_name)
629 .map_err(|e| format!("invalid library path: {}", e))?;
630 let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
631 if handle.is_null() {
632 let err = libc::dlerror();
633 let msg = if err.is_null() {
634 "unknown dlopen error".to_string()
635 } else {
636 std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
637 };
638 return Err(format!(
639 "ONNX Runtime not found. dlopen('{}') failed: {}. \
640 Run `bunx @cortexkit/aft-opencode@latest doctor` to diagnose.",
641 lib_name, msg
642 ));
643 }
644
645 let detected_version = detect_ort_version_from_path(lib_name);
648
649 libc::dlclose(handle);
650
651 if let Some(ref version) = detected_version {
653 let parts: Vec<&str> = version.split('.').collect();
654 if let (Some(major), Some(minor)) = (
655 parts.first().and_then(|s| s.parse::<u32>().ok()),
656 parts.get(1).and_then(|s| s.parse::<u32>().ok()),
657 ) {
658 if major != 1 || minor < 20 {
659 return Err(format!(
660 "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
661 Solutions:\n\
662 1. Remove the old library and restart (AFT auto-downloads the correct version):\n\
663 {}\n\
664 2. Or install ONNX Runtime 1.24: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
665 3. Run `bunx @cortexkit/aft-opencode@latest doctor` for full diagnostics.",
666 version,
667 lib_name,
668 suggest_removal_command(lib_name),
669 ));
670 }
671 }
672 }
673 }
674 }
675
676 #[cfg(target_os = "windows")]
677 {
678 let _ = dylib_path;
680 }
681
682 Ok(())
683}
684
685fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
688 let path = std::path::Path::new(lib_path);
689
690 for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
692 .into_iter()
693 .flatten()
694 {
695 if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
696 if let Some(version) = extract_version_from_filename(name) {
697 return Some(version);
698 }
699 }
700 }
701
702 if let Some(parent) = path.parent() {
704 if let Ok(entries) = std::fs::read_dir(parent) {
705 for entry in entries.flatten() {
706 if let Some(name) = entry.file_name().to_str() {
707 if name.starts_with("libonnxruntime") {
708 if let Some(version) = extract_version_from_filename(name) {
709 return Some(version);
710 }
711 }
712 }
713 }
714 }
715 }
716
717 None
718}
719
720fn extract_version_from_filename(name: &str) -> Option<String> {
722 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
724 re.find(name).map(|m| m.as_str().to_string())
725}
726
727fn suggest_removal_command(lib_path: &str) -> String {
728 if lib_path.starts_with("/usr/local/lib")
729 || lib_path == "libonnxruntime.so"
730 || lib_path == "libonnxruntime.dylib"
731 {
732 #[cfg(target_os = "linux")]
733 return " sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
734 #[cfg(target_os = "macos")]
735 return " sudo rm /usr/local/lib/libonnxruntime*".to_string();
736 #[cfg(target_os = "windows")]
737 return " Delete the ONNX Runtime DLL from your PATH".to_string();
738 }
739 format!(" rm '{}'", lib_path)
740}
741
742pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
743 pre_validate_onnx_runtime()?;
745
746 let selected_model = match model {
747 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
748 _ => {
749 return Err(format!(
750 "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
751 model
752 ))
753 }
754 };
755
756 TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
757}
758
759pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
760 if message.trim_start().starts_with("ONNX Runtime not found.") {
761 return true;
762 }
763
764 let message = message.to_ascii_lowercase();
765 let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
766 .iter()
767 .any(|pattern| message.contains(pattern));
768 let mentions_dynamic_load_failure = [
769 "shared library",
770 "dynamic library",
771 "failed to load",
772 "could not load",
773 "unable to load",
774 "dlopen",
775 "loadlibrary",
776 "no such file",
777 "not found",
778 ]
779 .iter()
780 .any(|pattern| message.contains(pattern));
781
782 mentions_onnx_runtime && mentions_dynamic_load_failure
783}
784
785fn format_embedding_init_error(error: impl Display) -> String {
786 let message = error.to_string();
787
788 if is_onnx_runtime_unavailable(&message) {
789 return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
790 }
791
792 format!("failed to initialize semantic embedding model: {message}")
793}
794
795#[derive(Debug, Clone)]
797pub struct SemanticChunk {
798 pub file: PathBuf,
800 pub name: String,
802 pub kind: SymbolKind,
804 pub start_line: u32,
806 pub end_line: u32,
807 pub exported: bool,
809 pub embed_text: String,
811 pub snippet: String,
813}
814
815#[derive(Debug)]
817struct EmbeddingEntry {
818 chunk: SemanticChunk,
819 vector: Vec<f32>,
820}
821
822#[derive(Debug)]
824pub struct SemanticIndex {
825 entries: Vec<EmbeddingEntry>,
826 file_mtimes: HashMap<PathBuf, SystemTime>,
828 file_sizes: HashMap<PathBuf, u64>,
830 dimension: usize,
832 fingerprint: Option<SemanticIndexFingerprint>,
833}
834
835#[derive(Debug, Clone, Copy)]
836struct IndexedFileMetadata {
837 mtime: SystemTime,
838 size: u64,
839}
840
841#[derive(Debug, Default, Clone, Copy)]
844pub struct RefreshSummary {
845 pub changed: usize,
846 pub added: usize,
847 pub deleted: usize,
848 pub total_processed: usize,
849}
850
851impl RefreshSummary {
852 pub fn is_noop(&self) -> bool {
854 self.changed == 0 && self.added == 0 && self.deleted == 0
855 }
856}
857
858#[derive(Debug)]
860pub struct SemanticResult {
861 pub file: PathBuf,
862 pub name: String,
863 pub kind: SymbolKind,
864 pub start_line: u32,
865 pub end_line: u32,
866 pub exported: bool,
867 pub snippet: String,
868 pub score: f32,
869}
870
871impl SemanticIndex {
872 pub fn new() -> Self {
873 Self {
874 entries: Vec::new(),
875 file_mtimes: HashMap::new(),
876 file_sizes: HashMap::new(),
877 dimension: DEFAULT_DIMENSION, fingerprint: None,
879 }
880 }
881
882 pub fn entry_count(&self) -> usize {
884 self.entries.len()
885 }
886
887 pub fn status_label(&self) -> &'static str {
889 if self.entries.is_empty() {
890 "empty"
891 } else {
892 "ready"
893 }
894 }
895
896 fn collect_chunks(
897 project_root: &Path,
898 files: &[PathBuf],
899 ) -> (Vec<SemanticChunk>, HashMap<PathBuf, IndexedFileMetadata>) {
900 let per_file: Vec<(
901 PathBuf,
902 Result<(IndexedFileMetadata, Vec<SemanticChunk>), String>,
903 )> = files
904 .par_iter()
905 .map_init(HashMap::new, |parsers, file| {
906 let result = collect_file_metadata(file).and_then(|metadata| {
907 collect_file_chunks(project_root, file, parsers)
908 .map(|chunks| (metadata, chunks))
909 });
910 (file.clone(), result)
911 })
912 .collect();
913
914 let mut chunks: Vec<SemanticChunk> = Vec::new();
915 let mut file_metadata: HashMap<PathBuf, IndexedFileMetadata> = HashMap::new();
916
917 for (file, result) in per_file {
918 match result {
919 Ok((metadata, file_chunks)) => {
920 file_metadata.insert(file, metadata);
921 chunks.extend(file_chunks);
922 }
923 Err(error) => {
924 if error == "unsupported file extension" {
930 continue;
931 }
932 slog_warn!(
933 "failed to collect semantic chunks for {}: {}",
934 file.display(),
935 error
936 );
937 }
938 }
939 }
940
941 (chunks, file_metadata)
942 }
943
944 fn build_from_chunks<F, P>(
945 chunks: Vec<SemanticChunk>,
946 file_metadata: HashMap<PathBuf, IndexedFileMetadata>,
947 embed_fn: &mut F,
948 max_batch_size: usize,
949 mut progress: Option<&mut P>,
950 ) -> Result<Self, String>
951 where
952 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
953 P: FnMut(usize, usize),
954 {
955 let total_chunks = chunks.len();
956
957 if chunks.is_empty() {
958 return Ok(Self {
959 entries: Vec::new(),
960 file_mtimes: file_metadata
961 .iter()
962 .map(|(path, metadata)| (path.clone(), metadata.mtime))
963 .collect(),
964 file_sizes: file_metadata
965 .into_iter()
966 .map(|(path, metadata)| (path, metadata.size))
967 .collect(),
968 dimension: DEFAULT_DIMENSION,
969 fingerprint: None,
970 });
971 }
972
973 let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
975 let mut expected_dimension: Option<usize> = None;
976 let batch_size = max_batch_size.max(1);
977 for batch_start in (0..chunks.len()).step_by(batch_size) {
978 let batch_end = (batch_start + batch_size).min(chunks.len());
979 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
980 .iter()
981 .map(|c| c.embed_text.clone())
982 .collect();
983
984 let vectors = embed_fn(batch_texts)?;
985 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
986
987 if let Some(dim) = vectors.first().map(|v| v.len()) {
989 match expected_dimension {
990 None => expected_dimension = Some(dim),
991 Some(expected) if dim != expected => {
992 return Err(format!(
993 "embedding dimension changed across batches: expected {expected}, got {dim}"
994 ));
995 }
996 _ => {}
997 }
998 }
999
1000 for (i, vector) in vectors.into_iter().enumerate() {
1001 let chunk_idx = batch_start + i;
1002 entries.push(EmbeddingEntry {
1003 chunk: chunks[chunk_idx].clone(),
1004 vector,
1005 });
1006 }
1007
1008 if let Some(callback) = progress.as_mut() {
1009 callback(entries.len(), total_chunks);
1010 }
1011 }
1012
1013 let dimension = entries
1014 .first()
1015 .map(|e| e.vector.len())
1016 .unwrap_or(DEFAULT_DIMENSION);
1017
1018 Ok(Self {
1019 entries,
1020 file_mtimes: file_metadata
1021 .iter()
1022 .map(|(path, metadata)| (path.clone(), metadata.mtime))
1023 .collect(),
1024 file_sizes: file_metadata
1025 .into_iter()
1026 .map(|(path, metadata)| (path, metadata.size))
1027 .collect(),
1028 dimension,
1029 fingerprint: None,
1030 })
1031 }
1032
1033 pub fn build<F>(
1036 project_root: &Path,
1037 files: &[PathBuf],
1038 embed_fn: &mut F,
1039 max_batch_size: usize,
1040 ) -> Result<Self, String>
1041 where
1042 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1043 {
1044 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1045 Self::build_from_chunks(
1046 chunks,
1047 file_mtimes,
1048 embed_fn,
1049 max_batch_size,
1050 Option::<&mut fn(usize, usize)>::None,
1051 )
1052 }
1053
1054 pub fn build_with_progress<F, P>(
1056 project_root: &Path,
1057 files: &[PathBuf],
1058 embed_fn: &mut F,
1059 max_batch_size: usize,
1060 progress: &mut P,
1061 ) -> Result<Self, String>
1062 where
1063 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1064 P: FnMut(usize, usize),
1065 {
1066 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1067 let total_chunks = chunks.len();
1068 progress(0, total_chunks);
1069 Self::build_from_chunks(
1070 chunks,
1071 file_mtimes,
1072 embed_fn,
1073 max_batch_size,
1074 Some(progress),
1075 )
1076 }
1077
1078 pub fn refresh_stale_files<F, P>(
1089 &mut self,
1090 project_root: &Path,
1091 current_files: &[PathBuf],
1092 embed_fn: &mut F,
1093 max_batch_size: usize,
1094 progress: &mut P,
1095 ) -> Result<RefreshSummary, String>
1096 where
1097 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1098 P: FnMut(usize, usize),
1099 {
1100 self.backfill_missing_file_sizes();
1101
1102 let current_set: HashSet<&Path> = current_files.iter().map(PathBuf::as_path).collect();
1104 let total_processed = current_set.len() + self.file_mtimes.len()
1105 - self
1106 .file_mtimes
1107 .keys()
1108 .filter(|path| current_set.contains(path.as_path()))
1109 .count();
1110
1111 let mut deleted: Vec<PathBuf> = Vec::new();
1114 let mut changed: Vec<PathBuf> = Vec::new();
1115 for indexed_path in self.file_mtimes.keys() {
1116 if !current_set.contains(indexed_path.as_path()) {
1117 deleted.push(indexed_path.clone());
1118 continue;
1119 }
1120 if self.is_file_stale(indexed_path) {
1121 changed.push(indexed_path.clone());
1122 }
1123 }
1124
1125 let mut added: Vec<PathBuf> = Vec::new();
1127 for path in current_files {
1128 if !self.file_mtimes.contains_key(path) {
1129 added.push(path.clone());
1130 }
1131 }
1132
1133 if deleted.is_empty() && changed.is_empty() && added.is_empty() {
1135 progress(0, 0);
1136 return Ok(RefreshSummary {
1137 total_processed,
1138 ..RefreshSummary::default()
1139 });
1140 }
1141
1142 if !deleted.is_empty() {
1146 let deleted_set: HashSet<&Path> = deleted.iter().map(PathBuf::as_path).collect();
1147 self.entries
1148 .retain(|entry| !deleted_set.contains(entry.chunk.file.as_path()));
1149 for path in &deleted {
1150 self.file_mtimes.remove(path);
1151 self.file_sizes.remove(path);
1152 }
1153 }
1154
1155 let mut to_embed: Vec<PathBuf> = Vec::with_capacity(changed.len() + added.len());
1157 to_embed.extend(changed.iter().cloned());
1158 to_embed.extend(added.iter().cloned());
1159
1160 if to_embed.is_empty() {
1161 progress(0, 0);
1163 return Ok(RefreshSummary {
1164 changed: 0,
1165 added: 0,
1166 deleted: deleted.len(),
1167 total_processed,
1168 });
1169 }
1170
1171 let (chunks, fresh_metadata) = Self::collect_chunks(project_root, &to_embed);
1172
1173 if chunks.is_empty() {
1174 progress(0, 0);
1175 let successful_files: HashSet<PathBuf> = fresh_metadata.keys().cloned().collect();
1176 if !successful_files.is_empty() {
1177 self.entries
1178 .retain(|entry| !successful_files.contains(&entry.chunk.file));
1179 }
1180 let changed_count = changed
1181 .iter()
1182 .filter(|path| successful_files.contains(*path))
1183 .count();
1184 let added_count = added
1185 .iter()
1186 .filter(|path| successful_files.contains(*path))
1187 .count();
1188 for (file, metadata) in fresh_metadata {
1189 self.file_mtimes.insert(file.clone(), metadata.mtime);
1190 self.file_sizes.insert(file, metadata.size);
1191 }
1192 return Ok(RefreshSummary {
1193 changed: changed_count,
1194 added: added_count,
1195 deleted: deleted.len(),
1196 total_processed,
1197 });
1198 }
1199
1200 let total_chunks = chunks.len();
1202 progress(0, total_chunks);
1203 let batch_size = max_batch_size.max(1);
1204 let existing_dimension = if self.entries.is_empty() {
1205 None
1206 } else {
1207 Some(self.dimension)
1208 };
1209 let mut new_entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
1210 let mut observed_dimension: Option<usize> = existing_dimension;
1211
1212 for batch_start in (0..chunks.len()).step_by(batch_size) {
1213 let batch_end = (batch_start + batch_size).min(chunks.len());
1214 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
1215 .iter()
1216 .map(|c| c.embed_text.clone())
1217 .collect();
1218
1219 let vectors = embed_fn(batch_texts)?;
1220 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
1221
1222 if let Some(dim) = vectors.first().map(|v| v.len()) {
1223 match observed_dimension {
1224 None => observed_dimension = Some(dim),
1225 Some(expected) if dim != expected => {
1226 return Err(format!(
1229 "embedding dimension changed during incremental refresh: \
1230 cached index uses {expected}, new vectors use {dim}"
1231 ));
1232 }
1233 _ => {}
1234 }
1235 }
1236
1237 for (i, vector) in vectors.into_iter().enumerate() {
1238 let chunk_idx = batch_start + i;
1239 new_entries.push(EmbeddingEntry {
1240 chunk: chunks[chunk_idx].clone(),
1241 vector,
1242 });
1243 }
1244
1245 progress(new_entries.len(), total_chunks);
1246 }
1247
1248 let successful_files: HashSet<PathBuf> = fresh_metadata.keys().cloned().collect();
1249 if !successful_files.is_empty() {
1250 self.entries
1251 .retain(|entry| !successful_files.contains(&entry.chunk.file));
1252 }
1253
1254 self.entries.extend(new_entries);
1255 for (file, metadata) in fresh_metadata {
1256 self.file_mtimes.insert(file.clone(), metadata.mtime);
1257 self.file_sizes.insert(file, metadata.size);
1258 }
1259 if let Some(dim) = observed_dimension {
1260 self.dimension = dim;
1261 }
1262
1263 Ok(RefreshSummary {
1264 changed: changed
1265 .iter()
1266 .filter(|path| successful_files.contains(*path))
1267 .count(),
1268 added: added
1269 .iter()
1270 .filter(|path| successful_files.contains(*path))
1271 .count(),
1272 deleted: deleted.len(),
1273 total_processed,
1274 })
1275 }
1276
1277 pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
1279 if self.entries.is_empty() || query_vector.len() != self.dimension {
1280 return Vec::new();
1281 }
1282
1283 let mut scored: Vec<(f32, usize)> = self
1284 .entries
1285 .iter()
1286 .enumerate()
1287 .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
1288 .collect();
1289
1290 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
1292
1293 scored
1294 .into_iter()
1295 .take(top_k)
1296 .filter(|(score, _)| *score > 0.0)
1297 .map(|(score, idx)| {
1298 let entry = &self.entries[idx];
1299 SemanticResult {
1300 file: entry.chunk.file.clone(),
1301 name: entry.chunk.name.clone(),
1302 kind: entry.chunk.kind.clone(),
1303 start_line: entry.chunk.start_line,
1304 end_line: entry.chunk.end_line,
1305 exported: entry.chunk.exported,
1306 snippet: entry.chunk.snippet.clone(),
1307 score,
1308 }
1309 })
1310 .collect()
1311 }
1312
1313 pub fn len(&self) -> usize {
1315 self.entries.len()
1316 }
1317
1318 pub fn is_file_stale(&self, file: &Path) -> bool {
1320 let Some(stored_mtime) = self.file_mtimes.get(file) else {
1321 return true;
1322 };
1323 let Some(stored_size) = self.file_sizes.get(file) else {
1324 return true;
1325 };
1326 match collect_file_metadata(file) {
1327 Ok(current) => *stored_mtime != current.mtime || *stored_size != current.size,
1328 Err(_) => true,
1329 }
1330 }
1331
1332 fn backfill_missing_file_sizes(&mut self) {
1333 for path in self.file_mtimes.keys() {
1334 if self.file_sizes.contains_key(path) {
1335 continue;
1336 }
1337 if let Ok(metadata) = fs::metadata(path) {
1338 self.file_sizes.insert(path.clone(), metadata.len());
1339 }
1340 }
1341 }
1342
1343 pub fn remove_file(&mut self, file: &Path) {
1345 self.invalidate_file(file);
1346 }
1347
1348 pub fn invalidate_file(&mut self, file: &Path) {
1349 self.entries.retain(|e| e.chunk.file != file);
1350 self.file_mtimes.remove(file);
1351 self.file_sizes.remove(file);
1352 }
1353
1354 pub fn dimension(&self) -> usize {
1356 self.dimension
1357 }
1358
1359 pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1360 self.fingerprint.as_ref()
1361 }
1362
1363 pub fn backend_label(&self) -> Option<&str> {
1364 self.fingerprint.as_ref().map(|f| f.backend.as_str())
1365 }
1366
1367 pub fn model_label(&self) -> Option<&str> {
1368 self.fingerprint.as_ref().map(|f| f.model.as_str())
1369 }
1370
1371 pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1372 self.fingerprint = Some(fingerprint);
1373 }
1374
1375 pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1377 if self.entries.is_empty() {
1380 slog_info!("skipping semantic index persistence (0 entries)");
1381 return;
1382 }
1383 let dir = storage_dir.join("semantic").join(project_key);
1384 if let Err(e) = fs::create_dir_all(&dir) {
1385 slog_warn!("failed to create semantic cache dir: {}", e);
1386 return;
1387 }
1388 let data_path = dir.join("semantic.bin");
1389 let tmp_path = dir.join("semantic.bin.tmp");
1390 let bytes = self.to_bytes();
1391 if let Err(e) = fs::write(&tmp_path, &bytes) {
1392 slog_warn!("failed to write semantic index: {}", e);
1393 let _ = fs::remove_file(&tmp_path);
1394 return;
1395 }
1396 if let Err(e) = fs::rename(&tmp_path, &data_path) {
1397 slog_warn!("failed to rename semantic index: {}", e);
1398 let _ = fs::remove_file(&tmp_path);
1399 return;
1400 }
1401 slog_info!(
1402 "semantic index persisted: {} entries, {:.1} KB",
1403 self.entries.len(),
1404 bytes.len() as f64 / 1024.0
1405 );
1406 }
1407
1408 pub fn read_from_disk(
1410 storage_dir: &Path,
1411 project_key: &str,
1412 expected_fingerprint: Option<&str>,
1413 ) -> Option<Self> {
1414 let data_path = storage_dir
1415 .join("semantic")
1416 .join(project_key)
1417 .join("semantic.bin");
1418 let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1419 if file_len < HEADER_BYTES_V1 {
1420 slog_warn!(
1421 "corrupt semantic index (too small: {} bytes), removing",
1422 file_len
1423 );
1424 let _ = fs::remove_file(&data_path);
1425 return None;
1426 }
1427
1428 let bytes = fs::read(&data_path).ok()?;
1429 let version = bytes[0];
1430 if version != SEMANTIC_INDEX_VERSION_V5 {
1431 slog_info!(
1432 "cached semantic index version {} is older than {}, rebuilding",
1433 version,
1434 SEMANTIC_INDEX_VERSION_V5
1435 );
1436 let _ = fs::remove_file(&data_path);
1437 return None;
1438 }
1439 match Self::from_bytes(&bytes) {
1440 Ok(index) => {
1441 if index.entries.is_empty() {
1442 slog_info!("cached semantic index is empty, will rebuild");
1443 let _ = fs::remove_file(&data_path);
1444 return None;
1445 }
1446 if let Some(expected) = expected_fingerprint {
1447 let matches = index
1448 .fingerprint()
1449 .map(|fingerprint| fingerprint.matches_expected(expected))
1450 .unwrap_or(false);
1451 if !matches {
1452 slog_info!("cached semantic index fingerprint mismatch, rebuilding");
1453 let _ = fs::remove_file(&data_path);
1454 return None;
1455 }
1456 }
1457 slog_info!(
1458 "loaded semantic index from disk: {} entries",
1459 index.entries.len()
1460 );
1461 Some(index)
1462 }
1463 Err(e) => {
1464 slog_warn!("corrupt semantic index, rebuilding: {}", e);
1465 let _ = fs::remove_file(&data_path);
1466 None
1467 }
1468 }
1469 }
1470
1471 pub fn to_bytes(&self) -> Vec<u8> {
1473 let mut buf = Vec::new();
1474 let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1475 let encoded = fingerprint.as_string();
1476 if encoded.is_empty() {
1477 None
1478 } else {
1479 Some(encoded.into_bytes())
1480 }
1481 });
1482
1483 let version = SEMANTIC_INDEX_VERSION_V5;
1495 buf.push(version);
1496 buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1497 buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1498 let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1499 buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1500 buf.extend_from_slice(fp_bytes_ref);
1501
1502 buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1505 for (path, mtime) in &self.file_mtimes {
1506 let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1507 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1508 buf.extend_from_slice(&path_bytes);
1509 let duration = mtime
1510 .duration_since(SystemTime::UNIX_EPOCH)
1511 .unwrap_or_default();
1512 buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1513 buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1514 let size = self.file_sizes.get(path).copied().unwrap_or_default();
1515 buf.extend_from_slice(&size.to_le_bytes());
1516 }
1517
1518 for entry in &self.entries {
1520 let c = &entry.chunk;
1521
1522 let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1524 buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1525 buf.extend_from_slice(&file_bytes);
1526
1527 let name_bytes = c.name.as_bytes();
1529 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1530 buf.extend_from_slice(name_bytes);
1531
1532 buf.push(symbol_kind_to_u8(&c.kind));
1534
1535 buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1537 buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1538 buf.push(c.exported as u8);
1539
1540 let snippet_bytes = c.snippet.as_bytes();
1542 buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1543 buf.extend_from_slice(snippet_bytes);
1544
1545 let embed_bytes = c.embed_text.as_bytes();
1547 buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1548 buf.extend_from_slice(embed_bytes);
1549
1550 for &val in &entry.vector {
1552 buf.extend_from_slice(&val.to_le_bytes());
1553 }
1554 }
1555
1556 buf
1557 }
1558
1559 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1561 let mut pos = 0;
1562
1563 if data.len() < HEADER_BYTES_V1 {
1564 return Err("data too short".to_string());
1565 }
1566
1567 let version = data[pos];
1568 pos += 1;
1569 if version != SEMANTIC_INDEX_VERSION_V1
1570 && version != SEMANTIC_INDEX_VERSION_V2
1571 && version != SEMANTIC_INDEX_VERSION_V3
1572 && version != SEMANTIC_INDEX_VERSION_V4
1573 && version != SEMANTIC_INDEX_VERSION_V5
1574 {
1575 return Err(format!("unsupported version: {}", version));
1576 }
1577 if (version == SEMANTIC_INDEX_VERSION_V2
1581 || version == SEMANTIC_INDEX_VERSION_V3
1582 || version == SEMANTIC_INDEX_VERSION_V4
1583 || version == SEMANTIC_INDEX_VERSION_V5)
1584 && data.len() < HEADER_BYTES_V2
1585 {
1586 return Err("data too short for semantic index v2/v3/v4/v5 header".to_string());
1587 }
1588
1589 let dimension = read_u32(data, &mut pos)? as usize;
1590 let entry_count = read_u32(data, &mut pos)? as usize;
1591 if dimension == 0 || dimension > MAX_DIMENSION {
1592 return Err(format!("invalid embedding dimension: {}", dimension));
1593 }
1594 if entry_count > MAX_ENTRIES {
1595 return Err(format!("too many semantic index entries: {}", entry_count));
1596 }
1597
1598 let has_fingerprint_field = version == SEMANTIC_INDEX_VERSION_V2
1604 || version == SEMANTIC_INDEX_VERSION_V3
1605 || version == SEMANTIC_INDEX_VERSION_V4
1606 || version == SEMANTIC_INDEX_VERSION_V5;
1607 let fingerprint = if has_fingerprint_field {
1608 let fingerprint_len = read_u32(data, &mut pos)? as usize;
1609 if pos + fingerprint_len > data.len() {
1610 return Err("unexpected end of data reading fingerprint".to_string());
1611 }
1612 if fingerprint_len == 0 {
1613 None
1614 } else {
1615 let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1616 pos += fingerprint_len;
1617 Some(
1618 serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1619 .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1620 )
1621 }
1622 } else {
1623 None
1624 };
1625
1626 let mtime_count = read_u32(data, &mut pos)? as usize;
1628 if mtime_count > MAX_ENTRIES {
1629 return Err(format!("too many semantic file mtimes: {}", mtime_count));
1630 }
1631
1632 let vector_bytes = entry_count
1633 .checked_mul(dimension)
1634 .and_then(|count| count.checked_mul(F32_BYTES))
1635 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1636 if vector_bytes > data.len().saturating_sub(pos) {
1637 return Err("semantic index vectors exceed available data".to_string());
1638 }
1639
1640 let mut file_mtimes = HashMap::with_capacity(mtime_count);
1641 let mut file_sizes = HashMap::with_capacity(mtime_count);
1642 for _ in 0..mtime_count {
1643 let path = read_string(data, &mut pos)?;
1644 let secs = read_u64(data, &mut pos)?;
1645 let nanos = if version == SEMANTIC_INDEX_VERSION_V3
1651 || version == SEMANTIC_INDEX_VERSION_V4
1652 || version == SEMANTIC_INDEX_VERSION_V5
1653 {
1654 read_u32(data, &mut pos)?
1655 } else {
1656 0
1657 };
1658 let size = if version == SEMANTIC_INDEX_VERSION_V5 {
1659 read_u64(data, &mut pos)?
1660 } else {
1661 0
1662 };
1663 if nanos >= 1_000_000_000 {
1670 return Err(format!(
1671 "invalid semantic mtime: nanos {} >= 1_000_000_000",
1672 nanos
1673 ));
1674 }
1675 let duration = std::time::Duration::new(secs, nanos);
1676 let mtime = SystemTime::UNIX_EPOCH
1677 .checked_add(duration)
1678 .ok_or_else(|| {
1679 format!(
1680 "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1681 secs, nanos
1682 )
1683 })?;
1684 let path = PathBuf::from(path);
1685 file_mtimes.insert(path.clone(), mtime);
1686 file_sizes.insert(path, size);
1687 }
1688
1689 let mut entries = Vec::with_capacity(entry_count);
1691 for _ in 0..entry_count {
1692 let file = PathBuf::from(read_string(data, &mut pos)?);
1693 let name = read_string(data, &mut pos)?;
1694
1695 if pos >= data.len() {
1696 return Err("unexpected end of data".to_string());
1697 }
1698 let kind = u8_to_symbol_kind(data[pos]);
1699 pos += 1;
1700
1701 let start_line = read_u32(data, &mut pos)?;
1702 let end_line = read_u32(data, &mut pos)?;
1703
1704 if pos >= data.len() {
1705 return Err("unexpected end of data".to_string());
1706 }
1707 let exported = data[pos] != 0;
1708 pos += 1;
1709
1710 let snippet = read_string(data, &mut pos)?;
1711 let embed_text = read_string(data, &mut pos)?;
1712
1713 let vec_bytes = dimension
1715 .checked_mul(F32_BYTES)
1716 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1717 if pos + vec_bytes > data.len() {
1718 return Err("unexpected end of data reading vector".to_string());
1719 }
1720 let mut vector = Vec::with_capacity(dimension);
1721 for _ in 0..dimension {
1722 let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1723 vector.push(f32::from_le_bytes(bytes));
1724 pos += 4;
1725 }
1726
1727 entries.push(EmbeddingEntry {
1728 chunk: SemanticChunk {
1729 file,
1730 name,
1731 kind,
1732 start_line,
1733 end_line,
1734 exported,
1735 embed_text,
1736 snippet,
1737 },
1738 vector,
1739 });
1740 }
1741
1742 Ok(Self {
1743 entries,
1744 file_mtimes,
1745 file_sizes,
1746 dimension,
1747 fingerprint,
1748 })
1749 }
1750}
1751
1752fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1754 let relative = file
1755 .strip_prefix(project_root)
1756 .unwrap_or(file)
1757 .to_string_lossy();
1758
1759 let kind_label = match &symbol.kind {
1760 SymbolKind::Function => "function",
1761 SymbolKind::Class => "class",
1762 SymbolKind::Method => "method",
1763 SymbolKind::Struct => "struct",
1764 SymbolKind::Interface => "interface",
1765 SymbolKind::Enum => "enum",
1766 SymbolKind::TypeAlias => "type",
1767 SymbolKind::Variable => "variable",
1768 SymbolKind::Heading => "heading",
1769 };
1770
1771 let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1773
1774 if let Some(sig) = &symbol.signature {
1775 text.push_str(&format!(" signature:{}", sig));
1776 }
1777
1778 let lines: Vec<&str> = source.lines().collect();
1780 let start = (symbol.range.start_line as usize).min(lines.len());
1781 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1783 if start < end {
1784 let body: String = lines[start..end]
1785 .iter()
1786 .take(15) .copied()
1788 .collect::<Vec<&str>>()
1789 .join("\n");
1790 let snippet = if body.len() > 300 {
1791 format!("{}...", &body[..body.floor_char_boundary(300)])
1792 } else {
1793 body
1794 };
1795 text.push_str(&format!(" body:{}", snippet));
1796 }
1797
1798 text
1799}
1800
1801fn parser_for(
1802 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1803 lang: crate::parser::LangId,
1804) -> Result<&mut Parser, String> {
1805 use std::collections::hash_map::Entry;
1806
1807 match parsers.entry(lang) {
1808 Entry::Occupied(entry) => Ok(entry.into_mut()),
1809 Entry::Vacant(entry) => {
1810 let grammar = grammar_for(lang);
1811 let mut parser = Parser::new();
1812 parser
1813 .set_language(&grammar)
1814 .map_err(|error| error.to_string())?;
1815 Ok(entry.insert(parser))
1816 }
1817 }
1818}
1819
1820fn collect_file_metadata(file: &Path) -> Result<IndexedFileMetadata, String> {
1821 let metadata = fs::metadata(file).map_err(|error| error.to_string())?;
1822 let mtime = metadata.modified().map_err(|error| error.to_string())?;
1823 Ok(IndexedFileMetadata {
1824 mtime,
1825 size: metadata.len(),
1826 })
1827}
1828
1829fn collect_file_chunks(
1830 project_root: &Path,
1831 file: &Path,
1832 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1833) -> Result<Vec<SemanticChunk>, String> {
1834 let lang = detect_language(file).ok_or_else(|| "unsupported file extension".to_string())?;
1835 let source = std::fs::read_to_string(file).map_err(|error| error.to_string())?;
1836 let tree = parser_for(parsers, lang)?
1837 .parse(&source, None)
1838 .ok_or_else(|| format!("tree-sitter parse returned None for {}", file.display()))?;
1839 let symbols =
1840 extract_symbols_from_tree(&source, &tree, lang).map_err(|error| error.to_string())?;
1841
1842 Ok(symbols_to_chunks(file, &symbols, &source, project_root))
1843}
1844
1845fn build_snippet(symbol: &Symbol, source: &str) -> String {
1847 let lines: Vec<&str> = source.lines().collect();
1848 let start = (symbol.range.start_line as usize).min(lines.len());
1849 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1851 if start < end {
1852 let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1853 let mut snippet = snippet_lines.join("\n");
1854 if end - start > 5 {
1855 snippet.push_str("\n ...");
1856 }
1857 if snippet.len() > 300 {
1858 snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1859 }
1860 snippet
1861 } else {
1862 String::new()
1863 }
1864}
1865
1866fn symbols_to_chunks(
1868 file: &Path,
1869 symbols: &[Symbol],
1870 source: &str,
1871 project_root: &Path,
1872) -> Vec<SemanticChunk> {
1873 let mut chunks = Vec::new();
1874
1875 for symbol in symbols {
1876 if matches!(symbol.kind, SymbolKind::Heading) {
1881 continue;
1882 }
1883
1884 let line_count = symbol
1886 .range
1887 .end_line
1888 .saturating_sub(symbol.range.start_line)
1889 + 1;
1890 if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1891 continue;
1892 }
1893
1894 let embed_text = build_embed_text(symbol, source, file, project_root);
1895 let snippet = build_snippet(symbol, source);
1896
1897 chunks.push(SemanticChunk {
1898 file: file.to_path_buf(),
1899 name: symbol.name.clone(),
1900 kind: symbol.kind.clone(),
1901 start_line: symbol.range.start_line,
1902 end_line: symbol.range.end_line,
1903 exported: symbol.exported,
1904 embed_text,
1905 snippet,
1906 });
1907
1908 }
1911
1912 chunks
1913}
1914
1915fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1917 if a.len() != b.len() {
1918 return 0.0;
1919 }
1920
1921 let mut dot = 0.0f32;
1922 let mut norm_a = 0.0f32;
1923 let mut norm_b = 0.0f32;
1924
1925 for i in 0..a.len() {
1926 dot += a[i] * b[i];
1927 norm_a += a[i] * a[i];
1928 norm_b += b[i] * b[i];
1929 }
1930
1931 let denom = norm_a.sqrt() * norm_b.sqrt();
1932 if denom == 0.0 {
1933 0.0
1934 } else {
1935 dot / denom
1936 }
1937}
1938
1939fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1941 match kind {
1942 SymbolKind::Function => 0,
1943 SymbolKind::Class => 1,
1944 SymbolKind::Method => 2,
1945 SymbolKind::Struct => 3,
1946 SymbolKind::Interface => 4,
1947 SymbolKind::Enum => 5,
1948 SymbolKind::TypeAlias => 6,
1949 SymbolKind::Variable => 7,
1950 SymbolKind::Heading => 8,
1951 }
1952}
1953
1954fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1955 match v {
1956 0 => SymbolKind::Function,
1957 1 => SymbolKind::Class,
1958 2 => SymbolKind::Method,
1959 3 => SymbolKind::Struct,
1960 4 => SymbolKind::Interface,
1961 5 => SymbolKind::Enum,
1962 6 => SymbolKind::TypeAlias,
1963 7 => SymbolKind::Variable,
1964 _ => SymbolKind::Heading,
1965 }
1966}
1967
1968fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
1969 if *pos + 4 > data.len() {
1970 return Err("unexpected end of data reading u32".to_string());
1971 }
1972 let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
1973 *pos += 4;
1974 Ok(val)
1975}
1976
1977fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
1978 if *pos + 8 > data.len() {
1979 return Err("unexpected end of data reading u64".to_string());
1980 }
1981 let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
1982 *pos += 8;
1983 Ok(u64::from_le_bytes(bytes))
1984}
1985
1986fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
1987 let len = read_u32(data, pos)? as usize;
1988 if *pos + len > data.len() {
1989 return Err("unexpected end of data reading string".to_string());
1990 }
1991 let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
1992 *pos += len;
1993 Ok(s)
1994}
1995
1996#[cfg(test)]
1997mod tests {
1998 use super::*;
1999 use crate::config::{SemanticBackend, SemanticBackendConfig};
2000 use crate::parser::FileParser;
2001 use std::io::{Read, Write};
2002 use std::net::TcpListener;
2003 use std::thread;
2004
2005 fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
2006 where
2007 F: Fn(String, String, String) -> String + Send + 'static,
2008 {
2009 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
2010 let addr = listener.local_addr().expect("local addr");
2011 let handle = thread::spawn(move || {
2012 let (mut stream, _) = listener.accept().expect("accept request");
2013 let mut buf = Vec::new();
2014 let mut chunk = [0u8; 4096];
2015 let mut header_end = None;
2016 let mut content_length = 0usize;
2017 loop {
2018 let n = stream.read(&mut chunk).expect("read request");
2019 if n == 0 {
2020 break;
2021 }
2022 buf.extend_from_slice(&chunk[..n]);
2023 if header_end.is_none() {
2024 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
2025 header_end = Some(pos + 4);
2026 let headers = String::from_utf8_lossy(&buf[..pos + 4]);
2027 for line in headers.lines() {
2028 if let Some(value) = line.strip_prefix("Content-Length:") {
2029 content_length = value.trim().parse::<usize>().unwrap_or(0);
2030 }
2031 }
2032 }
2033 }
2034 if let Some(end) = header_end {
2035 if buf.len() >= end + content_length {
2036 break;
2037 }
2038 }
2039 }
2040
2041 let end = header_end.expect("header terminator");
2042 let request = String::from_utf8_lossy(&buf[..end]).to_string();
2043 let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
2044 let mut lines = request.lines();
2045 let request_line = lines.next().expect("request line").to_string();
2046 let path = request_line
2047 .split_whitespace()
2048 .nth(1)
2049 .expect("request path")
2050 .to_string();
2051 let response_body = handler(request_line, path, body);
2052 let response = format!(
2053 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
2054 response_body.len(),
2055 response_body
2056 );
2057 stream
2058 .write_all(response.as_bytes())
2059 .expect("write response");
2060 });
2061
2062 (format!("http://{}", addr), handle)
2063 }
2064
2065 fn test_vector_for_texts(texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
2066 Ok(texts.iter().map(|_| vec![1.0, 0.0, 0.0]).collect())
2067 }
2068
2069 fn write_rust_file(path: &Path, function_name: &str) {
2070 fs::write(
2071 path,
2072 format!("pub fn {function_name}() -> bool {{\n true\n}}\n"),
2073 )
2074 .unwrap();
2075 }
2076
2077 fn build_test_index(project_root: &Path, files: &[PathBuf]) -> SemanticIndex {
2078 let mut embed = test_vector_for_texts;
2079 SemanticIndex::build(project_root, files, &mut embed, 8).unwrap()
2080 }
2081
2082 fn set_file_metadata(index: &mut SemanticIndex, file: &Path, mtime: SystemTime, size: u64) {
2083 index.file_mtimes.insert(file.to_path_buf(), mtime);
2084 index.file_sizes.insert(file.to_path_buf(), size);
2085 }
2086
2087 #[test]
2088 fn test_cosine_similarity_identical() {
2089 let a = vec![1.0, 0.0, 0.0];
2090 let b = vec![1.0, 0.0, 0.0];
2091 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
2092 }
2093
2094 #[test]
2095 fn test_cosine_similarity_orthogonal() {
2096 let a = vec![1.0, 0.0, 0.0];
2097 let b = vec![0.0, 1.0, 0.0];
2098 assert!(cosine_similarity(&a, &b).abs() < 0.001);
2099 }
2100
2101 #[test]
2102 fn test_cosine_similarity_opposite() {
2103 let a = vec![1.0, 0.0, 0.0];
2104 let b = vec![-1.0, 0.0, 0.0];
2105 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
2106 }
2107
2108 #[test]
2109 fn test_serialization_roundtrip() {
2110 let mut index = SemanticIndex::new();
2111 index.entries.push(EmbeddingEntry {
2112 chunk: SemanticChunk {
2113 file: PathBuf::from("/src/main.rs"),
2114 name: "handle_request".to_string(),
2115 kind: SymbolKind::Function,
2116 start_line: 10,
2117 end_line: 25,
2118 exported: true,
2119 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2120 snippet: "fn handle_request() {\n // ...\n}".to_string(),
2121 },
2122 vector: vec![0.1, 0.2, 0.3, 0.4],
2123 });
2124 index.dimension = 4;
2125 index
2126 .file_mtimes
2127 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2128 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2129 index.set_fingerprint(SemanticIndexFingerprint {
2130 backend: "fastembed".to_string(),
2131 model: "all-MiniLM-L6-v2".to_string(),
2132 base_url: FALLBACK_BACKEND.to_string(),
2133 dimension: 4,
2134 });
2135
2136 let bytes = index.to_bytes();
2137 let restored = SemanticIndex::from_bytes(&bytes).unwrap();
2138
2139 assert_eq!(restored.entries.len(), 1);
2140 assert_eq!(restored.entries[0].chunk.name, "handle_request");
2141 assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
2142 assert_eq!(restored.dimension, 4);
2143 assert_eq!(restored.backend_label(), Some("fastembed"));
2144 assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
2145 }
2146
2147 #[test]
2148 fn test_search_top_k() {
2149 let mut index = SemanticIndex::new();
2150 index.dimension = 3;
2151
2152 for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
2154 let mut vec = vec![0.0f32; 3];
2155 vec[i] = 1.0; index.entries.push(EmbeddingEntry {
2157 chunk: SemanticChunk {
2158 file: PathBuf::from("/src/lib.rs"),
2159 name: name.to_string(),
2160 kind: SymbolKind::Function,
2161 start_line: (i * 10 + 1) as u32,
2162 end_line: (i * 10 + 5) as u32,
2163 exported: true,
2164 embed_text: format!("kind:function name:{}", name),
2165 snippet: format!("fn {}() {{}}", name),
2166 },
2167 vector: vec,
2168 });
2169 }
2170
2171 let query = vec![0.9, 0.1, 0.0];
2173 let results = index.search(&query, 2);
2174
2175 assert_eq!(results.len(), 2);
2176 assert_eq!(results[0].name, "auth"); assert!(results[0].score > results[1].score);
2178 }
2179
2180 #[test]
2181 fn test_empty_index_search() {
2182 let index = SemanticIndex::new();
2183 let results = index.search(&[0.1, 0.2, 0.3], 10);
2184 assert!(results.is_empty());
2185 }
2186
2187 #[test]
2188 fn single_line_symbol_builds_non_empty_snippet() {
2189 let symbol = Symbol {
2190 name: "answer".to_string(),
2191 kind: SymbolKind::Variable,
2192 range: crate::symbols::Range {
2193 start_line: 0,
2194 start_col: 0,
2195 end_line: 0,
2196 end_col: 24,
2197 },
2198 signature: Some("const answer = 42".to_string()),
2199 scope_chain: Vec::new(),
2200 exported: true,
2201 parent: None,
2202 };
2203 let source = "export const answer = 42;\n";
2204
2205 let snippet = build_snippet(&symbol, source);
2206
2207 assert_eq!(snippet, "export const answer = 42;");
2208 }
2209
2210 #[test]
2211 fn optimized_file_chunk_collection_matches_file_parser_path() {
2212 let project_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2213 let file = project_root.join("src/semantic_index.rs");
2214 let source = std::fs::read_to_string(&file).unwrap();
2215
2216 let mut legacy_parser = FileParser::new();
2217 let legacy_symbols = legacy_parser.extract_symbols(&file).unwrap();
2218 let legacy_chunks = symbols_to_chunks(&file, &legacy_symbols, &source, &project_root);
2219
2220 let mut parsers = HashMap::new();
2221 let optimized_chunks = collect_file_chunks(&project_root, &file, &mut parsers).unwrap();
2222
2223 assert_eq!(
2224 chunk_fingerprint(&optimized_chunks),
2225 chunk_fingerprint(&legacy_chunks)
2226 );
2227 }
2228
2229 fn chunk_fingerprint(
2230 chunks: &[SemanticChunk],
2231 ) -> Vec<(String, SymbolKind, u32, u32, bool, String, String)> {
2232 chunks
2233 .iter()
2234 .map(|chunk| {
2235 (
2236 chunk.name.clone(),
2237 chunk.kind.clone(),
2238 chunk.start_line,
2239 chunk.end_line,
2240 chunk.exported,
2241 chunk.embed_text.clone(),
2242 chunk.snippet.clone(),
2243 )
2244 })
2245 .collect()
2246 }
2247
2248 #[test]
2249 fn rejects_oversized_dimension_during_deserialization() {
2250 let mut bytes = Vec::new();
2251 bytes.push(1u8);
2252 bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
2253 bytes.extend_from_slice(&0u32.to_le_bytes());
2254 bytes.extend_from_slice(&0u32.to_le_bytes());
2255
2256 assert!(SemanticIndex::from_bytes(&bytes).is_err());
2257 }
2258
2259 #[test]
2260 fn rejects_oversized_entry_count_during_deserialization() {
2261 let mut bytes = Vec::new();
2262 bytes.push(1u8);
2263 bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
2264 bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
2265 bytes.extend_from_slice(&0u32.to_le_bytes());
2266
2267 assert!(SemanticIndex::from_bytes(&bytes).is_err());
2268 }
2269
2270 #[test]
2271 fn invalidate_file_removes_entries_and_mtime() {
2272 let target = PathBuf::from("/src/main.rs");
2273 let mut index = SemanticIndex::new();
2274 index.entries.push(EmbeddingEntry {
2275 chunk: SemanticChunk {
2276 file: target.clone(),
2277 name: "main".to_string(),
2278 kind: SymbolKind::Function,
2279 start_line: 0,
2280 end_line: 1,
2281 exported: false,
2282 embed_text: "main".to_string(),
2283 snippet: "fn main() {}".to_string(),
2284 },
2285 vector: vec![1.0; DEFAULT_DIMENSION],
2286 });
2287 index
2288 .file_mtimes
2289 .insert(target.clone(), SystemTime::UNIX_EPOCH);
2290 index.file_sizes.insert(target.clone(), 0);
2291
2292 index.invalidate_file(&target);
2293
2294 assert!(index.entries.is_empty());
2295 assert!(!index.file_mtimes.contains_key(&target));
2296 assert!(!index.file_sizes.contains_key(&target));
2297 }
2298
2299 #[test]
2300 fn refresh_transient_error_preserves_existing_entry_and_mtime() {
2301 let temp = tempfile::tempdir().unwrap();
2302 let project_root = temp.path();
2303 let file = project_root.join("src/lib.rs");
2304 fs::create_dir_all(file.parent().unwrap()).unwrap();
2305 write_rust_file(&file, "kept_symbol");
2306
2307 let mut index = build_test_index(project_root, &[file.clone()]);
2308 let original_entry_count = index.entries.len();
2309 let original_mtime = *index.file_mtimes.get(&file).unwrap();
2310 let original_size = *index.file_sizes.get(&file).unwrap();
2311
2312 let stale_mtime = SystemTime::UNIX_EPOCH;
2313 set_file_metadata(&mut index, &file, stale_mtime, original_size + 1);
2314 fs::remove_file(&file).unwrap();
2315
2316 let mut embed = test_vector_for_texts;
2317 let mut progress = |_done: usize, _total: usize| {};
2318 let summary = index
2319 .refresh_stale_files(project_root, &[file.clone()], &mut embed, 8, &mut progress)
2320 .unwrap();
2321
2322 assert_eq!(summary.changed, 0);
2323 assert_eq!(summary.added, 0);
2324 assert_eq!(summary.deleted, 0);
2325 assert_eq!(index.entries.len(), original_entry_count);
2326 assert_eq!(index.entries[0].chunk.name, "kept_symbol");
2327 assert_eq!(index.file_mtimes.get(&file), Some(&stale_mtime));
2328 assert_ne!(index.file_mtimes.get(&file), Some(&original_mtime));
2329 assert_eq!(index.file_sizes.get(&file), Some(&(original_size + 1)));
2330 }
2331
2332 #[test]
2333 fn refresh_never_indexed_file_error_does_not_record_mtime() {
2334 let temp = tempfile::tempdir().unwrap();
2335 let project_root = temp.path();
2336 let missing = project_root.join("src/missing.rs");
2337 fs::create_dir_all(missing.parent().unwrap()).unwrap();
2338
2339 let mut index = SemanticIndex::new();
2340 let mut embed = test_vector_for_texts;
2341 let mut progress = |_done: usize, _total: usize| {};
2342 let summary = index
2343 .refresh_stale_files(
2344 project_root,
2345 &[missing.clone()],
2346 &mut embed,
2347 8,
2348 &mut progress,
2349 )
2350 .unwrap();
2351
2352 assert_eq!(summary.added, 0);
2353 assert_eq!(summary.changed, 0);
2354 assert_eq!(summary.deleted, 0);
2355 assert!(!index.file_mtimes.contains_key(&missing));
2356 assert!(!index.file_sizes.contains_key(&missing));
2357 assert!(index.entries.is_empty());
2358 }
2359
2360 #[test]
2361 fn refresh_reports_added_for_new_files() {
2362 let temp = tempfile::tempdir().unwrap();
2363 let project_root = temp.path();
2364 let existing = project_root.join("src/lib.rs");
2365 let added = project_root.join("src/new.rs");
2366 fs::create_dir_all(existing.parent().unwrap()).unwrap();
2367 write_rust_file(&existing, "existing_symbol");
2368 write_rust_file(&added, "added_symbol");
2369
2370 let mut index = build_test_index(project_root, &[existing.clone()]);
2371 let mut embed = test_vector_for_texts;
2372 let mut progress = |_done: usize, _total: usize| {};
2373 let summary = index
2374 .refresh_stale_files(
2375 project_root,
2376 &[existing.clone(), added.clone()],
2377 &mut embed,
2378 8,
2379 &mut progress,
2380 )
2381 .unwrap();
2382
2383 assert_eq!(summary.added, 1);
2384 assert_eq!(summary.changed, 0);
2385 assert_eq!(summary.deleted, 0);
2386 assert_eq!(summary.total_processed, 2);
2387 assert!(index.file_mtimes.contains_key(&added));
2388 assert!(index.entries.iter().any(|entry| entry.chunk.file == added));
2389 }
2390
2391 #[test]
2392 fn refresh_reports_deleted_for_removed_files() {
2393 let temp = tempfile::tempdir().unwrap();
2394 let project_root = temp.path();
2395 let deleted = project_root.join("src/deleted.rs");
2396 fs::create_dir_all(deleted.parent().unwrap()).unwrap();
2397 write_rust_file(&deleted, "deleted_symbol");
2398
2399 let mut index = build_test_index(project_root, &[deleted.clone()]);
2400 fs::remove_file(&deleted).unwrap();
2401
2402 let mut embed = test_vector_for_texts;
2403 let mut progress = |_done: usize, _total: usize| {};
2404 let summary = index
2405 .refresh_stale_files(project_root, &[], &mut embed, 8, &mut progress)
2406 .unwrap();
2407
2408 assert_eq!(summary.deleted, 1);
2409 assert_eq!(summary.changed, 0);
2410 assert_eq!(summary.added, 0);
2411 assert_eq!(summary.total_processed, 1);
2412 assert!(!index.file_mtimes.contains_key(&deleted));
2413 assert!(index.entries.is_empty());
2414 }
2415
2416 #[test]
2417 fn refresh_reports_changed_for_modified_files() {
2418 let temp = tempfile::tempdir().unwrap();
2419 let project_root = temp.path();
2420 let file = project_root.join("src/lib.rs");
2421 fs::create_dir_all(file.parent().unwrap()).unwrap();
2422 write_rust_file(&file, "old_symbol");
2423
2424 let mut index = build_test_index(project_root, &[file.clone()]);
2425 set_file_metadata(&mut index, &file, SystemTime::UNIX_EPOCH, 0);
2426 write_rust_file(&file, "new_symbol");
2427
2428 let mut embed = test_vector_for_texts;
2429 let mut progress = |_done: usize, _total: usize| {};
2430 let summary = index
2431 .refresh_stale_files(project_root, &[file.clone()], &mut embed, 8, &mut progress)
2432 .unwrap();
2433
2434 assert_eq!(summary.changed, 1);
2435 assert_eq!(summary.added, 0);
2436 assert_eq!(summary.deleted, 0);
2437 assert_eq!(summary.total_processed, 1);
2438 assert!(index
2439 .entries
2440 .iter()
2441 .any(|entry| entry.chunk.name == "new_symbol"));
2442 assert!(!index
2443 .entries
2444 .iter()
2445 .any(|entry| entry.chunk.name == "old_symbol"));
2446 }
2447
2448 #[test]
2449 fn refresh_all_clean_reports_zero_counts_and_no_embedding_work() {
2450 let temp = tempfile::tempdir().unwrap();
2451 let project_root = temp.path();
2452 let file = project_root.join("src/lib.rs");
2453 fs::create_dir_all(file.parent().unwrap()).unwrap();
2454 write_rust_file(&file, "clean_symbol");
2455
2456 let mut index = build_test_index(project_root, &[file.clone()]);
2457 let original_entries = index.entries.len();
2458 let mut embed_called = false;
2459 let mut embed = |texts: Vec<String>| {
2460 embed_called = true;
2461 test_vector_for_texts(texts)
2462 };
2463 let mut progress = |_done: usize, _total: usize| {};
2464 let summary = index
2465 .refresh_stale_files(project_root, &[file.clone()], &mut embed, 8, &mut progress)
2466 .unwrap();
2467
2468 assert!(summary.is_noop());
2469 assert_eq!(summary.total_processed, 1);
2470 assert!(!embed_called);
2471 assert_eq!(index.entries.len(), original_entries);
2472 }
2473
2474 #[test]
2475 fn detects_missing_onnx_runtime_from_dynamic_load_error() {
2476 let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
2477
2478 assert!(is_onnx_runtime_unavailable(message));
2479 }
2480
2481 #[test]
2482 fn formats_missing_onnx_runtime_with_install_hint() {
2483 let message = format_embedding_init_error(
2484 "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
2485 );
2486
2487 assert!(message.starts_with("ONNX Runtime not found. Install via:"));
2488 assert!(message.contains("Original error:"));
2489 }
2490
2491 #[test]
2492 fn openai_compatible_backend_embeds_with_mock_server() {
2493 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2494 assert!(request_line.starts_with("POST "));
2495 assert_eq!(path, "/v1/embeddings");
2496 "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
2497 });
2498
2499 let config = SemanticBackendConfig {
2500 backend: SemanticBackend::OpenAiCompatible,
2501 model: "test-embedding".to_string(),
2502 base_url: Some(base_url),
2503 api_key_env: None,
2504 timeout_ms: 5_000,
2505 max_batch_size: 64,
2506 };
2507
2508 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2509 let vectors = model
2510 .embed(vec!["hello".to_string(), "world".to_string()])
2511 .unwrap();
2512
2513 assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
2514 handle.join().unwrap();
2515 }
2516
2517 #[test]
2518 fn ollama_backend_embeds_with_mock_server() {
2519 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2520 assert!(request_line.starts_with("POST "));
2521 assert_eq!(path, "/api/embed");
2522 "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
2523 });
2524
2525 let config = SemanticBackendConfig {
2526 backend: SemanticBackend::Ollama,
2527 model: "embeddinggemma".to_string(),
2528 base_url: Some(base_url),
2529 api_key_env: None,
2530 timeout_ms: 5_000,
2531 max_batch_size: 64,
2532 };
2533
2534 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2535 let vectors = model
2536 .embed(vec!["hello".to_string(), "world".to_string()])
2537 .unwrap();
2538
2539 assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
2540 handle.join().unwrap();
2541 }
2542
2543 #[test]
2544 fn read_from_disk_rejects_fingerprint_mismatch() {
2545 let storage = tempfile::tempdir().unwrap();
2546 let project_key = "proj";
2547
2548 let mut index = SemanticIndex::new();
2549 index.entries.push(EmbeddingEntry {
2550 chunk: SemanticChunk {
2551 file: PathBuf::from("/src/main.rs"),
2552 name: "handle_request".to_string(),
2553 kind: SymbolKind::Function,
2554 start_line: 10,
2555 end_line: 25,
2556 exported: true,
2557 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2558 snippet: "fn handle_request() {}".to_string(),
2559 },
2560 vector: vec![0.1, 0.2, 0.3],
2561 });
2562 index.dimension = 3;
2563 index
2564 .file_mtimes
2565 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2566 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2567 index.set_fingerprint(SemanticIndexFingerprint {
2568 backend: "openai_compatible".to_string(),
2569 model: "test-embedding".to_string(),
2570 base_url: "http://127.0.0.1:1234/v1".to_string(),
2571 dimension: 3,
2572 });
2573 index.write_to_disk(storage.path(), project_key);
2574
2575 let matching = index.fingerprint().unwrap().as_string();
2576 assert!(
2577 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
2578 );
2579
2580 let mismatched = SemanticIndexFingerprint {
2581 backend: "ollama".to_string(),
2582 model: "embeddinggemma".to_string(),
2583 base_url: "http://127.0.0.1:11434".to_string(),
2584 dimension: 3,
2585 }
2586 .as_string();
2587 assert!(
2588 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
2589 );
2590 }
2591
2592 #[test]
2593 fn read_from_disk_rejects_v3_cache_for_snippet_rebuild() {
2594 let storage = tempfile::tempdir().unwrap();
2595 let project_key = "proj-v3";
2596 let dir = storage.path().join("semantic").join(project_key);
2597 fs::create_dir_all(&dir).unwrap();
2598
2599 let mut index = SemanticIndex::new();
2600 index.entries.push(EmbeddingEntry {
2601 chunk: SemanticChunk {
2602 file: PathBuf::from("/src/main.rs"),
2603 name: "handle_request".to_string(),
2604 kind: SymbolKind::Function,
2605 start_line: 0,
2606 end_line: 0,
2607 exported: true,
2608 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2609 snippet: "fn handle_request() {}".to_string(),
2610 },
2611 vector: vec![0.1, 0.2, 0.3],
2612 });
2613 index.dimension = 3;
2614 index
2615 .file_mtimes
2616 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2617 index.file_sizes.insert(PathBuf::from("/src/main.rs"), 0);
2618 let fingerprint = SemanticIndexFingerprint {
2619 backend: "fastembed".to_string(),
2620 model: "test".to_string(),
2621 base_url: FALLBACK_BACKEND.to_string(),
2622 dimension: 3,
2623 };
2624 index.set_fingerprint(fingerprint.clone());
2625
2626 let mut bytes = index.to_bytes();
2627 bytes[0] = SEMANTIC_INDEX_VERSION_V3;
2628 fs::write(dir.join("semantic.bin"), bytes).unwrap();
2629
2630 assert!(SemanticIndex::read_from_disk(
2631 storage.path(),
2632 project_key,
2633 Some(&fingerprint.as_string())
2634 )
2635 .is_none());
2636 assert!(!dir.join("semantic.bin").exists());
2637 }
2638
2639 fn make_symbol(kind: SymbolKind, name: &str, start: u32, end: u32) -> crate::symbols::Symbol {
2640 crate::symbols::Symbol {
2641 name: name.to_string(),
2642 kind,
2643 range: crate::symbols::Range {
2644 start_line: start,
2645 start_col: 0,
2646 end_line: end,
2647 end_col: 0,
2648 },
2649 signature: None,
2650 scope_chain: Vec::new(),
2651 exported: false,
2652 parent: None,
2653 }
2654 }
2655
2656 #[test]
2661 fn symbols_to_chunks_skips_heading_symbols() {
2662 let project_root = PathBuf::from("/proj");
2663 let file = project_root.join("README.md");
2664 let source = "# Title\n\nbody text\n\n## Section\n\nmore text\n";
2665
2666 let symbols = vec![
2667 make_symbol(SymbolKind::Heading, "Title", 0, 2),
2668 make_symbol(SymbolKind::Heading, "Section", 4, 6),
2669 ];
2670
2671 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2672 assert!(
2673 chunks.is_empty(),
2674 "Heading symbols must be filtered out before embedding; got {} chunk(s)",
2675 chunks.len()
2676 );
2677 }
2678
2679 #[test]
2683 fn symbols_to_chunks_keeps_code_symbols_alongside_skipped_headings() {
2684 let project_root = PathBuf::from("/proj");
2685 let file = project_root.join("src/lib.rs");
2686 let source = "pub fn handle_request() -> bool {\n true\n}\n";
2687
2688 let symbols = vec![
2689 make_symbol(SymbolKind::Heading, "doc heading", 0, 1),
2691 make_symbol(SymbolKind::Function, "handle_request", 0, 2),
2692 make_symbol(SymbolKind::Struct, "AuthService", 4, 6),
2693 ];
2694
2695 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2696 assert_eq!(
2697 chunks.len(),
2698 2,
2699 "Expected 2 code chunks (Function + Struct), got {}",
2700 chunks.len()
2701 );
2702 let names: Vec<&str> = chunks.iter().map(|c| c.name.as_str()).collect();
2703 assert!(names.contains(&"handle_request"));
2704 assert!(names.contains(&"AuthService"));
2705 assert!(
2706 !names.contains(&"doc heading"),
2707 "Heading symbol leaked into chunks: {names:?}"
2708 );
2709 }
2710
2711 #[test]
2712 fn validate_ssrf_rejects_loopback_hostnames() {
2713 for host in &[
2714 "http://localhost",
2715 "http://localhost:8080",
2716 "http://localhost.localdomain",
2717 "http://foo.localhost",
2718 ] {
2719 assert!(
2720 validate_base_url_no_ssrf(host).is_err(),
2721 "Expected {host} to be rejected"
2722 );
2723 }
2724 }
2725
2726 #[test]
2727 fn validate_ssrf_rejects_private_ips() {
2728 for url in &[
2729 "http://192.168.1.1",
2730 "http://10.0.0.1",
2731 "http://172.16.0.1",
2732 "http://127.0.0.1",
2733 "http://169.254.169.254",
2734 ] {
2735 let result = validate_base_url_no_ssrf(url);
2736 assert!(
2737 result.is_err(),
2738 "Expected {url} to be rejected, got: {:?}",
2739 result
2740 );
2741 }
2742 }
2743
2744 #[test]
2745 fn normalize_base_url_allows_localhost_for_tests() {
2746 assert!(normalize_base_url("http://127.0.0.1:9999").is_ok());
2749 assert!(normalize_base_url("http://localhost:8080").is_ok());
2750 }
2751}