1use crate::db::IndexDb;
5use crate::embedding_store::{EmbeddingChunk, EmbeddingStore, ScoredChunk};
6use crate::project::ProjectRoot;
7use anyhow::{Context, Result};
8use fastembed::{
9 ExecutionProviderDispatch, InitOptionsUserDefined, TextEmbedding, TokenizerFiles,
10 UserDefinedEmbeddingModel,
11};
12use rusqlite::Connection;
13use serde::Serialize;
14use std::collections::{HashMap, HashSet, VecDeque};
15use std::sync::{Arc, Mutex, Once};
16use std::thread::available_parallelism;
17use tracing::debug;
18
19pub(super) mod ffi {
21 use anyhow::Result;
22
23 pub fn register_sqlite_vec() -> Result<()> {
24 let rc = unsafe {
25 rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute::<
26 *const (),
27 unsafe extern "C" fn(
28 *mut rusqlite::ffi::sqlite3,
29 *mut *mut i8,
30 *const rusqlite::ffi::sqlite3_api_routines,
31 ) -> i32,
32 >(
33 sqlite_vec::sqlite3_vec_init as *const ()
34 )))
35 };
36 if rc != rusqlite::ffi::SQLITE_OK {
37 anyhow::bail!("failed to register sqlite-vec extension (SQLite error code: {rc})");
38 }
39 Ok(())
40 }
41
42 #[cfg(target_os = "macos")]
43 pub fn sysctl_usize(name: &[u8]) -> Option<usize> {
44 let mut value: libc::c_uint = 0;
45 let mut size = std::mem::size_of::<libc::c_uint>();
46 let rc = unsafe {
47 libc::sysctlbyname(
48 name.as_ptr().cast(),
49 (&mut value as *mut libc::c_uint).cast(),
50 &mut size,
51 std::ptr::null_mut(),
52 0,
53 )
54 };
55 (rc == 0 && size == std::mem::size_of::<libc::c_uint>()).then_some(value as usize)
56 }
57}
58
59#[derive(Debug, Clone, Serialize)]
61pub struct SemanticMatch {
62 pub file_path: String,
63 pub symbol_name: String,
64 pub kind: String,
65 pub line: usize,
66 pub signature: String,
67 pub name_path: String,
68 pub score: f64,
69}
70
71impl From<ScoredChunk> for SemanticMatch {
72 fn from(c: ScoredChunk) -> Self {
73 Self {
74 file_path: c.file_path,
75 symbol_name: c.symbol_name,
76 kind: c.kind,
77 line: c.line,
78 signature: c.signature,
79 name_path: c.name_path,
80 score: c.score,
81 }
82 }
83}
84
85mod vec_store;
86use vec_store::SqliteVecStore;
87
88type ReusableEmbeddingKey = (String, String, String, String, String, String);
89
90fn reusable_embedding_key(
91 file_path: &str,
92 symbol_name: &str,
93 kind: &str,
94 signature: &str,
95 name_path: &str,
96 text: &str,
97) -> ReusableEmbeddingKey {
98 (
99 file_path.to_owned(),
100 symbol_name.to_owned(),
101 kind.to_owned(),
102 signature.to_owned(),
103 name_path.to_owned(),
104 text.to_owned(),
105 )
106}
107
108fn reusable_embedding_key_for_chunk(chunk: &EmbeddingChunk) -> ReusableEmbeddingKey {
109 reusable_embedding_key(
110 &chunk.file_path,
111 &chunk.symbol_name,
112 &chunk.kind,
113 &chunk.signature,
114 &chunk.name_path,
115 &chunk.text,
116 )
117}
118
119fn reusable_embedding_key_for_symbol(
120 sym: &crate::db::SymbolWithFile,
121 text: &str,
122) -> ReusableEmbeddingKey {
123 reusable_embedding_key(
124 &sym.file_path,
125 &sym.name,
126 &sym.kind,
127 &sym.signature,
128 &sym.name_path,
129 text,
130 )
131}
132
133const DEFAULT_EMBED_BATCH_SIZE: usize = 128;
136const DEFAULT_MACOS_EMBED_BATCH_SIZE: usize = 128;
137const DEFAULT_TEXT_EMBED_CACHE_SIZE: usize = 256;
138const DEFAULT_MACOS_TEXT_EMBED_CACHE_SIZE: usize = 1024;
139const CODESEARCH_DIMENSION: usize = 384;
140const DEFAULT_MAX_EMBED_SYMBOLS: usize = 50_000;
141const CHANGED_FILE_QUERY_CHUNK: usize = 128;
142const DEFAULT_DUPLICATE_SCAN_BATCH_SIZE: usize = 128;
143static ORT_ENV_INIT: Once = Once::new();
144
145const CODESEARCH_MODEL_NAME: &str = "MiniLM-L12-CodeSearchNet-INT8";
148
149pub struct EmbeddingEngine {
150 model: Mutex<TextEmbedding>,
151 store: Box<dyn EmbeddingStore>,
152 model_name: String,
153 runtime_info: EmbeddingRuntimeInfo,
154 text_embed_cache: Mutex<TextEmbeddingCache>,
155 indexing: std::sync::atomic::AtomicBool,
156}
157
158#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
159pub struct EmbeddingIndexInfo {
160 pub model_name: String,
161 pub indexed_symbols: usize,
162}
163
164#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
165pub struct EmbeddingRuntimeInfo {
166 pub runtime_preference: String,
167 pub backend: String,
168 pub threads: usize,
169 pub max_length: usize,
170 pub coreml_model_format: Option<String>,
171 pub coreml_compute_units: Option<String>,
172 pub coreml_static_input_shapes: Option<bool>,
173 pub coreml_profile_compute_plan: Option<bool>,
174 pub coreml_specialization_strategy: Option<String>,
175 pub coreml_model_cache_dir: Option<String>,
176 pub fallback_reason: Option<String>,
177}
178
179struct TextEmbeddingCache {
180 capacity: usize,
181 order: VecDeque<String>,
182 entries: HashMap<String, Vec<f32>>,
183}
184
185impl TextEmbeddingCache {
186 fn new(capacity: usize) -> Self {
187 Self {
188 capacity,
189 order: VecDeque::new(),
190 entries: HashMap::new(),
191 }
192 }
193
194 fn get(&mut self, key: &str) -> Option<Vec<f32>> {
195 let value = self.entries.get(key)?.clone();
196 self.touch(key);
197 Some(value)
198 }
199
200 fn insert(&mut self, key: String, value: Vec<f32>) {
201 if self.capacity == 0 {
202 return;
203 }
204
205 self.entries.insert(key.clone(), value);
206 self.touch(&key);
207
208 while self.entries.len() > self.capacity {
209 if let Some(oldest) = self.order.pop_front() {
210 self.entries.remove(&oldest);
211 } else {
212 break;
213 }
214 }
215 }
216
217 fn touch(&mut self, key: &str) {
218 if let Some(position) = self.order.iter().position(|existing| existing == key) {
219 self.order.remove(position);
220 }
221 self.order.push_back(key.to_owned());
222 }
223}
224
225fn resolve_model_dir() -> Result<std::path::PathBuf> {
233 if let Ok(dir) = std::env::var("CODELENS_MODEL_DIR") {
235 let p = std::path::PathBuf::from(dir).join("codesearch");
236 if p.join("model.onnx").exists() {
237 return Ok(p);
238 }
239 }
240
241 if let Ok(exe) = std::env::current_exe()
243 && let Some(exe_dir) = exe.parent()
244 {
245 let p = exe_dir.join("models").join("codesearch");
246 if p.join("model.onnx").exists() {
247 return Ok(p);
248 }
249 }
250
251 if let Some(home) = dirs_fallback() {
253 let p = home
254 .join(".cache")
255 .join("codelens")
256 .join("models")
257 .join("codesearch");
258 if p.join("model.onnx").exists() {
259 return Ok(p);
260 }
261 }
262
263 let dev_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
265 .join("models")
266 .join("codesearch");
267 if dev_path.join("model.onnx").exists() {
268 return Ok(dev_path);
269 }
270
271 anyhow::bail!(
272 "CodeSearchNet model not found. Place model files in one of:\n\
273 - $CODELENS_MODEL_DIR/codesearch/\n\
274 - <executable>/models/codesearch/\n\
275 - ~/.cache/codelens/models/codesearch/\n\
276 Required files: model.onnx, tokenizer.json, config.json, special_tokens_map.json, tokenizer_config.json"
277 )
278}
279
280fn dirs_fallback() -> Option<std::path::PathBuf> {
281 std::env::var_os("HOME").map(std::path::PathBuf::from)
282}
283
284fn parse_usize_env(name: &str) -> Option<usize> {
285 std::env::var(name)
286 .ok()
287 .and_then(|v| v.trim().parse::<usize>().ok())
288 .filter(|v| *v > 0)
289}
290
291fn parse_bool_env(name: &str) -> Option<bool> {
292 std::env::var(name).ok().and_then(|value| {
293 let normalized = value.trim().to_ascii_lowercase();
294 match normalized.as_str() {
295 "1" | "true" | "yes" | "on" => Some(true),
296 "0" | "false" | "no" | "off" => Some(false),
297 _ => None,
298 }
299 })
300}
301
302#[cfg(target_os = "macos")]
303fn apple_perf_cores() -> Option<usize> {
304 ffi::sysctl_usize(b"hw.perflevel0.physicalcpu\0")
305 .filter(|value| *value > 0)
306 .or_else(|| ffi::sysctl_usize(b"hw.physicalcpu\0").filter(|value| *value > 0))
307}
308
309#[cfg(not(target_os = "macos"))]
310fn apple_perf_cores() -> Option<usize> {
311 None
312}
313
314pub fn configured_embedding_runtime_preference() -> String {
315 let requested = std::env::var("CODELENS_EMBED_PROVIDER")
316 .ok()
317 .map(|value| value.trim().to_ascii_lowercase());
318
319 match requested.as_deref() {
320 Some("cpu") => "cpu".to_string(),
321 Some("coreml") if cfg!(target_os = "macos") => "coreml".to_string(),
322 Some("coreml") => "cpu".to_string(),
323 _ if cfg!(target_os = "macos") => "coreml_preferred".to_string(),
324 _ => "cpu".to_string(),
325 }
326}
327
328pub fn configured_embedding_threads() -> usize {
329 recommended_embed_threads()
330}
331
332fn configured_embedding_max_length() -> usize {
333 parse_usize_env("CODELENS_EMBED_MAX_LENGTH")
334 .unwrap_or(256)
335 .clamp(32, 512)
336}
337
338fn configured_embedding_text_cache_size() -> usize {
339 std::env::var("CODELENS_EMBED_TEXT_CACHE_SIZE")
340 .ok()
341 .and_then(|value| value.trim().parse::<usize>().ok())
342 .unwrap_or({
343 if cfg!(target_os = "macos") {
344 DEFAULT_MACOS_TEXT_EMBED_CACHE_SIZE
345 } else {
346 DEFAULT_TEXT_EMBED_CACHE_SIZE
347 }
348 })
349 .min(8192)
350}
351
352#[cfg(target_os = "macos")]
353fn configured_coreml_compute_units_name() -> String {
354 match std::env::var("CODELENS_EMBED_COREML_COMPUTE_UNITS")
355 .ok()
356 .map(|value| value.trim().to_ascii_lowercase())
357 .as_deref()
358 {
359 Some("all") => "all".to_string(),
360 Some("cpu") | Some("cpu_only") => "cpu_only".to_string(),
361 Some("gpu") | Some("cpu_and_gpu") => "cpu_and_gpu".to_string(),
362 Some("ane") | Some("neural_engine") | Some("cpu_and_neural_engine") => {
363 "cpu_and_neural_engine".to_string()
364 }
365 _ => "cpu_and_neural_engine".to_string(),
366 }
367}
368
369#[cfg(target_os = "macos")]
370fn configured_coreml_model_format_name() -> String {
371 match std::env::var("CODELENS_EMBED_COREML_MODEL_FORMAT")
372 .ok()
373 .map(|value| value.trim().to_ascii_lowercase())
374 .as_deref()
375 {
376 Some("neuralnetwork") | Some("neural_network") => "neural_network".to_string(),
377 _ => "mlprogram".to_string(),
378 }
379}
380
381#[cfg(target_os = "macos")]
382fn configured_coreml_profile_compute_plan() -> bool {
383 parse_bool_env("CODELENS_EMBED_COREML_PROFILE_PLAN").unwrap_or(false)
384}
385
386#[cfg(target_os = "macos")]
387fn configured_coreml_static_input_shapes() -> bool {
388 parse_bool_env("CODELENS_EMBED_COREML_STATIC_INPUT_SHAPES").unwrap_or(true)
389}
390
391#[cfg(target_os = "macos")]
392fn configured_coreml_specialization_strategy_name() -> String {
393 match std::env::var("CODELENS_EMBED_COREML_SPECIALIZATION")
394 .ok()
395 .map(|value| value.trim().to_ascii_lowercase())
396 .as_deref()
397 {
398 Some("default") => "default".to_string(),
399 _ => "fast_prediction".to_string(),
400 }
401}
402
403#[cfg(target_os = "macos")]
404fn configured_coreml_model_cache_dir() -> std::path::PathBuf {
405 dirs_fallback()
406 .unwrap_or_else(std::env::temp_dir)
407 .join(".cache")
408 .join("codelens")
409 .join("coreml-cache")
410 .join("codesearch")
411}
412
413fn recommended_embed_threads() -> usize {
414 if let Some(explicit) = parse_usize_env("CODELENS_EMBED_THREADS") {
415 return explicit.max(1);
416 }
417
418 let available = available_parallelism().map(|n| n.get()).unwrap_or(1);
419 if cfg!(target_os = "macos") {
420 apple_perf_cores()
421 .unwrap_or(available)
422 .min(available)
423 .clamp(1, 8)
424 } else {
425 available.div_ceil(2).clamp(1, 8)
426 }
427}
428
429fn embed_batch_size() -> usize {
430 parse_usize_env("CODELENS_EMBED_BATCH_SIZE").unwrap_or({
431 if cfg!(target_os = "macos") {
432 DEFAULT_MACOS_EMBED_BATCH_SIZE
433 } else {
434 DEFAULT_EMBED_BATCH_SIZE
435 }
436 })
437}
438
439fn max_embed_symbols() -> usize {
440 parse_usize_env("CODELENS_MAX_EMBED_SYMBOLS").unwrap_or(DEFAULT_MAX_EMBED_SYMBOLS)
441}
442
443fn set_env_if_unset(name: &str, value: impl Into<String>) {
444 if std::env::var_os(name).is_none() {
445 unsafe {
448 std::env::set_var(name, value.into());
449 }
450 }
451}
452
453fn configure_embedding_runtime() {
454 let threads = recommended_embed_threads();
455 let runtime_preference = configured_embedding_runtime_preference();
456
457 set_env_if_unset("OMP_NUM_THREADS", threads.to_string());
460 set_env_if_unset("OMP_WAIT_POLICY", "PASSIVE");
461 set_env_if_unset("OMP_DYNAMIC", "FALSE");
462 set_env_if_unset("TOKENIZERS_PARALLELISM", "false");
463 if cfg!(target_os = "macos") {
464 set_env_if_unset("VECLIB_MAXIMUM_THREADS", threads.to_string());
465 }
466
467 ORT_ENV_INIT.call_once(|| {
468 let pool = ort::environment::GlobalThreadPoolOptions::default()
469 .with_intra_threads(threads)
470 .and_then(|pool| pool.with_inter_threads(1))
471 .and_then(|pool| pool.with_spin_control(false));
472
473 if let Ok(pool) = pool {
474 let _ = ort::init()
475 .with_name("codelens-embedding")
476 .with_telemetry(false)
477 .with_global_thread_pool(pool)
478 .commit();
479 }
480 });
481
482 debug!(
483 threads,
484 runtime_preference = %runtime_preference,
485 "configured embedding runtime"
486 );
487}
488
489pub fn configured_embedding_runtime_info() -> EmbeddingRuntimeInfo {
490 let runtime_preference = configured_embedding_runtime_preference();
491 let threads = configured_embedding_threads();
492
493 #[cfg(target_os = "macos")]
494 {
495 let coreml_enabled = runtime_preference != "cpu";
496 EmbeddingRuntimeInfo {
497 runtime_preference,
498 backend: "not_loaded".to_string(),
499 threads,
500 max_length: configured_embedding_max_length(),
501 coreml_model_format: coreml_enabled.then(configured_coreml_model_format_name),
502 coreml_compute_units: coreml_enabled.then(configured_coreml_compute_units_name),
503 coreml_static_input_shapes: coreml_enabled.then(configured_coreml_static_input_shapes),
504 coreml_profile_compute_plan: coreml_enabled
505 .then(configured_coreml_profile_compute_plan),
506 coreml_specialization_strategy: coreml_enabled
507 .then(configured_coreml_specialization_strategy_name),
508 coreml_model_cache_dir: coreml_enabled
509 .then(|| configured_coreml_model_cache_dir().display().to_string()),
510 fallback_reason: None,
511 }
512 }
513
514 #[cfg(not(target_os = "macos"))]
515 {
516 EmbeddingRuntimeInfo {
517 runtime_preference,
518 backend: "not_loaded".to_string(),
519 threads,
520 max_length: configured_embedding_max_length(),
521 coreml_model_format: None,
522 coreml_compute_units: None,
523 coreml_static_input_shapes: None,
524 coreml_profile_compute_plan: None,
525 coreml_specialization_strategy: None,
526 coreml_model_cache_dir: None,
527 fallback_reason: None,
528 }
529 }
530}
531
532#[cfg(target_os = "macos")]
533fn build_coreml_execution_provider() -> ExecutionProviderDispatch {
534 use ort::ep::{
535 CoreML,
536 coreml::{ComputeUnits, ModelFormat, SpecializationStrategy},
537 };
538
539 let compute_units = match configured_coreml_compute_units_name().as_str() {
540 "all" => ComputeUnits::All,
541 "cpu_only" => ComputeUnits::CPUOnly,
542 "cpu_and_gpu" => ComputeUnits::CPUAndGPU,
543 _ => ComputeUnits::CPUAndNeuralEngine,
544 };
545 let model_format = match configured_coreml_model_format_name().as_str() {
546 "neural_network" => ModelFormat::NeuralNetwork,
547 _ => ModelFormat::MLProgram,
548 };
549 let specialization = match configured_coreml_specialization_strategy_name().as_str() {
550 "default" => SpecializationStrategy::Default,
551 _ => SpecializationStrategy::FastPrediction,
552 };
553 let cache_dir = configured_coreml_model_cache_dir();
554 let _ = std::fs::create_dir_all(&cache_dir);
555
556 CoreML::default()
557 .with_model_format(model_format)
558 .with_compute_units(compute_units)
559 .with_static_input_shapes(configured_coreml_static_input_shapes())
560 .with_specialization_strategy(specialization)
561 .with_profile_compute_plan(configured_coreml_profile_compute_plan())
562 .with_model_cache_dir(cache_dir.display().to_string())
563 .build()
564 .error_on_failure()
565}
566
567fn cpu_runtime_info(
568 runtime_preference: String,
569 fallback_reason: Option<String>,
570) -> EmbeddingRuntimeInfo {
571 EmbeddingRuntimeInfo {
572 runtime_preference,
573 backend: "cpu".to_string(),
574 threads: configured_embedding_threads(),
575 max_length: configured_embedding_max_length(),
576 coreml_model_format: None,
577 coreml_compute_units: None,
578 coreml_static_input_shapes: None,
579 coreml_profile_compute_plan: None,
580 coreml_specialization_strategy: None,
581 coreml_model_cache_dir: None,
582 fallback_reason,
583 }
584}
585
586#[cfg(target_os = "macos")]
587fn coreml_runtime_info(
588 runtime_preference: String,
589 fallback_reason: Option<String>,
590) -> EmbeddingRuntimeInfo {
591 EmbeddingRuntimeInfo {
592 runtime_preference,
593 backend: if fallback_reason.is_some() {
594 "cpu".to_string()
595 } else {
596 "coreml".to_string()
597 },
598 threads: configured_embedding_threads(),
599 max_length: configured_embedding_max_length(),
600 coreml_model_format: Some(configured_coreml_model_format_name()),
601 coreml_compute_units: Some(configured_coreml_compute_units_name()),
602 coreml_static_input_shapes: Some(configured_coreml_static_input_shapes()),
603 coreml_profile_compute_plan: Some(configured_coreml_profile_compute_plan()),
604 coreml_specialization_strategy: Some(configured_coreml_specialization_strategy_name()),
605 coreml_model_cache_dir: Some(configured_coreml_model_cache_dir().display().to_string()),
606 fallback_reason,
607 }
608}
609
610fn load_codesearch_model() -> Result<(TextEmbedding, usize, String, EmbeddingRuntimeInfo)> {
612 configure_embedding_runtime();
613 let model_dir = resolve_model_dir()?;
614
615 let onnx_bytes =
616 std::fs::read(model_dir.join("model.onnx")).context("failed to read model.onnx")?;
617 let tokenizer_bytes =
618 std::fs::read(model_dir.join("tokenizer.json")).context("failed to read tokenizer.json")?;
619 let config_bytes =
620 std::fs::read(model_dir.join("config.json")).context("failed to read config.json")?;
621 let special_tokens_bytes = std::fs::read(model_dir.join("special_tokens_map.json"))
622 .context("failed to read special_tokens_map.json")?;
623 let tokenizer_config_bytes = std::fs::read(model_dir.join("tokenizer_config.json"))
624 .context("failed to read tokenizer_config.json")?;
625
626 let user_model = UserDefinedEmbeddingModel::new(
627 onnx_bytes,
628 TokenizerFiles {
629 tokenizer_file: tokenizer_bytes,
630 config_file: config_bytes,
631 special_tokens_map_file: special_tokens_bytes,
632 tokenizer_config_file: tokenizer_config_bytes,
633 },
634 );
635
636 let runtime_preference = configured_embedding_runtime_preference();
637
638 #[cfg(target_os = "macos")]
639 if runtime_preference != "cpu" {
640 let init_opts = InitOptionsUserDefined::new()
641 .with_max_length(configured_embedding_max_length())
642 .with_execution_providers(vec![build_coreml_execution_provider()]);
643 match TextEmbedding::try_new_from_user_defined(user_model.clone(), init_opts) {
644 Ok(model) => {
645 let runtime_info = coreml_runtime_info(runtime_preference.clone(), None);
646 debug!(
647 threads = runtime_info.threads,
648 runtime_preference = %runtime_info.runtime_preference,
649 backend = %runtime_info.backend,
650 coreml_compute_units = ?runtime_info.coreml_compute_units,
651 coreml_static_input_shapes = ?runtime_info.coreml_static_input_shapes,
652 coreml_profile_compute_plan = ?runtime_info.coreml_profile_compute_plan,
653 coreml_specialization_strategy = ?runtime_info.coreml_specialization_strategy,
654 coreml_model_cache_dir = ?runtime_info.coreml_model_cache_dir,
655 "loaded CodeSearchNet embedding model"
656 );
657 return Ok((
658 model,
659 CODESEARCH_DIMENSION,
660 CODESEARCH_MODEL_NAME.to_string(),
661 runtime_info,
662 ));
663 }
664 Err(err) => {
665 let reason = err.to_string();
666 debug!(
667 runtime_preference = %runtime_preference,
668 fallback_reason = %reason,
669 "CoreML embedding load failed; falling back to CPU"
670 );
671 let model = TextEmbedding::try_new_from_user_defined(
672 user_model,
673 InitOptionsUserDefined::new()
674 .with_max_length(configured_embedding_max_length()),
675 )
676 .context("failed to load CodeSearchNet embedding model")?;
677 let runtime_info = coreml_runtime_info(runtime_preference.clone(), Some(reason));
678 debug!(
679 threads = runtime_info.threads,
680 runtime_preference = %runtime_info.runtime_preference,
681 backend = %runtime_info.backend,
682 coreml_compute_units = ?runtime_info.coreml_compute_units,
683 coreml_static_input_shapes = ?runtime_info.coreml_static_input_shapes,
684 coreml_profile_compute_plan = ?runtime_info.coreml_profile_compute_plan,
685 coreml_specialization_strategy = ?runtime_info.coreml_specialization_strategy,
686 coreml_model_cache_dir = ?runtime_info.coreml_model_cache_dir,
687 fallback_reason = ?runtime_info.fallback_reason,
688 "loaded CodeSearchNet embedding model"
689 );
690 return Ok((
691 model,
692 CODESEARCH_DIMENSION,
693 CODESEARCH_MODEL_NAME.to_string(),
694 runtime_info,
695 ));
696 }
697 }
698 }
699
700 let model = TextEmbedding::try_new_from_user_defined(
701 user_model,
702 InitOptionsUserDefined::new().with_max_length(configured_embedding_max_length()),
703 )
704 .context("failed to load CodeSearchNet embedding model")?;
705 let runtime_info = cpu_runtime_info(runtime_preference.clone(), None);
706
707 debug!(
708 threads = runtime_info.threads,
709 runtime_preference = %runtime_info.runtime_preference,
710 backend = %runtime_info.backend,
711 "loaded CodeSearchNet embedding model"
712 );
713
714 Ok((
715 model,
716 CODESEARCH_DIMENSION,
717 CODESEARCH_MODEL_NAME.to_string(),
718 runtime_info,
719 ))
720}
721
722pub fn configured_embedding_model_name() -> String {
723 std::env::var("CODELENS_EMBED_MODEL").unwrap_or_else(|_| CODESEARCH_MODEL_NAME.to_string())
724}
725
726pub fn embedding_model_assets_available() -> bool {
727 resolve_model_dir().is_ok()
728}
729
730impl EmbeddingEngine {
731 fn embed_texts_cached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
732 if texts.is_empty() {
733 return Ok(Vec::new());
734 }
735
736 let mut resolved: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
737 let mut missing_order: Vec<String> = Vec::new();
738 let mut missing_positions: HashMap<String, Vec<usize>> = HashMap::new();
739
740 {
741 let mut cache = self
742 .text_embed_cache
743 .lock()
744 .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
745 for (index, text) in texts.iter().enumerate() {
746 if let Some(cached) = cache.get(text) {
747 resolved[index] = Some(cached);
748 } else {
749 let key = (*text).to_owned();
750 if !missing_positions.contains_key(&key) {
751 missing_order.push(key.clone());
752 }
753 missing_positions.entry(key).or_default().push(index);
754 }
755 }
756 }
757
758 if !missing_order.is_empty() {
759 let missing_refs: Vec<&str> = missing_order.iter().map(String::as_str).collect();
760 let embeddings = self
761 .model
762 .lock()
763 .map_err(|_| anyhow::anyhow!("model lock"))?
764 .embed(missing_refs, None)
765 .context("text embedding failed")?;
766
767 let mut cache = self
768 .text_embed_cache
769 .lock()
770 .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
771 for (text, embedding) in missing_order.into_iter().zip(embeddings.into_iter()) {
772 cache.insert(text.clone(), embedding.clone());
773 if let Some(indices) = missing_positions.remove(&text) {
774 for index in indices {
775 resolved[index] = Some(embedding.clone());
776 }
777 }
778 }
779 }
780
781 resolved
782 .into_iter()
783 .map(|item| item.ok_or_else(|| anyhow::anyhow!("missing embedding cache entry")))
784 .collect()
785 }
786
787 pub fn new(project: &ProjectRoot) -> Result<Self> {
788 let (model, dimension, model_name, runtime_info) = load_codesearch_model()?;
789
790 let db_dir = project.as_path().join(".codelens/index");
791 std::fs::create_dir_all(&db_dir)?;
792 let db_path = db_dir.join("embeddings.db");
793
794 let store = SqliteVecStore::new(&db_path, dimension, &model_name)?;
795
796 Ok(Self {
797 model: Mutex::new(model),
798 store: Box::new(store),
799 model_name,
800 runtime_info,
801 text_embed_cache: Mutex::new(TextEmbeddingCache::new(
802 configured_embedding_text_cache_size(),
803 )),
804 indexing: std::sync::atomic::AtomicBool::new(false),
805 })
806 }
807
808 pub fn model_name(&self) -> &str {
809 &self.model_name
810 }
811
812 pub fn runtime_info(&self) -> &EmbeddingRuntimeInfo {
813 &self.runtime_info
814 }
815
816 pub fn is_indexing(&self) -> bool {
823 self.indexing.load(std::sync::atomic::Ordering::Relaxed)
824 }
825
826 pub fn index_from_project(&self, project: &ProjectRoot) -> Result<usize> {
827 if self
829 .indexing
830 .compare_exchange(
831 false,
832 true,
833 std::sync::atomic::Ordering::AcqRel,
834 std::sync::atomic::Ordering::Relaxed,
835 )
836 .is_err()
837 {
838 anyhow::bail!(
839 "Embedding indexing already in progress — wait for the current run to complete before retrying."
840 );
841 }
842 struct IndexGuard<'a>(&'a std::sync::atomic::AtomicBool);
844 impl Drop for IndexGuard<'_> {
845 fn drop(&mut self) {
846 self.0.store(false, std::sync::atomic::Ordering::Release);
847 }
848 }
849 let _guard = IndexGuard(&self.indexing);
850
851 let db_path = crate::db::index_db_path(project.as_path());
852 let symbol_db = IndexDb::open(&db_path)?;
853 let batch_size = embed_batch_size();
854 let max_symbols = max_embed_symbols();
855 let mut total_indexed = 0usize;
856 let mut total_seen = 0usize;
857 let mut model = None;
858 let mut existing_embeddings: HashMap<
859 String,
860 HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
861 > = HashMap::new();
862 let mut current_db_files = HashSet::new();
863 let mut capped = false;
864
865 self.store
866 .for_each_file_embeddings(&mut |file_path, chunks| {
867 existing_embeddings.insert(
868 file_path,
869 chunks
870 .into_iter()
871 .map(|chunk| (reusable_embedding_key_for_chunk(&chunk), chunk))
872 .collect(),
873 );
874 Ok(())
875 })?;
876
877 symbol_db.for_each_file_symbols_with_bytes(|file_path, symbols| {
878 current_db_files.insert(file_path.clone());
879 if capped {
880 return Ok(());
881 }
882
883 let source = std::fs::read_to_string(project.as_path().join(&file_path)).ok();
884 let relevant_symbols: Vec<_> = symbols
885 .into_iter()
886 .filter(|sym| !is_test_only_symbol(sym, source.as_deref()))
887 .collect();
888
889 if relevant_symbols.is_empty() {
890 self.store.delete_by_file(&[file_path.as_str()])?;
891 existing_embeddings.remove(&file_path);
892 return Ok(());
893 }
894
895 if total_seen + relevant_symbols.len() > max_symbols {
896 capped = true;
897 return Ok(());
898 }
899 total_seen += relevant_symbols.len();
900
901 let existing_for_file = existing_embeddings.remove(&file_path).unwrap_or_default();
902 total_indexed += self.reconcile_file_embeddings(
903 &file_path,
904 relevant_symbols,
905 source.as_deref(),
906 existing_for_file,
907 batch_size,
908 &mut model,
909 )?;
910 Ok(())
911 })?;
912
913 let removed_files: Vec<String> = existing_embeddings
914 .into_keys()
915 .filter(|file_path| !current_db_files.contains(file_path))
916 .collect();
917 if !removed_files.is_empty() {
918 let removed_refs: Vec<&str> = removed_files.iter().map(String::as_str).collect();
919 self.store.delete_by_file(&removed_refs)?;
920 }
921
922 Ok(total_indexed)
923 }
924
925 fn reconcile_file_embeddings<'a>(
926 &'a self,
927 file_path: &str,
928 symbols: Vec<crate::db::SymbolWithFile>,
929 source: Option<&str>,
930 mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
931 batch_size: usize,
932 model: &mut Option<std::sync::MutexGuard<'a, TextEmbedding>>,
933 ) -> Result<usize> {
934 let mut reconciled_chunks = Vec::with_capacity(symbols.len());
935 let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
936 let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
937
938 for sym in symbols {
939 let text = build_embedding_text(&sym, source);
940 if let Some(existing) =
941 existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
942 {
943 reconciled_chunks.push(EmbeddingChunk {
944 file_path: sym.file_path.clone(),
945 symbol_name: sym.name.clone(),
946 kind: sym.kind.clone(),
947 line: sym.line as usize,
948 signature: sym.signature.clone(),
949 name_path: sym.name_path.clone(),
950 text,
951 embedding: existing.embedding,
952 doc_embedding: existing.doc_embedding,
953 });
954 continue;
955 }
956
957 batch_texts.push(text);
958 batch_meta.push(sym);
959
960 if batch_texts.len() >= batch_size {
961 if model.is_none() {
962 *model = Some(
963 self.model
964 .lock()
965 .map_err(|_| anyhow::anyhow!("model lock"))?,
966 );
967 }
968 reconciled_chunks.extend(Self::embed_chunks(
969 model.as_mut().expect("model lock initialized"),
970 &batch_texts,
971 &batch_meta,
972 )?);
973 batch_texts.clear();
974 batch_meta.clear();
975 }
976 }
977
978 if !batch_texts.is_empty() {
979 if model.is_none() {
980 *model = Some(
981 self.model
982 .lock()
983 .map_err(|_| anyhow::anyhow!("model lock"))?,
984 );
985 }
986 reconciled_chunks.extend(Self::embed_chunks(
987 model.as_mut().expect("model lock initialized"),
988 &batch_texts,
989 &batch_meta,
990 )?);
991 }
992
993 self.store.delete_by_file(&[file_path])?;
994 if reconciled_chunks.is_empty() {
995 return Ok(0);
996 }
997 self.store.insert(&reconciled_chunks)
998 }
999
1000 fn embed_chunks(
1001 model: &mut TextEmbedding,
1002 texts: &[String],
1003 meta: &[crate::db::SymbolWithFile],
1004 ) -> Result<Vec<EmbeddingChunk>> {
1005 let batch_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
1006 let embeddings = model.embed(batch_refs, None).context("embedding failed")?;
1007
1008 Ok(meta
1009 .iter()
1010 .zip(embeddings)
1011 .zip(texts.iter())
1012 .map(|((sym, emb), text)| EmbeddingChunk {
1013 file_path: sym.file_path.clone(),
1014 symbol_name: sym.name.clone(),
1015 kind: sym.kind.clone(),
1016 line: sym.line as usize,
1017 signature: sym.signature.clone(),
1018 name_path: sym.name_path.clone(),
1019 text: text.clone(),
1020 embedding: emb,
1021 doc_embedding: None,
1022 })
1023 .collect())
1024 }
1025
1026 fn flush_batch(
1028 model: &mut TextEmbedding,
1029 store: &dyn EmbeddingStore,
1030 texts: &[String],
1031 meta: &[crate::db::SymbolWithFile],
1032 ) -> Result<usize> {
1033 let chunks = Self::embed_chunks(model, texts, meta)?;
1034 store.insert(&chunks)
1035 }
1036
1037 pub fn search(&self, query: &str, max_results: usize) -> Result<Vec<SemanticMatch>> {
1039 let results = self.search_scored(query, max_results)?;
1040 Ok(results.into_iter().map(SemanticMatch::from).collect())
1041 }
1042
1043 pub fn search_scored(&self, query: &str, max_results: usize) -> Result<Vec<ScoredChunk>> {
1050 let query_embedding = self.embed_texts_cached(&[query])?;
1051
1052 if query_embedding.is_empty() {
1053 return Ok(Vec::new());
1054 }
1055
1056 let candidate_count = max_results.saturating_mul(3).max(max_results);
1058 let mut candidates = self.store.search(&query_embedding[0], candidate_count)?;
1059
1060 if candidates.len() <= max_results {
1061 return Ok(candidates);
1062 }
1063
1064 let query_lower = query.to_lowercase();
1067 let query_tokens: Vec<&str> = query_lower
1068 .split(|c: char| c.is_whitespace() || c == '_' || c == '-')
1069 .filter(|t| t.len() >= 2)
1070 .collect();
1071
1072 if query_tokens.is_empty() {
1073 candidates.truncate(max_results);
1074 return Ok(candidates);
1075 }
1076
1077 for chunk in &mut candidates {
1078 let searchable = format!(
1080 "{} {} {}",
1081 chunk.symbol_name.to_lowercase(),
1082 chunk.signature.to_lowercase(),
1083 chunk.file_path.to_lowercase(),
1084 );
1085 let overlap = query_tokens
1086 .iter()
1087 .filter(|t| searchable.contains(**t))
1088 .count() as f64;
1089 let overlap_ratio = overlap / query_tokens.len().max(1) as f64;
1090 chunk.score = chunk.score * 0.8 + overlap_ratio * 0.2;
1092 }
1093
1094 candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
1095 candidates.truncate(max_results);
1096 Ok(candidates)
1097 }
1098
1099 pub fn index_changed_files(
1101 &self,
1102 project: &ProjectRoot,
1103 changed_files: &[&str],
1104 ) -> Result<usize> {
1105 if changed_files.is_empty() {
1106 return Ok(0);
1107 }
1108 let batch_size = embed_batch_size();
1109 let mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk> = HashMap::new();
1110 for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
1111 for chunk in self.store.embeddings_for_files(file_chunk)? {
1112 existing_embeddings.insert(reusable_embedding_key_for_chunk(&chunk), chunk);
1113 }
1114 }
1115 self.store.delete_by_file(changed_files)?;
1116
1117 let db_path = crate::db::index_db_path(project.as_path());
1118 let symbol_db = IndexDb::open(&db_path)?;
1119
1120 let mut total_indexed = 0usize;
1121 let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
1122 let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
1123 let mut batch_reused: Vec<EmbeddingChunk> = Vec::with_capacity(batch_size);
1124 let mut file_cache: std::collections::HashMap<String, Option<String>> =
1125 std::collections::HashMap::new();
1126 let mut model = None;
1127
1128 for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
1129 let relevant = symbol_db.symbols_for_files(file_chunk)?;
1130 for sym in relevant {
1131 let source = file_cache.entry(sym.file_path.clone()).or_insert_with(|| {
1132 std::fs::read_to_string(project.as_path().join(&sym.file_path)).ok()
1133 });
1134 if is_test_only_symbol(&sym, source.as_deref()) {
1135 continue;
1136 }
1137 let text = build_embedding_text(&sym, source.as_deref());
1138 if let Some(existing) =
1139 existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
1140 {
1141 batch_reused.push(EmbeddingChunk {
1142 file_path: sym.file_path.clone(),
1143 symbol_name: sym.name.clone(),
1144 kind: sym.kind.clone(),
1145 line: sym.line as usize,
1146 signature: sym.signature.clone(),
1147 name_path: sym.name_path.clone(),
1148 text,
1149 embedding: existing.embedding,
1150 doc_embedding: existing.doc_embedding,
1151 });
1152 if batch_reused.len() >= batch_size {
1153 total_indexed += self.store.insert(&batch_reused)?;
1154 batch_reused.clear();
1155 }
1156 continue;
1157 }
1158 batch_texts.push(text);
1159 batch_meta.push(sym);
1160
1161 if batch_texts.len() >= batch_size {
1162 if model.is_none() {
1163 model = Some(
1164 self.model
1165 .lock()
1166 .map_err(|_| anyhow::anyhow!("model lock"))?,
1167 );
1168 }
1169 total_indexed += Self::flush_batch(
1170 model.as_mut().expect("model lock initialized"),
1171 &*self.store,
1172 &batch_texts,
1173 &batch_meta,
1174 )?;
1175 batch_texts.clear();
1176 batch_meta.clear();
1177 }
1178 }
1179 }
1180
1181 if !batch_reused.is_empty() {
1182 total_indexed += self.store.insert(&batch_reused)?;
1183 }
1184
1185 if !batch_texts.is_empty() {
1186 if model.is_none() {
1187 model = Some(
1188 self.model
1189 .lock()
1190 .map_err(|_| anyhow::anyhow!("model lock"))?,
1191 );
1192 }
1193 total_indexed += Self::flush_batch(
1194 model.as_mut().expect("model lock initialized"),
1195 &*self.store,
1196 &batch_texts,
1197 &batch_meta,
1198 )?;
1199 }
1200
1201 Ok(total_indexed)
1202 }
1203
1204 pub fn is_indexed(&self) -> bool {
1206 self.store.count().unwrap_or(0) > 0
1207 }
1208
1209 pub fn index_info(&self) -> EmbeddingIndexInfo {
1210 EmbeddingIndexInfo {
1211 model_name: self.model_name.clone(),
1212 indexed_symbols: self.store.count().unwrap_or(0),
1213 }
1214 }
1215
1216 pub fn inspect_existing_index(project: &ProjectRoot) -> Result<Option<EmbeddingIndexInfo>> {
1217 let db_path = project.as_path().join(".codelens/index/embeddings.db");
1218 if !db_path.exists() {
1219 return Ok(None);
1220 }
1221
1222 let conn =
1223 crate::db::open_derived_sqlite_with_recovery(&db_path, "embedding index", || {
1224 ffi::register_sqlite_vec()?;
1225 let conn = Connection::open(&db_path)?;
1226 conn.execute_batch("PRAGMA busy_timeout=5000;")?;
1227 conn.query_row("PRAGMA schema_version", [], |_row| Ok(()))?;
1228 Ok(conn)
1229 })?;
1230
1231 let model_name: Option<String> = conn
1232 .query_row(
1233 "SELECT value FROM meta WHERE key = 'model' LIMIT 1",
1234 [],
1235 |row| row.get(0),
1236 )
1237 .ok();
1238 let indexed_symbols: usize = conn
1239 .query_row("SELECT COUNT(*) FROM symbols", [], |row| {
1240 row.get::<_, i64>(0)
1241 })
1242 .map(|count| count.max(0) as usize)
1243 .unwrap_or(0);
1244
1245 Ok(model_name.map(|model_name| EmbeddingIndexInfo {
1246 model_name,
1247 indexed_symbols,
1248 }))
1249 }
1250
1251 pub fn find_similar_code(
1255 &self,
1256 file_path: &str,
1257 symbol_name: &str,
1258 max_results: usize,
1259 ) -> Result<Vec<SemanticMatch>> {
1260 let target = self
1261 .store
1262 .get_embedding(file_path, symbol_name)?
1263 .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?;
1264
1265 let oversample = max_results.saturating_add(8).max(1);
1266 let scored = self
1267 .store
1268 .search(&target.embedding, oversample)?
1269 .into_iter()
1270 .filter(|c| !(c.file_path == file_path && c.symbol_name == symbol_name))
1271 .take(max_results)
1272 .map(SemanticMatch::from)
1273 .collect();
1274 Ok(scored)
1275 }
1276
1277 pub fn find_duplicates(&self, threshold: f64, max_pairs: usize) -> Result<Vec<DuplicatePair>> {
1280 let mut pairs = Vec::new();
1281 let mut seen_pairs = HashSet::new();
1282 let mut embedding_cache: HashMap<StoredChunkKey, Arc<EmbeddingChunk>> = HashMap::new();
1283 let candidate_limit = duplicate_candidate_limit(max_pairs);
1284 let mut done = false;
1285
1286 self.store
1287 .for_each_embedding_batch(DEFAULT_DUPLICATE_SCAN_BATCH_SIZE, &mut |batch| {
1288 if done {
1289 return Ok(());
1290 }
1291
1292 let mut candidate_lists = Vec::with_capacity(batch.len());
1293 let mut missing_candidates = Vec::new();
1294 let mut missing_keys = HashSet::new();
1295
1296 for chunk in &batch {
1297 if pairs.len() >= max_pairs {
1298 done = true;
1299 break;
1300 }
1301
1302 let filtered: Vec<ScoredChunk> = self
1303 .store
1304 .search(&chunk.embedding, candidate_limit)?
1305 .into_iter()
1306 .filter(|candidate| {
1307 !(chunk.file_path == candidate.file_path
1308 && chunk.symbol_name == candidate.symbol_name
1309 && chunk.line == candidate.line
1310 && chunk.signature == candidate.signature
1311 && chunk.name_path == candidate.name_path)
1312 })
1313 .collect();
1314
1315 for candidate in &filtered {
1316 let cache_key = stored_chunk_key_for_score(candidate);
1317 if !embedding_cache.contains_key(&cache_key)
1318 && missing_keys.insert(cache_key)
1319 {
1320 missing_candidates.push(candidate.clone());
1321 }
1322 }
1323
1324 candidate_lists.push(filtered);
1325 }
1326
1327 if !missing_candidates.is_empty() {
1328 for candidate_chunk in self
1329 .store
1330 .embeddings_for_scored_chunks(&missing_candidates)?
1331 {
1332 embedding_cache
1333 .entry(stored_chunk_key(&candidate_chunk))
1334 .or_insert_with(|| Arc::new(candidate_chunk));
1335 }
1336 }
1337
1338 for (chunk, candidates) in batch.iter().zip(candidate_lists.iter()) {
1339 if pairs.len() >= max_pairs {
1340 done = true;
1341 break;
1342 }
1343
1344 for candidate in candidates {
1345 let pair_key = duplicate_pair_key(
1346 &chunk.file_path,
1347 &chunk.symbol_name,
1348 &candidate.file_path,
1349 &candidate.symbol_name,
1350 );
1351 if !seen_pairs.insert(pair_key) {
1352 continue;
1353 }
1354
1355 let Some(candidate_chunk) =
1356 embedding_cache.get(&stored_chunk_key_for_score(candidate))
1357 else {
1358 continue;
1359 };
1360
1361 let sim = cosine_similarity(&chunk.embedding, &candidate_chunk.embedding);
1362 if sim < threshold {
1363 continue;
1364 }
1365
1366 pairs.push(DuplicatePair {
1367 symbol_a: format!("{}:{}", chunk.file_path, chunk.symbol_name),
1368 symbol_b: format!(
1369 "{}:{}",
1370 candidate_chunk.file_path, candidate_chunk.symbol_name
1371 ),
1372 file_a: chunk.file_path.clone(),
1373 file_b: candidate_chunk.file_path.clone(),
1374 line_a: chunk.line,
1375 line_b: candidate_chunk.line,
1376 similarity: sim,
1377 });
1378 if pairs.len() >= max_pairs {
1379 done = true;
1380 break;
1381 }
1382 }
1383 }
1384 Ok(())
1385 })?;
1386
1387 pairs.sort_by(|a, b| {
1388 b.similarity
1389 .partial_cmp(&a.similarity)
1390 .unwrap_or(std::cmp::Ordering::Equal)
1391 });
1392 Ok(pairs)
1393 }
1394}
1395
1396fn duplicate_candidate_limit(max_pairs: usize) -> usize {
1397 max_pairs.saturating_mul(4).clamp(32, 128)
1398}
1399
1400fn duplicate_pair_key(
1401 file_a: &str,
1402 symbol_a: &str,
1403 file_b: &str,
1404 symbol_b: &str,
1405) -> ((String, String), (String, String)) {
1406 let left = (file_a.to_owned(), symbol_a.to_owned());
1407 let right = (file_b.to_owned(), symbol_b.to_owned());
1408 if left <= right {
1409 (left, right)
1410 } else {
1411 (right, left)
1412 }
1413}
1414
1415type StoredChunkKey = (String, String, usize, String, String);
1416
1417fn stored_chunk_key(chunk: &EmbeddingChunk) -> StoredChunkKey {
1418 (
1419 chunk.file_path.clone(),
1420 chunk.symbol_name.clone(),
1421 chunk.line,
1422 chunk.signature.clone(),
1423 chunk.name_path.clone(),
1424 )
1425}
1426
1427fn stored_chunk_key_for_score(chunk: &ScoredChunk) -> StoredChunkKey {
1428 (
1429 chunk.file_path.clone(),
1430 chunk.symbol_name.clone(),
1431 chunk.line,
1432 chunk.signature.clone(),
1433 chunk.name_path.clone(),
1434 )
1435}
1436
1437impl EmbeddingEngine {
1438 pub fn classify_symbol(
1440 &self,
1441 file_path: &str,
1442 symbol_name: &str,
1443 categories: &[&str],
1444 ) -> Result<Vec<CategoryScore>> {
1445 let target = match self.store.get_embedding(file_path, symbol_name)? {
1446 Some(target) => target,
1447 None => self
1448 .store
1449 .all_with_embeddings()?
1450 .into_iter()
1451 .find(|c| c.file_path == file_path && c.symbol_name == symbol_name)
1452 .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?,
1453 };
1454
1455 let embeddings = self.embed_texts_cached(categories)?;
1456
1457 let mut scores: Vec<CategoryScore> = categories
1458 .iter()
1459 .zip(embeddings.iter())
1460 .map(|(cat, emb)| CategoryScore {
1461 category: cat.to_string(),
1462 score: cosine_similarity(&target.embedding, emb),
1463 })
1464 .collect();
1465
1466 scores.sort_by(|a, b| {
1467 b.score
1468 .partial_cmp(&a.score)
1469 .unwrap_or(std::cmp::Ordering::Equal)
1470 });
1471 Ok(scores)
1472 }
1473
1474 pub fn find_misplaced_code(&self, max_results: usize) -> Result<Vec<OutlierSymbol>> {
1476 let mut outliers = Vec::new();
1477
1478 self.store
1479 .for_each_file_embeddings(&mut |file_path, chunks| {
1480 if chunks.len() < 2 {
1481 return Ok(());
1482 }
1483
1484 for (idx, chunk) in chunks.iter().enumerate() {
1485 let mut sim_sum = 0.0;
1486 let mut count = 0;
1487 for (other_idx, other_chunk) in chunks.iter().enumerate() {
1488 if other_idx == idx {
1489 continue;
1490 }
1491 sim_sum += cosine_similarity(&chunk.embedding, &other_chunk.embedding);
1492 count += 1;
1493 }
1494 if count > 0 {
1495 let avg_sim = sim_sum / count as f64; outliers.push(OutlierSymbol {
1497 file_path: file_path.clone(),
1498 symbol_name: chunk.symbol_name.clone(),
1499 kind: chunk.kind.clone(),
1500 line: chunk.line,
1501 avg_similarity_to_file: avg_sim,
1502 });
1503 }
1504 }
1505 Ok(())
1506 })?;
1507
1508 outliers.sort_by(|a, b| {
1509 a.avg_similarity_to_file
1510 .partial_cmp(&b.avg_similarity_to_file)
1511 .unwrap_or(std::cmp::Ordering::Equal)
1512 });
1513 outliers.truncate(max_results);
1514 Ok(outliers)
1515 }
1516}
1517
1518#[derive(Debug, Clone, Serialize)]
1521pub struct DuplicatePair {
1522 pub symbol_a: String,
1523 pub symbol_b: String,
1524 pub file_a: String,
1525 pub file_b: String,
1526 pub line_a: usize,
1527 pub line_b: usize,
1528 pub similarity: f64,
1529}
1530
1531#[derive(Debug, Clone, Serialize)]
1532pub struct CategoryScore {
1533 pub category: String,
1534 pub score: f64,
1535}
1536
1537#[derive(Debug, Clone, Serialize)]
1538pub struct OutlierSymbol {
1539 pub file_path: String,
1540 pub symbol_name: String,
1541 pub kind: String,
1542 pub line: usize,
1543 pub avg_similarity_to_file: f64,
1544}
1545
1546fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
1551 debug_assert_eq!(a.len(), b.len());
1552
1553 let (mut dot, mut norm_a, mut norm_b) = (0.0f32, 0.0f32, 0.0f32);
1556 for (x, y) in a.iter().zip(b.iter()) {
1557 dot += x * y;
1558 norm_a += x * x;
1559 norm_b += y * y;
1560 }
1561
1562 let norm_a = (norm_a as f64).sqrt();
1563 let norm_b = (norm_b as f64).sqrt();
1564 if norm_a == 0.0 || norm_b == 0.0 {
1565 0.0
1566 } else {
1567 dot as f64 / (norm_a * norm_b)
1568 }
1569}
1570
1571fn split_identifier(name: &str) -> String {
1586 if !name.contains('_') && !name.chars().any(|c| c.is_uppercase()) {
1588 return name.to_string();
1589 }
1590 let mut words = Vec::new();
1591 let mut current = String::new();
1592 let chars: Vec<char> = name.chars().collect();
1593 for (i, &ch) in chars.iter().enumerate() {
1594 if ch == '_' {
1595 if !current.is_empty() {
1596 words.push(current.clone());
1597 current.clear();
1598 }
1599 } else if ch.is_uppercase()
1600 && !current.is_empty()
1601 && (current
1602 .chars()
1603 .last()
1604 .map(|c| c.is_lowercase())
1605 .unwrap_or(false)
1606 || chars.get(i + 1).map(|c| c.is_lowercase()).unwrap_or(false))
1607 {
1608 words.push(current.clone());
1610 current.clear();
1611 current.push(ch);
1612 } else {
1613 current.push(ch);
1614 }
1615 }
1616 if !current.is_empty() {
1617 words.push(current);
1618 }
1619 if words.len() <= 1 {
1620 return name.to_string(); }
1622 words.join(" ")
1623}
1624
1625fn is_test_only_symbol(sym: &crate::db::SymbolWithFile, source: Option<&str>) -> bool {
1626 if sym.file_path.contains("/tests/") || sym.file_path.ends_with("_tests.rs") {
1627 return true;
1628 }
1629 if sym.name_path.starts_with("tests::")
1630 || sym.name_path.contains("::tests::")
1631 || sym.name_path.starts_with("test::")
1632 || sym.name_path.contains("::test::")
1633 {
1634 return true;
1635 }
1636
1637 let Some(source) = source else {
1638 return false;
1639 };
1640
1641 let start = usize::try_from(sym.start_byte.max(0))
1642 .unwrap_or(0)
1643 .min(source.len());
1644 let window_start = start.saturating_sub(2048);
1645 let attrs = String::from_utf8_lossy(&source.as_bytes()[window_start..start]);
1646 attrs.contains("#[test]")
1647 || attrs.contains("#[tokio::test]")
1648 || attrs.contains("#[cfg(test)]")
1649 || attrs.contains("#[cfg(all(test")
1650}
1651
1652fn build_embedding_text(sym: &crate::db::SymbolWithFile, source: Option<&str>) -> String {
1653 let file_ctx = if sym.file_path.is_empty() {
1657 String::new()
1658 } else {
1659 let filename = sym.file_path.rsplit('/').next().unwrap_or(&sym.file_path);
1660 format!(" in {}", filename)
1661 };
1662
1663 let split_name = split_identifier(&sym.name);
1666 let name_with_split = if split_name != sym.name {
1667 format!("{} ({})", sym.name, split_name)
1668 } else {
1669 sym.name.clone()
1670 };
1671
1672 let parent_ctx = if !sym.name_path.is_empty() && sym.name_path.contains('/') {
1674 let parent = sym.name_path.rsplit_once('/').map(|x| x.0).unwrap_or("");
1675 if parent.is_empty() {
1676 String::new()
1677 } else {
1678 format!(" (in {})", parent)
1679 }
1680 } else {
1681 String::new()
1682 };
1683
1684 let module_ctx = if sym.file_path.contains('/') {
1687 let parts: Vec<&str> = sym.file_path.rsplitn(3, '/').collect();
1688 if parts.len() >= 2 {
1689 let dir = parts[1];
1690 if dir != "src" && dir != "crates" {
1692 format!(" [{dir}]")
1693 } else {
1694 String::new()
1695 }
1696 } else {
1697 String::new()
1698 }
1699 } else {
1700 String::new()
1701 };
1702
1703 let base = if sym.signature.is_empty() {
1704 format!(
1705 "{} {}{}{}{}", sym.kind, name_with_split, parent_ctx, module_ctx, file_ctx
1706 )
1707 } else {
1708 format!(
1709 "{} {}{}{}{}: {}",
1710 sym.kind, name_with_split, parent_ctx, module_ctx, file_ctx, sym.signature
1711 )
1712 };
1713
1714 let docstrings_disabled = std::env::var("CODELENS_EMBED_DOCSTRINGS")
1718 .map(|v| v == "0" || v == "false")
1719 .unwrap_or(false);
1720
1721 if docstrings_disabled {
1722 return base;
1723 }
1724
1725 let docstring = source
1726 .and_then(|src| extract_leading_doc(src, sym.start_byte as usize, sym.end_byte as usize))
1727 .unwrap_or_default();
1728
1729 let mut text = if docstring.is_empty() {
1730 let body_hint = source
1735 .and_then(|src| extract_body_hint(src, sym.start_byte as usize, sym.end_byte as usize))
1736 .unwrap_or_default();
1737 if body_hint.is_empty() {
1738 base
1739 } else {
1740 format!("{} — {}", base, body_hint)
1741 }
1742 } else {
1743 let line_budget = hint_line_budget();
1748 let lines: Vec<String> = docstring
1749 .lines()
1750 .map(str::trim)
1751 .filter(|line| !line.is_empty())
1752 .take(line_budget)
1753 .map(str::to_string)
1754 .collect();
1755 let hint = join_hint_lines(&lines);
1756 if hint.is_empty() {
1757 base
1758 } else {
1759 format!("{} — {}", base, hint)
1760 }
1761 };
1762
1763 if let Some(src) = source
1767 && let Some(nl_tokens) =
1768 extract_nl_tokens(src, sym.start_byte as usize, sym.end_byte as usize)
1769 && !nl_tokens.is_empty()
1770 {
1771 text.push_str(" · NL: ");
1772 text.push_str(&nl_tokens);
1773 }
1774
1775 if let Some(src) = source
1780 && let Some(api_calls) =
1781 extract_api_calls(src, sym.start_byte as usize, sym.end_byte as usize)
1782 && !api_calls.is_empty()
1783 {
1784 text.push_str(" · API: ");
1785 text.push_str(&api_calls);
1786 }
1787
1788 text
1789}
1790
1791const DEFAULT_HINT_TOTAL_CHAR_BUDGET: usize = 60;
1804
1805const DEFAULT_HINT_LINES: usize = 1;
1808
1809fn hint_char_budget() -> usize {
1810 std::env::var("CODELENS_EMBED_HINT_CHARS")
1811 .ok()
1812 .and_then(|raw| raw.parse::<usize>().ok())
1813 .map(|n| n.clamp(60, 512))
1814 .unwrap_or(DEFAULT_HINT_TOTAL_CHAR_BUDGET)
1815}
1816
1817fn hint_line_budget() -> usize {
1818 std::env::var("CODELENS_EMBED_HINT_LINES")
1819 .ok()
1820 .and_then(|raw| raw.parse::<usize>().ok())
1821 .map(|n| n.clamp(1, 10))
1822 .unwrap_or(DEFAULT_HINT_LINES)
1823}
1824
1825fn join_hint_lines(lines: &[String]) -> String {
1832 if lines.is_empty() {
1833 return String::new();
1834 }
1835 let joined = lines
1836 .iter()
1837 .map(String::as_str)
1838 .collect::<Vec<_>>()
1839 .join(" · ");
1840 let budget = hint_char_budget();
1841 if joined.chars().count() > budget {
1842 let truncated: String = joined.chars().take(budget).collect();
1843 format!("{truncated}...")
1844 } else {
1845 joined
1846 }
1847}
1848
1849fn extract_body_hint(source: &str, start: usize, end: usize) -> Option<String> {
1859 if start >= source.len() || end > source.len() || start >= end {
1860 return None;
1861 }
1862 let safe_start = if source.is_char_boundary(start) {
1863 start
1864 } else {
1865 source.floor_char_boundary(start)
1866 };
1867 let safe_end = end.min(source.len());
1868 let safe_end = if source.is_char_boundary(safe_end) {
1869 safe_end
1870 } else {
1871 source.floor_char_boundary(safe_end)
1872 };
1873 let body = &source[safe_start..safe_end];
1874
1875 let max_lines = hint_line_budget();
1876 let mut collected: Vec<String> = Vec::with_capacity(max_lines);
1877
1878 let mut past_signature = false;
1881 for line in body.lines() {
1882 let trimmed = line.trim();
1883 if !past_signature {
1884 if trimmed.ends_with('{') || trimmed.ends_with(':') || trimmed == "{" {
1886 past_signature = true;
1887 }
1888 continue;
1889 }
1890 if trimmed.is_empty()
1892 || trimmed.starts_with("//")
1893 || trimmed.starts_with('#')
1894 || trimmed.starts_with("/*")
1895 || trimmed.starts_with('*')
1896 || trimmed == "}"
1897 {
1898 continue;
1899 }
1900 collected.push(trimmed.to_string());
1901 if collected.len() >= max_lines {
1902 break;
1903 }
1904 }
1905
1906 if collected.is_empty() {
1907 None
1908 } else {
1909 Some(join_hint_lines(&collected))
1910 }
1911}
1912
1913fn nl_tokens_enabled() -> bool {
1923 if let Some(explicit) = parse_bool_env("CODELENS_EMBED_HINT_INCLUDE_COMMENTS") {
1924 return explicit;
1925 }
1926 auto_hint_should_enable()
1927}
1928
1929pub(super) fn auto_hint_mode_enabled() -> bool {
1971 parse_bool_env("CODELENS_EMBED_HINT_AUTO").unwrap_or(true)
1972}
1973
1974pub(super) fn auto_hint_lang() -> Option<String> {
1985 std::env::var("CODELENS_EMBED_HINT_AUTO_LANG")
1986 .ok()
1987 .map(|raw| raw.trim().to_ascii_lowercase())
1988}
1989
1990pub(super) fn language_supports_nl_stack(lang: &str) -> bool {
2026 matches!(
2027 lang.trim().to_ascii_lowercase().as_str(),
2028 "rs" | "rust"
2029 | "cpp"
2030 | "cc"
2031 | "cxx"
2032 | "c++"
2033 | "c"
2034 | "go"
2035 | "golang"
2036 | "java"
2037 | "kt"
2038 | "kotlin"
2039 | "scala"
2040 | "cs"
2041 | "csharp"
2042 | "ts"
2043 | "typescript"
2044 | "tsx"
2045 | "js"
2046 | "javascript"
2047 | "jsx"
2048 )
2049}
2050
2051pub(super) fn language_supports_sparse_weighting(lang: &str) -> bool {
2069 matches!(
2070 lang.trim().to_ascii_lowercase().as_str(),
2071 "rs" | "rust"
2072 | "cpp"
2073 | "cc"
2074 | "cxx"
2075 | "c++"
2076 | "c"
2077 | "go"
2078 | "golang"
2079 | "java"
2080 | "kt"
2081 | "kotlin"
2082 | "scala"
2083 | "cs"
2084 | "csharp"
2085 )
2086}
2087
2088pub(super) fn auto_hint_should_enable() -> bool {
2093 if !auto_hint_mode_enabled() {
2094 return false;
2095 }
2096 match auto_hint_lang() {
2097 Some(lang) => language_supports_nl_stack(&lang),
2098 None => false, }
2100}
2101
2102pub(super) fn auto_sparse_should_enable() -> bool {
2109 if !auto_hint_mode_enabled() {
2110 return false;
2111 }
2112 match auto_hint_lang() {
2113 Some(lang) => language_supports_sparse_weighting(&lang),
2114 None => false,
2115 }
2116}
2117
2118pub(super) fn is_nl_shaped(s: &str) -> bool {
2127 let s = s.trim();
2128 if s.chars().count() < 4 {
2129 return false;
2130 }
2131 if s.contains('/') || s.contains('\\') || s.contains("::") {
2132 return false;
2133 }
2134 if !s.contains(' ') {
2135 return false;
2136 }
2137 let non_ws: usize = s.chars().filter(|c| !c.is_whitespace()).count();
2138 if non_ws == 0 {
2139 return false;
2140 }
2141 let alpha: usize = s.chars().filter(|c| c.is_alphabetic()).count();
2142 (alpha * 100) / non_ws >= 60
2143}
2144
2145fn strict_comments_enabled() -> bool {
2160 std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS")
2161 .map(|raw| {
2162 let lowered = raw.to_ascii_lowercase();
2163 matches!(lowered.as_str(), "1" | "true" | "yes" | "on")
2164 })
2165 .unwrap_or(false)
2166}
2167
2168pub(super) fn looks_like_meta_annotation(body: &str) -> bool {
2189 let trimmed = body.trim_start();
2190 let word_end = trimmed
2193 .find(|c: char| !c.is_ascii_alphabetic())
2194 .unwrap_or(trimmed.len());
2195 if word_end == 0 {
2196 return false;
2197 }
2198 let first_word = &trimmed[..word_end];
2199 let upper = first_word.to_ascii_uppercase();
2200 matches!(
2201 upper.as_str(),
2202 "TODO"
2203 | "FIXME"
2204 | "HACK"
2205 | "XXX"
2206 | "BUG"
2207 | "REVIEW"
2208 | "REFACTOR"
2209 | "TEMP"
2210 | "TEMPORARY"
2211 | "DEPRECATED"
2212 )
2213}
2214
2215fn strict_literal_filter_enabled() -> bool {
2230 std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS")
2231 .map(|raw| {
2232 let lowered = raw.to_ascii_lowercase();
2233 matches!(lowered.as_str(), "1" | "true" | "yes" | "on")
2234 })
2235 .unwrap_or(false)
2236}
2237
2238pub(super) fn contains_format_specifier(s: &str) -> bool {
2250 let bytes = s.as_bytes();
2251 let len = bytes.len();
2252 let mut i = 0;
2253 while i + 1 < len {
2254 if bytes[i] == b'%' {
2255 let next = bytes[i + 1];
2256 if matches!(next, b's' | b'd' | b'r' | b'f' | b'x' | b'o' | b'i' | b'u') {
2257 return true;
2258 }
2259 }
2260 i += 1;
2261 }
2262 for window in s.split('{').skip(1) {
2270 let Some(close_idx) = window.find('}') else {
2271 continue;
2272 };
2273 let inside = &window[..close_idx];
2274 if inside.is_empty() {
2276 return true;
2277 }
2278 if inside.chars().any(|c| c.is_whitespace()) {
2280 continue;
2281 }
2282 if inside.starts_with(':') {
2284 return true;
2285 }
2286 let ident_end = inside.find(':').unwrap_or(inside.len());
2290 let ident = &inside[..ident_end];
2291 if !ident.is_empty()
2292 && ident
2293 .chars()
2294 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
2295 {
2296 return true;
2297 }
2298 }
2299 false
2300}
2301
2302pub(super) fn looks_like_error_or_log_prefix(s: &str) -> bool {
2313 let lower = s.trim().to_lowercase();
2314 const PREFIXES: &[&str] = &[
2315 "invalid ",
2316 "cannot ",
2317 "could not ",
2318 "unable to ",
2319 "failed to ",
2320 "expected ",
2321 "unexpected ",
2322 "missing ",
2323 "not found",
2324 "error: ",
2325 "error ",
2326 "warning: ",
2327 "warning ",
2328 "sending ",
2329 "received ",
2330 "starting ",
2331 "stopping ",
2332 "calling ",
2333 "connecting ",
2334 "disconnecting ",
2335 ];
2336 PREFIXES.iter().any(|p| lower.starts_with(p))
2337}
2338
2339#[cfg(test)]
2344pub(super) fn should_reject_literal_strict(s: &str) -> bool {
2345 contains_format_specifier(s) || looks_like_error_or_log_prefix(s)
2346}
2347
2348fn extract_nl_tokens(source: &str, start: usize, end: usize) -> Option<String> {
2362 if !nl_tokens_enabled() {
2363 return None;
2364 }
2365 extract_nl_tokens_inner(source, start, end)
2366}
2367
2368pub(super) fn extract_nl_tokens_inner(source: &str, start: usize, end: usize) -> Option<String> {
2373 if start >= source.len() || end > source.len() || start >= end {
2374 return None;
2375 }
2376 let safe_start = if source.is_char_boundary(start) {
2377 start
2378 } else {
2379 source.floor_char_boundary(start)
2380 };
2381 let safe_end = end.min(source.len());
2382 let safe_end = if source.is_char_boundary(safe_end) {
2383 safe_end
2384 } else {
2385 source.floor_char_boundary(safe_end)
2386 };
2387 let body = &source[safe_start..safe_end];
2388
2389 let mut tokens: Vec<String> = Vec::new();
2390
2391 let strict_comments = strict_comments_enabled();
2399 for line in body.lines() {
2400 let trimmed = line.trim();
2401 if let Some(cleaned) = extract_comment_body(trimmed)
2402 && is_nl_shaped(&cleaned)
2403 && (!strict_comments || !looks_like_meta_annotation(&cleaned))
2404 {
2405 tokens.push(cleaned);
2406 }
2407 }
2408
2409 let strict_literals = strict_literal_filter_enabled();
2419 let mut chars = body.chars().peekable();
2420 let mut in_string = false;
2421 let mut current = String::new();
2422 while let Some(c) = chars.next() {
2423 if in_string {
2424 if c == '\\' {
2425 let _ = chars.next();
2427 } else if c == '"' {
2428 if is_nl_shaped(¤t)
2429 && (!strict_literals
2430 || (!contains_format_specifier(¤t)
2431 && !looks_like_error_or_log_prefix(¤t)))
2432 {
2433 tokens.push(current.clone());
2434 }
2435 current.clear();
2436 in_string = false;
2437 } else {
2438 current.push(c);
2439 }
2440 } else if c == '"' {
2441 in_string = true;
2442 }
2443 }
2444
2445 if tokens.is_empty() {
2446 return None;
2447 }
2448 Some(join_hint_lines(&tokens))
2449}
2450
2451fn api_calls_enabled() -> bool {
2460 if let Some(explicit) = parse_bool_env("CODELENS_EMBED_HINT_INCLUDE_API_CALLS") {
2461 return explicit;
2462 }
2463 auto_hint_should_enable()
2464}
2465
2466pub(super) fn is_static_method_ident(ident: &str) -> bool {
2476 ident.chars().next().is_some_and(|c| c.is_ascii_uppercase())
2477}
2478
2479fn extract_api_calls(source: &str, start: usize, end: usize) -> Option<String> {
2491 if !api_calls_enabled() {
2492 return None;
2493 }
2494 extract_api_calls_inner(source, start, end)
2495}
2496
2497pub(super) fn extract_api_calls_inner(source: &str, start: usize, end: usize) -> Option<String> {
2511 if start >= source.len() || end > source.len() || start >= end {
2512 return None;
2513 }
2514 let safe_start = if source.is_char_boundary(start) {
2515 start
2516 } else {
2517 source.floor_char_boundary(start)
2518 };
2519 let safe_end = end.min(source.len());
2520 let safe_end = if source.is_char_boundary(safe_end) {
2521 safe_end
2522 } else {
2523 source.floor_char_boundary(safe_end)
2524 };
2525 if safe_start >= safe_end {
2526 return None;
2527 }
2528 let body = &source[safe_start..safe_end];
2529 let bytes = body.as_bytes();
2530 let len = bytes.len();
2531
2532 let mut calls: Vec<String> = Vec::new();
2533 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
2534
2535 let mut i = 0usize;
2536 while i < len {
2537 let b = bytes[i];
2538 if !(b == b'_' || b.is_ascii_alphabetic()) {
2540 i += 1;
2541 continue;
2542 }
2543 let ident_start = i;
2544 while i < len {
2545 let bb = bytes[i];
2546 if bb == b'_' || bb.is_ascii_alphanumeric() {
2547 i += 1;
2548 } else {
2549 break;
2550 }
2551 }
2552 let ident_end = i;
2553
2554 if i + 1 >= len || bytes[i] != b':' || bytes[i + 1] != b':' {
2556 continue;
2557 }
2558
2559 let type_ident = &body[ident_start..ident_end];
2560 if !is_static_method_ident(type_ident) {
2561 i += 2;
2564 continue;
2565 }
2566
2567 let mut j = i + 2;
2569 if j >= len || !(bytes[j] == b'_' || bytes[j].is_ascii_alphabetic()) {
2570 i = j;
2571 continue;
2572 }
2573 let method_start = j;
2574 while j < len {
2575 let bb = bytes[j];
2576 if bb == b'_' || bb.is_ascii_alphanumeric() {
2577 j += 1;
2578 } else {
2579 break;
2580 }
2581 }
2582 let method_end = j;
2583
2584 let method_ident = &body[method_start..method_end];
2585 let call = format!("{type_ident}::{method_ident}");
2586 if seen.insert(call.clone()) {
2587 calls.push(call);
2588 }
2589 i = j;
2590 }
2591
2592 if calls.is_empty() {
2593 return None;
2594 }
2595 Some(join_hint_lines(&calls))
2596}
2597
2598fn extract_comment_body(trimmed: &str) -> Option<String> {
2601 if trimmed.is_empty() {
2602 return None;
2603 }
2604 if let Some(rest) = trimmed.strip_prefix("///") {
2606 return Some(rest.trim().to_string());
2607 }
2608 if let Some(rest) = trimmed.strip_prefix("//!") {
2609 return Some(rest.trim().to_string());
2610 }
2611 if let Some(rest) = trimmed.strip_prefix("//") {
2612 return Some(rest.trim().to_string());
2613 }
2614 if trimmed.starts_with("#[") || trimmed.starts_with("#!") {
2616 return None;
2617 }
2618 if let Some(rest) = trimmed.strip_prefix('#') {
2620 return Some(rest.trim().to_string());
2621 }
2622 if let Some(rest) = trimmed.strip_prefix("/**") {
2624 return Some(rest.trim_end_matches("*/").trim().to_string());
2625 }
2626 if let Some(rest) = trimmed.strip_prefix("/*") {
2627 return Some(rest.trim_end_matches("*/").trim().to_string());
2628 }
2629 if let Some(rest) = trimmed.strip_prefix('*') {
2630 let rest = rest.trim_end_matches("*/").trim();
2633 if rest.is_empty() {
2634 return None;
2635 }
2636 if rest.contains(';') || rest.contains('{') {
2638 return None;
2639 }
2640 return Some(rest.to_string());
2641 }
2642 None
2643}
2644
2645fn extract_leading_doc(source: &str, start: usize, end: usize) -> Option<String> {
2648 if start >= source.len() || end > source.len() || start >= end {
2649 return None;
2650 }
2651 let safe_start = if source.is_char_boundary(start) {
2653 start
2654 } else {
2655 source.floor_char_boundary(start)
2656 };
2657 let safe_end = end.min(source.len());
2658 let safe_end = if source.is_char_boundary(safe_end) {
2659 safe_end
2660 } else {
2661 source.floor_char_boundary(safe_end)
2662 };
2663 if safe_start >= safe_end {
2664 return None;
2665 }
2666 let body = &source[safe_start..safe_end];
2667 let lines: Vec<&str> = body.lines().skip(1).collect(); if lines.is_empty() {
2669 return None;
2670 }
2671
2672 let mut doc_lines = Vec::new();
2673
2674 let first_trimmed = lines.first().map(|l| l.trim()).unwrap_or_default();
2676 if first_trimmed.starts_with("\"\"\"") || first_trimmed.starts_with("'''") {
2677 let quote = &first_trimmed[..3];
2678 for line in &lines {
2679 let t = line.trim();
2680 doc_lines.push(t.trim_start_matches(quote).trim_end_matches(quote));
2681 if doc_lines.len() > 1 && t.ends_with(quote) {
2682 break;
2683 }
2684 }
2685 }
2686 else if first_trimmed.starts_with("///") || first_trimmed.starts_with("//!") {
2688 for line in &lines {
2689 let t = line.trim();
2690 if t.starts_with("///") || t.starts_with("//!") {
2691 doc_lines.push(t.trim_start_matches("///").trim_start_matches("//!").trim());
2692 } else {
2693 break;
2694 }
2695 }
2696 }
2697 else if first_trimmed.starts_with("/**") {
2699 for line in &lines {
2700 let t = line.trim();
2701 let cleaned = t
2702 .trim_start_matches("/**")
2703 .trim_start_matches('*')
2704 .trim_end_matches("*/")
2705 .trim();
2706 if !cleaned.is_empty() {
2707 doc_lines.push(cleaned);
2708 }
2709 if t.ends_with("*/") {
2710 break;
2711 }
2712 }
2713 }
2714 else {
2716 for line in &lines {
2717 let t = line.trim();
2718 if t.starts_with("//") || t.starts_with('#') {
2719 doc_lines.push(t.trim_start_matches("//").trim_start_matches('#').trim());
2720 } else {
2721 break;
2722 }
2723 }
2724 }
2725
2726 if doc_lines.is_empty() {
2727 return None;
2728 }
2729 Some(doc_lines.join(" ").trim().to_owned())
2730}
2731
2732pub(super) fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
2733 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
2734}
2735
2736#[cfg(test)]
2737mod tests {
2738 use super::*;
2739 use crate::db::{IndexDb, NewSymbol};
2740 use std::sync::Mutex;
2741
2742 static MODEL_LOCK: Mutex<()> = Mutex::new(());
2744
2745 static ENV_LOCK: Mutex<()> = Mutex::new(());
2752
2753 macro_rules! skip_without_embedding_model {
2754 () => {
2755 if !super::embedding_model_assets_available() {
2756 eprintln!("skipping embedding test: CodeSearchNet model assets unavailable");
2757 return;
2758 }
2759 };
2760 }
2761
2762 fn make_project_with_source() -> (tempfile::TempDir, ProjectRoot) {
2764 let dir = tempfile::tempdir().unwrap();
2765 let root = dir.path();
2766
2767 let source = "def hello():\n print('hi')\n\ndef world():\n return 42\n";
2769 write_python_file_with_symbols(
2770 root,
2771 "main.py",
2772 source,
2773 "hash1",
2774 &[
2775 ("hello", "def hello():", "hello"),
2776 ("world", "def world():", "world"),
2777 ],
2778 );
2779
2780 let project = ProjectRoot::new_exact(root).unwrap();
2781 (dir, project)
2782 }
2783
2784 fn write_python_file_with_symbols(
2785 root: &std::path::Path,
2786 relative_path: &str,
2787 source: &str,
2788 hash: &str,
2789 symbols: &[(&str, &str, &str)],
2790 ) {
2791 std::fs::write(root.join(relative_path), source).unwrap();
2792 let db_path = crate::db::index_db_path(root);
2793 let db = IndexDb::open(&db_path).unwrap();
2794 let file_id = db
2795 .upsert_file(relative_path, 100, hash, source.len() as i64, Some("py"))
2796 .unwrap();
2797
2798 let new_symbols: Vec<NewSymbol<'_>> = symbols
2799 .iter()
2800 .map(|(name, signature, name_path)| {
2801 let start = source.find(signature).unwrap() as i64;
2802 let end = source[start as usize..]
2803 .find("\n\ndef ")
2804 .map(|offset| start + offset as i64)
2805 .unwrap_or(source.len() as i64);
2806 let line = source[..start as usize]
2807 .bytes()
2808 .filter(|&b| b == b'\n')
2809 .count() as i64
2810 + 1;
2811 NewSymbol {
2812 name,
2813 kind: "function",
2814 line,
2815 column_num: 0,
2816 start_byte: start,
2817 end_byte: end,
2818 signature,
2819 name_path,
2820 parent_id: None,
2821 }
2822 })
2823 .collect();
2824 db.insert_symbols(file_id, &new_symbols).unwrap();
2825 }
2826
2827 fn replace_file_embeddings_with_sentinels(
2828 engine: &EmbeddingEngine,
2829 file_path: &str,
2830 sentinels: &[(&str, f32)],
2831 ) {
2832 let mut chunks = engine.store.embeddings_for_files(&[file_path]).unwrap();
2833 for chunk in &mut chunks {
2834 if let Some((_, value)) = sentinels
2835 .iter()
2836 .find(|(symbol_name, _)| *symbol_name == chunk.symbol_name)
2837 {
2838 chunk.embedding = vec![*value; chunk.embedding.len()];
2839 }
2840 }
2841 engine.store.delete_by_file(&[file_path]).unwrap();
2842 engine.store.insert(&chunks).unwrap();
2843 }
2844
2845 #[test]
2846 fn build_embedding_text_with_signature() {
2847 let sym = crate::db::SymbolWithFile {
2848 name: "hello".into(),
2849 kind: "function".into(),
2850 file_path: "main.py".into(),
2851 line: 1,
2852 signature: "def hello():".into(),
2853 name_path: "hello".into(),
2854 start_byte: 0,
2855 end_byte: 10,
2856 };
2857 let text = build_embedding_text(&sym, Some("def hello(): pass"));
2858 assert_eq!(text, "function hello in main.py: def hello():");
2859 }
2860
2861 #[test]
2862 fn build_embedding_text_without_source() {
2863 let sym = crate::db::SymbolWithFile {
2864 name: "MyClass".into(),
2865 kind: "class".into(),
2866 file_path: "app.py".into(),
2867 line: 5,
2868 signature: "class MyClass:".into(),
2869 name_path: "MyClass".into(),
2870 start_byte: 0,
2871 end_byte: 50,
2872 };
2873 let text = build_embedding_text(&sym, None);
2874 assert_eq!(text, "class MyClass (My Class) in app.py: class MyClass:");
2875 }
2876
2877 #[test]
2878 fn build_embedding_text_empty_signature() {
2879 let sym = crate::db::SymbolWithFile {
2880 name: "CONFIG".into(),
2881 kind: "variable".into(),
2882 file_path: "config.py".into(),
2883 line: 1,
2884 signature: String::new(),
2885 name_path: "CONFIG".into(),
2886 start_byte: 0,
2887 end_byte: 0,
2888 };
2889 let text = build_embedding_text(&sym, None);
2890 assert_eq!(text, "variable CONFIG in config.py");
2891 }
2892
2893 #[test]
2894 fn filters_direct_test_symbols_from_embedding_index() {
2895 let source = "#[test]\nfn alias_case() {}\n";
2896 let sym = crate::db::SymbolWithFile {
2897 name: "alias_case".into(),
2898 kind: "function".into(),
2899 file_path: "src/lib.rs".into(),
2900 line: 2,
2901 signature: "fn alias_case() {}".into(),
2902 name_path: "alias_case".into(),
2903 start_byte: source.find("fn alias_case").unwrap() as i64,
2904 end_byte: source.len() as i64,
2905 };
2906
2907 assert!(is_test_only_symbol(&sym, Some(source)));
2908 }
2909
2910 #[test]
2911 fn filters_cfg_test_module_symbols_from_embedding_index() {
2912 let source = "#[cfg(all(test, feature = \"semantic\"))]\nmod semantic_tests {\n fn helper_case() {}\n}\n";
2913 let sym = crate::db::SymbolWithFile {
2914 name: "helper_case".into(),
2915 kind: "function".into(),
2916 file_path: "src/lib.rs".into(),
2917 line: 3,
2918 signature: "fn helper_case() {}".into(),
2919 name_path: "helper_case".into(),
2920 start_byte: source.find("fn helper_case").unwrap() as i64,
2921 end_byte: source.len() as i64,
2922 };
2923
2924 assert!(is_test_only_symbol(&sym, Some(source)));
2925 }
2926
2927 #[test]
2928 fn extract_python_docstring() {
2929 let source =
2930 "def greet(name):\n \"\"\"Say hello to a person.\"\"\"\n print(f'hi {name}')\n";
2931 let doc = extract_leading_doc(source, 0, source.len()).unwrap();
2932 assert!(doc.contains("Say hello to a person"));
2933 }
2934
2935 #[test]
2936 fn extract_rust_doc_comment() {
2937 let source = "fn dispatch_tool() {\n /// Route incoming tool requests.\n /// Handles all MCP methods.\n let x = 1;\n}\n";
2938 let doc = extract_leading_doc(source, 0, source.len()).unwrap();
2939 assert!(doc.contains("Route incoming tool requests"));
2940 assert!(doc.contains("Handles all MCP methods"));
2941 }
2942
2943 #[test]
2944 fn extract_leading_doc_returns_none_for_no_doc() {
2945 let source = "def f():\n return 1\n";
2946 assert!(extract_leading_doc(source, 0, source.len()).is_none());
2947 }
2948
2949 #[test]
2950 fn extract_body_hint_finds_first_meaningful_line() {
2951 let source = "pub fn parse_symbols(\n project: &ProjectRoot,\n) -> Vec<SymbolInfo> {\n let mut parser = tree_sitter::Parser::new();\n parser.set_language(lang);\n}\n";
2952 let hint = extract_body_hint(source, 0, source.len());
2953 assert!(hint.is_some());
2954 assert!(hint.unwrap().contains("tree_sitter::Parser"));
2955 }
2956
2957 #[test]
2958 fn extract_body_hint_skips_comments() {
2959 let source = "fn foo() {\n // setup\n let x = bar();\n}\n";
2960 let hint = extract_body_hint(source, 0, source.len());
2961 assert_eq!(hint.unwrap(), "let x = bar();");
2962 }
2963
2964 #[test]
2965 fn extract_body_hint_returns_none_for_empty() {
2966 let source = "fn empty() {\n}\n";
2967 let hint = extract_body_hint(source, 0, source.len());
2968 assert!(hint.is_none());
2969 }
2970
2971 #[test]
2972 fn extract_body_hint_multi_line_collection_via_env_override() {
2973 let previous_lines = std::env::var("CODELENS_EMBED_HINT_LINES").ok();
2978 let previous_chars = std::env::var("CODELENS_EMBED_HINT_CHARS").ok();
2979 unsafe {
2980 std::env::set_var("CODELENS_EMBED_HINT_LINES", "3");
2981 std::env::set_var("CODELENS_EMBED_HINT_CHARS", "200");
2982 }
2983
2984 let source = "\
2985fn route_request() {
2986 let kind = detect_request_kind();
2987 let target = dispatch_table.get(&kind);
2988 return target.handle();
2989}
2990";
2991 let hint = extract_body_hint(source, 0, source.len()).expect("hint present");
2992
2993 let env_restore = || unsafe {
2994 match &previous_lines {
2995 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_LINES", value),
2996 None => std::env::remove_var("CODELENS_EMBED_HINT_LINES"),
2997 }
2998 match &previous_chars {
2999 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_CHARS", value),
3000 None => std::env::remove_var("CODELENS_EMBED_HINT_CHARS"),
3001 }
3002 };
3003
3004 let all_three = hint.contains("detect_request_kind")
3005 && hint.contains("dispatch_table")
3006 && hint.contains("target.handle");
3007 let has_separator = hint.contains(" · ");
3008 env_restore();
3009
3010 assert!(all_three, "missing one of three body lines: {hint}");
3011 assert!(has_separator, "missing · separator: {hint}");
3012 }
3013
3014 #[test]
3025 fn hint_line_budget_respects_env_override() {
3026 let previous = std::env::var("CODELENS_EMBED_HINT_LINES").ok();
3029 unsafe {
3030 std::env::set_var("CODELENS_EMBED_HINT_LINES", "5");
3031 }
3032 let budget = super::hint_line_budget();
3033 unsafe {
3034 match previous {
3035 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_LINES", value),
3036 None => std::env::remove_var("CODELENS_EMBED_HINT_LINES"),
3037 }
3038 }
3039 assert_eq!(budget, 5);
3040 }
3041
3042 #[test]
3043 fn is_nl_shaped_accepts_multi_word_prose() {
3044 assert!(super::is_nl_shaped("skip comments and string literals"));
3045 assert!(super::is_nl_shaped("failed to open database"));
3046 assert!(super::is_nl_shaped("detect client version"));
3047 }
3048
3049 #[test]
3050 fn is_nl_shaped_rejects_code_and_paths() {
3051 assert!(!super::is_nl_shaped("crates/codelens-engine/src"));
3053 assert!(!super::is_nl_shaped("C:\\Users\\foo"));
3054 assert!(!super::is_nl_shaped("std::sync::Mutex"));
3056 assert!(!super::is_nl_shaped("detect_client"));
3058 assert!(!super::is_nl_shaped("ok"));
3060 assert!(!super::is_nl_shaped(""));
3061 assert!(!super::is_nl_shaped("1 2 3 4 5"));
3063 }
3064
3065 #[test]
3066 fn extract_comment_body_strips_comment_markers() {
3067 assert_eq!(
3068 super::extract_comment_body("/// rust doc comment"),
3069 Some("rust doc comment".to_string())
3070 );
3071 assert_eq!(
3072 super::extract_comment_body("// regular line comment"),
3073 Some("regular line comment".to_string())
3074 );
3075 assert_eq!(
3076 super::extract_comment_body("# python line comment"),
3077 Some("python line comment".to_string())
3078 );
3079 assert_eq!(
3080 super::extract_comment_body("/* inline block */"),
3081 Some("inline block".to_string())
3082 );
3083 assert_eq!(
3084 super::extract_comment_body("* continuation line"),
3085 Some("continuation line".to_string())
3086 );
3087 }
3088
3089 #[test]
3090 fn extract_comment_body_rejects_rust_attributes_and_shebangs() {
3091 assert!(super::extract_comment_body("#[derive(Debug)]").is_none());
3092 assert!(super::extract_comment_body("#[test]").is_none());
3093 assert!(super::extract_comment_body("#!/usr/bin/env python").is_none());
3094 }
3095
3096 #[test]
3097 fn extract_nl_tokens_gated_off_by_default() {
3098 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3099 let previous = std::env::var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS").ok();
3101 unsafe {
3102 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS");
3103 }
3104 let source = "\
3105fn skip_things() {
3106 // skip comments and string literals during search
3107 let lit = \"scan for matching tokens\";
3108}
3109";
3110 let result = extract_nl_tokens(source, 0, source.len());
3111 unsafe {
3112 if let Some(value) = previous {
3113 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", value);
3114 }
3115 }
3116 assert!(result.is_none(), "gate leaked: {result:?}");
3117 }
3118
3119 #[test]
3120 fn auto_hint_mode_defaults_on_unless_explicit_off() {
3121 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3122 let previous = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3130
3131 unsafe {
3133 std::env::remove_var("CODELENS_EMBED_HINT_AUTO");
3134 }
3135 let default_enabled = super::auto_hint_mode_enabled();
3136 assert!(
3137 default_enabled,
3138 "v1.6.0 default flip: auto hint mode should be ON when env unset"
3139 );
3140
3141 unsafe {
3143 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "0");
3144 }
3145 let explicit_off = super::auto_hint_mode_enabled();
3146 assert!(
3147 !explicit_off,
3148 "explicit CODELENS_EMBED_HINT_AUTO=0 must still disable (opt-out escape hatch)"
3149 );
3150
3151 unsafe {
3153 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3154 }
3155 let explicit_on = super::auto_hint_mode_enabled();
3156 assert!(
3157 explicit_on,
3158 "explicit CODELENS_EMBED_HINT_AUTO=1 must still enable"
3159 );
3160
3161 unsafe {
3163 match previous {
3164 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3165 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3166 }
3167 }
3168 }
3169
3170 #[test]
3171 fn language_supports_nl_stack_classifies_correctly() {
3172 assert!(super::language_supports_nl_stack("rs"));
3174 assert!(super::language_supports_nl_stack("rust"));
3175 assert!(super::language_supports_nl_stack("cpp"));
3176 assert!(super::language_supports_nl_stack("c++"));
3177 assert!(super::language_supports_nl_stack("c"));
3178 assert!(super::language_supports_nl_stack("go"));
3179 assert!(super::language_supports_nl_stack("golang"));
3180 assert!(super::language_supports_nl_stack("java"));
3181 assert!(super::language_supports_nl_stack("kt"));
3182 assert!(super::language_supports_nl_stack("kotlin"));
3183 assert!(super::language_supports_nl_stack("scala"));
3184 assert!(super::language_supports_nl_stack("cs"));
3185 assert!(super::language_supports_nl_stack("csharp"));
3186 assert!(super::language_supports_nl_stack("ts"));
3189 assert!(super::language_supports_nl_stack("typescript"));
3190 assert!(super::language_supports_nl_stack("tsx"));
3191 assert!(super::language_supports_nl_stack("js"));
3192 assert!(super::language_supports_nl_stack("javascript"));
3193 assert!(super::language_supports_nl_stack("jsx"));
3194 assert!(super::language_supports_nl_stack("Rust"));
3196 assert!(super::language_supports_nl_stack("RUST"));
3197 assert!(super::language_supports_nl_stack("TypeScript"));
3198 assert!(super::language_supports_nl_stack(" rust "));
3200 assert!(super::language_supports_nl_stack(" ts "));
3201
3202 assert!(!super::language_supports_nl_stack("py"));
3204 assert!(!super::language_supports_nl_stack("python"));
3205 assert!(!super::language_supports_nl_stack("rb"));
3206 assert!(!super::language_supports_nl_stack("ruby"));
3207 assert!(!super::language_supports_nl_stack("php"));
3208 assert!(!super::language_supports_nl_stack("lua"));
3209 assert!(!super::language_supports_nl_stack("sh"));
3210 assert!(!super::language_supports_nl_stack("klingon"));
3212 assert!(!super::language_supports_nl_stack(""));
3213 }
3214
3215 #[test]
3216 fn language_supports_sparse_weighting_classifies_correctly() {
3217 assert!(super::language_supports_sparse_weighting("rs"));
3218 assert!(super::language_supports_sparse_weighting("rust"));
3219 assert!(super::language_supports_sparse_weighting("cpp"));
3220 assert!(super::language_supports_sparse_weighting("go"));
3221 assert!(super::language_supports_sparse_weighting("java"));
3222 assert!(super::language_supports_sparse_weighting("kotlin"));
3223 assert!(super::language_supports_sparse_weighting("csharp"));
3224
3225 assert!(!super::language_supports_sparse_weighting("ts"));
3226 assert!(!super::language_supports_sparse_weighting("typescript"));
3227 assert!(!super::language_supports_sparse_weighting("tsx"));
3228 assert!(!super::language_supports_sparse_weighting("js"));
3229 assert!(!super::language_supports_sparse_weighting("javascript"));
3230 assert!(!super::language_supports_sparse_weighting("jsx"));
3231 assert!(!super::language_supports_sparse_weighting("py"));
3232 assert!(!super::language_supports_sparse_weighting("python"));
3233 assert!(!super::language_supports_sparse_weighting("klingon"));
3234 assert!(!super::language_supports_sparse_weighting(""));
3235 }
3236
3237 #[test]
3238 fn auto_hint_should_enable_requires_both_gate_and_supported_lang() {
3239 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3240 let prev_auto = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3241 let prev_lang = std::env::var("CODELENS_EMBED_HINT_AUTO_LANG").ok();
3242
3243 unsafe {
3247 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "0");
3248 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3249 }
3250 assert!(
3251 !super::auto_hint_should_enable(),
3252 "gate-off (explicit =0) with lang=rust must stay disabled"
3253 );
3254
3255 unsafe {
3257 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3258 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3259 }
3260 assert!(
3261 super::auto_hint_should_enable(),
3262 "gate-on + lang=rust must enable"
3263 );
3264
3265 unsafe {
3266 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3267 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "typescript");
3268 }
3269 assert!(
3270 super::auto_hint_should_enable(),
3271 "gate-on + lang=typescript must keep Phase 2b/2c enabled"
3272 );
3273
3274 unsafe {
3276 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3277 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3278 }
3279 assert!(
3280 !super::auto_hint_should_enable(),
3281 "gate-on + lang=python must stay disabled"
3282 );
3283
3284 unsafe {
3286 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3287 std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG");
3288 }
3289 assert!(
3290 !super::auto_hint_should_enable(),
3291 "gate-on + no lang tag must stay disabled"
3292 );
3293
3294 unsafe {
3296 match prev_auto {
3297 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3298 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3299 }
3300 match prev_lang {
3301 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", v),
3302 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG"),
3303 }
3304 }
3305 }
3306
3307 #[test]
3308 fn auto_sparse_should_enable_requires_both_gate_and_sparse_supported_lang() {
3309 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3310 let prev_auto = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3311 let prev_lang = std::env::var("CODELENS_EMBED_HINT_AUTO_LANG").ok();
3312
3313 unsafe {
3314 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "0");
3315 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3316 }
3317 assert!(
3318 !super::auto_sparse_should_enable(),
3319 "gate-off (explicit =0) must disable sparse auto gate"
3320 );
3321
3322 unsafe {
3323 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3324 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3325 }
3326 assert!(
3327 super::auto_sparse_should_enable(),
3328 "gate-on + lang=rust must enable sparse auto gate"
3329 );
3330
3331 unsafe {
3332 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3333 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "typescript");
3334 }
3335 assert!(
3336 !super::auto_sparse_should_enable(),
3337 "gate-on + lang=typescript must keep sparse auto gate disabled"
3338 );
3339
3340 unsafe {
3341 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3342 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3343 }
3344 assert!(
3345 !super::auto_sparse_should_enable(),
3346 "gate-on + lang=python must keep sparse auto gate disabled"
3347 );
3348
3349 unsafe {
3350 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3351 std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG");
3352 }
3353 assert!(
3354 !super::auto_sparse_should_enable(),
3355 "gate-on + no lang tag must keep sparse auto gate disabled"
3356 );
3357
3358 unsafe {
3359 match prev_auto {
3360 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3361 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3362 }
3363 match prev_lang {
3364 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", v),
3365 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG"),
3366 }
3367 }
3368 }
3369
3370 #[test]
3371 fn nl_tokens_enabled_explicit_env_wins_over_auto() {
3372 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3373 let prev_explicit = std::env::var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS").ok();
3374 let prev_auto = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3375 let prev_lang = std::env::var("CODELENS_EMBED_HINT_AUTO_LANG").ok();
3376
3377 unsafe {
3379 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", "1");
3380 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3381 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3382 }
3383 assert!(
3384 super::nl_tokens_enabled(),
3385 "explicit=1 must win over auto+python=off"
3386 );
3387
3388 unsafe {
3390 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", "0");
3391 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3392 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3393 }
3394 assert!(
3395 !super::nl_tokens_enabled(),
3396 "explicit=0 must win over auto+rust=on"
3397 );
3398
3399 unsafe {
3401 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS");
3402 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3403 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3404 }
3405 assert!(
3406 super::nl_tokens_enabled(),
3407 "no explicit + auto+rust must enable"
3408 );
3409
3410 unsafe {
3412 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS");
3413 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3414 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3415 }
3416 assert!(
3417 !super::nl_tokens_enabled(),
3418 "no explicit + auto+python must stay disabled"
3419 );
3420
3421 unsafe {
3423 match prev_explicit {
3424 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", v),
3425 None => std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS"),
3426 }
3427 match prev_auto {
3428 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3429 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3430 }
3431 match prev_lang {
3432 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", v),
3433 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG"),
3434 }
3435 }
3436 }
3437
3438 #[test]
3439 fn strict_comments_gated_off_by_default() {
3440 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3441 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS").ok();
3442 unsafe {
3443 std::env::remove_var("CODELENS_EMBED_HINT_STRICT_COMMENTS");
3444 }
3445 let enabled = super::strict_comments_enabled();
3446 unsafe {
3447 if let Some(value) = previous {
3448 std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", value);
3449 }
3450 }
3451 assert!(!enabled, "strict comments gate leaked");
3452 }
3453
3454 #[test]
3455 fn looks_like_meta_annotation_detects_rejected_prefixes() {
3456 assert!(super::looks_like_meta_annotation("TODO: fix later"));
3458 assert!(super::looks_like_meta_annotation("todo handle edge case"));
3459 assert!(super::looks_like_meta_annotation("FIXME this is broken"));
3460 assert!(super::looks_like_meta_annotation(
3461 "HACK: workaround for bug"
3462 ));
3463 assert!(super::looks_like_meta_annotation("XXX not implemented yet"));
3464 assert!(super::looks_like_meta_annotation(
3465 "BUG in the upstream crate"
3466 ));
3467 assert!(super::looks_like_meta_annotation("REVIEW before merging"));
3468 assert!(super::looks_like_meta_annotation(
3469 "REFACTOR this block later"
3470 ));
3471 assert!(super::looks_like_meta_annotation("TEMP: remove before v2"));
3472 assert!(super::looks_like_meta_annotation(
3473 "DEPRECATED use new_api instead"
3474 ));
3475 assert!(super::looks_like_meta_annotation(
3477 " TODO: with leading ws"
3478 ));
3479 }
3480
3481 #[test]
3482 fn looks_like_meta_annotation_preserves_behaviour_prefixes() {
3483 assert!(!super::looks_like_meta_annotation(
3485 "NOTE: this branch handles empty input"
3486 ));
3487 assert!(!super::looks_like_meta_annotation(
3488 "WARN: overflow is possible"
3489 ));
3490 assert!(!super::looks_like_meta_annotation(
3491 "SAFETY: caller must hold the lock"
3492 ));
3493 assert!(!super::looks_like_meta_annotation(
3494 "PANIC: unreachable by construction"
3495 ));
3496 assert!(!super::looks_like_meta_annotation(
3498 "parse json body from request"
3499 ));
3500 assert!(!super::looks_like_meta_annotation(
3501 "walk directory respecting gitignore"
3502 ));
3503 assert!(!super::looks_like_meta_annotation(
3504 "compute cosine similarity between vectors"
3505 ));
3506 assert!(!super::looks_like_meta_annotation(""));
3508 assert!(!super::looks_like_meta_annotation(" "));
3509 assert!(!super::looks_like_meta_annotation("123 numeric prefix"));
3510 }
3511
3512 #[test]
3513 fn strict_comments_filters_meta_annotations_during_extraction() {
3514 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3515 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS").ok();
3516 unsafe {
3517 std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", "1");
3518 }
3519 let source = "\
3520fn handle_request() {
3521 // TODO: handle the error path properly
3522 // parse json body from the incoming request
3523 // FIXME: this can panic on empty input
3524 // walk directory respecting the gitignore rules
3525 let _x = 1;
3526}
3527";
3528 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3529 unsafe {
3530 match previous {
3531 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", value),
3532 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_COMMENTS"),
3533 }
3534 }
3535 let hint = result.expect("behaviour comments must survive");
3536 assert!(
3540 hint.contains("parse json body"),
3541 "behaviour comment dropped: {hint}"
3542 );
3543 assert!(!hint.contains("TODO"), "TODO annotation leaked: {hint}");
3546 assert!(!hint.contains("FIXME"), "FIXME annotation leaked: {hint}");
3547 }
3548
3549 #[test]
3550 fn strict_comments_is_orthogonal_to_strict_literals() {
3551 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3552 let prev_c = std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS").ok();
3556 let prev_l = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3557 unsafe {
3558 std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", "1");
3559 std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS");
3560 }
3561 let source = "\
3564fn handle() {
3565 // handles real behaviour
3566 let fmt = \"format error string\";
3567}
3568";
3569 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3570 unsafe {
3571 match prev_c {
3572 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", v),
3573 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_COMMENTS"),
3574 }
3575 match prev_l {
3576 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", v),
3577 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS"),
3578 }
3579 }
3580 let hint = result.expect("tokens must exist");
3581 assert!(hint.contains("handles real"), "comment dropped: {hint}");
3583 assert!(
3586 hint.contains("format error string"),
3587 "literal dropped: {hint}"
3588 );
3589 }
3590
3591 #[test]
3592 fn strict_literal_filter_gated_off_by_default() {
3593 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3594 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3595 unsafe {
3596 std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS");
3597 }
3598 let enabled = super::strict_literal_filter_enabled();
3599 unsafe {
3600 if let Some(value) = previous {
3601 std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", value);
3602 }
3603 }
3604 assert!(!enabled, "strict literal filter gate leaked");
3605 }
3606
3607 #[test]
3608 fn contains_format_specifier_detects_c_and_python_style() {
3609 assert!(super::contains_format_specifier("Invalid URL %s"));
3611 assert!(super::contains_format_specifier("got %d matches"));
3612 assert!(super::contains_format_specifier("value=%r"));
3613 assert!(super::contains_format_specifier("size=%f"));
3614 assert!(super::contains_format_specifier("sending request to {url}"));
3616 assert!(super::contains_format_specifier("got {0} items"));
3617 assert!(super::contains_format_specifier("{:?}"));
3618 assert!(super::contains_format_specifier("value: {x:.2f}"));
3619 assert!(super::contains_format_specifier("{}"));
3620 assert!(!super::contains_format_specifier(
3622 "skip comments and string literals"
3623 ));
3624 assert!(!super::contains_format_specifier("failed to open database"));
3625 assert!(!super::contains_format_specifier("{name: foo, id: 1}"));
3628 }
3629
3630 #[test]
3631 fn looks_like_error_or_log_prefix_rejects_common_patterns() {
3632 assert!(super::looks_like_error_or_log_prefix("Invalid URL format"));
3633 assert!(super::looks_like_error_or_log_prefix(
3634 "Cannot decode response"
3635 ));
3636 assert!(super::looks_like_error_or_log_prefix("could not open file"));
3637 assert!(super::looks_like_error_or_log_prefix(
3638 "Failed to send request"
3639 ));
3640 assert!(super::looks_like_error_or_log_prefix(
3641 "Expected int, got str"
3642 ));
3643 assert!(super::looks_like_error_or_log_prefix(
3644 "sending request to server"
3645 ));
3646 assert!(super::looks_like_error_or_log_prefix(
3647 "received response headers"
3648 ));
3649 assert!(super::looks_like_error_or_log_prefix(
3650 "starting worker pool"
3651 ));
3652 assert!(!super::looks_like_error_or_log_prefix(
3654 "parse json body from request"
3655 ));
3656 assert!(!super::looks_like_error_or_log_prefix(
3657 "compute cosine similarity between vectors"
3658 ));
3659 assert!(!super::looks_like_error_or_log_prefix(
3660 "walk directory tree respecting gitignore"
3661 ));
3662 }
3663
3664 #[test]
3665 fn strict_mode_rejects_format_and_error_literals_during_extraction() {
3666 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3667 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3671 unsafe {
3672 std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", "1");
3673 }
3674 let source = "\
3675fn handle_request() {
3676 let err = \"Invalid URL %s\";
3677 let log = \"sending request to the upstream server\";
3678 let fmt = \"received {count} items in batch\";
3679 let real = \"parse json body from the incoming request\";
3680}
3681";
3682 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3683 unsafe {
3684 match previous {
3685 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", value),
3686 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS"),
3687 }
3688 }
3689 let hint = result.expect("some token should survive");
3690 assert!(
3692 hint.contains("parse json body"),
3693 "real literal was filtered out: {hint}"
3694 );
3695 assert!(
3697 !hint.contains("Invalid URL"),
3698 "format-specifier literal leaked: {hint}"
3699 );
3700 assert!(
3701 !hint.contains("sending request"),
3702 "log-prefix literal leaked: {hint}"
3703 );
3704 assert!(
3705 !hint.contains("received {count}"),
3706 "python fstring literal leaked: {hint}"
3707 );
3708 }
3709
3710 #[test]
3711 fn strict_mode_leaves_comments_untouched() {
3712 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3713 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3717 unsafe {
3718 std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", "1");
3719 }
3720 let source = "\
3721fn do_work() {
3722 // Invalid inputs are rejected by this guard clause.
3723 // sending requests in parallel across worker threads.
3724 let _lit = \"format spec %s\";
3725}
3726";
3727 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3728 unsafe {
3729 match previous {
3730 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", value),
3731 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS"),
3732 }
3733 }
3734 let hint = result.expect("comments should survive strict mode");
3735 assert!(
3738 hint.contains("Invalid inputs") || hint.contains("rejected by this guard"),
3739 "strict mode swallowed a comment: {hint}"
3740 );
3741 assert!(
3743 !hint.contains("format spec"),
3744 "format-specifier literal leaked under strict mode: {hint}"
3745 );
3746 }
3747
3748 #[test]
3749 fn should_reject_literal_strict_composes_format_and_prefix() {
3750 assert!(super::should_reject_literal_strict("Invalid URL %s"));
3754 assert!(super::should_reject_literal_strict(
3755 "sending request to server"
3756 ));
3757 assert!(super::should_reject_literal_strict("value: {x:.2f}"));
3758 assert!(!super::should_reject_literal_strict(
3760 "parse json body from the incoming request"
3761 ));
3762 assert!(!super::should_reject_literal_strict(
3763 "compute cosine similarity between vectors"
3764 ));
3765 }
3766
3767 #[test]
3768 fn is_static_method_ident_accepts_pascal_and_rejects_snake() {
3769 assert!(super::is_static_method_ident("HashMap"));
3770 assert!(super::is_static_method_ident("Parser"));
3771 assert!(super::is_static_method_ident("A"));
3772 assert!(!super::is_static_method_ident("std"));
3775 assert!(!super::is_static_method_ident("fs"));
3776 assert!(!super::is_static_method_ident("_private"));
3777 assert!(!super::is_static_method_ident(""));
3778 }
3779
3780 #[test]
3781 fn extract_api_calls_gated_off_by_default() {
3782 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3783 let previous = std::env::var("CODELENS_EMBED_HINT_INCLUDE_API_CALLS").ok();
3785 unsafe {
3786 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_API_CALLS");
3787 }
3788 let source = "\
3789fn make_parser() {
3790 let p = Parser::new();
3791 let _ = HashMap::with_capacity(8);
3792}
3793";
3794 let result = extract_api_calls(source, 0, source.len());
3795 unsafe {
3796 if let Some(value) = previous {
3797 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_API_CALLS", value);
3798 }
3799 }
3800 assert!(result.is_none(), "gate leaked: {result:?}");
3801 }
3802
3803 #[test]
3804 fn extract_api_calls_captures_type_method_patterns() {
3805 let source = "\
3807fn open_db() {
3808 let p = Parser::new();
3809 let map = HashMap::with_capacity(16);
3810 let _ = tree_sitter::Parser::new();
3811}
3812";
3813 let hint = super::extract_api_calls_inner(source, 0, source.len())
3814 .expect("api calls should be produced");
3815 assert!(hint.contains("Parser::new"), "missing Parser::new: {hint}");
3816 assert!(
3817 hint.contains("HashMap::with_capacity"),
3818 "missing HashMap::with_capacity: {hint}"
3819 );
3820 }
3821
3822 #[test]
3823 fn extract_api_calls_rejects_module_prefixed_free_functions() {
3824 let source = "\
3827fn read_config() {
3828 let _ = std::fs::read_to_string(\"foo\");
3829 let _ = crate::util::parse();
3830}
3831";
3832 let hint = super::extract_api_calls_inner(source, 0, source.len());
3833 if let Some(hint) = hint {
3836 assert!(!hint.contains("std::fs"), "lowercase module leaked: {hint}");
3837 assert!(
3838 !hint.contains("fs::read_to_string"),
3839 "module-prefixed free function leaked: {hint}"
3840 );
3841 assert!(!hint.contains("crate::util"), "crate path leaked: {hint}");
3842 }
3843 }
3844
3845 #[test]
3846 fn extract_api_calls_deduplicates_repeated_calls() {
3847 let source = "\
3848fn hot_loop() {
3849 for _ in 0..10 {
3850 let _ = Parser::new();
3851 let _ = Parser::new();
3852 }
3853 let _ = Parser::new();
3854}
3855";
3856 let hint = super::extract_api_calls_inner(source, 0, source.len())
3857 .expect("api calls should be produced");
3858 let first = hint.find("Parser::new").expect("hit");
3859 let rest = &hint[first + "Parser::new".len()..];
3860 assert!(
3861 !rest.contains("Parser::new"),
3862 "duplicate not deduplicated: {hint}"
3863 );
3864 }
3865
3866 #[test]
3867 fn extract_api_calls_returns_none_when_body_has_no_type_calls() {
3868 let source = "\
3869fn plain() {
3870 let x = 1;
3871 let y = x + 2;
3872}
3873";
3874 assert!(super::extract_api_calls_inner(source, 0, source.len()).is_none());
3875 }
3876
3877 #[test]
3878 fn extract_nl_tokens_collects_comments_and_string_literals() {
3879 let source = "\
3883fn search_for_matches() {
3884 // skip comments and string literals during search
3885 let error = \"failed to open database\";
3886 let single = \"tok\";
3887 let path = \"src/foo/bar\";
3888 let keyword = match kind {
3889 Kind::Ident => \"detect client version\",
3890 _ => \"\",
3891 };
3892}
3893";
3894 let hint = super::extract_nl_tokens_inner(source, 0, source.len())
3900 .expect("nl tokens should be produced");
3901 let has_first_nl_signal = hint.contains("skip comments")
3905 || hint.contains("failed to open")
3906 || hint.contains("detect client");
3907 assert!(has_first_nl_signal, "no NL signal produced: {hint}");
3908 assert!(!hint.contains(" tok "), "short literal leaked: {hint}");
3910 assert!(!hint.contains("src/foo/bar"), "path literal leaked: {hint}");
3912 }
3913
3914 #[test]
3915 fn hint_char_budget_respects_env_override() {
3916 let previous = std::env::var("CODELENS_EMBED_HINT_CHARS").ok();
3917 unsafe {
3918 std::env::set_var("CODELENS_EMBED_HINT_CHARS", "120");
3919 }
3920 let budget = super::hint_char_budget();
3921 unsafe {
3922 match previous {
3923 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_CHARS", value),
3924 None => std::env::remove_var("CODELENS_EMBED_HINT_CHARS"),
3925 }
3926 }
3927 assert_eq!(budget, 120);
3928 }
3929
3930 #[test]
3931 fn embedding_to_bytes_roundtrip() {
3932 let floats = vec![1.0f32, -0.5, 0.0, 3.25];
3933 let bytes = embedding_to_bytes(&floats);
3934 assert_eq!(bytes.len(), 4 * 4);
3935 let recovered: Vec<f32> = bytes
3937 .chunks_exact(4)
3938 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
3939 .collect();
3940 assert_eq!(floats, recovered);
3941 }
3942
3943 #[test]
3944 fn duplicate_pair_key_is_order_independent() {
3945 let a = duplicate_pair_key("a.py", "foo", "b.py", "bar");
3946 let b = duplicate_pair_key("b.py", "bar", "a.py", "foo");
3947 assert_eq!(a, b);
3948 }
3949
3950 #[test]
3951 fn text_embedding_cache_updates_recency() {
3952 let mut cache = TextEmbeddingCache::new(2);
3953 cache.insert("a".into(), vec![1.0]);
3954 cache.insert("b".into(), vec![2.0]);
3955 assert_eq!(cache.get("a"), Some(vec![1.0]));
3956 cache.insert("c".into(), vec![3.0]);
3957
3958 assert_eq!(cache.get("a"), Some(vec![1.0]));
3959 assert_eq!(cache.get("b"), None);
3960 assert_eq!(cache.get("c"), Some(vec![3.0]));
3961 }
3962
3963 #[test]
3964 fn text_embedding_cache_can_be_disabled() {
3965 let mut cache = TextEmbeddingCache::new(0);
3966 cache.insert("a".into(), vec![1.0]);
3967 assert_eq!(cache.get("a"), None);
3968 }
3969
3970 #[test]
3971 fn engine_new_and_index() {
3972 let _lock = MODEL_LOCK.lock().unwrap();
3973 skip_without_embedding_model!();
3974 let (_dir, project) = make_project_with_source();
3975 let engine = EmbeddingEngine::new(&project).expect("engine should load");
3976 assert!(!engine.is_indexed());
3977
3978 let count = engine.index_from_project(&project).unwrap();
3979 assert_eq!(count, 2, "should index 2 symbols");
3980 assert!(engine.is_indexed());
3981 }
3982
3983 #[test]
3984 fn engine_search_returns_results() {
3985 let _lock = MODEL_LOCK.lock().unwrap();
3986 skip_without_embedding_model!();
3987 let (_dir, project) = make_project_with_source();
3988 let engine = EmbeddingEngine::new(&project).unwrap();
3989 engine.index_from_project(&project).unwrap();
3990
3991 let results = engine.search("hello function", 10).unwrap();
3992 assert!(!results.is_empty(), "search should return results");
3993 for r in &results {
3994 assert!(
3995 r.score >= -1.0 && r.score <= 1.0,
3996 "score should be in [-1,1]: {}",
3997 r.score
3998 );
3999 }
4000 }
4001
4002 #[test]
4003 fn engine_incremental_index() {
4004 let _lock = MODEL_LOCK.lock().unwrap();
4005 skip_without_embedding_model!();
4006 let (_dir, project) = make_project_with_source();
4007 let engine = EmbeddingEngine::new(&project).unwrap();
4008 engine.index_from_project(&project).unwrap();
4009 assert_eq!(engine.store.count().unwrap(), 2);
4010
4011 let count = engine.index_changed_files(&project, &["main.py"]).unwrap();
4013 assert_eq!(count, 2);
4014 assert_eq!(engine.store.count().unwrap(), 2);
4015 }
4016
4017 #[test]
4018 fn engine_reindex_preserves_symbol_count() {
4019 let _lock = MODEL_LOCK.lock().unwrap();
4020 skip_without_embedding_model!();
4021 let (_dir, project) = make_project_with_source();
4022 let engine = EmbeddingEngine::new(&project).unwrap();
4023 engine.index_from_project(&project).unwrap();
4024 assert_eq!(engine.store.count().unwrap(), 2);
4025
4026 let count = engine.index_from_project(&project).unwrap();
4027 assert_eq!(count, 2);
4028 assert_eq!(engine.store.count().unwrap(), 2);
4029 }
4030
4031 #[test]
4032 fn full_reindex_reuses_unchanged_embeddings() {
4033 let _lock = MODEL_LOCK.lock().unwrap();
4034 skip_without_embedding_model!();
4035 let (_dir, project) = make_project_with_source();
4036 let engine = EmbeddingEngine::new(&project).unwrap();
4037 engine.index_from_project(&project).unwrap();
4038
4039 replace_file_embeddings_with_sentinels(
4040 &engine,
4041 "main.py",
4042 &[("hello", 11.0), ("world", 22.0)],
4043 );
4044
4045 let count = engine.index_from_project(&project).unwrap();
4046 assert_eq!(count, 2);
4047
4048 let hello = engine
4049 .store
4050 .get_embedding("main.py", "hello")
4051 .unwrap()
4052 .expect("hello should exist");
4053 let world = engine
4054 .store
4055 .get_embedding("main.py", "world")
4056 .unwrap()
4057 .expect("world should exist");
4058 assert!(hello.embedding.iter().all(|value| *value == 11.0));
4059 assert!(world.embedding.iter().all(|value| *value == 22.0));
4060 }
4061
4062 #[test]
4063 fn full_reindex_reuses_unchanged_sibling_after_edit() {
4064 let _lock = MODEL_LOCK.lock().unwrap();
4065 skip_without_embedding_model!();
4066 let (dir, project) = make_project_with_source();
4067 let engine = EmbeddingEngine::new(&project).unwrap();
4068 engine.index_from_project(&project).unwrap();
4069
4070 replace_file_embeddings_with_sentinels(
4071 &engine,
4072 "main.py",
4073 &[("hello", 11.0), ("world", 22.0)],
4074 );
4075
4076 let updated_source =
4077 "def hello():\n print('hi')\n\ndef world(name):\n return name.upper()\n";
4078 write_python_file_with_symbols(
4079 dir.path(),
4080 "main.py",
4081 updated_source,
4082 "hash2",
4083 &[
4084 ("hello", "def hello():", "hello"),
4085 ("world", "def world(name):", "world"),
4086 ],
4087 );
4088
4089 let count = engine.index_from_project(&project).unwrap();
4090 assert_eq!(count, 2);
4091
4092 let hello = engine
4093 .store
4094 .get_embedding("main.py", "hello")
4095 .unwrap()
4096 .expect("hello should exist");
4097 let world = engine
4098 .store
4099 .get_embedding("main.py", "world")
4100 .unwrap()
4101 .expect("world should exist");
4102 assert!(hello.embedding.iter().all(|value| *value == 11.0));
4103 assert!(world.embedding.iter().any(|value| *value != 22.0));
4104 assert_eq!(engine.store.count().unwrap(), 2);
4105 }
4106
4107 #[test]
4108 fn full_reindex_removes_deleted_files() {
4109 let _lock = MODEL_LOCK.lock().unwrap();
4110 skip_without_embedding_model!();
4111 let (dir, project) = make_project_with_source();
4112 write_python_file_with_symbols(
4113 dir.path(),
4114 "extra.py",
4115 "def bonus():\n return 7\n",
4116 "hash-extra",
4117 &[("bonus", "def bonus():", "bonus")],
4118 );
4119
4120 let engine = EmbeddingEngine::new(&project).unwrap();
4121 engine.index_from_project(&project).unwrap();
4122 assert_eq!(engine.store.count().unwrap(), 3);
4123
4124 std::fs::remove_file(dir.path().join("extra.py")).unwrap();
4125 let db_path = crate::db::index_db_path(dir.path());
4126 let db = IndexDb::open(&db_path).unwrap();
4127 db.delete_file("extra.py").unwrap();
4128
4129 let count = engine.index_from_project(&project).unwrap();
4130 assert_eq!(count, 2);
4131 assert_eq!(engine.store.count().unwrap(), 2);
4132 assert!(
4133 engine
4134 .store
4135 .embeddings_for_files(&["extra.py"])
4136 .unwrap()
4137 .is_empty()
4138 );
4139 }
4140
4141 #[test]
4142 fn engine_model_change_recreates_db() {
4143 let _lock = MODEL_LOCK.lock().unwrap();
4144 skip_without_embedding_model!();
4145 let (_dir, project) = make_project_with_source();
4146
4147 let engine1 = EmbeddingEngine::new(&project).unwrap();
4149 engine1.index_from_project(&project).unwrap();
4150 assert_eq!(engine1.store.count().unwrap(), 2);
4151 drop(engine1);
4152
4153 let engine2 = EmbeddingEngine::new(&project).unwrap();
4155 assert!(engine2.store.count().unwrap() >= 2);
4156 }
4157
4158 #[test]
4159 fn inspect_existing_index_returns_model_and_count() {
4160 let _lock = MODEL_LOCK.lock().unwrap();
4161 skip_without_embedding_model!();
4162 let (_dir, project) = make_project_with_source();
4163 let engine = EmbeddingEngine::new(&project).unwrap();
4164 engine.index_from_project(&project).unwrap();
4165
4166 let info = EmbeddingEngine::inspect_existing_index(&project)
4167 .unwrap()
4168 .expect("index info should exist");
4169 assert_eq!(info.model_name, engine.model_name());
4170 assert_eq!(info.indexed_symbols, 2);
4171 }
4172
4173 #[test]
4174 fn inspect_existing_index_recovers_from_corrupt_db() {
4175 let (_dir, project) = make_project_with_source();
4176 let index_dir = project.as_path().join(".codelens/index");
4177 let db_path = index_dir.join("embeddings.db");
4178 let wal_path = index_dir.join("embeddings.db-wal");
4179 let shm_path = index_dir.join("embeddings.db-shm");
4180
4181 std::fs::write(&db_path, b"not a sqlite database").unwrap();
4182 std::fs::write(&wal_path, b"bad wal").unwrap();
4183 std::fs::write(&shm_path, b"bad shm").unwrap();
4184
4185 let info = EmbeddingEngine::inspect_existing_index(&project).unwrap();
4186 assert!(info.is_none());
4187
4188 assert!(db_path.is_file());
4189
4190 let backup_names: Vec<String> = std::fs::read_dir(&index_dir)
4191 .unwrap()
4192 .map(|entry| entry.unwrap().file_name().to_string_lossy().into_owned())
4193 .filter(|name| name.contains(".corrupt-"))
4194 .collect();
4195
4196 assert!(
4197 backup_names
4198 .iter()
4199 .any(|name| name.starts_with("embeddings.db.corrupt-")),
4200 "expected quarantined embedding db, found {backup_names:?}"
4201 );
4202 }
4203
4204 #[test]
4205 fn store_can_fetch_single_embedding_without_loading_all() {
4206 let _lock = MODEL_LOCK.lock().unwrap();
4207 skip_without_embedding_model!();
4208 let (_dir, project) = make_project_with_source();
4209 let engine = EmbeddingEngine::new(&project).unwrap();
4210 engine.index_from_project(&project).unwrap();
4211
4212 let chunk = engine
4213 .store
4214 .get_embedding("main.py", "hello")
4215 .unwrap()
4216 .expect("embedding should exist");
4217 assert_eq!(chunk.file_path, "main.py");
4218 assert_eq!(chunk.symbol_name, "hello");
4219 assert!(!chunk.embedding.is_empty());
4220 }
4221
4222 #[test]
4223 fn find_similar_code_uses_index_and_excludes_target_symbol() {
4224 let _lock = MODEL_LOCK.lock().unwrap();
4225 skip_without_embedding_model!();
4226 let (_dir, project) = make_project_with_source();
4227 let engine = EmbeddingEngine::new(&project).unwrap();
4228 engine.index_from_project(&project).unwrap();
4229
4230 let matches = engine.find_similar_code("main.py", "hello", 5).unwrap();
4231 assert!(!matches.is_empty());
4232 assert!(
4233 matches
4234 .iter()
4235 .all(|m| !(m.file_path == "main.py" && m.symbol_name == "hello"))
4236 );
4237 }
4238
4239 #[test]
4240 fn delete_by_file_removes_rows_in_one_batch() {
4241 let _lock = MODEL_LOCK.lock().unwrap();
4242 skip_without_embedding_model!();
4243 let (_dir, project) = make_project_with_source();
4244 let engine = EmbeddingEngine::new(&project).unwrap();
4245 engine.index_from_project(&project).unwrap();
4246
4247 let deleted = engine.store.delete_by_file(&["main.py"]).unwrap();
4248 assert_eq!(deleted, 2);
4249 assert_eq!(engine.store.count().unwrap(), 0);
4250 }
4251
4252 #[test]
4253 fn store_streams_embeddings_grouped_by_file() {
4254 let _lock = MODEL_LOCK.lock().unwrap();
4255 skip_without_embedding_model!();
4256 let (_dir, project) = make_project_with_source();
4257 let engine = EmbeddingEngine::new(&project).unwrap();
4258 engine.index_from_project(&project).unwrap();
4259
4260 let mut groups = Vec::new();
4261 engine
4262 .store
4263 .for_each_file_embeddings(&mut |file_path, chunks| {
4264 groups.push((file_path, chunks.len()));
4265 Ok(())
4266 })
4267 .unwrap();
4268
4269 assert_eq!(groups, vec![("main.py".to_string(), 2)]);
4270 }
4271
4272 #[test]
4273 fn store_fetches_embeddings_for_specific_files() {
4274 let _lock = MODEL_LOCK.lock().unwrap();
4275 skip_without_embedding_model!();
4276 let (_dir, project) = make_project_with_source();
4277 let engine = EmbeddingEngine::new(&project).unwrap();
4278 engine.index_from_project(&project).unwrap();
4279
4280 let chunks = engine.store.embeddings_for_files(&["main.py"]).unwrap();
4281 assert_eq!(chunks.len(), 2);
4282 assert!(chunks.iter().all(|chunk| chunk.file_path == "main.py"));
4283 }
4284
4285 #[test]
4286 fn store_fetches_embeddings_for_scored_chunks() {
4287 let _lock = MODEL_LOCK.lock().unwrap();
4288 skip_without_embedding_model!();
4289 let (_dir, project) = make_project_with_source();
4290 let engine = EmbeddingEngine::new(&project).unwrap();
4291 engine.index_from_project(&project).unwrap();
4292
4293 let scored = engine.search_scored("hello world function", 2).unwrap();
4294 let chunks = engine.store.embeddings_for_scored_chunks(&scored).unwrap();
4295
4296 assert_eq!(chunks.len(), scored.len());
4297 assert!(scored.iter().all(|candidate| chunks.iter().any(|chunk| {
4298 chunk.file_path == candidate.file_path
4299 && chunk.symbol_name == candidate.symbol_name
4300 && chunk.line == candidate.line
4301 && chunk.signature == candidate.signature
4302 && chunk.name_path == candidate.name_path
4303 })));
4304 }
4305
4306 #[test]
4307 fn find_misplaced_code_returns_per_file_outliers() {
4308 let _lock = MODEL_LOCK.lock().unwrap();
4309 skip_without_embedding_model!();
4310 let (_dir, project) = make_project_with_source();
4311 let engine = EmbeddingEngine::new(&project).unwrap();
4312 engine.index_from_project(&project).unwrap();
4313
4314 let outliers = engine.find_misplaced_code(5).unwrap();
4315 assert_eq!(outliers.len(), 2);
4316 assert!(outliers.iter().all(|item| item.file_path == "main.py"));
4317 }
4318
4319 #[test]
4320 fn find_duplicates_uses_batched_candidate_embeddings() {
4321 let _lock = MODEL_LOCK.lock().unwrap();
4322 skip_without_embedding_model!();
4323 let (_dir, project) = make_project_with_source();
4324 let engine = EmbeddingEngine::new(&project).unwrap();
4325 engine.index_from_project(&project).unwrap();
4326
4327 replace_file_embeddings_with_sentinels(
4328 &engine,
4329 "main.py",
4330 &[("hello", 5.0), ("world", 5.0)],
4331 );
4332
4333 let duplicates = engine.find_duplicates(0.99, 4).unwrap();
4334 assert!(!duplicates.is_empty());
4335 assert!(duplicates.iter().any(|pair| {
4336 (pair.symbol_a == "main.py:hello" && pair.symbol_b == "main.py:world")
4337 || (pair.symbol_a == "main.py:world" && pair.symbol_b == "main.py:hello")
4338 }));
4339 }
4340
4341 #[test]
4342 fn search_scored_returns_raw_chunks() {
4343 let _lock = MODEL_LOCK.lock().unwrap();
4344 skip_without_embedding_model!();
4345 let (_dir, project) = make_project_with_source();
4346 let engine = EmbeddingEngine::new(&project).unwrap();
4347 engine.index_from_project(&project).unwrap();
4348
4349 let chunks = engine.search_scored("world function", 5).unwrap();
4350 assert!(!chunks.is_empty());
4351 for c in &chunks {
4352 assert!(!c.file_path.is_empty());
4353 assert!(!c.symbol_name.is_empty());
4354 }
4355 }
4356
4357 #[test]
4358 fn configured_embedding_model_name_defaults_to_codesearchnet() {
4359 assert_eq!(configured_embedding_model_name(), CODESEARCH_MODEL_NAME);
4360 }
4361
4362 #[test]
4363 fn recommended_embed_threads_caps_macos_style_load() {
4364 let threads = recommended_embed_threads();
4365 assert!(threads >= 1);
4366 assert!(threads <= 8);
4367 }
4368
4369 #[test]
4370 fn embed_batch_size_has_safe_default_floor() {
4371 assert!(embed_batch_size() >= 1);
4372 if cfg!(target_os = "macos") {
4373 assert!(embed_batch_size() <= DEFAULT_MACOS_EMBED_BATCH_SIZE);
4374 }
4375 }
4376}