1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::FileParser;
3use crate::symbols::{Symbol, SymbolKind};
4
5use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
6use reqwest::blocking::Client;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::env;
10use std::fmt::Display;
11use std::fs;
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14use std::time::SystemTime;
15use url::Url;
16
17const DEFAULT_DIMENSION: usize = 384;
18const MAX_ENTRIES: usize = 1_000_000;
19const MAX_DIMENSION: usize = 1024;
20const F32_BYTES: usize = std::mem::size_of::<f32>();
21const HEADER_BYTES_V1: usize = 9;
22const HEADER_BYTES_V2: usize = 13;
23const ONNX_RUNTIME_INSTALL_HINT: &str =
24 "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
25
26const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
27const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
28const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
29const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
30const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
32const DEFAULT_MAX_BATCH_SIZE: usize = 64;
33const FALLBACK_BACKEND: &str = "none";
34const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
35const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct SemanticIndexFingerprint {
39 pub backend: String,
40 pub model: String,
41 #[serde(default)]
42 pub base_url: String,
43 pub dimension: usize,
44}
45
46impl SemanticIndexFingerprint {
47 fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
48 let base_url = config
51 .base_url
52 .as_ref()
53 .and_then(|u| normalize_base_url(u).ok())
54 .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
55 Self {
56 backend: config.backend.as_str().to_string(),
57 model: config.model.clone(),
58 base_url,
59 dimension,
60 }
61 }
62
63 pub fn as_string(&self) -> String {
64 serde_json::to_string(self).unwrap_or_else(|_| String::new())
65 }
66
67 fn matches_expected(&self, expected: &str) -> bool {
68 let encoded = self.as_string();
69 !encoded.is_empty() && encoded == expected
70 }
71}
72
73enum SemanticEmbeddingEngine {
74 Fastembed(TextEmbedding),
75 OpenAiCompatible {
76 client: Client,
77 model: String,
78 base_url: String,
79 api_key: Option<String>,
80 },
81 Ollama {
82 client: Client,
83 model: String,
84 base_url: String,
85 },
86}
87
88pub struct SemanticEmbeddingModel {
89 backend: SemanticBackend,
90 model: String,
91 base_url: Option<String>,
92 timeout_ms: u64,
93 max_batch_size: usize,
94 dimension: Option<usize>,
95 engine: SemanticEmbeddingEngine,
96}
97
98pub type EmbeddingModel = SemanticEmbeddingModel;
99
100fn validate_embedding_batch(
101 vectors: &[Vec<f32>],
102 expected_count: usize,
103 context: &str,
104) -> Result<(), String> {
105 if expected_count > 0 && vectors.is_empty() {
106 return Err(format!(
107 "{context} returned no vectors for {expected_count} inputs"
108 ));
109 }
110
111 if vectors.len() != expected_count {
112 return Err(format!(
113 "{context} returned {} vectors for {} inputs",
114 vectors.len(),
115 expected_count
116 ));
117 }
118
119 let Some(first_vector) = vectors.first() else {
120 return Ok(());
121 };
122 let expected_dimension = first_vector.len();
123 for (index, vector) in vectors.iter().enumerate() {
124 if vector.len() != expected_dimension {
125 return Err(format!(
126 "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
127 vector.len()
128 ));
129 }
130 }
131
132 Ok(())
133}
134
135fn normalize_base_url(raw: &str) -> Result<String, String> {
136 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
137 let scheme = parsed.scheme();
138 if scheme != "http" && scheme != "https" {
139 return Err(format!(
140 "unsupported URL scheme '{}' — only http:// and https:// are allowed",
141 scheme
142 ));
143 }
144 Ok(parsed.to_string().trim_end_matches('/').to_string())
145}
146
147fn build_openai_embeddings_endpoint(base_url: &str) -> String {
148 if base_url.ends_with("/v1") {
149 format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
150 } else {
151 format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
152 }
153}
154
155fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
156 if base_url.ends_with("/api") {
157 format!("{base_url}/embed")
158 } else {
159 format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
160 }
161}
162
163fn normalize_api_key(value: Option<String>) -> Option<String> {
164 value.and_then(|token| {
165 let token = token.trim();
166 if token.is_empty() {
167 None
168 } else {
169 Some(token.to_string())
170 }
171 })
172}
173
174fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
175 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
176}
177
178fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
179 error.is_connect()
180}
181
182fn sleep_before_embedding_retry(attempt_index: usize) {
183 if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
184 std::thread::sleep(Duration::from_millis(*delay_ms));
185 }
186}
187
188fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
189where
190 F: FnMut() -> reqwest::blocking::RequestBuilder,
191{
192 for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
193 let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
194
195 let response = match make_request().send() {
196 Ok(response) => response,
197 Err(error) => {
198 if !last_attempt && is_retryable_embedding_error(&error) {
199 sleep_before_embedding_retry(attempt_index);
200 continue;
201 }
202 return Err(format!("{backend_label} request failed: {error}"));
203 }
204 };
205
206 let status = response.status();
207 let raw = match response.text() {
208 Ok(raw) => raw,
209 Err(error) => {
210 if !last_attempt && is_retryable_embedding_error(&error) {
211 sleep_before_embedding_retry(attempt_index);
212 continue;
213 }
214 return Err(format!("{backend_label} response read failed: {error}"));
215 }
216 };
217
218 if status.is_success() {
219 return Ok(raw);
220 }
221
222 if !last_attempt && is_retryable_embedding_status(status) {
223 sleep_before_embedding_retry(attempt_index);
224 continue;
225 }
226
227 return Err(format!(
228 "{backend_label} request failed (HTTP {}): {}",
229 status, raw
230 ));
231 }
232
233 unreachable!("embedding request retries exhausted without returning")
234}
235
236impl SemanticEmbeddingModel {
237 pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
238 let timeout_ms = if config.timeout_ms == 0 {
239 DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
240 } else {
241 config.timeout_ms
242 };
243
244 let max_batch_size = if config.max_batch_size == 0 {
245 DEFAULT_MAX_BATCH_SIZE
246 } else {
247 config.max_batch_size
248 };
249
250 let api_key_env = normalize_api_key(config.api_key_env.clone());
251 let model = config.model.clone();
252
253 let client = Client::builder()
254 .timeout(Duration::from_millis(timeout_ms))
255 .redirect(reqwest::redirect::Policy::none())
256 .build()
257 .map_err(|error| format!("failed to configure embedding client: {error}"))?;
258
259 let engine = match config.backend {
260 SemanticBackend::Fastembed => {
261 SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
262 }
263 SemanticBackend::OpenAiCompatible => {
264 let raw = config.base_url.as_ref().ok_or_else(|| {
265 "base_url is required for openai_compatible backend".to_string()
266 })?;
267 let base_url = normalize_base_url(raw)?;
268
269 let api_key = match api_key_env {
270 Some(var_name) => Some(env::var(&var_name).map_err(|_| {
271 format!("missing api_key_env '{var_name}' for openai_compatible backend")
272 })?),
273 None => None,
274 };
275
276 SemanticEmbeddingEngine::OpenAiCompatible {
277 client,
278 model,
279 base_url,
280 api_key,
281 }
282 }
283 SemanticBackend::Ollama => {
284 let raw = config
285 .base_url
286 .as_ref()
287 .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
288 let base_url = normalize_base_url(raw)?;
289
290 SemanticEmbeddingEngine::Ollama {
291 client,
292 model,
293 base_url,
294 }
295 }
296 };
297
298 Ok(Self {
299 backend: config.backend,
300 model: config.model.clone(),
301 base_url: config.base_url.clone(),
302 timeout_ms,
303 max_batch_size,
304 dimension: None,
305 engine,
306 })
307 }
308
309 pub fn backend(&self) -> SemanticBackend {
310 self.backend
311 }
312
313 pub fn model(&self) -> &str {
314 &self.model
315 }
316
317 pub fn base_url(&self) -> Option<&str> {
318 self.base_url.as_deref()
319 }
320
321 pub fn max_batch_size(&self) -> usize {
322 self.max_batch_size
323 }
324
325 pub fn timeout_ms(&self) -> u64 {
326 self.timeout_ms
327 }
328
329 pub fn fingerprint(
330 &mut self,
331 config: &SemanticBackendConfig,
332 ) -> Result<SemanticIndexFingerprint, String> {
333 let dimension = self.dimension()?;
334 Ok(SemanticIndexFingerprint::from_config(config, dimension))
335 }
336
337 pub fn dimension(&mut self) -> Result<usize, String> {
338 if let Some(dimension) = self.dimension {
339 return Ok(dimension);
340 }
341
342 let dimension = match &mut self.engine {
343 SemanticEmbeddingEngine::Fastembed(model) => {
344 let vectors = model
345 .embed(vec!["semantic index fingerprint probe".to_string()], None)
346 .map_err(|error| format_embedding_init_error(error.to_string()))?;
347 vectors
348 .first()
349 .map(|v| v.len())
350 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
351 }
352 SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
353 let vectors =
354 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
355 vectors
356 .first()
357 .map(|v| v.len())
358 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
359 }
360 SemanticEmbeddingEngine::Ollama { .. } => {
361 let vectors =
362 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
363 vectors
364 .first()
365 .map(|v| v.len())
366 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
367 }
368 };
369
370 self.dimension = Some(dimension);
371 Ok(dimension)
372 }
373
374 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
375 self.embed_texts(texts)
376 }
377
378 fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
379 match &mut self.engine {
380 SemanticEmbeddingEngine::Fastembed(model) => model
381 .embed(texts, None::<usize>)
382 .map_err(|error| format_embedding_init_error(error.to_string()))
383 .map_err(|error| format!("failed to embed batch: {error}")),
384 SemanticEmbeddingEngine::OpenAiCompatible {
385 client,
386 model,
387 base_url,
388 api_key,
389 } => {
390 let expected_text_count = texts.len();
391 let endpoint = build_openai_embeddings_endpoint(base_url);
392 let body = serde_json::json!({
393 "input": texts,
394 "model": model,
395 });
396
397 let raw = send_embedding_request(
398 || {
399 let mut request = client
400 .post(&endpoint)
401 .json(&body)
402 .header("Content-Type", "application/json");
403
404 if let Some(api_key) = api_key {
405 request = request.header("Authorization", format!("Bearer {api_key}"));
406 }
407
408 request
409 },
410 "openai compatible",
411 )?;
412
413 #[derive(Deserialize)]
414 struct OpenAiResponse {
415 data: Vec<OpenAiEmbeddingResult>,
416 }
417
418 #[derive(Deserialize)]
419 struct OpenAiEmbeddingResult {
420 embedding: Vec<f32>,
421 index: Option<u32>,
422 }
423
424 let parsed: OpenAiResponse = serde_json::from_str(&raw)
425 .map_err(|error| format!("invalid openai compatible response: {error}"))?;
426 if parsed.data.len() != expected_text_count {
427 return Err(format!(
428 "openai compatible response returned {} embeddings for {} inputs",
429 parsed.data.len(),
430 expected_text_count
431 ));
432 }
433
434 let mut vectors = vec![Vec::new(); parsed.data.len()];
435 for (i, item) in parsed.data.into_iter().enumerate() {
436 let index = item.index.unwrap_or(i as u32) as usize;
437 if index >= vectors.len() {
438 return Err(
439 "openai compatible response contains invalid vector index".to_string()
440 );
441 }
442 vectors[index] = item.embedding;
443 }
444
445 for vector in &vectors {
446 if vector.is_empty() {
447 return Err(
448 "openai compatible response contained missing vectors".to_string()
449 );
450 }
451 }
452
453 self.dimension = vectors.first().map(Vec::len);
454 Ok(vectors)
455 }
456 SemanticEmbeddingEngine::Ollama {
457 client,
458 model,
459 base_url,
460 } => {
461 let expected_text_count = texts.len();
462 let endpoint = build_ollama_embeddings_endpoint(base_url);
463
464 #[derive(Serialize)]
465 struct OllamaPayload<'a> {
466 model: &'a str,
467 input: Vec<String>,
468 }
469
470 let payload = OllamaPayload {
471 model,
472 input: texts,
473 };
474
475 let raw = send_embedding_request(
476 || {
477 client
478 .post(&endpoint)
479 .json(&payload)
480 .header("Content-Type", "application/json")
481 },
482 "ollama",
483 )?;
484
485 #[derive(Deserialize)]
486 struct OllamaResponse {
487 embeddings: Vec<Vec<f32>>,
488 }
489
490 let parsed: OllamaResponse = serde_json::from_str(&raw)
491 .map_err(|error| format!("invalid ollama response: {error}"))?;
492 if parsed.embeddings.is_empty() {
493 return Err("ollama response returned no embeddings".to_string());
494 }
495 if parsed.embeddings.len() != expected_text_count {
496 return Err(format!(
497 "ollama response returned {} embeddings for {} inputs",
498 parsed.embeddings.len(),
499 expected_text_count
500 ));
501 }
502
503 let vectors = parsed.embeddings;
504 for vector in &vectors {
505 if vector.is_empty() {
506 return Err("ollama response contained empty embeddings".to_string());
507 }
508 }
509
510 self.dimension = vectors.first().map(Vec::len);
511 Ok(vectors)
512 }
513 }
514 }
515}
516
517pub fn pre_validate_onnx_runtime() -> Result<(), String> {
521 let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
522
523 #[cfg(any(target_os = "linux", target_os = "macos"))]
524 {
525 #[cfg(target_os = "linux")]
526 let default_name = "libonnxruntime.so";
527 #[cfg(target_os = "macos")]
528 let default_name = "libonnxruntime.dylib";
529
530 let lib_name = dylib_path.as_deref().unwrap_or(default_name);
531
532 unsafe {
533 let c_name = std::ffi::CString::new(lib_name)
534 .map_err(|e| format!("invalid library path: {}", e))?;
535 let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
536 if handle.is_null() {
537 let err = libc::dlerror();
538 let msg = if err.is_null() {
539 "unknown dlopen error".to_string()
540 } else {
541 std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
542 };
543 return Err(format!(
544 "ONNX Runtime not found. dlopen('{}') failed: {}. \
545 Run `bunx @cortexkit/aft-opencode@latest doctor` to diagnose.",
546 lib_name, msg
547 ));
548 }
549
550 let detected_version = detect_ort_version_from_path(lib_name);
553
554 libc::dlclose(handle);
555
556 if let Some(ref version) = detected_version {
558 let parts: Vec<&str> = version.split('.').collect();
559 if let (Some(major), Some(minor)) = (
560 parts.first().and_then(|s| s.parse::<u32>().ok()),
561 parts.get(1).and_then(|s| s.parse::<u32>().ok()),
562 ) {
563 if major != 1 || minor < 20 {
564 return Err(format!(
565 "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
566 Solutions:\n\
567 1. Remove the old library and restart (AFT auto-downloads the correct version):\n\
568 {}\n\
569 2. Or install ONNX Runtime 1.24: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
570 3. Run `bunx @cortexkit/aft-opencode@latest doctor` for full diagnostics.",
571 version,
572 lib_name,
573 suggest_removal_command(lib_name),
574 ));
575 }
576 }
577 }
578 }
579 }
580
581 #[cfg(target_os = "windows")]
582 {
583 let _ = dylib_path;
585 }
586
587 Ok(())
588}
589
590fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
593 let path = std::path::Path::new(lib_path);
594
595 for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
597 .into_iter()
598 .flatten()
599 {
600 if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
601 if let Some(version) = extract_version_from_filename(name) {
602 return Some(version);
603 }
604 }
605 }
606
607 if let Some(parent) = path.parent() {
609 if let Ok(entries) = std::fs::read_dir(parent) {
610 for entry in entries.flatten() {
611 if let Some(name) = entry.file_name().to_str() {
612 if name.starts_with("libonnxruntime") {
613 if let Some(version) = extract_version_from_filename(name) {
614 return Some(version);
615 }
616 }
617 }
618 }
619 }
620 }
621
622 None
623}
624
625fn extract_version_from_filename(name: &str) -> Option<String> {
627 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
629 re.find(name).map(|m| m.as_str().to_string())
630}
631
632fn suggest_removal_command(lib_path: &str) -> String {
633 if lib_path.starts_with("/usr/local/lib")
634 || lib_path == "libonnxruntime.so"
635 || lib_path == "libonnxruntime.dylib"
636 {
637 #[cfg(target_os = "linux")]
638 return " sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
639 #[cfg(target_os = "macos")]
640 return " sudo rm /usr/local/lib/libonnxruntime*".to_string();
641 #[cfg(target_os = "windows")]
642 return " Delete the ONNX Runtime DLL from your PATH".to_string();
643 }
644 format!(" rm '{}'", lib_path)
645}
646
647pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
648 pre_validate_onnx_runtime()?;
650
651 let selected_model = match model {
652 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
653 _ => {
654 return Err(format!(
655 "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
656 model
657 ))
658 }
659 };
660
661 TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
662}
663
664pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
665 if message.trim_start().starts_with("ONNX Runtime not found.") {
666 return true;
667 }
668
669 let message = message.to_ascii_lowercase();
670 let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
671 .iter()
672 .any(|pattern| message.contains(pattern));
673 let mentions_dynamic_load_failure = [
674 "shared library",
675 "dynamic library",
676 "failed to load",
677 "could not load",
678 "unable to load",
679 "dlopen",
680 "loadlibrary",
681 "no such file",
682 "not found",
683 ]
684 .iter()
685 .any(|pattern| message.contains(pattern));
686
687 mentions_onnx_runtime && mentions_dynamic_load_failure
688}
689
690fn format_embedding_init_error(error: impl Display) -> String {
691 let message = error.to_string();
692
693 if is_onnx_runtime_unavailable(&message) {
694 return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
695 }
696
697 format!("failed to initialize semantic embedding model: {message}")
698}
699
700#[derive(Debug, Clone)]
702pub struct SemanticChunk {
703 pub file: PathBuf,
705 pub name: String,
707 pub kind: SymbolKind,
709 pub start_line: u32,
711 pub end_line: u32,
712 pub exported: bool,
714 pub embed_text: String,
716 pub snippet: String,
718}
719
720#[derive(Debug)]
722struct EmbeddingEntry {
723 chunk: SemanticChunk,
724 vector: Vec<f32>,
725}
726
727#[derive(Debug)]
729pub struct SemanticIndex {
730 entries: Vec<EmbeddingEntry>,
731 file_mtimes: HashMap<PathBuf, SystemTime>,
733 dimension: usize,
735 fingerprint: Option<SemanticIndexFingerprint>,
736}
737
738#[derive(Debug)]
740pub struct SemanticResult {
741 pub file: PathBuf,
742 pub name: String,
743 pub kind: SymbolKind,
744 pub start_line: u32,
745 pub end_line: u32,
746 pub exported: bool,
747 pub snippet: String,
748 pub score: f32,
749}
750
751impl SemanticIndex {
752 pub fn new() -> Self {
753 Self {
754 entries: Vec::new(),
755 file_mtimes: HashMap::new(),
756 dimension: DEFAULT_DIMENSION, fingerprint: None,
758 }
759 }
760
761 pub fn entry_count(&self) -> usize {
763 self.entries.len()
764 }
765
766 pub fn status_label(&self) -> &'static str {
768 if self.entries.is_empty() {
769 "empty"
770 } else {
771 "ready"
772 }
773 }
774
775 fn collect_chunks(
776 project_root: &Path,
777 files: &[PathBuf],
778 ) -> (Vec<SemanticChunk>, HashMap<PathBuf, SystemTime>) {
779 let mut parser = FileParser::new();
780 let mut chunks: Vec<SemanticChunk> = Vec::new();
781 let mut file_mtimes: HashMap<PathBuf, SystemTime> = HashMap::new();
782
783 for file in files {
784 let mtime = std::fs::metadata(file)
785 .and_then(|m| m.modified())
786 .unwrap_or(SystemTime::UNIX_EPOCH);
787 file_mtimes.insert(file.clone(), mtime);
788
789 let source = match std::fs::read_to_string(file) {
790 Ok(s) => s,
791 Err(_) => continue,
792 };
793
794 let symbols = match parser.extract_symbols(file) {
795 Ok(s) => s,
796 Err(_) => continue,
797 };
798 let file_chunks = symbols_to_chunks(file, &symbols, &source, project_root);
799 chunks.extend(file_chunks);
800 }
801
802 (chunks, file_mtimes)
803 }
804
805 fn build_from_chunks<F, P>(
806 chunks: Vec<SemanticChunk>,
807 file_mtimes: HashMap<PathBuf, SystemTime>,
808 embed_fn: &mut F,
809 max_batch_size: usize,
810 mut progress: Option<&mut P>,
811 ) -> Result<Self, String>
812 where
813 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
814 P: FnMut(usize, usize),
815 {
816 let total_chunks = chunks.len();
817
818 if chunks.is_empty() {
819 return Ok(Self {
820 entries: Vec::new(),
821 file_mtimes,
822 dimension: DEFAULT_DIMENSION,
823 fingerprint: None,
824 });
825 }
826
827 let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
829 let mut expected_dimension: Option<usize> = None;
830 let batch_size = max_batch_size.max(1);
831 for batch_start in (0..chunks.len()).step_by(batch_size) {
832 let batch_end = (batch_start + batch_size).min(chunks.len());
833 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
834 .iter()
835 .map(|c| c.embed_text.clone())
836 .collect();
837
838 let vectors = embed_fn(batch_texts)?;
839 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
840
841 if let Some(dim) = vectors.first().map(|v| v.len()) {
843 match expected_dimension {
844 None => expected_dimension = Some(dim),
845 Some(expected) if dim != expected => {
846 return Err(format!(
847 "embedding dimension changed across batches: expected {expected}, got {dim}"
848 ));
849 }
850 _ => {}
851 }
852 }
853
854 for (i, vector) in vectors.into_iter().enumerate() {
855 let chunk_idx = batch_start + i;
856 entries.push(EmbeddingEntry {
857 chunk: chunks[chunk_idx].clone(),
858 vector,
859 });
860 }
861
862 if let Some(callback) = progress.as_mut() {
863 callback(entries.len(), total_chunks);
864 }
865 }
866
867 let dimension = entries
868 .first()
869 .map(|e| e.vector.len())
870 .unwrap_or(DEFAULT_DIMENSION);
871
872 Ok(Self {
873 entries,
874 file_mtimes,
875 dimension,
876 fingerprint: None,
877 })
878 }
879
880 pub fn build<F>(
883 project_root: &Path,
884 files: &[PathBuf],
885 embed_fn: &mut F,
886 max_batch_size: usize,
887 ) -> Result<Self, String>
888 where
889 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
890 {
891 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
892 Self::build_from_chunks(
893 chunks,
894 file_mtimes,
895 embed_fn,
896 max_batch_size,
897 Option::<&mut fn(usize, usize)>::None,
898 )
899 }
900
901 pub fn build_with_progress<F, P>(
903 project_root: &Path,
904 files: &[PathBuf],
905 embed_fn: &mut F,
906 max_batch_size: usize,
907 progress: &mut P,
908 ) -> Result<Self, String>
909 where
910 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
911 P: FnMut(usize, usize),
912 {
913 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
914 let total_chunks = chunks.len();
915 progress(0, total_chunks);
916 Self::build_from_chunks(
917 chunks,
918 file_mtimes,
919 embed_fn,
920 max_batch_size,
921 Some(progress),
922 )
923 }
924
925 pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
927 if self.entries.is_empty() || query_vector.len() != self.dimension {
928 return Vec::new();
929 }
930
931 let mut scored: Vec<(f32, usize)> = self
932 .entries
933 .iter()
934 .enumerate()
935 .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
936 .collect();
937
938 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
940
941 scored
942 .into_iter()
943 .take(top_k)
944 .filter(|(score, _)| *score > 0.0)
945 .map(|(score, idx)| {
946 let entry = &self.entries[idx];
947 SemanticResult {
948 file: entry.chunk.file.clone(),
949 name: entry.chunk.name.clone(),
950 kind: entry.chunk.kind.clone(),
951 start_line: entry.chunk.start_line,
952 end_line: entry.chunk.end_line,
953 exported: entry.chunk.exported,
954 snippet: entry.chunk.snippet.clone(),
955 score,
956 }
957 })
958 .collect()
959 }
960
961 pub fn len(&self) -> usize {
963 self.entries.len()
964 }
965
966 pub fn is_file_stale(&self, file: &Path) -> bool {
968 match self.file_mtimes.get(file) {
969 None => true,
970 Some(stored_mtime) => match fs::metadata(file).and_then(|m| m.modified()) {
971 Ok(current_mtime) => *stored_mtime != current_mtime,
972 Err(_) => true,
973 },
974 }
975 }
976
977 pub fn count_stale_files(&self) -> usize {
978 self.file_mtimes
979 .keys()
980 .filter(|path| self.is_file_stale(path))
981 .count()
982 }
983
984 pub fn remove_file(&mut self, file: &Path) {
986 self.invalidate_file(file);
987 }
988
989 pub fn invalidate_file(&mut self, file: &Path) {
990 self.entries.retain(|e| e.chunk.file != file);
991 self.file_mtimes.remove(file);
992 }
993
994 pub fn dimension(&self) -> usize {
996 self.dimension
997 }
998
999 pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1000 self.fingerprint.as_ref()
1001 }
1002
1003 pub fn backend_label(&self) -> Option<&str> {
1004 self.fingerprint.as_ref().map(|f| f.backend.as_str())
1005 }
1006
1007 pub fn model_label(&self) -> Option<&str> {
1008 self.fingerprint.as_ref().map(|f| f.model.as_str())
1009 }
1010
1011 pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1012 self.fingerprint = Some(fingerprint);
1013 }
1014
1015 pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1017 if self.entries.is_empty() {
1020 log::info!("[aft] skipping semantic index persistence (0 entries)");
1021 return;
1022 }
1023 let dir = storage_dir.join("semantic").join(project_key);
1024 if let Err(e) = fs::create_dir_all(&dir) {
1025 log::warn!("[aft] failed to create semantic cache dir: {}", e);
1026 return;
1027 }
1028 let data_path = dir.join("semantic.bin");
1029 let tmp_path = dir.join("semantic.bin.tmp");
1030 let bytes = self.to_bytes();
1031 if let Err(e) = fs::write(&tmp_path, &bytes) {
1032 log::warn!("[aft] failed to write semantic index: {}", e);
1033 let _ = fs::remove_file(&tmp_path);
1034 return;
1035 }
1036 if let Err(e) = fs::rename(&tmp_path, &data_path) {
1037 log::warn!("[aft] failed to rename semantic index: {}", e);
1038 let _ = fs::remove_file(&tmp_path);
1039 return;
1040 }
1041 log::info!(
1042 "[aft] semantic index persisted: {} entries, {:.1} KB",
1043 self.entries.len(),
1044 bytes.len() as f64 / 1024.0
1045 );
1046 }
1047
1048 pub fn read_from_disk(
1050 storage_dir: &Path,
1051 project_key: &str,
1052 expected_fingerprint: Option<&str>,
1053 ) -> Option<Self> {
1054 let data_path = storage_dir
1055 .join("semantic")
1056 .join(project_key)
1057 .join("semantic.bin");
1058 let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1059 if file_len < HEADER_BYTES_V1 {
1060 log::warn!(
1061 "[aft] corrupt semantic index (too small: {} bytes), removing",
1062 file_len
1063 );
1064 let _ = fs::remove_file(&data_path);
1065 return None;
1066 }
1067
1068 let bytes = fs::read(&data_path).ok()?;
1069 match Self::from_bytes(&bytes) {
1070 Ok(index) => {
1071 if index.entries.is_empty() {
1072 log::info!("[aft] cached semantic index is empty, will rebuild");
1073 let _ = fs::remove_file(&data_path);
1074 return None;
1075 }
1076 if let Some(expected) = expected_fingerprint {
1077 let matches = index
1078 .fingerprint()
1079 .map(|fingerprint| fingerprint.matches_expected(expected))
1080 .unwrap_or(false);
1081 if !matches {
1082 log::info!("[aft] cached semantic index fingerprint mismatch, rebuilding");
1083 let _ = fs::remove_file(&data_path);
1084 return None;
1085 }
1086 }
1087 log::info!(
1088 "[aft] loaded semantic index from disk: {} entries",
1089 index.entries.len()
1090 );
1091 Some(index)
1092 }
1093 Err(e) => {
1094 log::warn!("[aft] corrupt semantic index, rebuilding: {}", e);
1095 let _ = fs::remove_file(&data_path);
1096 None
1097 }
1098 }
1099 }
1100
1101 pub fn to_bytes(&self) -> Vec<u8> {
1103 let mut buf = Vec::new();
1104 let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1105 let encoded = fingerprint.as_string();
1106 if encoded.is_empty() {
1107 None
1108 } else {
1109 Some(encoded.into_bytes())
1110 }
1111 });
1112
1113 let version = if fingerprint_bytes.is_some() {
1115 SEMANTIC_INDEX_VERSION_V2
1116 } else {
1117 SEMANTIC_INDEX_VERSION_V1
1118 };
1119 buf.push(version);
1120 buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1121 buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1122 if let Some(bytes) = &fingerprint_bytes {
1123 buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
1124 buf.extend_from_slice(bytes);
1125 }
1126
1127 buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1129 for (path, mtime) in &self.file_mtimes {
1130 let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1131 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1132 buf.extend_from_slice(&path_bytes);
1133 let duration = mtime
1134 .duration_since(SystemTime::UNIX_EPOCH)
1135 .unwrap_or_default();
1136 buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1137 }
1138
1139 for entry in &self.entries {
1141 let c = &entry.chunk;
1142
1143 let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1145 buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1146 buf.extend_from_slice(&file_bytes);
1147
1148 let name_bytes = c.name.as_bytes();
1150 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1151 buf.extend_from_slice(name_bytes);
1152
1153 buf.push(symbol_kind_to_u8(&c.kind));
1155
1156 buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1158 buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1159 buf.push(c.exported as u8);
1160
1161 let snippet_bytes = c.snippet.as_bytes();
1163 buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1164 buf.extend_from_slice(snippet_bytes);
1165
1166 let embed_bytes = c.embed_text.as_bytes();
1168 buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1169 buf.extend_from_slice(embed_bytes);
1170
1171 for &val in &entry.vector {
1173 buf.extend_from_slice(&val.to_le_bytes());
1174 }
1175 }
1176
1177 buf
1178 }
1179
1180 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1182 let mut pos = 0;
1183
1184 if data.len() < HEADER_BYTES_V1 {
1185 return Err("data too short".to_string());
1186 }
1187
1188 let version = data[pos];
1189 pos += 1;
1190 if version != SEMANTIC_INDEX_VERSION_V1 && version != SEMANTIC_INDEX_VERSION_V2 {
1191 return Err(format!("unsupported version: {}", version));
1192 }
1193 if version == SEMANTIC_INDEX_VERSION_V2 && data.len() < HEADER_BYTES_V2 {
1194 return Err("data too short for semantic index v2 header".to_string());
1195 }
1196
1197 let dimension = read_u32(data, &mut pos)? as usize;
1198 let entry_count = read_u32(data, &mut pos)? as usize;
1199 if dimension == 0 || dimension > MAX_DIMENSION {
1200 return Err(format!("invalid embedding dimension: {}", dimension));
1201 }
1202 if entry_count > MAX_ENTRIES {
1203 return Err(format!("too many semantic index entries: {}", entry_count));
1204 }
1205
1206 let fingerprint = if version == SEMANTIC_INDEX_VERSION_V2 {
1207 let fingerprint_len = read_u32(data, &mut pos)? as usize;
1208 if pos + fingerprint_len > data.len() {
1209 return Err("unexpected end of data reading fingerprint".to_string());
1210 }
1211 let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1212 pos += fingerprint_len;
1213 Some(
1214 serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1215 .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1216 )
1217 } else {
1218 None
1219 };
1220
1221 let mtime_count = read_u32(data, &mut pos)? as usize;
1223 if mtime_count > MAX_ENTRIES {
1224 return Err(format!("too many semantic file mtimes: {}", mtime_count));
1225 }
1226
1227 let vector_bytes = entry_count
1228 .checked_mul(dimension)
1229 .and_then(|count| count.checked_mul(F32_BYTES))
1230 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1231 if vector_bytes > data.len().saturating_sub(pos) {
1232 return Err("semantic index vectors exceed available data".to_string());
1233 }
1234
1235 let mut file_mtimes = HashMap::with_capacity(mtime_count);
1236 for _ in 0..mtime_count {
1237 let path = read_string(data, &mut pos)?;
1238 let secs = read_u64(data, &mut pos)?;
1239 let mtime = SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(secs);
1240 file_mtimes.insert(PathBuf::from(path), mtime);
1241 }
1242
1243 let mut entries = Vec::with_capacity(entry_count);
1245 for _ in 0..entry_count {
1246 let file = PathBuf::from(read_string(data, &mut pos)?);
1247 let name = read_string(data, &mut pos)?;
1248
1249 if pos >= data.len() {
1250 return Err("unexpected end of data".to_string());
1251 }
1252 let kind = u8_to_symbol_kind(data[pos]);
1253 pos += 1;
1254
1255 let start_line = read_u32(data, &mut pos)?;
1256 let end_line = read_u32(data, &mut pos)?;
1257
1258 if pos >= data.len() {
1259 return Err("unexpected end of data".to_string());
1260 }
1261 let exported = data[pos] != 0;
1262 pos += 1;
1263
1264 let snippet = read_string(data, &mut pos)?;
1265 let embed_text = read_string(data, &mut pos)?;
1266
1267 let vec_bytes = dimension
1269 .checked_mul(F32_BYTES)
1270 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1271 if pos + vec_bytes > data.len() {
1272 return Err("unexpected end of data reading vector".to_string());
1273 }
1274 let mut vector = Vec::with_capacity(dimension);
1275 for _ in 0..dimension {
1276 let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1277 vector.push(f32::from_le_bytes(bytes));
1278 pos += 4;
1279 }
1280
1281 entries.push(EmbeddingEntry {
1282 chunk: SemanticChunk {
1283 file,
1284 name,
1285 kind,
1286 start_line,
1287 end_line,
1288 exported,
1289 embed_text,
1290 snippet,
1291 },
1292 vector,
1293 });
1294 }
1295
1296 Ok(Self {
1297 entries,
1298 file_mtimes,
1299 dimension,
1300 fingerprint,
1301 })
1302 }
1303}
1304
1305fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1307 let relative = file
1308 .strip_prefix(project_root)
1309 .unwrap_or(file)
1310 .to_string_lossy();
1311
1312 let kind_label = match &symbol.kind {
1313 SymbolKind::Function => "function",
1314 SymbolKind::Class => "class",
1315 SymbolKind::Method => "method",
1316 SymbolKind::Struct => "struct",
1317 SymbolKind::Interface => "interface",
1318 SymbolKind::Enum => "enum",
1319 SymbolKind::TypeAlias => "type",
1320 SymbolKind::Variable => "variable",
1321 SymbolKind::Heading => "heading",
1322 };
1323
1324 let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1326
1327 if let Some(sig) = &symbol.signature {
1328 text.push_str(&format!(" signature:{}", sig));
1329 }
1330
1331 let lines: Vec<&str> = source.lines().collect();
1333 let start = (symbol.range.start_line.saturating_sub(1) as usize).min(lines.len()); let end = (symbol.range.end_line as usize).min(lines.len()); if start < end {
1336 let body: String = lines[start..end]
1337 .iter()
1338 .take(15) .copied()
1340 .collect::<Vec<&str>>()
1341 .join("\n");
1342 let snippet = if body.len() > 300 {
1343 format!("{}...", &body[..body.floor_char_boundary(300)])
1344 } else {
1345 body
1346 };
1347 text.push_str(&format!(" body:{}", snippet));
1348 }
1349
1350 text
1351}
1352
1353fn build_snippet(symbol: &Symbol, source: &str) -> String {
1355 let lines: Vec<&str> = source.lines().collect();
1356 let start = (symbol.range.start_line.saturating_sub(1) as usize).min(lines.len());
1357 let end = (symbol.range.end_line as usize).min(lines.len());
1358 if start < end {
1359 let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1360 let mut snippet = snippet_lines.join("\n");
1361 if end - start > 5 {
1362 snippet.push_str("\n ...");
1363 }
1364 if snippet.len() > 300 {
1365 snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1366 }
1367 snippet
1368 } else {
1369 String::new()
1370 }
1371}
1372
1373fn symbols_to_chunks(
1375 file: &Path,
1376 symbols: &[Symbol],
1377 source: &str,
1378 project_root: &Path,
1379) -> Vec<SemanticChunk> {
1380 let mut chunks = Vec::new();
1381
1382 for symbol in symbols {
1383 let line_count = symbol
1385 .range
1386 .end_line
1387 .saturating_sub(symbol.range.start_line)
1388 + 1;
1389 if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1390 continue;
1391 }
1392
1393 let embed_text = build_embed_text(symbol, source, file, project_root);
1394 let snippet = build_snippet(symbol, source);
1395
1396 chunks.push(SemanticChunk {
1397 file: file.to_path_buf(),
1398 name: symbol.name.clone(),
1399 kind: symbol.kind.clone(),
1400 start_line: symbol.range.start_line,
1401 end_line: symbol.range.end_line,
1402 exported: symbol.exported,
1403 embed_text,
1404 snippet,
1405 });
1406
1407 }
1410
1411 chunks
1412}
1413
1414fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1416 if a.len() != b.len() {
1417 return 0.0;
1418 }
1419
1420 let mut dot = 0.0f32;
1421 let mut norm_a = 0.0f32;
1422 let mut norm_b = 0.0f32;
1423
1424 for i in 0..a.len() {
1425 dot += a[i] * b[i];
1426 norm_a += a[i] * a[i];
1427 norm_b += b[i] * b[i];
1428 }
1429
1430 let denom = norm_a.sqrt() * norm_b.sqrt();
1431 if denom == 0.0 {
1432 0.0
1433 } else {
1434 dot / denom
1435 }
1436}
1437
1438fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1440 match kind {
1441 SymbolKind::Function => 0,
1442 SymbolKind::Class => 1,
1443 SymbolKind::Method => 2,
1444 SymbolKind::Struct => 3,
1445 SymbolKind::Interface => 4,
1446 SymbolKind::Enum => 5,
1447 SymbolKind::TypeAlias => 6,
1448 SymbolKind::Variable => 7,
1449 SymbolKind::Heading => 8,
1450 }
1451}
1452
1453fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1454 match v {
1455 0 => SymbolKind::Function,
1456 1 => SymbolKind::Class,
1457 2 => SymbolKind::Method,
1458 3 => SymbolKind::Struct,
1459 4 => SymbolKind::Interface,
1460 5 => SymbolKind::Enum,
1461 6 => SymbolKind::TypeAlias,
1462 7 => SymbolKind::Variable,
1463 _ => SymbolKind::Heading,
1464 }
1465}
1466
1467fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
1468 if *pos + 4 > data.len() {
1469 return Err("unexpected end of data reading u32".to_string());
1470 }
1471 let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
1472 *pos += 4;
1473 Ok(val)
1474}
1475
1476fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
1477 if *pos + 8 > data.len() {
1478 return Err("unexpected end of data reading u64".to_string());
1479 }
1480 let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
1481 *pos += 8;
1482 Ok(u64::from_le_bytes(bytes))
1483}
1484
1485fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
1486 let len = read_u32(data, pos)? as usize;
1487 if *pos + len > data.len() {
1488 return Err("unexpected end of data reading string".to_string());
1489 }
1490 let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
1491 *pos += len;
1492 Ok(s)
1493}
1494
1495#[cfg(test)]
1496mod tests {
1497 use super::*;
1498 use crate::config::{SemanticBackend, SemanticBackendConfig};
1499 use std::io::{Read, Write};
1500 use std::net::TcpListener;
1501 use std::thread;
1502
1503 fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
1504 where
1505 F: Fn(String, String, String) -> String + Send + 'static,
1506 {
1507 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1508 let addr = listener.local_addr().expect("local addr");
1509 let handle = thread::spawn(move || {
1510 let (mut stream, _) = listener.accept().expect("accept request");
1511 let mut buf = Vec::new();
1512 let mut chunk = [0u8; 4096];
1513 let mut header_end = None;
1514 let mut content_length = 0usize;
1515 loop {
1516 let n = stream.read(&mut chunk).expect("read request");
1517 if n == 0 {
1518 break;
1519 }
1520 buf.extend_from_slice(&chunk[..n]);
1521 if header_end.is_none() {
1522 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
1523 header_end = Some(pos + 4);
1524 let headers = String::from_utf8_lossy(&buf[..pos + 4]);
1525 for line in headers.lines() {
1526 if let Some(value) = line.strip_prefix("Content-Length:") {
1527 content_length = value.trim().parse::<usize>().unwrap_or(0);
1528 }
1529 }
1530 }
1531 }
1532 if let Some(end) = header_end {
1533 if buf.len() >= end + content_length {
1534 break;
1535 }
1536 }
1537 }
1538
1539 let end = header_end.expect("header terminator");
1540 let request = String::from_utf8_lossy(&buf[..end]).to_string();
1541 let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
1542 let mut lines = request.lines();
1543 let request_line = lines.next().expect("request line").to_string();
1544 let path = request_line
1545 .split_whitespace()
1546 .nth(1)
1547 .expect("request path")
1548 .to_string();
1549 let response_body = handler(request_line, path, body);
1550 let response = format!(
1551 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
1552 response_body.len(),
1553 response_body
1554 );
1555 stream
1556 .write_all(response.as_bytes())
1557 .expect("write response");
1558 });
1559
1560 (format!("http://{}", addr), handle)
1561 }
1562
1563 #[test]
1564 fn test_cosine_similarity_identical() {
1565 let a = vec![1.0, 0.0, 0.0];
1566 let b = vec![1.0, 0.0, 0.0];
1567 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1568 }
1569
1570 #[test]
1571 fn test_cosine_similarity_orthogonal() {
1572 let a = vec![1.0, 0.0, 0.0];
1573 let b = vec![0.0, 1.0, 0.0];
1574 assert!(cosine_similarity(&a, &b).abs() < 0.001);
1575 }
1576
1577 #[test]
1578 fn test_cosine_similarity_opposite() {
1579 let a = vec![1.0, 0.0, 0.0];
1580 let b = vec![-1.0, 0.0, 0.0];
1581 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
1582 }
1583
1584 #[test]
1585 fn test_serialization_roundtrip() {
1586 let mut index = SemanticIndex::new();
1587 index.entries.push(EmbeddingEntry {
1588 chunk: SemanticChunk {
1589 file: PathBuf::from("/src/main.rs"),
1590 name: "handle_request".to_string(),
1591 kind: SymbolKind::Function,
1592 start_line: 10,
1593 end_line: 25,
1594 exported: true,
1595 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1596 snippet: "fn handle_request() {\n // ...\n}".to_string(),
1597 },
1598 vector: vec![0.1, 0.2, 0.3, 0.4],
1599 });
1600 index.dimension = 4;
1601 index
1602 .file_mtimes
1603 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1604 index.set_fingerprint(SemanticIndexFingerprint {
1605 backend: "fastembed".to_string(),
1606 model: "all-MiniLM-L6-v2".to_string(),
1607 base_url: FALLBACK_BACKEND.to_string(),
1608 dimension: 4,
1609 });
1610
1611 let bytes = index.to_bytes();
1612 let restored = SemanticIndex::from_bytes(&bytes).unwrap();
1613
1614 assert_eq!(restored.entries.len(), 1);
1615 assert_eq!(restored.entries[0].chunk.name, "handle_request");
1616 assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
1617 assert_eq!(restored.dimension, 4);
1618 assert_eq!(restored.backend_label(), Some("fastembed"));
1619 assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
1620 }
1621
1622 #[test]
1623 fn test_search_top_k() {
1624 let mut index = SemanticIndex::new();
1625 index.dimension = 3;
1626
1627 for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
1629 let mut vec = vec![0.0f32; 3];
1630 vec[i] = 1.0; index.entries.push(EmbeddingEntry {
1632 chunk: SemanticChunk {
1633 file: PathBuf::from("/src/lib.rs"),
1634 name: name.to_string(),
1635 kind: SymbolKind::Function,
1636 start_line: (i * 10 + 1) as u32,
1637 end_line: (i * 10 + 5) as u32,
1638 exported: true,
1639 embed_text: format!("kind:function name:{}", name),
1640 snippet: format!("fn {}() {{}}", name),
1641 },
1642 vector: vec,
1643 });
1644 }
1645
1646 let query = vec![0.9, 0.1, 0.0];
1648 let results = index.search(&query, 2);
1649
1650 assert_eq!(results.len(), 2);
1651 assert_eq!(results[0].name, "auth"); assert!(results[0].score > results[1].score);
1653 }
1654
1655 #[test]
1656 fn test_empty_index_search() {
1657 let index = SemanticIndex::new();
1658 let results = index.search(&[0.1, 0.2, 0.3], 10);
1659 assert!(results.is_empty());
1660 }
1661
1662 #[test]
1663 fn rejects_oversized_dimension_during_deserialization() {
1664 let mut bytes = Vec::new();
1665 bytes.push(1u8);
1666 bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
1667 bytes.extend_from_slice(&0u32.to_le_bytes());
1668 bytes.extend_from_slice(&0u32.to_le_bytes());
1669
1670 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1671 }
1672
1673 #[test]
1674 fn rejects_oversized_entry_count_during_deserialization() {
1675 let mut bytes = Vec::new();
1676 bytes.push(1u8);
1677 bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
1678 bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
1679 bytes.extend_from_slice(&0u32.to_le_bytes());
1680
1681 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1682 }
1683
1684 #[test]
1685 fn invalidate_file_removes_entries_and_mtime() {
1686 let target = PathBuf::from("/src/main.rs");
1687 let mut index = SemanticIndex::new();
1688 index.entries.push(EmbeddingEntry {
1689 chunk: SemanticChunk {
1690 file: target.clone(),
1691 name: "main".to_string(),
1692 kind: SymbolKind::Function,
1693 start_line: 0,
1694 end_line: 1,
1695 exported: false,
1696 embed_text: "main".to_string(),
1697 snippet: "fn main() {}".to_string(),
1698 },
1699 vector: vec![1.0; DEFAULT_DIMENSION],
1700 });
1701 index
1702 .file_mtimes
1703 .insert(target.clone(), SystemTime::UNIX_EPOCH);
1704
1705 index.invalidate_file(&target);
1706
1707 assert!(index.entries.is_empty());
1708 assert!(!index.file_mtimes.contains_key(&target));
1709 }
1710
1711 #[test]
1712 fn detects_missing_onnx_runtime_from_dynamic_load_error() {
1713 let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
1714
1715 assert!(is_onnx_runtime_unavailable(message));
1716 }
1717
1718 #[test]
1719 fn formats_missing_onnx_runtime_with_install_hint() {
1720 let message = format_embedding_init_error(
1721 "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
1722 );
1723
1724 assert!(message.starts_with("ONNX Runtime not found. Install via:"));
1725 assert!(message.contains("Original error:"));
1726 }
1727
1728 #[test]
1729 fn openai_compatible_backend_embeds_with_mock_server() {
1730 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1731 assert!(request_line.starts_with("POST "));
1732 assert_eq!(path, "/v1/embeddings");
1733 "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
1734 });
1735
1736 let config = SemanticBackendConfig {
1737 backend: SemanticBackend::OpenAiCompatible,
1738 model: "test-embedding".to_string(),
1739 base_url: Some(base_url),
1740 api_key_env: None,
1741 timeout_ms: 5_000,
1742 max_batch_size: 64,
1743 };
1744
1745 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1746 let vectors = model
1747 .embed(vec!["hello".to_string(), "world".to_string()])
1748 .unwrap();
1749
1750 assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
1751 handle.join().unwrap();
1752 }
1753
1754 #[test]
1755 fn ollama_backend_embeds_with_mock_server() {
1756 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
1757 assert!(request_line.starts_with("POST "));
1758 assert_eq!(path, "/api/embed");
1759 "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
1760 });
1761
1762 let config = SemanticBackendConfig {
1763 backend: SemanticBackend::Ollama,
1764 model: "embeddinggemma".to_string(),
1765 base_url: Some(base_url),
1766 api_key_env: None,
1767 timeout_ms: 5_000,
1768 max_batch_size: 64,
1769 };
1770
1771 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
1772 let vectors = model
1773 .embed(vec!["hello".to_string(), "world".to_string()])
1774 .unwrap();
1775
1776 assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
1777 handle.join().unwrap();
1778 }
1779
1780 #[test]
1781 fn read_from_disk_rejects_fingerprint_mismatch() {
1782 let storage = tempfile::tempdir().unwrap();
1783 let project_key = "proj";
1784
1785 let mut index = SemanticIndex::new();
1786 index.entries.push(EmbeddingEntry {
1787 chunk: SemanticChunk {
1788 file: PathBuf::from("/src/main.rs"),
1789 name: "handle_request".to_string(),
1790 kind: SymbolKind::Function,
1791 start_line: 10,
1792 end_line: 25,
1793 exported: true,
1794 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1795 snippet: "fn handle_request() {}".to_string(),
1796 },
1797 vector: vec![0.1, 0.2, 0.3],
1798 });
1799 index.dimension = 3;
1800 index
1801 .file_mtimes
1802 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1803 index.set_fingerprint(SemanticIndexFingerprint {
1804 backend: "openai_compatible".to_string(),
1805 model: "test-embedding".to_string(),
1806 base_url: "http://127.0.0.1:1234/v1".to_string(),
1807 dimension: 3,
1808 });
1809 index.write_to_disk(storage.path(), project_key);
1810
1811 let matching = index.fingerprint().unwrap().as_string();
1812 assert!(
1813 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
1814 );
1815
1816 let mismatched = SemanticIndexFingerprint {
1817 backend: "ollama".to_string(),
1818 model: "embeddinggemma".to_string(),
1819 base_url: "http://127.0.0.1:11434".to_string(),
1820 dimension: 3,
1821 }
1822 .as_string();
1823 assert!(
1824 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
1825 );
1826 }
1827}