Skip to main content

aft/
semantic_index.rs

1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::{detect_language, extract_symbols_from_tree, grammar_for};
3use crate::symbols::{Symbol, SymbolKind};
4use crate::{slog_info, slog_warn};
5
6use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
7use rayon::prelude::*;
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::env;
12use std::fmt::Display;
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16use std::time::SystemTime;
17use tree_sitter::Parser;
18use url::Url;
19
20const DEFAULT_DIMENSION: usize = 384;
21const MAX_ENTRIES: usize = 1_000_000;
22const MAX_DIMENSION: usize = 1024;
23const F32_BYTES: usize = std::mem::size_of::<f32>();
24const HEADER_BYTES_V1: usize = 9;
25const HEADER_BYTES_V2: usize = 13;
26const ONNX_RUNTIME_INSTALL_HINT: &str =
27    "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
28
29const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
30const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
31/// V3 adds subsec_nanos to the file-mtime table so staleness detection survives
32/// restart round-trips on filesystems with subsecond mtime precision (APFS,
33/// ext4 with nsec, NTFS). V1/V2 persisted whole-second mtimes only, which
34/// caused every restart to flag ~99% of files as stale and re-embed them.
35const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
36/// V4 keeps the V3 on-disk layout but rebuilds persisted snippets once after
37/// fixing symbol ranges that were incorrectly treated as 1-based.
38const SEMANTIC_INDEX_VERSION_V4: u8 = 4;
39const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
40const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
41// Must stay below the bridge timeout (30s) to avoid bridge kills on slow backends.
42const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
43const DEFAULT_MAX_BATCH_SIZE: usize = 64;
44const FALLBACK_BACKEND: &str = "none";
45const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
46const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SemanticIndexFingerprint {
50    pub backend: String,
51    pub model: String,
52    #[serde(default)]
53    pub base_url: String,
54    pub dimension: usize,
55}
56
57impl SemanticIndexFingerprint {
58    fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
59        // Use normalized URL for fingerprinting so cosmetic differences
60        // (e.g. "http://host/v1" vs "http://host/v1/") don't cause rebuilds.
61        let base_url = config
62            .base_url
63            .as_ref()
64            .and_then(|u| normalize_base_url(u).ok())
65            .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
66        Self {
67            backend: config.backend.as_str().to_string(),
68            model: config.model.clone(),
69            base_url,
70            dimension,
71        }
72    }
73
74    pub fn as_string(&self) -> String {
75        serde_json::to_string(self).unwrap_or_else(|_| String::new())
76    }
77
78    fn matches_expected(&self, expected: &str) -> bool {
79        let encoded = self.as_string();
80        !encoded.is_empty() && encoded == expected
81    }
82}
83
84enum SemanticEmbeddingEngine {
85    Fastembed(TextEmbedding),
86    OpenAiCompatible {
87        client: Client,
88        model: String,
89        base_url: String,
90        api_key: Option<String>,
91    },
92    Ollama {
93        client: Client,
94        model: String,
95        base_url: String,
96    },
97}
98
99pub struct SemanticEmbeddingModel {
100    backend: SemanticBackend,
101    model: String,
102    base_url: Option<String>,
103    timeout_ms: u64,
104    max_batch_size: usize,
105    dimension: Option<usize>,
106    engine: SemanticEmbeddingEngine,
107}
108
109pub type EmbeddingModel = SemanticEmbeddingModel;
110
111fn validate_embedding_batch(
112    vectors: &[Vec<f32>],
113    expected_count: usize,
114    context: &str,
115) -> Result<(), String> {
116    if expected_count > 0 && vectors.is_empty() {
117        return Err(format!(
118            "{context} returned no vectors for {expected_count} inputs"
119        ));
120    }
121
122    if vectors.len() != expected_count {
123        return Err(format!(
124            "{context} returned {} vectors for {} inputs",
125            vectors.len(),
126            expected_count
127        ));
128    }
129
130    let Some(first_vector) = vectors.first() else {
131        return Ok(());
132    };
133    let expected_dimension = first_vector.len();
134    for (index, vector) in vectors.iter().enumerate() {
135        if vector.len() != expected_dimension {
136            return Err(format!(
137                "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
138                vector.len()
139            ));
140        }
141    }
142
143    Ok(())
144}
145
146/// Normalize a base URL: validate scheme and strip trailing slash.
147/// Does NOT perform SSRF/private-IP validation — call
148/// `validate_base_url_no_ssrf` separately when processing user-supplied config.
149fn normalize_base_url(raw: &str) -> Result<String, String> {
150    let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
151    let scheme = parsed.scheme();
152    if scheme != "http" && scheme != "https" {
153        return Err(format!(
154            "unsupported URL scheme '{}' — only http:// and https:// are allowed",
155            scheme
156        ));
157    }
158    Ok(parsed.to_string().trim_end_matches('/').to_string())
159}
160
161/// Validate that a base URL does not point to a private/loopback address.
162/// Call this on user-supplied config (at configure time) to prevent SSRF.
163/// Not called for programmatically constructed configs (e.g. tests).
164pub fn validate_base_url_no_ssrf(raw: &str) -> Result<(), String> {
165    use std::net::{IpAddr, ToSocketAddrs};
166
167    let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
168
169    // Reject well-known private/loopback hostnames before DNS lookup.
170    let host = parsed.host_str().unwrap_or("");
171    if host == "localhost"
172        || host == "localhost.localdomain"
173        || host.ends_with(".localhost")
174        || host.ends_with(".local")
175    {
176        return Err(format!(
177            "base_url host '{host}' resolves to a private/loopback address — only public endpoints are allowed"
178        ));
179    }
180
181    // Resolve the hostname and reject any private/loopback/link-local IPs.
182    let port = parsed.port_or_known_default().unwrap_or(443);
183    let addr_str = format!("{host}:{port}");
184    let addrs: Vec<IpAddr> = addr_str
185        .to_socket_addrs()
186        .map(|iter| iter.map(|sa| sa.ip()).collect())
187        .unwrap_or_default();
188    for ip in &addrs {
189        if is_private_ip(ip) {
190            return Err(format!(
191                "base_url '{raw}' resolves to a private/reserved IP address — only public endpoints are allowed"
192            ));
193        }
194    }
195
196    Ok(())
197}
198
199fn is_private_ip(ip: &std::net::IpAddr) -> bool {
200    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
201    match ip {
202        IpAddr::V4(v4) => {
203            let o = v4.octets();
204            // 10.0.0.0/8
205            o[0] == 10
206            // 172.16.0.0/12
207            || (o[0] == 172 && (16..=31).contains(&o[1]))
208            // 192.168.0.0/16
209            || (o[0] == 192 && o[1] == 168)
210            // 127.0.0.0/8 loopback
211            || o[0] == 127
212            // 169.254.0.0/16 link-local
213            || (o[0] == 169 && o[1] == 254)
214            // 100.64.0.0/10 CGNAT
215            || (o[0] == 100 && (64..=127).contains(&o[1]))
216            // 0.0.0.0/8
217            || o[0] == 0
218        }
219        IpAddr::V6(v6) => {
220            // ::1 loopback
221            *v6 == Ipv6Addr::LOCALHOST
222            // fe80::/10 link-local
223            || (v6.segments()[0] & 0xffc0) == 0xfe80
224            // fc00::/7 unique-local
225            || (v6.segments()[0] & 0xfe00) == 0xfc00
226            // ::ffff:0:0/96 IPv4-mapped — check the embedded IPv4
227            || (v6.segments()[0] == 0 && v6.segments()[1] == 0
228                && v6.segments()[2] == 0 && v6.segments()[3] == 0
229                && v6.segments()[4] == 0 && v6.segments()[5] == 0xffff
230                && {
231                    let [a, b, c, d] = v6.segments()[6..8] else { return false; };
232                    let ipv4 = Ipv4Addr::new((a >> 8) as u8, (a & 0xff) as u8, (b >> 8) as u8, (b & 0xff) as u8);
233                    is_private_ip(&IpAddr::V4(ipv4))
234                })
235        }
236    }
237}
238
239fn build_openai_embeddings_endpoint(base_url: &str) -> String {
240    if base_url.ends_with("/v1") {
241        format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
242    } else {
243        format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
244    }
245}
246
247fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
248    if base_url.ends_with("/api") {
249        format!("{base_url}/embed")
250    } else {
251        format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
252    }
253}
254
255fn normalize_api_key(value: Option<String>) -> Option<String> {
256    value.and_then(|token| {
257        let token = token.trim();
258        if token.is_empty() {
259            None
260        } else {
261            Some(token.to_string())
262        }
263    })
264}
265
266fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
267    status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
268}
269
270fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
271    error.is_connect()
272}
273
274fn sleep_before_embedding_retry(attempt_index: usize) {
275    if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
276        std::thread::sleep(Duration::from_millis(*delay_ms));
277    }
278}
279
280fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
281where
282    F: FnMut() -> reqwest::blocking::RequestBuilder,
283{
284    for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
285        let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
286
287        let response = match make_request().send() {
288            Ok(response) => response,
289            Err(error) => {
290                if !last_attempt && is_retryable_embedding_error(&error) {
291                    sleep_before_embedding_retry(attempt_index);
292                    continue;
293                }
294                return Err(format!("{backend_label} request failed: {error}"));
295            }
296        };
297
298        let status = response.status();
299        let raw = match response.text() {
300            Ok(raw) => raw,
301            Err(error) => {
302                if !last_attempt && is_retryable_embedding_error(&error) {
303                    sleep_before_embedding_retry(attempt_index);
304                    continue;
305                }
306                return Err(format!("{backend_label} response read failed: {error}"));
307            }
308        };
309
310        if status.is_success() {
311            return Ok(raw);
312        }
313
314        if !last_attempt && is_retryable_embedding_status(status) {
315            sleep_before_embedding_retry(attempt_index);
316            continue;
317        }
318
319        return Err(format!(
320            "{backend_label} request failed (HTTP {}): {}",
321            status, raw
322        ));
323    }
324
325    unreachable!("embedding request retries exhausted without returning")
326}
327
328impl SemanticEmbeddingModel {
329    pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
330        let timeout_ms = if config.timeout_ms == 0 {
331            DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
332        } else {
333            config.timeout_ms
334        };
335
336        let max_batch_size = if config.max_batch_size == 0 {
337            DEFAULT_MAX_BATCH_SIZE
338        } else {
339            config.max_batch_size
340        };
341
342        let api_key_env = normalize_api_key(config.api_key_env.clone());
343        let model = config.model.clone();
344
345        let client = Client::builder()
346            .timeout(Duration::from_millis(timeout_ms))
347            .redirect(reqwest::redirect::Policy::none())
348            .build()
349            .map_err(|error| format!("failed to configure embedding client: {error}"))?;
350
351        let engine = match config.backend {
352            SemanticBackend::Fastembed => {
353                SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
354            }
355            SemanticBackend::OpenAiCompatible => {
356                let raw = config.base_url.as_ref().ok_or_else(|| {
357                    "base_url is required for openai_compatible backend".to_string()
358                })?;
359                let base_url = normalize_base_url(raw)?;
360
361                let api_key = match api_key_env {
362                    Some(var_name) => Some(env::var(&var_name).map_err(|_| {
363                        format!("missing api_key_env '{var_name}' for openai_compatible backend")
364                    })?),
365                    None => None,
366                };
367
368                SemanticEmbeddingEngine::OpenAiCompatible {
369                    client,
370                    model,
371                    base_url,
372                    api_key,
373                }
374            }
375            SemanticBackend::Ollama => {
376                let raw = config
377                    .base_url
378                    .as_ref()
379                    .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
380                let base_url = normalize_base_url(raw)?;
381
382                SemanticEmbeddingEngine::Ollama {
383                    client,
384                    model,
385                    base_url,
386                }
387            }
388        };
389
390        Ok(Self {
391            backend: config.backend,
392            model: config.model.clone(),
393            base_url: config.base_url.clone(),
394            timeout_ms,
395            max_batch_size,
396            dimension: None,
397            engine,
398        })
399    }
400
401    pub fn backend(&self) -> SemanticBackend {
402        self.backend
403    }
404
405    pub fn model(&self) -> &str {
406        &self.model
407    }
408
409    pub fn base_url(&self) -> Option<&str> {
410        self.base_url.as_deref()
411    }
412
413    pub fn max_batch_size(&self) -> usize {
414        self.max_batch_size
415    }
416
417    pub fn timeout_ms(&self) -> u64 {
418        self.timeout_ms
419    }
420
421    pub fn fingerprint(
422        &mut self,
423        config: &SemanticBackendConfig,
424    ) -> Result<SemanticIndexFingerprint, String> {
425        let dimension = self.dimension()?;
426        Ok(SemanticIndexFingerprint::from_config(config, dimension))
427    }
428
429    pub fn dimension(&mut self) -> Result<usize, String> {
430        if let Some(dimension) = self.dimension {
431            return Ok(dimension);
432        }
433
434        let dimension = match &mut self.engine {
435            SemanticEmbeddingEngine::Fastembed(model) => {
436                let vectors = model
437                    .embed(vec!["semantic index fingerprint probe".to_string()], None)
438                    .map_err(|error| format_embedding_init_error(error.to_string()))?;
439                vectors
440                    .first()
441                    .map(|v| v.len())
442                    .ok_or_else(|| "embedding backend returned no vectors".to_string())?
443            }
444            SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
445                let vectors =
446                    self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
447                vectors
448                    .first()
449                    .map(|v| v.len())
450                    .ok_or_else(|| "embedding backend returned no vectors".to_string())?
451            }
452            SemanticEmbeddingEngine::Ollama { .. } => {
453                let vectors =
454                    self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
455                vectors
456                    .first()
457                    .map(|v| v.len())
458                    .ok_or_else(|| "embedding backend returned no vectors".to_string())?
459            }
460        };
461
462        self.dimension = Some(dimension);
463        Ok(dimension)
464    }
465
466    pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
467        self.embed_texts(texts)
468    }
469
470    fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
471        match &mut self.engine {
472            SemanticEmbeddingEngine::Fastembed(model) => model
473                .embed(texts, None::<usize>)
474                .map_err(|error| format_embedding_init_error(error.to_string()))
475                .map_err(|error| format!("failed to embed batch: {error}")),
476            SemanticEmbeddingEngine::OpenAiCompatible {
477                client,
478                model,
479                base_url,
480                api_key,
481            } => {
482                let expected_text_count = texts.len();
483                let endpoint = build_openai_embeddings_endpoint(base_url);
484                let body = serde_json::json!({
485                    "input": texts,
486                    "model": model,
487                });
488
489                let raw = send_embedding_request(
490                    || {
491                        let mut request = client
492                            .post(&endpoint)
493                            .json(&body)
494                            .header("Content-Type", "application/json");
495
496                        if let Some(api_key) = api_key {
497                            request = request.header("Authorization", format!("Bearer {api_key}"));
498                        }
499
500                        request
501                    },
502                    "openai compatible",
503                )?;
504
505                #[derive(Deserialize)]
506                struct OpenAiResponse {
507                    data: Vec<OpenAiEmbeddingResult>,
508                }
509
510                #[derive(Deserialize)]
511                struct OpenAiEmbeddingResult {
512                    embedding: Vec<f32>,
513                    index: Option<u32>,
514                }
515
516                let parsed: OpenAiResponse = serde_json::from_str(&raw)
517                    .map_err(|error| format!("invalid openai compatible response: {error}"))?;
518                if parsed.data.len() != expected_text_count {
519                    return Err(format!(
520                        "openai compatible response returned {} embeddings for {} inputs",
521                        parsed.data.len(),
522                        expected_text_count
523                    ));
524                }
525
526                let mut vectors = vec![Vec::new(); parsed.data.len()];
527                for (i, item) in parsed.data.into_iter().enumerate() {
528                    let index = item.index.unwrap_or(i as u32) as usize;
529                    if index >= vectors.len() {
530                        return Err(
531                            "openai compatible response contains invalid vector index".to_string()
532                        );
533                    }
534                    vectors[index] = item.embedding;
535                }
536
537                for vector in &vectors {
538                    if vector.is_empty() {
539                        return Err(
540                            "openai compatible response contained missing vectors".to_string()
541                        );
542                    }
543                }
544
545                self.dimension = vectors.first().map(Vec::len);
546                Ok(vectors)
547            }
548            SemanticEmbeddingEngine::Ollama {
549                client,
550                model,
551                base_url,
552            } => {
553                let expected_text_count = texts.len();
554                let endpoint = build_ollama_embeddings_endpoint(base_url);
555
556                #[derive(Serialize)]
557                struct OllamaPayload<'a> {
558                    model: &'a str,
559                    input: Vec<String>,
560                }
561
562                let payload = OllamaPayload {
563                    model,
564                    input: texts,
565                };
566
567                let raw = send_embedding_request(
568                    || {
569                        client
570                            .post(&endpoint)
571                            .json(&payload)
572                            .header("Content-Type", "application/json")
573                    },
574                    "ollama",
575                )?;
576
577                #[derive(Deserialize)]
578                struct OllamaResponse {
579                    embeddings: Vec<Vec<f32>>,
580                }
581
582                let parsed: OllamaResponse = serde_json::from_str(&raw)
583                    .map_err(|error| format!("invalid ollama response: {error}"))?;
584                if parsed.embeddings.is_empty() {
585                    return Err("ollama response returned no embeddings".to_string());
586                }
587                if parsed.embeddings.len() != expected_text_count {
588                    return Err(format!(
589                        "ollama response returned {} embeddings for {} inputs",
590                        parsed.embeddings.len(),
591                        expected_text_count
592                    ));
593                }
594
595                let vectors = parsed.embeddings;
596                for vector in &vectors {
597                    if vector.is_empty() {
598                        return Err("ollama response contained empty embeddings".to_string());
599                    }
600                }
601
602                self.dimension = vectors.first().map(Vec::len);
603                Ok(vectors)
604            }
605        }
606    }
607}
608
609/// Pre-validate ONNX Runtime by attempting a raw dlopen before ort touches it.
610/// This catches broken/incompatible .so files without risking a panic in the ort crate.
611/// Also checks the runtime version via OrtGetApiBase if available.
612pub fn pre_validate_onnx_runtime() -> Result<(), String> {
613    let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
614
615    #[cfg(any(target_os = "linux", target_os = "macos"))]
616    {
617        #[cfg(target_os = "linux")]
618        let default_name = "libonnxruntime.so";
619        #[cfg(target_os = "macos")]
620        let default_name = "libonnxruntime.dylib";
621
622        let lib_name = dylib_path.as_deref().unwrap_or(default_name);
623
624        unsafe {
625            let c_name = std::ffi::CString::new(lib_name)
626                .map_err(|e| format!("invalid library path: {}", e))?;
627            let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
628            if handle.is_null() {
629                let err = libc::dlerror();
630                let msg = if err.is_null() {
631                    "unknown dlopen error".to_string()
632                } else {
633                    std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
634                };
635                return Err(format!(
636                    "ONNX Runtime not found. dlopen('{}') failed: {}. \
637                     Run `bunx @cortexkit/aft-opencode@latest doctor` to diagnose.",
638                    lib_name, msg
639                ));
640            }
641
642            // Try to detect the runtime version from the file path or soname.
643            // libonnxruntime.so.1.19.0, libonnxruntime.1.24.4.dylib, etc.
644            let detected_version = detect_ort_version_from_path(lib_name);
645
646            libc::dlclose(handle);
647
648            // Check version compatibility — we need 1.24.x
649            if let Some(ref version) = detected_version {
650                let parts: Vec<&str> = version.split('.').collect();
651                if let (Some(major), Some(minor)) = (
652                    parts.first().and_then(|s| s.parse::<u32>().ok()),
653                    parts.get(1).and_then(|s| s.parse::<u32>().ok()),
654                ) {
655                    if major != 1 || minor < 20 {
656                        return Err(format!(
657                            "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
658                             Solutions:\n\
659                             1. Remove the old library and restart (AFT auto-downloads the correct version):\n\
660                             {}\n\
661                             2. Or install ONNX Runtime 1.24: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
662                             3. Run `bunx @cortexkit/aft-opencode@latest doctor` for full diagnostics.",
663                            version,
664                            lib_name,
665                            suggest_removal_command(lib_name),
666                        ));
667                    }
668                }
669            }
670        }
671    }
672
673    #[cfg(target_os = "windows")]
674    {
675        // On Windows, skip pre-validation — let ort handle LoadLibrary
676        let _ = dylib_path;
677    }
678
679    Ok(())
680}
681
682/// Try to extract the ORT version from the library filename or resolved symlink.
683/// Examples: "libonnxruntime.so.1.19.0" → "1.19.0", "libonnxruntime.1.24.4.dylib" → "1.24.4"
684fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
685    let path = std::path::Path::new(lib_path);
686
687    // Try the path as given, then follow symlinks
688    for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
689        .into_iter()
690        .flatten()
691    {
692        if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
693            if let Some(version) = extract_version_from_filename(name) {
694                return Some(version);
695            }
696        }
697    }
698
699    // Also check for versioned siblings in the same directory
700    if let Some(parent) = path.parent() {
701        if let Ok(entries) = std::fs::read_dir(parent) {
702            for entry in entries.flatten() {
703                if let Some(name) = entry.file_name().to_str() {
704                    if name.starts_with("libonnxruntime") {
705                        if let Some(version) = extract_version_from_filename(name) {
706                            return Some(version);
707                        }
708                    }
709                }
710            }
711        }
712    }
713
714    None
715}
716
717/// Extract version from filenames like "libonnxruntime.so.1.19.0" or "libonnxruntime.1.24.4.dylib"
718fn extract_version_from_filename(name: &str) -> Option<String> {
719    // Match patterns: .so.X.Y.Z or .X.Y.Z.dylib or .X.Y.Z.so
720    let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
721    re.find(name).map(|m| m.as_str().to_string())
722}
723
724fn suggest_removal_command(lib_path: &str) -> String {
725    if lib_path.starts_with("/usr/local/lib")
726        || lib_path == "libonnxruntime.so"
727        || lib_path == "libonnxruntime.dylib"
728    {
729        #[cfg(target_os = "linux")]
730        return "   sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
731        #[cfg(target_os = "macos")]
732        return "   sudo rm /usr/local/lib/libonnxruntime*".to_string();
733        #[cfg(target_os = "windows")]
734        return "   Delete the ONNX Runtime DLL from your PATH".to_string();
735    }
736    format!("   rm '{}'", lib_path)
737}
738
739pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
740    // Pre-validate before ort can panic on a bad library
741    pre_validate_onnx_runtime()?;
742
743    let selected_model = match model {
744        "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
745        _ => {
746            return Err(format!(
747                "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
748                model
749            ))
750        }
751    };
752
753    TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
754}
755
756pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
757    if message.trim_start().starts_with("ONNX Runtime not found.") {
758        return true;
759    }
760
761    let message = message.to_ascii_lowercase();
762    let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
763        .iter()
764        .any(|pattern| message.contains(pattern));
765    let mentions_dynamic_load_failure = [
766        "shared library",
767        "dynamic library",
768        "failed to load",
769        "could not load",
770        "unable to load",
771        "dlopen",
772        "loadlibrary",
773        "no such file",
774        "not found",
775    ]
776    .iter()
777    .any(|pattern| message.contains(pattern));
778
779    mentions_onnx_runtime && mentions_dynamic_load_failure
780}
781
782fn format_embedding_init_error(error: impl Display) -> String {
783    let message = error.to_string();
784
785    if is_onnx_runtime_unavailable(&message) {
786        return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
787    }
788
789    format!("failed to initialize semantic embedding model: {message}")
790}
791
792/// A chunk of code ready for embedding — derived from a Symbol with context enrichment
793#[derive(Debug, Clone)]
794pub struct SemanticChunk {
795    /// Absolute file path
796    pub file: PathBuf,
797    /// Symbol name
798    pub name: String,
799    /// Symbol kind (function, class, struct, etc.)
800    pub kind: SymbolKind,
801    /// Line range (0-based internally, inclusive)
802    pub start_line: u32,
803    pub end_line: u32,
804    /// Whether the symbol is exported
805    pub exported: bool,
806    /// The enriched text that gets embedded (scope + signature + body snippet)
807    pub embed_text: String,
808    /// Short code snippet for display in results
809    pub snippet: String,
810}
811
812/// A stored embedding entry — chunk metadata + vector
813#[derive(Debug)]
814struct EmbeddingEntry {
815    chunk: SemanticChunk,
816    vector: Vec<f32>,
817}
818
819/// The semantic index — stores embeddings for all symbols in a project
820#[derive(Debug)]
821pub struct SemanticIndex {
822    entries: Vec<EmbeddingEntry>,
823    /// Track which files are indexed and their mtime for staleness detection
824    file_mtimes: HashMap<PathBuf, SystemTime>,
825    /// Embedding dimension (384 for MiniLM-L6-v2)
826    dimension: usize,
827    fingerprint: Option<SemanticIndexFingerprint>,
828}
829
830/// Search result from a semantic query
831#[derive(Debug)]
832pub struct SemanticResult {
833    pub file: PathBuf,
834    pub name: String,
835    pub kind: SymbolKind,
836    pub start_line: u32,
837    pub end_line: u32,
838    pub exported: bool,
839    pub snippet: String,
840    pub score: f32,
841}
842
843impl SemanticIndex {
844    pub fn new() -> Self {
845        Self {
846            entries: Vec::new(),
847            file_mtimes: HashMap::new(),
848            dimension: DEFAULT_DIMENSION, // MiniLM-L6-v2 default
849            fingerprint: None,
850        }
851    }
852
853    /// Number of embedded symbol entries.
854    pub fn entry_count(&self) -> usize {
855        self.entries.len()
856    }
857
858    /// Human-readable status label for the index.
859    pub fn status_label(&self) -> &'static str {
860        if self.entries.is_empty() {
861            "empty"
862        } else {
863            "ready"
864        }
865    }
866
867    fn collect_chunks(
868        project_root: &Path,
869        files: &[PathBuf],
870    ) -> (Vec<SemanticChunk>, HashMap<PathBuf, SystemTime>) {
871        let per_file: Vec<(PathBuf, SystemTime, Vec<SemanticChunk>)> = files
872            .par_iter()
873            .map_init(HashMap::new, |parsers, file| {
874                let mtime = std::fs::metadata(file)
875                    .and_then(|m| m.modified())
876                    .unwrap_or(SystemTime::UNIX_EPOCH);
877
878                let chunks = collect_file_chunks(project_root, file, parsers).unwrap_or_default();
879
880                (file.clone(), mtime, chunks)
881            })
882            .collect();
883
884        let mut chunks: Vec<SemanticChunk> = Vec::new();
885        let mut file_mtimes: HashMap<PathBuf, SystemTime> = HashMap::new();
886
887        for (file, mtime, file_chunks) in per_file {
888            file_mtimes.insert(file, mtime);
889            chunks.extend(file_chunks);
890        }
891
892        (chunks, file_mtimes)
893    }
894
895    fn build_from_chunks<F, P>(
896        chunks: Vec<SemanticChunk>,
897        file_mtimes: HashMap<PathBuf, SystemTime>,
898        embed_fn: &mut F,
899        max_batch_size: usize,
900        mut progress: Option<&mut P>,
901    ) -> Result<Self, String>
902    where
903        F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
904        P: FnMut(usize, usize),
905    {
906        let total_chunks = chunks.len();
907
908        if chunks.is_empty() {
909            return Ok(Self {
910                entries: Vec::new(),
911                file_mtimes,
912                dimension: DEFAULT_DIMENSION,
913                fingerprint: None,
914            });
915        }
916
917        // Embed in batches
918        let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
919        let mut expected_dimension: Option<usize> = None;
920        let batch_size = max_batch_size.max(1);
921        for batch_start in (0..chunks.len()).step_by(batch_size) {
922            let batch_end = (batch_start + batch_size).min(chunks.len());
923            let batch_texts: Vec<String> = chunks[batch_start..batch_end]
924                .iter()
925                .map(|c| c.embed_text.clone())
926                .collect();
927
928            let vectors = embed_fn(batch_texts)?;
929            validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
930
931            // Track consistent dimension across all batches
932            if let Some(dim) = vectors.first().map(|v| v.len()) {
933                match expected_dimension {
934                    None => expected_dimension = Some(dim),
935                    Some(expected) if dim != expected => {
936                        return Err(format!(
937                            "embedding dimension changed across batches: expected {expected}, got {dim}"
938                        ));
939                    }
940                    _ => {}
941                }
942            }
943
944            for (i, vector) in vectors.into_iter().enumerate() {
945                let chunk_idx = batch_start + i;
946                entries.push(EmbeddingEntry {
947                    chunk: chunks[chunk_idx].clone(),
948                    vector,
949                });
950            }
951
952            if let Some(callback) = progress.as_mut() {
953                callback(entries.len(), total_chunks);
954            }
955        }
956
957        let dimension = entries
958            .first()
959            .map(|e| e.vector.len())
960            .unwrap_or(DEFAULT_DIMENSION);
961
962        Ok(Self {
963            entries,
964            file_mtimes,
965            dimension,
966            fingerprint: None,
967        })
968    }
969
970    /// Build the semantic index from a set of files using the provided embedding function.
971    /// `embed_fn` takes a batch of texts and returns a batch of embedding vectors.
972    pub fn build<F>(
973        project_root: &Path,
974        files: &[PathBuf],
975        embed_fn: &mut F,
976        max_batch_size: usize,
977    ) -> Result<Self, String>
978    where
979        F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
980    {
981        let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
982        Self::build_from_chunks(
983            chunks,
984            file_mtimes,
985            embed_fn,
986            max_batch_size,
987            Option::<&mut fn(usize, usize)>::None,
988        )
989    }
990
991    /// Build the semantic index and report embedding progress using entry counts.
992    pub fn build_with_progress<F, P>(
993        project_root: &Path,
994        files: &[PathBuf],
995        embed_fn: &mut F,
996        max_batch_size: usize,
997        progress: &mut P,
998    ) -> Result<Self, String>
999    where
1000        F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1001        P: FnMut(usize, usize),
1002    {
1003        let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1004        let total_chunks = chunks.len();
1005        progress(0, total_chunks);
1006        Self::build_from_chunks(
1007            chunks,
1008            file_mtimes,
1009            embed_fn,
1010            max_batch_size,
1011            Some(progress),
1012        )
1013    }
1014
1015    /// Search the index with a query embedding, returning top-K results sorted by relevance
1016    pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
1017        if self.entries.is_empty() || query_vector.len() != self.dimension {
1018            return Vec::new();
1019        }
1020
1021        let mut scored: Vec<(f32, usize)> = self
1022            .entries
1023            .iter()
1024            .enumerate()
1025            .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
1026            .collect();
1027
1028        // Sort descending by score
1029        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
1030
1031        scored
1032            .into_iter()
1033            .take(top_k)
1034            .filter(|(score, _)| *score > 0.0)
1035            .map(|(score, idx)| {
1036                let entry = &self.entries[idx];
1037                SemanticResult {
1038                    file: entry.chunk.file.clone(),
1039                    name: entry.chunk.name.clone(),
1040                    kind: entry.chunk.kind.clone(),
1041                    start_line: entry.chunk.start_line,
1042                    end_line: entry.chunk.end_line,
1043                    exported: entry.chunk.exported,
1044                    snippet: entry.chunk.snippet.clone(),
1045                    score,
1046                }
1047            })
1048            .collect()
1049    }
1050
1051    /// Number of indexed entries
1052    pub fn len(&self) -> usize {
1053        self.entries.len()
1054    }
1055
1056    /// Check if a file needs re-indexing based on mtime
1057    pub fn is_file_stale(&self, file: &Path) -> bool {
1058        match self.file_mtimes.get(file) {
1059            None => true,
1060            Some(stored_mtime) => match fs::metadata(file).and_then(|m| m.modified()) {
1061                Ok(current_mtime) => *stored_mtime != current_mtime,
1062                Err(_) => true,
1063            },
1064        }
1065    }
1066
1067    pub fn count_stale_files(&self) -> usize {
1068        self.file_mtimes
1069            .keys()
1070            .filter(|path| self.is_file_stale(path))
1071            .count()
1072    }
1073
1074    /// Remove entries for a specific file
1075    pub fn remove_file(&mut self, file: &Path) {
1076        self.invalidate_file(file);
1077    }
1078
1079    pub fn invalidate_file(&mut self, file: &Path) {
1080        self.entries.retain(|e| e.chunk.file != file);
1081        self.file_mtimes.remove(file);
1082    }
1083
1084    /// Get the embedding dimension
1085    pub fn dimension(&self) -> usize {
1086        self.dimension
1087    }
1088
1089    pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1090        self.fingerprint.as_ref()
1091    }
1092
1093    pub fn backend_label(&self) -> Option<&str> {
1094        self.fingerprint.as_ref().map(|f| f.backend.as_str())
1095    }
1096
1097    pub fn model_label(&self) -> Option<&str> {
1098        self.fingerprint.as_ref().map(|f| f.model.as_str())
1099    }
1100
1101    pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1102        self.fingerprint = Some(fingerprint);
1103    }
1104
1105    /// Write the semantic index to disk using atomic temp+rename pattern
1106    pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1107        // Don't persist empty indexes — they would be loaded on next startup
1108        // and prevent a fresh build that might find files.
1109        if self.entries.is_empty() {
1110            slog_info!("skipping semantic index persistence (0 entries)");
1111            return;
1112        }
1113        let dir = storage_dir.join("semantic").join(project_key);
1114        if let Err(e) = fs::create_dir_all(&dir) {
1115            slog_warn!("failed to create semantic cache dir: {}", e);
1116            return;
1117        }
1118        let data_path = dir.join("semantic.bin");
1119        let tmp_path = dir.join("semantic.bin.tmp");
1120        let bytes = self.to_bytes();
1121        if let Err(e) = fs::write(&tmp_path, &bytes) {
1122            slog_warn!("failed to write semantic index: {}", e);
1123            let _ = fs::remove_file(&tmp_path);
1124            return;
1125        }
1126        if let Err(e) = fs::rename(&tmp_path, &data_path) {
1127            slog_warn!("failed to rename semantic index: {}", e);
1128            let _ = fs::remove_file(&tmp_path);
1129            return;
1130        }
1131        slog_info!(
1132            "semantic index persisted: {} entries, {:.1} KB",
1133            self.entries.len(),
1134            bytes.len() as f64 / 1024.0
1135        );
1136    }
1137
1138    /// Read the semantic index from disk
1139    pub fn read_from_disk(
1140        storage_dir: &Path,
1141        project_key: &str,
1142        expected_fingerprint: Option<&str>,
1143    ) -> Option<Self> {
1144        let data_path = storage_dir
1145            .join("semantic")
1146            .join(project_key)
1147            .join("semantic.bin");
1148        let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1149        if file_len < HEADER_BYTES_V1 {
1150            slog_warn!(
1151                "corrupt semantic index (too small: {} bytes), removing",
1152                file_len
1153            );
1154            let _ = fs::remove_file(&data_path);
1155            return None;
1156        }
1157
1158        let bytes = fs::read(&data_path).ok()?;
1159        let version = bytes[0];
1160        if version != SEMANTIC_INDEX_VERSION_V4 {
1161            slog_info!(
1162                "cached semantic index version {} is older than {}, rebuilding",
1163                version,
1164                SEMANTIC_INDEX_VERSION_V4
1165            );
1166            let _ = fs::remove_file(&data_path);
1167            return None;
1168        }
1169        match Self::from_bytes(&bytes) {
1170            Ok(index) => {
1171                if index.entries.is_empty() {
1172                    slog_info!("cached semantic index is empty, will rebuild");
1173                    let _ = fs::remove_file(&data_path);
1174                    return None;
1175                }
1176                if let Some(expected) = expected_fingerprint {
1177                    let matches = index
1178                        .fingerprint()
1179                        .map(|fingerprint| fingerprint.matches_expected(expected))
1180                        .unwrap_or(false);
1181                    if !matches {
1182                        slog_info!("cached semantic index fingerprint mismatch, rebuilding");
1183                        let _ = fs::remove_file(&data_path);
1184                        return None;
1185                    }
1186                }
1187                slog_info!(
1188                    "loaded semantic index from disk: {} entries",
1189                    index.entries.len()
1190                );
1191                Some(index)
1192            }
1193            Err(e) => {
1194                slog_warn!("corrupt semantic index, rebuilding: {}", e);
1195                let _ = fs::remove_file(&data_path);
1196                None
1197            }
1198        }
1199    }
1200
1201    /// Serialize the index to bytes for disk persistence
1202    pub fn to_bytes(&self) -> Vec<u8> {
1203        let mut buf = Vec::new();
1204        let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1205            let encoded = fingerprint.as_string();
1206            if encoded.is_empty() {
1207                None
1208            } else {
1209                Some(encoded.into_bytes())
1210            }
1211        });
1212
1213        // Header: version(1) + dimension(4) + entry_count(4) + fingerprint_len(4) + fingerprint
1214        //
1215        // V4 (v0.16.0+) is the single write format. Layout matches V3:
1216        //   - fingerprint is always represented (absent ⇒ fingerprint_len=0,
1217        //     no bytes follow). Uniform format simplifies the reader.
1218        //   - mtimes stored as secs(u64) + subsec_nanos(u32). Preserves full
1219        //     APFS/ext4/NTFS precision so staleness checks survive restart
1220        //     round-trips.
1221        //
1222        // V1/V2 remain readable for backward compatibility (see from_bytes).
1223        // V3 loads as the same format but is rejected on disk so snippets are
1224        // rebuilt once after the symbol-range bug fixed in v0.16.0.
1225        let version = SEMANTIC_INDEX_VERSION_V4;
1226        buf.push(version);
1227        buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1228        buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1229        let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1230        buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1231        buf.extend_from_slice(fp_bytes_ref);
1232
1233        // File mtime table: count(4) + entries
1234        // V3 layout per entry: path_len(4) + path + secs(8) + subsec_nanos(4)
1235        buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1236        for (path, mtime) in &self.file_mtimes {
1237            let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1238            buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1239            buf.extend_from_slice(&path_bytes);
1240            let duration = mtime
1241                .duration_since(SystemTime::UNIX_EPOCH)
1242                .unwrap_or_default();
1243            buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1244            buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1245        }
1246
1247        // Entries: each is metadata + vector
1248        for entry in &self.entries {
1249            let c = &entry.chunk;
1250
1251            // File path
1252            let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1253            buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1254            buf.extend_from_slice(&file_bytes);
1255
1256            // Name
1257            let name_bytes = c.name.as_bytes();
1258            buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1259            buf.extend_from_slice(name_bytes);
1260
1261            // Kind (1 byte)
1262            buf.push(symbol_kind_to_u8(&c.kind));
1263
1264            // Lines + exported
1265            buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1266            buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1267            buf.push(c.exported as u8);
1268
1269            // Snippet
1270            let snippet_bytes = c.snippet.as_bytes();
1271            buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1272            buf.extend_from_slice(snippet_bytes);
1273
1274            // Embed text
1275            let embed_bytes = c.embed_text.as_bytes();
1276            buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1277            buf.extend_from_slice(embed_bytes);
1278
1279            // Vector (f32 array)
1280            for &val in &entry.vector {
1281                buf.extend_from_slice(&val.to_le_bytes());
1282            }
1283        }
1284
1285        buf
1286    }
1287
1288    /// Deserialize the index from bytes
1289    pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1290        let mut pos = 0;
1291
1292        if data.len() < HEADER_BYTES_V1 {
1293            return Err("data too short".to_string());
1294        }
1295
1296        let version = data[pos];
1297        pos += 1;
1298        if version != SEMANTIC_INDEX_VERSION_V1
1299            && version != SEMANTIC_INDEX_VERSION_V2
1300            && version != SEMANTIC_INDEX_VERSION_V3
1301            && version != SEMANTIC_INDEX_VERSION_V4
1302        {
1303            return Err(format!("unsupported version: {}", version));
1304        }
1305        // V2, V3, and V4 share the same header layout (V3/V4 only differ from
1306        // V2 in the per-mtime entry layout): version(1) + dimension(4) +
1307        // entry_count(4) + fingerprint_len(4) + fingerprint bytes.
1308        if (version == SEMANTIC_INDEX_VERSION_V2
1309            || version == SEMANTIC_INDEX_VERSION_V3
1310            || version == SEMANTIC_INDEX_VERSION_V4)
1311            && data.len() < HEADER_BYTES_V2
1312        {
1313            return Err("data too short for semantic index v2/v3/v4 header".to_string());
1314        }
1315
1316        let dimension = read_u32(data, &mut pos)? as usize;
1317        let entry_count = read_u32(data, &mut pos)? as usize;
1318        if dimension == 0 || dimension > MAX_DIMENSION {
1319            return Err(format!("invalid embedding dimension: {}", dimension));
1320        }
1321        if entry_count > MAX_ENTRIES {
1322            return Err(format!("too many semantic index entries: {}", entry_count));
1323        }
1324
1325        // Fingerprint handling:
1326        //   - V1: no fingerprint field at all.
1327        //   - V2: fingerprint_len + fingerprint bytes; always present (writer
1328        //     only emitted V2 when fingerprint was Some).
1329        //   - V3: fingerprint_len always present; fingerprint_len==0 ⇒ None.
1330        let has_fingerprint_field = version == SEMANTIC_INDEX_VERSION_V2
1331            || version == SEMANTIC_INDEX_VERSION_V3
1332            || version == SEMANTIC_INDEX_VERSION_V4;
1333        let fingerprint = if has_fingerprint_field {
1334            let fingerprint_len = read_u32(data, &mut pos)? as usize;
1335            if pos + fingerprint_len > data.len() {
1336                return Err("unexpected end of data reading fingerprint".to_string());
1337            }
1338            if fingerprint_len == 0 {
1339                None
1340            } else {
1341                let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1342                pos += fingerprint_len;
1343                Some(
1344                    serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1345                        .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1346                )
1347            }
1348        } else {
1349            None
1350        };
1351
1352        // File mtimes
1353        let mtime_count = read_u32(data, &mut pos)? as usize;
1354        if mtime_count > MAX_ENTRIES {
1355            return Err(format!("too many semantic file mtimes: {}", mtime_count));
1356        }
1357
1358        let vector_bytes = entry_count
1359            .checked_mul(dimension)
1360            .and_then(|count| count.checked_mul(F32_BYTES))
1361            .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1362        if vector_bytes > data.len().saturating_sub(pos) {
1363            return Err("semantic index vectors exceed available data".to_string());
1364        }
1365
1366        let mut file_mtimes = HashMap::with_capacity(mtime_count);
1367        for _ in 0..mtime_count {
1368            let path = read_string(data, &mut pos)?;
1369            let secs = read_u64(data, &mut pos)?;
1370            // V3 persists subsec_nanos alongside secs so staleness checks
1371            // survive restart round-trips. V1/V2 load with 0 nanos, which
1372            // causes one rebuild on upgrade (they never matched live APFS
1373            // mtimes anyway — the bug v0.15.2 fixes). After that rebuild,
1374            // the cache is persisted as V3 and stabilises.
1375            let nanos =
1376                if version == SEMANTIC_INDEX_VERSION_V3 || version == SEMANTIC_INDEX_VERSION_V4 {
1377                    read_u32(data, &mut pos)?
1378                } else {
1379                    0
1380                };
1381            // Hardening against corrupt / maliciously crafted cache files
1382            // (v0.15.2). `Duration::new(secs, nanos)` can panic when the
1383            // nanosecond carry overflows the second counter, and
1384            // `SystemTime + Duration` can panic on carry past the platform's
1385            // upper bound. Explicit validation keeps a corrupted semantic.bin
1386            // from taking down the whole aft process.
1387            if nanos >= 1_000_000_000 {
1388                return Err(format!(
1389                    "invalid semantic mtime: nanos {} >= 1_000_000_000",
1390                    nanos
1391                ));
1392            }
1393            let duration = std::time::Duration::new(secs, nanos);
1394            let mtime = SystemTime::UNIX_EPOCH
1395                .checked_add(duration)
1396                .ok_or_else(|| {
1397                    format!(
1398                        "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1399                        secs, nanos
1400                    )
1401                })?;
1402            file_mtimes.insert(PathBuf::from(path), mtime);
1403        }
1404
1405        // Entries
1406        let mut entries = Vec::with_capacity(entry_count);
1407        for _ in 0..entry_count {
1408            let file = PathBuf::from(read_string(data, &mut pos)?);
1409            let name = read_string(data, &mut pos)?;
1410
1411            if pos >= data.len() {
1412                return Err("unexpected end of data".to_string());
1413            }
1414            let kind = u8_to_symbol_kind(data[pos]);
1415            pos += 1;
1416
1417            let start_line = read_u32(data, &mut pos)?;
1418            let end_line = read_u32(data, &mut pos)?;
1419
1420            if pos >= data.len() {
1421                return Err("unexpected end of data".to_string());
1422            }
1423            let exported = data[pos] != 0;
1424            pos += 1;
1425
1426            let snippet = read_string(data, &mut pos)?;
1427            let embed_text = read_string(data, &mut pos)?;
1428
1429            // Vector
1430            let vec_bytes = dimension
1431                .checked_mul(F32_BYTES)
1432                .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1433            if pos + vec_bytes > data.len() {
1434                return Err("unexpected end of data reading vector".to_string());
1435            }
1436            let mut vector = Vec::with_capacity(dimension);
1437            for _ in 0..dimension {
1438                let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1439                vector.push(f32::from_le_bytes(bytes));
1440                pos += 4;
1441            }
1442
1443            entries.push(EmbeddingEntry {
1444                chunk: SemanticChunk {
1445                    file,
1446                    name,
1447                    kind,
1448                    start_line,
1449                    end_line,
1450                    exported,
1451                    embed_text,
1452                    snippet,
1453                },
1454                vector,
1455            });
1456        }
1457
1458        Ok(Self {
1459            entries,
1460            file_mtimes,
1461            dimension,
1462            fingerprint,
1463        })
1464    }
1465}
1466
1467/// Build enriched embedding text from a symbol with cAST-style context
1468fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1469    let relative = file
1470        .strip_prefix(project_root)
1471        .unwrap_or(file)
1472        .to_string_lossy();
1473
1474    let kind_label = match &symbol.kind {
1475        SymbolKind::Function => "function",
1476        SymbolKind::Class => "class",
1477        SymbolKind::Method => "method",
1478        SymbolKind::Struct => "struct",
1479        SymbolKind::Interface => "interface",
1480        SymbolKind::Enum => "enum",
1481        SymbolKind::TypeAlias => "type",
1482        SymbolKind::Variable => "variable",
1483        SymbolKind::Heading => "heading",
1484    };
1485
1486    // Build: "file:relative/path kind:function name:validateAuth signature:fn validateAuth(token: &str) -> bool"
1487    let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1488
1489    if let Some(sig) = &symbol.signature {
1490        text.push_str(&format!(" signature:{}", sig));
1491    }
1492
1493    // Add body snippet (first ~300 chars of symbol body)
1494    let lines: Vec<&str> = source.lines().collect();
1495    let start = (symbol.range.start_line as usize).min(lines.len());
1496    // range.end_line is inclusive 0-based; +1 makes it an exclusive slice bound.
1497    let end = (symbol.range.end_line as usize + 1).min(lines.len());
1498    if start < end {
1499        let body: String = lines[start..end]
1500            .iter()
1501            .take(15) // max 15 lines
1502            .copied()
1503            .collect::<Vec<&str>>()
1504            .join("\n");
1505        let snippet = if body.len() > 300 {
1506            format!("{}...", &body[..body.floor_char_boundary(300)])
1507        } else {
1508            body
1509        };
1510        text.push_str(&format!(" body:{}", snippet));
1511    }
1512
1513    text
1514}
1515
1516fn parser_for(
1517    parsers: &mut HashMap<crate::parser::LangId, Parser>,
1518    lang: crate::parser::LangId,
1519) -> Result<&mut Parser, String> {
1520    use std::collections::hash_map::Entry;
1521
1522    match parsers.entry(lang) {
1523        Entry::Occupied(entry) => Ok(entry.into_mut()),
1524        Entry::Vacant(entry) => {
1525            let grammar = grammar_for(lang);
1526            let mut parser = Parser::new();
1527            parser
1528                .set_language(&grammar)
1529                .map_err(|error| error.to_string())?;
1530            Ok(entry.insert(parser))
1531        }
1532    }
1533}
1534
1535fn collect_file_chunks(
1536    project_root: &Path,
1537    file: &Path,
1538    parsers: &mut HashMap<crate::parser::LangId, Parser>,
1539) -> Result<Vec<SemanticChunk>, String> {
1540    let lang = detect_language(file).ok_or_else(|| "unsupported file extension".to_string())?;
1541    let source = std::fs::read_to_string(file).map_err(|error| error.to_string())?;
1542    let tree = parser_for(parsers, lang)?
1543        .parse(&source, None)
1544        .ok_or_else(|| format!("tree-sitter parse returned None for {}", file.display()))?;
1545    let symbols =
1546        extract_symbols_from_tree(&source, &tree, lang).map_err(|error| error.to_string())?;
1547
1548    Ok(symbols_to_chunks(file, &symbols, &source, project_root))
1549}
1550
1551/// Build a display snippet from a symbol's source
1552fn build_snippet(symbol: &Symbol, source: &str) -> String {
1553    let lines: Vec<&str> = source.lines().collect();
1554    let start = (symbol.range.start_line as usize).min(lines.len());
1555    // range.end_line is inclusive 0-based; +1 makes it an exclusive slice bound.
1556    let end = (symbol.range.end_line as usize + 1).min(lines.len());
1557    if start < end {
1558        let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1559        let mut snippet = snippet_lines.join("\n");
1560        if end - start > 5 {
1561            snippet.push_str("\n  ...");
1562        }
1563        if snippet.len() > 300 {
1564            snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1565        }
1566        snippet
1567    } else {
1568        String::new()
1569    }
1570}
1571
1572/// Convert symbols to semantic chunks with enriched context
1573fn symbols_to_chunks(
1574    file: &Path,
1575    symbols: &[Symbol],
1576    source: &str,
1577    project_root: &Path,
1578) -> Vec<SemanticChunk> {
1579    let mut chunks = Vec::new();
1580
1581    for symbol in symbols {
1582        // Skip Markdown / HTML heading chunks: empirically they dominate result
1583        // lists even for code-shaped queries because heading prose embeds well.
1584        // Agents querying for code lose the actual matches under doc noise.
1585        // README/docs queries are still served by grep on the same files.
1586        if matches!(symbol.kind, SymbolKind::Heading) {
1587            continue;
1588        }
1589
1590        // Skip very small symbols (single-line variables, etc.)
1591        let line_count = symbol
1592            .range
1593            .end_line
1594            .saturating_sub(symbol.range.start_line)
1595            + 1;
1596        if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1597            continue;
1598        }
1599
1600        let embed_text = build_embed_text(symbol, source, file, project_root);
1601        let snippet = build_snippet(symbol, source);
1602
1603        chunks.push(SemanticChunk {
1604            file: file.to_path_buf(),
1605            name: symbol.name.clone(),
1606            kind: symbol.kind.clone(),
1607            start_line: symbol.range.start_line,
1608            end_line: symbol.range.end_line,
1609            exported: symbol.exported,
1610            embed_text,
1611            snippet,
1612        });
1613
1614        // Note: Nested symbols are handled separately by the outline system
1615        // Each symbol is indexed individually
1616    }
1617
1618    chunks
1619}
1620
1621/// Cosine similarity between two vectors
1622fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1623    if a.len() != b.len() {
1624        return 0.0;
1625    }
1626
1627    let mut dot = 0.0f32;
1628    let mut norm_a = 0.0f32;
1629    let mut norm_b = 0.0f32;
1630
1631    for i in 0..a.len() {
1632        dot += a[i] * b[i];
1633        norm_a += a[i] * a[i];
1634        norm_b += b[i] * b[i];
1635    }
1636
1637    let denom = norm_a.sqrt() * norm_b.sqrt();
1638    if denom == 0.0 {
1639        0.0
1640    } else {
1641        dot / denom
1642    }
1643}
1644
1645// Serialization helpers
1646fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1647    match kind {
1648        SymbolKind::Function => 0,
1649        SymbolKind::Class => 1,
1650        SymbolKind::Method => 2,
1651        SymbolKind::Struct => 3,
1652        SymbolKind::Interface => 4,
1653        SymbolKind::Enum => 5,
1654        SymbolKind::TypeAlias => 6,
1655        SymbolKind::Variable => 7,
1656        SymbolKind::Heading => 8,
1657    }
1658}
1659
1660fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1661    match v {
1662        0 => SymbolKind::Function,
1663        1 => SymbolKind::Class,
1664        2 => SymbolKind::Method,
1665        3 => SymbolKind::Struct,
1666        4 => SymbolKind::Interface,
1667        5 => SymbolKind::Enum,
1668        6 => SymbolKind::TypeAlias,
1669        7 => SymbolKind::Variable,
1670        _ => SymbolKind::Heading,
1671    }
1672}
1673
1674fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
1675    if *pos + 4 > data.len() {
1676        return Err("unexpected end of data reading u32".to_string());
1677    }
1678    let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
1679    *pos += 4;
1680    Ok(val)
1681}
1682
1683fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
1684    if *pos + 8 > data.len() {
1685        return Err("unexpected end of data reading u64".to_string());
1686    }
1687    let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
1688    *pos += 8;
1689    Ok(u64::from_le_bytes(bytes))
1690}
1691
1692fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
1693    let len = read_u32(data, pos)? as usize;
1694    if *pos + len > data.len() {
1695        return Err("unexpected end of data reading string".to_string());
1696    }
1697    let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
1698    *pos += len;
1699    Ok(s)
1700}
1701
1702#[cfg(test)]
1703mod tests {
1704    use super::*;
1705    use crate::config::{SemanticBackend, SemanticBackendConfig};
1706    use crate::parser::FileParser;
1707    use std::io::{Read, Write};
1708    use std::net::TcpListener;
1709    use std::thread;
1710
1711    fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
1712    where
1713        F: Fn(String, String, String) -> String + Send + 'static,
1714    {
1715        let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1716        let addr = listener.local_addr().expect("local addr");
1717        let handle = thread::spawn(move || {
1718            let (mut stream, _) = listener.accept().expect("accept request");
1719            let mut buf = Vec::new();
1720            let mut chunk = [0u8; 4096];
1721            let mut header_end = None;
1722            let mut content_length = 0usize;
1723            loop {
1724                let n = stream.read(&mut chunk).expect("read request");
1725                if n == 0 {
1726                    break;
1727                }
1728                buf.extend_from_slice(&chunk[..n]);
1729                if header_end.is_none() {
1730                    if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
1731                        header_end = Some(pos + 4);
1732                        let headers = String::from_utf8_lossy(&buf[..pos + 4]);
1733                        for line in headers.lines() {
1734                            if let Some(value) = line.strip_prefix("Content-Length:") {
1735                                content_length = value.trim().parse::<usize>().unwrap_or(0);
1736                            }
1737                        }
1738                    }
1739                }
1740                if let Some(end) = header_end {
1741                    if buf.len() >= end + content_length {
1742                        break;
1743                    }
1744                }
1745            }
1746
1747            let end = header_end.expect("header terminator");
1748            let request = String::from_utf8_lossy(&buf[..end]).to_string();
1749            let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
1750            let mut lines = request.lines();
1751            let request_line = lines.next().expect("request line").to_string();
1752            let path = request_line
1753                .split_whitespace()
1754                .nth(1)
1755                .expect("request path")
1756                .to_string();
1757            let response_body = handler(request_line, path, body);
1758            let response = format!(
1759                "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
1760                response_body.len(),
1761                response_body
1762            );
1763            stream
1764                .write_all(response.as_bytes())
1765                .expect("write response");
1766        });
1767
1768        (format!("http://{}", addr), handle)
1769    }
1770
1771    #[test]
1772    fn test_cosine_similarity_identical() {
1773        let a = vec![1.0, 0.0, 0.0];
1774        let b = vec![1.0, 0.0, 0.0];
1775        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1776    }
1777
1778    #[test]
1779    fn test_cosine_similarity_orthogonal() {
1780        let a = vec![1.0, 0.0, 0.0];
1781        let b = vec![0.0, 1.0, 0.0];
1782        assert!(cosine_similarity(&a, &b).abs() < 0.001);
1783    }
1784
1785    #[test]
1786    fn test_cosine_similarity_opposite() {
1787        let a = vec![1.0, 0.0, 0.0];
1788        let b = vec![-1.0, 0.0, 0.0];
1789        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
1790    }
1791
1792    #[test]
1793    fn test_serialization_roundtrip() {
1794        let mut index = SemanticIndex::new();
1795        index.entries.push(EmbeddingEntry {
1796            chunk: SemanticChunk {
1797                file: PathBuf::from("/src/main.rs"),
1798                name: "handle_request".to_string(),
1799                kind: SymbolKind::Function,
1800                start_line: 10,
1801                end_line: 25,
1802                exported: true,
1803                embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1804                snippet: "fn handle_request() {\n  // ...\n}".to_string(),
1805            },
1806            vector: vec![0.1, 0.2, 0.3, 0.4],
1807        });
1808        index.dimension = 4;
1809        index
1810            .file_mtimes
1811            .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1812        index.set_fingerprint(SemanticIndexFingerprint {
1813            backend: "fastembed".to_string(),
1814            model: "all-MiniLM-L6-v2".to_string(),
1815            base_url: FALLBACK_BACKEND.to_string(),
1816            dimension: 4,
1817        });
1818
1819        let bytes = index.to_bytes();
1820        let restored = SemanticIndex::from_bytes(&bytes).unwrap();
1821
1822        assert_eq!(restored.entries.len(), 1);
1823        assert_eq!(restored.entries[0].chunk.name, "handle_request");
1824        assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
1825        assert_eq!(restored.dimension, 4);
1826        assert_eq!(restored.backend_label(), Some("fastembed"));
1827        assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
1828    }
1829
1830    #[test]
1831    fn test_search_top_k() {
1832        let mut index = SemanticIndex::new();
1833        index.dimension = 3;
1834
1835        // Add entries with known vectors
1836        for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
1837            let mut vec = vec![0.0f32; 3];
1838            vec[i] = 1.0; // orthogonal vectors
1839            index.entries.push(EmbeddingEntry {
1840                chunk: SemanticChunk {
1841                    file: PathBuf::from("/src/lib.rs"),
1842                    name: name.to_string(),
1843                    kind: SymbolKind::Function,
1844                    start_line: (i * 10 + 1) as u32,
1845                    end_line: (i * 10 + 5) as u32,
1846                    exported: true,
1847                    embed_text: format!("kind:function name:{}", name),
1848                    snippet: format!("fn {}() {{}}", name),
1849                },
1850                vector: vec,
1851            });
1852        }
1853
1854        // Query aligned with "auth" (index 0)
1855        let query = vec![0.9, 0.1, 0.0];
1856        let results = index.search(&query, 2);
1857
1858        assert_eq!(results.len(), 2);
1859        assert_eq!(results[0].name, "auth"); // highest score
1860        assert!(results[0].score > results[1].score);
1861    }
1862
1863    #[test]
1864    fn test_empty_index_search() {
1865        let index = SemanticIndex::new();
1866        let results = index.search(&[0.1, 0.2, 0.3], 10);
1867        assert!(results.is_empty());
1868    }
1869
1870    #[test]
1871    fn single_line_symbol_builds_non_empty_snippet() {
1872        let symbol = Symbol {
1873            name: "answer".to_string(),
1874            kind: SymbolKind::Variable,
1875            range: crate::symbols::Range {
1876                start_line: 0,
1877                start_col: 0,
1878                end_line: 0,
1879                end_col: 24,
1880            },
1881            signature: Some("const answer = 42".to_string()),
1882            scope_chain: Vec::new(),
1883            exported: true,
1884            parent: None,
1885        };
1886        let source = "export const answer = 42;\n";
1887
1888        let snippet = build_snippet(&symbol, source);
1889
1890        assert_eq!(snippet, "export const answer = 42;");
1891    }
1892
1893    #[test]
1894    fn optimized_file_chunk_collection_matches_file_parser_path() {
1895        let project_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1896        let file = project_root.join("src/semantic_index.rs");
1897        let source = std::fs::read_to_string(&file).unwrap();
1898
1899        let mut legacy_parser = FileParser::new();
1900        let legacy_symbols = legacy_parser.extract_symbols(&file).unwrap();
1901        let legacy_chunks = symbols_to_chunks(&file, &legacy_symbols, &source, &project_root);
1902
1903        let mut parsers = HashMap::new();
1904        let optimized_chunks = collect_file_chunks(&project_root, &file, &mut parsers).unwrap();
1905
1906        assert_eq!(
1907            chunk_fingerprint(&optimized_chunks),
1908            chunk_fingerprint(&legacy_chunks)
1909        );
1910    }
1911
1912    fn chunk_fingerprint(
1913        chunks: &[SemanticChunk],
1914    ) -> Vec<(String, SymbolKind, u32, u32, bool, String, String)> {
1915        chunks
1916            .iter()
1917            .map(|chunk| {
1918                (
1919                    chunk.name.clone(),
1920                    chunk.kind.clone(),
1921                    chunk.start_line,
1922                    chunk.end_line,
1923                    chunk.exported,
1924                    chunk.embed_text.clone(),
1925                    chunk.snippet.clone(),
1926                )
1927            })
1928            .collect()
1929    }
1930
1931    #[test]
1932    fn rejects_oversized_dimension_during_deserialization() {
1933        let mut bytes = Vec::new();
1934        bytes.push(1u8);
1935        bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
1936        bytes.extend_from_slice(&0u32.to_le_bytes());
1937        bytes.extend_from_slice(&0u32.to_le_bytes());
1938
1939        assert!(SemanticIndex::from_bytes(&bytes).is_err());
1940    }
1941
1942    #[test]
1943    fn rejects_oversized_entry_count_during_deserialization() {
1944        let mut bytes = Vec::new();
1945        bytes.push(1u8);
1946        bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
1947        bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
1948        bytes.extend_from_slice(&0u32.to_le_bytes());
1949
1950        assert!(SemanticIndex::from_bytes(&bytes).is_err());
1951    }
1952
1953    #[test]
1954    fn invalidate_file_removes_entries_and_mtime() {
1955        let target = PathBuf::from("/src/main.rs");
1956        let mut index = SemanticIndex::new();
1957        index.entries.push(EmbeddingEntry {
1958            chunk: SemanticChunk {
1959                file: target.clone(),
1960                name: "main".to_string(),
1961                kind: SymbolKind::Function,
1962                start_line: 0,
1963                end_line: 1,
1964                exported: false,
1965                embed_text: "main".to_string(),
1966                snippet: "fn main() {}".to_string(),
1967            },
1968            vector: vec![1.0; DEFAULT_DIMENSION],
1969        });
1970        index
1971            .file_mtimes
1972            .insert(target.clone(), SystemTime::UNIX_EPOCH);
1973
1974        index.invalidate_file(&target);
1975
1976        assert!(index.entries.is_empty());
1977        assert!(!index.file_mtimes.contains_key(&target));
1978    }
1979
1980    #[test]
1981    fn detects_missing_onnx_runtime_from_dynamic_load_error() {
1982        let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
1983
1984        assert!(is_onnx_runtime_unavailable(message));
1985    }
1986
1987    #[test]
1988    fn formats_missing_onnx_runtime_with_install_hint() {
1989        let message = format_embedding_init_error(
1990            "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
1991        );
1992
1993        assert!(message.starts_with("ONNX Runtime not found. Install via:"));
1994        assert!(message.contains("Original error:"));
1995    }
1996
1997    #[test]
1998    fn openai_compatible_backend_embeds_with_mock_server() {
1999        let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2000            assert!(request_line.starts_with("POST "));
2001            assert_eq!(path, "/v1/embeddings");
2002            "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
2003        });
2004
2005        let config = SemanticBackendConfig {
2006            backend: SemanticBackend::OpenAiCompatible,
2007            model: "test-embedding".to_string(),
2008            base_url: Some(base_url),
2009            api_key_env: None,
2010            timeout_ms: 5_000,
2011            max_batch_size: 64,
2012        };
2013
2014        let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2015        let vectors = model
2016            .embed(vec!["hello".to_string(), "world".to_string()])
2017            .unwrap();
2018
2019        assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
2020        handle.join().unwrap();
2021    }
2022
2023    #[test]
2024    fn ollama_backend_embeds_with_mock_server() {
2025        let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2026            assert!(request_line.starts_with("POST "));
2027            assert_eq!(path, "/api/embed");
2028            "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
2029        });
2030
2031        let config = SemanticBackendConfig {
2032            backend: SemanticBackend::Ollama,
2033            model: "embeddinggemma".to_string(),
2034            base_url: Some(base_url),
2035            api_key_env: None,
2036            timeout_ms: 5_000,
2037            max_batch_size: 64,
2038        };
2039
2040        let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2041        let vectors = model
2042            .embed(vec!["hello".to_string(), "world".to_string()])
2043            .unwrap();
2044
2045        assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
2046        handle.join().unwrap();
2047    }
2048
2049    #[test]
2050    fn read_from_disk_rejects_fingerprint_mismatch() {
2051        let storage = tempfile::tempdir().unwrap();
2052        let project_key = "proj";
2053
2054        let mut index = SemanticIndex::new();
2055        index.entries.push(EmbeddingEntry {
2056            chunk: SemanticChunk {
2057                file: PathBuf::from("/src/main.rs"),
2058                name: "handle_request".to_string(),
2059                kind: SymbolKind::Function,
2060                start_line: 10,
2061                end_line: 25,
2062                exported: true,
2063                embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2064                snippet: "fn handle_request() {}".to_string(),
2065            },
2066            vector: vec![0.1, 0.2, 0.3],
2067        });
2068        index.dimension = 3;
2069        index
2070            .file_mtimes
2071            .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2072        index.set_fingerprint(SemanticIndexFingerprint {
2073            backend: "openai_compatible".to_string(),
2074            model: "test-embedding".to_string(),
2075            base_url: "http://127.0.0.1:1234/v1".to_string(),
2076            dimension: 3,
2077        });
2078        index.write_to_disk(storage.path(), project_key);
2079
2080        let matching = index.fingerprint().unwrap().as_string();
2081        assert!(
2082            SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
2083        );
2084
2085        let mismatched = SemanticIndexFingerprint {
2086            backend: "ollama".to_string(),
2087            model: "embeddinggemma".to_string(),
2088            base_url: "http://127.0.0.1:11434".to_string(),
2089            dimension: 3,
2090        }
2091        .as_string();
2092        assert!(
2093            SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
2094        );
2095    }
2096
2097    #[test]
2098    fn read_from_disk_rejects_v3_cache_for_snippet_rebuild() {
2099        let storage = tempfile::tempdir().unwrap();
2100        let project_key = "proj-v3";
2101        let dir = storage.path().join("semantic").join(project_key);
2102        fs::create_dir_all(&dir).unwrap();
2103
2104        let mut index = SemanticIndex::new();
2105        index.entries.push(EmbeddingEntry {
2106            chunk: SemanticChunk {
2107                file: PathBuf::from("/src/main.rs"),
2108                name: "handle_request".to_string(),
2109                kind: SymbolKind::Function,
2110                start_line: 0,
2111                end_line: 0,
2112                exported: true,
2113                embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2114                snippet: "fn handle_request() {}".to_string(),
2115            },
2116            vector: vec![0.1, 0.2, 0.3],
2117        });
2118        index.dimension = 3;
2119        index
2120            .file_mtimes
2121            .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2122        let fingerprint = SemanticIndexFingerprint {
2123            backend: "fastembed".to_string(),
2124            model: "test".to_string(),
2125            base_url: FALLBACK_BACKEND.to_string(),
2126            dimension: 3,
2127        };
2128        index.set_fingerprint(fingerprint.clone());
2129
2130        let mut bytes = index.to_bytes();
2131        bytes[0] = SEMANTIC_INDEX_VERSION_V3;
2132        fs::write(dir.join("semantic.bin"), bytes).unwrap();
2133
2134        assert!(SemanticIndex::read_from_disk(
2135            storage.path(),
2136            project_key,
2137            Some(&fingerprint.as_string())
2138        )
2139        .is_none());
2140        assert!(!dir.join("semantic.bin").exists());
2141    }
2142
2143    fn make_symbol(kind: SymbolKind, name: &str, start: u32, end: u32) -> crate::symbols::Symbol {
2144        crate::symbols::Symbol {
2145            name: name.to_string(),
2146            kind,
2147            range: crate::symbols::Range {
2148                start_line: start,
2149                start_col: 0,
2150                end_line: end,
2151                end_col: 0,
2152            },
2153            signature: None,
2154            scope_chain: Vec::new(),
2155            exported: false,
2156            parent: None,
2157        }
2158    }
2159
2160    /// Heading symbols (Markdown / HTML headings) must NOT be indexed —
2161    /// they overwhelmingly dominated semantic results even on code-shaped
2162    /// queries because heading prose embeds far more strongly than code
2163    /// chunks. Skipping headings keeps aft_search a code-finder.
2164    #[test]
2165    fn symbols_to_chunks_skips_heading_symbols() {
2166        let project_root = PathBuf::from("/proj");
2167        let file = project_root.join("README.md");
2168        let source = "# Title\n\nbody text\n\n## Section\n\nmore text\n";
2169
2170        let symbols = vec![
2171            make_symbol(SymbolKind::Heading, "Title", 0, 2),
2172            make_symbol(SymbolKind::Heading, "Section", 4, 6),
2173        ];
2174
2175        let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2176        assert!(
2177            chunks.is_empty(),
2178            "Heading symbols must be filtered out before embedding; got {} chunk(s)",
2179            chunks.len()
2180        );
2181    }
2182
2183    /// Code symbols (functions, classes, methods, structs, etc.) must still
2184    /// be indexed alongside the heading skip — otherwise we'd starve the
2185    /// index entirely.
2186    #[test]
2187    fn symbols_to_chunks_keeps_code_symbols_alongside_skipped_headings() {
2188        let project_root = PathBuf::from("/proj");
2189        let file = project_root.join("src/lib.rs");
2190        let source = "pub fn handle_request() -> bool {\n    true\n}\n";
2191
2192        let symbols = vec![
2193            // A heading mixed in (e.g. from a doc comment block elsewhere).
2194            make_symbol(SymbolKind::Heading, "doc heading", 0, 1),
2195            make_symbol(SymbolKind::Function, "handle_request", 0, 2),
2196            make_symbol(SymbolKind::Struct, "AuthService", 4, 6),
2197        ];
2198
2199        let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2200        assert_eq!(
2201            chunks.len(),
2202            2,
2203            "Expected 2 code chunks (Function + Struct), got {}",
2204            chunks.len()
2205        );
2206        let names: Vec<&str> = chunks.iter().map(|c| c.name.as_str()).collect();
2207        assert!(names.contains(&"handle_request"));
2208        assert!(names.contains(&"AuthService"));
2209        assert!(
2210            !names.contains(&"doc heading"),
2211            "Heading symbol leaked into chunks: {names:?}"
2212        );
2213    }
2214
2215    #[test]
2216    fn validate_ssrf_rejects_loopback_hostnames() {
2217        for host in &[
2218            "http://localhost",
2219            "http://localhost:8080",
2220            "http://localhost.localdomain",
2221            "http://foo.localhost",
2222        ] {
2223            assert!(
2224                validate_base_url_no_ssrf(host).is_err(),
2225                "Expected {host} to be rejected"
2226            );
2227        }
2228    }
2229
2230    #[test]
2231    fn validate_ssrf_rejects_private_ips() {
2232        for url in &[
2233            "http://192.168.1.1",
2234            "http://10.0.0.1",
2235            "http://172.16.0.1",
2236            "http://127.0.0.1",
2237            "http://169.254.169.254",
2238        ] {
2239            let result = validate_base_url_no_ssrf(url);
2240            assert!(
2241                result.is_err(),
2242                "Expected {url} to be rejected, got: {:?}",
2243                result
2244            );
2245        }
2246    }
2247
2248    #[test]
2249    fn normalize_base_url_allows_localhost_for_tests() {
2250        // normalize_base_url itself should NOT block localhost — only
2251        // validate_base_url_no_ssrf does. Tests construct backends directly.
2252        assert!(normalize_base_url("http://127.0.0.1:9999").is_ok());
2253        assert!(normalize_base_url("http://localhost:8080").is_ok());
2254    }
2255}