Skip to main content

aft/
semantic_index.rs

1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::FileParser;
3use crate::symbols::{Symbol, SymbolKind};
4
5use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
6use reqwest::blocking::Client;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::env;
10use std::fmt::Display;
11use std::fs;
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14use std::time::SystemTime;
15use url::Url;
16
17const DEFAULT_DIMENSION: usize = 384;
18const MAX_ENTRIES: usize = 1_000_000;
19const MAX_DIMENSION: usize = 1024;
20const F32_BYTES: usize = std::mem::size_of::<f32>();
21const HEADER_BYTES_V1: usize = 9;
22const HEADER_BYTES_V2: usize = 13;
23const ONNX_RUNTIME_INSTALL_HINT: &str =
24    "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
25
26const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
27const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
28const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
29const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
30// Must stay below the bridge timeout (30s) to avoid bridge kills on slow backends.
31const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
32const DEFAULT_MAX_BATCH_SIZE: usize = 64;
33const FALLBACK_BACKEND: &str = "none";
34const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
35const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct SemanticIndexFingerprint {
39    pub backend: String,
40    pub model: String,
41    #[serde(default)]
42    pub base_url: String,
43    pub dimension: usize,
44}
45
46impl SemanticIndexFingerprint {
47    fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
48        // Use normalized URL for fingerprinting so cosmetic differences
49        // (e.g. "http://host/v1" vs "http://host/v1/") don't cause rebuilds.
50        let base_url = config
51            .base_url
52            .as_ref()
53            .and_then(|u| normalize_base_url(u).ok())
54            .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
55        Self {
56            backend: config.backend.as_str().to_string(),
57            model: config.model.clone(),
58            base_url,
59            dimension,
60        }
61    }
62
63    pub fn as_string(&self) -> String {
64        serde_json::to_string(self).unwrap_or_else(|_| String::new())
65    }
66
67    fn matches_expected(&self, expected: &str) -> bool {
68        let encoded = self.as_string();
69        !encoded.is_empty() && encoded == expected
70    }
71}
72
73enum SemanticEmbeddingEngine {
74    Fastembed(TextEmbedding),
75    OpenAiCompatible {
76        client: Client,
77        model: String,
78        base_url: String,
79        api_key: Option<String>,
80    },
81    Ollama {
82        client: Client,
83        model: String,
84        base_url: String,
85    },
86}
87
88pub struct SemanticEmbeddingModel {
89    backend: SemanticBackend,
90    model: String,
91    base_url: Option<String>,
92    timeout_ms: u64,
93    max_batch_size: usize,
94    dimension: Option<usize>,
95    engine: SemanticEmbeddingEngine,
96}
97
98pub type EmbeddingModel = SemanticEmbeddingModel;
99
100fn validate_embedding_batch(
101    vectors: &[Vec<f32>],
102    expected_count: usize,
103    context: &str,
104) -> Result<(), String> {
105    if expected_count > 0 && vectors.is_empty() {
106        return Err(format!(
107            "{context} returned no vectors for {expected_count} inputs"
108        ));
109    }
110
111    if vectors.len() != expected_count {
112        return Err(format!(
113            "{context} returned {} vectors for {} inputs",
114            vectors.len(),
115            expected_count
116        ));
117    }
118
119    let Some(first_vector) = vectors.first() else {
120        return Ok(());
121    };
122    let expected_dimension = first_vector.len();
123    for (index, vector) in vectors.iter().enumerate() {
124        if vector.len() != expected_dimension {
125            return Err(format!(
126                "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
127                vector.len()
128            ));
129        }
130    }
131
132    Ok(())
133}
134
135fn normalize_base_url(raw: &str) -> Result<String, String> {
136    let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
137    let scheme = parsed.scheme();
138    if scheme != "http" && scheme != "https" {
139        return Err(format!(
140            "unsupported URL scheme '{}' — only http:// and https:// are allowed",
141            scheme
142        ));
143    }
144    Ok(parsed.to_string().trim_end_matches('/').to_string())
145}
146
147fn build_openai_embeddings_endpoint(base_url: &str) -> String {
148    if base_url.ends_with("/v1") {
149        format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
150    } else {
151        format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
152    }
153}
154
155fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
156    if base_url.ends_with("/api") {
157        format!("{base_url}/embed")
158    } else {
159        format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
160    }
161}
162
163fn normalize_api_key(value: Option<String>) -> Option<String> {
164    value.and_then(|token| {
165        let token = token.trim();
166        if token.is_empty() {
167            None
168        } else {
169            Some(token.to_string())
170        }
171    })
172}
173
174fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
175    status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
176}
177
178fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
179    error.is_connect()
180}
181
182fn sleep_before_embedding_retry(attempt_index: usize) {
183    if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
184        std::thread::sleep(Duration::from_millis(*delay_ms));
185    }
186}
187
188fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
189where
190    F: FnMut() -> reqwest::blocking::RequestBuilder,
191{
192    for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
193        let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
194
195        let response = match make_request().send() {
196            Ok(response) => response,
197            Err(error) => {
198                if !last_attempt && is_retryable_embedding_error(&error) {
199                    sleep_before_embedding_retry(attempt_index);
200                    continue;
201                }
202                return Err(format!("{backend_label} request failed: {error}"));
203            }
204        };
205
206        let status = response.status();
207        let raw = match response.text() {
208            Ok(raw) => raw,
209            Err(error) => {
210                if !last_attempt && is_retryable_embedding_error(&error) {
211                    sleep_before_embedding_retry(attempt_index);
212                    continue;
213                }
214                return Err(format!("{backend_label} response read failed: {error}"));
215            }
216        };
217
218        if status.is_success() {
219            return Ok(raw);
220        }
221
222        if !last_attempt && is_retryable_embedding_status(status) {
223            sleep_before_embedding_retry(attempt_index);
224            continue;
225        }
226
227        return Err(format!(
228            "{backend_label} request failed (HTTP {}): {}",
229            status, raw
230        ));
231    }
232
233    unreachable!("embedding request retries exhausted without returning")
234}
235
236impl SemanticEmbeddingModel {
237    pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
238        let timeout_ms = if config.timeout_ms == 0 {
239            DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
240        } else {
241            config.timeout_ms
242        };
243
244        let max_batch_size = if config.max_batch_size == 0 {
245            DEFAULT_MAX_BATCH_SIZE
246        } else {
247            config.max_batch_size
248        };
249
250        let api_key_env = normalize_api_key(config.api_key_env.clone());
251        let model = config.model.clone();
252
253        let client = Client::builder()
254            .timeout(Duration::from_millis(timeout_ms))
255            .redirect(reqwest::redirect::Policy::none())
256            .build()
257            .map_err(|error| format!("failed to configure embedding client: {error}"))?;
258
259        let engine = match config.backend {
260            SemanticBackend::Fastembed => {
261                SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
262            }
263            SemanticBackend::OpenAiCompatible => {
264                let raw = config.base_url.as_ref().ok_or_else(|| {
265                    "base_url is required for openai_compatible backend".to_string()
266                })?;
267                let base_url = normalize_base_url(raw)?;
268
269                let api_key = match api_key_env {
270                    Some(var_name) => Some(env::var(&var_name).map_err(|_| {
271                        format!("missing api_key_env '{var_name}' for openai_compatible backend")
272                    })?),
273                    None => None,
274                };
275
276                SemanticEmbeddingEngine::OpenAiCompatible {
277                    client,
278                    model,
279                    base_url,
280                    api_key,
281                }
282            }
283            SemanticBackend::Ollama => {
284                let raw = config
285                    .base_url
286                    .as_ref()
287                    .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
288                let base_url = normalize_base_url(raw)?;
289
290                SemanticEmbeddingEngine::Ollama {
291                    client,
292                    model,
293                    base_url,
294                }
295            }
296        };
297
298        Ok(Self {
299            backend: config.backend,
300            model: config.model.clone(),
301            base_url: config.base_url.clone(),
302            timeout_ms,
303            max_batch_size,
304            dimension: None,
305            engine,
306        })
307    }
308
309    pub fn backend(&self) -> SemanticBackend {
310        self.backend
311    }
312
313    pub fn model(&self) -> &str {
314        &self.model
315    }
316
317    pub fn base_url(&self) -> Option<&str> {
318        self.base_url.as_deref()
319    }
320
321    pub fn max_batch_size(&self) -> usize {
322        self.max_batch_size
323    }
324
325    pub fn timeout_ms(&self) -> u64 {
326        self.timeout_ms
327    }
328
329    pub fn fingerprint(
330        &mut self,
331        config: &SemanticBackendConfig,
332    ) -> Result<SemanticIndexFingerprint, String> {
333        let dimension = self.dimension()?;
334        Ok(SemanticIndexFingerprint::from_config(config, dimension))
335    }
336
337    pub fn dimension(&mut self) -> Result<usize, String> {
338        if let Some(dimension) = self.dimension {
339            return Ok(dimension);
340        }
341
342        let dimension = match &mut self.engine {
343            SemanticEmbeddingEngine::Fastembed(model) => {
344                let vectors = model
345                    .embed(vec!["semantic index fingerprint probe".to_string()], None)
346                    .map_err(|error| format_embedding_init_error(error.to_string()))?;
347                vectors
348                    .first()
349                    .map(|v| v.len())
350                    .ok_or_else(|| "embedding backend returned no vectors".to_string())?
351            }
352            SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
353                let vectors =
354                    self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
355                vectors
356                    .first()
357                    .map(|v| v.len())
358                    .ok_or_else(|| "embedding backend returned no vectors".to_string())?
359            }
360            SemanticEmbeddingEngine::Ollama { .. } => {
361                let vectors =
362                    self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
363                vectors
364                    .first()
365                    .map(|v| v.len())
366                    .ok_or_else(|| "embedding backend returned no vectors".to_string())?
367            }
368        };
369
370        self.dimension = Some(dimension);
371        Ok(dimension)
372    }
373
374    pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
375        self.embed_texts(texts)
376    }
377
378    fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
379        match &mut self.engine {
380            SemanticEmbeddingEngine::Fastembed(model) => model
381                .embed(texts, None::<usize>)
382                .map_err(|error| format_embedding_init_error(error.to_string()))
383                .map_err(|error| format!("failed to embed batch: {error}")),
384            SemanticEmbeddingEngine::OpenAiCompatible {
385                client,
386                model,
387                base_url,
388                api_key,
389            } => {
390                let expected_text_count = texts.len();
391                let endpoint = build_openai_embeddings_endpoint(base_url);
392                let body = serde_json::json!({
393                    "input": texts,
394                    "model": model,
395                });
396
397                let raw = send_embedding_request(
398                    || {
399                        let mut request = client
400                            .post(&endpoint)
401                            .json(&body)
402                            .header("Content-Type", "application/json");
403
404                        if let Some(api_key) = api_key {
405                            request = request.header("Authorization", format!("Bearer {api_key}"));
406                        }
407
408                        request
409                    },
410                    "openai compatible",
411                )?;
412
413                #[derive(Deserialize)]
414                struct OpenAiResponse {
415                    data: Vec<OpenAiEmbeddingResult>,
416                }
417
418                #[derive(Deserialize)]
419                struct OpenAiEmbeddingResult {
420                    embedding: Vec<f32>,
421                    index: Option<u32>,
422                }
423
424                let parsed: OpenAiResponse = serde_json::from_str(&raw)
425                    .map_err(|error| format!("invalid openai compatible response: {error}"))?;
426                if parsed.data.len() != expected_text_count {
427                    return Err(format!(
428                        "openai compatible response returned {} embeddings for {} inputs",
429                        parsed.data.len(),
430                        expected_text_count
431                    ));
432                }
433
434                let mut vectors = vec![Vec::new(); parsed.data.len()];
435                for (i, item) in parsed.data.into_iter().enumerate() {
436                    let index = item.index.unwrap_or(i as u32) as usize;
437                    if index >= vectors.len() {
438                        return Err(
439                            "openai compatible response contains invalid vector index".to_string()
440                        );
441                    }
442                    vectors[index] = item.embedding;
443                }
444
445                for vector in &vectors {
446                    if vector.is_empty() {
447                        return Err(
448                            "openai compatible response contained missing vectors".to_string()
449                        );
450                    }
451                }
452
453                self.dimension = vectors.first().map(Vec::len);
454                Ok(vectors)
455            }
456            SemanticEmbeddingEngine::Ollama {
457                client,
458                model,
459                base_url,
460            } => {
461                let expected_text_count = texts.len();
462                let endpoint = build_ollama_embeddings_endpoint(base_url);
463
464                #[derive(Serialize)]
465                struct OllamaPayload<'a> {
466                    model: &'a str,
467                    input: Vec<String>,
468                }
469
470                let payload = OllamaPayload {
471                    model,
472                    input: texts,
473                };
474
475                let raw = send_embedding_request(
476                    || {
477                        client
478                            .post(&endpoint)
479                            .json(&payload)
480                            .header("Content-Type", "application/json")
481                    },
482                    "ollama",
483                )?;
484
485                #[derive(Deserialize)]
486                struct OllamaResponse {
487                    embeddings: Vec<Vec<f32>>,
488                }
489
490                let parsed: OllamaResponse = serde_json::from_str(&raw)
491                    .map_err(|error| format!("invalid ollama response: {error}"))?;
492                if parsed.embeddings.is_empty() {
493                    return Err("ollama response returned no embeddings".to_string());
494                }
495                if parsed.embeddings.len() != expected_text_count {
496                    return Err(format!(
497                        "ollama response returned {} embeddings for {} inputs",
498                        parsed.embeddings.len(),
499                        expected_text_count
500                    ));
501                }
502
503                let vectors = parsed.embeddings;
504                for vector in &vectors {
505                    if vector.is_empty() {
506                        return Err("ollama response contained empty embeddings".to_string());
507                    }
508                }
509
510                self.dimension = vectors.first().map(Vec::len);
511                Ok(vectors)
512            }
513        }
514    }
515}
516
517/// Pre-validate ONNX Runtime by attempting a raw dlopen before ort touches it.
518/// This catches broken/incompatible .so files without risking a panic in the ort crate.
519/// Also checks the runtime version via OrtGetApiBase if available.
520pub fn pre_validate_onnx_runtime() -> Result<(), String> {
521    let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
522
523    #[cfg(any(target_os = "linux", target_os = "macos"))]
524    {
525        #[cfg(target_os = "linux")]
526        let default_name = "libonnxruntime.so";
527        #[cfg(target_os = "macos")]
528        let default_name = "libonnxruntime.dylib";
529
530        let lib_name = dylib_path.as_deref().unwrap_or(default_name);
531
532        unsafe {
533            let c_name = std::ffi::CString::new(lib_name)
534                .map_err(|e| format!("invalid library path: {}", e))?;
535            let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
536            if handle.is_null() {
537                let err = libc::dlerror();
538                let msg = if err.is_null() {
539                    "unknown dlopen error".to_string()
540                } else {
541                    std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
542                };
543                return Err(format!(
544                    "ONNX Runtime not found. dlopen('{}') failed: {}. \
545                     Run `bunx @cortexkit/aft-opencode@latest doctor` to diagnose.",
546                    lib_name, msg
547                ));
548            }
549
550            // Try to detect the runtime version from the file path or soname.
551            // libonnxruntime.so.1.19.0, libonnxruntime.1.24.4.dylib, etc.
552            let detected_version = detect_ort_version_from_path(lib_name);
553
554            libc::dlclose(handle);
555
556            // Check version compatibility — we need 1.24.x
557            if let Some(ref version) = detected_version {
558                let parts: Vec<&str> = version.split('.').collect();
559                if let (Some(major), Some(minor)) = (
560                    parts.first().and_then(|s| s.parse::<u32>().ok()),
561                    parts.get(1).and_then(|s| s.parse::<u32>().ok()),
562                ) {
563                    if major != 1 || minor < 20 {
564                        return Err(format!(
565                            "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
566                             Solutions:\n\
567                             1. Remove the old library and restart (AFT auto-downloads the correct version):\n\
568                             {}\n\
569                             2. Or install ONNX Runtime 1.24: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
570                             3. Run `bunx @cortexkit/aft-opencode@latest doctor` for full diagnostics.",
571                            version,
572                            lib_name,
573                            suggest_removal_command(lib_name),
574                        ));
575                    }
576                }
577            }
578        }
579    }
580
581    #[cfg(target_os = "windows")]
582    {
583        // On Windows, skip pre-validation — let ort handle LoadLibrary
584        let _ = dylib_path;
585    }
586
587    Ok(())
588}
589
590/// Try to extract the ORT version from the library filename or resolved symlink.
591/// Examples: "libonnxruntime.so.1.19.0" → "1.19.0", "libonnxruntime.1.24.4.dylib" → "1.24.4"
592fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
593    let path = std::path::Path::new(lib_path);
594
595    // Try the path as given, then follow symlinks
596    for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
597        .into_iter()
598        .flatten()
599    {
600        if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
601            if let Some(version) = extract_version_from_filename(name) {
602                return Some(version);
603            }
604        }
605    }
606
607    // Also check for versioned siblings in the same directory
608    if let Some(parent) = path.parent() {
609        if let Ok(entries) = std::fs::read_dir(parent) {
610            for entry in entries.flatten() {
611                if let Some(name) = entry.file_name().to_str() {
612                    if name.starts_with("libonnxruntime") {
613                        if let Some(version) = extract_version_from_filename(name) {
614                            return Some(version);
615                        }
616                    }
617                }
618            }
619        }
620    }
621
622    None
623}
624
625/// Extract version from filenames like "libonnxruntime.so.1.19.0" or "libonnxruntime.1.24.4.dylib"
626fn extract_version_from_filename(name: &str) -> Option<String> {
627    // Match patterns: .so.X.Y.Z or .X.Y.Z.dylib or .X.Y.Z.so
628    let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
629    re.find(name).map(|m| m.as_str().to_string())
630}
631
632fn suggest_removal_command(lib_path: &str) -> String {
633    if lib_path.starts_with("/usr/local/lib")
634        || lib_path == "libonnxruntime.so"
635        || lib_path == "libonnxruntime.dylib"
636    {
637        #[cfg(target_os = "linux")]
638        return "   sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
639        #[cfg(target_os = "macos")]
640        return "   sudo rm /usr/local/lib/libonnxruntime*".to_string();
641        #[cfg(target_os = "windows")]
642        return "   Delete the ONNX Runtime DLL from your PATH".to_string();
643    }
644    format!("   rm '{}'", lib_path)
645}
646
647pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
648    // Pre-validate before ort can panic on a bad library
649    pre_validate_onnx_runtime()?;
650
651    let selected_model = match model {
652        "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
653        _ => {
654            return Err(format!(
655                "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
656                model
657            ))
658        }
659    };
660
661    TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
662}
663
664pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
665    if message.trim_start().starts_with("ONNX Runtime not found.") {
666        return true;
667    }
668
669    let message = message.to_ascii_lowercase();
670    let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
671        .iter()
672        .any(|pattern| message.contains(pattern));
673    let mentions_dynamic_load_failure = [
674        "shared library",
675        "dynamic library",
676        "failed to load",
677        "could not load",
678        "unable to load",
679        "dlopen",
680        "loadlibrary",
681        "no such file",
682        "not found",
683    ]
684    .iter()
685    .any(|pattern| message.contains(pattern));
686
687    mentions_onnx_runtime && mentions_dynamic_load_failure
688}
689
690fn format_embedding_init_error(error: impl Display) -> String {
691    let message = error.to_string();
692
693    if is_onnx_runtime_unavailable(&message) {
694        return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
695    }
696
697    format!("failed to initialize semantic embedding model: {message}")
698}
699
700/// A chunk of code ready for embedding — derived from a Symbol with context enrichment
701#[derive(Debug, Clone)]
702pub struct SemanticChunk {
703    /// Absolute file path
704    pub file: PathBuf,
705    /// Symbol name
706    pub name: String,
707    /// Symbol kind (function, class, struct, etc.)
708    pub kind: SymbolKind,
709    /// Line range (0-based internally, inclusive)
710    pub start_line: u32,
711    pub end_line: u32,
712    /// Whether the symbol is exported
713    pub exported: bool,
714    /// The enriched text that gets embedded (scope + signature + body snippet)
715    pub embed_text: String,
716    /// Short code snippet for display in results
717    pub snippet: String,
718}
719
720/// A stored embedding entry — chunk metadata + vector
721#[derive(Debug)]
722struct EmbeddingEntry {
723    chunk: SemanticChunk,
724    vector: Vec<f32>,
725}
726
727/// The semantic index — stores embeddings for all symbols in a project
728#[derive(Debug)]
729pub struct SemanticIndex {
730    entries: Vec<EmbeddingEntry>,
731    /// Track which files are indexed and their mtime for staleness detection
732    file_mtimes: HashMap<PathBuf, SystemTime>,
733    /// Embedding dimension (384 for MiniLM-L6-v2)
734    dimension: usize,
735    fingerprint: Option<SemanticIndexFingerprint>,
736}
737
738/// Search result from a semantic query
739#[derive(Debug)]
740pub struct SemanticResult {
741    pub file: PathBuf,
742    pub name: String,
743    pub kind: SymbolKind,
744    pub start_line: u32,
745    pub end_line: u32,
746    pub exported: bool,
747    pub snippet: String,
748    pub score: f32,
749}
750
751impl SemanticIndex {
752    pub fn new() -> Self {
753        Self {
754            entries: Vec::new(),
755            file_mtimes: HashMap::new(),
756            dimension: DEFAULT_DIMENSION, // MiniLM-L6-v2 default
757            fingerprint: None,
758        }
759    }
760
761    /// Number of embedded symbol entries.
762    pub fn entry_count(&self) -> usize {
763        self.entries.len()
764    }
765
766    /// Human-readable status label for the index.
767    pub fn status_label(&self) -> &'static str {
768        if self.entries.is_empty() {
769            "empty"
770        } else {
771            "ready"
772        }
773    }
774
775    fn collect_chunks(
776        project_root: &Path,
777        files: &[PathBuf],
778    ) -> (Vec<SemanticChunk>, HashMap<PathBuf, SystemTime>) {
779        let mut parser = FileParser::new();
780        let mut chunks: Vec<SemanticChunk> = Vec::new();
781        let mut file_mtimes: HashMap<PathBuf, SystemTime> = HashMap::new();
782
783        for file in files {
784            let mtime = std::fs::metadata(file)
785                .and_then(|m| m.modified())
786                .unwrap_or(SystemTime::UNIX_EPOCH);
787            file_mtimes.insert(file.clone(), mtime);
788
789            let source = match std::fs::read_to_string(file) {
790                Ok(s) => s,
791                Err(_) => continue,
792            };
793
794            let symbols = match parser.extract_symbols(file) {
795                Ok(s) => s,
796                Err(_) => continue,
797            };
798            let file_chunks = symbols_to_chunks(file, &symbols, &source, project_root);
799            chunks.extend(file_chunks);
800        }
801
802        (chunks, file_mtimes)
803    }
804
805    fn build_from_chunks<F, P>(
806        chunks: Vec<SemanticChunk>,
807        file_mtimes: HashMap<PathBuf, SystemTime>,
808        embed_fn: &mut F,
809        max_batch_size: usize,
810        mut progress: Option<&mut P>,
811    ) -> Result<Self, String>
812    where
813        F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
814        P: FnMut(usize, usize),
815    {
816        let total_chunks = chunks.len();
817
818        if chunks.is_empty() {
819            return Ok(Self {
820                entries: Vec::new(),
821                file_mtimes,
822                dimension: DEFAULT_DIMENSION,
823                fingerprint: None,
824            });
825        }
826
827        // Embed in batches
828        let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
829        let mut expected_dimension: Option<usize> = None;
830        let batch_size = max_batch_size.max(1);
831        for batch_start in (0..chunks.len()).step_by(batch_size) {
832            let batch_end = (batch_start + batch_size).min(chunks.len());
833            let batch_texts: Vec<String> = chunks[batch_start..batch_end]
834                .iter()
835                .map(|c| c.embed_text.clone())
836                .collect();
837
838            let vectors = embed_fn(batch_texts)?;
839            validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
840
841            // Track consistent dimension across all batches
842            if let Some(dim) = vectors.first().map(|v| v.len()) {
843                match expected_dimension {
844                    None => expected_dimension = Some(dim),
845                    Some(expected) if dim != expected => {
846                        return Err(format!(
847                            "embedding dimension changed across batches: expected {expected}, got {dim}"
848                        ));
849                    }
850                    _ => {}
851                }
852            }
853
854            for (i, vector) in vectors.into_iter().enumerate() {
855                let chunk_idx = batch_start + i;
856                entries.push(EmbeddingEntry {
857                    chunk: chunks[chunk_idx].clone(),
858                    vector,
859                });
860            }
861
862            if let Some(callback) = progress.as_mut() {
863                callback(entries.len(), total_chunks);
864            }
865        }
866
867        let dimension = entries
868            .first()
869            .map(|e| e.vector.len())
870            .unwrap_or(DEFAULT_DIMENSION);
871
872        Ok(Self {
873            entries,
874            file_mtimes,
875            dimension,
876            fingerprint: None,
877        })
878    }
879
880    /// Build the semantic index from a set of files using the provided embedding function.
881    /// `embed_fn` takes a batch of texts and returns a batch of embedding vectors.
882    pub fn build<F>(
883        project_root: &Path,
884        files: &[PathBuf],
885        embed_fn: &mut F,
886        max_batch_size: usize,
887    ) -> Result<Self, String>
888    where
889        F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
890    {
891        let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
892        Self::build_from_chunks(
893            chunks,
894            file_mtimes,
895            embed_fn,
896            max_batch_size,
897            Option::<&mut fn(usize, usize)>::None,
898        )
899    }
900
901    /// Build the semantic index and report embedding progress using entry counts.
902    pub fn build_with_progress<F, P>(
903        project_root: &Path,
904        files: &[PathBuf],
905        embed_fn: &mut F,
906        max_batch_size: usize,
907        progress: &mut P,
908    ) -> Result<Self, String>
909    where
910        F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
911        P: FnMut(usize, usize),
912    {
913        let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
914        let total_chunks = chunks.len();
915        progress(0, total_chunks);
916        Self::build_from_chunks(
917            chunks,
918            file_mtimes,
919            embed_fn,
920            max_batch_size,
921            Some(progress),
922        )
923    }
924
925    /// Search the index with a query embedding, returning top-K results sorted by relevance
926    pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
927        if self.entries.is_empty() || query_vector.len() != self.dimension {
928            return Vec::new();
929        }
930
931        let mut scored: Vec<(f32, usize)> = self
932            .entries
933            .iter()
934            .enumerate()
935            .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
936            .collect();
937
938        // Sort descending by score
939        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
940
941        scored
942            .into_iter()
943            .take(top_k)
944            .filter(|(score, _)| *score > 0.0)
945            .map(|(score, idx)| {
946                let entry = &self.entries[idx];
947                SemanticResult {
948                    file: entry.chunk.file.clone(),
949                    name: entry.chunk.name.clone(),
950                    kind: entry.chunk.kind.clone(),
951                    start_line: entry.chunk.start_line,
952                    end_line: entry.chunk.end_line,
953                    exported: entry.chunk.exported,
954                    snippet: entry.chunk.snippet.clone(),
955                    score,
956                }
957            })
958            .collect()
959    }
960
961    /// Number of indexed entries
962    pub fn len(&self) -> usize {
963        self.entries.len()
964    }
965
966    /// Check if a file needs re-indexing based on mtime
967    pub fn is_file_stale(&self, file: &Path) -> bool {
968        match self.file_mtimes.get(file) {
969            None => true,
970            Some(stored_mtime) => match fs::metadata(file).and_then(|m| m.modified()) {
971                Ok(current_mtime) => *stored_mtime != current_mtime,
972                Err(_) => true,
973            },
974        }
975    }
976
977    pub fn count_stale_files(&self) -> usize {
978        self.file_mtimes
979            .keys()
980            .filter(|path| self.is_file_stale(path))
981            .count()
982    }
983
984    /// Remove entries for a specific file
985    pub fn remove_file(&mut self, file: &Path) {
986        self.invalidate_file(file);
987    }
988
989    pub fn invalidate_file(&mut self, file: &Path) {
990        self.entries.retain(|e| e.chunk.file != file);
991        self.file_mtimes.remove(file);
992    }
993
994    /// Get the embedding dimension
995    pub fn dimension(&self) -> usize {
996        self.dimension
997    }
998
999    pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1000        self.fingerprint.as_ref()
1001    }
1002
1003    pub fn backend_label(&self) -> Option<&str> {
1004        self.fingerprint.as_ref().map(|f| f.backend.as_str())
1005    }
1006
1007    pub fn model_label(&self) -> Option<&str> {
1008        self.fingerprint.as_ref().map(|f| f.model.as_str())
1009    }
1010
1011    pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1012        self.fingerprint = Some(fingerprint);
1013    }
1014
1015    /// Write the semantic index to disk using atomic temp+rename pattern
1016    pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1017        // Don't persist empty indexes — they would be loaded on next startup
1018        // and prevent a fresh build that might find files.
1019        if self.entries.is_empty() {
1020            log::info!("[aft] skipping semantic index persistence (0 entries)");
1021            return;
1022        }
1023        let dir = storage_dir.join("semantic").join(project_key);
1024        if let Err(e) = fs::create_dir_all(&dir) {
1025            log::warn!("[aft] failed to create semantic cache dir: {}", e);
1026            return;
1027        }
1028        let data_path = dir.join("semantic.bin");
1029        let tmp_path = dir.join("semantic.bin.tmp");
1030        let bytes = self.to_bytes();
1031        if let Err(e) = fs::write(&tmp_path, &bytes) {
1032            log::warn!("[aft] failed to write semantic index: {}", e);
1033            let _ = fs::remove_file(&tmp_path);
1034            return;
1035        }
1036        if let Err(e) = fs::rename(&tmp_path, &data_path) {
1037            log::warn!("[aft] failed to rename semantic index: {}", e);
1038            let _ = fs::remove_file(&tmp_path);
1039            return;
1040        }
1041        log::info!(
1042            "[aft] semantic index persisted: {} entries, {:.1} KB",
1043            self.entries.len(),
1044            bytes.len() as f64 / 1024.0
1045        );
1046    }
1047
1048    /// Read the semantic index from disk
1049    pub fn read_from_disk(
1050        storage_dir: &Path,
1051        project_key: &str,
1052        expected_fingerprint: Option<&str>,
1053    ) -> Option<Self> {
1054        let data_path = storage_dir
1055            .join("semantic")
1056            .join(project_key)
1057            .join("semantic.bin");
1058        let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1059        if file_len < HEADER_BYTES_V1 {
1060            log::warn!(
1061                "[aft] corrupt semantic index (too small: {} bytes), removing",
1062                file_len
1063            );
1064            let _ = fs::remove_file(&data_path);
1065            return None;
1066        }
1067
1068        let bytes = fs::read(&data_path).ok()?;
1069        match Self::from_bytes(&bytes) {
1070            Ok(index) => {
1071                if index.entries.is_empty() {
1072                    log::info!("[aft] cached semantic index is empty, will rebuild");
1073                    let _ = fs::remove_file(&data_path);
1074                    return None;
1075                }
1076                if let Some(expected) = expected_fingerprint {
1077                    let matches = index
1078                        .fingerprint()
1079                        .map(|fingerprint| fingerprint.matches_expected(expected))
1080                        .unwrap_or(false);
1081                    if !matches {
1082                        log::info!("[aft] cached semantic index fingerprint mismatch, rebuilding");
1083                        let _ = fs::remove_file(&data_path);
1084                        return None;
1085                    }
1086                }
1087                log::info!(
1088                    "[aft] loaded semantic index from disk: {} entries",
1089                    index.entries.len()
1090                );
1091                Some(index)
1092            }
1093            Err(e) => {
1094                log::warn!("[aft] corrupt semantic index, rebuilding: {}", e);
1095                let _ = fs::remove_file(&data_path);
1096                None
1097            }
1098        }
1099    }
1100
1101    /// Serialize the index to bytes for disk persistence
1102    pub fn to_bytes(&self) -> Vec<u8> {
1103        let mut buf = Vec::new();
1104        let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1105            let encoded = fingerprint.as_string();
1106            if encoded.is_empty() {
1107                None
1108            } else {
1109                Some(encoded.into_bytes())
1110            }
1111        });
1112
1113        // Header: version(1) + dimension(4) + entry_count(4) [+ fingerprint_len(4)]
1114        let version = if fingerprint_bytes.is_some() {
1115            SEMANTIC_INDEX_VERSION_V2
1116        } else {
1117            SEMANTIC_INDEX_VERSION_V1
1118        };
1119        buf.push(version);
1120        buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1121        buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1122        if let Some(bytes) = &fingerprint_bytes {
1123            buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
1124            buf.extend_from_slice(bytes);
1125        }
1126
1127        // File mtime table: count(4) + entries
1128        buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1129        for (path, mtime) in &self.file_mtimes {
1130            let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1131            buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1132            buf.extend_from_slice(&path_bytes);
1133            let duration = mtime
1134                .duration_since(SystemTime::UNIX_EPOCH)
1135                .unwrap_or_default();
1136            buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1137        }
1138
1139        // Entries: each is metadata + vector
1140        for entry in &self.entries {
1141            let c = &entry.chunk;
1142
1143            // File path
1144            let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1145            buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1146            buf.extend_from_slice(&file_bytes);
1147
1148            // Name
1149            let name_bytes = c.name.as_bytes();
1150            buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1151            buf.extend_from_slice(name_bytes);
1152
1153            // Kind (1 byte)
1154            buf.push(symbol_kind_to_u8(&c.kind));
1155
1156            // Lines + exported
1157            buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1158            buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1159            buf.push(c.exported as u8);
1160
1161            // Snippet
1162            let snippet_bytes = c.snippet.as_bytes();
1163            buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1164            buf.extend_from_slice(snippet_bytes);
1165
1166            // Embed text
1167            let embed_bytes = c.embed_text.as_bytes();
1168            buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1169            buf.extend_from_slice(embed_bytes);
1170
1171            // Vector (f32 array)
1172            for &val in &entry.vector {
1173                buf.extend_from_slice(&val.to_le_bytes());
1174            }
1175        }
1176
1177        buf
1178    }
1179
1180    /// Deserialize the index from bytes
1181    pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1182        let mut pos = 0;
1183
1184        if data.len() < HEADER_BYTES_V1 {
1185            return Err("data too short".to_string());
1186        }
1187
1188        let version = data[pos];
1189        pos += 1;
1190        if version != SEMANTIC_INDEX_VERSION_V1 && version != SEMANTIC_INDEX_VERSION_V2 {
1191            return Err(format!("unsupported version: {}", version));
1192        }
1193        if version == SEMANTIC_INDEX_VERSION_V2 && data.len() < HEADER_BYTES_V2 {
1194            return Err("data too short for semantic index v2 header".to_string());
1195        }
1196
1197        let dimension = read_u32(data, &mut pos)? as usize;
1198        let entry_count = read_u32(data, &mut pos)? as usize;
1199        if dimension == 0 || dimension > MAX_DIMENSION {
1200            return Err(format!("invalid embedding dimension: {}", dimension));
1201        }
1202        if entry_count > MAX_ENTRIES {
1203            return Err(format!("too many semantic index entries: {}", entry_count));
1204        }
1205
1206        let fingerprint = if version == SEMANTIC_INDEX_VERSION_V2 {
1207            let fingerprint_len = read_u32(data, &mut pos)? as usize;
1208            if pos + fingerprint_len > data.len() {
1209                return Err("unexpected end of data reading fingerprint".to_string());
1210            }
1211            let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1212            pos += fingerprint_len;
1213            Some(
1214                serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1215                    .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1216            )
1217        } else {
1218            None
1219        };
1220
1221        // File mtimes
1222        let mtime_count = read_u32(data, &mut pos)? as usize;
1223        if mtime_count > MAX_ENTRIES {
1224            return Err(format!("too many semantic file mtimes: {}", mtime_count));
1225        }
1226
1227        let vector_bytes = entry_count
1228            .checked_mul(dimension)
1229            .and_then(|count| count.checked_mul(F32_BYTES))
1230            .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1231        if vector_bytes > data.len().saturating_sub(pos) {
1232            return Err("semantic index vectors exceed available data".to_string());
1233        }
1234
1235        let mut file_mtimes = HashMap::with_capacity(mtime_count);
1236        for _ in 0..mtime_count {
1237            let path = read_string(data, &mut pos)?;
1238            let secs = read_u64(data, &mut pos)?;
1239            let mtime = SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(secs);
1240            file_mtimes.insert(PathBuf::from(path), mtime);
1241        }
1242
1243        // Entries
1244        let mut entries = Vec::with_capacity(entry_count);
1245        for _ in 0..entry_count {
1246            let file = PathBuf::from(read_string(data, &mut pos)?);
1247            let name = read_string(data, &mut pos)?;
1248
1249            if pos >= data.len() {
1250                return Err("unexpected end of data".to_string());
1251            }
1252            let kind = u8_to_symbol_kind(data[pos]);
1253            pos += 1;
1254
1255            let start_line = read_u32(data, &mut pos)?;
1256            let end_line = read_u32(data, &mut pos)?;
1257
1258            if pos >= data.len() {
1259                return Err("unexpected end of data".to_string());
1260            }
1261            let exported = data[pos] != 0;
1262            pos += 1;
1263
1264            let snippet = read_string(data, &mut pos)?;
1265            let embed_text = read_string(data, &mut pos)?;
1266
1267            // Vector
1268            let vec_bytes = dimension
1269                .checked_mul(F32_BYTES)
1270                .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1271            if pos + vec_bytes > data.len() {
1272                return Err("unexpected end of data reading vector".to_string());
1273            }
1274            let mut vector = Vec::with_capacity(dimension);
1275            for _ in 0..dimension {
1276                let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1277                vector.push(f32::from_le_bytes(bytes));
1278                pos += 4;
1279            }
1280
1281            entries.push(EmbeddingEntry {
1282                chunk: SemanticChunk {
1283                    file,
1284                    name,
1285                    kind,
1286                    start_line,
1287                    end_line,
1288                    exported,
1289                    embed_text,
1290                    snippet,
1291                },
1292                vector,
1293            });
1294        }
1295
1296        Ok(Self {
1297            entries,
1298            file_mtimes,
1299            dimension,
1300            fingerprint,
1301        })
1302    }
1303}
1304
1305/// Build enriched embedding text from a symbol with cAST-style context
1306fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1307    let relative = file
1308        .strip_prefix(project_root)
1309        .unwrap_or(file)
1310        .to_string_lossy();
1311
1312    let kind_label = match &symbol.kind {
1313        SymbolKind::Function => "function",
1314        SymbolKind::Class => "class",
1315        SymbolKind::Method => "method",
1316        SymbolKind::Struct => "struct",
1317        SymbolKind::Interface => "interface",
1318        SymbolKind::Enum => "enum",
1319        SymbolKind::TypeAlias => "type",
1320        SymbolKind::Variable => "variable",
1321        SymbolKind::Heading => "heading",
1322    };
1323
1324    // Build: "file:relative/path kind:function name:validateAuth signature:fn validateAuth(token: &str) -> bool"
1325    let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1326
1327    if let Some(sig) = &symbol.signature {
1328        text.push_str(&format!(" signature:{}", sig));
1329    }
1330
1331    // Add body snippet (first ~300 chars of symbol body)
1332    let lines: Vec<&str> = source.lines().collect();
1333    let start = (symbol.range.start_line.saturating_sub(1) as usize).min(lines.len()); // 1-based to 0-based
1334    let end = (symbol.range.end_line as usize).min(lines.len()); // 1-based inclusive
1335    if start < end {
1336        let body: String = lines[start..end]
1337            .iter()
1338            .take(15) // max 15 lines
1339            .copied()
1340            .collect::<Vec<&str>>()
1341            .join("\n");
1342        let snippet = if body.len() > 300 {
1343            format!("{}...", &body[..body.floor_char_boundary(300)])
1344        } else {
1345            body
1346        };
1347        text.push_str(&format!(" body:{}", snippet));
1348    }
1349
1350    text
1351}
1352
1353/// Build a display snippet from a symbol's source
1354fn build_snippet(symbol: &Symbol, source: &str) -> String {
1355    let lines: Vec<&str> = source.lines().collect();
1356    let start = (symbol.range.start_line.saturating_sub(1) as usize).min(lines.len());
1357    let end = (symbol.range.end_line as usize).min(lines.len());
1358    if start < end {
1359        let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1360        let mut snippet = snippet_lines.join("\n");
1361        if end - start > 5 {
1362            snippet.push_str("\n  ...");
1363        }
1364        if snippet.len() > 300 {
1365            snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1366        }
1367        snippet
1368    } else {
1369        String::new()
1370    }
1371}
1372
1373/// Convert symbols to semantic chunks with enriched context
1374fn symbols_to_chunks(
1375    file: &Path,
1376    symbols: &[Symbol],
1377    source: &str,
1378    project_root: &Path,
1379) -> Vec<SemanticChunk> {
1380    let mut chunks = Vec::new();
1381
1382    for symbol in symbols {
1383        // Skip very small symbols (single-line variables, etc.)
1384        let line_count = symbol
1385            .range
1386            .end_line
1387            .saturating_sub(symbol.range.start_line)
1388            + 1;
1389        if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1390            continue;
1391        }
1392
1393        let embed_text = build_embed_text(symbol, source, file, project_root);
1394        let snippet = build_snippet(symbol, source);
1395
1396        chunks.push(SemanticChunk {
1397            file: file.to_path_buf(),
1398            name: symbol.name.clone(),
1399            kind: symbol.kind.clone(),
1400            start_line: symbol.range.start_line,
1401            end_line: symbol.range.end_line,
1402            exported: symbol.exported,
1403            embed_text,
1404            snippet,
1405        });
1406
1407        // Note: Nested symbols are handled separately by the outline system
1408        // Each symbol is indexed individually
1409    }
1410
1411    chunks
1412}
1413
1414/// Cosine similarity between two vectors
1415fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1416    if a.len() != b.len() {
1417        return 0.0;
1418    }
1419
1420    let mut dot = 0.0f32;
1421    let mut norm_a = 0.0f32;
1422    let mut norm_b = 0.0f32;
1423
1424    for i in 0..a.len() {
1425        dot += a[i] * b[i];
1426        norm_a += a[i] * a[i];
1427        norm_b += b[i] * b[i];
1428    }
1429
1430    let denom = norm_a.sqrt() * norm_b.sqrt();
1431    if denom == 0.0 {
1432        0.0
1433    } else {
1434        dot / denom
1435    }
1436}
1437
1438// Serialization helpers
1439fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1440    match kind {
1441        SymbolKind::Function => 0,
1442        SymbolKind::Class => 1,
1443        SymbolKind::Method => 2,
1444        SymbolKind::Struct => 3,
1445        SymbolKind::Interface => 4,
1446        SymbolKind::Enum => 5,
1447        SymbolKind::TypeAlias => 6,
1448        SymbolKind::Variable => 7,
1449        SymbolKind::Heading => 8,
1450    }
1451}
1452
1453fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1454    match v {
1455        0 => SymbolKind::Function,
1456        1 => SymbolKind::Class,
1457        2 => SymbolKind::Method,
1458        3 => SymbolKind::Struct,
1459        4 => SymbolKind::Interface,
1460        5 => SymbolKind::Enum,
1461        6 => SymbolKind::TypeAlias,
1462        7 => SymbolKind::Variable,
1463        _ => SymbolKind::Heading,
1464    }
1465}
1466
1467fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
1468    if *pos + 4 > data.len() {
1469        return Err("unexpected end of data reading u32".to_string());
1470    }
1471    let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
1472    *pos += 4;
1473    Ok(val)
1474}
1475
1476fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
1477    if *pos + 8 > data.len() {
1478        return Err("unexpected end of data reading u64".to_string());
1479    }
1480    let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
1481    *pos += 8;
1482    Ok(u64::from_le_bytes(bytes))
1483}
1484
1485fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
1486    let len = read_u32(data, pos)? as usize;
1487    if *pos + len > data.len() {
1488        return Err("unexpected end of data reading string".to_string());
1489    }
1490    let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
1491    *pos += len;
1492    Ok(s)
1493}
1494
1495#[cfg(test)]
1496mod tests {
1497    use super::*;
1498    use crate::config::{SemanticBackend, SemanticBackendConfig};
1499    use std::io::{Read, Write};
1500    use std::net::TcpListener;
1501    use std::thread;
1502
1503    fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
1504    where
1505        F: Fn(String, String, String) -> String + Send + 'static,
1506    {
1507        let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1508        let addr = listener.local_addr().expect("local addr");
1509        let handle = thread::spawn(move || {
1510            let (mut stream, _) = listener.accept().expect("accept request");
1511            let mut buf = Vec::new();
1512            let mut chunk = [0u8; 4096];
1513            let mut header_end = None;
1514            let mut content_length = 0usize;
1515            loop {
1516                let n = stream.read(&mut chunk).expect("read request");
1517                if n == 0 {
1518                    break;
1519                }
1520                buf.extend_from_slice(&chunk[..n]);
1521                if header_end.is_none() {
1522                    if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
1523                        header_end = Some(pos + 4);
1524                        let headers = String::from_utf8_lossy(&buf[..pos + 4]);
1525                        for line in headers.lines() {
1526                            if let Some(value) = line.strip_prefix("Content-Length:") {
1527                                content_length = value.trim().parse::<usize>().unwrap_or(0);
1528                            }
1529                        }
1530                    }
1531                }
1532                if let Some(end) = header_end {
1533                    if buf.len() >= end + content_length {
1534                        break;
1535                    }
1536                }
1537            }
1538
1539            let end = header_end.expect("header terminator");
1540            let request = String::from_utf8_lossy(&buf[..end]).to_string();
1541            let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
1542            let mut lines = request.lines();
1543            let request_line = lines.next().expect("request line").to_string();
1544            let path = request_line
1545                .split_whitespace()
1546                .nth(1)
1547                .expect("request path")
1548                .to_string();
1549            let response_body = handler(request_line, path, body);
1550            let response = format!(
1551                "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
1552                response_body.len(),
1553                response_body
1554            );
1555            stream
1556                .write_all(response.as_bytes())
1557                .expect("write response");
1558        });
1559
1560        (format!("http://{}", addr), handle)
1561    }
1562
1563    #[test]
1564    fn test_cosine_similarity_identical() {
1565        let a = vec![1.0, 0.0, 0.0];
1566        let b = vec![1.0, 0.0, 0.0];
1567        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1568    }
1569
1570    #[test]
1571    fn test_cosine_similarity_orthogonal() {
1572        let a = vec![1.0, 0.0, 0.0];
1573        let b = vec![0.0, 1.0, 0.0];
1574        assert!(cosine_similarity(&a, &b).abs() < 0.001);
1575    }
1576
1577    #[test]
1578    fn test_cosine_similarity_opposite() {
1579        let a = vec![1.0, 0.0, 0.0];
1580        let b = vec![-1.0, 0.0, 0.0];
1581        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
1582    }
1583
1584    #[test]
1585    fn test_serialization_roundtrip() {
1586        let mut index = SemanticIndex::new();
1587        index.entries.push(EmbeddingEntry {
1588            chunk: SemanticChunk {
1589                file: PathBuf::from("/src/main.rs"),
1590                name: "handle_request".to_string(),
1591                kind: SymbolKind::Function,
1592                start_line: 10,
1593                end_line: 25,
1594                exported: true,
1595                embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1596                snippet: "fn handle_request() {\n  // ...\n}".to_string(),
1597            },
1598            vector: vec![0.1, 0.2, 0.3, 0.4],
1599        });
1600        index.dimension = 4;
1601        index
1602            .file_mtimes
1603            .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1604        index.set_fingerprint(SemanticIndexFingerprint {
1605            backend: "fastembed".to_string(),
1606            model: "all-MiniLM-L6-v2".to_string(),
1607            base_url: FALLBACK_BACKEND.to_string(),
1608            dimension: 4,
1609        });
1610
1611        let bytes = index.to_bytes();
1612        let restored = SemanticIndex::from_bytes(&bytes).unwrap();
1613
1614        assert_eq!(restored.entries.len(), 1);
1615        assert_eq!(restored.entries[0].chunk.name, "handle_request");
1616        assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
1617        assert_eq!(restored.dimension, 4);
1618        assert_eq!(restored.backend_label(), Some("fastembed"));
1619        assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
1620    }
1621
1622    #[test]
1623    fn test_search_top_k() {
1624        let mut index = SemanticIndex::new();
1625        index.dimension = 3;
1626
1627        // Add entries with known vectors
1628        for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
1629            let mut vec = vec![0.0f32; 3];
1630            vec[i] = 1.0; // orthogonal vectors
1631            index.entries.push(EmbeddingEntry {
1632                chunk: SemanticChunk {
1633                    file: PathBuf::from("/src/lib.rs"),
1634                    name: name.to_string(),
1635                    kind: SymbolKind::Function,
1636                    start_line: (i * 10 + 1) as u32,
1637                    end_line: (i * 10 + 5) as u32,
1638                    exported: true,
1639                    embed_text: format!("kind:function name:{}", name),
1640                    snippet: format!("fn {}() {{}}", name),
1641                },
1642                vector: vec,
1643            });
1644        }
1645
1646        // Query aligned with "auth" (index 0)
1647        let query = vec![0.9, 0.1, 0.0];
1648        let results = index.search(&query, 2);
1649
1650        assert_eq!(results.len(), 2);
1651        assert_eq!(results[0].name, "auth"); // highest score
1652        assert!(results[0].score > results[1].score);
1653    }
1654
1655    #[test]
1656    fn test_empty_index_search() {
1657        let index = SemanticIndex::new();
1658        let results = index.search(&[0.1, 0.2, 0.3], 10);
1659        assert!(results.is_empty());
1660    }
1661
1662    #[test]
1663    fn rejects_oversized_dimension_during_deserialization() {
1664        let mut bytes = Vec::new();
1665        bytes.push(1u8);
1666        bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
1667        bytes.extend_from_slice(&0u32.to_le_bytes());
1668        bytes.extend_from_slice(&0u32.to_le_bytes());
1669
1670        assert!(SemanticIndex::from_bytes(&bytes).is_err());
1671    }
1672
1673    #[test]
1674    fn rejects_oversized_entry_count_during_deserialization() {
1675        let mut bytes = Vec::new();
1676        bytes.push(1u8);
1677        bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
1678        bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
1679        bytes.extend_from_slice(&0u32.to_le_bytes());
1680
1681        assert!(SemanticIndex::from_bytes(&bytes).is_err());
1682    }
1683
1684    #[test]
1685    fn invalidate_file_removes_entries_and_mtime() {
1686        let target = PathBuf::from("/src/main.rs");
1687        let mut index = SemanticIndex::new();
1688        index.entries.push(EmbeddingEntry {
1689            chunk: SemanticChunk {
1690                file: target.clone(),
1691                name: "main".to_string(),
1692                kind: SymbolKind::Function,
1693                start_line: 0,
1694                end_line: 1,
1695                exported: false,
1696                embed_text: "main".to_string(),
1697                snippet: "fn main() {}".to_string(),
1698            },
1699            vector: vec![1.0; DEFAULT_DIMENSION],
1700        });
1701        index
1702            .file_mtimes
1703            .insert(target.clone(), SystemTime::UNIX_EPOCH);
1704
1705        index.invalidate_file(&target);
1706
1707        assert!(index.entries.is_empty());
1708        assert!(!index.file_mtimes.contains_key(&target));
1709    }
1710
1711    #[test]
1712    fn detects_missing_onnx_runtime_from_dynamic_load_error() {
1713        let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
1714
1715        assert!(is_onnx_runtime_unavailable(message));
1716    }
1717
1718    #[test]
1719    fn formats_missing_onnx_runtime_with_install_hint() {
1720        let message = format_embedding_init_error(
1721            "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
1722        );
1723
1724        assert!(message.starts_with("ONNX Runtime not found. Install via:"));
1725        assert!(message.contains("Original error:"));
1726    }
1727
1728    #[test]
1729    fn openai_compatible_backend_embeds_with_mock_server() {
1730        let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1731            assert!(request_line.starts_with("POST "));
1732            assert_eq!(path, "/v1/embeddings");
1733            "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
1734        });
1735
1736        let config = SemanticBackendConfig {
1737            backend: SemanticBackend::OpenAiCompatible,
1738            model: "test-embedding".to_string(),
1739            base_url: Some(base_url),
1740            api_key_env: None,
1741            timeout_ms: 5_000,
1742            max_batch_size: 64,
1743        };
1744
1745        let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1746        let vectors = model
1747            .embed(vec!["hello".to_string(), "world".to_string()])
1748            .unwrap();
1749
1750        assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
1751        handle.join().unwrap();
1752    }
1753
1754    #[test]
1755    fn ollama_backend_embeds_with_mock_server() {
1756        let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1757            assert!(request_line.starts_with("POST "));
1758            assert_eq!(path, "/api/embed");
1759            "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
1760        });
1761
1762        let config = SemanticBackendConfig {
1763            backend: SemanticBackend::Ollama,
1764            model: "embeddinggemma".to_string(),
1765            base_url: Some(base_url),
1766            api_key_env: None,
1767            timeout_ms: 5_000,
1768            max_batch_size: 64,
1769        };
1770
1771        let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1772        let vectors = model
1773            .embed(vec!["hello".to_string(), "world".to_string()])
1774            .unwrap();
1775
1776        assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
1777        handle.join().unwrap();
1778    }
1779
1780    #[test]
1781    fn read_from_disk_rejects_fingerprint_mismatch() {
1782        let storage = tempfile::tempdir().unwrap();
1783        let project_key = "proj";
1784
1785        let mut index = SemanticIndex::new();
1786        index.entries.push(EmbeddingEntry {
1787            chunk: SemanticChunk {
1788                file: PathBuf::from("/src/main.rs"),
1789                name: "handle_request".to_string(),
1790                kind: SymbolKind::Function,
1791                start_line: 10,
1792                end_line: 25,
1793                exported: true,
1794                embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1795                snippet: "fn handle_request() {}".to_string(),
1796            },
1797            vector: vec![0.1, 0.2, 0.3],
1798        });
1799        index.dimension = 3;
1800        index
1801            .file_mtimes
1802            .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1803        index.set_fingerprint(SemanticIndexFingerprint {
1804            backend: "openai_compatible".to_string(),
1805            model: "test-embedding".to_string(),
1806            base_url: "http://127.0.0.1:1234/v1".to_string(),
1807            dimension: 3,
1808        });
1809        index.write_to_disk(storage.path(), project_key);
1810
1811        let matching = index.fingerprint().unwrap().as_string();
1812        assert!(
1813            SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
1814        );
1815
1816        let mismatched = SemanticIndexFingerprint {
1817            backend: "ollama".to_string(),
1818            model: "embeddinggemma".to_string(),
1819            base_url: "http://127.0.0.1:11434".to_string(),
1820            dimension: 3,
1821        }
1822        .as_string();
1823        assert!(
1824            SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
1825        );
1826    }
1827}