1use crate::config::{SemanticBackend, SemanticBackendConfig};
2use crate::parser::{detect_language, extract_symbols_from_tree, grammar_for};
3use crate::symbols::{Symbol, SymbolKind};
4use crate::{slog_info, slog_warn};
5
6use fastembed::{EmbeddingModel as FastembedEmbeddingModel, InitOptions, TextEmbedding};
7use rayon::prelude::*;
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::env;
12use std::fmt::Display;
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16use std::time::SystemTime;
17use tree_sitter::Parser;
18use url::Url;
19
20const DEFAULT_DIMENSION: usize = 384;
21const MAX_ENTRIES: usize = 1_000_000;
22const MAX_DIMENSION: usize = 1024;
23const F32_BYTES: usize = std::mem::size_of::<f32>();
24const HEADER_BYTES_V1: usize = 9;
25const HEADER_BYTES_V2: usize = 13;
26const ONNX_RUNTIME_INSTALL_HINT: &str =
27 "ONNX Runtime not found. Install via: brew install onnxruntime (macOS) or apt install libonnxruntime (Linux).";
28
29const SEMANTIC_INDEX_VERSION_V1: u8 = 1;
30const SEMANTIC_INDEX_VERSION_V2: u8 = 2;
31const SEMANTIC_INDEX_VERSION_V3: u8 = 3;
36const SEMANTIC_INDEX_VERSION_V4: u8 = 4;
39const DEFAULT_OPENAI_EMBEDDING_PATH: &str = "/embeddings";
40const DEFAULT_OLLAMA_EMBEDDING_PATH: &str = "/api/embed";
41const DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS: u64 = 25_000;
43const DEFAULT_MAX_BATCH_SIZE: usize = 64;
44const FALLBACK_BACKEND: &str = "none";
45const EMBEDDING_REQUEST_MAX_ATTEMPTS: usize = 3;
46const EMBEDDING_REQUEST_BACKOFF_MS: [u64; 2] = [500, 1_000];
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SemanticIndexFingerprint {
50 pub backend: String,
51 pub model: String,
52 #[serde(default)]
53 pub base_url: String,
54 pub dimension: usize,
55}
56
57impl SemanticIndexFingerprint {
58 fn from_config(config: &SemanticBackendConfig, dimension: usize) -> Self {
59 let base_url = config
62 .base_url
63 .as_ref()
64 .and_then(|u| normalize_base_url(u).ok())
65 .unwrap_or_else(|| FALLBACK_BACKEND.to_string());
66 Self {
67 backend: config.backend.as_str().to_string(),
68 model: config.model.clone(),
69 base_url,
70 dimension,
71 }
72 }
73
74 pub fn as_string(&self) -> String {
75 serde_json::to_string(self).unwrap_or_else(|_| String::new())
76 }
77
78 fn matches_expected(&self, expected: &str) -> bool {
79 let encoded = self.as_string();
80 !encoded.is_empty() && encoded == expected
81 }
82}
83
84enum SemanticEmbeddingEngine {
85 Fastembed(TextEmbedding),
86 OpenAiCompatible {
87 client: Client,
88 model: String,
89 base_url: String,
90 api_key: Option<String>,
91 },
92 Ollama {
93 client: Client,
94 model: String,
95 base_url: String,
96 },
97}
98
99pub struct SemanticEmbeddingModel {
100 backend: SemanticBackend,
101 model: String,
102 base_url: Option<String>,
103 timeout_ms: u64,
104 max_batch_size: usize,
105 dimension: Option<usize>,
106 engine: SemanticEmbeddingEngine,
107}
108
109pub type EmbeddingModel = SemanticEmbeddingModel;
110
111fn validate_embedding_batch(
112 vectors: &[Vec<f32>],
113 expected_count: usize,
114 context: &str,
115) -> Result<(), String> {
116 if expected_count > 0 && vectors.is_empty() {
117 return Err(format!(
118 "{context} returned no vectors for {expected_count} inputs"
119 ));
120 }
121
122 if vectors.len() != expected_count {
123 return Err(format!(
124 "{context} returned {} vectors for {} inputs",
125 vectors.len(),
126 expected_count
127 ));
128 }
129
130 let Some(first_vector) = vectors.first() else {
131 return Ok(());
132 };
133 let expected_dimension = first_vector.len();
134 for (index, vector) in vectors.iter().enumerate() {
135 if vector.len() != expected_dimension {
136 return Err(format!(
137 "{context} returned inconsistent embedding dimensions: vector 0 has length {expected_dimension}, vector {index} has length {}",
138 vector.len()
139 ));
140 }
141 }
142
143 Ok(())
144}
145
146fn normalize_base_url(raw: &str) -> Result<String, String> {
150 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
151 let scheme = parsed.scheme();
152 if scheme != "http" && scheme != "https" {
153 return Err(format!(
154 "unsupported URL scheme '{}' — only http:// and https:// are allowed",
155 scheme
156 ));
157 }
158 Ok(parsed.to_string().trim_end_matches('/').to_string())
159}
160
161pub fn validate_base_url_no_ssrf(raw: &str) -> Result<(), String> {
165 use std::net::{IpAddr, ToSocketAddrs};
166
167 let parsed = Url::parse(raw).map_err(|error| format!("invalid base_url '{raw}': {error}"))?;
168
169 let host = parsed.host_str().unwrap_or("");
171 if host == "localhost"
172 || host == "localhost.localdomain"
173 || host.ends_with(".localhost")
174 || host.ends_with(".local")
175 {
176 return Err(format!(
177 "base_url host '{host}' resolves to a private/loopback address — only public endpoints are allowed"
178 ));
179 }
180
181 let port = parsed.port_or_known_default().unwrap_or(443);
183 let addr_str = format!("{host}:{port}");
184 let addrs: Vec<IpAddr> = addr_str
185 .to_socket_addrs()
186 .map(|iter| iter.map(|sa| sa.ip()).collect())
187 .unwrap_or_default();
188 for ip in &addrs {
189 if is_private_ip(ip) {
190 return Err(format!(
191 "base_url '{raw}' resolves to a private/reserved IP address — only public endpoints are allowed"
192 ));
193 }
194 }
195
196 Ok(())
197}
198
199fn is_private_ip(ip: &std::net::IpAddr) -> bool {
200 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
201 match ip {
202 IpAddr::V4(v4) => {
203 let o = v4.octets();
204 o[0] == 10
206 || (o[0] == 172 && (16..=31).contains(&o[1]))
208 || (o[0] == 192 && o[1] == 168)
210 || o[0] == 127
212 || (o[0] == 169 && o[1] == 254)
214 || (o[0] == 100 && (64..=127).contains(&o[1]))
216 || o[0] == 0
218 }
219 IpAddr::V6(v6) => {
220 *v6 == Ipv6Addr::LOCALHOST
222 || (v6.segments()[0] & 0xffc0) == 0xfe80
224 || (v6.segments()[0] & 0xfe00) == 0xfc00
226 || (v6.segments()[0] == 0 && v6.segments()[1] == 0
228 && v6.segments()[2] == 0 && v6.segments()[3] == 0
229 && v6.segments()[4] == 0 && v6.segments()[5] == 0xffff
230 && {
231 let [a, b, c, d] = v6.segments()[6..8] else { return false; };
232 let ipv4 = Ipv4Addr::new((a >> 8) as u8, (a & 0xff) as u8, (b >> 8) as u8, (b & 0xff) as u8);
233 is_private_ip(&IpAddr::V4(ipv4))
234 })
235 }
236 }
237}
238
239fn build_openai_embeddings_endpoint(base_url: &str) -> String {
240 if base_url.ends_with("/v1") {
241 format!("{base_url}{DEFAULT_OPENAI_EMBEDDING_PATH}")
242 } else {
243 format!("{base_url}/v1{}", DEFAULT_OPENAI_EMBEDDING_PATH)
244 }
245}
246
247fn build_ollama_embeddings_endpoint(base_url: &str) -> String {
248 if base_url.ends_with("/api") {
249 format!("{base_url}/embed")
250 } else {
251 format!("{base_url}{DEFAULT_OLLAMA_EMBEDDING_PATH}")
252 }
253}
254
255fn normalize_api_key(value: Option<String>) -> Option<String> {
256 value.and_then(|token| {
257 let token = token.trim();
258 if token.is_empty() {
259 None
260 } else {
261 Some(token.to_string())
262 }
263 })
264}
265
266fn is_retryable_embedding_status(status: reqwest::StatusCode) -> bool {
267 status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
268}
269
270fn is_retryable_embedding_error(error: &reqwest::Error) -> bool {
271 error.is_connect()
272}
273
274fn sleep_before_embedding_retry(attempt_index: usize) {
275 if let Some(delay_ms) = EMBEDDING_REQUEST_BACKOFF_MS.get(attempt_index) {
276 std::thread::sleep(Duration::from_millis(*delay_ms));
277 }
278}
279
280fn send_embedding_request<F>(mut make_request: F, backend_label: &str) -> Result<String, String>
281where
282 F: FnMut() -> reqwest::blocking::RequestBuilder,
283{
284 for attempt_index in 0..EMBEDDING_REQUEST_MAX_ATTEMPTS {
285 let last_attempt = attempt_index + 1 == EMBEDDING_REQUEST_MAX_ATTEMPTS;
286
287 let response = match make_request().send() {
288 Ok(response) => response,
289 Err(error) => {
290 if !last_attempt && is_retryable_embedding_error(&error) {
291 sleep_before_embedding_retry(attempt_index);
292 continue;
293 }
294 return Err(format!("{backend_label} request failed: {error}"));
295 }
296 };
297
298 let status = response.status();
299 let raw = match response.text() {
300 Ok(raw) => raw,
301 Err(error) => {
302 if !last_attempt && is_retryable_embedding_error(&error) {
303 sleep_before_embedding_retry(attempt_index);
304 continue;
305 }
306 return Err(format!("{backend_label} response read failed: {error}"));
307 }
308 };
309
310 if status.is_success() {
311 return Ok(raw);
312 }
313
314 if !last_attempt && is_retryable_embedding_status(status) {
315 sleep_before_embedding_retry(attempt_index);
316 continue;
317 }
318
319 return Err(format!(
320 "{backend_label} request failed (HTTP {}): {}",
321 status, raw
322 ));
323 }
324
325 unreachable!("embedding request retries exhausted without returning")
326}
327
328impl SemanticEmbeddingModel {
329 pub fn from_config(config: &SemanticBackendConfig) -> Result<Self, String> {
330 let timeout_ms = if config.timeout_ms == 0 {
331 DEFAULT_OPENAI_EMBEDDING_TIMEOUT_MS
332 } else {
333 config.timeout_ms
334 };
335
336 let max_batch_size = if config.max_batch_size == 0 {
337 DEFAULT_MAX_BATCH_SIZE
338 } else {
339 config.max_batch_size
340 };
341
342 let api_key_env = normalize_api_key(config.api_key_env.clone());
343 let model = config.model.clone();
344
345 let client = Client::builder()
346 .timeout(Duration::from_millis(timeout_ms))
347 .redirect(reqwest::redirect::Policy::none())
348 .build()
349 .map_err(|error| format!("failed to configure embedding client: {error}"))?;
350
351 let engine = match config.backend {
352 SemanticBackend::Fastembed => {
353 SemanticEmbeddingEngine::Fastembed(initialize_text_embedding(&model)?)
354 }
355 SemanticBackend::OpenAiCompatible => {
356 let raw = config.base_url.as_ref().ok_or_else(|| {
357 "base_url is required for openai_compatible backend".to_string()
358 })?;
359 let base_url = normalize_base_url(raw)?;
360
361 let api_key = match api_key_env {
362 Some(var_name) => Some(env::var(&var_name).map_err(|_| {
363 format!("missing api_key_env '{var_name}' for openai_compatible backend")
364 })?),
365 None => None,
366 };
367
368 SemanticEmbeddingEngine::OpenAiCompatible {
369 client,
370 model,
371 base_url,
372 api_key,
373 }
374 }
375 SemanticBackend::Ollama => {
376 let raw = config
377 .base_url
378 .as_ref()
379 .ok_or_else(|| "base_url is required for ollama backend".to_string())?;
380 let base_url = normalize_base_url(raw)?;
381
382 SemanticEmbeddingEngine::Ollama {
383 client,
384 model,
385 base_url,
386 }
387 }
388 };
389
390 Ok(Self {
391 backend: config.backend,
392 model: config.model.clone(),
393 base_url: config.base_url.clone(),
394 timeout_ms,
395 max_batch_size,
396 dimension: None,
397 engine,
398 })
399 }
400
401 pub fn backend(&self) -> SemanticBackend {
402 self.backend
403 }
404
405 pub fn model(&self) -> &str {
406 &self.model
407 }
408
409 pub fn base_url(&self) -> Option<&str> {
410 self.base_url.as_deref()
411 }
412
413 pub fn max_batch_size(&self) -> usize {
414 self.max_batch_size
415 }
416
417 pub fn timeout_ms(&self) -> u64 {
418 self.timeout_ms
419 }
420
421 pub fn fingerprint(
422 &mut self,
423 config: &SemanticBackendConfig,
424 ) -> Result<SemanticIndexFingerprint, String> {
425 let dimension = self.dimension()?;
426 Ok(SemanticIndexFingerprint::from_config(config, dimension))
427 }
428
429 pub fn dimension(&mut self) -> Result<usize, String> {
430 if let Some(dimension) = self.dimension {
431 return Ok(dimension);
432 }
433
434 let dimension = match &mut self.engine {
435 SemanticEmbeddingEngine::Fastembed(model) => {
436 let vectors = model
437 .embed(vec!["semantic index fingerprint probe".to_string()], None)
438 .map_err(|error| format_embedding_init_error(error.to_string()))?;
439 vectors
440 .first()
441 .map(|v| v.len())
442 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
443 }
444 SemanticEmbeddingEngine::OpenAiCompatible { .. } => {
445 let vectors =
446 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
447 vectors
448 .first()
449 .map(|v| v.len())
450 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
451 }
452 SemanticEmbeddingEngine::Ollama { .. } => {
453 let vectors =
454 self.embed_texts(vec!["semantic index fingerprint probe".to_string()])?;
455 vectors
456 .first()
457 .map(|v| v.len())
458 .ok_or_else(|| "embedding backend returned no vectors".to_string())?
459 }
460 };
461
462 self.dimension = Some(dimension);
463 Ok(dimension)
464 }
465
466 pub fn embed(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
467 self.embed_texts(texts)
468 }
469
470 fn embed_texts(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, String> {
471 match &mut self.engine {
472 SemanticEmbeddingEngine::Fastembed(model) => model
473 .embed(texts, None::<usize>)
474 .map_err(|error| format_embedding_init_error(error.to_string()))
475 .map_err(|error| format!("failed to embed batch: {error}")),
476 SemanticEmbeddingEngine::OpenAiCompatible {
477 client,
478 model,
479 base_url,
480 api_key,
481 } => {
482 let expected_text_count = texts.len();
483 let endpoint = build_openai_embeddings_endpoint(base_url);
484 let body = serde_json::json!({
485 "input": texts,
486 "model": model,
487 });
488
489 let raw = send_embedding_request(
490 || {
491 let mut request = client
492 .post(&endpoint)
493 .json(&body)
494 .header("Content-Type", "application/json");
495
496 if let Some(api_key) = api_key {
497 request = request.header("Authorization", format!("Bearer {api_key}"));
498 }
499
500 request
501 },
502 "openai compatible",
503 )?;
504
505 #[derive(Deserialize)]
506 struct OpenAiResponse {
507 data: Vec<OpenAiEmbeddingResult>,
508 }
509
510 #[derive(Deserialize)]
511 struct OpenAiEmbeddingResult {
512 embedding: Vec<f32>,
513 index: Option<u32>,
514 }
515
516 let parsed: OpenAiResponse = serde_json::from_str(&raw)
517 .map_err(|error| format!("invalid openai compatible response: {error}"))?;
518 if parsed.data.len() != expected_text_count {
519 return Err(format!(
520 "openai compatible response returned {} embeddings for {} inputs",
521 parsed.data.len(),
522 expected_text_count
523 ));
524 }
525
526 let mut vectors = vec![Vec::new(); parsed.data.len()];
527 for (i, item) in parsed.data.into_iter().enumerate() {
528 let index = item.index.unwrap_or(i as u32) as usize;
529 if index >= vectors.len() {
530 return Err(
531 "openai compatible response contains invalid vector index".to_string()
532 );
533 }
534 vectors[index] = item.embedding;
535 }
536
537 for vector in &vectors {
538 if vector.is_empty() {
539 return Err(
540 "openai compatible response contained missing vectors".to_string()
541 );
542 }
543 }
544
545 self.dimension = vectors.first().map(Vec::len);
546 Ok(vectors)
547 }
548 SemanticEmbeddingEngine::Ollama {
549 client,
550 model,
551 base_url,
552 } => {
553 let expected_text_count = texts.len();
554 let endpoint = build_ollama_embeddings_endpoint(base_url);
555
556 #[derive(Serialize)]
557 struct OllamaPayload<'a> {
558 model: &'a str,
559 input: Vec<String>,
560 }
561
562 let payload = OllamaPayload {
563 model,
564 input: texts,
565 };
566
567 let raw = send_embedding_request(
568 || {
569 client
570 .post(&endpoint)
571 .json(&payload)
572 .header("Content-Type", "application/json")
573 },
574 "ollama",
575 )?;
576
577 #[derive(Deserialize)]
578 struct OllamaResponse {
579 embeddings: Vec<Vec<f32>>,
580 }
581
582 let parsed: OllamaResponse = serde_json::from_str(&raw)
583 .map_err(|error| format!("invalid ollama response: {error}"))?;
584 if parsed.embeddings.is_empty() {
585 return Err("ollama response returned no embeddings".to_string());
586 }
587 if parsed.embeddings.len() != expected_text_count {
588 return Err(format!(
589 "ollama response returned {} embeddings for {} inputs",
590 parsed.embeddings.len(),
591 expected_text_count
592 ));
593 }
594
595 let vectors = parsed.embeddings;
596 for vector in &vectors {
597 if vector.is_empty() {
598 return Err("ollama response contained empty embeddings".to_string());
599 }
600 }
601
602 self.dimension = vectors.first().map(Vec::len);
603 Ok(vectors)
604 }
605 }
606 }
607}
608
609pub fn pre_validate_onnx_runtime() -> Result<(), String> {
613 let dylib_path = std::env::var("ORT_DYLIB_PATH").ok();
614
615 #[cfg(any(target_os = "linux", target_os = "macos"))]
616 {
617 #[cfg(target_os = "linux")]
618 let default_name = "libonnxruntime.so";
619 #[cfg(target_os = "macos")]
620 let default_name = "libonnxruntime.dylib";
621
622 let lib_name = dylib_path.as_deref().unwrap_or(default_name);
623
624 unsafe {
625 let c_name = std::ffi::CString::new(lib_name)
626 .map_err(|e| format!("invalid library path: {}", e))?;
627 let handle = libc::dlopen(c_name.as_ptr(), libc::RTLD_NOW);
628 if handle.is_null() {
629 let err = libc::dlerror();
630 let msg = if err.is_null() {
631 "unknown dlopen error".to_string()
632 } else {
633 std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned()
634 };
635 return Err(format!(
636 "ONNX Runtime not found. dlopen('{}') failed: {}. \
637 Run `bunx @cortexkit/aft-opencode@latest doctor` to diagnose.",
638 lib_name, msg
639 ));
640 }
641
642 let detected_version = detect_ort_version_from_path(lib_name);
645
646 libc::dlclose(handle);
647
648 if let Some(ref version) = detected_version {
650 let parts: Vec<&str> = version.split('.').collect();
651 if let (Some(major), Some(minor)) = (
652 parts.first().and_then(|s| s.parse::<u32>().ok()),
653 parts.get(1).and_then(|s| s.parse::<u32>().ok()),
654 ) {
655 if major != 1 || minor < 20 {
656 return Err(format!(
657 "ONNX Runtime version mismatch: found v{} at '{}', but AFT requires v1.20+. \
658 Solutions:\n\
659 1. Remove the old library and restart (AFT auto-downloads the correct version):\n\
660 {}\n\
661 2. Or install ONNX Runtime 1.24: https://github.com/microsoft/onnxruntime/releases/tag/v1.24.0\n\
662 3. Run `bunx @cortexkit/aft-opencode@latest doctor` for full diagnostics.",
663 version,
664 lib_name,
665 suggest_removal_command(lib_name),
666 ));
667 }
668 }
669 }
670 }
671 }
672
673 #[cfg(target_os = "windows")]
674 {
675 let _ = dylib_path;
677 }
678
679 Ok(())
680}
681
682fn detect_ort_version_from_path(lib_path: &str) -> Option<String> {
685 let path = std::path::Path::new(lib_path);
686
687 for candidate in [Some(path.to_path_buf()), std::fs::canonicalize(path).ok()]
689 .into_iter()
690 .flatten()
691 {
692 if let Some(name) = candidate.file_name().and_then(|n| n.to_str()) {
693 if let Some(version) = extract_version_from_filename(name) {
694 return Some(version);
695 }
696 }
697 }
698
699 if let Some(parent) = path.parent() {
701 if let Ok(entries) = std::fs::read_dir(parent) {
702 for entry in entries.flatten() {
703 if let Some(name) = entry.file_name().to_str() {
704 if name.starts_with("libonnxruntime") {
705 if let Some(version) = extract_version_from_filename(name) {
706 return Some(version);
707 }
708 }
709 }
710 }
711 }
712 }
713
714 None
715}
716
717fn extract_version_from_filename(name: &str) -> Option<String> {
719 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
721 re.find(name).map(|m| m.as_str().to_string())
722}
723
724fn suggest_removal_command(lib_path: &str) -> String {
725 if lib_path.starts_with("/usr/local/lib")
726 || lib_path == "libonnxruntime.so"
727 || lib_path == "libonnxruntime.dylib"
728 {
729 #[cfg(target_os = "linux")]
730 return " sudo rm /usr/local/lib/libonnxruntime* && sudo ldconfig".to_string();
731 #[cfg(target_os = "macos")]
732 return " sudo rm /usr/local/lib/libonnxruntime*".to_string();
733 #[cfg(target_os = "windows")]
734 return " Delete the ONNX Runtime DLL from your PATH".to_string();
735 }
736 format!(" rm '{}'", lib_path)
737}
738
739pub fn initialize_text_embedding(model: &str) -> Result<TextEmbedding, String> {
740 pre_validate_onnx_runtime()?;
742
743 let selected_model = match model {
744 "all-MiniLM-L6-v2" | "all-minilm-l6-v2" => FastembedEmbeddingModel::AllMiniLML6V2,
745 _ => {
746 return Err(format!(
747 "unsupported fastembed model '{}'. Supported: all-MiniLM-L6-v2",
748 model
749 ))
750 }
751 };
752
753 TextEmbedding::try_new(InitOptions::new(selected_model)).map_err(format_embedding_init_error)
754}
755
756pub fn is_onnx_runtime_unavailable(message: &str) -> bool {
757 if message.trim_start().starts_with("ONNX Runtime not found.") {
758 return true;
759 }
760
761 let message = message.to_ascii_lowercase();
762 let mentions_onnx_runtime = ["onnx runtime", "onnxruntime", "libonnxruntime"]
763 .iter()
764 .any(|pattern| message.contains(pattern));
765 let mentions_dynamic_load_failure = [
766 "shared library",
767 "dynamic library",
768 "failed to load",
769 "could not load",
770 "unable to load",
771 "dlopen",
772 "loadlibrary",
773 "no such file",
774 "not found",
775 ]
776 .iter()
777 .any(|pattern| message.contains(pattern));
778
779 mentions_onnx_runtime && mentions_dynamic_load_failure
780}
781
782fn format_embedding_init_error(error: impl Display) -> String {
783 let message = error.to_string();
784
785 if is_onnx_runtime_unavailable(&message) {
786 return format!("{ONNX_RUNTIME_INSTALL_HINT} Original error: {message}");
787 }
788
789 format!("failed to initialize semantic embedding model: {message}")
790}
791
792#[derive(Debug, Clone)]
794pub struct SemanticChunk {
795 pub file: PathBuf,
797 pub name: String,
799 pub kind: SymbolKind,
801 pub start_line: u32,
803 pub end_line: u32,
804 pub exported: bool,
806 pub embed_text: String,
808 pub snippet: String,
810}
811
812#[derive(Debug)]
814struct EmbeddingEntry {
815 chunk: SemanticChunk,
816 vector: Vec<f32>,
817}
818
819#[derive(Debug)]
821pub struct SemanticIndex {
822 entries: Vec<EmbeddingEntry>,
823 file_mtimes: HashMap<PathBuf, SystemTime>,
825 dimension: usize,
827 fingerprint: Option<SemanticIndexFingerprint>,
828}
829
830#[derive(Debug)]
832pub struct SemanticResult {
833 pub file: PathBuf,
834 pub name: String,
835 pub kind: SymbolKind,
836 pub start_line: u32,
837 pub end_line: u32,
838 pub exported: bool,
839 pub snippet: String,
840 pub score: f32,
841}
842
843impl SemanticIndex {
844 pub fn new() -> Self {
845 Self {
846 entries: Vec::new(),
847 file_mtimes: HashMap::new(),
848 dimension: DEFAULT_DIMENSION, fingerprint: None,
850 }
851 }
852
853 pub fn entry_count(&self) -> usize {
855 self.entries.len()
856 }
857
858 pub fn status_label(&self) -> &'static str {
860 if self.entries.is_empty() {
861 "empty"
862 } else {
863 "ready"
864 }
865 }
866
867 fn collect_chunks(
868 project_root: &Path,
869 files: &[PathBuf],
870 ) -> (Vec<SemanticChunk>, HashMap<PathBuf, SystemTime>) {
871 let per_file: Vec<(PathBuf, SystemTime, Vec<SemanticChunk>)> = files
872 .par_iter()
873 .map_init(HashMap::new, |parsers, file| {
874 let mtime = std::fs::metadata(file)
875 .and_then(|m| m.modified())
876 .unwrap_or(SystemTime::UNIX_EPOCH);
877
878 let chunks = collect_file_chunks(project_root, file, parsers).unwrap_or_default();
879
880 (file.clone(), mtime, chunks)
881 })
882 .collect();
883
884 let mut chunks: Vec<SemanticChunk> = Vec::new();
885 let mut file_mtimes: HashMap<PathBuf, SystemTime> = HashMap::new();
886
887 for (file, mtime, file_chunks) in per_file {
888 file_mtimes.insert(file, mtime);
889 chunks.extend(file_chunks);
890 }
891
892 (chunks, file_mtimes)
893 }
894
895 fn build_from_chunks<F, P>(
896 chunks: Vec<SemanticChunk>,
897 file_mtimes: HashMap<PathBuf, SystemTime>,
898 embed_fn: &mut F,
899 max_batch_size: usize,
900 mut progress: Option<&mut P>,
901 ) -> Result<Self, String>
902 where
903 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
904 P: FnMut(usize, usize),
905 {
906 let total_chunks = chunks.len();
907
908 if chunks.is_empty() {
909 return Ok(Self {
910 entries: Vec::new(),
911 file_mtimes,
912 dimension: DEFAULT_DIMENSION,
913 fingerprint: None,
914 });
915 }
916
917 let mut entries: Vec<EmbeddingEntry> = Vec::with_capacity(chunks.len());
919 let mut expected_dimension: Option<usize> = None;
920 let batch_size = max_batch_size.max(1);
921 for batch_start in (0..chunks.len()).step_by(batch_size) {
922 let batch_end = (batch_start + batch_size).min(chunks.len());
923 let batch_texts: Vec<String> = chunks[batch_start..batch_end]
924 .iter()
925 .map(|c| c.embed_text.clone())
926 .collect();
927
928 let vectors = embed_fn(batch_texts)?;
929 validate_embedding_batch(&vectors, batch_end - batch_start, "embedding backend")?;
930
931 if let Some(dim) = vectors.first().map(|v| v.len()) {
933 match expected_dimension {
934 None => expected_dimension = Some(dim),
935 Some(expected) if dim != expected => {
936 return Err(format!(
937 "embedding dimension changed across batches: expected {expected}, got {dim}"
938 ));
939 }
940 _ => {}
941 }
942 }
943
944 for (i, vector) in vectors.into_iter().enumerate() {
945 let chunk_idx = batch_start + i;
946 entries.push(EmbeddingEntry {
947 chunk: chunks[chunk_idx].clone(),
948 vector,
949 });
950 }
951
952 if let Some(callback) = progress.as_mut() {
953 callback(entries.len(), total_chunks);
954 }
955 }
956
957 let dimension = entries
958 .first()
959 .map(|e| e.vector.len())
960 .unwrap_or(DEFAULT_DIMENSION);
961
962 Ok(Self {
963 entries,
964 file_mtimes,
965 dimension,
966 fingerprint: None,
967 })
968 }
969
970 pub fn build<F>(
973 project_root: &Path,
974 files: &[PathBuf],
975 embed_fn: &mut F,
976 max_batch_size: usize,
977 ) -> Result<Self, String>
978 where
979 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
980 {
981 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
982 Self::build_from_chunks(
983 chunks,
984 file_mtimes,
985 embed_fn,
986 max_batch_size,
987 Option::<&mut fn(usize, usize)>::None,
988 )
989 }
990
991 pub fn build_with_progress<F, P>(
993 project_root: &Path,
994 files: &[PathBuf],
995 embed_fn: &mut F,
996 max_batch_size: usize,
997 progress: &mut P,
998 ) -> Result<Self, String>
999 where
1000 F: FnMut(Vec<String>) -> Result<Vec<Vec<f32>>, String>,
1001 P: FnMut(usize, usize),
1002 {
1003 let (chunks, file_mtimes) = Self::collect_chunks(project_root, files);
1004 let total_chunks = chunks.len();
1005 progress(0, total_chunks);
1006 Self::build_from_chunks(
1007 chunks,
1008 file_mtimes,
1009 embed_fn,
1010 max_batch_size,
1011 Some(progress),
1012 )
1013 }
1014
1015 pub fn search(&self, query_vector: &[f32], top_k: usize) -> Vec<SemanticResult> {
1017 if self.entries.is_empty() || query_vector.len() != self.dimension {
1018 return Vec::new();
1019 }
1020
1021 let mut scored: Vec<(f32, usize)> = self
1022 .entries
1023 .iter()
1024 .enumerate()
1025 .map(|(i, entry)| (cosine_similarity(query_vector, &entry.vector), i))
1026 .collect();
1027
1028 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
1030
1031 scored
1032 .into_iter()
1033 .take(top_k)
1034 .filter(|(score, _)| *score > 0.0)
1035 .map(|(score, idx)| {
1036 let entry = &self.entries[idx];
1037 SemanticResult {
1038 file: entry.chunk.file.clone(),
1039 name: entry.chunk.name.clone(),
1040 kind: entry.chunk.kind.clone(),
1041 start_line: entry.chunk.start_line,
1042 end_line: entry.chunk.end_line,
1043 exported: entry.chunk.exported,
1044 snippet: entry.chunk.snippet.clone(),
1045 score,
1046 }
1047 })
1048 .collect()
1049 }
1050
1051 pub fn len(&self) -> usize {
1053 self.entries.len()
1054 }
1055
1056 pub fn is_file_stale(&self, file: &Path) -> bool {
1058 match self.file_mtimes.get(file) {
1059 None => true,
1060 Some(stored_mtime) => match fs::metadata(file).and_then(|m| m.modified()) {
1061 Ok(current_mtime) => *stored_mtime != current_mtime,
1062 Err(_) => true,
1063 },
1064 }
1065 }
1066
1067 pub fn count_stale_files(&self) -> usize {
1068 self.file_mtimes
1069 .keys()
1070 .filter(|path| self.is_file_stale(path))
1071 .count()
1072 }
1073
1074 pub fn remove_file(&mut self, file: &Path) {
1076 self.invalidate_file(file);
1077 }
1078
1079 pub fn invalidate_file(&mut self, file: &Path) {
1080 self.entries.retain(|e| e.chunk.file != file);
1081 self.file_mtimes.remove(file);
1082 }
1083
1084 pub fn dimension(&self) -> usize {
1086 self.dimension
1087 }
1088
1089 pub fn fingerprint(&self) -> Option<&SemanticIndexFingerprint> {
1090 self.fingerprint.as_ref()
1091 }
1092
1093 pub fn backend_label(&self) -> Option<&str> {
1094 self.fingerprint.as_ref().map(|f| f.backend.as_str())
1095 }
1096
1097 pub fn model_label(&self) -> Option<&str> {
1098 self.fingerprint.as_ref().map(|f| f.model.as_str())
1099 }
1100
1101 pub fn set_fingerprint(&mut self, fingerprint: SemanticIndexFingerprint) {
1102 self.fingerprint = Some(fingerprint);
1103 }
1104
1105 pub fn write_to_disk(&self, storage_dir: &Path, project_key: &str) {
1107 if self.entries.is_empty() {
1110 slog_info!("skipping semantic index persistence (0 entries)");
1111 return;
1112 }
1113 let dir = storage_dir.join("semantic").join(project_key);
1114 if let Err(e) = fs::create_dir_all(&dir) {
1115 slog_warn!("failed to create semantic cache dir: {}", e);
1116 return;
1117 }
1118 let data_path = dir.join("semantic.bin");
1119 let tmp_path = dir.join("semantic.bin.tmp");
1120 let bytes = self.to_bytes();
1121 if let Err(e) = fs::write(&tmp_path, &bytes) {
1122 slog_warn!("failed to write semantic index: {}", e);
1123 let _ = fs::remove_file(&tmp_path);
1124 return;
1125 }
1126 if let Err(e) = fs::rename(&tmp_path, &data_path) {
1127 slog_warn!("failed to rename semantic index: {}", e);
1128 let _ = fs::remove_file(&tmp_path);
1129 return;
1130 }
1131 slog_info!(
1132 "semantic index persisted: {} entries, {:.1} KB",
1133 self.entries.len(),
1134 bytes.len() as f64 / 1024.0
1135 );
1136 }
1137
1138 pub fn read_from_disk(
1140 storage_dir: &Path,
1141 project_key: &str,
1142 expected_fingerprint: Option<&str>,
1143 ) -> Option<Self> {
1144 let data_path = storage_dir
1145 .join("semantic")
1146 .join(project_key)
1147 .join("semantic.bin");
1148 let file_len = usize::try_from(fs::metadata(&data_path).ok()?.len()).ok()?;
1149 if file_len < HEADER_BYTES_V1 {
1150 slog_warn!(
1151 "corrupt semantic index (too small: {} bytes), removing",
1152 file_len
1153 );
1154 let _ = fs::remove_file(&data_path);
1155 return None;
1156 }
1157
1158 let bytes = fs::read(&data_path).ok()?;
1159 let version = bytes[0];
1160 if version != SEMANTIC_INDEX_VERSION_V4 {
1161 slog_info!(
1162 "cached semantic index version {} is older than {}, rebuilding",
1163 version,
1164 SEMANTIC_INDEX_VERSION_V4
1165 );
1166 let _ = fs::remove_file(&data_path);
1167 return None;
1168 }
1169 match Self::from_bytes(&bytes) {
1170 Ok(index) => {
1171 if index.entries.is_empty() {
1172 slog_info!("cached semantic index is empty, will rebuild");
1173 let _ = fs::remove_file(&data_path);
1174 return None;
1175 }
1176 if let Some(expected) = expected_fingerprint {
1177 let matches = index
1178 .fingerprint()
1179 .map(|fingerprint| fingerprint.matches_expected(expected))
1180 .unwrap_or(false);
1181 if !matches {
1182 slog_info!("cached semantic index fingerprint mismatch, rebuilding");
1183 let _ = fs::remove_file(&data_path);
1184 return None;
1185 }
1186 }
1187 slog_info!(
1188 "loaded semantic index from disk: {} entries",
1189 index.entries.len()
1190 );
1191 Some(index)
1192 }
1193 Err(e) => {
1194 slog_warn!("corrupt semantic index, rebuilding: {}", e);
1195 let _ = fs::remove_file(&data_path);
1196 None
1197 }
1198 }
1199 }
1200
1201 pub fn to_bytes(&self) -> Vec<u8> {
1203 let mut buf = Vec::new();
1204 let fingerprint_bytes = self.fingerprint.as_ref().and_then(|fingerprint| {
1205 let encoded = fingerprint.as_string();
1206 if encoded.is_empty() {
1207 None
1208 } else {
1209 Some(encoded.into_bytes())
1210 }
1211 });
1212
1213 let version = SEMANTIC_INDEX_VERSION_V4;
1226 buf.push(version);
1227 buf.extend_from_slice(&(self.dimension as u32).to_le_bytes());
1228 buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
1229 let fp_bytes_ref: &[u8] = fingerprint_bytes.as_deref().unwrap_or(&[]);
1230 buf.extend_from_slice(&(fp_bytes_ref.len() as u32).to_le_bytes());
1231 buf.extend_from_slice(fp_bytes_ref);
1232
1233 buf.extend_from_slice(&(self.file_mtimes.len() as u32).to_le_bytes());
1236 for (path, mtime) in &self.file_mtimes {
1237 let path_bytes = path.to_string_lossy().as_bytes().to_vec();
1238 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
1239 buf.extend_from_slice(&path_bytes);
1240 let duration = mtime
1241 .duration_since(SystemTime::UNIX_EPOCH)
1242 .unwrap_or_default();
1243 buf.extend_from_slice(&duration.as_secs().to_le_bytes());
1244 buf.extend_from_slice(&duration.subsec_nanos().to_le_bytes());
1245 }
1246
1247 for entry in &self.entries {
1249 let c = &entry.chunk;
1250
1251 let file_bytes = c.file.to_string_lossy().as_bytes().to_vec();
1253 buf.extend_from_slice(&(file_bytes.len() as u32).to_le_bytes());
1254 buf.extend_from_slice(&file_bytes);
1255
1256 let name_bytes = c.name.as_bytes();
1258 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
1259 buf.extend_from_slice(name_bytes);
1260
1261 buf.push(symbol_kind_to_u8(&c.kind));
1263
1264 buf.extend_from_slice(&(c.start_line as u32).to_le_bytes());
1266 buf.extend_from_slice(&(c.end_line as u32).to_le_bytes());
1267 buf.push(c.exported as u8);
1268
1269 let snippet_bytes = c.snippet.as_bytes();
1271 buf.extend_from_slice(&(snippet_bytes.len() as u32).to_le_bytes());
1272 buf.extend_from_slice(snippet_bytes);
1273
1274 let embed_bytes = c.embed_text.as_bytes();
1276 buf.extend_from_slice(&(embed_bytes.len() as u32).to_le_bytes());
1277 buf.extend_from_slice(embed_bytes);
1278
1279 for &val in &entry.vector {
1281 buf.extend_from_slice(&val.to_le_bytes());
1282 }
1283 }
1284
1285 buf
1286 }
1287
1288 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
1290 let mut pos = 0;
1291
1292 if data.len() < HEADER_BYTES_V1 {
1293 return Err("data too short".to_string());
1294 }
1295
1296 let version = data[pos];
1297 pos += 1;
1298 if version != SEMANTIC_INDEX_VERSION_V1
1299 && version != SEMANTIC_INDEX_VERSION_V2
1300 && version != SEMANTIC_INDEX_VERSION_V3
1301 && version != SEMANTIC_INDEX_VERSION_V4
1302 {
1303 return Err(format!("unsupported version: {}", version));
1304 }
1305 if (version == SEMANTIC_INDEX_VERSION_V2
1309 || version == SEMANTIC_INDEX_VERSION_V3
1310 || version == SEMANTIC_INDEX_VERSION_V4)
1311 && data.len() < HEADER_BYTES_V2
1312 {
1313 return Err("data too short for semantic index v2/v3/v4 header".to_string());
1314 }
1315
1316 let dimension = read_u32(data, &mut pos)? as usize;
1317 let entry_count = read_u32(data, &mut pos)? as usize;
1318 if dimension == 0 || dimension > MAX_DIMENSION {
1319 return Err(format!("invalid embedding dimension: {}", dimension));
1320 }
1321 if entry_count > MAX_ENTRIES {
1322 return Err(format!("too many semantic index entries: {}", entry_count));
1323 }
1324
1325 let has_fingerprint_field = version == SEMANTIC_INDEX_VERSION_V2
1331 || version == SEMANTIC_INDEX_VERSION_V3
1332 || version == SEMANTIC_INDEX_VERSION_V4;
1333 let fingerprint = if has_fingerprint_field {
1334 let fingerprint_len = read_u32(data, &mut pos)? as usize;
1335 if pos + fingerprint_len > data.len() {
1336 return Err("unexpected end of data reading fingerprint".to_string());
1337 }
1338 if fingerprint_len == 0 {
1339 None
1340 } else {
1341 let raw = String::from_utf8_lossy(&data[pos..pos + fingerprint_len]).to_string();
1342 pos += fingerprint_len;
1343 Some(
1344 serde_json::from_str::<SemanticIndexFingerprint>(&raw)
1345 .map_err(|error| format!("invalid semantic fingerprint: {error}"))?,
1346 )
1347 }
1348 } else {
1349 None
1350 };
1351
1352 let mtime_count = read_u32(data, &mut pos)? as usize;
1354 if mtime_count > MAX_ENTRIES {
1355 return Err(format!("too many semantic file mtimes: {}", mtime_count));
1356 }
1357
1358 let vector_bytes = entry_count
1359 .checked_mul(dimension)
1360 .and_then(|count| count.checked_mul(F32_BYTES))
1361 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1362 if vector_bytes > data.len().saturating_sub(pos) {
1363 return Err("semantic index vectors exceed available data".to_string());
1364 }
1365
1366 let mut file_mtimes = HashMap::with_capacity(mtime_count);
1367 for _ in 0..mtime_count {
1368 let path = read_string(data, &mut pos)?;
1369 let secs = read_u64(data, &mut pos)?;
1370 let nanos =
1376 if version == SEMANTIC_INDEX_VERSION_V3 || version == SEMANTIC_INDEX_VERSION_V4 {
1377 read_u32(data, &mut pos)?
1378 } else {
1379 0
1380 };
1381 if nanos >= 1_000_000_000 {
1388 return Err(format!(
1389 "invalid semantic mtime: nanos {} >= 1_000_000_000",
1390 nanos
1391 ));
1392 }
1393 let duration = std::time::Duration::new(secs, nanos);
1394 let mtime = SystemTime::UNIX_EPOCH
1395 .checked_add(duration)
1396 .ok_or_else(|| {
1397 format!(
1398 "invalid semantic mtime: secs={} nanos={} overflows SystemTime",
1399 secs, nanos
1400 )
1401 })?;
1402 file_mtimes.insert(PathBuf::from(path), mtime);
1403 }
1404
1405 let mut entries = Vec::with_capacity(entry_count);
1407 for _ in 0..entry_count {
1408 let file = PathBuf::from(read_string(data, &mut pos)?);
1409 let name = read_string(data, &mut pos)?;
1410
1411 if pos >= data.len() {
1412 return Err("unexpected end of data".to_string());
1413 }
1414 let kind = u8_to_symbol_kind(data[pos]);
1415 pos += 1;
1416
1417 let start_line = read_u32(data, &mut pos)?;
1418 let end_line = read_u32(data, &mut pos)?;
1419
1420 if pos >= data.len() {
1421 return Err("unexpected end of data".to_string());
1422 }
1423 let exported = data[pos] != 0;
1424 pos += 1;
1425
1426 let snippet = read_string(data, &mut pos)?;
1427 let embed_text = read_string(data, &mut pos)?;
1428
1429 let vec_bytes = dimension
1431 .checked_mul(F32_BYTES)
1432 .ok_or_else(|| "semantic vector allocation overflow".to_string())?;
1433 if pos + vec_bytes > data.len() {
1434 return Err("unexpected end of data reading vector".to_string());
1435 }
1436 let mut vector = Vec::with_capacity(dimension);
1437 for _ in 0..dimension {
1438 let bytes = [data[pos], data[pos + 1], data[pos + 2], data[pos + 3]];
1439 vector.push(f32::from_le_bytes(bytes));
1440 pos += 4;
1441 }
1442
1443 entries.push(EmbeddingEntry {
1444 chunk: SemanticChunk {
1445 file,
1446 name,
1447 kind,
1448 start_line,
1449 end_line,
1450 exported,
1451 embed_text,
1452 snippet,
1453 },
1454 vector,
1455 });
1456 }
1457
1458 Ok(Self {
1459 entries,
1460 file_mtimes,
1461 dimension,
1462 fingerprint,
1463 })
1464 }
1465}
1466
1467fn build_embed_text(symbol: &Symbol, source: &str, file: &Path, project_root: &Path) -> String {
1469 let relative = file
1470 .strip_prefix(project_root)
1471 .unwrap_or(file)
1472 .to_string_lossy();
1473
1474 let kind_label = match &symbol.kind {
1475 SymbolKind::Function => "function",
1476 SymbolKind::Class => "class",
1477 SymbolKind::Method => "method",
1478 SymbolKind::Struct => "struct",
1479 SymbolKind::Interface => "interface",
1480 SymbolKind::Enum => "enum",
1481 SymbolKind::TypeAlias => "type",
1482 SymbolKind::Variable => "variable",
1483 SymbolKind::Heading => "heading",
1484 };
1485
1486 let mut text = format!("file:{} kind:{} name:{}", relative, kind_label, symbol.name);
1488
1489 if let Some(sig) = &symbol.signature {
1490 text.push_str(&format!(" signature:{}", sig));
1491 }
1492
1493 let lines: Vec<&str> = source.lines().collect();
1495 let start = (symbol.range.start_line as usize).min(lines.len());
1496 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1498 if start < end {
1499 let body: String = lines[start..end]
1500 .iter()
1501 .take(15) .copied()
1503 .collect::<Vec<&str>>()
1504 .join("\n");
1505 let snippet = if body.len() > 300 {
1506 format!("{}...", &body[..body.floor_char_boundary(300)])
1507 } else {
1508 body
1509 };
1510 text.push_str(&format!(" body:{}", snippet));
1511 }
1512
1513 text
1514}
1515
1516fn parser_for(
1517 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1518 lang: crate::parser::LangId,
1519) -> Result<&mut Parser, String> {
1520 use std::collections::hash_map::Entry;
1521
1522 match parsers.entry(lang) {
1523 Entry::Occupied(entry) => Ok(entry.into_mut()),
1524 Entry::Vacant(entry) => {
1525 let grammar = grammar_for(lang);
1526 let mut parser = Parser::new();
1527 parser
1528 .set_language(&grammar)
1529 .map_err(|error| error.to_string())?;
1530 Ok(entry.insert(parser))
1531 }
1532 }
1533}
1534
1535fn collect_file_chunks(
1536 project_root: &Path,
1537 file: &Path,
1538 parsers: &mut HashMap<crate::parser::LangId, Parser>,
1539) -> Result<Vec<SemanticChunk>, String> {
1540 let lang = detect_language(file).ok_or_else(|| "unsupported file extension".to_string())?;
1541 let source = std::fs::read_to_string(file).map_err(|error| error.to_string())?;
1542 let tree = parser_for(parsers, lang)?
1543 .parse(&source, None)
1544 .ok_or_else(|| format!("tree-sitter parse returned None for {}", file.display()))?;
1545 let symbols =
1546 extract_symbols_from_tree(&source, &tree, lang).map_err(|error| error.to_string())?;
1547
1548 Ok(symbols_to_chunks(file, &symbols, &source, project_root))
1549}
1550
1551fn build_snippet(symbol: &Symbol, source: &str) -> String {
1553 let lines: Vec<&str> = source.lines().collect();
1554 let start = (symbol.range.start_line as usize).min(lines.len());
1555 let end = (symbol.range.end_line as usize + 1).min(lines.len());
1557 if start < end {
1558 let snippet_lines: Vec<&str> = lines[start..end].iter().take(5).copied().collect();
1559 let mut snippet = snippet_lines.join("\n");
1560 if end - start > 5 {
1561 snippet.push_str("\n ...");
1562 }
1563 if snippet.len() > 300 {
1564 snippet = format!("{}...", &snippet[..snippet.floor_char_boundary(300)]);
1565 }
1566 snippet
1567 } else {
1568 String::new()
1569 }
1570}
1571
1572fn symbols_to_chunks(
1574 file: &Path,
1575 symbols: &[Symbol],
1576 source: &str,
1577 project_root: &Path,
1578) -> Vec<SemanticChunk> {
1579 let mut chunks = Vec::new();
1580
1581 for symbol in symbols {
1582 if matches!(symbol.kind, SymbolKind::Heading) {
1587 continue;
1588 }
1589
1590 let line_count = symbol
1592 .range
1593 .end_line
1594 .saturating_sub(symbol.range.start_line)
1595 + 1;
1596 if line_count < 2 && !matches!(symbol.kind, SymbolKind::Variable) {
1597 continue;
1598 }
1599
1600 let embed_text = build_embed_text(symbol, source, file, project_root);
1601 let snippet = build_snippet(symbol, source);
1602
1603 chunks.push(SemanticChunk {
1604 file: file.to_path_buf(),
1605 name: symbol.name.clone(),
1606 kind: symbol.kind.clone(),
1607 start_line: symbol.range.start_line,
1608 end_line: symbol.range.end_line,
1609 exported: symbol.exported,
1610 embed_text,
1611 snippet,
1612 });
1613
1614 }
1617
1618 chunks
1619}
1620
1621fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1623 if a.len() != b.len() {
1624 return 0.0;
1625 }
1626
1627 let mut dot = 0.0f32;
1628 let mut norm_a = 0.0f32;
1629 let mut norm_b = 0.0f32;
1630
1631 for i in 0..a.len() {
1632 dot += a[i] * b[i];
1633 norm_a += a[i] * a[i];
1634 norm_b += b[i] * b[i];
1635 }
1636
1637 let denom = norm_a.sqrt() * norm_b.sqrt();
1638 if denom == 0.0 {
1639 0.0
1640 } else {
1641 dot / denom
1642 }
1643}
1644
1645fn symbol_kind_to_u8(kind: &SymbolKind) -> u8 {
1647 match kind {
1648 SymbolKind::Function => 0,
1649 SymbolKind::Class => 1,
1650 SymbolKind::Method => 2,
1651 SymbolKind::Struct => 3,
1652 SymbolKind::Interface => 4,
1653 SymbolKind::Enum => 5,
1654 SymbolKind::TypeAlias => 6,
1655 SymbolKind::Variable => 7,
1656 SymbolKind::Heading => 8,
1657 }
1658}
1659
1660fn u8_to_symbol_kind(v: u8) -> SymbolKind {
1661 match v {
1662 0 => SymbolKind::Function,
1663 1 => SymbolKind::Class,
1664 2 => SymbolKind::Method,
1665 3 => SymbolKind::Struct,
1666 4 => SymbolKind::Interface,
1667 5 => SymbolKind::Enum,
1668 6 => SymbolKind::TypeAlias,
1669 7 => SymbolKind::Variable,
1670 _ => SymbolKind::Heading,
1671 }
1672}
1673
1674fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, String> {
1675 if *pos + 4 > data.len() {
1676 return Err("unexpected end of data reading u32".to_string());
1677 }
1678 let val = u32::from_le_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
1679 *pos += 4;
1680 Ok(val)
1681}
1682
1683fn read_u64(data: &[u8], pos: &mut usize) -> Result<u64, String> {
1684 if *pos + 8 > data.len() {
1685 return Err("unexpected end of data reading u64".to_string());
1686 }
1687 let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
1688 *pos += 8;
1689 Ok(u64::from_le_bytes(bytes))
1690}
1691
1692fn read_string(data: &[u8], pos: &mut usize) -> Result<String, String> {
1693 let len = read_u32(data, pos)? as usize;
1694 if *pos + len > data.len() {
1695 return Err("unexpected end of data reading string".to_string());
1696 }
1697 let s = String::from_utf8_lossy(&data[*pos..*pos + len]).to_string();
1698 *pos += len;
1699 Ok(s)
1700}
1701
1702#[cfg(test)]
1703mod tests {
1704 use super::*;
1705 use crate::config::{SemanticBackend, SemanticBackendConfig};
1706 use crate::parser::FileParser;
1707 use std::io::{Read, Write};
1708 use std::net::TcpListener;
1709 use std::thread;
1710
1711 fn start_mock_http_server<F>(handler: F) -> (String, thread::JoinHandle<()>)
1712 where
1713 F: Fn(String, String, String) -> String + Send + 'static,
1714 {
1715 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1716 let addr = listener.local_addr().expect("local addr");
1717 let handle = thread::spawn(move || {
1718 let (mut stream, _) = listener.accept().expect("accept request");
1719 let mut buf = Vec::new();
1720 let mut chunk = [0u8; 4096];
1721 let mut header_end = None;
1722 let mut content_length = 0usize;
1723 loop {
1724 let n = stream.read(&mut chunk).expect("read request");
1725 if n == 0 {
1726 break;
1727 }
1728 buf.extend_from_slice(&chunk[..n]);
1729 if header_end.is_none() {
1730 if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
1731 header_end = Some(pos + 4);
1732 let headers = String::from_utf8_lossy(&buf[..pos + 4]);
1733 for line in headers.lines() {
1734 if let Some(value) = line.strip_prefix("Content-Length:") {
1735 content_length = value.trim().parse::<usize>().unwrap_or(0);
1736 }
1737 }
1738 }
1739 }
1740 if let Some(end) = header_end {
1741 if buf.len() >= end + content_length {
1742 break;
1743 }
1744 }
1745 }
1746
1747 let end = header_end.expect("header terminator");
1748 let request = String::from_utf8_lossy(&buf[..end]).to_string();
1749 let body = String::from_utf8_lossy(&buf[end..end + content_length]).to_string();
1750 let mut lines = request.lines();
1751 let request_line = lines.next().expect("request line").to_string();
1752 let path = request_line
1753 .split_whitespace()
1754 .nth(1)
1755 .expect("request path")
1756 .to_string();
1757 let response_body = handler(request_line, path, body);
1758 let response = format!(
1759 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
1760 response_body.len(),
1761 response_body
1762 );
1763 stream
1764 .write_all(response.as_bytes())
1765 .expect("write response");
1766 });
1767
1768 (format!("http://{}", addr), handle)
1769 }
1770
1771 #[test]
1772 fn test_cosine_similarity_identical() {
1773 let a = vec![1.0, 0.0, 0.0];
1774 let b = vec![1.0, 0.0, 0.0];
1775 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1776 }
1777
1778 #[test]
1779 fn test_cosine_similarity_orthogonal() {
1780 let a = vec![1.0, 0.0, 0.0];
1781 let b = vec![0.0, 1.0, 0.0];
1782 assert!(cosine_similarity(&a, &b).abs() < 0.001);
1783 }
1784
1785 #[test]
1786 fn test_cosine_similarity_opposite() {
1787 let a = vec![1.0, 0.0, 0.0];
1788 let b = vec![-1.0, 0.0, 0.0];
1789 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
1790 }
1791
1792 #[test]
1793 fn test_serialization_roundtrip() {
1794 let mut index = SemanticIndex::new();
1795 index.entries.push(EmbeddingEntry {
1796 chunk: SemanticChunk {
1797 file: PathBuf::from("/src/main.rs"),
1798 name: "handle_request".to_string(),
1799 kind: SymbolKind::Function,
1800 start_line: 10,
1801 end_line: 25,
1802 exported: true,
1803 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
1804 snippet: "fn handle_request() {\n // ...\n}".to_string(),
1805 },
1806 vector: vec![0.1, 0.2, 0.3, 0.4],
1807 });
1808 index.dimension = 4;
1809 index
1810 .file_mtimes
1811 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
1812 index.set_fingerprint(SemanticIndexFingerprint {
1813 backend: "fastembed".to_string(),
1814 model: "all-MiniLM-L6-v2".to_string(),
1815 base_url: FALLBACK_BACKEND.to_string(),
1816 dimension: 4,
1817 });
1818
1819 let bytes = index.to_bytes();
1820 let restored = SemanticIndex::from_bytes(&bytes).unwrap();
1821
1822 assert_eq!(restored.entries.len(), 1);
1823 assert_eq!(restored.entries[0].chunk.name, "handle_request");
1824 assert_eq!(restored.entries[0].vector, vec![0.1, 0.2, 0.3, 0.4]);
1825 assert_eq!(restored.dimension, 4);
1826 assert_eq!(restored.backend_label(), Some("fastembed"));
1827 assert_eq!(restored.model_label(), Some("all-MiniLM-L6-v2"));
1828 }
1829
1830 #[test]
1831 fn test_search_top_k() {
1832 let mut index = SemanticIndex::new();
1833 index.dimension = 3;
1834
1835 for (i, name) in ["auth", "database", "handler"].iter().enumerate() {
1837 let mut vec = vec![0.0f32; 3];
1838 vec[i] = 1.0; index.entries.push(EmbeddingEntry {
1840 chunk: SemanticChunk {
1841 file: PathBuf::from("/src/lib.rs"),
1842 name: name.to_string(),
1843 kind: SymbolKind::Function,
1844 start_line: (i * 10 + 1) as u32,
1845 end_line: (i * 10 + 5) as u32,
1846 exported: true,
1847 embed_text: format!("kind:function name:{}", name),
1848 snippet: format!("fn {}() {{}}", name),
1849 },
1850 vector: vec,
1851 });
1852 }
1853
1854 let query = vec![0.9, 0.1, 0.0];
1856 let results = index.search(&query, 2);
1857
1858 assert_eq!(results.len(), 2);
1859 assert_eq!(results[0].name, "auth"); assert!(results[0].score > results[1].score);
1861 }
1862
1863 #[test]
1864 fn test_empty_index_search() {
1865 let index = SemanticIndex::new();
1866 let results = index.search(&[0.1, 0.2, 0.3], 10);
1867 assert!(results.is_empty());
1868 }
1869
1870 #[test]
1871 fn single_line_symbol_builds_non_empty_snippet() {
1872 let symbol = Symbol {
1873 name: "answer".to_string(),
1874 kind: SymbolKind::Variable,
1875 range: crate::symbols::Range {
1876 start_line: 0,
1877 start_col: 0,
1878 end_line: 0,
1879 end_col: 24,
1880 },
1881 signature: Some("const answer = 42".to_string()),
1882 scope_chain: Vec::new(),
1883 exported: true,
1884 parent: None,
1885 };
1886 let source = "export const answer = 42;\n";
1887
1888 let snippet = build_snippet(&symbol, source);
1889
1890 assert_eq!(snippet, "export const answer = 42;");
1891 }
1892
1893 #[test]
1894 fn optimized_file_chunk_collection_matches_file_parser_path() {
1895 let project_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
1896 let file = project_root.join("src/semantic_index.rs");
1897 let source = std::fs::read_to_string(&file).unwrap();
1898
1899 let mut legacy_parser = FileParser::new();
1900 let legacy_symbols = legacy_parser.extract_symbols(&file).unwrap();
1901 let legacy_chunks = symbols_to_chunks(&file, &legacy_symbols, &source, &project_root);
1902
1903 let mut parsers = HashMap::new();
1904 let optimized_chunks = collect_file_chunks(&project_root, &file, &mut parsers).unwrap();
1905
1906 assert_eq!(
1907 chunk_fingerprint(&optimized_chunks),
1908 chunk_fingerprint(&legacy_chunks)
1909 );
1910 }
1911
1912 fn chunk_fingerprint(
1913 chunks: &[SemanticChunk],
1914 ) -> Vec<(String, SymbolKind, u32, u32, bool, String, String)> {
1915 chunks
1916 .iter()
1917 .map(|chunk| {
1918 (
1919 chunk.name.clone(),
1920 chunk.kind.clone(),
1921 chunk.start_line,
1922 chunk.end_line,
1923 chunk.exported,
1924 chunk.embed_text.clone(),
1925 chunk.snippet.clone(),
1926 )
1927 })
1928 .collect()
1929 }
1930
1931 #[test]
1932 fn rejects_oversized_dimension_during_deserialization() {
1933 let mut bytes = Vec::new();
1934 bytes.push(1u8);
1935 bytes.extend_from_slice(&((MAX_DIMENSION as u32) + 1).to_le_bytes());
1936 bytes.extend_from_slice(&0u32.to_le_bytes());
1937 bytes.extend_from_slice(&0u32.to_le_bytes());
1938
1939 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1940 }
1941
1942 #[test]
1943 fn rejects_oversized_entry_count_during_deserialization() {
1944 let mut bytes = Vec::new();
1945 bytes.push(1u8);
1946 bytes.extend_from_slice(&(DEFAULT_DIMENSION as u32).to_le_bytes());
1947 bytes.extend_from_slice(&((MAX_ENTRIES as u32) + 1).to_le_bytes());
1948 bytes.extend_from_slice(&0u32.to_le_bytes());
1949
1950 assert!(SemanticIndex::from_bytes(&bytes).is_err());
1951 }
1952
1953 #[test]
1954 fn invalidate_file_removes_entries_and_mtime() {
1955 let target = PathBuf::from("/src/main.rs");
1956 let mut index = SemanticIndex::new();
1957 index.entries.push(EmbeddingEntry {
1958 chunk: SemanticChunk {
1959 file: target.clone(),
1960 name: "main".to_string(),
1961 kind: SymbolKind::Function,
1962 start_line: 0,
1963 end_line: 1,
1964 exported: false,
1965 embed_text: "main".to_string(),
1966 snippet: "fn main() {}".to_string(),
1967 },
1968 vector: vec![1.0; DEFAULT_DIMENSION],
1969 });
1970 index
1971 .file_mtimes
1972 .insert(target.clone(), SystemTime::UNIX_EPOCH);
1973
1974 index.invalidate_file(&target);
1975
1976 assert!(index.entries.is_empty());
1977 assert!(!index.file_mtimes.contains_key(&target));
1978 }
1979
1980 #[test]
1981 fn detects_missing_onnx_runtime_from_dynamic_load_error() {
1982 let message = "Failed to load ONNX Runtime shared library libonnxruntime.dylib via dlopen: no such file";
1983
1984 assert!(is_onnx_runtime_unavailable(message));
1985 }
1986
1987 #[test]
1988 fn formats_missing_onnx_runtime_with_install_hint() {
1989 let message = format_embedding_init_error(
1990 "Failed to load ONNX Runtime shared library libonnxruntime.so via dlopen: no such file",
1991 );
1992
1993 assert!(message.starts_with("ONNX Runtime not found. Install via:"));
1994 assert!(message.contains("Original error:"));
1995 }
1996
1997 #[test]
1998 fn openai_compatible_backend_embeds_with_mock_server() {
1999 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2000 assert!(request_line.starts_with("POST "));
2001 assert_eq!(path, "/v1/embeddings");
2002 "{\"data\":[{\"embedding\":[0.1,0.2,0.3],\"index\":0},{\"embedding\":[0.4,0.5,0.6],\"index\":1}]}".to_string()
2003 });
2004
2005 let config = SemanticBackendConfig {
2006 backend: SemanticBackend::OpenAiCompatible,
2007 model: "test-embedding".to_string(),
2008 base_url: Some(base_url),
2009 api_key_env: None,
2010 timeout_ms: 5_000,
2011 max_batch_size: 64,
2012 };
2013
2014 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2015 let vectors = model
2016 .embed(vec!["hello".to_string(), "world".to_string()])
2017 .unwrap();
2018
2019 assert_eq!(vectors, vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
2020 handle.join().unwrap();
2021 }
2022
2023 #[test]
2024 fn ollama_backend_embeds_with_mock_server() {
2025 let (base_url, handle) = start_mock_http_server(|request_line, path, _body| {
2026 assert!(request_line.starts_with("POST "));
2027 assert_eq!(path, "/api/embed");
2028 "{\"embeddings\":[[0.7,0.8,0.9],[1.0,1.1,1.2]]}".to_string()
2029 });
2030
2031 let config = SemanticBackendConfig {
2032 backend: SemanticBackend::Ollama,
2033 model: "embeddinggemma".to_string(),
2034 base_url: Some(base_url),
2035 api_key_env: None,
2036 timeout_ms: 5_000,
2037 max_batch_size: 64,
2038 };
2039
2040 let mut model = SemanticEmbeddingModel::from_config(&config).unwrap();
2041 let vectors = model
2042 .embed(vec!["hello".to_string(), "world".to_string()])
2043 .unwrap();
2044
2045 assert_eq!(vectors, vec![vec![0.7, 0.8, 0.9], vec![1.0, 1.1, 1.2]]);
2046 handle.join().unwrap();
2047 }
2048
2049 #[test]
2050 fn read_from_disk_rejects_fingerprint_mismatch() {
2051 let storage = tempfile::tempdir().unwrap();
2052 let project_key = "proj";
2053
2054 let mut index = SemanticIndex::new();
2055 index.entries.push(EmbeddingEntry {
2056 chunk: SemanticChunk {
2057 file: PathBuf::from("/src/main.rs"),
2058 name: "handle_request".to_string(),
2059 kind: SymbolKind::Function,
2060 start_line: 10,
2061 end_line: 25,
2062 exported: true,
2063 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2064 snippet: "fn handle_request() {}".to_string(),
2065 },
2066 vector: vec![0.1, 0.2, 0.3],
2067 });
2068 index.dimension = 3;
2069 index
2070 .file_mtimes
2071 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2072 index.set_fingerprint(SemanticIndexFingerprint {
2073 backend: "openai_compatible".to_string(),
2074 model: "test-embedding".to_string(),
2075 base_url: "http://127.0.0.1:1234/v1".to_string(),
2076 dimension: 3,
2077 });
2078 index.write_to_disk(storage.path(), project_key);
2079
2080 let matching = index.fingerprint().unwrap().as_string();
2081 assert!(
2082 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&matching)).is_some()
2083 );
2084
2085 let mismatched = SemanticIndexFingerprint {
2086 backend: "ollama".to_string(),
2087 model: "embeddinggemma".to_string(),
2088 base_url: "http://127.0.0.1:11434".to_string(),
2089 dimension: 3,
2090 }
2091 .as_string();
2092 assert!(
2093 SemanticIndex::read_from_disk(storage.path(), project_key, Some(&mismatched)).is_none()
2094 );
2095 }
2096
2097 #[test]
2098 fn read_from_disk_rejects_v3_cache_for_snippet_rebuild() {
2099 let storage = tempfile::tempdir().unwrap();
2100 let project_key = "proj-v3";
2101 let dir = storage.path().join("semantic").join(project_key);
2102 fs::create_dir_all(&dir).unwrap();
2103
2104 let mut index = SemanticIndex::new();
2105 index.entries.push(EmbeddingEntry {
2106 chunk: SemanticChunk {
2107 file: PathBuf::from("/src/main.rs"),
2108 name: "handle_request".to_string(),
2109 kind: SymbolKind::Function,
2110 start_line: 0,
2111 end_line: 0,
2112 exported: true,
2113 embed_text: "file:src/main.rs kind:function name:handle_request".to_string(),
2114 snippet: "fn handle_request() {}".to_string(),
2115 },
2116 vector: vec![0.1, 0.2, 0.3],
2117 });
2118 index.dimension = 3;
2119 index
2120 .file_mtimes
2121 .insert(PathBuf::from("/src/main.rs"), SystemTime::UNIX_EPOCH);
2122 let fingerprint = SemanticIndexFingerprint {
2123 backend: "fastembed".to_string(),
2124 model: "test".to_string(),
2125 base_url: FALLBACK_BACKEND.to_string(),
2126 dimension: 3,
2127 };
2128 index.set_fingerprint(fingerprint.clone());
2129
2130 let mut bytes = index.to_bytes();
2131 bytes[0] = SEMANTIC_INDEX_VERSION_V3;
2132 fs::write(dir.join("semantic.bin"), bytes).unwrap();
2133
2134 assert!(SemanticIndex::read_from_disk(
2135 storage.path(),
2136 project_key,
2137 Some(&fingerprint.as_string())
2138 )
2139 .is_none());
2140 assert!(!dir.join("semantic.bin").exists());
2141 }
2142
2143 fn make_symbol(kind: SymbolKind, name: &str, start: u32, end: u32) -> crate::symbols::Symbol {
2144 crate::symbols::Symbol {
2145 name: name.to_string(),
2146 kind,
2147 range: crate::symbols::Range {
2148 start_line: start,
2149 start_col: 0,
2150 end_line: end,
2151 end_col: 0,
2152 },
2153 signature: None,
2154 scope_chain: Vec::new(),
2155 exported: false,
2156 parent: None,
2157 }
2158 }
2159
2160 #[test]
2165 fn symbols_to_chunks_skips_heading_symbols() {
2166 let project_root = PathBuf::from("/proj");
2167 let file = project_root.join("README.md");
2168 let source = "# Title\n\nbody text\n\n## Section\n\nmore text\n";
2169
2170 let symbols = vec![
2171 make_symbol(SymbolKind::Heading, "Title", 0, 2),
2172 make_symbol(SymbolKind::Heading, "Section", 4, 6),
2173 ];
2174
2175 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2176 assert!(
2177 chunks.is_empty(),
2178 "Heading symbols must be filtered out before embedding; got {} chunk(s)",
2179 chunks.len()
2180 );
2181 }
2182
2183 #[test]
2187 fn symbols_to_chunks_keeps_code_symbols_alongside_skipped_headings() {
2188 let project_root = PathBuf::from("/proj");
2189 let file = project_root.join("src/lib.rs");
2190 let source = "pub fn handle_request() -> bool {\n true\n}\n";
2191
2192 let symbols = vec![
2193 make_symbol(SymbolKind::Heading, "doc heading", 0, 1),
2195 make_symbol(SymbolKind::Function, "handle_request", 0, 2),
2196 make_symbol(SymbolKind::Struct, "AuthService", 4, 6),
2197 ];
2198
2199 let chunks = symbols_to_chunks(&file, &symbols, source, &project_root);
2200 assert_eq!(
2201 chunks.len(),
2202 2,
2203 "Expected 2 code chunks (Function + Struct), got {}",
2204 chunks.len()
2205 );
2206 let names: Vec<&str> = chunks.iter().map(|c| c.name.as_str()).collect();
2207 assert!(names.contains(&"handle_request"));
2208 assert!(names.contains(&"AuthService"));
2209 assert!(
2210 !names.contains(&"doc heading"),
2211 "Heading symbol leaked into chunks: {names:?}"
2212 );
2213 }
2214
2215 #[test]
2216 fn validate_ssrf_rejects_loopback_hostnames() {
2217 for host in &[
2218 "http://localhost",
2219 "http://localhost:8080",
2220 "http://localhost.localdomain",
2221 "http://foo.localhost",
2222 ] {
2223 assert!(
2224 validate_base_url_no_ssrf(host).is_err(),
2225 "Expected {host} to be rejected"
2226 );
2227 }
2228 }
2229
2230 #[test]
2231 fn validate_ssrf_rejects_private_ips() {
2232 for url in &[
2233 "http://192.168.1.1",
2234 "http://10.0.0.1",
2235 "http://172.16.0.1",
2236 "http://127.0.0.1",
2237 "http://169.254.169.254",
2238 ] {
2239 let result = validate_base_url_no_ssrf(url);
2240 assert!(
2241 result.is_err(),
2242 "Expected {url} to be rejected, got: {:?}",
2243 result
2244 );
2245 }
2246 }
2247
2248 #[test]
2249 fn normalize_base_url_allows_localhost_for_tests() {
2250 assert!(normalize_base_url("http://127.0.0.1:9999").is_ok());
2253 assert!(normalize_base_url("http://localhost:8080").is_ok());
2254 }
2255}