1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::FileParser;
3use crate::symbols::{Symbol, SymbolKind};
4
5use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
6use reqwest::blocking::Client;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::env;
10use std::fmt::Display;
11use std::fs;
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14use std::time::SystemTime;
15use url::Url;
16
17const DEFAULT_DIMENSION: usize = 384;
18const MAX_ENTRIES: usize = 1_000_000;
19const MAX_DIMENSION: usize = 1024;
20const F32_BYTES: usize = std::mem::size_of::<f32>();
21const HEADER_BYTES_V1: usize = 9;
22const HEADER_BYTES_V2: usize = 13;
23const ONNX_RUNTIME_INSTALL_HINT: &str =
24 "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
25
26const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
27const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
28const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
33const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
34const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
35const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
37const DEFAULT_MAX_BATCH_SIZE: usize = 64;
38const FALLBACK_BACKEND: &str = "none";
39const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
40const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct SemanticIndexFingerprint {
44 pub backend: String,
45 pub model: String,
46 #[serde(default)]
47 pub base_url: String,
48 pub dimension: usize,
49}
50
51impl SemanticIndexFingerprint {
52 fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
53 let base_url = config
56 .base_url
57 .as_ref()
58 .and_then(|u| normalize_base_url(u).ok())
59 .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
60 Self {
61 backend: config.backend.as_str().to_string(),
62 model: config.model.clone(),
63 base_url,
64 dimension,
65 }
66 }
67
68 pub fn as_string(&self) -> String {
69 serde_json::to_string(self).unwrap_or_else(|_| String::new())
70 }
71
72 fn matches_expected(&self, expected: &str) -> bool {
73 let encoded = self.as_string();
74 !encoded.is_empty() && encoded == expected
75 }
76}
77
78enum SemanticEmbeddingEngine {
79 Fastembed(TextEmbedding),
80 OpenAiCompatible {
81 client: Client,
82 model: String,
83 base_url: String,
84 api_key: Option<String>,
85 },
86 Ollama {
87 client: Client,
88 model: String,
89 base_url: String,
90 },
91}
92
93pub struct SemanticEmbeddingModel {
94 backend: SemanticBackend,
95 model: String,
96 base_url: Option<String>,
97 timeout_ms: u64,
98 max_batch_size: usize,
99 dimension: Option<usize>,
100 engine: SemanticEmbeddingEngine,
101}
102
103pub type EmbeddingModel = SemanticEmbeddingModel;
104
105fn validate_embedding_batch(
106 vectors: &[Vec<f32>],
107 expected_count: usize,
108 context: &str,
109) -> Result<(), String> {
110 if expected_count > 0 && vectors.is_empty() {
111 return Err(format!(
112 "{context} returned no vectors for {expected_count} inputs"
113 ));
114 }
115
116 if vectors.len() != expected_count {
117 return Err(format!(
118 "{context} returned {} vectors for {} inputs",
119 vectors.len(),
120 expected_count
121 ));
122 }
123
124 let Some(first_vector) = vectors.first() else {
125 return Ok(());
126 };
127 let expected_dimension = first_vector.len();
128 for (index, vector) in vectors.iter().enumerate() {
129 if vector.len() != expected_dimension {
130 return Err(format!(
131 "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
132 vector.len()
133 ));
134 }
135 }
136
137 Ok(())
138}
139
140fn normalize_base_url(raw: &str) -> Result<String, String> {
141 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
142 let scheme = parsed.scheme();
143 if scheme != "http" && scheme != "https" {
144 return Err(format!(
145 "unsupported URL scheme '{}' — only http:// and https:// are allowed",
146 scheme
147 ));
148 }
149 Ok(parsed.to_string().trim_end_matches('/').to_string())
150}
151
152fn build_openai_embeddings_endpoint(base_url: &str) -> String {
153 if base_url.ends_with("/v1") {
154 format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
155 } else {
156 format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
157 }
158}
159
160fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
161 if base_url.ends_with("/api") {
162 format!("{base_url}/embed")
163 } else {
164 format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
165 }
166}
167
168fn normalize_api_key(value: Option<String>) -> Option<String> {
169 value.and_then(|token| {
170 let token = token.trim();
171 if token.is_empty() {
172 None
173 } else {
174 Some(token.to_string())
175 }
176 })
177}
178
179fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
180 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
181}
182
183fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
184 error.is_connect()
185}
186
187fn sleep_before_embedding_retry(attempt_index: usize) {
188 if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
189 std::thread::sleep(Duration::from_millis(*delay_ms));
190 }
191}
192
193fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
194where
195 F: FnMut() -> reqwest::blocking::RequestBuilder,
196{
197 for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
198 let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
199
200 let response = match make_request().send() {
201 Ok(response) => response,
202 Err(error) => {
203 if !last_attempt && is_retryable_embedding_error(&error) {
204 sleep_before_embedding_retry(attempt_index);
205 continue;
206 }
207 return Err(format!("{backend_label} request failed: {error}"));
208 }
209 };
210
211 let status = response.status();
212 let raw = match response.text() {
213 Ok(raw) => raw,
214 Err(error) => {
215 if !last_attempt && is_retryable_embedding_error(&error) {
216 sleep_before_embedding_retry(attempt_index);
217 continue;
218 }
219 return Err(format!("{backend_label} response read failed: {error}"));
220 }
221 };
222
223 if status.is_success() {
224 return Ok(raw);
225 }
226
227 if !last_attempt && is_retryable_embedding_status(status) {
228 sleep_before_embedding_retry(attempt_index);
229 continue;
230 }
231
232 return Err(format!(
233 "{backend_label} request failed (HTTP {}): {}",
234 status, raw
235 ));
236 }
237
238 unreachable!("embedding request retries exhausted without returning")
239}
240
241impl SemanticEmbeddingModel {
242 pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
243 let timeout_ms = if config.timeout_ms == 0 {
244 DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
245 } else {
246 config.timeout_ms
247 };
248
249 let max_batch_size = if config.max_batch_size == 0 {
250 DEFAULT_MAX_BATCH_SIZE
251 } else {
252 config.max_batch_size
253 };
254
255 let api_key_env = normalize_api_key(config.api_key_env.clone());
256 let model = config.model.clone();
257
258 let client = Client::builder()
259 .timeout(Duration::from_millis(timeout_ms))
260 .redirect(reqwest::redirect::Policy::none())
261 .build()
262 .map_err(|error| format!("failed to configure embedding client: {error}"))?;
263
264 let engine = match config.backend {
265 SemanticBackend::Fastembed => {
266 SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
267 }
268 SemanticBackend::OpenAiCompatible => {
269 let raw = config.base_url.as_ref().ok_or_else(|| {
270 "base_url is required for openai_compatible backend".to_string()
271 })?;
272 let base_url = normalize_base_url(raw)?;
273
274 let api_key = match api_key_env {
275 Some(var_name) => Some(env::var(&var_name).map_err(|_| {
276 format!("missing api_key_env '{var_name}' for openai_compatible backend")
277 })?),
278 None => None,
279 };
280
281 SemanticEmbeddingEngine::OpenAiCompatible {
282 client,
283 model,
284 base_url,
285 api_key,
286 }
287 }
288 SemanticBackend::Ollama => {
289 let raw = config
290 .base_url
291 .as_ref()
292 .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
293 let base_url = normalize_base_url(raw)?;
294
295 SemanticEmbeddingEngine::Ollama {
296 client,
297 model,
298 base_url,
299 }
300 }
301 };
302
303 Ok(Self {
304 backend: config.backend,
305 model: config.model.clone(),
306 base_url: config.base_url.clone(),
307 timeout_ms,
308 max_batch_size,
309 dimension: None,
310 engine,
311 })
312 }
313
314 pub fn backend(&self) -> SemanticBackend {
315 self.backend
316 }
317
318 pub fn model(&self) -> &str {
319 &self.model
320 }
321
322 pub fn base_url(&self) -> Option<&str> {
323 self.base_url.as_deref()
324 }
325
326 pub fn max_batch_size(&self) -> usize {
327 self.max_batch_size
328 }
329
330 pub fn timeout_ms(&self) -> u64 {
331 self.timeout_ms
332 }
333
334 pub fn fingerprint(
335 &mut self,
336 config: &SemanticBackendConfig,
337 ) -> Result<SemanticIndexFingerprint, String> {
338 let dimension = self.dimension()?;
339 Ok(SemanticIndexFingerprint::from_config(config, dimension))
340 }
341
342 pub fn dimension(&mut self) -> Result<usize, String> {
343 if let Some(dimension) = self.dimension {
344 return Ok(dimension);
345 }
346
347 let dimension = match &mut self.engine {
348 SemanticEmbeddingEngine::Fastembed(model) => {
349 let vectors = model
350 .embed(vec!["semantic index fingerprint probe".to_string()], None)
351 .map_err(|error| format_embedding_init_error(error.to_string()))?;
352 vectors
353 .first()
354 .map(|v| v.len())
355 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
356 }
357 SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
358 let vectors =
359 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
360 vectors
361 .first()
362 .map(|v| v.len())
363 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
364 }
365 SemanticEmbeddingEngine::Ollama { .. } => {
366 let vectors =
367 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
368 vectors
369 .first()
370 .map(|v| v.len())
371 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
372 }
373 };
374
375 self.dimension = Some(dimension);
376 Ok(dimension)
377 }
378
379 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
380 self.embed_texts(texts)
381 }
382
383 fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
384 match &mut self.engine {
385 SemanticEmbeddingEngine::Fastembed(model) => model
386 .embed(texts, None::<usize>)
387 .map_err(|error| format_embedding_init_error(error.to_string()))
388 .map_err(|error| format!("failed to embed batch: {error}")),
389 SemanticEmbeddingEngine::OpenAiCompatible {
390 client,
391 model,
392 base_url,
393 api_key,
394 } => {
395 let expected_text_count = texts.len();
396 let endpoint = build_openai_embeddings_endpoint(base_url);
397 let body = serde_json::json!({
398 "input": texts,
399 "model": model,
400 });
401
402 let raw = send_embedding_request(
403 || {
404 let mut request = client
405 .post(&endpoint)
406 .json(&body)
407 .header("Content-Type", "application/json");
408
409 if let Some(api_key) = api_key {
410 request = request.header("Authorization", format!("Bearer {api_key}"));
411 }
412
413 request
414 },
415 "openai compatible",
416 )?;
417
418 #[derive(Deserialize)]
419 struct OpenAiResponse {
420 data: Vec<OpenAiEmbeddingResult>,
421 }
422
423 #[derive(Deserialize)]
424 struct OpenAiEmbeddingResult {
425 embedding: Vec<f32>,
426 index: Option<u32>,
427 }
428
429 let parsed: OpenAiResponse = serde_json::from_str(&raw)
430 .map_err(|error| format!("invalid openai compatible response: {error}"))?;
431 if parsed.data.len() != expected_text_count {
432 return Err(format!(
433 "openai compatible response returned {} embeddings for {} inputs",
434 parsed.data.len(),
435 expected_text_count
436 ));
437 }
438
439 let mut vectors = vec![Vec::new(); parsed.data.len()];
440 for (i, item) in parsed.data.into_iter().enumerate() {
441 let index = item.index.unwrap_or(i as u32) as usize;
442 if index >= vectors.len() {
443 return Err(
444 "openai compatible response contains invalid vector index".to_string()
445 );
446 }
447 vectors[index] = item.embedding;
448 }
449
450 for vector in &vectors {
451 if vector.is_empty() {
452 return Err(
453 "openai compatible response contained missing vectors".to_string()
454 );
455 }
456 }
457
458 self.dimension = vectors.first().map(Vec::len);
459 Ok(vectors)
460 }
461 SemanticEmbeddingEngine::Ollama {
462 client,
463 model,
464 base_url,
465 } => {
466 let expected_text_count = texts.len();
467 let endpoint = build_ollama_embeddings_endpoint(base_url);
468
469 #[derive(Serialize)]
470 struct OllamaPayload<'a> {
471 model: &'a str,
472 input: Vec<String>,
473 }
474
475 let payload = OllamaPayload {
476 model,
477 input: texts,
478 };
479
480 let raw = send_embedding_request(
481 || {
482 client
483 .post(&endpoint)
484 .json(&payload)
485 .header("Content-Type", "application/json")
486 },
487 "ollama",
488 )?;
489
490 #[derive(Deserialize)]
491 struct OllamaResponse {
492 embeddings: Vec<Vec<f32>>,
493 }
494
495 let parsed: OllamaResponse = serde_json::from_str(&raw)
496 .map_err(|error| format!("invalid ollama response: {error}"))?;
497 if parsed.embeddings.is_empty() {
498 return Err("ollama response returned no embeddings".to_string());
499 }
500 if parsed.embeddings.len() != expected_text_count {
501 return Err(format!(
502 "ollama response returned {} embeddings for {} inputs",
503 parsed.embeddings.len(),
504 expected_text_count
505 ));
506 }
507
508 let vectors = parsed.embeddings;
509 for vector in &vectors {
510 if vector.is_empty() {
511 return Err("ollama response contained empty embeddings".to_string());
512 }
513 }
514
515 self.dimension = vectors.first().map(Vec::len);
516 Ok(vectors)
517 }
518 }
519 }
520}
521
522pub fn pre_validate_onnx_runtime() -> Result<(), String> {
526 let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
527
528 #[cfg(any(target_os = "linux", target_os = "macos"))]
529 {
530 #[cfg(target_os = "linux")]
531 let default_name = "libonnxruntime.so";
532 #[cfg(target_os = "macos")]
533 let default_name = "libonnxruntime.dylib";
534
535 let lib_name = dylib_path.as_deref().unwrap_or(default_name);
536
537 unsafe {
538 let c_name = std::ffi::CString::new(lib_name)
539 .map_err(|e| format!("invalid library path: {}", e))?;
540 let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
541 if handle.is_null() {
542 let err = libc::dlerror();
543 let msg = if err.is_null() {
544 "unknown dlopen error".to_string()
545 } else {
546 std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
547 };
548 return Err(format!(
549 "ONNX Runtime not found. dlopen('{}') failed: {}. \
550 Run `bunx @cortexkit/aft-opencode@latest doctor` to diagnose.",
551 lib_name, msg
552 ));
553 }
554
555 let detected_version = detect_ort_version_from_path(lib_name);
558
559 libc::dlclose(handle);
560
561 if let Some(ref version) = detected_version {
563 let parts: Vec<&str> = version.split('.').collect();
564 if let (Some(major), Some(minor)) = (
565 parts.first().and_then(|s| s.parse::<u32>().ok()),
566 parts.get(1).and_then(|s| s.parse::<u32>().ok()),
567 ) {
568 if major != 1 || minor < 20 {
569 return Err(format!(
570 "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
571 Solutions:\n\
572 1. Remove the old library and restart (AFT auto-downloads the correct version):\n\
573 {}\n\
574 2. Or install ONNX Runtime 1.24: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
575 3. Run `bunx @cortexkit/aft-opencode@latest doctor` for full diagnostics.",
576 version,
577 lib_name,
578 suggest_removal_command(lib_name),
579 ));
580 }
581 }
582 }
583 }
584 }
585
586 #[cfg(target_os = "windows")]
587 {
588 let _ = dylib_path;
590 }
591
592 Ok(())
593}
594
595fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
598 let path = std::path::Path::new(lib_path);
599
600 for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
602 .into_iter()
603 .flatten()
604 {
605 if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
606 if let Some(version) = extract_version_from_filename(name) {
607 return Some(version);
608 }
609 }
610 }
611
612 if let Some(parent) = path.parent() {
614 if let Ok(entries) = std::fs::read_dir(parent) {
615 for entry in entries.flatten() {
616 if let Some(name) = entry.file_name().to_str() {
617 if name.starts_with("libonnxruntime") {
618 if let Some(version) = extract_version_from_filename(name) {
619 return Some(version);
620 }
621 }
622 }
623 }
624 }
625 }
626
627 None
628}
629
630fn extract_version_from_filename(name: &str) -> Option<String> {
632 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
634 re.find(name).map(|m| m.as_str().to_string())
635}
636
637fn suggest_removal_command(lib_path: &str) -> String {
638 if lib_path.starts_with("/usr/local/lib")
639 || lib_path == "libonnxruntime.so"
640 || lib_path == "libonnxruntime.dylib"
641 {
642 #[cfg(target_os = "linux")]
643 return " sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
644 #[cfg(target_os = "macos")]
645 return " sudo rm /usr/local/lib/libonnxruntime*".to_string();
646 #[cfg(target_os = "windows")]
647 return " Delete the ONNX Runtime DLL from your PATH".to_string();
648 }
649 format!(" rm '{}'", lib_path)
650}
651
652pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
653 pre_validate_onnx_runtime()?;
655
656 let selected_model = match model {
657 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
658 _ => {
659 return Err(format!(
660 "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
661 model
662 ))
663 }
664 };
665
666 TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
667}
668
669pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
670 if message.trim_start().starts_with("ONNX Runtime not found.") {
671 return true;
672 }
673
674 let message = message.to_ascii_lowercase();
675 let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
676 .iter()
677 .any(|pattern| message.contains(pattern));
678 let mentions_dynamic_load_failure = [
679 "shared library",
680 "dynamic library",
681 "failed to load",
682 "could not load",
683 "unable to load",
684 "dlopen",
685 "loadlibrary",
686 "no such file",
687 "not found",
688 ]
689 .iter()
690 .any(|pattern| message.contains(pattern));
691
692 mentions_onnx_runtime && mentions_dynamic_load_failure
693}
694
695fn format_embedding_init_error(error: impl Display) -> String {
696 let message = error.to_string();
697
698 if is_onnx_runtime_unavailable(&message) {
699 return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
700 }
701
702 format!("failed to initialize semantic embedding model: {message}")
703}
704
705#[derive(Debug, Clone)]
707pub struct SemanticChunk {
708 pub file: PathBuf,
710 pub name: String,
712 pub kind: SymbolKind,
714 pub start_line: u32,
716 pub end_line: u32,
717 pub exported: bool,
719 pub embed_text: String,
721 pub snippet: String,
723}
724
725#[derive(Debug)]
727struct EmbeddingEntry {
728 chunk: SemanticChunk,
729 vector: Vec<f32>,
730}
731
732#[derive(Debug)]
734pub struct SemanticIndex {
735 entries: Vec<EmbeddingEntry>,
736 file_mtimes: HashMap<PathBuf, SystemTime>,
738 dimension: usize,
740 fingerprint: Option<SemanticIndexFingerprint>,
741}
742
743#[derive(Debug)]
745pub struct SemanticResult {
746 pub file: PathBuf,
747 pub name: String,
748 pub kind: SymbolKind,
749 pub start_line: u32,
750 pub end_line: u32,
751 pub exported: bool,
752 pub snippet: String,
753 pub score: f32,
754}
755
756impl SemanticIndex {
757 pub fn new() -> Self {
758 Self {
759 entries: Vec::new(),
760 file_mtimes: HashMap::new(),
761 dimension: DEFAULT_DIMENSION, fingerprint: None,
763 }
764 }
765
766 pub fn entry_count(&self) -> usize {
768 self.entries.len()
769 }
770
771 pub fn status_label(&self) -> &'static str {
773 if self.entries.is_empty() {
774 "empty"
775 } else {
776 "ready"
777 }
778 }
779
780 fn collect_chunks(
781 project_root: &Path,
782 files: &[PathBuf],
783 ) -> (Vec<SemanticChunk>, HashMap<PathBuf, SystemTime>) {
784 let mut parser = FileParser::new();
785 let mut chunks: Vec<SemanticChunk> = Vec::new();
786 let mut file_mtimes: HashMap<PathBuf, SystemTime> = HashMap::new();
787
788 for file in files {
789 let mtime = std::fs::metadata(file)
790 .and_then(|m| m.modified())
791 .unwrap_or(SystemTime::UNIX_EPOCH);
792 file_mtimes.insert(file.clone(), mtime);
793
794 let source = match std::fs::read_to_string(file) {
795 Ok(s) => s,
796 Err(_) => continue,
797 };
798
799 let symbols = match parser.extract_symbols(file) {
800 Ok(s) => s,
801 Err(_) => continue,
802 };
803 let file_chunks = symbols_to_chunks(file, &symbols, &source, project_root);
804 chunks.extend(file_chunks);
805 }
806
807 (chunks, file_mtimes)
808 }
809
810 fn build_from_chunks<F, P>(
811 chunks: Vec<SemanticChunk>,
812 file_mtimes: HashMap<PathBuf, SystemTime>,
813 embed_fn: &mut F,
814 max_batch_size: usize,
815 mut progress: Option<&mut P>,
816 ) -> Result<Self, String>
817 where
818 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
819 P: FnMut(usize, usize),
820 {
821 let total_chunks = chunks.len();
822
823 if chunks.is_empty() {
824 return Ok(Self {
825 entries: Vec::new(),
826 file_mtimes,
827 dimension: DEFAULT_DIMENSION,
828 fingerprint: None,
829 });
830 }
831
832 let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
834 let mut expected_dimension: Option<usize> = None;
835 let batch_size = max_batch_size.max(1);
836 for batch_start in (0..chunks.len()).step_by(batch_size) {
837 let batch_end = (batch_start + batch_size).min(chunks.len());
838 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
839 .iter()
840 .map(|c| c.embed_text.clone())
841 .collect();
842
843 let vectors = embed_fn(batch_texts)?;
844 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
845
846 if let Some(dim) = vectors.first().map(|v| v.len()) {
848 match expected_dimension {
849 None => expected_dimension = Some(dim),
850 Some(expected) if dim != expected => {
851 return Err(format!(
852 "embedding dimension changed across batches: expected {expected}, got {dim}"
853 ));
854 }
855 _ => {}
856 }
857 }
858
859 for (i, vector) in vectors.into_iter().enumerate() {
860 let chunk_idx = batch_start + i;
861 entries.push(EmbeddingEntry {
862 chunk: chunks[chunk_idx].clone(),
863 vector,
864 });
865 }
866
867 if let Some(callback) = progress.as_mut() {
868 callback(entries.len(), total_chunks);
869 }
870 }
871
872 let dimension = entries
873 .first()
874 .map(|e| e.vector.len())
875 .unwrap_or(DEFAULT_DIMENSION);
876
877 Ok(Self {
878 entries,
879 file_mtimes,
880 dimension,
881 fingerprint: None,
882 })
883 }
884
885 pub fn build<F>(
888 project_root: &Path,
889 files: &[PathBuf],
890 embed_fn: &mut F,
891 max_batch_size: usize,
892 ) -> Result<Self, String>
893 where
894 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
895 {
896 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
897 Self::build_from_chunks(
898 chunks,
899 file_mtimes,
900 embed_fn,
901 max_batch_size,
902 Option::<&mut fn(usize, usize)>::None,
903 )
904 }
905
906 pub fn build_with_progress<F, P>(
908 project_root: &Path,
909 files: &[PathBuf],
910 embed_fn: &mut F,
911 max_batch_size: usize,
912 progress: &mut P,
913 ) -> Result<Self, String>
914 where
915 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
916 P: FnMut(usize, usize),
917 {
918 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
919 let total_chunks = chunks.len();
920 progress(0, total_chunks);
921 Self::build_from_chunks(
922 chunks,
923 file_mtimes,
924 embed_fn,
925 max_batch_size,
926 Some(progress),
927 )
928 }
929
930 pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
932 if self.entries.is_empty() || query_vector.len() != self.dimension {
933 return Vec::new();
934 }
935
936 let mut scored: Vec<(f32, usize)> = self
937 .entries
938 .iter()
939 .enumerate()
940 .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
941 .collect();
942
943 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
945
946 scored
947 .into_iter()
948 .take(top_k)
949 .filter(|(score, _)| *score > 0.0)
950 .map(|(score, idx)| {
951 let entry = &self.entries[idx];
952 SemanticResult {
953 file: entry.chunk.file.clone(),
954 name: entry.chunk.name.clone(),
955 kind: entry.chunk.kind.clone(),
956 start_line: entry.chunk.start_line,
957 end_line: entry.chunk.end_line,
958 exported: entry.chunk.exported,
959 snippet: entry.chunk.snippet.clone(),
960 score,
961 }
962 })
963 .collect()
964 }
965
966 pub fn len(&self) -> usize {
968 self.entries.len()
969 }
970
971 pub fn is_file_stale(&self, file: &Path) -> bool {
973 match self.file_mtimes.get(file) {
974 None => true,
975 Some(stored_mtime) => match fs::metadata(file).and_then(|m| m.modified()) {
976 Ok(current_mtime) => *stored_mtime != current_mtime,
977 Err(_) => true,
978 },
979 }
980 }
981
982 pub fn count_stale_files(&self) -> usize {
983 self.file_mtimes
984 .keys()
985 .filter(|path| self.is_file_stale(path))
986 .count()
987 }
988
989 pub fn remove_file(&mut self, file: &Path) {
991 self.invalidate_file(file);
992 }
993
994 pub fn invalidate_file(&mut self, file: &Path) {
995 self.entries.retain(|e| e.chunk.file != file);
996 self.file_mtimes.remove(file);
997 }
998
999 pub fn dimension(&self) -> usize {
1001 self.dimension
1002 }
1003
1004 pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1005 self.fingerprint.as_ref()
1006 }
1007
1008 pub fn backend_label(&self) -> Option<&str> {
1009 self.fingerprint.as_ref().map(|f| f.backend.as_str())
1010 }
1011
1012 pub fn model_label(&self) -> Option<&str> {
1013 self.fingerprint.as_ref().map(|f| f.model.as_str())
1014 }
1015
1016 pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1017 self.fingerprint = Some(fingerprint);
1018 }
1019
1020 pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1022 if self.entries.is_empty() {
1025 log::info!("[aft] skipping semantic index persistence (0 entries)");
1026 return;
1027 }
1028 let dir = storage_dir.join("semantic").join(project_key);
1029 if let Err(e) = fs::create_dir_all(&dir) {
1030 log::warn!("[aft] failed to create semantic cache dir: {}", e);
1031 return;
1032 }
1033 let data_path = dir.join("semantic.bin");
1034 let tmp_path = dir.join("semantic.bin.tmp");
1035 let bytes = self.to_bytes();
1036 if let Err(e) = fs::write(&tmp_path, &bytes) {
1037 log::warn!("[aft] failed to write semantic index: {}", e);
1038 let _ = fs::remove_file(&tmp_path);
1039 return;
1040 }
1041 if let Err(e) = fs::rename(&tmp_path, &data_path) {
1042 log::warn!("[aft] failed to rename semantic index: {}", e);
1043 let _ = fs::remove_file(&tmp_path);
1044 return;
1045 }
1046 log::info!(
1047 "[aft] semantic index persisted: {} entries, {:.1} KB",
1048 self.entries.len(),
1049 bytes.len() as f64 / 1024.0
1050 );
1051 }
1052
1053 pub fn read_from_disk(
1055 storage_dir: &Path,
1056 project_key: &str,
1057 expected_fingerprint: Option<&str>,
1058 ) -> Option<Self> {
1059 let data_path = storage_dir
1060 .join("semantic")
1061 .join(project_key)
1062 .join("semantic.bin");
1063 let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1064 if file_len < HEADER_BYTES_V1 {
1065 log::warn!(
1066 "[aft] corrupt semantic index (too small: {} bytes), removing",
1067 file_len
1068 );
1069 let _ = fs::remove_file(&data_path);
1070 return None;
1071 }
1072
1073 let bytes = fs::read(&data_path).ok()?;
1074 match Self::from_bytes(&bytes) {
1075 Ok(index) => {
1076 if index.entries.is_empty() {
1077 log::info!("[aft] cached semantic index is empty, will rebuild");
1078 let _ = fs::remove_file(&data_path);
1079 return None;
1080 }
1081 if let Some(expected) = expected_fingerprint {
1082 let matches = index
1083 .fingerprint()
1084 .map(|fingerprint| fingerprint.matches_expected(expected))
1085 .unwrap_or(false);
1086 if !matches {
1087 log::info!("[aft] cached semantic index fingerprint mismatch, rebuilding");
1088 let _ = fs::remove_file(&data_path);
1089 return None;
1090 }
1091 }
1092 log::info!(
1093 "[aft] loaded semantic index from disk: {} entries",
1094 index.entries.len()
1095 );
1096 Some(index)
1097 }
1098 Err(e) => {
1099 log::warn!("[aft] corrupt semantic index, rebuilding: {}", e);
1100 let _ = fs::remove_file(&data_path);
1101 None
1102 }
1103 }
1104 }
1105
1106 pub fn to_bytes(&self) -> Vec<u8> {
1108 let mut buf = Vec::new();
1109 let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1110 let encoded = fingerprint.as_string();
1111 if encoded.is_empty() {
1112 None
1113 } else {
1114 Some(encoded.into_bytes())
1115 }
1116 });
1117
1118 let version = SEMANTIC_INDEX_VERSION_V3;
1129 buf.push(version);
1130 buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1131 buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1132 let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1133 buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1134 buf.extend_from_slice(fp_bytes_ref);
1135
1136 buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1139 for (path, mtime) in &self.file_mtimes {
1140 let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1141 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1142 buf.extend_from_slice(&path_bytes);
1143 let duration = mtime
1144 .duration_since(SystemTime::UNIX_EPOCH)
1145 .unwrap_or_default();
1146 buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1147 buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1148 }
1149
1150 for entry in &self.entries {
1152 let c = &entry.chunk;
1153
1154 let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1156 buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1157 buf.extend_from_slice(&file_bytes);
1158
1159 let name_bytes = c.name.as_bytes();
1161 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1162 buf.extend_from_slice(name_bytes);
1163
1164 buf.push(symbol_kind_to_u8(&c.kind));
1166
1167 buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1169 buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1170 buf.push(c.exported as u8);
1171
1172 let snippet_bytes = c.snippet.as_bytes();
1174 buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1175 buf.extend_from_slice(snippet_bytes);
1176
1177 let embed_bytes = c.embed_text.as_bytes();
1179 buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1180 buf.extend_from_slice(embed_bytes);
1181
1182 for &val in &entry.vector {
1184 buf.extend_from_slice(&val.to_le_bytes());
1185 }
1186 }
1187
1188 buf
1189 }
1190
1191 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1193 let mut pos = 0;
1194
1195 if data.len() < HEADER_BYTES_V1 {
1196 return Err("data too short".to_string());
1197 }
1198
1199 let version = data[pos];
1200 pos += 1;
1201 if version != SEMANTIC_INDEX_VERSION_V1
1202 && version != SEMANTIC_INDEX_VERSION_V2
1203 && version != SEMANTIC_INDEX_VERSION_V3
1204 {
1205 return Err(format!("unsupported version: {}", version));
1206 }
1207 if (version == SEMANTIC_INDEX_VERSION_V2 || version == SEMANTIC_INDEX_VERSION_V3)
1211 && data.len() < HEADER_BYTES_V2
1212 {
1213 return Err("data too short for semantic index v2/v3 header".to_string());
1214 }
1215
1216 let dimension = read_u32(data, &mut pos)? as usize;
1217 let entry_count = read_u32(data, &mut pos)? as usize;
1218 if dimension == 0 || dimension > MAX_DIMENSION {
1219 return Err(format!("invalid embedding dimension: {}", dimension));
1220 }
1221 if entry_count > MAX_ENTRIES {
1222 return Err(format!("too many semantic index entries: {}", entry_count));
1223 }
1224
1225 let has_fingerprint_field =
1231 version == SEMANTIC_INDEX_VERSION_V2 || version == SEMANTIC_INDEX_VERSION_V3;
1232 let fingerprint = if has_fingerprint_field {
1233 let fingerprint_len = read_u32(data, &mut pos)? as usize;
1234 if pos + fingerprint_len > data.len() {
1235 return Err("unexpected end of data reading fingerprint".to_string());
1236 }
1237 if fingerprint_len == 0 {
1238 None
1239 } else {
1240 let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1241 pos += fingerprint_len;
1242 Some(
1243 serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1244 .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1245 )
1246 }
1247 } else {
1248 None
1249 };
1250
1251 let mtime_count = read_u32(data, &mut pos)? as usize;
1253 if mtime_count > MAX_ENTRIES {
1254 return Err(format!("too many semantic file mtimes: {}", mtime_count));
1255 }
1256
1257 let vector_bytes = entry_count
1258 .checked_mul(dimension)
1259 .and_then(|count| count.checked_mul(F32_BYTES))
1260 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1261 if vector_bytes > data.len().saturating_sub(pos) {
1262 return Err("semantic index vectors exceed available data".to_string());
1263 }
1264
1265 let mut file_mtimes = HashMap::with_capacity(mtime_count);
1266 for _ in 0..mtime_count {
1267 let path = read_string(data, &mut pos)?;
1268 let secs = read_u64(data, &mut pos)?;
1269 let nanos = if version == SEMANTIC_INDEX_VERSION_V3 {
1275 read_u32(data, &mut pos)?
1276 } else {
1277 0
1278 };
1279 if nanos >= 1_000_000_000 {
1286 return Err(format!(
1287 "invalid semantic mtime: nanos {} >= 1_000_000_000",
1288 nanos
1289 ));
1290 }
1291 let duration = std::time::Duration::new(secs, nanos);
1292 let mtime = SystemTime::UNIX_EPOCH
1293 .checked_add(duration)
1294 .ok_or_else(|| {
1295 format!(
1296 "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1297 secs, nanos
1298 )
1299 })?;
1300 file_mtimes.insert(PathBuf::from(path), mtime);
1301 }
1302
1303 let mut entries = Vec::with_capacity(entry_count);
1305 for _ in 0..entry_count {
1306 let file = PathBuf::from(read_string(data, &mut pos)?);
1307 let name = read_string(data, &mut pos)?;
1308
1309 if pos >= data.len() {
1310 return Err("unexpected end of data".to_string());
1311 }
1312 let kind = u8_to_symbol_kind(data[pos]);
1313 pos += 1;
1314
1315 let start_line = read_u32(data, &mut pos)?;
1316 let end_line = read_u32(data, &mut pos)?;
1317
1318 if pos >= data.len() {
1319 return Err("unexpected end of data".to_string());
1320 }
1321 let exported = data[pos] != 0;
1322 pos += 1;
1323
1324 let snippet = read_string(data, &mut pos)?;
1325 let embed_text = read_string(data, &mut pos)?;
1326
1327 let vec_bytes = dimension
1329 .checked_mul(F32_BYTES)
1330 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1331 if pos + vec_bytes > data.len() {
1332 return Err("unexpected end of data reading vector".to_string());
1333 }
1334 let mut vector = Vec::with_capacity(dimension);
1335 for _ in 0..dimension {
1336 let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1337 vector.push(f32::from_le_bytes(bytes));
1338 pos += 4;
1339 }
1340
1341 entries.push(EmbeddingEntry {
1342 chunk: SemanticChunk {
1343 file,
1344 name,
1345 kind,
1346 start_line,
1347 end_line,
1348 exported,
1349 embed_text,
1350 snippet,
1351 },
1352 vector,
1353 });
1354 }
1355
1356 Ok(Self {
1357 entries,
1358 file_mtimes,
1359 dimension,
1360 fingerprint,
1361 })
1362 }
1363}
1364
1365fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1367 let relative = file
1368 .strip_prefix(project_root)
1369 .unwrap_or(file)
1370 .to_string_lossy();
1371
1372 let kind_label = match &symbol.kind {
1373 SymbolKind::Function => "function",
1374 SymbolKind::Class => "class",
1375 SymbolKind::Method => "method",
1376 SymbolKind::Struct => "struct",
1377 SymbolKind::Interface => "interface",
1378 SymbolKind::Enum => "enum",
1379 SymbolKind::TypeAlias => "type",
1380 SymbolKind::Variable => "variable",
1381 SymbolKind::Heading => "heading",
1382 };
1383
1384 let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1386
1387 if let Some(sig) = &symbol.signature {
1388 text.push_str(&format!(" signature:{}", sig));
1389 }
1390
1391 let lines: Vec<&str> = source.lines().collect();
1393 let start = (symbol.range.start_line.saturating_sub(1) as usize).min(lines.len()); let end = (symbol.range.end_line as usize).min(lines.len()); if start < end {
1396 let body: String = lines[start..end]
1397 .iter()
1398 .take(15) .copied()
1400 .collect::<Vec<&str>>()
1401 .join("\n");
1402 let snippet = if body.len() > 300 {
1403 format!("{}...", &body[..body.floor_char_boundary(300)])
1404 } else {
1405 body
1406 };
1407 text.push_str(&format!(" body:{}", snippet));
1408 }
1409
1410 text
1411}
1412
1413fn build_snippet(symbol: &Symbol, source: &str) -> String {
1415 let lines: Vec<&str> = source.lines().collect();
1416 let start = (symbol.range.start_line.saturating_sub(1) as usize).min(lines.len());
1417 let end = (symbol.range.end_line as usize).min(lines.len());
1418 if start < end {
1419 let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1420 let mut snippet = snippet_lines.join("\n");
1421 if end - start > 5 {
1422 snippet.push_str("\n ...");
1423 }
1424 if snippet.len() > 300 {
1425 snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1426 }
1427 snippet
1428 } else {
1429 String::new()
1430 }
1431}
1432
1433fn symbols_to_chunks(
1435 file: &Path,
1436 symbols: &[Symbol],
1437 source: &str,
1438 project_root: &Path,
1439) -> Vec<SemanticChunk> {
1440 let mut chunks = Vec::new();
1441
1442 for symbol in symbols {
1443 let line_count = symbol
1445 .range
1446 .end_line
1447 .saturating_sub(symbol.range.start_line)
1448 + 1;
1449 if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1450 continue;
1451 }
1452
1453 let embed_text = build_embed_text(symbol, source, file, project_root);
1454 let snippet = build_snippet(symbol, source);
1455
1456 chunks.push(SemanticChunk {
1457 file: file.to_path_buf(),
1458 name: symbol.name.clone(),
1459 kind: symbol.kind.clone(),
1460 start_line: symbol.range.start_line,
1461 end_line: symbol.range.end_line,
1462 exported: symbol.exported,
1463 embed_text,
1464 snippet,
1465 });
1466
1467 }
1470
1471 chunks
1472}
1473
1474fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1476 if a.len() != b.len() {
1477 return 0.0;
1478 }
1479
1480 let mut dot = 0.0f32;
1481 let mut norm_a = 0.0f32;
1482 let mut norm_b = 0.0f32;
1483
1484 for i in 0..a.len() {
1485 dot += a[i] * b[i];
1486 norm_a += a[i] * a[i];
1487 norm_b += b[i] * b[i];
1488 }
1489
1490 let denom = norm_a.sqrt() * norm_b.sqrt();
1491 if denom == 0.0 {
1492 0.0
1493 } else {
1494 dot / denom
1495 }
1496}
1497
1498fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1500 match kind {
1501 SymbolKind::Function => 0,
1502 SymbolKind::Class => 1,
1503 SymbolKind::Method => 2,
1504 SymbolKind::Struct => 3,
1505 SymbolKind::Interface => 4,
1506 SymbolKind::Enum => 5,
1507 SymbolKind::TypeAlias => 6,
1508 SymbolKind::Variable => 7,
1509 SymbolKind::Heading => 8,
1510 }
1511}
1512
1513fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1514 match v {
1515 0 => SymbolKind::Function,
1516 1 => SymbolKind::Class,
1517 2 => SymbolKind::Method,
1518 3 => SymbolKind::Struct,
1519 4 => SymbolKind::Interface,
1520 5 => SymbolKind::Enum,
1521 6 => SymbolKind::TypeAlias,
1522 7 => SymbolKind::Variable,
1523 _ => SymbolKind::Heading,
1524 }
1525}
1526
1527fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
1528 if *pos + 4 > data.len() {
1529 return Err("unexpected end of data reading u32".to_string());
1530 }
1531 let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
1532 *pos += 4;
1533 Ok(val)
1534}
1535
1536fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
1537 if *pos + 8 > data.len() {
1538 return Err("unexpected end of data reading u64".to_string());
1539 }
1540 let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
1541 *pos += 8;
1542 Ok(u64::from_le_bytes(bytes))
1543}
1544
1545fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
1546 let len = read_u32(data, pos)? as usize;
1547 if *pos + len > data.len() {
1548 return Err("unexpected end of data reading string".to_string());
1549 }
1550 let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
1551 *pos += len;
1552 Ok(s)
1553}
1554
1555#[cfg(test)]
1556mod tests {
1557 use super::*;
1558 use crate::config::{SemanticBackend, SemanticBackendConfig};
1559 use std::io::{Read, Write};
1560 use std::net::TcpListener;
1561 use std::thread;
1562
1563 fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
1564 where
1565 F: Fn(String, String, String) -> String + Send + 'static,
1566 {
1567 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1568 let addr = listener.local_addr().expect("local addr");
1569 let handle = thread::spawn(move || {
1570 let (mut stream, _) = listener.accept().expect("accept request");
1571 let mut buf = Vec::new();
1572 let mut chunk = [0u8; 4096];
1573 let mut header_end = None;
1574 let mut content_length = 0usize;
1575 loop {
1576 let n = stream.read(&mut chunk).expect("read request");
1577 if n == 0 {
1578 break;
1579 }
1580 buf.extend_from_slice(&chunk[..n]);
1581 if header_end.is_none() {
1582 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
1583 header_end = Some(pos + 4);
1584 let headers = String::from_utf8_lossy(&buf[..pos + 4]);
1585 for line in headers.lines() {
1586 if let Some(value) = line.strip_prefix("Content-Length:") {
1587 content_length = value.trim().parse::<usize>().unwrap_or(0);
1588 }
1589 }
1590 }
1591 }
1592 if let Some(end) = header_end {
1593 if buf.len() >= end + content_length {
1594 break;
1595 }
1596 }
1597 }
1598
1599 let end = header_end.expect("header terminator");
1600 let request = String::from_utf8_lossy(&buf[..end]).to_string();
1601 let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
1602 let mut lines = request.lines();
1603 let request_line = lines.next().expect("request line").to_string();
1604 let path = request_line
1605 .split_whitespace()
1606 .nth(1)
1607 .expect("request path")
1608 .to_string();
1609 let response_body = handler(request_line, path, body);
1610 let response = format!(
1611 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
1612 response_body.len(),
1613 response_body
1614 );
1615 stream
1616 .write_all(response.as_bytes())
1617 .expect("write response");
1618 });
1619
1620 (format!("http://{}", addr), handle)
1621 }
1622
1623 #[test]
1624 fn test_cosine_similarity_identical() {
1625 let a = vec![1.0, 0.0, 0.0];
1626 let b = vec![1.0, 0.0, 0.0];
1627 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1628 }
1629
1630 #[test]
1631 fn test_cosine_similarity_orthogonal() {
1632 let a = vec![1.0, 0.0, 0.0];
1633 let b = vec![0.0, 1.0, 0.0];
1634 assert!(cosine_similarity(&a, &b).abs() < 0.001);
1635 }
1636
1637 #[test]
1638 fn test_cosine_similarity_opposite() {
1639 let a = vec![1.0, 0.0, 0.0];
1640 let b = vec![-1.0, 0.0, 0.0];
1641 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
1642 }
1643
1644 #[test]
1645 fn test_serialization_roundtrip() {
1646 let mut index = SemanticIndex::new();
1647 index.entries.push(EmbeddingEntry {
1648 chunk: SemanticChunk {
1649 file: PathBuf::from("/src/main.rs"),
1650 name: "handle_request".to_string(),
1651 kind: SymbolKind::Function,
1652 start_line: 10,
1653 end_line: 25,
1654 exported: true,
1655 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1656 snippet: "fn handle_request() {\n // ...\n}".to_string(),
1657 },
1658 vector: vec![0.1, 0.2, 0.3, 0.4],
1659 });
1660 index.dimension = 4;
1661 index
1662 .file_mtimes
1663 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1664 index.set_fingerprint(SemanticIndexFingerprint {
1665 backend: "fastembed".to_string(),
1666 model: "all-MiniLM-L6-v2".to_string(),
1667 base_url: FALLBACK_BACKEND.to_string(),
1668 dimension: 4,
1669 });
1670
1671 let bytes = index.to_bytes();
1672 let restored = SemanticIndex::from_bytes(&bytes).unwrap();
1673
1674 assert_eq!(restored.entries.len(), 1);
1675 assert_eq!(restored.entries[0].chunk.name, "handle_request");
1676 assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
1677 assert_eq!(restored.dimension, 4);
1678 assert_eq!(restored.backend_label(), Some("fastembed"));
1679 assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
1680 }
1681
1682 #[test]
1683 fn test_search_top_k() {
1684 let mut index = SemanticIndex::new();
1685 index.dimension = 3;
1686
1687 for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
1689 let mut vec = vec![0.0f32; 3];
1690 vec[i] = 1.0; index.entries.push(EmbeddingEntry {
1692 chunk: SemanticChunk {
1693 file: PathBuf::from("/src/lib.rs"),
1694 name: name.to_string(),
1695 kind: SymbolKind::Function,
1696 start_line: (i * 10 + 1) as u32,
1697 end_line: (i * 10 + 5) as u32,
1698 exported: true,
1699 embed_text: format!("kind:function name:{}", name),
1700 snippet: format!("fn {}() {{}}", name),
1701 },
1702 vector: vec,
1703 });
1704 }
1705
1706 let query = vec![0.9, 0.1, 0.0];
1708 let results = index.search(&query, 2);
1709
1710 assert_eq!(results.len(), 2);
1711 assert_eq!(results[0].name, "auth"); assert!(results[0].score > results[1].score);
1713 }
1714
1715 #[test]
1716 fn test_empty_index_search() {
1717 let index = SemanticIndex::new();
1718 let results = index.search(&[0.1, 0.2, 0.3], 10);
1719 assert!(results.is_empty());
1720 }
1721
1722 #[test]
1723 fn rejects_oversized_dimension_during_deserialization() {
1724 let mut bytes = Vec::new();
1725 bytes.push(1u8);
1726 bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
1727 bytes.extend_from_slice(&0u32.to_le_bytes());
1728 bytes.extend_from_slice(&0u32.to_le_bytes());
1729
1730 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1731 }
1732
1733 #[test]
1734 fn rejects_oversized_entry_count_during_deserialization() {
1735 let mut bytes = Vec::new();
1736 bytes.push(1u8);
1737 bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
1738 bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
1739 bytes.extend_from_slice(&0u32.to_le_bytes());
1740
1741 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1742 }
1743
1744 #[test]
1745 fn invalidate_file_removes_entries_and_mtime() {
1746 let target = PathBuf::from("/src/main.rs");
1747 let mut index = SemanticIndex::new();
1748 index.entries.push(EmbeddingEntry {
1749 chunk: SemanticChunk {
1750 file: target.clone(),
1751 name: "main".to_string(),
1752 kind: SymbolKind::Function,
1753 start_line: 0,
1754 end_line: 1,
1755 exported: false,
1756 embed_text: "main".to_string(),
1757 snippet: "fn main() {}".to_string(),
1758 },
1759 vector: vec![1.0; DEFAULT_DIMENSION],
1760 });
1761 index
1762 .file_mtimes
1763 .insert(target.clone(), SystemTime::UNIX_EPOCH);
1764
1765 index.invalidate_file(&target);
1766
1767 assert!(index.entries.is_empty());
1768 assert!(!index.file_mtimes.contains_key(&target));
1769 }
1770
1771 #[test]
1772 fn detects_missing_onnx_runtime_from_dynamic_load_error() {
1773 let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
1774
1775 assert!(is_onnx_runtime_unavailable(message));
1776 }
1777
1778 #[test]
1779 fn formats_missing_onnx_runtime_with_install_hint() {
1780 let message = format_embedding_init_error(
1781 "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
1782 );
1783
1784 assert!(message.starts_with("ONNX Runtime not found. Install via:"));
1785 assert!(message.contains("Original error:"));
1786 }
1787
1788 #[test]
1789 fn openai_compatible_backend_embeds_with_mock_server() {
1790 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1791 assert!(request_line.starts_with("POST "));
1792 assert_eq!(path, "/v1/embeddings");
1793 "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
1794 });
1795
1796 let config = SemanticBackendConfig {
1797 backend: SemanticBackend::OpenAiCompatible,
1798 model: "test-embedding".to_string(),
1799 base_url: Some(base_url),
1800 api_key_env: None,
1801 timeout_ms: 5_000,
1802 max_batch_size: 64,
1803 };
1804
1805 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1806 let vectors = model
1807 .embed(vec!["hello".to_string(), "world".to_string()])
1808 .unwrap();
1809
1810 assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
1811 handle.join().unwrap();
1812 }
1813
1814 #[test]
1815 fn ollama_backend_embeds_with_mock_server() {
1816 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1817 assert!(request_line.starts_with("POST "));
1818 assert_eq!(path, "/api/embed");
1819 "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
1820 });
1821
1822 let config = SemanticBackendConfig {
1823 backend: SemanticBackend::Ollama,
1824 model: "embeddinggemma".to_string(),
1825 base_url: Some(base_url),
1826 api_key_env: None,
1827 timeout_ms: 5_000,
1828 max_batch_size: 64,
1829 };
1830
1831 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1832 let vectors = model
1833 .embed(vec!["hello".to_string(), "world".to_string()])
1834 .unwrap();
1835
1836 assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
1837 handle.join().unwrap();
1838 }
1839
1840 #[test]
1841 fn read_from_disk_rejects_fingerprint_mismatch() {
1842 let storage = tempfile::tempdir().unwrap();
1843 let project_key = "proj";
1844
1845 let mut index = SemanticIndex::new();
1846 index.entries.push(EmbeddingEntry {
1847 chunk: SemanticChunk {
1848 file: PathBuf::from("/src/main.rs"),
1849 name: "handle_request".to_string(),
1850 kind: SymbolKind::Function,
1851 start_line: 10,
1852 end_line: 25,
1853 exported: true,
1854 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1855 snippet: "fn handle_request() {}".to_string(),
1856 },
1857 vector: vec![0.1, 0.2, 0.3],
1858 });
1859 index.dimension = 3;
1860 index
1861 .file_mtimes
1862 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1863 index.set_fingerprint(SemanticIndexFingerprint {
1864 backend: "openai_compatible".to_string(),
1865 model: "test-embedding".to_string(),
1866 base_url: "http://127.0.0.1:1234/v1".to_string(),
1867 dimension: 3,
1868 });
1869 index.write_to_disk(storage.path(), project_key);
1870
1871 let matching = index.fingerprint().unwrap().as_string();
1872 assert!(
1873 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
1874 );
1875
1876 let mismatched = SemanticIndexFingerprint {
1877 backend: "ollama".to_string(),
1878 model: "embeddinggemma".to_string(),
1879 base_url: "http://127.0.0.1:11434".to_string(),
1880 dimension: 3,
1881 }
1882 .as_string();
1883 assert!(
1884 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
1885 );
1886 }
1887}