1use crate::db::IndexDb;
5use crate::embedding_store::{EmbeddingChunk, EmbeddingStore, ScoredChunk};
6use crate::project::ProjectRoot;
7use anyhow::{Context, Result};
8use fastembed::{
9 ExecutionProviderDispatch, InitOptionsUserDefined, TextEmbedding, TokenizerFiles,
10 UserDefinedEmbeddingModel,
11};
12use rusqlite::Connection;
13use serde::Serialize;
14use std::collections::{HashMap, HashSet, VecDeque};
15use std::sync::{Arc, Mutex, Once};
16use std::thread::available_parallelism;
17use tracing::debug;
18
19pub(super) mod ffi {
21 use anyhow::Result;
22
23 pub fn register_sqlite_vec() -> Result<()> {
24 let rc = unsafe {
25 rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute::<
26 *const (),
27 unsafe extern "C" fn(
28 *mut rusqlite::ffi::sqlite3,
29 *mut *mut i8,
30 *const rusqlite::ffi::sqlite3_api_routines,
31 ) -> i32,
32 >(
33 sqlite_vec::sqlite3_vec_init as *const ()
34 )))
35 };
36 if rc != rusqlite::ffi::SQLITE_OK {
37 anyhow::bail!("failed to register sqlite-vec extension (SQLite error code: {rc})");
38 }
39 Ok(())
40 }
41
42 #[cfg(target_os = "macos")]
43 pub fn sysctl_usize(name: &[u8]) -> Option<usize> {
44 let mut value: libc::c_uint = 0;
45 let mut size = std::mem::size_of::<libc::c_uint>();
46 let rc = unsafe {
47 libc::sysctlbyname(
48 name.as_ptr().cast(),
49 (&mut value as *mut libc::c_uint).cast(),
50 &mut size,
51 std::ptr::null_mut(),
52 0,
53 )
54 };
55 (rc == 0 && size == std::mem::size_of::<libc::c_uint>()).then_some(value as usize)
56 }
57}
58
59#[derive(Debug, Clone, Serialize)]
61pub struct SemanticMatch {
62 pub file_path: String,
63 pub symbol_name: String,
64 pub kind: String,
65 pub line: usize,
66 pub signature: String,
67 pub name_path: String,
68 pub score: f64,
69}
70
71impl From<ScoredChunk> for SemanticMatch {
72 fn from(c: ScoredChunk) -> Self {
73 Self {
74 file_path: c.file_path,
75 symbol_name: c.symbol_name,
76 kind: c.kind,
77 line: c.line,
78 signature: c.signature,
79 name_path: c.name_path,
80 score: c.score,
81 }
82 }
83}
84
85mod vec_store;
86use vec_store::SqliteVecStore;
87
88type ReusableEmbeddingKey = (String, String, String, String, String, String);
89
90fn reusable_embedding_key(
91 file_path: &str,
92 symbol_name: &str,
93 kind: &str,
94 signature: &str,
95 name_path: &str,
96 text: &str,
97) -> ReusableEmbeddingKey {
98 (
99 file_path.to_owned(),
100 symbol_name.to_owned(),
101 kind.to_owned(),
102 signature.to_owned(),
103 name_path.to_owned(),
104 text.to_owned(),
105 )
106}
107
108fn reusable_embedding_key_for_chunk(chunk: &EmbeddingChunk) -> ReusableEmbeddingKey {
109 reusable_embedding_key(
110 &chunk.file_path,
111 &chunk.symbol_name,
112 &chunk.kind,
113 &chunk.signature,
114 &chunk.name_path,
115 &chunk.text,
116 )
117}
118
119fn reusable_embedding_key_for_symbol(
120 sym: &crate::db::SymbolWithFile,
121 text: &str,
122) -> ReusableEmbeddingKey {
123 reusable_embedding_key(
124 &sym.file_path,
125 &sym.name,
126 &sym.kind,
127 &sym.signature,
128 &sym.name_path,
129 text,
130 )
131}
132
133const DEFAULT_EMBED_BATCH_SIZE: usize = 128;
136const DEFAULT_MACOS_EMBED_BATCH_SIZE: usize = 128;
137const DEFAULT_TEXT_EMBED_CACHE_SIZE: usize = 256;
138const DEFAULT_MACOS_TEXT_EMBED_CACHE_SIZE: usize = 1024;
139const CODESEARCH_DIMENSION: usize = 384;
140const DEFAULT_MAX_EMBED_SYMBOLS: usize = 50_000;
141const CHANGED_FILE_QUERY_CHUNK: usize = 128;
142const DEFAULT_DUPLICATE_SCAN_BATCH_SIZE: usize = 128;
143static ORT_ENV_INIT: Once = Once::new();
144
145const CODESEARCH_MODEL_NAME: &str = "MiniLM-L12-CodeSearchNet-INT8";
148
149pub struct EmbeddingEngine {
150 model: Mutex<TextEmbedding>,
151 store: Box<dyn EmbeddingStore>,
152 model_name: String,
153 runtime_info: EmbeddingRuntimeInfo,
154 text_embed_cache: Mutex<TextEmbeddingCache>,
155 indexing: std::sync::atomic::AtomicBool,
156}
157
158#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
159pub struct EmbeddingIndexInfo {
160 pub model_name: String,
161 pub indexed_symbols: usize,
162}
163
164#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
165pub struct EmbeddingRuntimeInfo {
166 pub runtime_preference: String,
167 pub backend: String,
168 pub threads: usize,
169 pub max_length: usize,
170 pub coreml_model_format: Option<String>,
171 pub coreml_compute_units: Option<String>,
172 pub coreml_static_input_shapes: Option<bool>,
173 pub coreml_profile_compute_plan: Option<bool>,
174 pub coreml_specialization_strategy: Option<String>,
175 pub coreml_model_cache_dir: Option<String>,
176 pub fallback_reason: Option<String>,
177}
178
179struct TextEmbeddingCache {
180 capacity: usize,
181 order: VecDeque<String>,
182 entries: HashMap<String, Vec<f32>>,
183}
184
185impl TextEmbeddingCache {
186 fn new(capacity: usize) -> Self {
187 Self {
188 capacity,
189 order: VecDeque::new(),
190 entries: HashMap::new(),
191 }
192 }
193
194 fn get(&mut self, key: &str) -> Option<Vec<f32>> {
195 let value = self.entries.get(key)?.clone();
196 self.touch(key);
197 Some(value)
198 }
199
200 fn insert(&mut self, key: String, value: Vec<f32>) {
201 if self.capacity == 0 {
202 return;
203 }
204
205 self.entries.insert(key.clone(), value);
206 self.touch(&key);
207
208 while self.entries.len() > self.capacity {
209 if let Some(oldest) = self.order.pop_front() {
210 self.entries.remove(&oldest);
211 } else {
212 break;
213 }
214 }
215 }
216
217 fn touch(&mut self, key: &str) {
218 if let Some(position) = self.order.iter().position(|existing| existing == key) {
219 self.order.remove(position);
220 }
221 self.order.push_back(key.to_owned());
222 }
223}
224
225fn resolve_model_dir() -> Result<std::path::PathBuf> {
233 if let Ok(dir) = std::env::var("CODELENS_MODEL_DIR") {
235 let p = std::path::PathBuf::from(dir).join("codesearch");
236 if p.join("model.onnx").exists() {
237 return Ok(p);
238 }
239 }
240
241 if let Ok(exe) = std::env::current_exe()
243 && let Some(exe_dir) = exe.parent()
244 {
245 let p = exe_dir.join("models").join("codesearch");
246 if p.join("model.onnx").exists() {
247 return Ok(p);
248 }
249 }
250
251 if let Some(home) = dirs_fallback() {
253 let p = home
254 .join(".cache")
255 .join("codelens")
256 .join("models")
257 .join("codesearch");
258 if p.join("model.onnx").exists() {
259 return Ok(p);
260 }
261 }
262
263 let dev_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
265 .join("models")
266 .join("codesearch");
267 if dev_path.join("model.onnx").exists() {
268 return Ok(dev_path);
269 }
270
271 anyhow::bail!(
272 "CodeSearchNet model not found. Place model files in one of:\n\
273 - $CODELENS_MODEL_DIR/codesearch/\n\
274 - <executable>/models/codesearch/\n\
275 - ~/.cache/codelens/models/codesearch/\n\
276 Required files: model.onnx, tokenizer.json, config.json, special_tokens_map.json, tokenizer_config.json"
277 )
278}
279
280fn dirs_fallback() -> Option<std::path::PathBuf> {
281 std::env::var_os("HOME").map(std::path::PathBuf::from)
282}
283
284fn parse_usize_env(name: &str) -> Option<usize> {
285 std::env::var(name)
286 .ok()
287 .and_then(|v| v.trim().parse::<usize>().ok())
288 .filter(|v| *v > 0)
289}
290
291fn parse_bool_env(name: &str) -> Option<bool> {
292 std::env::var(name).ok().and_then(|value| {
293 let normalized = value.trim().to_ascii_lowercase();
294 match normalized.as_str() {
295 "1" | "true" | "yes" | "on" => Some(true),
296 "0" | "false" | "no" | "off" => Some(false),
297 _ => None,
298 }
299 })
300}
301
302#[cfg(target_os = "macos")]
303fn apple_perf_cores() -> Option<usize> {
304 ffi::sysctl_usize(b"hw.perflevel0.physicalcpu\0")
305 .filter(|value| *value > 0)
306 .or_else(|| ffi::sysctl_usize(b"hw.physicalcpu\0").filter(|value| *value > 0))
307}
308
309#[cfg(not(target_os = "macos"))]
310fn apple_perf_cores() -> Option<usize> {
311 None
312}
313
314pub fn configured_embedding_runtime_preference() -> String {
315 let requested = std::env::var("CODELENS_EMBED_PROVIDER")
316 .ok()
317 .map(|value| value.trim().to_ascii_lowercase());
318
319 match requested.as_deref() {
320 Some("cpu") => "cpu".to_string(),
321 Some("coreml") if cfg!(target_os = "macos") => "coreml".to_string(),
322 Some("coreml") => "cpu".to_string(),
323 _ if cfg!(target_os = "macos") => "coreml_preferred".to_string(),
324 _ => "cpu".to_string(),
325 }
326}
327
328pub fn configured_embedding_threads() -> usize {
329 recommended_embed_threads()
330}
331
332fn configured_embedding_max_length() -> usize {
333 parse_usize_env("CODELENS_EMBED_MAX_LENGTH")
334 .unwrap_or(256)
335 .clamp(32, 512)
336}
337
338fn configured_embedding_text_cache_size() -> usize {
339 std::env::var("CODELENS_EMBED_TEXT_CACHE_SIZE")
340 .ok()
341 .and_then(|value| value.trim().parse::<usize>().ok())
342 .unwrap_or({
343 if cfg!(target_os = "macos") {
344 DEFAULT_MACOS_TEXT_EMBED_CACHE_SIZE
345 } else {
346 DEFAULT_TEXT_EMBED_CACHE_SIZE
347 }
348 })
349 .min(8192)
350}
351
352#[cfg(target_os = "macos")]
353fn configured_coreml_compute_units_name() -> String {
354 match std::env::var("CODELENS_EMBED_COREML_COMPUTE_UNITS")
355 .ok()
356 .map(|value| value.trim().to_ascii_lowercase())
357 .as_deref()
358 {
359 Some("all") => "all".to_string(),
360 Some("cpu") | Some("cpu_only") => "cpu_only".to_string(),
361 Some("gpu") | Some("cpu_and_gpu") => "cpu_and_gpu".to_string(),
362 Some("ane") | Some("neural_engine") | Some("cpu_and_neural_engine") => {
363 "cpu_and_neural_engine".to_string()
364 }
365 _ => "cpu_and_neural_engine".to_string(),
366 }
367}
368
369#[cfg(target_os = "macos")]
370fn configured_coreml_model_format_name() -> String {
371 match std::env::var("CODELENS_EMBED_COREML_MODEL_FORMAT")
372 .ok()
373 .map(|value| value.trim().to_ascii_lowercase())
374 .as_deref()
375 {
376 Some("neuralnetwork") | Some("neural_network") => "neural_network".to_string(),
377 _ => "mlprogram".to_string(),
378 }
379}
380
381#[cfg(target_os = "macos")]
382fn configured_coreml_profile_compute_plan() -> bool {
383 parse_bool_env("CODELENS_EMBED_COREML_PROFILE_PLAN").unwrap_or(false)
384}
385
386#[cfg(target_os = "macos")]
387fn configured_coreml_static_input_shapes() -> bool {
388 parse_bool_env("CODELENS_EMBED_COREML_STATIC_INPUT_SHAPES").unwrap_or(true)
389}
390
391#[cfg(target_os = "macos")]
392fn configured_coreml_specialization_strategy_name() -> String {
393 match std::env::var("CODELENS_EMBED_COREML_SPECIALIZATION")
394 .ok()
395 .map(|value| value.trim().to_ascii_lowercase())
396 .as_deref()
397 {
398 Some("default") => "default".to_string(),
399 _ => "fast_prediction".to_string(),
400 }
401}
402
403#[cfg(target_os = "macos")]
404fn configured_coreml_model_cache_dir() -> std::path::PathBuf {
405 dirs_fallback()
406 .unwrap_or_else(std::env::temp_dir)
407 .join(".cache")
408 .join("codelens")
409 .join("coreml-cache")
410 .join("codesearch")
411}
412
413fn recommended_embed_threads() -> usize {
414 if let Some(explicit) = parse_usize_env("CODELENS_EMBED_THREADS") {
415 return explicit.max(1);
416 }
417
418 let available = available_parallelism().map(|n| n.get()).unwrap_or(1);
419 if cfg!(target_os = "macos") {
420 apple_perf_cores()
421 .unwrap_or(available)
422 .min(available)
423 .clamp(1, 8)
424 } else {
425 available.div_ceil(2).clamp(1, 8)
426 }
427}
428
429fn embed_batch_size() -> usize {
430 parse_usize_env("CODELENS_EMBED_BATCH_SIZE").unwrap_or({
431 if cfg!(target_os = "macos") {
432 DEFAULT_MACOS_EMBED_BATCH_SIZE
433 } else {
434 DEFAULT_EMBED_BATCH_SIZE
435 }
436 })
437}
438
439fn max_embed_symbols() -> usize {
440 parse_usize_env("CODELENS_MAX_EMBED_SYMBOLS").unwrap_or(DEFAULT_MAX_EMBED_SYMBOLS)
441}
442
443fn set_env_if_unset(name: &str, value: impl Into<String>) {
444 if std::env::var_os(name).is_none() {
445 unsafe {
448 std::env::set_var(name, value.into());
449 }
450 }
451}
452
453fn configure_embedding_runtime() {
454 let threads = recommended_embed_threads();
455 let runtime_preference = configured_embedding_runtime_preference();
456
457 set_env_if_unset("OMP_NUM_THREADS", threads.to_string());
460 set_env_if_unset("OMP_WAIT_POLICY", "PASSIVE");
461 set_env_if_unset("OMP_DYNAMIC", "FALSE");
462 set_env_if_unset("TOKENIZERS_PARALLELISM", "false");
463 if cfg!(target_os = "macos") {
464 set_env_if_unset("VECLIB_MAXIMUM_THREADS", threads.to_string());
465 }
466
467 ORT_ENV_INIT.call_once(|| {
468 let pool = ort::environment::GlobalThreadPoolOptions::default()
469 .with_intra_threads(threads)
470 .and_then(|pool| pool.with_inter_threads(1))
471 .and_then(|pool| pool.with_spin_control(false));
472
473 if let Ok(pool) = pool {
474 let _ = ort::init()
475 .with_name("codelens-embedding")
476 .with_telemetry(false)
477 .with_global_thread_pool(pool)
478 .commit();
479 }
480 });
481
482 debug!(
483 threads,
484 runtime_preference = %runtime_preference,
485 "configured embedding runtime"
486 );
487}
488
489pub fn configured_embedding_runtime_info() -> EmbeddingRuntimeInfo {
490 let runtime_preference = configured_embedding_runtime_preference();
491 let threads = configured_embedding_threads();
492
493 #[cfg(target_os = "macos")]
494 {
495 let coreml_enabled = runtime_preference != "cpu";
496 EmbeddingRuntimeInfo {
497 runtime_preference,
498 backend: "not_loaded".to_string(),
499 threads,
500 max_length: configured_embedding_max_length(),
501 coreml_model_format: coreml_enabled.then(configured_coreml_model_format_name),
502 coreml_compute_units: coreml_enabled.then(configured_coreml_compute_units_name),
503 coreml_static_input_shapes: coreml_enabled.then(configured_coreml_static_input_shapes),
504 coreml_profile_compute_plan: coreml_enabled
505 .then(configured_coreml_profile_compute_plan),
506 coreml_specialization_strategy: coreml_enabled
507 .then(configured_coreml_specialization_strategy_name),
508 coreml_model_cache_dir: coreml_enabled
509 .then(|| configured_coreml_model_cache_dir().display().to_string()),
510 fallback_reason: None,
511 }
512 }
513
514 #[cfg(not(target_os = "macos"))]
515 {
516 EmbeddingRuntimeInfo {
517 runtime_preference,
518 backend: "not_loaded".to_string(),
519 threads,
520 max_length: configured_embedding_max_length(),
521 coreml_model_format: None,
522 coreml_compute_units: None,
523 coreml_static_input_shapes: None,
524 coreml_profile_compute_plan: None,
525 coreml_specialization_strategy: None,
526 coreml_model_cache_dir: None,
527 fallback_reason: None,
528 }
529 }
530}
531
532#[cfg(target_os = "macos")]
533fn build_coreml_execution_provider() -> ExecutionProviderDispatch {
534 use ort::ep::{
535 CoreML,
536 coreml::{ComputeUnits, ModelFormat, SpecializationStrategy},
537 };
538
539 let compute_units = match configured_coreml_compute_units_name().as_str() {
540 "all" => ComputeUnits::All,
541 "cpu_only" => ComputeUnits::CPUOnly,
542 "cpu_and_gpu" => ComputeUnits::CPUAndGPU,
543 _ => ComputeUnits::CPUAndNeuralEngine,
544 };
545 let model_format = match configured_coreml_model_format_name().as_str() {
546 "neural_network" => ModelFormat::NeuralNetwork,
547 _ => ModelFormat::MLProgram,
548 };
549 let specialization = match configured_coreml_specialization_strategy_name().as_str() {
550 "default" => SpecializationStrategy::Default,
551 _ => SpecializationStrategy::FastPrediction,
552 };
553 let cache_dir = configured_coreml_model_cache_dir();
554 let _ = std::fs::create_dir_all(&cache_dir);
555
556 CoreML::default()
557 .with_model_format(model_format)
558 .with_compute_units(compute_units)
559 .with_static_input_shapes(configured_coreml_static_input_shapes())
560 .with_specialization_strategy(specialization)
561 .with_profile_compute_plan(configured_coreml_profile_compute_plan())
562 .with_model_cache_dir(cache_dir.display().to_string())
563 .build()
564 .error_on_failure()
565}
566
567fn cpu_runtime_info(
568 runtime_preference: String,
569 fallback_reason: Option<String>,
570) -> EmbeddingRuntimeInfo {
571 EmbeddingRuntimeInfo {
572 runtime_preference,
573 backend: "cpu".to_string(),
574 threads: configured_embedding_threads(),
575 max_length: configured_embedding_max_length(),
576 coreml_model_format: None,
577 coreml_compute_units: None,
578 coreml_static_input_shapes: None,
579 coreml_profile_compute_plan: None,
580 coreml_specialization_strategy: None,
581 coreml_model_cache_dir: None,
582 fallback_reason,
583 }
584}
585
586#[cfg(target_os = "macos")]
587fn coreml_runtime_info(
588 runtime_preference: String,
589 fallback_reason: Option<String>,
590) -> EmbeddingRuntimeInfo {
591 EmbeddingRuntimeInfo {
592 runtime_preference,
593 backend: if fallback_reason.is_some() {
594 "cpu".to_string()
595 } else {
596 "coreml".to_string()
597 },
598 threads: configured_embedding_threads(),
599 max_length: configured_embedding_max_length(),
600 coreml_model_format: Some(configured_coreml_model_format_name()),
601 coreml_compute_units: Some(configured_coreml_compute_units_name()),
602 coreml_static_input_shapes: Some(configured_coreml_static_input_shapes()),
603 coreml_profile_compute_plan: Some(configured_coreml_profile_compute_plan()),
604 coreml_specialization_strategy: Some(configured_coreml_specialization_strategy_name()),
605 coreml_model_cache_dir: Some(configured_coreml_model_cache_dir().display().to_string()),
606 fallback_reason,
607 }
608}
609
610fn load_codesearch_model() -> Result<(TextEmbedding, usize, String, EmbeddingRuntimeInfo)> {
612 configure_embedding_runtime();
613 let model_dir = resolve_model_dir()?;
614
615 let onnx_bytes =
616 std::fs::read(model_dir.join("model.onnx")).context("failed to read model.onnx")?;
617 let tokenizer_bytes =
618 std::fs::read(model_dir.join("tokenizer.json")).context("failed to read tokenizer.json")?;
619 let config_bytes =
620 std::fs::read(model_dir.join("config.json")).context("failed to read config.json")?;
621 let special_tokens_bytes = std::fs::read(model_dir.join("special_tokens_map.json"))
622 .context("failed to read special_tokens_map.json")?;
623 let tokenizer_config_bytes = std::fs::read(model_dir.join("tokenizer_config.json"))
624 .context("failed to read tokenizer_config.json")?;
625
626 let user_model = UserDefinedEmbeddingModel::new(
627 onnx_bytes,
628 TokenizerFiles {
629 tokenizer_file: tokenizer_bytes,
630 config_file: config_bytes,
631 special_tokens_map_file: special_tokens_bytes,
632 tokenizer_config_file: tokenizer_config_bytes,
633 },
634 );
635
636 let runtime_preference = configured_embedding_runtime_preference();
637
638 #[cfg(target_os = "macos")]
639 if runtime_preference != "cpu" {
640 let init_opts = InitOptionsUserDefined::new()
641 .with_max_length(configured_embedding_max_length())
642 .with_execution_providers(vec![build_coreml_execution_provider()]);
643 match TextEmbedding::try_new_from_user_defined(user_model.clone(), init_opts) {
644 Ok(model) => {
645 let runtime_info = coreml_runtime_info(runtime_preference.clone(), None);
646 debug!(
647 threads = runtime_info.threads,
648 runtime_preference = %runtime_info.runtime_preference,
649 backend = %runtime_info.backend,
650 coreml_compute_units = ?runtime_info.coreml_compute_units,
651 coreml_static_input_shapes = ?runtime_info.coreml_static_input_shapes,
652 coreml_profile_compute_plan = ?runtime_info.coreml_profile_compute_plan,
653 coreml_specialization_strategy = ?runtime_info.coreml_specialization_strategy,
654 coreml_model_cache_dir = ?runtime_info.coreml_model_cache_dir,
655 "loaded CodeSearchNet embedding model"
656 );
657 return Ok((
658 model,
659 CODESEARCH_DIMENSION,
660 CODESEARCH_MODEL_NAME.to_string(),
661 runtime_info,
662 ));
663 }
664 Err(err) => {
665 let reason = err.to_string();
666 debug!(
667 runtime_preference = %runtime_preference,
668 fallback_reason = %reason,
669 "CoreML embedding load failed; falling back to CPU"
670 );
671 let model = TextEmbedding::try_new_from_user_defined(
672 user_model,
673 InitOptionsUserDefined::new()
674 .with_max_length(configured_embedding_max_length()),
675 )
676 .context("failed to load CodeSearchNet embedding model")?;
677 let runtime_info = coreml_runtime_info(runtime_preference.clone(), Some(reason));
678 debug!(
679 threads = runtime_info.threads,
680 runtime_preference = %runtime_info.runtime_preference,
681 backend = %runtime_info.backend,
682 coreml_compute_units = ?runtime_info.coreml_compute_units,
683 coreml_static_input_shapes = ?runtime_info.coreml_static_input_shapes,
684 coreml_profile_compute_plan = ?runtime_info.coreml_profile_compute_plan,
685 coreml_specialization_strategy = ?runtime_info.coreml_specialization_strategy,
686 coreml_model_cache_dir = ?runtime_info.coreml_model_cache_dir,
687 fallback_reason = ?runtime_info.fallback_reason,
688 "loaded CodeSearchNet embedding model"
689 );
690 return Ok((
691 model,
692 CODESEARCH_DIMENSION,
693 CODESEARCH_MODEL_NAME.to_string(),
694 runtime_info,
695 ));
696 }
697 }
698 }
699
700 let model = TextEmbedding::try_new_from_user_defined(
701 user_model,
702 InitOptionsUserDefined::new().with_max_length(configured_embedding_max_length()),
703 )
704 .context("failed to load CodeSearchNet embedding model")?;
705 let runtime_info = cpu_runtime_info(runtime_preference.clone(), None);
706
707 debug!(
708 threads = runtime_info.threads,
709 runtime_preference = %runtime_info.runtime_preference,
710 backend = %runtime_info.backend,
711 "loaded CodeSearchNet embedding model"
712 );
713
714 Ok((
715 model,
716 CODESEARCH_DIMENSION,
717 CODESEARCH_MODEL_NAME.to_string(),
718 runtime_info,
719 ))
720}
721
722pub fn configured_embedding_model_name() -> String {
723 std::env::var("CODELENS_EMBED_MODEL").unwrap_or_else(|_| CODESEARCH_MODEL_NAME.to_string())
724}
725
726pub fn embedding_model_assets_available() -> bool {
727 resolve_model_dir().is_ok()
728}
729
730impl EmbeddingEngine {
731 fn embed_texts_cached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
732 if texts.is_empty() {
733 return Ok(Vec::new());
734 }
735
736 let mut resolved: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
737 let mut missing_order: Vec<String> = Vec::new();
738 let mut missing_positions: HashMap<String, Vec<usize>> = HashMap::new();
739
740 {
741 let mut cache = self
742 .text_embed_cache
743 .lock()
744 .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
745 for (index, text) in texts.iter().enumerate() {
746 if let Some(cached) = cache.get(text) {
747 resolved[index] = Some(cached);
748 } else {
749 let key = (*text).to_owned();
750 if !missing_positions.contains_key(&key) {
751 missing_order.push(key.clone());
752 }
753 missing_positions.entry(key).or_default().push(index);
754 }
755 }
756 }
757
758 if !missing_order.is_empty() {
759 let missing_refs: Vec<&str> = missing_order.iter().map(String::as_str).collect();
760 let embeddings = self
761 .model
762 .lock()
763 .map_err(|_| anyhow::anyhow!("model lock"))?
764 .embed(missing_refs, None)
765 .context("text embedding failed")?;
766
767 let mut cache = self
768 .text_embed_cache
769 .lock()
770 .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
771 for (text, embedding) in missing_order.into_iter().zip(embeddings.into_iter()) {
772 cache.insert(text.clone(), embedding.clone());
773 if let Some(indices) = missing_positions.remove(&text) {
774 for index in indices {
775 resolved[index] = Some(embedding.clone());
776 }
777 }
778 }
779 }
780
781 resolved
782 .into_iter()
783 .map(|item| item.ok_or_else(|| anyhow::anyhow!("missing embedding cache entry")))
784 .collect()
785 }
786
787 pub fn new(project: &ProjectRoot) -> Result<Self> {
788 let (model, dimension, model_name, runtime_info) = load_codesearch_model()?;
789
790 let db_dir = project.as_path().join(".codelens/index");
791 std::fs::create_dir_all(&db_dir)?;
792 let db_path = db_dir.join("embeddings.db");
793
794 let store = SqliteVecStore::new(&db_path, dimension, &model_name)?;
795
796 Ok(Self {
797 model: Mutex::new(model),
798 store: Box::new(store),
799 model_name,
800 runtime_info,
801 text_embed_cache: Mutex::new(TextEmbeddingCache::new(
802 configured_embedding_text_cache_size(),
803 )),
804 indexing: std::sync::atomic::AtomicBool::new(false),
805 })
806 }
807
808 pub fn model_name(&self) -> &str {
809 &self.model_name
810 }
811
812 pub fn runtime_info(&self) -> &EmbeddingRuntimeInfo {
813 &self.runtime_info
814 }
815
816 pub fn is_indexing(&self) -> bool {
823 self.indexing.load(std::sync::atomic::Ordering::Relaxed)
824 }
825
826 pub fn index_from_project(&self, project: &ProjectRoot) -> Result<usize> {
827 if self
829 .indexing
830 .compare_exchange(
831 false,
832 true,
833 std::sync::atomic::Ordering::AcqRel,
834 std::sync::atomic::Ordering::Relaxed,
835 )
836 .is_err()
837 {
838 anyhow::bail!(
839 "Embedding indexing already in progress — wait for the current run to complete before retrying."
840 );
841 }
842 struct IndexGuard<'a>(&'a std::sync::atomic::AtomicBool);
844 impl Drop for IndexGuard<'_> {
845 fn drop(&mut self) {
846 self.0.store(false, std::sync::atomic::Ordering::Release);
847 }
848 }
849 let _guard = IndexGuard(&self.indexing);
850
851 let db_path = crate::db::index_db_path(project.as_path());
852 let symbol_db = IndexDb::open(&db_path)?;
853 let batch_size = embed_batch_size();
854 let max_symbols = max_embed_symbols();
855 let mut total_indexed = 0usize;
856 let mut total_seen = 0usize;
857 let mut model = None;
858 let mut existing_embeddings: HashMap<
859 String,
860 HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
861 > = HashMap::new();
862 let mut current_db_files = HashSet::new();
863 let mut capped = false;
864
865 self.store
866 .for_each_file_embeddings(&mut |file_path, chunks| {
867 existing_embeddings.insert(
868 file_path,
869 chunks
870 .into_iter()
871 .map(|chunk| (reusable_embedding_key_for_chunk(&chunk), chunk))
872 .collect(),
873 );
874 Ok(())
875 })?;
876
877 symbol_db.for_each_file_symbols_with_bytes(|file_path, symbols| {
878 current_db_files.insert(file_path.clone());
879 if capped {
880 return Ok(());
881 }
882
883 let source = std::fs::read_to_string(project.as_path().join(&file_path)).ok();
884 let relevant_symbols: Vec<_> = symbols
885 .into_iter()
886 .filter(|sym| !is_test_only_symbol(sym, source.as_deref()))
887 .collect();
888
889 if relevant_symbols.is_empty() {
890 self.store.delete_by_file(&[file_path.as_str()])?;
891 existing_embeddings.remove(&file_path);
892 return Ok(());
893 }
894
895 if total_seen + relevant_symbols.len() > max_symbols {
896 capped = true;
897 return Ok(());
898 }
899 total_seen += relevant_symbols.len();
900
901 let existing_for_file = existing_embeddings.remove(&file_path).unwrap_or_default();
902 total_indexed += self.reconcile_file_embeddings(
903 &file_path,
904 relevant_symbols,
905 source.as_deref(),
906 existing_for_file,
907 batch_size,
908 &mut model,
909 )?;
910 Ok(())
911 })?;
912
913 let removed_files: Vec<String> = existing_embeddings
914 .into_keys()
915 .filter(|file_path| !current_db_files.contains(file_path))
916 .collect();
917 if !removed_files.is_empty() {
918 let removed_refs: Vec<&str> = removed_files.iter().map(String::as_str).collect();
919 self.store.delete_by_file(&removed_refs)?;
920 }
921
922 Ok(total_indexed)
923 }
924
925 fn reconcile_file_embeddings<'a>(
926 &'a self,
927 file_path: &str,
928 symbols: Vec<crate::db::SymbolWithFile>,
929 source: Option<&str>,
930 mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
931 batch_size: usize,
932 model: &mut Option<std::sync::MutexGuard<'a, TextEmbedding>>,
933 ) -> Result<usize> {
934 let mut reconciled_chunks = Vec::with_capacity(symbols.len());
935 let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
936 let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
937
938 for sym in symbols {
939 let text = build_embedding_text(&sym, source);
940 if let Some(existing) =
941 existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
942 {
943 reconciled_chunks.push(EmbeddingChunk {
944 file_path: sym.file_path.clone(),
945 symbol_name: sym.name.clone(),
946 kind: sym.kind.clone(),
947 line: sym.line as usize,
948 signature: sym.signature.clone(),
949 name_path: sym.name_path.clone(),
950 text,
951 embedding: existing.embedding,
952 doc_embedding: existing.doc_embedding,
953 });
954 continue;
955 }
956
957 batch_texts.push(text);
958 batch_meta.push(sym);
959
960 if batch_texts.len() >= batch_size {
961 if model.is_none() {
962 *model = Some(
963 self.model
964 .lock()
965 .map_err(|_| anyhow::anyhow!("model lock"))?,
966 );
967 }
968 reconciled_chunks.extend(Self::embed_chunks(
969 model.as_mut().expect("model lock initialized"),
970 &batch_texts,
971 &batch_meta,
972 )?);
973 batch_texts.clear();
974 batch_meta.clear();
975 }
976 }
977
978 if !batch_texts.is_empty() {
979 if model.is_none() {
980 *model = Some(
981 self.model
982 .lock()
983 .map_err(|_| anyhow::anyhow!("model lock"))?,
984 );
985 }
986 reconciled_chunks.extend(Self::embed_chunks(
987 model.as_mut().expect("model lock initialized"),
988 &batch_texts,
989 &batch_meta,
990 )?);
991 }
992
993 self.store.delete_by_file(&[file_path])?;
994 if reconciled_chunks.is_empty() {
995 return Ok(0);
996 }
997 self.store.insert(&reconciled_chunks)
998 }
999
1000 fn embed_chunks(
1001 model: &mut TextEmbedding,
1002 texts: &[String],
1003 meta: &[crate::db::SymbolWithFile],
1004 ) -> Result<Vec<EmbeddingChunk>> {
1005 let batch_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
1006 let embeddings = model.embed(batch_refs, None).context("embedding failed")?;
1007
1008 Ok(meta
1009 .iter()
1010 .zip(embeddings)
1011 .zip(texts.iter())
1012 .map(|((sym, emb), text)| EmbeddingChunk {
1013 file_path: sym.file_path.clone(),
1014 symbol_name: sym.name.clone(),
1015 kind: sym.kind.clone(),
1016 line: sym.line as usize,
1017 signature: sym.signature.clone(),
1018 name_path: sym.name_path.clone(),
1019 text: text.clone(),
1020 embedding: emb,
1021 doc_embedding: None,
1022 })
1023 .collect())
1024 }
1025
1026 fn flush_batch(
1028 model: &mut TextEmbedding,
1029 store: &dyn EmbeddingStore,
1030 texts: &[String],
1031 meta: &[crate::db::SymbolWithFile],
1032 ) -> Result<usize> {
1033 let chunks = Self::embed_chunks(model, texts, meta)?;
1034 store.insert(&chunks)
1035 }
1036
1037 pub fn search(&self, query: &str, max_results: usize) -> Result<Vec<SemanticMatch>> {
1039 let results = self.search_scored(query, max_results)?;
1040 Ok(results.into_iter().map(SemanticMatch::from).collect())
1041 }
1042
1043 pub fn search_scored(&self, query: &str, max_results: usize) -> Result<Vec<ScoredChunk>> {
1047 let query_embedding = self.embed_texts_cached(&[query])?;
1048
1049 if query_embedding.is_empty() {
1050 return Ok(Vec::new());
1051 }
1052
1053 self.store.search(&query_embedding[0], max_results)
1054 }
1055
1056 pub fn index_changed_files(
1058 &self,
1059 project: &ProjectRoot,
1060 changed_files: &[&str],
1061 ) -> Result<usize> {
1062 if changed_files.is_empty() {
1063 return Ok(0);
1064 }
1065 let batch_size = embed_batch_size();
1066 let mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk> = HashMap::new();
1067 for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
1068 for chunk in self.store.embeddings_for_files(file_chunk)? {
1069 existing_embeddings.insert(reusable_embedding_key_for_chunk(&chunk), chunk);
1070 }
1071 }
1072 self.store.delete_by_file(changed_files)?;
1073
1074 let db_path = crate::db::index_db_path(project.as_path());
1075 let symbol_db = IndexDb::open(&db_path)?;
1076
1077 let mut total_indexed = 0usize;
1078 let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
1079 let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
1080 let mut batch_reused: Vec<EmbeddingChunk> = Vec::with_capacity(batch_size);
1081 let mut file_cache: std::collections::HashMap<String, Option<String>> =
1082 std::collections::HashMap::new();
1083 let mut model = None;
1084
1085 for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
1086 let relevant = symbol_db.symbols_for_files(file_chunk)?;
1087 for sym in relevant {
1088 let source = file_cache.entry(sym.file_path.clone()).or_insert_with(|| {
1089 std::fs::read_to_string(project.as_path().join(&sym.file_path)).ok()
1090 });
1091 if is_test_only_symbol(&sym, source.as_deref()) {
1092 continue;
1093 }
1094 let text = build_embedding_text(&sym, source.as_deref());
1095 if let Some(existing) =
1096 existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
1097 {
1098 batch_reused.push(EmbeddingChunk {
1099 file_path: sym.file_path.clone(),
1100 symbol_name: sym.name.clone(),
1101 kind: sym.kind.clone(),
1102 line: sym.line as usize,
1103 signature: sym.signature.clone(),
1104 name_path: sym.name_path.clone(),
1105 text,
1106 embedding: existing.embedding,
1107 doc_embedding: existing.doc_embedding,
1108 });
1109 if batch_reused.len() >= batch_size {
1110 total_indexed += self.store.insert(&batch_reused)?;
1111 batch_reused.clear();
1112 }
1113 continue;
1114 }
1115 batch_texts.push(text);
1116 batch_meta.push(sym);
1117
1118 if batch_texts.len() >= batch_size {
1119 if model.is_none() {
1120 model = Some(
1121 self.model
1122 .lock()
1123 .map_err(|_| anyhow::anyhow!("model lock"))?,
1124 );
1125 }
1126 total_indexed += Self::flush_batch(
1127 model.as_mut().expect("model lock initialized"),
1128 &*self.store,
1129 &batch_texts,
1130 &batch_meta,
1131 )?;
1132 batch_texts.clear();
1133 batch_meta.clear();
1134 }
1135 }
1136 }
1137
1138 if !batch_reused.is_empty() {
1139 total_indexed += self.store.insert(&batch_reused)?;
1140 }
1141
1142 if !batch_texts.is_empty() {
1143 if model.is_none() {
1144 model = Some(
1145 self.model
1146 .lock()
1147 .map_err(|_| anyhow::anyhow!("model lock"))?,
1148 );
1149 }
1150 total_indexed += Self::flush_batch(
1151 model.as_mut().expect("model lock initialized"),
1152 &*self.store,
1153 &batch_texts,
1154 &batch_meta,
1155 )?;
1156 }
1157
1158 Ok(total_indexed)
1159 }
1160
1161 pub fn is_indexed(&self) -> bool {
1163 self.store.count().unwrap_or(0) > 0
1164 }
1165
1166 pub fn index_info(&self) -> EmbeddingIndexInfo {
1167 EmbeddingIndexInfo {
1168 model_name: self.model_name.clone(),
1169 indexed_symbols: self.store.count().unwrap_or(0),
1170 }
1171 }
1172
1173 pub fn inspect_existing_index(project: &ProjectRoot) -> Result<Option<EmbeddingIndexInfo>> {
1174 let db_path = project.as_path().join(".codelens/index/embeddings.db");
1175 if !db_path.exists() {
1176 return Ok(None);
1177 }
1178
1179 let conn =
1180 crate::db::open_derived_sqlite_with_recovery(&db_path, "embedding index", || {
1181 ffi::register_sqlite_vec()?;
1182 let conn = Connection::open(&db_path)?;
1183 conn.execute_batch("PRAGMA busy_timeout=5000;")?;
1184 conn.query_row("PRAGMA schema_version", [], |_row| Ok(()))?;
1185 Ok(conn)
1186 })?;
1187
1188 let model_name: Option<String> = conn
1189 .query_row(
1190 "SELECT value FROM meta WHERE key = 'model' LIMIT 1",
1191 [],
1192 |row| row.get(0),
1193 )
1194 .ok();
1195 let indexed_symbols: usize = conn
1196 .query_row("SELECT COUNT(*) FROM symbols", [], |row| {
1197 row.get::<_, i64>(0)
1198 })
1199 .map(|count| count.max(0) as usize)
1200 .unwrap_or(0);
1201
1202 Ok(model_name.map(|model_name| EmbeddingIndexInfo {
1203 model_name,
1204 indexed_symbols,
1205 }))
1206 }
1207
1208 pub fn find_similar_code(
1212 &self,
1213 file_path: &str,
1214 symbol_name: &str,
1215 max_results: usize,
1216 ) -> Result<Vec<SemanticMatch>> {
1217 let target = self
1218 .store
1219 .get_embedding(file_path, symbol_name)?
1220 .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?;
1221
1222 let oversample = max_results.saturating_add(8).max(1);
1223 let scored = self
1224 .store
1225 .search(&target.embedding, oversample)?
1226 .into_iter()
1227 .filter(|c| !(c.file_path == file_path && c.symbol_name == symbol_name))
1228 .take(max_results)
1229 .map(SemanticMatch::from)
1230 .collect();
1231 Ok(scored)
1232 }
1233
1234 pub fn find_duplicates(&self, threshold: f64, max_pairs: usize) -> Result<Vec<DuplicatePair>> {
1237 let mut pairs = Vec::new();
1238 let mut seen_pairs = HashSet::new();
1239 let mut embedding_cache: HashMap<StoredChunkKey, Arc<EmbeddingChunk>> = HashMap::new();
1240 let candidate_limit = duplicate_candidate_limit(max_pairs);
1241 let mut done = false;
1242
1243 self.store
1244 .for_each_embedding_batch(DEFAULT_DUPLICATE_SCAN_BATCH_SIZE, &mut |batch| {
1245 if done {
1246 return Ok(());
1247 }
1248
1249 let mut candidate_lists = Vec::with_capacity(batch.len());
1250 let mut missing_candidates = Vec::new();
1251 let mut missing_keys = HashSet::new();
1252
1253 for chunk in &batch {
1254 if pairs.len() >= max_pairs {
1255 done = true;
1256 break;
1257 }
1258
1259 let filtered: Vec<ScoredChunk> = self
1260 .store
1261 .search(&chunk.embedding, candidate_limit)?
1262 .into_iter()
1263 .filter(|candidate| {
1264 !(chunk.file_path == candidate.file_path
1265 && chunk.symbol_name == candidate.symbol_name
1266 && chunk.line == candidate.line
1267 && chunk.signature == candidate.signature
1268 && chunk.name_path == candidate.name_path)
1269 })
1270 .collect();
1271
1272 for candidate in &filtered {
1273 let cache_key = stored_chunk_key_for_score(candidate);
1274 if !embedding_cache.contains_key(&cache_key)
1275 && missing_keys.insert(cache_key)
1276 {
1277 missing_candidates.push(candidate.clone());
1278 }
1279 }
1280
1281 candidate_lists.push(filtered);
1282 }
1283
1284 if !missing_candidates.is_empty() {
1285 for candidate_chunk in self
1286 .store
1287 .embeddings_for_scored_chunks(&missing_candidates)?
1288 {
1289 embedding_cache
1290 .entry(stored_chunk_key(&candidate_chunk))
1291 .or_insert_with(|| Arc::new(candidate_chunk));
1292 }
1293 }
1294
1295 for (chunk, candidates) in batch.iter().zip(candidate_lists.iter()) {
1296 if pairs.len() >= max_pairs {
1297 done = true;
1298 break;
1299 }
1300
1301 for candidate in candidates {
1302 let pair_key = duplicate_pair_key(
1303 &chunk.file_path,
1304 &chunk.symbol_name,
1305 &candidate.file_path,
1306 &candidate.symbol_name,
1307 );
1308 if !seen_pairs.insert(pair_key) {
1309 continue;
1310 }
1311
1312 let Some(candidate_chunk) =
1313 embedding_cache.get(&stored_chunk_key_for_score(candidate))
1314 else {
1315 continue;
1316 };
1317
1318 let sim = cosine_similarity(&chunk.embedding, &candidate_chunk.embedding);
1319 if sim < threshold {
1320 continue;
1321 }
1322
1323 pairs.push(DuplicatePair {
1324 symbol_a: format!("{}:{}", chunk.file_path, chunk.symbol_name),
1325 symbol_b: format!(
1326 "{}:{}",
1327 candidate_chunk.file_path, candidate_chunk.symbol_name
1328 ),
1329 file_a: chunk.file_path.clone(),
1330 file_b: candidate_chunk.file_path.clone(),
1331 line_a: chunk.line,
1332 line_b: candidate_chunk.line,
1333 similarity: sim,
1334 });
1335 if pairs.len() >= max_pairs {
1336 done = true;
1337 break;
1338 }
1339 }
1340 }
1341 Ok(())
1342 })?;
1343
1344 pairs.sort_by(|a, b| {
1345 b.similarity
1346 .partial_cmp(&a.similarity)
1347 .unwrap_or(std::cmp::Ordering::Equal)
1348 });
1349 Ok(pairs)
1350 }
1351}
1352
1353fn duplicate_candidate_limit(max_pairs: usize) -> usize {
1354 max_pairs.saturating_mul(4).clamp(32, 128)
1355}
1356
1357fn duplicate_pair_key(
1358 file_a: &str,
1359 symbol_a: &str,
1360 file_b: &str,
1361 symbol_b: &str,
1362) -> ((String, String), (String, String)) {
1363 let left = (file_a.to_owned(), symbol_a.to_owned());
1364 let right = (file_b.to_owned(), symbol_b.to_owned());
1365 if left <= right {
1366 (left, right)
1367 } else {
1368 (right, left)
1369 }
1370}
1371
1372type StoredChunkKey = (String, String, usize, String, String);
1373
1374fn stored_chunk_key(chunk: &EmbeddingChunk) -> StoredChunkKey {
1375 (
1376 chunk.file_path.clone(),
1377 chunk.symbol_name.clone(),
1378 chunk.line,
1379 chunk.signature.clone(),
1380 chunk.name_path.clone(),
1381 )
1382}
1383
1384fn stored_chunk_key_for_score(chunk: &ScoredChunk) -> StoredChunkKey {
1385 (
1386 chunk.file_path.clone(),
1387 chunk.symbol_name.clone(),
1388 chunk.line,
1389 chunk.signature.clone(),
1390 chunk.name_path.clone(),
1391 )
1392}
1393
1394impl EmbeddingEngine {
1395 pub fn classify_symbol(
1397 &self,
1398 file_path: &str,
1399 symbol_name: &str,
1400 categories: &[&str],
1401 ) -> Result<Vec<CategoryScore>> {
1402 let target = match self.store.get_embedding(file_path, symbol_name)? {
1403 Some(target) => target,
1404 None => self
1405 .store
1406 .all_with_embeddings()?
1407 .into_iter()
1408 .find(|c| c.file_path == file_path && c.symbol_name == symbol_name)
1409 .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?,
1410 };
1411
1412 let embeddings = self.embed_texts_cached(categories)?;
1413
1414 let mut scores: Vec<CategoryScore> = categories
1415 .iter()
1416 .zip(embeddings.iter())
1417 .map(|(cat, emb)| CategoryScore {
1418 category: cat.to_string(),
1419 score: cosine_similarity(&target.embedding, emb),
1420 })
1421 .collect();
1422
1423 scores.sort_by(|a, b| {
1424 b.score
1425 .partial_cmp(&a.score)
1426 .unwrap_or(std::cmp::Ordering::Equal)
1427 });
1428 Ok(scores)
1429 }
1430
1431 pub fn find_misplaced_code(&self, max_results: usize) -> Result<Vec<OutlierSymbol>> {
1433 let mut outliers = Vec::new();
1434
1435 self.store
1436 .for_each_file_embeddings(&mut |file_path, chunks| {
1437 if chunks.len() < 2 {
1438 return Ok(());
1439 }
1440
1441 for (idx, chunk) in chunks.iter().enumerate() {
1442 let mut sim_sum = 0.0;
1443 let mut count = 0;
1444 for (other_idx, other_chunk) in chunks.iter().enumerate() {
1445 if other_idx == idx {
1446 continue;
1447 }
1448 sim_sum += cosine_similarity(&chunk.embedding, &other_chunk.embedding);
1449 count += 1;
1450 }
1451 if count > 0 {
1452 let avg_sim = sim_sum / count as f64; outliers.push(OutlierSymbol {
1454 file_path: file_path.clone(),
1455 symbol_name: chunk.symbol_name.clone(),
1456 kind: chunk.kind.clone(),
1457 line: chunk.line,
1458 avg_similarity_to_file: avg_sim,
1459 });
1460 }
1461 }
1462 Ok(())
1463 })?;
1464
1465 outliers.sort_by(|a, b| {
1466 a.avg_similarity_to_file
1467 .partial_cmp(&b.avg_similarity_to_file)
1468 .unwrap_or(std::cmp::Ordering::Equal)
1469 });
1470 outliers.truncate(max_results);
1471 Ok(outliers)
1472 }
1473}
1474
1475#[derive(Debug, Clone, Serialize)]
1478pub struct DuplicatePair {
1479 pub symbol_a: String,
1480 pub symbol_b: String,
1481 pub file_a: String,
1482 pub file_b: String,
1483 pub line_a: usize,
1484 pub line_b: usize,
1485 pub similarity: f64,
1486}
1487
1488#[derive(Debug, Clone, Serialize)]
1489pub struct CategoryScore {
1490 pub category: String,
1491 pub score: f64,
1492}
1493
1494#[derive(Debug, Clone, Serialize)]
1495pub struct OutlierSymbol {
1496 pub file_path: String,
1497 pub symbol_name: String,
1498 pub kind: String,
1499 pub line: usize,
1500 pub avg_similarity_to_file: f64,
1501}
1502
1503fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
1508 debug_assert_eq!(a.len(), b.len());
1509
1510 let (mut dot, mut norm_a, mut norm_b) = (0.0f32, 0.0f32, 0.0f32);
1513 for (x, y) in a.iter().zip(b.iter()) {
1514 dot += x * y;
1515 norm_a += x * x;
1516 norm_b += y * y;
1517 }
1518
1519 let norm_a = (norm_a as f64).sqrt();
1520 let norm_b = (norm_b as f64).sqrt();
1521 if norm_a == 0.0 || norm_b == 0.0 {
1522 0.0
1523 } else {
1524 dot as f64 / (norm_a * norm_b)
1525 }
1526}
1527
1528fn split_identifier(name: &str) -> String {
1543 if !name.contains('_') && !name.chars().any(|c| c.is_uppercase()) {
1545 return name.to_string();
1546 }
1547 let mut words = Vec::new();
1548 let mut current = String::new();
1549 let chars: Vec<char> = name.chars().collect();
1550 for (i, &ch) in chars.iter().enumerate() {
1551 if ch == '_' {
1552 if !current.is_empty() {
1553 words.push(current.clone());
1554 current.clear();
1555 }
1556 } else if ch.is_uppercase()
1557 && !current.is_empty()
1558 && (current
1559 .chars()
1560 .last()
1561 .map(|c| c.is_lowercase())
1562 .unwrap_or(false)
1563 || chars.get(i + 1).map(|c| c.is_lowercase()).unwrap_or(false))
1564 {
1565 words.push(current.clone());
1567 current.clear();
1568 current.push(ch);
1569 } else {
1570 current.push(ch);
1571 }
1572 }
1573 if !current.is_empty() {
1574 words.push(current);
1575 }
1576 if words.len() <= 1 {
1577 return name.to_string(); }
1579 words.join(" ")
1580}
1581
1582fn is_test_only_symbol(sym: &crate::db::SymbolWithFile, source: Option<&str>) -> bool {
1583 if sym.file_path.contains("/tests/") || sym.file_path.ends_with("_tests.rs") {
1584 return true;
1585 }
1586 if sym.name_path.starts_with("tests::")
1587 || sym.name_path.contains("::tests::")
1588 || sym.name_path.starts_with("test::")
1589 || sym.name_path.contains("::test::")
1590 {
1591 return true;
1592 }
1593
1594 let Some(source) = source else {
1595 return false;
1596 };
1597
1598 let start = usize::try_from(sym.start_byte.max(0))
1599 .unwrap_or(0)
1600 .min(source.len());
1601 let window_start = start.saturating_sub(2048);
1602 let attrs = String::from_utf8_lossy(&source.as_bytes()[window_start..start]);
1603 attrs.contains("#[test]")
1604 || attrs.contains("#[tokio::test]")
1605 || attrs.contains("#[cfg(test)]")
1606 || attrs.contains("#[cfg(all(test")
1607}
1608
1609fn build_embedding_text(sym: &crate::db::SymbolWithFile, source: Option<&str>) -> String {
1610 let file_ctx = if sym.file_path.is_empty() {
1611 String::new()
1612 } else {
1613 format!(" in {}", sym.file_path)
1614 };
1615
1616 let split_name = split_identifier(&sym.name);
1619 let name_with_split = if split_name != sym.name {
1620 format!("{} ({})", sym.name, split_name)
1621 } else {
1622 sym.name.clone()
1623 };
1624
1625 let parent_ctx = if !sym.name_path.is_empty() && sym.name_path.contains('/') {
1627 let parent = sym.name_path.rsplit_once('/').map(|x| x.0).unwrap_or("");
1628 if parent.is_empty() {
1629 String::new()
1630 } else {
1631 format!(" (in {})", parent)
1632 }
1633 } else {
1634 String::new()
1635 };
1636
1637 let base = if sym.signature.is_empty() {
1638 format!("{} {}{}{}", sym.kind, name_with_split, parent_ctx, file_ctx)
1639 } else {
1640 format!(
1641 "{} {}{}{}: {}",
1642 sym.kind, name_with_split, parent_ctx, file_ctx, sym.signature
1643 )
1644 };
1645
1646 let docstrings_disabled = std::env::var("CODELENS_EMBED_DOCSTRINGS")
1650 .map(|v| v == "0" || v == "false")
1651 .unwrap_or(false);
1652
1653 if docstrings_disabled {
1654 return base;
1655 }
1656
1657 let docstring = source
1658 .and_then(|src| extract_leading_doc(src, sym.start_byte as usize, sym.end_byte as usize))
1659 .unwrap_or_default();
1660
1661 let mut text = if docstring.is_empty() {
1662 let body_hint = source
1667 .and_then(|src| extract_body_hint(src, sym.start_byte as usize, sym.end_byte as usize))
1668 .unwrap_or_default();
1669 if body_hint.is_empty() {
1670 base
1671 } else {
1672 format!("{} — {}", base, body_hint)
1673 }
1674 } else {
1675 let line_budget = hint_line_budget();
1680 let lines: Vec<String> = docstring
1681 .lines()
1682 .map(str::trim)
1683 .filter(|line| !line.is_empty())
1684 .take(line_budget)
1685 .map(str::to_string)
1686 .collect();
1687 let hint = join_hint_lines(&lines);
1688 if hint.is_empty() {
1689 base
1690 } else {
1691 format!("{} — {}", base, hint)
1692 }
1693 };
1694
1695 if let Some(src) = source
1699 && let Some(nl_tokens) =
1700 extract_nl_tokens(src, sym.start_byte as usize, sym.end_byte as usize)
1701 && !nl_tokens.is_empty()
1702 {
1703 text.push_str(" · NL: ");
1704 text.push_str(&nl_tokens);
1705 }
1706
1707 if let Some(src) = source
1712 && let Some(api_calls) =
1713 extract_api_calls(src, sym.start_byte as usize, sym.end_byte as usize)
1714 && !api_calls.is_empty()
1715 {
1716 text.push_str(" · API: ");
1717 text.push_str(&api_calls);
1718 }
1719
1720 text
1721}
1722
1723const DEFAULT_HINT_TOTAL_CHAR_BUDGET: usize = 60;
1736
1737const DEFAULT_HINT_LINES: usize = 1;
1740
1741fn hint_char_budget() -> usize {
1742 std::env::var("CODELENS_EMBED_HINT_CHARS")
1743 .ok()
1744 .and_then(|raw| raw.parse::<usize>().ok())
1745 .map(|n| n.clamp(60, 512))
1746 .unwrap_or(DEFAULT_HINT_TOTAL_CHAR_BUDGET)
1747}
1748
1749fn hint_line_budget() -> usize {
1750 std::env::var("CODELENS_EMBED_HINT_LINES")
1751 .ok()
1752 .and_then(|raw| raw.parse::<usize>().ok())
1753 .map(|n| n.clamp(1, 10))
1754 .unwrap_or(DEFAULT_HINT_LINES)
1755}
1756
1757fn join_hint_lines(lines: &[String]) -> String {
1764 if lines.is_empty() {
1765 return String::new();
1766 }
1767 let joined = lines
1768 .iter()
1769 .map(String::as_str)
1770 .collect::<Vec<_>>()
1771 .join(" · ");
1772 let budget = hint_char_budget();
1773 if joined.chars().count() > budget {
1774 let truncated: String = joined.chars().take(budget).collect();
1775 format!("{truncated}...")
1776 } else {
1777 joined
1778 }
1779}
1780
1781fn extract_body_hint(source: &str, start: usize, end: usize) -> Option<String> {
1791 if start >= source.len() || end > source.len() || start >= end {
1792 return None;
1793 }
1794 let safe_start = if source.is_char_boundary(start) {
1795 start
1796 } else {
1797 source.floor_char_boundary(start)
1798 };
1799 let safe_end = end.min(source.len());
1800 let safe_end = if source.is_char_boundary(safe_end) {
1801 safe_end
1802 } else {
1803 source.floor_char_boundary(safe_end)
1804 };
1805 let body = &source[safe_start..safe_end];
1806
1807 let max_lines = hint_line_budget();
1808 let mut collected: Vec<String> = Vec::with_capacity(max_lines);
1809
1810 let mut past_signature = false;
1813 for line in body.lines() {
1814 let trimmed = line.trim();
1815 if !past_signature {
1816 if trimmed.ends_with('{') || trimmed.ends_with(':') || trimmed == "{" {
1818 past_signature = true;
1819 }
1820 continue;
1821 }
1822 if trimmed.is_empty()
1824 || trimmed.starts_with("//")
1825 || trimmed.starts_with('#')
1826 || trimmed.starts_with("/*")
1827 || trimmed.starts_with('*')
1828 || trimmed == "}"
1829 {
1830 continue;
1831 }
1832 collected.push(trimmed.to_string());
1833 if collected.len() >= max_lines {
1834 break;
1835 }
1836 }
1837
1838 if collected.is_empty() {
1839 None
1840 } else {
1841 Some(join_hint_lines(&collected))
1842 }
1843}
1844
1845fn nl_tokens_enabled() -> bool {
1855 if let Some(explicit) = parse_bool_env("CODELENS_EMBED_HINT_INCLUDE_COMMENTS") {
1856 return explicit;
1857 }
1858 auto_hint_should_enable()
1859}
1860
1861pub(super) fn auto_hint_mode_enabled() -> bool {
1903 parse_bool_env("CODELENS_EMBED_HINT_AUTO").unwrap_or(true)
1904}
1905
1906pub(super) fn auto_hint_lang() -> Option<String> {
1917 std::env::var("CODELENS_EMBED_HINT_AUTO_LANG")
1918 .ok()
1919 .map(|raw| raw.trim().to_ascii_lowercase())
1920}
1921
1922pub(super) fn language_supports_nl_stack(lang: &str) -> bool {
1958 matches!(
1959 lang.trim().to_ascii_lowercase().as_str(),
1960 "rs" | "rust"
1961 | "cpp"
1962 | "cc"
1963 | "cxx"
1964 | "c++"
1965 | "c"
1966 | "go"
1967 | "golang"
1968 | "java"
1969 | "kt"
1970 | "kotlin"
1971 | "scala"
1972 | "cs"
1973 | "csharp"
1974 | "ts"
1975 | "typescript"
1976 | "tsx"
1977 | "js"
1978 | "javascript"
1979 | "jsx"
1980 )
1981}
1982
1983pub(super) fn language_supports_sparse_weighting(lang: &str) -> bool {
2001 matches!(
2002 lang.trim().to_ascii_lowercase().as_str(),
2003 "rs" | "rust"
2004 | "cpp"
2005 | "cc"
2006 | "cxx"
2007 | "c++"
2008 | "c"
2009 | "go"
2010 | "golang"
2011 | "java"
2012 | "kt"
2013 | "kotlin"
2014 | "scala"
2015 | "cs"
2016 | "csharp"
2017 )
2018}
2019
2020pub(super) fn auto_hint_should_enable() -> bool {
2025 if !auto_hint_mode_enabled() {
2026 return false;
2027 }
2028 match auto_hint_lang() {
2029 Some(lang) => language_supports_nl_stack(&lang),
2030 None => false, }
2032}
2033
2034pub(super) fn auto_sparse_should_enable() -> bool {
2041 if !auto_hint_mode_enabled() {
2042 return false;
2043 }
2044 match auto_hint_lang() {
2045 Some(lang) => language_supports_sparse_weighting(&lang),
2046 None => false,
2047 }
2048}
2049
2050pub(super) fn is_nl_shaped(s: &str) -> bool {
2059 let s = s.trim();
2060 if s.chars().count() < 4 {
2061 return false;
2062 }
2063 if s.contains('/') || s.contains('\\') || s.contains("::") {
2064 return false;
2065 }
2066 if !s.contains(' ') {
2067 return false;
2068 }
2069 let non_ws: usize = s.chars().filter(|c| !c.is_whitespace()).count();
2070 if non_ws == 0 {
2071 return false;
2072 }
2073 let alpha: usize = s.chars().filter(|c| c.is_alphabetic()).count();
2074 (alpha * 100) / non_ws >= 60
2075}
2076
2077fn strict_comments_enabled() -> bool {
2092 std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS")
2093 .map(|raw| {
2094 let lowered = raw.to_ascii_lowercase();
2095 matches!(lowered.as_str(), "1" | "true" | "yes" | "on")
2096 })
2097 .unwrap_or(false)
2098}
2099
2100pub(super) fn looks_like_meta_annotation(body: &str) -> bool {
2121 let trimmed = body.trim_start();
2122 let word_end = trimmed
2125 .find(|c: char| !c.is_ascii_alphabetic())
2126 .unwrap_or(trimmed.len());
2127 if word_end == 0 {
2128 return false;
2129 }
2130 let first_word = &trimmed[..word_end];
2131 let upper = first_word.to_ascii_uppercase();
2132 matches!(
2133 upper.as_str(),
2134 "TODO"
2135 | "FIXME"
2136 | "HACK"
2137 | "XXX"
2138 | "BUG"
2139 | "REVIEW"
2140 | "REFACTOR"
2141 | "TEMP"
2142 | "TEMPORARY"
2143 | "DEPRECATED"
2144 )
2145}
2146
2147fn strict_literal_filter_enabled() -> bool {
2162 std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS")
2163 .map(|raw| {
2164 let lowered = raw.to_ascii_lowercase();
2165 matches!(lowered.as_str(), "1" | "true" | "yes" | "on")
2166 })
2167 .unwrap_or(false)
2168}
2169
2170pub(super) fn contains_format_specifier(s: &str) -> bool {
2182 let bytes = s.as_bytes();
2183 let len = bytes.len();
2184 let mut i = 0;
2185 while i + 1 < len {
2186 if bytes[i] == b'%' {
2187 let next = bytes[i + 1];
2188 if matches!(next, b's' | b'd' | b'r' | b'f' | b'x' | b'o' | b'i' | b'u') {
2189 return true;
2190 }
2191 }
2192 i += 1;
2193 }
2194 for window in s.split('{').skip(1) {
2202 let Some(close_idx) = window.find('}') else {
2203 continue;
2204 };
2205 let inside = &window[..close_idx];
2206 if inside.is_empty() {
2208 return true;
2209 }
2210 if inside.chars().any(|c| c.is_whitespace()) {
2212 continue;
2213 }
2214 if inside.starts_with(':') {
2216 return true;
2217 }
2218 let ident_end = inside.find(':').unwrap_or(inside.len());
2222 let ident = &inside[..ident_end];
2223 if !ident.is_empty()
2224 && ident
2225 .chars()
2226 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
2227 {
2228 return true;
2229 }
2230 }
2231 false
2232}
2233
2234pub(super) fn looks_like_error_or_log_prefix(s: &str) -> bool {
2245 let lower = s.trim().to_lowercase();
2246 const PREFIXES: &[&str] = &[
2247 "invalid ",
2248 "cannot ",
2249 "could not ",
2250 "unable to ",
2251 "failed to ",
2252 "expected ",
2253 "unexpected ",
2254 "missing ",
2255 "not found",
2256 "error: ",
2257 "error ",
2258 "warning: ",
2259 "warning ",
2260 "sending ",
2261 "received ",
2262 "starting ",
2263 "stopping ",
2264 "calling ",
2265 "connecting ",
2266 "disconnecting ",
2267 ];
2268 PREFIXES.iter().any(|p| lower.starts_with(p))
2269}
2270
2271#[cfg(test)]
2276pub(super) fn should_reject_literal_strict(s: &str) -> bool {
2277 contains_format_specifier(s) || looks_like_error_or_log_prefix(s)
2278}
2279
2280fn extract_nl_tokens(source: &str, start: usize, end: usize) -> Option<String> {
2294 if !nl_tokens_enabled() {
2295 return None;
2296 }
2297 extract_nl_tokens_inner(source, start, end)
2298}
2299
2300pub(super) fn extract_nl_tokens_inner(source: &str, start: usize, end: usize) -> Option<String> {
2305 if start >= source.len() || end > source.len() || start >= end {
2306 return None;
2307 }
2308 let safe_start = if source.is_char_boundary(start) {
2309 start
2310 } else {
2311 source.floor_char_boundary(start)
2312 };
2313 let safe_end = end.min(source.len());
2314 let safe_end = if source.is_char_boundary(safe_end) {
2315 safe_end
2316 } else {
2317 source.floor_char_boundary(safe_end)
2318 };
2319 let body = &source[safe_start..safe_end];
2320
2321 let mut tokens: Vec<String> = Vec::new();
2322
2323 let strict_comments = strict_comments_enabled();
2331 for line in body.lines() {
2332 let trimmed = line.trim();
2333 if let Some(cleaned) = extract_comment_body(trimmed)
2334 && is_nl_shaped(&cleaned)
2335 && (!strict_comments || !looks_like_meta_annotation(&cleaned))
2336 {
2337 tokens.push(cleaned);
2338 }
2339 }
2340
2341 let strict_literals = strict_literal_filter_enabled();
2351 let mut chars = body.chars().peekable();
2352 let mut in_string = false;
2353 let mut current = String::new();
2354 while let Some(c) = chars.next() {
2355 if in_string {
2356 if c == '\\' {
2357 let _ = chars.next();
2359 } else if c == '"' {
2360 if is_nl_shaped(¤t)
2361 && (!strict_literals
2362 || (!contains_format_specifier(¤t)
2363 && !looks_like_error_or_log_prefix(¤t)))
2364 {
2365 tokens.push(current.clone());
2366 }
2367 current.clear();
2368 in_string = false;
2369 } else {
2370 current.push(c);
2371 }
2372 } else if c == '"' {
2373 in_string = true;
2374 }
2375 }
2376
2377 if tokens.is_empty() {
2378 return None;
2379 }
2380 Some(join_hint_lines(&tokens))
2381}
2382
2383fn api_calls_enabled() -> bool {
2392 if let Some(explicit) = parse_bool_env("CODELENS_EMBED_HINT_INCLUDE_API_CALLS") {
2393 return explicit;
2394 }
2395 auto_hint_should_enable()
2396}
2397
2398pub(super) fn is_static_method_ident(ident: &str) -> bool {
2408 ident.chars().next().is_some_and(|c| c.is_ascii_uppercase())
2409}
2410
2411fn extract_api_calls(source: &str, start: usize, end: usize) -> Option<String> {
2423 if !api_calls_enabled() {
2424 return None;
2425 }
2426 extract_api_calls_inner(source, start, end)
2427}
2428
2429pub(super) fn extract_api_calls_inner(source: &str, start: usize, end: usize) -> Option<String> {
2443 if start >= source.len() || end > source.len() || start >= end {
2444 return None;
2445 }
2446 let safe_start = if source.is_char_boundary(start) {
2447 start
2448 } else {
2449 source.floor_char_boundary(start)
2450 };
2451 let safe_end = end.min(source.len());
2452 let safe_end = if source.is_char_boundary(safe_end) {
2453 safe_end
2454 } else {
2455 source.floor_char_boundary(safe_end)
2456 };
2457 if safe_start >= safe_end {
2458 return None;
2459 }
2460 let body = &source[safe_start..safe_end];
2461 let bytes = body.as_bytes();
2462 let len = bytes.len();
2463
2464 let mut calls: Vec<String> = Vec::new();
2465 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
2466
2467 let mut i = 0usize;
2468 while i < len {
2469 let b = bytes[i];
2470 if !(b == b'_' || b.is_ascii_alphabetic()) {
2472 i += 1;
2473 continue;
2474 }
2475 let ident_start = i;
2476 while i < len {
2477 let bb = bytes[i];
2478 if bb == b'_' || bb.is_ascii_alphanumeric() {
2479 i += 1;
2480 } else {
2481 break;
2482 }
2483 }
2484 let ident_end = i;
2485
2486 if i + 1 >= len || bytes[i] != b':' || bytes[i + 1] != b':' {
2488 continue;
2489 }
2490
2491 let type_ident = &body[ident_start..ident_end];
2492 if !is_static_method_ident(type_ident) {
2493 i += 2;
2496 continue;
2497 }
2498
2499 let mut j = i + 2;
2501 if j >= len || !(bytes[j] == b'_' || bytes[j].is_ascii_alphabetic()) {
2502 i = j;
2503 continue;
2504 }
2505 let method_start = j;
2506 while j < len {
2507 let bb = bytes[j];
2508 if bb == b'_' || bb.is_ascii_alphanumeric() {
2509 j += 1;
2510 } else {
2511 break;
2512 }
2513 }
2514 let method_end = j;
2515
2516 let method_ident = &body[method_start..method_end];
2517 let call = format!("{type_ident}::{method_ident}");
2518 if seen.insert(call.clone()) {
2519 calls.push(call);
2520 }
2521 i = j;
2522 }
2523
2524 if calls.is_empty() {
2525 return None;
2526 }
2527 Some(join_hint_lines(&calls))
2528}
2529
2530fn extract_comment_body(trimmed: &str) -> Option<String> {
2533 if trimmed.is_empty() {
2534 return None;
2535 }
2536 if let Some(rest) = trimmed.strip_prefix("///") {
2538 return Some(rest.trim().to_string());
2539 }
2540 if let Some(rest) = trimmed.strip_prefix("//!") {
2541 return Some(rest.trim().to_string());
2542 }
2543 if let Some(rest) = trimmed.strip_prefix("//") {
2544 return Some(rest.trim().to_string());
2545 }
2546 if trimmed.starts_with("#[") || trimmed.starts_with("#!") {
2548 return None;
2549 }
2550 if let Some(rest) = trimmed.strip_prefix('#') {
2552 return Some(rest.trim().to_string());
2553 }
2554 if let Some(rest) = trimmed.strip_prefix("/**") {
2556 return Some(rest.trim_end_matches("*/").trim().to_string());
2557 }
2558 if let Some(rest) = trimmed.strip_prefix("/*") {
2559 return Some(rest.trim_end_matches("*/").trim().to_string());
2560 }
2561 if let Some(rest) = trimmed.strip_prefix('*') {
2562 let rest = rest.trim_end_matches("*/").trim();
2565 if rest.is_empty() {
2566 return None;
2567 }
2568 if rest.contains(';') || rest.contains('{') {
2570 return None;
2571 }
2572 return Some(rest.to_string());
2573 }
2574 None
2575}
2576
2577fn extract_leading_doc(source: &str, start: usize, end: usize) -> Option<String> {
2580 if start >= source.len() || end > source.len() || start >= end {
2581 return None;
2582 }
2583 let safe_start = if source.is_char_boundary(start) {
2585 start
2586 } else {
2587 source.floor_char_boundary(start)
2588 };
2589 let safe_end = end.min(source.len());
2590 let safe_end = if source.is_char_boundary(safe_end) {
2591 safe_end
2592 } else {
2593 source.floor_char_boundary(safe_end)
2594 };
2595 if safe_start >= safe_end {
2596 return None;
2597 }
2598 let body = &source[safe_start..safe_end];
2599 let lines: Vec<&str> = body.lines().skip(1).collect(); if lines.is_empty() {
2601 return None;
2602 }
2603
2604 let mut doc_lines = Vec::new();
2605
2606 let first_trimmed = lines.first().map(|l| l.trim()).unwrap_or_default();
2608 if first_trimmed.starts_with("\"\"\"") || first_trimmed.starts_with("'''") {
2609 let quote = &first_trimmed[..3];
2610 for line in &lines {
2611 let t = line.trim();
2612 doc_lines.push(t.trim_start_matches(quote).trim_end_matches(quote));
2613 if doc_lines.len() > 1 && t.ends_with(quote) {
2614 break;
2615 }
2616 }
2617 }
2618 else if first_trimmed.starts_with("///") || first_trimmed.starts_with("//!") {
2620 for line in &lines {
2621 let t = line.trim();
2622 if t.starts_with("///") || t.starts_with("//!") {
2623 doc_lines.push(t.trim_start_matches("///").trim_start_matches("//!").trim());
2624 } else {
2625 break;
2626 }
2627 }
2628 }
2629 else if first_trimmed.starts_with("/**") {
2631 for line in &lines {
2632 let t = line.trim();
2633 let cleaned = t
2634 .trim_start_matches("/**")
2635 .trim_start_matches('*')
2636 .trim_end_matches("*/")
2637 .trim();
2638 if !cleaned.is_empty() {
2639 doc_lines.push(cleaned);
2640 }
2641 if t.ends_with("*/") {
2642 break;
2643 }
2644 }
2645 }
2646 else {
2648 for line in &lines {
2649 let t = line.trim();
2650 if t.starts_with("//") || t.starts_with('#') {
2651 doc_lines.push(t.trim_start_matches("//").trim_start_matches('#').trim());
2652 } else {
2653 break;
2654 }
2655 }
2656 }
2657
2658 if doc_lines.is_empty() {
2659 return None;
2660 }
2661 Some(doc_lines.join(" ").trim().to_owned())
2662}
2663
2664pub(super) fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
2665 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
2666}
2667
2668#[cfg(test)]
2669mod tests {
2670 use super::*;
2671 use crate::db::{IndexDb, NewSymbol};
2672 use std::sync::Mutex;
2673
2674 static MODEL_LOCK: Mutex<()> = Mutex::new(());
2676
2677 static ENV_LOCK: Mutex<()> = Mutex::new(());
2684
2685 macro_rules! skip_without_embedding_model {
2686 () => {
2687 if !super::embedding_model_assets_available() {
2688 eprintln!("skipping embedding test: CodeSearchNet model assets unavailable");
2689 return;
2690 }
2691 };
2692 }
2693
2694 fn make_project_with_source() -> (tempfile::TempDir, ProjectRoot) {
2696 let dir = tempfile::tempdir().unwrap();
2697 let root = dir.path();
2698
2699 let source = "def hello():\n print('hi')\n\ndef world():\n return 42\n";
2701 write_python_file_with_symbols(
2702 root,
2703 "main.py",
2704 source,
2705 "hash1",
2706 &[
2707 ("hello", "def hello():", "hello"),
2708 ("world", "def world():", "world"),
2709 ],
2710 );
2711
2712 let project = ProjectRoot::new_exact(root).unwrap();
2713 (dir, project)
2714 }
2715
2716 fn write_python_file_with_symbols(
2717 root: &std::path::Path,
2718 relative_path: &str,
2719 source: &str,
2720 hash: &str,
2721 symbols: &[(&str, &str, &str)],
2722 ) {
2723 std::fs::write(root.join(relative_path), source).unwrap();
2724 let db_path = crate::db::index_db_path(root);
2725 let db = IndexDb::open(&db_path).unwrap();
2726 let file_id = db
2727 .upsert_file(relative_path, 100, hash, source.len() as i64, Some("py"))
2728 .unwrap();
2729
2730 let new_symbols: Vec<NewSymbol<'_>> = symbols
2731 .iter()
2732 .map(|(name, signature, name_path)| {
2733 let start = source.find(signature).unwrap() as i64;
2734 let end = source[start as usize..]
2735 .find("\n\ndef ")
2736 .map(|offset| start + offset as i64)
2737 .unwrap_or(source.len() as i64);
2738 let line = source[..start as usize]
2739 .bytes()
2740 .filter(|&b| b == b'\n')
2741 .count() as i64
2742 + 1;
2743 NewSymbol {
2744 name,
2745 kind: "function",
2746 line,
2747 column_num: 0,
2748 start_byte: start,
2749 end_byte: end,
2750 signature,
2751 name_path,
2752 parent_id: None,
2753 }
2754 })
2755 .collect();
2756 db.insert_symbols(file_id, &new_symbols).unwrap();
2757 }
2758
2759 fn replace_file_embeddings_with_sentinels(
2760 engine: &EmbeddingEngine,
2761 file_path: &str,
2762 sentinels: &[(&str, f32)],
2763 ) {
2764 let mut chunks = engine.store.embeddings_for_files(&[file_path]).unwrap();
2765 for chunk in &mut chunks {
2766 if let Some((_, value)) = sentinels
2767 .iter()
2768 .find(|(symbol_name, _)| *symbol_name == chunk.symbol_name)
2769 {
2770 chunk.embedding = vec![*value; chunk.embedding.len()];
2771 }
2772 }
2773 engine.store.delete_by_file(&[file_path]).unwrap();
2774 engine.store.insert(&chunks).unwrap();
2775 }
2776
2777 #[test]
2778 fn build_embedding_text_with_signature() {
2779 let sym = crate::db::SymbolWithFile {
2780 name: "hello".into(),
2781 kind: "function".into(),
2782 file_path: "main.py".into(),
2783 line: 1,
2784 signature: "def hello():".into(),
2785 name_path: "hello".into(),
2786 start_byte: 0,
2787 end_byte: 10,
2788 };
2789 let text = build_embedding_text(&sym, Some("def hello(): pass"));
2790 assert_eq!(text, "function hello in main.py: def hello():");
2791 }
2792
2793 #[test]
2794 fn build_embedding_text_without_source() {
2795 let sym = crate::db::SymbolWithFile {
2796 name: "MyClass".into(),
2797 kind: "class".into(),
2798 file_path: "app.py".into(),
2799 line: 5,
2800 signature: "class MyClass:".into(),
2801 name_path: "MyClass".into(),
2802 start_byte: 0,
2803 end_byte: 50,
2804 };
2805 let text = build_embedding_text(&sym, None);
2806 assert_eq!(text, "class MyClass (My Class) in app.py: class MyClass:");
2807 }
2808
2809 #[test]
2810 fn build_embedding_text_empty_signature() {
2811 let sym = crate::db::SymbolWithFile {
2812 name: "CONFIG".into(),
2813 kind: "variable".into(),
2814 file_path: "config.py".into(),
2815 line: 1,
2816 signature: String::new(),
2817 name_path: "CONFIG".into(),
2818 start_byte: 0,
2819 end_byte: 0,
2820 };
2821 let text = build_embedding_text(&sym, None);
2822 assert_eq!(text, "variable CONFIG in config.py");
2823 }
2824
2825 #[test]
2826 fn filters_direct_test_symbols_from_embedding_index() {
2827 let source = "#[test]\nfn alias_case() {}\n";
2828 let sym = crate::db::SymbolWithFile {
2829 name: "alias_case".into(),
2830 kind: "function".into(),
2831 file_path: "src/lib.rs".into(),
2832 line: 2,
2833 signature: "fn alias_case() {}".into(),
2834 name_path: "alias_case".into(),
2835 start_byte: source.find("fn alias_case").unwrap() as i64,
2836 end_byte: source.len() as i64,
2837 };
2838
2839 assert!(is_test_only_symbol(&sym, Some(source)));
2840 }
2841
2842 #[test]
2843 fn filters_cfg_test_module_symbols_from_embedding_index() {
2844 let source = "#[cfg(all(test, feature = \"semantic\"))]\nmod semantic_tests {\n fn helper_case() {}\n}\n";
2845 let sym = crate::db::SymbolWithFile {
2846 name: "helper_case".into(),
2847 kind: "function".into(),
2848 file_path: "src/lib.rs".into(),
2849 line: 3,
2850 signature: "fn helper_case() {}".into(),
2851 name_path: "helper_case".into(),
2852 start_byte: source.find("fn helper_case").unwrap() as i64,
2853 end_byte: source.len() as i64,
2854 };
2855
2856 assert!(is_test_only_symbol(&sym, Some(source)));
2857 }
2858
2859 #[test]
2860 fn extract_python_docstring() {
2861 let source =
2862 "def greet(name):\n \"\"\"Say hello to a person.\"\"\"\n print(f'hi {name}')\n";
2863 let doc = extract_leading_doc(source, 0, source.len()).unwrap();
2864 assert!(doc.contains("Say hello to a person"));
2865 }
2866
2867 #[test]
2868 fn extract_rust_doc_comment() {
2869 let source = "fn dispatch_tool() {\n /// Route incoming tool requests.\n /// Handles all MCP methods.\n let x = 1;\n}\n";
2870 let doc = extract_leading_doc(source, 0, source.len()).unwrap();
2871 assert!(doc.contains("Route incoming tool requests"));
2872 assert!(doc.contains("Handles all MCP methods"));
2873 }
2874
2875 #[test]
2876 fn extract_leading_doc_returns_none_for_no_doc() {
2877 let source = "def f():\n return 1\n";
2878 assert!(extract_leading_doc(source, 0, source.len()).is_none());
2879 }
2880
2881 #[test]
2882 fn extract_body_hint_finds_first_meaningful_line() {
2883 let source = "pub fn parse_symbols(\n project: &ProjectRoot,\n) -> Vec<SymbolInfo> {\n let mut parser = tree_sitter::Parser::new();\n parser.set_language(lang);\n}\n";
2884 let hint = extract_body_hint(source, 0, source.len());
2885 assert!(hint.is_some());
2886 assert!(hint.unwrap().contains("tree_sitter::Parser"));
2887 }
2888
2889 #[test]
2890 fn extract_body_hint_skips_comments() {
2891 let source = "fn foo() {\n // setup\n let x = bar();\n}\n";
2892 let hint = extract_body_hint(source, 0, source.len());
2893 assert_eq!(hint.unwrap(), "let x = bar();");
2894 }
2895
2896 #[test]
2897 fn extract_body_hint_returns_none_for_empty() {
2898 let source = "fn empty() {\n}\n";
2899 let hint = extract_body_hint(source, 0, source.len());
2900 assert!(hint.is_none());
2901 }
2902
2903 #[test]
2904 fn extract_body_hint_multi_line_collection_via_env_override() {
2905 let previous_lines = std::env::var("CODELENS_EMBED_HINT_LINES").ok();
2910 let previous_chars = std::env::var("CODELENS_EMBED_HINT_CHARS").ok();
2911 unsafe {
2912 std::env::set_var("CODELENS_EMBED_HINT_LINES", "3");
2913 std::env::set_var("CODELENS_EMBED_HINT_CHARS", "200");
2914 }
2915
2916 let source = "\
2917fn route_request() {
2918 let kind = detect_request_kind();
2919 let target = dispatch_table.get(&kind);
2920 return target.handle();
2921}
2922";
2923 let hint = extract_body_hint(source, 0, source.len()).expect("hint present");
2924
2925 let env_restore = || unsafe {
2926 match &previous_lines {
2927 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_LINES", value),
2928 None => std::env::remove_var("CODELENS_EMBED_HINT_LINES"),
2929 }
2930 match &previous_chars {
2931 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_CHARS", value),
2932 None => std::env::remove_var("CODELENS_EMBED_HINT_CHARS"),
2933 }
2934 };
2935
2936 let all_three = hint.contains("detect_request_kind")
2937 && hint.contains("dispatch_table")
2938 && hint.contains("target.handle");
2939 let has_separator = hint.contains(" · ");
2940 env_restore();
2941
2942 assert!(all_three, "missing one of three body lines: {hint}");
2943 assert!(has_separator, "missing · separator: {hint}");
2944 }
2945
2946 #[test]
2957 fn hint_line_budget_respects_env_override() {
2958 let previous = std::env::var("CODELENS_EMBED_HINT_LINES").ok();
2961 unsafe {
2962 std::env::set_var("CODELENS_EMBED_HINT_LINES", "5");
2963 }
2964 let budget = super::hint_line_budget();
2965 unsafe {
2966 match previous {
2967 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_LINES", value),
2968 None => std::env::remove_var("CODELENS_EMBED_HINT_LINES"),
2969 }
2970 }
2971 assert_eq!(budget, 5);
2972 }
2973
2974 #[test]
2975 fn is_nl_shaped_accepts_multi_word_prose() {
2976 assert!(super::is_nl_shaped("skip comments and string literals"));
2977 assert!(super::is_nl_shaped("failed to open database"));
2978 assert!(super::is_nl_shaped("detect client version"));
2979 }
2980
2981 #[test]
2982 fn is_nl_shaped_rejects_code_and_paths() {
2983 assert!(!super::is_nl_shaped("crates/codelens-engine/src"));
2985 assert!(!super::is_nl_shaped("C:\\Users\\foo"));
2986 assert!(!super::is_nl_shaped("std::sync::Mutex"));
2988 assert!(!super::is_nl_shaped("detect_client"));
2990 assert!(!super::is_nl_shaped("ok"));
2992 assert!(!super::is_nl_shaped(""));
2993 assert!(!super::is_nl_shaped("1 2 3 4 5"));
2995 }
2996
2997 #[test]
2998 fn extract_comment_body_strips_comment_markers() {
2999 assert_eq!(
3000 super::extract_comment_body("/// rust doc comment"),
3001 Some("rust doc comment".to_string())
3002 );
3003 assert_eq!(
3004 super::extract_comment_body("// regular line comment"),
3005 Some("regular line comment".to_string())
3006 );
3007 assert_eq!(
3008 super::extract_comment_body("# python line comment"),
3009 Some("python line comment".to_string())
3010 );
3011 assert_eq!(
3012 super::extract_comment_body("/* inline block */"),
3013 Some("inline block".to_string())
3014 );
3015 assert_eq!(
3016 super::extract_comment_body("* continuation line"),
3017 Some("continuation line".to_string())
3018 );
3019 }
3020
3021 #[test]
3022 fn extract_comment_body_rejects_rust_attributes_and_shebangs() {
3023 assert!(super::extract_comment_body("#[derive(Debug)]").is_none());
3024 assert!(super::extract_comment_body("#[test]").is_none());
3025 assert!(super::extract_comment_body("#!/usr/bin/env python").is_none());
3026 }
3027
3028 #[test]
3029 fn extract_nl_tokens_gated_off_by_default() {
3030 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3031 let previous = std::env::var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS").ok();
3033 unsafe {
3034 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS");
3035 }
3036 let source = "\
3037fn skip_things() {
3038 // skip comments and string literals during search
3039 let lit = \"scan for matching tokens\";
3040}
3041";
3042 let result = extract_nl_tokens(source, 0, source.len());
3043 unsafe {
3044 if let Some(value) = previous {
3045 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", value);
3046 }
3047 }
3048 assert!(result.is_none(), "gate leaked: {result:?}");
3049 }
3050
3051 #[test]
3052 fn auto_hint_mode_defaults_on_unless_explicit_off() {
3053 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3054 let previous = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3062
3063 unsafe {
3065 std::env::remove_var("CODELENS_EMBED_HINT_AUTO");
3066 }
3067 let default_enabled = super::auto_hint_mode_enabled();
3068 assert!(
3069 default_enabled,
3070 "v1.6.0 default flip: auto hint mode should be ON when env unset"
3071 );
3072
3073 unsafe {
3075 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "0");
3076 }
3077 let explicit_off = super::auto_hint_mode_enabled();
3078 assert!(
3079 !explicit_off,
3080 "explicit CODELENS_EMBED_HINT_AUTO=0 must still disable (opt-out escape hatch)"
3081 );
3082
3083 unsafe {
3085 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3086 }
3087 let explicit_on = super::auto_hint_mode_enabled();
3088 assert!(
3089 explicit_on,
3090 "explicit CODELENS_EMBED_HINT_AUTO=1 must still enable"
3091 );
3092
3093 unsafe {
3095 match previous {
3096 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3097 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3098 }
3099 }
3100 }
3101
3102 #[test]
3103 fn language_supports_nl_stack_classifies_correctly() {
3104 assert!(super::language_supports_nl_stack("rs"));
3106 assert!(super::language_supports_nl_stack("rust"));
3107 assert!(super::language_supports_nl_stack("cpp"));
3108 assert!(super::language_supports_nl_stack("c++"));
3109 assert!(super::language_supports_nl_stack("c"));
3110 assert!(super::language_supports_nl_stack("go"));
3111 assert!(super::language_supports_nl_stack("golang"));
3112 assert!(super::language_supports_nl_stack("java"));
3113 assert!(super::language_supports_nl_stack("kt"));
3114 assert!(super::language_supports_nl_stack("kotlin"));
3115 assert!(super::language_supports_nl_stack("scala"));
3116 assert!(super::language_supports_nl_stack("cs"));
3117 assert!(super::language_supports_nl_stack("csharp"));
3118 assert!(super::language_supports_nl_stack("ts"));
3121 assert!(super::language_supports_nl_stack("typescript"));
3122 assert!(super::language_supports_nl_stack("tsx"));
3123 assert!(super::language_supports_nl_stack("js"));
3124 assert!(super::language_supports_nl_stack("javascript"));
3125 assert!(super::language_supports_nl_stack("jsx"));
3126 assert!(super::language_supports_nl_stack("Rust"));
3128 assert!(super::language_supports_nl_stack("RUST"));
3129 assert!(super::language_supports_nl_stack("TypeScript"));
3130 assert!(super::language_supports_nl_stack(" rust "));
3132 assert!(super::language_supports_nl_stack(" ts "));
3133
3134 assert!(!super::language_supports_nl_stack("py"));
3136 assert!(!super::language_supports_nl_stack("python"));
3137 assert!(!super::language_supports_nl_stack("rb"));
3138 assert!(!super::language_supports_nl_stack("ruby"));
3139 assert!(!super::language_supports_nl_stack("php"));
3140 assert!(!super::language_supports_nl_stack("lua"));
3141 assert!(!super::language_supports_nl_stack("sh"));
3142 assert!(!super::language_supports_nl_stack("klingon"));
3144 assert!(!super::language_supports_nl_stack(""));
3145 }
3146
3147 #[test]
3148 fn language_supports_sparse_weighting_classifies_correctly() {
3149 assert!(super::language_supports_sparse_weighting("rs"));
3150 assert!(super::language_supports_sparse_weighting("rust"));
3151 assert!(super::language_supports_sparse_weighting("cpp"));
3152 assert!(super::language_supports_sparse_weighting("go"));
3153 assert!(super::language_supports_sparse_weighting("java"));
3154 assert!(super::language_supports_sparse_weighting("kotlin"));
3155 assert!(super::language_supports_sparse_weighting("csharp"));
3156
3157 assert!(!super::language_supports_sparse_weighting("ts"));
3158 assert!(!super::language_supports_sparse_weighting("typescript"));
3159 assert!(!super::language_supports_sparse_weighting("tsx"));
3160 assert!(!super::language_supports_sparse_weighting("js"));
3161 assert!(!super::language_supports_sparse_weighting("javascript"));
3162 assert!(!super::language_supports_sparse_weighting("jsx"));
3163 assert!(!super::language_supports_sparse_weighting("py"));
3164 assert!(!super::language_supports_sparse_weighting("python"));
3165 assert!(!super::language_supports_sparse_weighting("klingon"));
3166 assert!(!super::language_supports_sparse_weighting(""));
3167 }
3168
3169 #[test]
3170 fn auto_hint_should_enable_requires_both_gate_and_supported_lang() {
3171 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3172 let prev_auto = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3173 let prev_lang = std::env::var("CODELENS_EMBED_HINT_AUTO_LANG").ok();
3174
3175 unsafe {
3179 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "0");
3180 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3181 }
3182 assert!(
3183 !super::auto_hint_should_enable(),
3184 "gate-off (explicit =0) with lang=rust must stay disabled"
3185 );
3186
3187 unsafe {
3189 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3190 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3191 }
3192 assert!(
3193 super::auto_hint_should_enable(),
3194 "gate-on + lang=rust must enable"
3195 );
3196
3197 unsafe {
3198 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3199 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "typescript");
3200 }
3201 assert!(
3202 super::auto_hint_should_enable(),
3203 "gate-on + lang=typescript must keep Phase 2b/2c enabled"
3204 );
3205
3206 unsafe {
3208 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3209 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3210 }
3211 assert!(
3212 !super::auto_hint_should_enable(),
3213 "gate-on + lang=python must stay disabled"
3214 );
3215
3216 unsafe {
3218 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3219 std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG");
3220 }
3221 assert!(
3222 !super::auto_hint_should_enable(),
3223 "gate-on + no lang tag must stay disabled"
3224 );
3225
3226 unsafe {
3228 match prev_auto {
3229 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3230 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3231 }
3232 match prev_lang {
3233 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", v),
3234 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG"),
3235 }
3236 }
3237 }
3238
3239 #[test]
3240 fn auto_sparse_should_enable_requires_both_gate_and_sparse_supported_lang() {
3241 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3242 let prev_auto = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3243 let prev_lang = std::env::var("CODELENS_EMBED_HINT_AUTO_LANG").ok();
3244
3245 unsafe {
3246 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "0");
3247 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3248 }
3249 assert!(
3250 !super::auto_sparse_should_enable(),
3251 "gate-off (explicit =0) must disable sparse auto gate"
3252 );
3253
3254 unsafe {
3255 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3256 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3257 }
3258 assert!(
3259 super::auto_sparse_should_enable(),
3260 "gate-on + lang=rust must enable sparse auto gate"
3261 );
3262
3263 unsafe {
3264 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3265 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "typescript");
3266 }
3267 assert!(
3268 !super::auto_sparse_should_enable(),
3269 "gate-on + lang=typescript must keep sparse auto gate disabled"
3270 );
3271
3272 unsafe {
3273 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3274 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3275 }
3276 assert!(
3277 !super::auto_sparse_should_enable(),
3278 "gate-on + lang=python must keep sparse auto gate disabled"
3279 );
3280
3281 unsafe {
3282 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3283 std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG");
3284 }
3285 assert!(
3286 !super::auto_sparse_should_enable(),
3287 "gate-on + no lang tag must keep sparse auto gate disabled"
3288 );
3289
3290 unsafe {
3291 match prev_auto {
3292 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3293 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3294 }
3295 match prev_lang {
3296 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", v),
3297 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG"),
3298 }
3299 }
3300 }
3301
3302 #[test]
3303 fn nl_tokens_enabled_explicit_env_wins_over_auto() {
3304 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3305 let prev_explicit = std::env::var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS").ok();
3306 let prev_auto = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3307 let prev_lang = std::env::var("CODELENS_EMBED_HINT_AUTO_LANG").ok();
3308
3309 unsafe {
3311 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", "1");
3312 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3313 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3314 }
3315 assert!(
3316 super::nl_tokens_enabled(),
3317 "explicit=1 must win over auto+python=off"
3318 );
3319
3320 unsafe {
3322 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", "0");
3323 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3324 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3325 }
3326 assert!(
3327 !super::nl_tokens_enabled(),
3328 "explicit=0 must win over auto+rust=on"
3329 );
3330
3331 unsafe {
3333 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS");
3334 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3335 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3336 }
3337 assert!(
3338 super::nl_tokens_enabled(),
3339 "no explicit + auto+rust must enable"
3340 );
3341
3342 unsafe {
3344 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS");
3345 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3346 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3347 }
3348 assert!(
3349 !super::nl_tokens_enabled(),
3350 "no explicit + auto+python must stay disabled"
3351 );
3352
3353 unsafe {
3355 match prev_explicit {
3356 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", v),
3357 None => std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS"),
3358 }
3359 match prev_auto {
3360 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3361 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3362 }
3363 match prev_lang {
3364 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", v),
3365 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG"),
3366 }
3367 }
3368 }
3369
3370 #[test]
3371 fn strict_comments_gated_off_by_default() {
3372 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3373 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS").ok();
3374 unsafe {
3375 std::env::remove_var("CODELENS_EMBED_HINT_STRICT_COMMENTS");
3376 }
3377 let enabled = super::strict_comments_enabled();
3378 unsafe {
3379 if let Some(value) = previous {
3380 std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", value);
3381 }
3382 }
3383 assert!(!enabled, "strict comments gate leaked");
3384 }
3385
3386 #[test]
3387 fn looks_like_meta_annotation_detects_rejected_prefixes() {
3388 assert!(super::looks_like_meta_annotation("TODO: fix later"));
3390 assert!(super::looks_like_meta_annotation("todo handle edge case"));
3391 assert!(super::looks_like_meta_annotation("FIXME this is broken"));
3392 assert!(super::looks_like_meta_annotation(
3393 "HACK: workaround for bug"
3394 ));
3395 assert!(super::looks_like_meta_annotation("XXX not implemented yet"));
3396 assert!(super::looks_like_meta_annotation(
3397 "BUG in the upstream crate"
3398 ));
3399 assert!(super::looks_like_meta_annotation("REVIEW before merging"));
3400 assert!(super::looks_like_meta_annotation(
3401 "REFACTOR this block later"
3402 ));
3403 assert!(super::looks_like_meta_annotation("TEMP: remove before v2"));
3404 assert!(super::looks_like_meta_annotation(
3405 "DEPRECATED use new_api instead"
3406 ));
3407 assert!(super::looks_like_meta_annotation(
3409 " TODO: with leading ws"
3410 ));
3411 }
3412
3413 #[test]
3414 fn looks_like_meta_annotation_preserves_behaviour_prefixes() {
3415 assert!(!super::looks_like_meta_annotation(
3417 "NOTE: this branch handles empty input"
3418 ));
3419 assert!(!super::looks_like_meta_annotation(
3420 "WARN: overflow is possible"
3421 ));
3422 assert!(!super::looks_like_meta_annotation(
3423 "SAFETY: caller must hold the lock"
3424 ));
3425 assert!(!super::looks_like_meta_annotation(
3426 "PANIC: unreachable by construction"
3427 ));
3428 assert!(!super::looks_like_meta_annotation(
3430 "parse json body from request"
3431 ));
3432 assert!(!super::looks_like_meta_annotation(
3433 "walk directory respecting gitignore"
3434 ));
3435 assert!(!super::looks_like_meta_annotation(
3436 "compute cosine similarity between vectors"
3437 ));
3438 assert!(!super::looks_like_meta_annotation(""));
3440 assert!(!super::looks_like_meta_annotation(" "));
3441 assert!(!super::looks_like_meta_annotation("123 numeric prefix"));
3442 }
3443
3444 #[test]
3445 fn strict_comments_filters_meta_annotations_during_extraction() {
3446 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3447 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS").ok();
3448 unsafe {
3449 std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", "1");
3450 }
3451 let source = "\
3452fn handle_request() {
3453 // TODO: handle the error path properly
3454 // parse json body from the incoming request
3455 // FIXME: this can panic on empty input
3456 // walk directory respecting the gitignore rules
3457 let _x = 1;
3458}
3459";
3460 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3461 unsafe {
3462 match previous {
3463 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", value),
3464 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_COMMENTS"),
3465 }
3466 }
3467 let hint = result.expect("behaviour comments must survive");
3468 assert!(
3472 hint.contains("parse json body"),
3473 "behaviour comment dropped: {hint}"
3474 );
3475 assert!(!hint.contains("TODO"), "TODO annotation leaked: {hint}");
3478 assert!(!hint.contains("FIXME"), "FIXME annotation leaked: {hint}");
3479 }
3480
3481 #[test]
3482 fn strict_comments_is_orthogonal_to_strict_literals() {
3483 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3484 let prev_c = std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS").ok();
3488 let prev_l = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3489 unsafe {
3490 std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", "1");
3491 std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS");
3492 }
3493 let source = "\
3496fn handle() {
3497 // handles real behaviour
3498 let fmt = \"format error string\";
3499}
3500";
3501 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3502 unsafe {
3503 match prev_c {
3504 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", v),
3505 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_COMMENTS"),
3506 }
3507 match prev_l {
3508 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", v),
3509 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS"),
3510 }
3511 }
3512 let hint = result.expect("tokens must exist");
3513 assert!(hint.contains("handles real"), "comment dropped: {hint}");
3515 assert!(
3518 hint.contains("format error string"),
3519 "literal dropped: {hint}"
3520 );
3521 }
3522
3523 #[test]
3524 fn strict_literal_filter_gated_off_by_default() {
3525 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3526 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3527 unsafe {
3528 std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS");
3529 }
3530 let enabled = super::strict_literal_filter_enabled();
3531 unsafe {
3532 if let Some(value) = previous {
3533 std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", value);
3534 }
3535 }
3536 assert!(!enabled, "strict literal filter gate leaked");
3537 }
3538
3539 #[test]
3540 fn contains_format_specifier_detects_c_and_python_style() {
3541 assert!(super::contains_format_specifier("Invalid URL %s"));
3543 assert!(super::contains_format_specifier("got %d matches"));
3544 assert!(super::contains_format_specifier("value=%r"));
3545 assert!(super::contains_format_specifier("size=%f"));
3546 assert!(super::contains_format_specifier("sending request to {url}"));
3548 assert!(super::contains_format_specifier("got {0} items"));
3549 assert!(super::contains_format_specifier("{:?}"));
3550 assert!(super::contains_format_specifier("value: {x:.2f}"));
3551 assert!(super::contains_format_specifier("{}"));
3552 assert!(!super::contains_format_specifier(
3554 "skip comments and string literals"
3555 ));
3556 assert!(!super::contains_format_specifier("failed to open database"));
3557 assert!(!super::contains_format_specifier("{name: foo, id: 1}"));
3560 }
3561
3562 #[test]
3563 fn looks_like_error_or_log_prefix_rejects_common_patterns() {
3564 assert!(super::looks_like_error_or_log_prefix("Invalid URL format"));
3565 assert!(super::looks_like_error_or_log_prefix(
3566 "Cannot decode response"
3567 ));
3568 assert!(super::looks_like_error_or_log_prefix("could not open file"));
3569 assert!(super::looks_like_error_or_log_prefix(
3570 "Failed to send request"
3571 ));
3572 assert!(super::looks_like_error_or_log_prefix(
3573 "Expected int, got str"
3574 ));
3575 assert!(super::looks_like_error_or_log_prefix(
3576 "sending request to server"
3577 ));
3578 assert!(super::looks_like_error_or_log_prefix(
3579 "received response headers"
3580 ));
3581 assert!(super::looks_like_error_or_log_prefix(
3582 "starting worker pool"
3583 ));
3584 assert!(!super::looks_like_error_or_log_prefix(
3586 "parse json body from request"
3587 ));
3588 assert!(!super::looks_like_error_or_log_prefix(
3589 "compute cosine similarity between vectors"
3590 ));
3591 assert!(!super::looks_like_error_or_log_prefix(
3592 "walk directory tree respecting gitignore"
3593 ));
3594 }
3595
3596 #[test]
3597 fn strict_mode_rejects_format_and_error_literals_during_extraction() {
3598 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3599 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3603 unsafe {
3604 std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", "1");
3605 }
3606 let source = "\
3607fn handle_request() {
3608 let err = \"Invalid URL %s\";
3609 let log = \"sending request to the upstream server\";
3610 let fmt = \"received {count} items in batch\";
3611 let real = \"parse json body from the incoming request\";
3612}
3613";
3614 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3615 unsafe {
3616 match previous {
3617 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", value),
3618 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS"),
3619 }
3620 }
3621 let hint = result.expect("some token should survive");
3622 assert!(
3624 hint.contains("parse json body"),
3625 "real literal was filtered out: {hint}"
3626 );
3627 assert!(
3629 !hint.contains("Invalid URL"),
3630 "format-specifier literal leaked: {hint}"
3631 );
3632 assert!(
3633 !hint.contains("sending request"),
3634 "log-prefix literal leaked: {hint}"
3635 );
3636 assert!(
3637 !hint.contains("received {count}"),
3638 "python fstring literal leaked: {hint}"
3639 );
3640 }
3641
3642 #[test]
3643 fn strict_mode_leaves_comments_untouched() {
3644 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3645 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3649 unsafe {
3650 std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", "1");
3651 }
3652 let source = "\
3653fn do_work() {
3654 // Invalid inputs are rejected by this guard clause.
3655 // sending requests in parallel across worker threads.
3656 let _lit = \"format spec %s\";
3657}
3658";
3659 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3660 unsafe {
3661 match previous {
3662 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", value),
3663 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS"),
3664 }
3665 }
3666 let hint = result.expect("comments should survive strict mode");
3667 assert!(
3670 hint.contains("Invalid inputs") || hint.contains("rejected by this guard"),
3671 "strict mode swallowed a comment: {hint}"
3672 );
3673 assert!(
3675 !hint.contains("format spec"),
3676 "format-specifier literal leaked under strict mode: {hint}"
3677 );
3678 }
3679
3680 #[test]
3681 fn should_reject_literal_strict_composes_format_and_prefix() {
3682 assert!(super::should_reject_literal_strict("Invalid URL %s"));
3686 assert!(super::should_reject_literal_strict(
3687 "sending request to server"
3688 ));
3689 assert!(super::should_reject_literal_strict("value: {x:.2f}"));
3690 assert!(!super::should_reject_literal_strict(
3692 "parse json body from the incoming request"
3693 ));
3694 assert!(!super::should_reject_literal_strict(
3695 "compute cosine similarity between vectors"
3696 ));
3697 }
3698
3699 #[test]
3700 fn is_static_method_ident_accepts_pascal_and_rejects_snake() {
3701 assert!(super::is_static_method_ident("HashMap"));
3702 assert!(super::is_static_method_ident("Parser"));
3703 assert!(super::is_static_method_ident("A"));
3704 assert!(!super::is_static_method_ident("std"));
3707 assert!(!super::is_static_method_ident("fs"));
3708 assert!(!super::is_static_method_ident("_private"));
3709 assert!(!super::is_static_method_ident(""));
3710 }
3711
3712 #[test]
3713 fn extract_api_calls_gated_off_by_default() {
3714 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3715 let previous = std::env::var("CODELENS_EMBED_HINT_INCLUDE_API_CALLS").ok();
3717 unsafe {
3718 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_API_CALLS");
3719 }
3720 let source = "\
3721fn make_parser() {
3722 let p = Parser::new();
3723 let _ = HashMap::with_capacity(8);
3724}
3725";
3726 let result = extract_api_calls(source, 0, source.len());
3727 unsafe {
3728 if let Some(value) = previous {
3729 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_API_CALLS", value);
3730 }
3731 }
3732 assert!(result.is_none(), "gate leaked: {result:?}");
3733 }
3734
3735 #[test]
3736 fn extract_api_calls_captures_type_method_patterns() {
3737 let source = "\
3739fn open_db() {
3740 let p = Parser::new();
3741 let map = HashMap::with_capacity(16);
3742 let _ = tree_sitter::Parser::new();
3743}
3744";
3745 let hint = super::extract_api_calls_inner(source, 0, source.len())
3746 .expect("api calls should be produced");
3747 assert!(hint.contains("Parser::new"), "missing Parser::new: {hint}");
3748 assert!(
3749 hint.contains("HashMap::with_capacity"),
3750 "missing HashMap::with_capacity: {hint}"
3751 );
3752 }
3753
3754 #[test]
3755 fn extract_api_calls_rejects_module_prefixed_free_functions() {
3756 let source = "\
3759fn read_config() {
3760 let _ = std::fs::read_to_string(\"foo\");
3761 let _ = crate::util::parse();
3762}
3763";
3764 let hint = super::extract_api_calls_inner(source, 0, source.len());
3765 if let Some(hint) = hint {
3768 assert!(!hint.contains("std::fs"), "lowercase module leaked: {hint}");
3769 assert!(
3770 !hint.contains("fs::read_to_string"),
3771 "module-prefixed free function leaked: {hint}"
3772 );
3773 assert!(!hint.contains("crate::util"), "crate path leaked: {hint}");
3774 }
3775 }
3776
3777 #[test]
3778 fn extract_api_calls_deduplicates_repeated_calls() {
3779 let source = "\
3780fn hot_loop() {
3781 for _ in 0..10 {
3782 let _ = Parser::new();
3783 let _ = Parser::new();
3784 }
3785 let _ = Parser::new();
3786}
3787";
3788 let hint = super::extract_api_calls_inner(source, 0, source.len())
3789 .expect("api calls should be produced");
3790 let first = hint.find("Parser::new").expect("hit");
3791 let rest = &hint[first + "Parser::new".len()..];
3792 assert!(
3793 !rest.contains("Parser::new"),
3794 "duplicate not deduplicated: {hint}"
3795 );
3796 }
3797
3798 #[test]
3799 fn extract_api_calls_returns_none_when_body_has_no_type_calls() {
3800 let source = "\
3801fn plain() {
3802 let x = 1;
3803 let y = x + 2;
3804}
3805";
3806 assert!(super::extract_api_calls_inner(source, 0, source.len()).is_none());
3807 }
3808
3809 #[test]
3810 fn extract_nl_tokens_collects_comments_and_string_literals() {
3811 let source = "\
3815fn search_for_matches() {
3816 // skip comments and string literals during search
3817 let error = \"failed to open database\";
3818 let single = \"tok\";
3819 let path = \"src/foo/bar\";
3820 let keyword = match kind {
3821 Kind::Ident => \"detect client version\",
3822 _ => \"\",
3823 };
3824}
3825";
3826 let hint = super::extract_nl_tokens_inner(source, 0, source.len())
3832 .expect("nl tokens should be produced");
3833 let has_first_nl_signal = hint.contains("skip comments")
3837 || hint.contains("failed to open")
3838 || hint.contains("detect client");
3839 assert!(has_first_nl_signal, "no NL signal produced: {hint}");
3840 assert!(!hint.contains(" tok "), "short literal leaked: {hint}");
3842 assert!(!hint.contains("src/foo/bar"), "path literal leaked: {hint}");
3844 }
3845
3846 #[test]
3847 fn hint_char_budget_respects_env_override() {
3848 let previous = std::env::var("CODELENS_EMBED_HINT_CHARS").ok();
3849 unsafe {
3850 std::env::set_var("CODELENS_EMBED_HINT_CHARS", "120");
3851 }
3852 let budget = super::hint_char_budget();
3853 unsafe {
3854 match previous {
3855 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_CHARS", value),
3856 None => std::env::remove_var("CODELENS_EMBED_HINT_CHARS"),
3857 }
3858 }
3859 assert_eq!(budget, 120);
3860 }
3861
3862 #[test]
3863 fn embedding_to_bytes_roundtrip() {
3864 let floats = vec![1.0f32, -0.5, 0.0, 3.25];
3865 let bytes = embedding_to_bytes(&floats);
3866 assert_eq!(bytes.len(), 4 * 4);
3867 let recovered: Vec<f32> = bytes
3869 .chunks_exact(4)
3870 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
3871 .collect();
3872 assert_eq!(floats, recovered);
3873 }
3874
3875 #[test]
3876 fn duplicate_pair_key_is_order_independent() {
3877 let a = duplicate_pair_key("a.py", "foo", "b.py", "bar");
3878 let b = duplicate_pair_key("b.py", "bar", "a.py", "foo");
3879 assert_eq!(a, b);
3880 }
3881
3882 #[test]
3883 fn text_embedding_cache_updates_recency() {
3884 let mut cache = TextEmbeddingCache::new(2);
3885 cache.insert("a".into(), vec![1.0]);
3886 cache.insert("b".into(), vec![2.0]);
3887 assert_eq!(cache.get("a"), Some(vec![1.0]));
3888 cache.insert("c".into(), vec![3.0]);
3889
3890 assert_eq!(cache.get("a"), Some(vec![1.0]));
3891 assert_eq!(cache.get("b"), None);
3892 assert_eq!(cache.get("c"), Some(vec![3.0]));
3893 }
3894
3895 #[test]
3896 fn text_embedding_cache_can_be_disabled() {
3897 let mut cache = TextEmbeddingCache::new(0);
3898 cache.insert("a".into(), vec![1.0]);
3899 assert_eq!(cache.get("a"), None);
3900 }
3901
3902 #[test]
3903 fn engine_new_and_index() {
3904 let _lock = MODEL_LOCK.lock().unwrap();
3905 skip_without_embedding_model!();
3906 let (_dir, project) = make_project_with_source();
3907 let engine = EmbeddingEngine::new(&project).expect("engine should load");
3908 assert!(!engine.is_indexed());
3909
3910 let count = engine.index_from_project(&project).unwrap();
3911 assert_eq!(count, 2, "should index 2 symbols");
3912 assert!(engine.is_indexed());
3913 }
3914
3915 #[test]
3916 fn engine_search_returns_results() {
3917 let _lock = MODEL_LOCK.lock().unwrap();
3918 skip_without_embedding_model!();
3919 let (_dir, project) = make_project_with_source();
3920 let engine = EmbeddingEngine::new(&project).unwrap();
3921 engine.index_from_project(&project).unwrap();
3922
3923 let results = engine.search("hello function", 10).unwrap();
3924 assert!(!results.is_empty(), "search should return results");
3925 for r in &results {
3926 assert!(
3927 r.score >= -1.0 && r.score <= 1.0,
3928 "score should be in [-1,1]: {}",
3929 r.score
3930 );
3931 }
3932 }
3933
3934 #[test]
3935 fn engine_incremental_index() {
3936 let _lock = MODEL_LOCK.lock().unwrap();
3937 skip_without_embedding_model!();
3938 let (_dir, project) = make_project_with_source();
3939 let engine = EmbeddingEngine::new(&project).unwrap();
3940 engine.index_from_project(&project).unwrap();
3941 assert_eq!(engine.store.count().unwrap(), 2);
3942
3943 let count = engine.index_changed_files(&project, &["main.py"]).unwrap();
3945 assert_eq!(count, 2);
3946 assert_eq!(engine.store.count().unwrap(), 2);
3947 }
3948
3949 #[test]
3950 fn engine_reindex_preserves_symbol_count() {
3951 let _lock = MODEL_LOCK.lock().unwrap();
3952 skip_without_embedding_model!();
3953 let (_dir, project) = make_project_with_source();
3954 let engine = EmbeddingEngine::new(&project).unwrap();
3955 engine.index_from_project(&project).unwrap();
3956 assert_eq!(engine.store.count().unwrap(), 2);
3957
3958 let count = engine.index_from_project(&project).unwrap();
3959 assert_eq!(count, 2);
3960 assert_eq!(engine.store.count().unwrap(), 2);
3961 }
3962
3963 #[test]
3964 fn full_reindex_reuses_unchanged_embeddings() {
3965 let _lock = MODEL_LOCK.lock().unwrap();
3966 skip_without_embedding_model!();
3967 let (_dir, project) = make_project_with_source();
3968 let engine = EmbeddingEngine::new(&project).unwrap();
3969 engine.index_from_project(&project).unwrap();
3970
3971 replace_file_embeddings_with_sentinels(
3972 &engine,
3973 "main.py",
3974 &[("hello", 11.0), ("world", 22.0)],
3975 );
3976
3977 let count = engine.index_from_project(&project).unwrap();
3978 assert_eq!(count, 2);
3979
3980 let hello = engine
3981 .store
3982 .get_embedding("main.py", "hello")
3983 .unwrap()
3984 .expect("hello should exist");
3985 let world = engine
3986 .store
3987 .get_embedding("main.py", "world")
3988 .unwrap()
3989 .expect("world should exist");
3990 assert!(hello.embedding.iter().all(|value| *value == 11.0));
3991 assert!(world.embedding.iter().all(|value| *value == 22.0));
3992 }
3993
3994 #[test]
3995 fn full_reindex_reuses_unchanged_sibling_after_edit() {
3996 let _lock = MODEL_LOCK.lock().unwrap();
3997 skip_without_embedding_model!();
3998 let (dir, project) = make_project_with_source();
3999 let engine = EmbeddingEngine::new(&project).unwrap();
4000 engine.index_from_project(&project).unwrap();
4001
4002 replace_file_embeddings_with_sentinels(
4003 &engine,
4004 "main.py",
4005 &[("hello", 11.0), ("world", 22.0)],
4006 );
4007
4008 let updated_source =
4009 "def hello():\n print('hi')\n\ndef world(name):\n return name.upper()\n";
4010 write_python_file_with_symbols(
4011 dir.path(),
4012 "main.py",
4013 updated_source,
4014 "hash2",
4015 &[
4016 ("hello", "def hello():", "hello"),
4017 ("world", "def world(name):", "world"),
4018 ],
4019 );
4020
4021 let count = engine.index_from_project(&project).unwrap();
4022 assert_eq!(count, 2);
4023
4024 let hello = engine
4025 .store
4026 .get_embedding("main.py", "hello")
4027 .unwrap()
4028 .expect("hello should exist");
4029 let world = engine
4030 .store
4031 .get_embedding("main.py", "world")
4032 .unwrap()
4033 .expect("world should exist");
4034 assert!(hello.embedding.iter().all(|value| *value == 11.0));
4035 assert!(world.embedding.iter().any(|value| *value != 22.0));
4036 assert_eq!(engine.store.count().unwrap(), 2);
4037 }
4038
4039 #[test]
4040 fn full_reindex_removes_deleted_files() {
4041 let _lock = MODEL_LOCK.lock().unwrap();
4042 skip_without_embedding_model!();
4043 let (dir, project) = make_project_with_source();
4044 write_python_file_with_symbols(
4045 dir.path(),
4046 "extra.py",
4047 "def bonus():\n return 7\n",
4048 "hash-extra",
4049 &[("bonus", "def bonus():", "bonus")],
4050 );
4051
4052 let engine = EmbeddingEngine::new(&project).unwrap();
4053 engine.index_from_project(&project).unwrap();
4054 assert_eq!(engine.store.count().unwrap(), 3);
4055
4056 std::fs::remove_file(dir.path().join("extra.py")).unwrap();
4057 let db_path = crate::db::index_db_path(dir.path());
4058 let db = IndexDb::open(&db_path).unwrap();
4059 db.delete_file("extra.py").unwrap();
4060
4061 let count = engine.index_from_project(&project).unwrap();
4062 assert_eq!(count, 2);
4063 assert_eq!(engine.store.count().unwrap(), 2);
4064 assert!(
4065 engine
4066 .store
4067 .embeddings_for_files(&["extra.py"])
4068 .unwrap()
4069 .is_empty()
4070 );
4071 }
4072
4073 #[test]
4074 fn engine_model_change_recreates_db() {
4075 let _lock = MODEL_LOCK.lock().unwrap();
4076 skip_without_embedding_model!();
4077 let (_dir, project) = make_project_with_source();
4078
4079 let engine1 = EmbeddingEngine::new(&project).unwrap();
4081 engine1.index_from_project(&project).unwrap();
4082 assert_eq!(engine1.store.count().unwrap(), 2);
4083 drop(engine1);
4084
4085 let engine2 = EmbeddingEngine::new(&project).unwrap();
4087 assert!(engine2.store.count().unwrap() >= 2);
4088 }
4089
4090 #[test]
4091 fn inspect_existing_index_returns_model_and_count() {
4092 let _lock = MODEL_LOCK.lock().unwrap();
4093 skip_without_embedding_model!();
4094 let (_dir, project) = make_project_with_source();
4095 let engine = EmbeddingEngine::new(&project).unwrap();
4096 engine.index_from_project(&project).unwrap();
4097
4098 let info = EmbeddingEngine::inspect_existing_index(&project)
4099 .unwrap()
4100 .expect("index info should exist");
4101 assert_eq!(info.model_name, engine.model_name());
4102 assert_eq!(info.indexed_symbols, 2);
4103 }
4104
4105 #[test]
4106 fn inspect_existing_index_recovers_from_corrupt_db() {
4107 let (_dir, project) = make_project_with_source();
4108 let index_dir = project.as_path().join(".codelens/index");
4109 let db_path = index_dir.join("embeddings.db");
4110 let wal_path = index_dir.join("embeddings.db-wal");
4111 let shm_path = index_dir.join("embeddings.db-shm");
4112
4113 std::fs::write(&db_path, b"not a sqlite database").unwrap();
4114 std::fs::write(&wal_path, b"bad wal").unwrap();
4115 std::fs::write(&shm_path, b"bad shm").unwrap();
4116
4117 let info = EmbeddingEngine::inspect_existing_index(&project).unwrap();
4118 assert!(info.is_none());
4119
4120 assert!(db_path.is_file());
4121
4122 let backup_names: Vec<String> = std::fs::read_dir(&index_dir)
4123 .unwrap()
4124 .map(|entry| entry.unwrap().file_name().to_string_lossy().into_owned())
4125 .filter(|name| name.contains(".corrupt-"))
4126 .collect();
4127
4128 assert!(
4129 backup_names
4130 .iter()
4131 .any(|name| name.starts_with("embeddings.db.corrupt-")),
4132 "expected quarantined embedding db, found {backup_names:?}"
4133 );
4134 }
4135
4136 #[test]
4137 fn store_can_fetch_single_embedding_without_loading_all() {
4138 let _lock = MODEL_LOCK.lock().unwrap();
4139 skip_without_embedding_model!();
4140 let (_dir, project) = make_project_with_source();
4141 let engine = EmbeddingEngine::new(&project).unwrap();
4142 engine.index_from_project(&project).unwrap();
4143
4144 let chunk = engine
4145 .store
4146 .get_embedding("main.py", "hello")
4147 .unwrap()
4148 .expect("embedding should exist");
4149 assert_eq!(chunk.file_path, "main.py");
4150 assert_eq!(chunk.symbol_name, "hello");
4151 assert!(!chunk.embedding.is_empty());
4152 }
4153
4154 #[test]
4155 fn find_similar_code_uses_index_and_excludes_target_symbol() {
4156 let _lock = MODEL_LOCK.lock().unwrap();
4157 skip_without_embedding_model!();
4158 let (_dir, project) = make_project_with_source();
4159 let engine = EmbeddingEngine::new(&project).unwrap();
4160 engine.index_from_project(&project).unwrap();
4161
4162 let matches = engine.find_similar_code("main.py", "hello", 5).unwrap();
4163 assert!(!matches.is_empty());
4164 assert!(
4165 matches
4166 .iter()
4167 .all(|m| !(m.file_path == "main.py" && m.symbol_name == "hello"))
4168 );
4169 }
4170
4171 #[test]
4172 fn delete_by_file_removes_rows_in_one_batch() {
4173 let _lock = MODEL_LOCK.lock().unwrap();
4174 skip_without_embedding_model!();
4175 let (_dir, project) = make_project_with_source();
4176 let engine = EmbeddingEngine::new(&project).unwrap();
4177 engine.index_from_project(&project).unwrap();
4178
4179 let deleted = engine.store.delete_by_file(&["main.py"]).unwrap();
4180 assert_eq!(deleted, 2);
4181 assert_eq!(engine.store.count().unwrap(), 0);
4182 }
4183
4184 #[test]
4185 fn store_streams_embeddings_grouped_by_file() {
4186 let _lock = MODEL_LOCK.lock().unwrap();
4187 skip_without_embedding_model!();
4188 let (_dir, project) = make_project_with_source();
4189 let engine = EmbeddingEngine::new(&project).unwrap();
4190 engine.index_from_project(&project).unwrap();
4191
4192 let mut groups = Vec::new();
4193 engine
4194 .store
4195 .for_each_file_embeddings(&mut |file_path, chunks| {
4196 groups.push((file_path, chunks.len()));
4197 Ok(())
4198 })
4199 .unwrap();
4200
4201 assert_eq!(groups, vec![("main.py".to_string(), 2)]);
4202 }
4203
4204 #[test]
4205 fn store_fetches_embeddings_for_specific_files() {
4206 let _lock = MODEL_LOCK.lock().unwrap();
4207 skip_without_embedding_model!();
4208 let (_dir, project) = make_project_with_source();
4209 let engine = EmbeddingEngine::new(&project).unwrap();
4210 engine.index_from_project(&project).unwrap();
4211
4212 let chunks = engine.store.embeddings_for_files(&["main.py"]).unwrap();
4213 assert_eq!(chunks.len(), 2);
4214 assert!(chunks.iter().all(|chunk| chunk.file_path == "main.py"));
4215 }
4216
4217 #[test]
4218 fn store_fetches_embeddings_for_scored_chunks() {
4219 let _lock = MODEL_LOCK.lock().unwrap();
4220 skip_without_embedding_model!();
4221 let (_dir, project) = make_project_with_source();
4222 let engine = EmbeddingEngine::new(&project).unwrap();
4223 engine.index_from_project(&project).unwrap();
4224
4225 let scored = engine.search_scored("hello world function", 2).unwrap();
4226 let chunks = engine.store.embeddings_for_scored_chunks(&scored).unwrap();
4227
4228 assert_eq!(chunks.len(), scored.len());
4229 assert!(scored.iter().all(|candidate| chunks.iter().any(|chunk| {
4230 chunk.file_path == candidate.file_path
4231 && chunk.symbol_name == candidate.symbol_name
4232 && chunk.line == candidate.line
4233 && chunk.signature == candidate.signature
4234 && chunk.name_path == candidate.name_path
4235 })));
4236 }
4237
4238 #[test]
4239 fn find_misplaced_code_returns_per_file_outliers() {
4240 let _lock = MODEL_LOCK.lock().unwrap();
4241 skip_without_embedding_model!();
4242 let (_dir, project) = make_project_with_source();
4243 let engine = EmbeddingEngine::new(&project).unwrap();
4244 engine.index_from_project(&project).unwrap();
4245
4246 let outliers = engine.find_misplaced_code(5).unwrap();
4247 assert_eq!(outliers.len(), 2);
4248 assert!(outliers.iter().all(|item| item.file_path == "main.py"));
4249 }
4250
4251 #[test]
4252 fn find_duplicates_uses_batched_candidate_embeddings() {
4253 let _lock = MODEL_LOCK.lock().unwrap();
4254 skip_without_embedding_model!();
4255 let (_dir, project) = make_project_with_source();
4256 let engine = EmbeddingEngine::new(&project).unwrap();
4257 engine.index_from_project(&project).unwrap();
4258
4259 replace_file_embeddings_with_sentinels(
4260 &engine,
4261 "main.py",
4262 &[("hello", 5.0), ("world", 5.0)],
4263 );
4264
4265 let duplicates = engine.find_duplicates(0.99, 4).unwrap();
4266 assert!(!duplicates.is_empty());
4267 assert!(duplicates.iter().any(|pair| {
4268 (pair.symbol_a == "main.py:hello" && pair.symbol_b == "main.py:world")
4269 || (pair.symbol_a == "main.py:world" && pair.symbol_b == "main.py:hello")
4270 }));
4271 }
4272
4273 #[test]
4274 fn search_scored_returns_raw_chunks() {
4275 let _lock = MODEL_LOCK.lock().unwrap();
4276 skip_without_embedding_model!();
4277 let (_dir, project) = make_project_with_source();
4278 let engine = EmbeddingEngine::new(&project).unwrap();
4279 engine.index_from_project(&project).unwrap();
4280
4281 let chunks = engine.search_scored("world function", 5).unwrap();
4282 assert!(!chunks.is_empty());
4283 for c in &chunks {
4284 assert!(!c.file_path.is_empty());
4285 assert!(!c.symbol_name.is_empty());
4286 }
4287 }
4288
4289 #[test]
4290 fn configured_embedding_model_name_defaults_to_codesearchnet() {
4291 assert_eq!(configured_embedding_model_name(), CODESEARCH_MODEL_NAME);
4292 }
4293
4294 #[test]
4295 fn recommended_embed_threads_caps_macos_style_load() {
4296 let threads = recommended_embed_threads();
4297 assert!(threads >= 1);
4298 assert!(threads <= 8);
4299 }
4300
4301 #[test]
4302 fn embed_batch_size_has_safe_default_floor() {
4303 assert!(embed_batch_size() >= 1);
4304 if cfg!(target_os = "macos") {
4305 assert!(embed_batch_size() <= DEFAULT_MACOS_EMBED_BATCH_SIZE);
4306 }
4307 }
4308}