Skip to main content

aft/
semantic_index.rs

1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::FileParser;
3use crate::symbols::{Symbol, SymbolKind};
4
5use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
6use reqwest::blocking::Client;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::env;
10use std::fmt::Display;
11use std::fs;
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14use std::time::SystemTime;
15use url::Url;
16
17const DEFAULT_DIMENSION: usize = 384;
18const MAX_ENTRIES: usize = 1_000_000;
19const MAX_DIMENSION: usize = 1024;
20const F32_BYTES: usize = std::mem::size_of::<f32>();
21const HEADER_BYTES_V1: usize = 9;
22const HEADER_BYTES_V2: usize = 13;
23const ONNX_RUNTIME_INSTALL_HINT: &str =
24    "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
25
26const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
27const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
28/// V3 adds subsec_nanos to the file-mtime table so staleness detection survives
29/// restart round-trips on filesystems with subsecond mtime precision (APFS,
30/// ext4 with nsec, NTFS). V1/V2 persisted whole-second mtimes only, which
31/// caused every restart to flag ~99% of files as stale and re-embed them.
32const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
33const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
34const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
35// Must stay below the bridge timeout (30s) to avoid bridge kills on slow backends.
36const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
37const DEFAULT_MAX_BATCH_SIZE: usize = 64;
38const FALLBACK_BACKEND: &str = "none";
39const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
40const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct SemanticIndexFingerprint {
44    pub backend: String,
45    pub model: String,
46    #[serde(default)]
47    pub base_url: String,
48    pub dimension: usize,
49}
50
51impl SemanticIndexFingerprint {
52    fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
53        // Use normalized URL for fingerprinting so cosmetic differences
54        // (e.g. "http://host/v1" vs "http://host/v1/") don't cause rebuilds.
55        let base_url = config
56            .base_url
57            .as_ref()
58            .and_then(|u| normalize_base_url(u).ok())
59            .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
60        Self {
61            backend: config.backend.as_str().to_string(),
62            model: config.model.clone(),
63            base_url,
64            dimension,
65        }
66    }
67
68    pub fn as_string(&self) -> String {
69        serde_json::to_string(self).unwrap_or_else(|_| String::new())
70    }
71
72    fn matches_expected(&self, expected: &str) -> bool {
73        let encoded = self.as_string();
74        !encoded.is_empty() && encoded == expected
75    }
76}
77
78enum SemanticEmbeddingEngine {
79    Fastembed(TextEmbedding),
80    OpenAiCompatible {
81        client: Client,
82        model: String,
83        base_url: String,
84        api_key: Option<String>,
85    },
86    Ollama {
87        client: Client,
88        model: String,
89        base_url: String,
90    },
91}
92
93pub struct SemanticEmbeddingModel {
94    backend: SemanticBackend,
95    model: String,
96    base_url: Option<String>,
97    timeout_ms: u64,
98    max_batch_size: usize,
99    dimension: Option<usize>,
100    engine: SemanticEmbeddingEngine,
101}
102
103pub type EmbeddingModel = SemanticEmbeddingModel;
104
105fn validate_embedding_batch(
106    vectors: &[Vec<f32>],
107    expected_count: usize,
108    context: &str,
109) -> Result<(), String> {
110    if expected_count > 0 && vectors.is_empty() {
111        return Err(format!(
112            "{context} returned no vectors for {expected_count} inputs"
113        ));
114    }
115
116    if vectors.len() != expected_count {
117        return Err(format!(
118            "{context} returned {} vectors for {} inputs",
119            vectors.len(),
120            expected_count
121        ));
122    }
123
124    let Some(first_vector) = vectors.first() else {
125        return Ok(());
126    };
127    let expected_dimension = first_vector.len();
128    for (index, vector) in vectors.iter().enumerate() {
129        if vector.len() != expected_dimension {
130            return Err(format!(
131                "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
132                vector.len()
133            ));
134        }
135    }
136
137    Ok(())
138}
139
140fn normalize_base_url(raw: &str) -> Result<String, String> {
141    let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
142    let scheme = parsed.scheme();
143    if scheme != "http" && scheme != "https" {
144        return Err(format!(
145            "unsupported URL scheme '{}' — only http:// and https:// are allowed",
146            scheme
147        ));
148    }
149    Ok(parsed.to_string().trim_end_matches('/').to_string())
150}
151
152fn build_openai_embeddings_endpoint(base_url: &str) -> String {
153    if base_url.ends_with("/v1") {
154        format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
155    } else {
156        format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
157    }
158}
159
160fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
161    if base_url.ends_with("/api") {
162        format!("{base_url}/embed")
163    } else {
164        format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
165    }
166}
167
168fn normalize_api_key(value: Option<String>) -> Option<String> {
169    value.and_then(|token| {
170        let token = token.trim();
171        if token.is_empty() {
172            None
173        } else {
174            Some(token.to_string())
175        }
176    })
177}
178
179fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
180    status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
181}
182
183fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
184    error.is_connect()
185}
186
187fn sleep_before_embedding_retry(attempt_index: usize) {
188    if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
189        std::thread::sleep(Duration::from_millis(*delay_ms));
190    }
191}
192
193fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
194where
195    F: FnMut() -> reqwest::blocking::RequestBuilder,
196{
197    for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
198        let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
199
200        let response = match make_request().send() {
201            Ok(response) => response,
202            Err(error) => {
203                if !last_attempt && is_retryable_embedding_error(&error) {
204                    sleep_before_embedding_retry(attempt_index);
205                    continue;
206                }
207                return Err(format!("{backend_label} request failed: {error}"));
208            }
209        };
210
211        let status = response.status();
212        let raw = match response.text() {
213            Ok(raw) => raw,
214            Err(error) => {
215                if !last_attempt && is_retryable_embedding_error(&error) {
216                    sleep_before_embedding_retry(attempt_index);
217                    continue;
218                }
219                return Err(format!("{backend_label} response read failed: {error}"));
220            }
221        };
222
223        if status.is_success() {
224            return Ok(raw);
225        }
226
227        if !last_attempt && is_retryable_embedding_status(status) {
228            sleep_before_embedding_retry(attempt_index);
229            continue;
230        }
231
232        return Err(format!(
233            "{backend_label} request failed (HTTP {}): {}",
234            status, raw
235        ));
236    }
237
238    unreachable!("embedding request retries exhausted without returning")
239}
240
241impl SemanticEmbeddingModel {
242    pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
243        let timeout_ms = if config.timeout_ms == 0 {
244            DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
245        } else {
246            config.timeout_ms
247        };
248
249        let max_batch_size = if config.max_batch_size == 0 {
250            DEFAULT_MAX_BATCH_SIZE
251        } else {
252            config.max_batch_size
253        };
254
255        let api_key_env = normalize_api_key(config.api_key_env.clone());
256        let model = config.model.clone();
257
258        let client = Client::builder()
259            .timeout(Duration::from_millis(timeout_ms))
260            .redirect(reqwest::redirect::Policy::none())
261            .build()
262            .map_err(|error| format!("failed to configure embedding client: {error}"))?;
263
264        let engine = match config.backend {
265            SemanticBackend::Fastembed => {
266                SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
267            }
268            SemanticBackend::OpenAiCompatible => {
269                let raw = config.base_url.as_ref().ok_or_else(|| {
270                    "base_url is required for openai_compatible backend".to_string()
271                })?;
272                let base_url = normalize_base_url(raw)?;
273
274                let api_key = match api_key_env {
275                    Some(var_name) => Some(env::var(&var_name).map_err(|_| {
276                        format!("missing api_key_env '{var_name}' for openai_compatible backend")
277                    })?),
278                    None => None,
279                };
280
281                SemanticEmbeddingEngine::OpenAiCompatible {
282                    client,
283                    model,
284                    base_url,
285                    api_key,
286                }
287            }
288            SemanticBackend::Ollama => {
289                let raw = config
290                    .base_url
291                    .as_ref()
292                    .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
293                let base_url = normalize_base_url(raw)?;
294
295                SemanticEmbeddingEngine::Ollama {
296                    client,
297                    model,
298                    base_url,
299                }
300            }
301        };
302
303        Ok(Self {
304            backend: config.backend,
305            model: config.model.clone(),
306            base_url: config.base_url.clone(),
307            timeout_ms,
308            max_batch_size,
309            dimension: None,
310            engine,
311        })
312    }
313
314    pub fn backend(&self) -> SemanticBackend {
315        self.backend
316    }
317
318    pub fn model(&self) -> &str {
319        &self.model
320    }
321
322    pub fn base_url(&self) -> Option<&str> {
323        self.base_url.as_deref()
324    }
325
326    pub fn max_batch_size(&self) -> usize {
327        self.max_batch_size
328    }
329
330    pub fn timeout_ms(&self) -> u64 {
331        self.timeout_ms
332    }
333
334    pub fn fingerprint(
335        &mut self,
336        config: &SemanticBackendConfig,
337    ) -> Result<SemanticIndexFingerprint, String> {
338        let dimension = self.dimension()?;
339        Ok(SemanticIndexFingerprint::from_config(config, dimension))
340    }
341
342    pub fn dimension(&mut self) -> Result<usize, String> {
343        if let Some(dimension) = self.dimension {
344            return Ok(dimension);
345        }
346
347        let dimension = match &mut self.engine {
348            SemanticEmbeddingEngine::Fastembed(model) => {
349                let vectors = model
350                    .embed(vec!["semantic index fingerprint probe".to_string()], None)
351                    .map_err(|error| format_embedding_init_error(error.to_string()))?;
352                vectors
353                    .first()
354                    .map(|v| v.len())
355                    .ok_or_else(|| "embedding backend returned no vectors".to_string())?
356            }
357            SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
358                let vectors =
359                    self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
360                vectors
361                    .first()
362                    .map(|v| v.len())
363                    .ok_or_else(|| "embedding backend returned no vectors".to_string())?
364            }
365            SemanticEmbeddingEngine::Ollama { .. } => {
366                let vectors =
367                    self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
368                vectors
369                    .first()
370                    .map(|v| v.len())
371                    .ok_or_else(|| "embedding backend returned no vectors".to_string())?
372            }
373        };
374
375        self.dimension = Some(dimension);
376        Ok(dimension)
377    }
378
379    pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
380        self.embed_texts(texts)
381    }
382
383    fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
384        match &mut self.engine {
385            SemanticEmbeddingEngine::Fastembed(model) => model
386                .embed(texts, None::<usize>)
387                .map_err(|error| format_embedding_init_error(error.to_string()))
388                .map_err(|error| format!("failed to embed batch: {error}")),
389            SemanticEmbeddingEngine::OpenAiCompatible {
390                client,
391                model,
392                base_url,
393                api_key,
394            } => {
395                let expected_text_count = texts.len();
396                let endpoint = build_openai_embeddings_endpoint(base_url);
397                let body = serde_json::json!({
398                    "input": texts,
399                    "model": model,
400                });
401
402                let raw = send_embedding_request(
403                    || {
404                        let mut request = client
405                            .post(&endpoint)
406                            .json(&body)
407                            .header("Content-Type", "application/json");
408
409                        if let Some(api_key) = api_key {
410                            request = request.header("Authorization", format!("Bearer {api_key}"));
411                        }
412
413                        request
414                    },
415                    "openai compatible",
416                )?;
417
418                #[derive(Deserialize)]
419                struct OpenAiResponse {
420                    data: Vec<OpenAiEmbeddingResult>,
421                }
422
423                #[derive(Deserialize)]
424                struct OpenAiEmbeddingResult {
425                    embedding: Vec<f32>,
426                    index: Option<u32>,
427                }
428
429                let parsed: OpenAiResponse = serde_json::from_str(&raw)
430                    .map_err(|error| format!("invalid openai compatible response: {error}"))?;
431                if parsed.data.len() != expected_text_count {
432                    return Err(format!(
433                        "openai compatible response returned {} embeddings for {} inputs",
434                        parsed.data.len(),
435                        expected_text_count
436                    ));
437                }
438
439                let mut vectors = vec![Vec::new(); parsed.data.len()];
440                for (i, item) in parsed.data.into_iter().enumerate() {
441                    let index = item.index.unwrap_or(i as u32) as usize;
442                    if index >= vectors.len() {
443                        return Err(
444                            "openai compatible response contains invalid vector index".to_string()
445                        );
446                    }
447                    vectors[index] = item.embedding;
448                }
449
450                for vector in &vectors {
451                    if vector.is_empty() {
452                        return Err(
453                            "openai compatible response contained missing vectors".to_string()
454                        );
455                    }
456                }
457
458                self.dimension = vectors.first().map(Vec::len);
459                Ok(vectors)
460            }
461            SemanticEmbeddingEngine::Ollama {
462                client,
463                model,
464                base_url,
465            } => {
466                let expected_text_count = texts.len();
467                let endpoint = build_ollama_embeddings_endpoint(base_url);
468
469                #[derive(Serialize)]
470                struct OllamaPayload<'a> {
471                    model: &'a str,
472                    input: Vec<String>,
473                }
474
475                let payload = OllamaPayload {
476                    model,
477                    input: texts,
478                };
479
480                let raw = send_embedding_request(
481                    || {
482                        client
483                            .post(&endpoint)
484                            .json(&payload)
485                            .header("Content-Type", "application/json")
486                    },
487                    "ollama",
488                )?;
489
490                #[derive(Deserialize)]
491                struct OllamaResponse {
492                    embeddings: Vec<Vec<f32>>,
493                }
494
495                let parsed: OllamaResponse = serde_json::from_str(&raw)
496                    .map_err(|error| format!("invalid ollama response: {error}"))?;
497                if parsed.embeddings.is_empty() {
498                    return Err("ollama response returned no embeddings".to_string());
499                }
500                if parsed.embeddings.len() != expected_text_count {
501                    return Err(format!(
502                        "ollama response returned {} embeddings for {} inputs",
503                        parsed.embeddings.len(),
504                        expected_text_count
505                    ));
506                }
507
508                let vectors = parsed.embeddings;
509                for vector in &vectors {
510                    if vector.is_empty() {
511                        return Err("ollama response contained empty embeddings".to_string());
512                    }
513                }
514
515                self.dimension = vectors.first().map(Vec::len);
516                Ok(vectors)
517            }
518        }
519    }
520}
521
522/// Pre-validate ONNX Runtime by attempting a raw dlopen before ort touches it.
523/// This catches broken/incompatible .so files without risking a panic in the ort crate.
524/// Also checks the runtime version via OrtGetApiBase if available.
525pub fn pre_validate_onnx_runtime() -> Result<(), String> {
526    let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
527
528    #[cfg(any(target_os = "linux", target_os = "macos"))]
529    {
530        #[cfg(target_os = "linux")]
531        let default_name = "libonnxruntime.so";
532        #[cfg(target_os = "macos")]
533        let default_name = "libonnxruntime.dylib";
534
535        let lib_name = dylib_path.as_deref().unwrap_or(default_name);
536
537        unsafe {
538            let c_name = std::ffi::CString::new(lib_name)
539                .map_err(|e| format!("invalid library path: {}", e))?;
540            let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
541            if handle.is_null() {
542                let err = libc::dlerror();
543                let msg = if err.is_null() {
544                    "unknown dlopen error".to_string()
545                } else {
546                    std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
547                };
548                return Err(format!(
549                    "ONNX Runtime not found. dlopen('{}') failed: {}. \
550                     Run `bunx @cortexkit/aft-opencode@latest doctor` to diagnose.",
551                    lib_name, msg
552                ));
553            }
554
555            // Try to detect the runtime version from the file path or soname.
556            // libonnxruntime.so.1.19.0, libonnxruntime.1.24.4.dylib, etc.
557            let detected_version = detect_ort_version_from_path(lib_name);
558
559            libc::dlclose(handle);
560
561            // Check version compatibility — we need 1.24.x
562            if let Some(ref version) = detected_version {
563                let parts: Vec<&str> = version.split('.').collect();
564                if let (Some(major), Some(minor)) = (
565                    parts.first().and_then(|s| s.parse::<u32>().ok()),
566                    parts.get(1).and_then(|s| s.parse::<u32>().ok()),
567                ) {
568                    if major != 1 || minor < 20 {
569                        return Err(format!(
570                            "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
571                             Solutions:\n\
572                             1. Remove the old library and restart (AFT auto-downloads the correct version):\n\
573                             {}\n\
574                             2. Or install ONNX Runtime 1.24: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
575                             3. Run `bunx @cortexkit/aft-opencode@latest doctor` for full diagnostics.",
576                            version,
577                            lib_name,
578                            suggest_removal_command(lib_name),
579                        ));
580                    }
581                }
582            }
583        }
584    }
585
586    #[cfg(target_os = "windows")]
587    {
588        // On Windows, skip pre-validation — let ort handle LoadLibrary
589        let _ = dylib_path;
590    }
591
592    Ok(())
593}
594
595/// Try to extract the ORT version from the library filename or resolved symlink.
596/// Examples: "libonnxruntime.so.1.19.0" → "1.19.0", "libonnxruntime.1.24.4.dylib" → "1.24.4"
597fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
598    let path = std::path::Path::new(lib_path);
599
600    // Try the path as given, then follow symlinks
601    for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
602        .into_iter()
603        .flatten()
604    {
605        if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
606            if let Some(version) = extract_version_from_filename(name) {
607                return Some(version);
608            }
609        }
610    }
611
612    // Also check for versioned siblings in the same directory
613    if let Some(parent) = path.parent() {
614        if let Ok(entries) = std::fs::read_dir(parent) {
615            for entry in entries.flatten() {
616                if let Some(name) = entry.file_name().to_str() {
617                    if name.starts_with("libonnxruntime") {
618                        if let Some(version) = extract_version_from_filename(name) {
619                            return Some(version);
620                        }
621                    }
622                }
623            }
624        }
625    }
626
627    None
628}
629
630/// Extract version from filenames like "libonnxruntime.so.1.19.0" or "libonnxruntime.1.24.4.dylib"
631fn extract_version_from_filename(name: &str) -> Option<String> {
632    // Match patterns: .so.X.Y.Z or .X.Y.Z.dylib or .X.Y.Z.so
633    let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
634    re.find(name).map(|m| m.as_str().to_string())
635}
636
637fn suggest_removal_command(lib_path: &str) -> String {
638    if lib_path.starts_with("/usr/local/lib")
639        || lib_path == "libonnxruntime.so"
640        || lib_path == "libonnxruntime.dylib"
641    {
642        #[cfg(target_os = "linux")]
643        return "   sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
644        #[cfg(target_os = "macos")]
645        return "   sudo rm /usr/local/lib/libonnxruntime*".to_string();
646        #[cfg(target_os = "windows")]
647        return "   Delete the ONNX Runtime DLL from your PATH".to_string();
648    }
649    format!("   rm '{}'", lib_path)
650}
651
652pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
653    // Pre-validate before ort can panic on a bad library
654    pre_validate_onnx_runtime()?;
655
656    let selected_model = match model {
657        "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
658        _ => {
659            return Err(format!(
660                "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
661                model
662            ))
663        }
664    };
665
666    TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
667}
668
669pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
670    if message.trim_start().starts_with("ONNX Runtime not found.") {
671        return true;
672    }
673
674    let message = message.to_ascii_lowercase();
675    let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
676        .iter()
677        .any(|pattern| message.contains(pattern));
678    let mentions_dynamic_load_failure = [
679        "shared library",
680        "dynamic library",
681        "failed to load",
682        "could not load",
683        "unable to load",
684        "dlopen",
685        "loadlibrary",
686        "no such file",
687        "not found",
688    ]
689    .iter()
690    .any(|pattern| message.contains(pattern));
691
692    mentions_onnx_runtime && mentions_dynamic_load_failure
693}
694
695fn format_embedding_init_error(error: impl Display) -> String {
696    let message = error.to_string();
697
698    if is_onnx_runtime_unavailable(&message) {
699        return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
700    }
701
702    format!("failed to initialize semantic embedding model: {message}")
703}
704
705/// A chunk of code ready for embedding — derived from a Symbol with context enrichment
706#[derive(Debug, Clone)]
707pub struct SemanticChunk {
708    /// Absolute file path
709    pub file: PathBuf,
710    /// Symbol name
711    pub name: String,
712    /// Symbol kind (function, class, struct, etc.)
713    pub kind: SymbolKind,
714    /// Line range (0-based internally, inclusive)
715    pub start_line: u32,
716    pub end_line: u32,
717    /// Whether the symbol is exported
718    pub exported: bool,
719    /// The enriched text that gets embedded (scope + signature + body snippet)
720    pub embed_text: String,
721    /// Short code snippet for display in results
722    pub snippet: String,
723}
724
725/// A stored embedding entry — chunk metadata + vector
726#[derive(Debug)]
727struct EmbeddingEntry {
728    chunk: SemanticChunk,
729    vector: Vec<f32>,
730}
731
732/// The semantic index — stores embeddings for all symbols in a project
733#[derive(Debug)]
734pub struct SemanticIndex {
735    entries: Vec<EmbeddingEntry>,
736    /// Track which files are indexed and their mtime for staleness detection
737    file_mtimes: HashMap<PathBuf, SystemTime>,
738    /// Embedding dimension (384 for MiniLM-L6-v2)
739    dimension: usize,
740    fingerprint: Option<SemanticIndexFingerprint>,
741}
742
743/// Search result from a semantic query
744#[derive(Debug)]
745pub struct SemanticResult {
746    pub file: PathBuf,
747    pub name: String,
748    pub kind: SymbolKind,
749    pub start_line: u32,
750    pub end_line: u32,
751    pub exported: bool,
752    pub snippet: String,
753    pub score: f32,
754}
755
756impl SemanticIndex {
757    pub fn new() -> Self {
758        Self {
759            entries: Vec::new(),
760            file_mtimes: HashMap::new(),
761            dimension: DEFAULT_DIMENSION, // MiniLM-L6-v2 default
762            fingerprint: None,
763        }
764    }
765
766    /// Number of embedded symbol entries.
767    pub fn entry_count(&self) -> usize {
768        self.entries.len()
769    }
770
771    /// Human-readable status label for the index.
772    pub fn status_label(&self) -> &'static str {
773        if self.entries.is_empty() {
774            "empty"
775        } else {
776            "ready"
777        }
778    }
779
780    fn collect_chunks(
781        project_root: &Path,
782        files: &[PathBuf],
783    ) -> (Vec<SemanticChunk>, HashMap<PathBuf, SystemTime>) {
784        let mut parser = FileParser::new();
785        let mut chunks: Vec<SemanticChunk> = Vec::new();
786        let mut file_mtimes: HashMap<PathBuf, SystemTime> = HashMap::new();
787
788        for file in files {
789            let mtime = std::fs::metadata(file)
790                .and_then(|m| m.modified())
791                .unwrap_or(SystemTime::UNIX_EPOCH);
792            file_mtimes.insert(file.clone(), mtime);
793
794            let source = match std::fs::read_to_string(file) {
795                Ok(s) => s,
796                Err(_) => continue,
797            };
798
799            let symbols = match parser.extract_symbols(file) {
800                Ok(s) => s,
801                Err(_) => continue,
802            };
803            let file_chunks = symbols_to_chunks(file, &symbols, &source, project_root);
804            chunks.extend(file_chunks);
805        }
806
807        (chunks, file_mtimes)
808    }
809
810    fn build_from_chunks<F, P>(
811        chunks: Vec<SemanticChunk>,
812        file_mtimes: HashMap<PathBuf, SystemTime>,
813        embed_fn: &mut F,
814        max_batch_size: usize,
815        mut progress: Option<&mut P>,
816    ) -> Result<Self, String>
817    where
818        F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
819        P: FnMut(usize, usize),
820    {
821        let total_chunks = chunks.len();
822
823        if chunks.is_empty() {
824            return Ok(Self {
825                entries: Vec::new(),
826                file_mtimes,
827                dimension: DEFAULT_DIMENSION,
828                fingerprint: None,
829            });
830        }
831
832        // Embed in batches
833        let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
834        let mut expected_dimension: Option<usize> = None;
835        let batch_size = max_batch_size.max(1);
836        for batch_start in (0..chunks.len()).step_by(batch_size) {
837            let batch_end = (batch_start + batch_size).min(chunks.len());
838            let batch_texts: Vec<String> = chunks[batch_start..batch_end]
839                .iter()
840                .map(|c| c.embed_text.clone())
841                .collect();
842
843            let vectors = embed_fn(batch_texts)?;
844            validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
845
846            // Track consistent dimension across all batches
847            if let Some(dim) = vectors.first().map(|v| v.len()) {
848                match expected_dimension {
849                    None => expected_dimension = Some(dim),
850                    Some(expected) if dim != expected => {
851                        return Err(format!(
852                            "embedding dimension changed across batches: expected {expected}, got {dim}"
853                        ));
854                    }
855                    _ => {}
856                }
857            }
858
859            for (i, vector) in vectors.into_iter().enumerate() {
860                let chunk_idx = batch_start + i;
861                entries.push(EmbeddingEntry {
862                    chunk: chunks[chunk_idx].clone(),
863                    vector,
864                });
865            }
866
867            if let Some(callback) = progress.as_mut() {
868                callback(entries.len(), total_chunks);
869            }
870        }
871
872        let dimension = entries
873            .first()
874            .map(|e| e.vector.len())
875            .unwrap_or(DEFAULT_DIMENSION);
876
877        Ok(Self {
878            entries,
879            file_mtimes,
880            dimension,
881            fingerprint: None,
882        })
883    }
884
885    /// Build the semantic index from a set of files using the provided embedding function.
886    /// `embed_fn` takes a batch of texts and returns a batch of embedding vectors.
887    pub fn build<F>(
888        project_root: &Path,
889        files: &[PathBuf],
890        embed_fn: &mut F,
891        max_batch_size: usize,
892    ) -> Result<Self, String>
893    where
894        F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
895    {
896        let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
897        Self::build_from_chunks(
898            chunks,
899            file_mtimes,
900            embed_fn,
901            max_batch_size,
902            Option::<&mut fn(usize, usize)>::None,
903        )
904    }
905
906    /// Build the semantic index and report embedding progress using entry counts.
907    pub fn build_with_progress<F, P>(
908        project_root: &Path,
909        files: &[PathBuf],
910        embed_fn: &mut F,
911        max_batch_size: usize,
912        progress: &mut P,
913    ) -> Result<Self, String>
914    where
915        F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
916        P: FnMut(usize, usize),
917    {
918        let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
919        let total_chunks = chunks.len();
920        progress(0, total_chunks);
921        Self::build_from_chunks(
922            chunks,
923            file_mtimes,
924            embed_fn,
925            max_batch_size,
926            Some(progress),
927        )
928    }
929
930    /// Search the index with a query embedding, returning top-K results sorted by relevance
931    pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
932        if self.entries.is_empty() || query_vector.len() != self.dimension {
933            return Vec::new();
934        }
935
936        let mut scored: Vec<(f32, usize)> = self
937            .entries
938            .iter()
939            .enumerate()
940            .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
941            .collect();
942
943        // Sort descending by score
944        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
945
946        scored
947            .into_iter()
948            .take(top_k)
949            .filter(|(score, _)| *score > 0.0)
950            .map(|(score, idx)| {
951                let entry = &self.entries[idx];
952                SemanticResult {
953                    file: entry.chunk.file.clone(),
954                    name: entry.chunk.name.clone(),
955                    kind: entry.chunk.kind.clone(),
956                    start_line: entry.chunk.start_line,
957                    end_line: entry.chunk.end_line,
958                    exported: entry.chunk.exported,
959                    snippet: entry.chunk.snippet.clone(),
960                    score,
961                }
962            })
963            .collect()
964    }
965
966    /// Number of indexed entries
967    pub fn len(&self) -> usize {
968        self.entries.len()
969    }
970
971    /// Check if a file needs re-indexing based on mtime
972    pub fn is_file_stale(&self, file: &Path) -> bool {
973        match self.file_mtimes.get(file) {
974            None => true,
975            Some(stored_mtime) => match fs::metadata(file).and_then(|m| m.modified()) {
976                Ok(current_mtime) => *stored_mtime != current_mtime,
977                Err(_) => true,
978            },
979        }
980    }
981
982    pub fn count_stale_files(&self) -> usize {
983        self.file_mtimes
984            .keys()
985            .filter(|path| self.is_file_stale(path))
986            .count()
987    }
988
989    /// Remove entries for a specific file
990    pub fn remove_file(&mut self, file: &Path) {
991        self.invalidate_file(file);
992    }
993
994    pub fn invalidate_file(&mut self, file: &Path) {
995        self.entries.retain(|e| e.chunk.file != file);
996        self.file_mtimes.remove(file);
997    }
998
999    /// Get the embedding dimension
1000    pub fn dimension(&self) -> usize {
1001        self.dimension
1002    }
1003
1004    pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1005        self.fingerprint.as_ref()
1006    }
1007
1008    pub fn backend_label(&self) -> Option<&str> {
1009        self.fingerprint.as_ref().map(|f| f.backend.as_str())
1010    }
1011
1012    pub fn model_label(&self) -> Option<&str> {
1013        self.fingerprint.as_ref().map(|f| f.model.as_str())
1014    }
1015
1016    pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1017        self.fingerprint = Some(fingerprint);
1018    }
1019
1020    /// Write the semantic index to disk using atomic temp+rename pattern
1021    pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1022        // Don't persist empty indexes — they would be loaded on next startup
1023        // and prevent a fresh build that might find files.
1024        if self.entries.is_empty() {
1025            log::info!("[aft] skipping semantic index persistence (0 entries)");
1026            return;
1027        }
1028        let dir = storage_dir.join("semantic").join(project_key);
1029        if let Err(e) = fs::create_dir_all(&dir) {
1030            log::warn!("[aft] failed to create semantic cache dir: {}", e);
1031            return;
1032        }
1033        let data_path = dir.join("semantic.bin");
1034        let tmp_path = dir.join("semantic.bin.tmp");
1035        let bytes = self.to_bytes();
1036        if let Err(e) = fs::write(&tmp_path, &bytes) {
1037            log::warn!("[aft] failed to write semantic index: {}", e);
1038            let _ = fs::remove_file(&tmp_path);
1039            return;
1040        }
1041        if let Err(e) = fs::rename(&tmp_path, &data_path) {
1042            log::warn!("[aft] failed to rename semantic index: {}", e);
1043            let _ = fs::remove_file(&tmp_path);
1044            return;
1045        }
1046        log::info!(
1047            "[aft] semantic index persisted: {} entries, {:.1} KB",
1048            self.entries.len(),
1049            bytes.len() as f64 / 1024.0
1050        );
1051    }
1052
1053    /// Read the semantic index from disk
1054    pub fn read_from_disk(
1055        storage_dir: &Path,
1056        project_key: &str,
1057        expected_fingerprint: Option<&str>,
1058    ) -> Option<Self> {
1059        let data_path = storage_dir
1060            .join("semantic")
1061            .join(project_key)
1062            .join("semantic.bin");
1063        let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1064        if file_len < HEADER_BYTES_V1 {
1065            log::warn!(
1066                "[aft] corrupt semantic index (too small: {} bytes), removing",
1067                file_len
1068            );
1069            let _ = fs::remove_file(&data_path);
1070            return None;
1071        }
1072
1073        let bytes = fs::read(&data_path).ok()?;
1074        match Self::from_bytes(&bytes) {
1075            Ok(index) => {
1076                if index.entries.is_empty() {
1077                    log::info!("[aft] cached semantic index is empty, will rebuild");
1078                    let _ = fs::remove_file(&data_path);
1079                    return None;
1080                }
1081                if let Some(expected) = expected_fingerprint {
1082                    let matches = index
1083                        .fingerprint()
1084                        .map(|fingerprint| fingerprint.matches_expected(expected))
1085                        .unwrap_or(false);
1086                    if !matches {
1087                        log::info!("[aft] cached semantic index fingerprint mismatch, rebuilding");
1088                        let _ = fs::remove_file(&data_path);
1089                        return None;
1090                    }
1091                }
1092                log::info!(
1093                    "[aft] loaded semantic index from disk: {} entries",
1094                    index.entries.len()
1095                );
1096                Some(index)
1097            }
1098            Err(e) => {
1099                log::warn!("[aft] corrupt semantic index, rebuilding: {}", e);
1100                let _ = fs::remove_file(&data_path);
1101                None
1102            }
1103        }
1104    }
1105
1106    /// Serialize the index to bytes for disk persistence
1107    pub fn to_bytes(&self) -> Vec<u8> {
1108        let mut buf = Vec::new();
1109        let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1110            let encoded = fingerprint.as_string();
1111            if encoded.is_empty() {
1112                None
1113            } else {
1114                Some(encoded.into_bytes())
1115            }
1116        });
1117
1118        // Header: version(1) + dimension(4) + entry_count(4) + fingerprint_len(4) + fingerprint
1119        //
1120        // V3 (v0.15.2+) is the single write format. Layout:
1121        //   - fingerprint is always represented (absent ⇒ fingerprint_len=0,
1122        //     no bytes follow). Uniform format simplifies the reader.
1123        //   - mtimes stored as secs(u64) + subsec_nanos(u32). Preserves full
1124        //     APFS/ext4/NTFS precision so staleness checks survive restart
1125        //     round-trips.
1126        //
1127        // V1/V2 remain readable for backward compatibility (see from_bytes).
1128        let version = SEMANTIC_INDEX_VERSION_V3;
1129        buf.push(version);
1130        buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1131        buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1132        let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1133        buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1134        buf.extend_from_slice(fp_bytes_ref);
1135
1136        // File mtime table: count(4) + entries
1137        // V3 layout per entry: path_len(4) + path + secs(8) + subsec_nanos(4)
1138        buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1139        for (path, mtime) in &self.file_mtimes {
1140            let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1141            buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1142            buf.extend_from_slice(&path_bytes);
1143            let duration = mtime
1144                .duration_since(SystemTime::UNIX_EPOCH)
1145                .unwrap_or_default();
1146            buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1147            buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1148        }
1149
1150        // Entries: each is metadata + vector
1151        for entry in &self.entries {
1152            let c = &entry.chunk;
1153
1154            // File path
1155            let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1156            buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1157            buf.extend_from_slice(&file_bytes);
1158
1159            // Name
1160            let name_bytes = c.name.as_bytes();
1161            buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1162            buf.extend_from_slice(name_bytes);
1163
1164            // Kind (1 byte)
1165            buf.push(symbol_kind_to_u8(&c.kind));
1166
1167            // Lines + exported
1168            buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1169            buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1170            buf.push(c.exported as u8);
1171
1172            // Snippet
1173            let snippet_bytes = c.snippet.as_bytes();
1174            buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1175            buf.extend_from_slice(snippet_bytes);
1176
1177            // Embed text
1178            let embed_bytes = c.embed_text.as_bytes();
1179            buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1180            buf.extend_from_slice(embed_bytes);
1181
1182            // Vector (f32 array)
1183            for &val in &entry.vector {
1184                buf.extend_from_slice(&val.to_le_bytes());
1185            }
1186        }
1187
1188        buf
1189    }
1190
1191    /// Deserialize the index from bytes
1192    pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1193        let mut pos = 0;
1194
1195        if data.len() < HEADER_BYTES_V1 {
1196            return Err("data too short".to_string());
1197        }
1198
1199        let version = data[pos];
1200        pos += 1;
1201        if version != SEMANTIC_INDEX_VERSION_V1
1202            && version != SEMANTIC_INDEX_VERSION_V2
1203            && version != SEMANTIC_INDEX_VERSION_V3
1204        {
1205            return Err(format!("unsupported version: {}", version));
1206        }
1207        // V2 and V3 share the same header layout (V3 only differs in the
1208        // per-mtime entry layout): version(1) + dimension(4) + entry_count(4)
1209        // + fingerprint_len(4) + fingerprint bytes.
1210        if (version == SEMANTIC_INDEX_VERSION_V2 || version == SEMANTIC_INDEX_VERSION_V3)
1211            && data.len() < HEADER_BYTES_V2
1212        {
1213            return Err("data too short for semantic index v2/v3 header".to_string());
1214        }
1215
1216        let dimension = read_u32(data, &mut pos)? as usize;
1217        let entry_count = read_u32(data, &mut pos)? as usize;
1218        if dimension == 0 || dimension > MAX_DIMENSION {
1219            return Err(format!("invalid embedding dimension: {}", dimension));
1220        }
1221        if entry_count > MAX_ENTRIES {
1222            return Err(format!("too many semantic index entries: {}", entry_count));
1223        }
1224
1225        // Fingerprint handling:
1226        //   - V1: no fingerprint field at all.
1227        //   - V2: fingerprint_len + fingerprint bytes; always present (writer
1228        //     only emitted V2 when fingerprint was Some).
1229        //   - V3: fingerprint_len always present; fingerprint_len==0 ⇒ None.
1230        let has_fingerprint_field =
1231            version == SEMANTIC_INDEX_VERSION_V2 || version == SEMANTIC_INDEX_VERSION_V3;
1232        let fingerprint = if has_fingerprint_field {
1233            let fingerprint_len = read_u32(data, &mut pos)? as usize;
1234            if pos + fingerprint_len > data.len() {
1235                return Err("unexpected end of data reading fingerprint".to_string());
1236            }
1237            if fingerprint_len == 0 {
1238                None
1239            } else {
1240                let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1241                pos += fingerprint_len;
1242                Some(
1243                    serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1244                        .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1245                )
1246            }
1247        } else {
1248            None
1249        };
1250
1251        // File mtimes
1252        let mtime_count = read_u32(data, &mut pos)? as usize;
1253        if mtime_count > MAX_ENTRIES {
1254            return Err(format!("too many semantic file mtimes: {}", mtime_count));
1255        }
1256
1257        let vector_bytes = entry_count
1258            .checked_mul(dimension)
1259            .and_then(|count| count.checked_mul(F32_BYTES))
1260            .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1261        if vector_bytes > data.len().saturating_sub(pos) {
1262            return Err("semantic index vectors exceed available data".to_string());
1263        }
1264
1265        let mut file_mtimes = HashMap::with_capacity(mtime_count);
1266        for _ in 0..mtime_count {
1267            let path = read_string(data, &mut pos)?;
1268            let secs = read_u64(data, &mut pos)?;
1269            // V3 persists subsec_nanos alongside secs so staleness checks
1270            // survive restart round-trips. V1/V2 load with 0 nanos, which
1271            // causes one rebuild on upgrade (they never matched live APFS
1272            // mtimes anyway — the bug v0.15.2 fixes). After that rebuild,
1273            // the cache is persisted as V3 and stabilises.
1274            let nanos = if version == SEMANTIC_INDEX_VERSION_V3 {
1275                read_u32(data, &mut pos)?
1276            } else {
1277                0
1278            };
1279            // Hardening against corrupt / maliciously crafted cache files
1280            // (v0.15.2). `Duration::new(secs, nanos)` can panic when the
1281            // nanosecond carry overflows the second counter, and
1282            // `SystemTime + Duration` can panic on carry past the platform's
1283            // upper bound. Explicit validation keeps a corrupted semantic.bin
1284            // from taking down the whole aft process.
1285            if nanos >= 1_000_000_000 {
1286                return Err(format!(
1287                    "invalid semantic mtime: nanos {} >= 1_000_000_000",
1288                    nanos
1289                ));
1290            }
1291            let duration = std::time::Duration::new(secs, nanos);
1292            let mtime = SystemTime::UNIX_EPOCH
1293                .checked_add(duration)
1294                .ok_or_else(|| {
1295                    format!(
1296                        "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1297                        secs, nanos
1298                    )
1299                })?;
1300            file_mtimes.insert(PathBuf::from(path), mtime);
1301        }
1302
1303        // Entries
1304        let mut entries = Vec::with_capacity(entry_count);
1305        for _ in 0..entry_count {
1306            let file = PathBuf::from(read_string(data, &mut pos)?);
1307            let name = read_string(data, &mut pos)?;
1308
1309            if pos >= data.len() {
1310                return Err("unexpected end of data".to_string());
1311            }
1312            let kind = u8_to_symbol_kind(data[pos]);
1313            pos += 1;
1314
1315            let start_line = read_u32(data, &mut pos)?;
1316            let end_line = read_u32(data, &mut pos)?;
1317
1318            if pos >= data.len() {
1319                return Err("unexpected end of data".to_string());
1320            }
1321            let exported = data[pos] != 0;
1322            pos += 1;
1323
1324            let snippet = read_string(data, &mut pos)?;
1325            let embed_text = read_string(data, &mut pos)?;
1326
1327            // Vector
1328            let vec_bytes = dimension
1329                .checked_mul(F32_BYTES)
1330                .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1331            if pos + vec_bytes > data.len() {
1332                return Err("unexpected end of data reading vector".to_string());
1333            }
1334            let mut vector = Vec::with_capacity(dimension);
1335            for _ in 0..dimension {
1336                let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1337                vector.push(f32::from_le_bytes(bytes));
1338                pos += 4;
1339            }
1340
1341            entries.push(EmbeddingEntry {
1342                chunk: SemanticChunk {
1343                    file,
1344                    name,
1345                    kind,
1346                    start_line,
1347                    end_line,
1348                    exported,
1349                    embed_text,
1350                    snippet,
1351                },
1352                vector,
1353            });
1354        }
1355
1356        Ok(Self {
1357            entries,
1358            file_mtimes,
1359            dimension,
1360            fingerprint,
1361        })
1362    }
1363}
1364
1365/// Build enriched embedding text from a symbol with cAST-style context
1366fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1367    let relative = file
1368        .strip_prefix(project_root)
1369        .unwrap_or(file)
1370        .to_string_lossy();
1371
1372    let kind_label = match &symbol.kind {
1373        SymbolKind::Function => "function",
1374        SymbolKind::Class => "class",
1375        SymbolKind::Method => "method",
1376        SymbolKind::Struct => "struct",
1377        SymbolKind::Interface => "interface",
1378        SymbolKind::Enum => "enum",
1379        SymbolKind::TypeAlias => "type",
1380        SymbolKind::Variable => "variable",
1381        SymbolKind::Heading => "heading",
1382    };
1383
1384    // Build: "file:relative/path kind:function name:validateAuth signature:fn validateAuth(token: &str) -> bool"
1385    let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1386
1387    if let Some(sig) = &symbol.signature {
1388        text.push_str(&format!(" signature:{}", sig));
1389    }
1390
1391    // Add body snippet (first ~300 chars of symbol body)
1392    let lines: Vec<&str> = source.lines().collect();
1393    let start = (symbol.range.start_line.saturating_sub(1) as usize).min(lines.len()); // 1-based to 0-based
1394    let end = (symbol.range.end_line as usize).min(lines.len()); // 1-based inclusive
1395    if start < end {
1396        let body: String = lines[start..end]
1397            .iter()
1398            .take(15) // max 15 lines
1399            .copied()
1400            .collect::<Vec<&str>>()
1401            .join("\n");
1402        let snippet = if body.len() > 300 {
1403            format!("{}...", &body[..body.floor_char_boundary(300)])
1404        } else {
1405            body
1406        };
1407        text.push_str(&format!(" body:{}", snippet));
1408    }
1409
1410    text
1411}
1412
1413/// Build a display snippet from a symbol's source
1414fn build_snippet(symbol: &Symbol, source: &str) -> String {
1415    let lines: Vec<&str> = source.lines().collect();
1416    let start = (symbol.range.start_line.saturating_sub(1) as usize).min(lines.len());
1417    let end = (symbol.range.end_line as usize).min(lines.len());
1418    if start < end {
1419        let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1420        let mut snippet = snippet_lines.join("\n");
1421        if end - start > 5 {
1422            snippet.push_str("\n  ...");
1423        }
1424        if snippet.len() > 300 {
1425            snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1426        }
1427        snippet
1428    } else {
1429        String::new()
1430    }
1431}
1432
1433/// Convert symbols to semantic chunks with enriched context
1434fn symbols_to_chunks(
1435    file: &Path,
1436    symbols: &[Symbol],
1437    source: &str,
1438    project_root: &Path,
1439) -> Vec<SemanticChunk> {
1440    let mut chunks = Vec::new();
1441
1442    for symbol in symbols {
1443        // Skip very small symbols (single-line variables, etc.)
1444        let line_count = symbol
1445            .range
1446            .end_line
1447            .saturating_sub(symbol.range.start_line)
1448            + 1;
1449        if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1450            continue;
1451        }
1452
1453        let embed_text = build_embed_text(symbol, source, file, project_root);
1454        let snippet = build_snippet(symbol, source);
1455
1456        chunks.push(SemanticChunk {
1457            file: file.to_path_buf(),
1458            name: symbol.name.clone(),
1459            kind: symbol.kind.clone(),
1460            start_line: symbol.range.start_line,
1461            end_line: symbol.range.end_line,
1462            exported: symbol.exported,
1463            embed_text,
1464            snippet,
1465        });
1466
1467        // Note: Nested symbols are handled separately by the outline system
1468        // Each symbol is indexed individually
1469    }
1470
1471    chunks
1472}
1473
1474/// Cosine similarity between two vectors
1475fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1476    if a.len() != b.len() {
1477        return 0.0;
1478    }
1479
1480    let mut dot = 0.0f32;
1481    let mut norm_a = 0.0f32;
1482    let mut norm_b = 0.0f32;
1483
1484    for i in 0..a.len() {
1485        dot += a[i] * b[i];
1486        norm_a += a[i] * a[i];
1487        norm_b += b[i] * b[i];
1488    }
1489
1490    let denom = norm_a.sqrt() * norm_b.sqrt();
1491    if denom == 0.0 {
1492        0.0
1493    } else {
1494        dot / denom
1495    }
1496}
1497
1498// Serialization helpers
1499fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1500    match kind {
1501        SymbolKind::Function => 0,
1502        SymbolKind::Class => 1,
1503        SymbolKind::Method => 2,
1504        SymbolKind::Struct => 3,
1505        SymbolKind::Interface => 4,
1506        SymbolKind::Enum => 5,
1507        SymbolKind::TypeAlias => 6,
1508        SymbolKind::Variable => 7,
1509        SymbolKind::Heading => 8,
1510    }
1511}
1512
1513fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1514    match v {
1515        0 => SymbolKind::Function,
1516        1 => SymbolKind::Class,
1517        2 => SymbolKind::Method,
1518        3 => SymbolKind::Struct,
1519        4 => SymbolKind::Interface,
1520        5 => SymbolKind::Enum,
1521        6 => SymbolKind::TypeAlias,
1522        7 => SymbolKind::Variable,
1523        _ => SymbolKind::Heading,
1524    }
1525}
1526
1527fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
1528    if *pos + 4 > data.len() {
1529        return Err("unexpected end of data reading u32".to_string());
1530    }
1531    let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
1532    *pos += 4;
1533    Ok(val)
1534}
1535
1536fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
1537    if *pos + 8 > data.len() {
1538        return Err("unexpected end of data reading u64".to_string());
1539    }
1540    let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
1541    *pos += 8;
1542    Ok(u64::from_le_bytes(bytes))
1543}
1544
1545fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
1546    let len = read_u32(data, pos)? as usize;
1547    if *pos + len > data.len() {
1548        return Err("unexpected end of data reading string".to_string());
1549    }
1550    let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
1551    *pos += len;
1552    Ok(s)
1553}
1554
1555#[cfg(test)]
1556mod tests {
1557    use super::*;
1558    use crate::config::{SemanticBackend, SemanticBackendConfig};
1559    use std::io::{Read, Write};
1560    use std::net::TcpListener;
1561    use std::thread;
1562
1563    fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
1564    where
1565        F: Fn(String, String, String) -> String + Send + 'static,
1566    {
1567        let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1568        let addr = listener.local_addr().expect("local addr");
1569        let handle = thread::spawn(move || {
1570            let (mut stream, _) = listener.accept().expect("accept request");
1571            let mut buf = Vec::new();
1572            let mut chunk = [0u8; 4096];
1573            let mut header_end = None;
1574            let mut content_length = 0usize;
1575            loop {
1576                let n = stream.read(&mut chunk).expect("read request");
1577                if n == 0 {
1578                    break;
1579                }
1580                buf.extend_from_slice(&chunk[..n]);
1581                if header_end.is_none() {
1582                    if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
1583                        header_end = Some(pos + 4);
1584                        let headers = String::from_utf8_lossy(&buf[..pos + 4]);
1585                        for line in headers.lines() {
1586                            if let Some(value) = line.strip_prefix("Content-Length:") {
1587                                content_length = value.trim().parse::<usize>().unwrap_or(0);
1588                            }
1589                        }
1590                    }
1591                }
1592                if let Some(end) = header_end {
1593                    if buf.len() >= end + content_length {
1594                        break;
1595                    }
1596                }
1597            }
1598
1599            let end = header_end.expect("header terminator");
1600            let request = String::from_utf8_lossy(&buf[..end]).to_string();
1601            let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
1602            let mut lines = request.lines();
1603            let request_line = lines.next().expect("request line").to_string();
1604            let path = request_line
1605                .split_whitespace()
1606                .nth(1)
1607                .expect("request path")
1608                .to_string();
1609            let response_body = handler(request_line, path, body);
1610            let response = format!(
1611                "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
1612                response_body.len(),
1613                response_body
1614            );
1615            stream
1616                .write_all(response.as_bytes())
1617                .expect("write response");
1618        });
1619
1620        (format!("http://{}", addr), handle)
1621    }
1622
1623    #[test]
1624    fn test_cosine_similarity_identical() {
1625        let a = vec![1.0, 0.0, 0.0];
1626        let b = vec![1.0, 0.0, 0.0];
1627        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1628    }
1629
1630    #[test]
1631    fn test_cosine_similarity_orthogonal() {
1632        let a = vec![1.0, 0.0, 0.0];
1633        let b = vec![0.0, 1.0, 0.0];
1634        assert!(cosine_similarity(&a, &b).abs() < 0.001);
1635    }
1636
1637    #[test]
1638    fn test_cosine_similarity_opposite() {
1639        let a = vec![1.0, 0.0, 0.0];
1640        let b = vec![-1.0, 0.0, 0.0];
1641        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
1642    }
1643
1644    #[test]
1645    fn test_serialization_roundtrip() {
1646        let mut index = SemanticIndex::new();
1647        index.entries.push(EmbeddingEntry {
1648            chunk: SemanticChunk {
1649                file: PathBuf::from("/src/main.rs"),
1650                name: "handle_request".to_string(),
1651                kind: SymbolKind::Function,
1652                start_line: 10,
1653                end_line: 25,
1654                exported: true,
1655                embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1656                snippet: "fn handle_request() {\n  // ...\n}".to_string(),
1657            },
1658            vector: vec![0.1, 0.2, 0.3, 0.4],
1659        });
1660        index.dimension = 4;
1661        index
1662            .file_mtimes
1663            .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1664        index.set_fingerprint(SemanticIndexFingerprint {
1665            backend: "fastembed".to_string(),
1666            model: "all-MiniLM-L6-v2".to_string(),
1667            base_url: FALLBACK_BACKEND.to_string(),
1668            dimension: 4,
1669        });
1670
1671        let bytes = index.to_bytes();
1672        let restored = SemanticIndex::from_bytes(&bytes).unwrap();
1673
1674        assert_eq!(restored.entries.len(), 1);
1675        assert_eq!(restored.entries[0].chunk.name, "handle_request");
1676        assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
1677        assert_eq!(restored.dimension, 4);
1678        assert_eq!(restored.backend_label(), Some("fastembed"));
1679        assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
1680    }
1681
1682    #[test]
1683    fn test_search_top_k() {
1684        let mut index = SemanticIndex::new();
1685        index.dimension = 3;
1686
1687        // Add entries with known vectors
1688        for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
1689            let mut vec = vec![0.0f32; 3];
1690            vec[i] = 1.0; // orthogonal vectors
1691            index.entries.push(EmbeddingEntry {
1692                chunk: SemanticChunk {
1693                    file: PathBuf::from("/src/lib.rs"),
1694                    name: name.to_string(),
1695                    kind: SymbolKind::Function,
1696                    start_line: (i * 10 + 1) as u32,
1697                    end_line: (i * 10 + 5) as u32,
1698                    exported: true,
1699                    embed_text: format!("kind:function name:{}", name),
1700                    snippet: format!("fn {}() {{}}", name),
1701                },
1702                vector: vec,
1703            });
1704        }
1705
1706        // Query aligned with "auth" (index 0)
1707        let query = vec![0.9, 0.1, 0.0];
1708        let results = index.search(&query, 2);
1709
1710        assert_eq!(results.len(), 2);
1711        assert_eq!(results[0].name, "auth"); // highest score
1712        assert!(results[0].score > results[1].score);
1713    }
1714
1715    #[test]
1716    fn test_empty_index_search() {
1717        let index = SemanticIndex::new();
1718        let results = index.search(&[0.1, 0.2, 0.3], 10);
1719        assert!(results.is_empty());
1720    }
1721
1722    #[test]
1723    fn rejects_oversized_dimension_during_deserialization() {
1724        let mut bytes = Vec::new();
1725        bytes.push(1u8);
1726        bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
1727        bytes.extend_from_slice(&0u32.to_le_bytes());
1728        bytes.extend_from_slice(&0u32.to_le_bytes());
1729
1730        assert!(SemanticIndex::from_bytes(&bytes).is_err());
1731    }
1732
1733    #[test]
1734    fn rejects_oversized_entry_count_during_deserialization() {
1735        let mut bytes = Vec::new();
1736        bytes.push(1u8);
1737        bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
1738        bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
1739        bytes.extend_from_slice(&0u32.to_le_bytes());
1740
1741        assert!(SemanticIndex::from_bytes(&bytes).is_err());
1742    }
1743
1744    #[test]
1745    fn invalidate_file_removes_entries_and_mtime() {
1746        let target = PathBuf::from("/src/main.rs");
1747        let mut index = SemanticIndex::new();
1748        index.entries.push(EmbeddingEntry {
1749            chunk: SemanticChunk {
1750                file: target.clone(),
1751                name: "main".to_string(),
1752                kind: SymbolKind::Function,
1753                start_line: 0,
1754                end_line: 1,
1755                exported: false,
1756                embed_text: "main".to_string(),
1757                snippet: "fn main() {}".to_string(),
1758            },
1759            vector: vec![1.0; DEFAULT_DIMENSION],
1760        });
1761        index
1762            .file_mtimes
1763            .insert(target.clone(), SystemTime::UNIX_EPOCH);
1764
1765        index.invalidate_file(&target);
1766
1767        assert!(index.entries.is_empty());
1768        assert!(!index.file_mtimes.contains_key(&target));
1769    }
1770
1771    #[test]
1772    fn detects_missing_onnx_runtime_from_dynamic_load_error() {
1773        let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
1774
1775        assert!(is_onnx_runtime_unavailable(message));
1776    }
1777
1778    #[test]
1779    fn formats_missing_onnx_runtime_with_install_hint() {
1780        let message = format_embedding_init_error(
1781            "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
1782        );
1783
1784        assert!(message.starts_with("ONNX Runtime not found. Install via:"));
1785        assert!(message.contains("Original error:"));
1786    }
1787
1788    #[test]
1789    fn openai_compatible_backend_embeds_with_mock_server() {
1790        let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1791            assert!(request_line.starts_with("POST "));
1792            assert_eq!(path, "/v1/embeddings");
1793            "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
1794        });
1795
1796        let config = SemanticBackendConfig {
1797            backend: SemanticBackend::OpenAiCompatible,
1798            model: "test-embedding".to_string(),
1799            base_url: Some(base_url),
1800            api_key_env: None,
1801            timeout_ms: 5_000,
1802            max_batch_size: 64,
1803        };
1804
1805        let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1806        let vectors = model
1807            .embed(vec!["hello".to_string(), "world".to_string()])
1808            .unwrap();
1809
1810        assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
1811        handle.join().unwrap();
1812    }
1813
1814    #[test]
1815    fn ollama_backend_embeds_with_mock_server() {
1816        let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1817            assert!(request_line.starts_with("POST "));
1818            assert_eq!(path, "/api/embed");
1819            "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
1820        });
1821
1822        let config = SemanticBackendConfig {
1823            backend: SemanticBackend::Ollama,
1824            model: "embeddinggemma".to_string(),
1825            base_url: Some(base_url),
1826            api_key_env: None,
1827            timeout_ms: 5_000,
1828            max_batch_size: 64,
1829        };
1830
1831        let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1832        let vectors = model
1833            .embed(vec!["hello".to_string(), "world".to_string()])
1834            .unwrap();
1835
1836        assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
1837        handle.join().unwrap();
1838    }
1839
1840    #[test]
1841    fn read_from_disk_rejects_fingerprint_mismatch() {
1842        let storage = tempfile::tempdir().unwrap();
1843        let project_key = "proj";
1844
1845        let mut index = SemanticIndex::new();
1846        index.entries.push(EmbeddingEntry {
1847            chunk: SemanticChunk {
1848                file: PathBuf::from("/src/main.rs"),
1849                name: "handle_request".to_string(),
1850                kind: SymbolKind::Function,
1851                start_line: 10,
1852                end_line: 25,
1853                exported: true,
1854                embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1855                snippet: "fn handle_request() {}".to_string(),
1856            },
1857            vector: vec![0.1, 0.2, 0.3],
1858        });
1859        index.dimension = 3;
1860        index
1861            .file_mtimes
1862            .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1863        index.set_fingerprint(SemanticIndexFingerprint {
1864            backend: "openai_compatible".to_string(),
1865            model: "test-embedding".to_string(),
1866            base_url: "http://127.0.0.1:1234/v1".to_string(),
1867            dimension: 3,
1868        });
1869        index.write_to_disk(storage.path(), project_key);
1870
1871        let matching = index.fingerprint().unwrap().as_string();
1872        assert!(
1873            SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
1874        );
1875
1876        let mismatched = SemanticIndexFingerprint {
1877            backend: "ollama".to_string(),
1878            model: "embeddinggemma".to_string(),
1879            base_url: "http://127.0.0.1:11434".to_string(),
1880            dimension: 3,
1881        }
1882        .as_string();
1883        assert!(
1884            SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
1885        );
1886    }
1887}