1use crate::db::IndexDb;
5use crate::embedding_store::{EmbeddingChunk, EmbeddingStore, ScoredChunk};
6use crate::project::ProjectRoot;
7use anyhow::{Context, Result};
8use fastembed::{
9 ExecutionProviderDispatch, InitOptionsUserDefined, TextEmbedding, TokenizerFiles,
10 UserDefinedEmbeddingModel,
11};
12use rusqlite::Connection;
13use serde::Serialize;
14use std::collections::{HashMap, HashSet, VecDeque};
15use std::sync::{Arc, Mutex, Once};
16use std::thread::available_parallelism;
17use tracing::debug;
18
19pub(super) mod ffi {
21 use anyhow::Result;
22
23 pub fn register_sqlite_vec() -> Result<()> {
24 let rc = unsafe {
25 rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute::<
26 *const (),
27 unsafe extern "C" fn(
28 *mut rusqlite::ffi::sqlite3,
29 *mut *mut i8,
30 *const rusqlite::ffi::sqlite3_api_routines,
31 ) -> i32,
32 >(
33 sqlite_vec::sqlite3_vec_init as *const ()
34 )))
35 };
36 if rc != rusqlite::ffi::SQLITE_OK {
37 anyhow::bail!("failed to register sqlite-vec extension (SQLite error code: {rc})");
38 }
39 Ok(())
40 }
41
42 #[cfg(target_os = "macos")]
43 pub fn sysctl_usize(name: &[u8]) -> Option<usize> {
44 let mut value: libc::c_uint = 0;
45 let mut size = std::mem::size_of::<libc::c_uint>();
46 let rc = unsafe {
47 libc::sysctlbyname(
48 name.as_ptr().cast(),
49 (&mut value as *mut libc::c_uint).cast(),
50 &mut size,
51 std::ptr::null_mut(),
52 0,
53 )
54 };
55 (rc == 0 && size == std::mem::size_of::<libc::c_uint>()).then_some(value as usize)
56 }
57}
58
59#[derive(Debug, Clone, Serialize)]
61pub struct SemanticMatch {
62 pub file_path: String,
63 pub symbol_name: String,
64 pub kind: String,
65 pub line: usize,
66 pub signature: String,
67 pub name_path: String,
68 pub score: f64,
69}
70
71impl From<ScoredChunk> for SemanticMatch {
72 fn from(c: ScoredChunk) -> Self {
73 Self {
74 file_path: c.file_path,
75 symbol_name: c.symbol_name,
76 kind: c.kind,
77 line: c.line,
78 signature: c.signature,
79 name_path: c.name_path,
80 score: c.score,
81 }
82 }
83}
84
85mod vec_store;
86use vec_store::SqliteVecStore;
87
88type ReusableEmbeddingKey = (String, String, String, String, String, String);
89
90fn reusable_embedding_key(
91 file_path: &str,
92 symbol_name: &str,
93 kind: &str,
94 signature: &str,
95 name_path: &str,
96 text: &str,
97) -> ReusableEmbeddingKey {
98 (
99 file_path.to_owned(),
100 symbol_name.to_owned(),
101 kind.to_owned(),
102 signature.to_owned(),
103 name_path.to_owned(),
104 text.to_owned(),
105 )
106}
107
108fn reusable_embedding_key_for_chunk(chunk: &EmbeddingChunk) -> ReusableEmbeddingKey {
109 reusable_embedding_key(
110 &chunk.file_path,
111 &chunk.symbol_name,
112 &chunk.kind,
113 &chunk.signature,
114 &chunk.name_path,
115 &chunk.text,
116 )
117}
118
119fn reusable_embedding_key_for_symbol(
120 sym: &crate::db::SymbolWithFile,
121 text: &str,
122) -> ReusableEmbeddingKey {
123 reusable_embedding_key(
124 &sym.file_path,
125 &sym.name,
126 &sym.kind,
127 &sym.signature,
128 &sym.name_path,
129 text,
130 )
131}
132
133const DEFAULT_EMBED_BATCH_SIZE: usize = 128;
136const DEFAULT_MACOS_EMBED_BATCH_SIZE: usize = 128;
137const DEFAULT_TEXT_EMBED_CACHE_SIZE: usize = 256;
138const DEFAULT_MACOS_TEXT_EMBED_CACHE_SIZE: usize = 1024;
139const CODESEARCH_DIMENSION: usize = 384;
140const DEFAULT_MAX_EMBED_SYMBOLS: usize = 50_000;
141const CHANGED_FILE_QUERY_CHUNK: usize = 128;
142const DEFAULT_DUPLICATE_SCAN_BATCH_SIZE: usize = 128;
143static ORT_ENV_INIT: Once = Once::new();
144
145const CODESEARCH_MODEL_NAME: &str = "MiniLM-L12-CodeSearchNet-INT8";
148
149pub struct EmbeddingEngine {
150 model: Mutex<TextEmbedding>,
151 store: Box<dyn EmbeddingStore>,
152 model_name: String,
153 runtime_info: EmbeddingRuntimeInfo,
154 text_embed_cache: Mutex<TextEmbeddingCache>,
155 indexing: std::sync::atomic::AtomicBool,
156}
157
158#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
159pub struct EmbeddingIndexInfo {
160 pub model_name: String,
161 pub indexed_symbols: usize,
162}
163
164#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
165pub struct EmbeddingRuntimeInfo {
166 pub runtime_preference: String,
167 pub backend: String,
168 pub threads: usize,
169 pub max_length: usize,
170 pub coreml_model_format: Option<String>,
171 pub coreml_compute_units: Option<String>,
172 pub coreml_static_input_shapes: Option<bool>,
173 pub coreml_profile_compute_plan: Option<bool>,
174 pub coreml_specialization_strategy: Option<String>,
175 pub coreml_model_cache_dir: Option<String>,
176 pub fallback_reason: Option<String>,
177}
178
179struct TextEmbeddingCache {
180 capacity: usize,
181 order: VecDeque<String>,
182 entries: HashMap<String, Vec<f32>>,
183}
184
185impl TextEmbeddingCache {
186 fn new(capacity: usize) -> Self {
187 Self {
188 capacity,
189 order: VecDeque::new(),
190 entries: HashMap::new(),
191 }
192 }
193
194 fn get(&mut self, key: &str) -> Option<Vec<f32>> {
195 let value = self.entries.get(key)?.clone();
196 self.touch(key);
197 Some(value)
198 }
199
200 fn insert(&mut self, key: String, value: Vec<f32>) {
201 if self.capacity == 0 {
202 return;
203 }
204
205 self.entries.insert(key.clone(), value);
206 self.touch(&key);
207
208 while self.entries.len() > self.capacity {
209 if let Some(oldest) = self.order.pop_front() {
210 self.entries.remove(&oldest);
211 } else {
212 break;
213 }
214 }
215 }
216
217 fn touch(&mut self, key: &str) {
218 if let Some(position) = self.order.iter().position(|existing| existing == key) {
219 self.order.remove(position);
220 }
221 self.order.push_back(key.to_owned());
222 }
223}
224
225fn resolve_model_dir() -> Result<std::path::PathBuf> {
233 if let Ok(dir) = std::env::var("CODELENS_MODEL_DIR") {
235 let p = std::path::PathBuf::from(dir).join("codesearch");
236 if p.join("model.onnx").exists() {
237 return Ok(p);
238 }
239 }
240
241 if let Ok(exe) = std::env::current_exe()
243 && let Some(exe_dir) = exe.parent()
244 {
245 let p = exe_dir.join("models").join("codesearch");
246 if p.join("model.onnx").exists() {
247 return Ok(p);
248 }
249 }
250
251 if let Some(home) = dirs_fallback() {
253 let p = home
254 .join(".cache")
255 .join("codelens")
256 .join("models")
257 .join("codesearch");
258 if p.join("model.onnx").exists() {
259 return Ok(p);
260 }
261 }
262
263 let dev_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
265 .join("models")
266 .join("codesearch");
267 if dev_path.join("model.onnx").exists() {
268 return Ok(dev_path);
269 }
270
271 anyhow::bail!(
272 "CodeSearchNet model not found. Place model files in one of:\n\
273 - $CODELENS_MODEL_DIR/codesearch/\n\
274 - <executable>/models/codesearch/\n\
275 - ~/.cache/codelens/models/codesearch/\n\
276 Required files: model.onnx, tokenizer.json, config.json, special_tokens_map.json, tokenizer_config.json"
277 )
278}
279
280fn dirs_fallback() -> Option<std::path::PathBuf> {
281 std::env::var_os("HOME").map(std::path::PathBuf::from)
282}
283
284fn parse_usize_env(name: &str) -> Option<usize> {
285 std::env::var(name)
286 .ok()
287 .and_then(|v| v.trim().parse::<usize>().ok())
288 .filter(|v| *v > 0)
289}
290
291fn parse_bool_env(name: &str) -> Option<bool> {
292 std::env::var(name).ok().and_then(|value| {
293 let normalized = value.trim().to_ascii_lowercase();
294 match normalized.as_str() {
295 "1" | "true" | "yes" | "on" => Some(true),
296 "0" | "false" | "no" | "off" => Some(false),
297 _ => None,
298 }
299 })
300}
301
302#[cfg(target_os = "macos")]
303fn apple_perf_cores() -> Option<usize> {
304 ffi::sysctl_usize(b"hw.perflevel0.physicalcpu\0")
305 .filter(|value| *value > 0)
306 .or_else(|| ffi::sysctl_usize(b"hw.physicalcpu\0").filter(|value| *value > 0))
307}
308
309#[cfg(not(target_os = "macos"))]
310fn apple_perf_cores() -> Option<usize> {
311 None
312}
313
314pub fn configured_embedding_runtime_preference() -> String {
315 let requested = std::env::var("CODELENS_EMBED_PROVIDER")
316 .ok()
317 .map(|value| value.trim().to_ascii_lowercase());
318
319 match requested.as_deref() {
320 Some("cpu") => "cpu".to_string(),
321 Some("coreml") if cfg!(target_os = "macos") => "coreml".to_string(),
322 Some("coreml") => "cpu".to_string(),
323 _ if cfg!(target_os = "macos") => "coreml_preferred".to_string(),
324 _ => "cpu".to_string(),
325 }
326}
327
328pub fn configured_embedding_threads() -> usize {
329 recommended_embed_threads()
330}
331
332fn configured_embedding_max_length() -> usize {
333 parse_usize_env("CODELENS_EMBED_MAX_LENGTH")
334 .unwrap_or(256)
335 .clamp(32, 512)
336}
337
338fn configured_embedding_text_cache_size() -> usize {
339 std::env::var("CODELENS_EMBED_TEXT_CACHE_SIZE")
340 .ok()
341 .and_then(|value| value.trim().parse::<usize>().ok())
342 .unwrap_or({
343 if cfg!(target_os = "macos") {
344 DEFAULT_MACOS_TEXT_EMBED_CACHE_SIZE
345 } else {
346 DEFAULT_TEXT_EMBED_CACHE_SIZE
347 }
348 })
349 .min(8192)
350}
351
352#[cfg(target_os = "macos")]
353fn configured_coreml_compute_units_name() -> String {
354 match std::env::var("CODELENS_EMBED_COREML_COMPUTE_UNITS")
355 .ok()
356 .map(|value| value.trim().to_ascii_lowercase())
357 .as_deref()
358 {
359 Some("all") => "all".to_string(),
360 Some("cpu") | Some("cpu_only") => "cpu_only".to_string(),
361 Some("gpu") | Some("cpu_and_gpu") => "cpu_and_gpu".to_string(),
362 Some("ane") | Some("neural_engine") | Some("cpu_and_neural_engine") => {
363 "cpu_and_neural_engine".to_string()
364 }
365 _ => "cpu_and_neural_engine".to_string(),
366 }
367}
368
369#[cfg(target_os = "macos")]
370fn configured_coreml_model_format_name() -> String {
371 match std::env::var("CODELENS_EMBED_COREML_MODEL_FORMAT")
372 .ok()
373 .map(|value| value.trim().to_ascii_lowercase())
374 .as_deref()
375 {
376 Some("neuralnetwork") | Some("neural_network") => "neural_network".to_string(),
377 _ => "mlprogram".to_string(),
378 }
379}
380
381#[cfg(target_os = "macos")]
382fn configured_coreml_profile_compute_plan() -> bool {
383 parse_bool_env("CODELENS_EMBED_COREML_PROFILE_PLAN").unwrap_or(false)
384}
385
386#[cfg(target_os = "macos")]
387fn configured_coreml_static_input_shapes() -> bool {
388 parse_bool_env("CODELENS_EMBED_COREML_STATIC_INPUT_SHAPES").unwrap_or(true)
389}
390
391#[cfg(target_os = "macos")]
392fn configured_coreml_specialization_strategy_name() -> String {
393 match std::env::var("CODELENS_EMBED_COREML_SPECIALIZATION")
394 .ok()
395 .map(|value| value.trim().to_ascii_lowercase())
396 .as_deref()
397 {
398 Some("default") => "default".to_string(),
399 _ => "fast_prediction".to_string(),
400 }
401}
402
403#[cfg(target_os = "macos")]
404fn configured_coreml_model_cache_dir() -> std::path::PathBuf {
405 dirs_fallback()
406 .unwrap_or_else(std::env::temp_dir)
407 .join(".cache")
408 .join("codelens")
409 .join("coreml-cache")
410 .join("codesearch")
411}
412
413fn recommended_embed_threads() -> usize {
414 if let Some(explicit) = parse_usize_env("CODELENS_EMBED_THREADS") {
415 return explicit.max(1);
416 }
417
418 let available = available_parallelism().map(|n| n.get()).unwrap_or(1);
419 if cfg!(target_os = "macos") {
420 apple_perf_cores()
421 .unwrap_or(available)
422 .min(available)
423 .clamp(1, 8)
424 } else {
425 available.div_ceil(2).clamp(1, 8)
426 }
427}
428
429fn embed_batch_size() -> usize {
430 parse_usize_env("CODELENS_EMBED_BATCH_SIZE").unwrap_or({
431 if cfg!(target_os = "macos") {
432 DEFAULT_MACOS_EMBED_BATCH_SIZE
433 } else {
434 DEFAULT_EMBED_BATCH_SIZE
435 }
436 })
437}
438
439fn max_embed_symbols() -> usize {
440 parse_usize_env("CODELENS_MAX_EMBED_SYMBOLS").unwrap_or(DEFAULT_MAX_EMBED_SYMBOLS)
441}
442
443fn set_env_if_unset(name: &str, value: impl Into<String>) {
444 if std::env::var_os(name).is_none() {
445 unsafe {
448 std::env::set_var(name, value.into());
449 }
450 }
451}
452
453fn configure_embedding_runtime() {
454 let threads = recommended_embed_threads();
455 let runtime_preference = configured_embedding_runtime_preference();
456
457 set_env_if_unset("OMP_NUM_THREADS", threads.to_string());
460 set_env_if_unset("OMP_WAIT_POLICY", "PASSIVE");
461 set_env_if_unset("OMP_DYNAMIC", "FALSE");
462 set_env_if_unset("TOKENIZERS_PARALLELISM", "false");
463 if cfg!(target_os = "macos") {
464 set_env_if_unset("VECLIB_MAXIMUM_THREADS", threads.to_string());
465 }
466
467 ORT_ENV_INIT.call_once(|| {
468 let pool = ort::environment::GlobalThreadPoolOptions::default()
469 .with_intra_threads(threads)
470 .and_then(|pool| pool.with_inter_threads(1))
471 .and_then(|pool| pool.with_spin_control(false));
472
473 if let Ok(pool) = pool {
474 let _ = ort::init()
475 .with_name("codelens-embedding")
476 .with_telemetry(false)
477 .with_global_thread_pool(pool)
478 .commit();
479 }
480 });
481
482 debug!(
483 threads,
484 runtime_preference = %runtime_preference,
485 "configured embedding runtime"
486 );
487}
488
489fn requested_embedding_model_override() -> Result<Option<String>> {
490 let env_model = std::env::var("CODELENS_EMBED_MODEL").ok();
491 let Some(model_id) = env_model else {
492 return Ok(None);
493 };
494 if model_id.is_empty() || model_id == CODESEARCH_MODEL_NAME {
495 return Ok(None);
496 }
497
498 #[cfg(feature = "model-bakeoff")]
499 {
500 return Ok(Some(model_id));
501 }
502
503 #[cfg(not(feature = "model-bakeoff"))]
504 {
505 anyhow::bail!(
506 "CODELENS_EMBED_MODEL={model_id} requires the `model-bakeoff` feature; \
507 rebuild the binary with `--features model-bakeoff` to run alternative model bake-offs"
508 );
509 }
510}
511
512pub fn configured_embedding_runtime_info() -> EmbeddingRuntimeInfo {
513 let runtime_preference = configured_embedding_runtime_preference();
514 let threads = configured_embedding_threads();
515
516 #[cfg(target_os = "macos")]
517 {
518 let coreml_enabled = runtime_preference != "cpu";
519 EmbeddingRuntimeInfo {
520 runtime_preference,
521 backend: "not_loaded".to_string(),
522 threads,
523 max_length: configured_embedding_max_length(),
524 coreml_model_format: coreml_enabled.then(configured_coreml_model_format_name),
525 coreml_compute_units: coreml_enabled.then(configured_coreml_compute_units_name),
526 coreml_static_input_shapes: coreml_enabled.then(configured_coreml_static_input_shapes),
527 coreml_profile_compute_plan: coreml_enabled
528 .then(configured_coreml_profile_compute_plan),
529 coreml_specialization_strategy: coreml_enabled
530 .then(configured_coreml_specialization_strategy_name),
531 coreml_model_cache_dir: coreml_enabled
532 .then(|| configured_coreml_model_cache_dir().display().to_string()),
533 fallback_reason: None,
534 }
535 }
536
537 #[cfg(not(target_os = "macos"))]
538 {
539 EmbeddingRuntimeInfo {
540 runtime_preference,
541 backend: "not_loaded".to_string(),
542 threads,
543 max_length: configured_embedding_max_length(),
544 coreml_model_format: None,
545 coreml_compute_units: None,
546 coreml_static_input_shapes: None,
547 coreml_profile_compute_plan: None,
548 coreml_specialization_strategy: None,
549 coreml_model_cache_dir: None,
550 fallback_reason: None,
551 }
552 }
553}
554
555#[cfg(target_os = "macos")]
556fn build_coreml_execution_provider() -> ExecutionProviderDispatch {
557 use ort::ep::{
558 CoreML,
559 coreml::{ComputeUnits, ModelFormat, SpecializationStrategy},
560 };
561
562 let compute_units = match configured_coreml_compute_units_name().as_str() {
563 "all" => ComputeUnits::All,
564 "cpu_only" => ComputeUnits::CPUOnly,
565 "cpu_and_gpu" => ComputeUnits::CPUAndGPU,
566 _ => ComputeUnits::CPUAndNeuralEngine,
567 };
568 let model_format = match configured_coreml_model_format_name().as_str() {
569 "neural_network" => ModelFormat::NeuralNetwork,
570 _ => ModelFormat::MLProgram,
571 };
572 let specialization = match configured_coreml_specialization_strategy_name().as_str() {
573 "default" => SpecializationStrategy::Default,
574 _ => SpecializationStrategy::FastPrediction,
575 };
576 let cache_dir = configured_coreml_model_cache_dir();
577 let _ = std::fs::create_dir_all(&cache_dir);
578
579 CoreML::default()
580 .with_model_format(model_format)
581 .with_compute_units(compute_units)
582 .with_static_input_shapes(configured_coreml_static_input_shapes())
583 .with_specialization_strategy(specialization)
584 .with_profile_compute_plan(configured_coreml_profile_compute_plan())
585 .with_model_cache_dir(cache_dir.display().to_string())
586 .build()
587 .error_on_failure()
588}
589
590fn cpu_runtime_info(
591 runtime_preference: String,
592 fallback_reason: Option<String>,
593) -> EmbeddingRuntimeInfo {
594 EmbeddingRuntimeInfo {
595 runtime_preference,
596 backend: "cpu".to_string(),
597 threads: configured_embedding_threads(),
598 max_length: configured_embedding_max_length(),
599 coreml_model_format: None,
600 coreml_compute_units: None,
601 coreml_static_input_shapes: None,
602 coreml_profile_compute_plan: None,
603 coreml_specialization_strategy: None,
604 coreml_model_cache_dir: None,
605 fallback_reason,
606 }
607}
608
609#[cfg(target_os = "macos")]
610fn coreml_runtime_info(
611 runtime_preference: String,
612 fallback_reason: Option<String>,
613) -> EmbeddingRuntimeInfo {
614 EmbeddingRuntimeInfo {
615 runtime_preference,
616 backend: if fallback_reason.is_some() {
617 "cpu".to_string()
618 } else {
619 "coreml".to_string()
620 },
621 threads: configured_embedding_threads(),
622 max_length: configured_embedding_max_length(),
623 coreml_model_format: Some(configured_coreml_model_format_name()),
624 coreml_compute_units: Some(configured_coreml_compute_units_name()),
625 coreml_static_input_shapes: Some(configured_coreml_static_input_shapes()),
626 coreml_profile_compute_plan: Some(configured_coreml_profile_compute_plan()),
627 coreml_specialization_strategy: Some(configured_coreml_specialization_strategy_name()),
628 coreml_model_cache_dir: Some(configured_coreml_model_cache_dir().display().to_string()),
629 fallback_reason,
630 }
631}
632
633#[cfg(feature = "model-bakeoff")]
638fn load_fastembed_builtin(
639 model_id: &str,
640) -> Result<(TextEmbedding, usize, String, EmbeddingRuntimeInfo)> {
641 use fastembed::EmbeddingModel;
642
643 let (model_enum, expected_dim) = match model_id {
645 "all-MiniLM-L6-v2" | "sentence-transformers/all-MiniLM-L6-v2" => {
646 (EmbeddingModel::AllMiniLML6V2, 384)
647 }
648 "all-MiniLM-L12-v2" | "sentence-transformers/all-MiniLM-L12-v2" => {
649 (EmbeddingModel::AllMiniLML12V2, 384)
650 }
651 "bge-small-en-v1.5" | "BAAI/bge-small-en-v1.5" => {
652 (EmbeddingModel::BGESmallENV15, 384)
653 }
654 "bge-base-en-v1.5" | "BAAI/bge-base-en-v1.5" => {
655 (EmbeddingModel::BGEBaseENV15, 768)
656 }
657 "nomic-embed-text-v1.5" | "nomic-ai/nomic-embed-text-v1.5" => {
658 (EmbeddingModel::NomicEmbedTextV15, 768)
659 }
660 other => {
661 anyhow::bail!(
662 "Unknown fastembed model: {other}. \
663 Supported: all-MiniLM-L6-v2, all-MiniLM-L12-v2, bge-small-en-v1.5, \
664 bge-base-en-v1.5, nomic-embed-text-v1.5"
665 );
666 }
667 };
668
669 let init = fastembed::InitOptionsWithLength::new(model_enum)
670 .with_max_length(configured_embedding_max_length())
671 .with_cache_dir(std::env::temp_dir().join("codelens-fastembed-cache"))
672 .with_show_download_progress(true);
673 let model =
674 TextEmbedding::try_new(init).with_context(|| format!("failed to load {model_id}"))?;
675
676 let runtime_info = cpu_runtime_info("cpu".to_string(), None);
677
678 tracing::info!(
679 model = model_id,
680 dimension = expected_dim,
681 "loaded fastembed built-in model for A/B comparison"
682 );
683
684 Ok((model, expected_dim, model_id.to_string(), runtime_info))
685}
686
687fn load_codesearch_model() -> Result<(TextEmbedding, usize, String, EmbeddingRuntimeInfo)> {
689 configure_embedding_runtime();
690
691 if let Some(model_id) = requested_embedding_model_override()? {
693 #[cfg(feature = "model-bakeoff")]
694 {
695 return load_fastembed_builtin(&model_id);
696 }
697
698 #[cfg(not(feature = "model-bakeoff"))]
699 unreachable!("alternative embedding model override should have errored");
700 }
701
702 let model_dir = resolve_model_dir()?;
703
704 let onnx_bytes =
705 std::fs::read(model_dir.join("model.onnx")).context("failed to read model.onnx")?;
706 let tokenizer_bytes =
707 std::fs::read(model_dir.join("tokenizer.json")).context("failed to read tokenizer.json")?;
708 let config_bytes =
709 std::fs::read(model_dir.join("config.json")).context("failed to read config.json")?;
710 let special_tokens_bytes = std::fs::read(model_dir.join("special_tokens_map.json"))
711 .context("failed to read special_tokens_map.json")?;
712 let tokenizer_config_bytes = std::fs::read(model_dir.join("tokenizer_config.json"))
713 .context("failed to read tokenizer_config.json")?;
714
715 let user_model = UserDefinedEmbeddingModel::new(
716 onnx_bytes,
717 TokenizerFiles {
718 tokenizer_file: tokenizer_bytes,
719 config_file: config_bytes,
720 special_tokens_map_file: special_tokens_bytes,
721 tokenizer_config_file: tokenizer_config_bytes,
722 },
723 );
724
725 let runtime_preference = configured_embedding_runtime_preference();
726
727 #[cfg(target_os = "macos")]
728 if runtime_preference != "cpu" {
729 let init_opts = InitOptionsUserDefined::new()
730 .with_max_length(configured_embedding_max_length())
731 .with_execution_providers(vec![build_coreml_execution_provider()]);
732 match TextEmbedding::try_new_from_user_defined(user_model.clone(), init_opts) {
733 Ok(model) => {
734 let runtime_info = coreml_runtime_info(runtime_preference.clone(), None);
735 debug!(
736 threads = runtime_info.threads,
737 runtime_preference = %runtime_info.runtime_preference,
738 backend = %runtime_info.backend,
739 coreml_compute_units = ?runtime_info.coreml_compute_units,
740 coreml_static_input_shapes = ?runtime_info.coreml_static_input_shapes,
741 coreml_profile_compute_plan = ?runtime_info.coreml_profile_compute_plan,
742 coreml_specialization_strategy = ?runtime_info.coreml_specialization_strategy,
743 coreml_model_cache_dir = ?runtime_info.coreml_model_cache_dir,
744 "loaded CodeSearchNet embedding model"
745 );
746 return Ok((
747 model,
748 CODESEARCH_DIMENSION,
749 CODESEARCH_MODEL_NAME.to_string(),
750 runtime_info,
751 ));
752 }
753 Err(err) => {
754 let reason = err.to_string();
755 debug!(
756 runtime_preference = %runtime_preference,
757 fallback_reason = %reason,
758 "CoreML embedding load failed; falling back to CPU"
759 );
760 let model = TextEmbedding::try_new_from_user_defined(
761 user_model,
762 InitOptionsUserDefined::new()
763 .with_max_length(configured_embedding_max_length()),
764 )
765 .context("failed to load CodeSearchNet embedding model")?;
766 let runtime_info = coreml_runtime_info(runtime_preference.clone(), Some(reason));
767 debug!(
768 threads = runtime_info.threads,
769 runtime_preference = %runtime_info.runtime_preference,
770 backend = %runtime_info.backend,
771 coreml_compute_units = ?runtime_info.coreml_compute_units,
772 coreml_static_input_shapes = ?runtime_info.coreml_static_input_shapes,
773 coreml_profile_compute_plan = ?runtime_info.coreml_profile_compute_plan,
774 coreml_specialization_strategy = ?runtime_info.coreml_specialization_strategy,
775 coreml_model_cache_dir = ?runtime_info.coreml_model_cache_dir,
776 fallback_reason = ?runtime_info.fallback_reason,
777 "loaded CodeSearchNet embedding model"
778 );
779 return Ok((
780 model,
781 CODESEARCH_DIMENSION,
782 CODESEARCH_MODEL_NAME.to_string(),
783 runtime_info,
784 ));
785 }
786 }
787 }
788
789 let model = TextEmbedding::try_new_from_user_defined(
790 user_model,
791 InitOptionsUserDefined::new().with_max_length(configured_embedding_max_length()),
792 )
793 .context("failed to load CodeSearchNet embedding model")?;
794 let runtime_info = cpu_runtime_info(runtime_preference.clone(), None);
795
796 debug!(
797 threads = runtime_info.threads,
798 runtime_preference = %runtime_info.runtime_preference,
799 backend = %runtime_info.backend,
800 "loaded CodeSearchNet embedding model"
801 );
802
803 Ok((
804 model,
805 CODESEARCH_DIMENSION,
806 CODESEARCH_MODEL_NAME.to_string(),
807 runtime_info,
808 ))
809}
810
811pub fn configured_embedding_model_name() -> String {
812 std::env::var("CODELENS_EMBED_MODEL").unwrap_or_else(|_| CODESEARCH_MODEL_NAME.to_string())
813}
814
815pub fn embedding_model_assets_available() -> bool {
816 resolve_model_dir().is_ok()
817}
818
819impl EmbeddingEngine {
820 fn embed_texts_cached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
821 if texts.is_empty() {
822 return Ok(Vec::new());
823 }
824
825 let mut resolved: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
826 let mut missing_order: Vec<String> = Vec::new();
827 let mut missing_positions: HashMap<String, Vec<usize>> = HashMap::new();
828
829 {
830 let mut cache = self
831 .text_embed_cache
832 .lock()
833 .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
834 for (index, text) in texts.iter().enumerate() {
835 if let Some(cached) = cache.get(text) {
836 resolved[index] = Some(cached);
837 } else {
838 let key = (*text).to_owned();
839 if !missing_positions.contains_key(&key) {
840 missing_order.push(key.clone());
841 }
842 missing_positions.entry(key).or_default().push(index);
843 }
844 }
845 }
846
847 if !missing_order.is_empty() {
848 let missing_refs: Vec<&str> = missing_order.iter().map(String::as_str).collect();
849 let embeddings = self
850 .model
851 .lock()
852 .map_err(|_| anyhow::anyhow!("model lock"))?
853 .embed(missing_refs, None)
854 .context("text embedding failed")?;
855
856 let mut cache = self
857 .text_embed_cache
858 .lock()
859 .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
860 for (text, embedding) in missing_order.into_iter().zip(embeddings.into_iter()) {
861 cache.insert(text.clone(), embedding.clone());
862 if let Some(indices) = missing_positions.remove(&text) {
863 for index in indices {
864 resolved[index] = Some(embedding.clone());
865 }
866 }
867 }
868 }
869
870 resolved
871 .into_iter()
872 .map(|item| item.ok_or_else(|| anyhow::anyhow!("missing embedding cache entry")))
873 .collect()
874 }
875
876 pub fn new(project: &ProjectRoot) -> Result<Self> {
877 let (model, dimension, model_name, runtime_info) = load_codesearch_model()?;
878
879 let db_dir = project.as_path().join(".codelens/index");
880 std::fs::create_dir_all(&db_dir)?;
881 let db_path = db_dir.join("embeddings.db");
882
883 let store = SqliteVecStore::new(&db_path, dimension, &model_name)?;
884
885 Ok(Self {
886 model: Mutex::new(model),
887 store: Box::new(store),
888 model_name,
889 runtime_info,
890 text_embed_cache: Mutex::new(TextEmbeddingCache::new(
891 configured_embedding_text_cache_size(),
892 )),
893 indexing: std::sync::atomic::AtomicBool::new(false),
894 })
895 }
896
897 pub fn model_name(&self) -> &str {
898 &self.model_name
899 }
900
901 pub fn runtime_info(&self) -> &EmbeddingRuntimeInfo {
902 &self.runtime_info
903 }
904
905 pub fn is_indexing(&self) -> bool {
912 self.indexing.load(std::sync::atomic::Ordering::Relaxed)
913 }
914
915 pub fn index_from_project(&self, project: &ProjectRoot) -> Result<usize> {
916 if self
918 .indexing
919 .compare_exchange(
920 false,
921 true,
922 std::sync::atomic::Ordering::AcqRel,
923 std::sync::atomic::Ordering::Relaxed,
924 )
925 .is_err()
926 {
927 anyhow::bail!(
928 "Embedding indexing already in progress — wait for the current run to complete before retrying."
929 );
930 }
931 struct IndexGuard<'a>(&'a std::sync::atomic::AtomicBool);
933 impl Drop for IndexGuard<'_> {
934 fn drop(&mut self) {
935 self.0.store(false, std::sync::atomic::Ordering::Release);
936 }
937 }
938 let _guard = IndexGuard(&self.indexing);
939
940 let db_path = crate::db::index_db_path(project.as_path());
941 let symbol_db = IndexDb::open(&db_path)?;
942 let batch_size = embed_batch_size();
943 let max_symbols = max_embed_symbols();
944 let mut total_indexed = 0usize;
945 let mut total_seen = 0usize;
946 let mut model = None;
947 let mut existing_embeddings: HashMap<
948 String,
949 HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
950 > = HashMap::new();
951 let mut current_db_files = HashSet::new();
952 let mut capped = false;
953
954 self.store
955 .for_each_file_embeddings(&mut |file_path, chunks| {
956 existing_embeddings.insert(
957 file_path,
958 chunks
959 .into_iter()
960 .map(|chunk| (reusable_embedding_key_for_chunk(&chunk), chunk))
961 .collect(),
962 );
963 Ok(())
964 })?;
965
966 symbol_db.for_each_file_symbols_with_bytes(|file_path, symbols| {
967 current_db_files.insert(file_path.clone());
968 if capped {
969 return Ok(());
970 }
971
972 let source = std::fs::read_to_string(project.as_path().join(&file_path)).ok();
973 let relevant_symbols: Vec<_> = symbols
974 .into_iter()
975 .filter(|sym| !is_test_only_symbol(sym, source.as_deref()))
976 .collect();
977
978 if relevant_symbols.is_empty() {
979 self.store.delete_by_file(&[file_path.as_str()])?;
980 existing_embeddings.remove(&file_path);
981 return Ok(());
982 }
983
984 if total_seen + relevant_symbols.len() > max_symbols {
985 capped = true;
986 return Ok(());
987 }
988 total_seen += relevant_symbols.len();
989
990 let existing_for_file = existing_embeddings.remove(&file_path).unwrap_or_default();
991 total_indexed += self.reconcile_file_embeddings(
992 &file_path,
993 relevant_symbols,
994 source.as_deref(),
995 existing_for_file,
996 batch_size,
997 &mut model,
998 )?;
999 Ok(())
1000 })?;
1001
1002 let removed_files: Vec<String> = existing_embeddings
1003 .into_keys()
1004 .filter(|file_path| !current_db_files.contains(file_path))
1005 .collect();
1006 if !removed_files.is_empty() {
1007 let removed_refs: Vec<&str> = removed_files.iter().map(String::as_str).collect();
1008 self.store.delete_by_file(&removed_refs)?;
1009 }
1010
1011 Ok(total_indexed)
1012 }
1013
1014 fn reconcile_file_embeddings<'a>(
1015 &'a self,
1016 file_path: &str,
1017 symbols: Vec<crate::db::SymbolWithFile>,
1018 source: Option<&str>,
1019 mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
1020 batch_size: usize,
1021 model: &mut Option<std::sync::MutexGuard<'a, TextEmbedding>>,
1022 ) -> Result<usize> {
1023 let mut reconciled_chunks = Vec::with_capacity(symbols.len());
1024 let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
1025 let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
1026
1027 for sym in symbols {
1028 let text = build_embedding_text(&sym, source);
1029 if let Some(existing) =
1030 existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
1031 {
1032 reconciled_chunks.push(EmbeddingChunk {
1033 file_path: sym.file_path.clone(),
1034 symbol_name: sym.name.clone(),
1035 kind: sym.kind.clone(),
1036 line: sym.line as usize,
1037 signature: sym.signature.clone(),
1038 name_path: sym.name_path.clone(),
1039 text,
1040 embedding: existing.embedding,
1041 doc_embedding: existing.doc_embedding,
1042 });
1043 continue;
1044 }
1045
1046 batch_texts.push(text);
1047 batch_meta.push(sym);
1048
1049 if batch_texts.len() >= batch_size {
1050 if model.is_none() {
1051 *model = Some(
1052 self.model
1053 .lock()
1054 .map_err(|_| anyhow::anyhow!("model lock"))?,
1055 );
1056 }
1057 reconciled_chunks.extend(Self::embed_chunks(
1058 model.as_mut().expect("model lock initialized"),
1059 &batch_texts,
1060 &batch_meta,
1061 )?);
1062 batch_texts.clear();
1063 batch_meta.clear();
1064 }
1065 }
1066
1067 if !batch_texts.is_empty() {
1068 if model.is_none() {
1069 *model = Some(
1070 self.model
1071 .lock()
1072 .map_err(|_| anyhow::anyhow!("model lock"))?,
1073 );
1074 }
1075 reconciled_chunks.extend(Self::embed_chunks(
1076 model.as_mut().expect("model lock initialized"),
1077 &batch_texts,
1078 &batch_meta,
1079 )?);
1080 }
1081
1082 self.store.delete_by_file(&[file_path])?;
1083 if reconciled_chunks.is_empty() {
1084 return Ok(0);
1085 }
1086 self.store.insert(&reconciled_chunks)
1087 }
1088
1089 fn embed_chunks(
1090 model: &mut TextEmbedding,
1091 texts: &[String],
1092 meta: &[crate::db::SymbolWithFile],
1093 ) -> Result<Vec<EmbeddingChunk>> {
1094 let batch_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
1095 let embeddings = model.embed(batch_refs, None).context("embedding failed")?;
1096
1097 Ok(meta
1098 .iter()
1099 .zip(embeddings)
1100 .zip(texts.iter())
1101 .map(|((sym, emb), text)| EmbeddingChunk {
1102 file_path: sym.file_path.clone(),
1103 symbol_name: sym.name.clone(),
1104 kind: sym.kind.clone(),
1105 line: sym.line as usize,
1106 signature: sym.signature.clone(),
1107 name_path: sym.name_path.clone(),
1108 text: text.clone(),
1109 embedding: emb,
1110 doc_embedding: None,
1111 })
1112 .collect())
1113 }
1114
1115 fn flush_batch(
1117 model: &mut TextEmbedding,
1118 store: &dyn EmbeddingStore,
1119 texts: &[String],
1120 meta: &[crate::db::SymbolWithFile],
1121 ) -> Result<usize> {
1122 let chunks = Self::embed_chunks(model, texts, meta)?;
1123 store.insert(&chunks)
1124 }
1125
1126 pub fn search(&self, query: &str, max_results: usize) -> Result<Vec<SemanticMatch>> {
1128 let results = self.search_scored(query, max_results)?;
1129 Ok(results.into_iter().map(SemanticMatch::from).collect())
1130 }
1131
1132 pub fn search_scored(&self, query: &str, max_results: usize) -> Result<Vec<ScoredChunk>> {
1139 let query_embedding = self.embed_texts_cached(&[query])?;
1140
1141 if query_embedding.is_empty() {
1142 return Ok(Vec::new());
1143 }
1144
1145 let candidate_count = max_results.saturating_mul(3).max(max_results);
1147 let mut candidates = self.store.search(&query_embedding[0], candidate_count)?;
1148
1149 if candidates.len() <= max_results {
1150 return Ok(candidates);
1151 }
1152
1153 let query_lower = query.to_lowercase();
1156 let query_tokens: Vec<&str> = query_lower
1157 .split(|c: char| c.is_whitespace() || c == '_' || c == '-')
1158 .filter(|t| t.len() >= 2)
1159 .collect();
1160
1161 if query_tokens.is_empty() {
1162 candidates.truncate(max_results);
1163 return Ok(candidates);
1164 }
1165
1166 for chunk in &mut candidates {
1167 let searchable = format!(
1169 "{} {} {}",
1170 chunk.symbol_name.to_lowercase(),
1171 chunk.signature.to_lowercase(),
1172 chunk.file_path.to_lowercase(),
1173 );
1174 let overlap = query_tokens
1175 .iter()
1176 .filter(|t| searchable.contains(**t))
1177 .count() as f64;
1178 let overlap_ratio = overlap / query_tokens.len().max(1) as f64;
1179 chunk.score = chunk.score * 0.8 + overlap_ratio * 0.2;
1181 }
1182
1183 candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
1184 candidates.truncate(max_results);
1185 Ok(candidates)
1186 }
1187
1188 pub fn index_changed_files(
1190 &self,
1191 project: &ProjectRoot,
1192 changed_files: &[&str],
1193 ) -> Result<usize> {
1194 if changed_files.is_empty() {
1195 return Ok(0);
1196 }
1197 let batch_size = embed_batch_size();
1198 let mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk> = HashMap::new();
1199 for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
1200 for chunk in self.store.embeddings_for_files(file_chunk)? {
1201 existing_embeddings.insert(reusable_embedding_key_for_chunk(&chunk), chunk);
1202 }
1203 }
1204 self.store.delete_by_file(changed_files)?;
1205
1206 let db_path = crate::db::index_db_path(project.as_path());
1207 let symbol_db = IndexDb::open(&db_path)?;
1208
1209 let mut total_indexed = 0usize;
1210 let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
1211 let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
1212 let mut batch_reused: Vec<EmbeddingChunk> = Vec::with_capacity(batch_size);
1213 let mut file_cache: std::collections::HashMap<String, Option<String>> =
1214 std::collections::HashMap::new();
1215 let mut model = None;
1216
1217 for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
1218 let relevant = symbol_db.symbols_for_files(file_chunk)?;
1219 for sym in relevant {
1220 let source = file_cache.entry(sym.file_path.clone()).or_insert_with(|| {
1221 std::fs::read_to_string(project.as_path().join(&sym.file_path)).ok()
1222 });
1223 if is_test_only_symbol(&sym, source.as_deref()) {
1224 continue;
1225 }
1226 let text = build_embedding_text(&sym, source.as_deref());
1227 if let Some(existing) =
1228 existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
1229 {
1230 batch_reused.push(EmbeddingChunk {
1231 file_path: sym.file_path.clone(),
1232 symbol_name: sym.name.clone(),
1233 kind: sym.kind.clone(),
1234 line: sym.line as usize,
1235 signature: sym.signature.clone(),
1236 name_path: sym.name_path.clone(),
1237 text,
1238 embedding: existing.embedding,
1239 doc_embedding: existing.doc_embedding,
1240 });
1241 if batch_reused.len() >= batch_size {
1242 total_indexed += self.store.insert(&batch_reused)?;
1243 batch_reused.clear();
1244 }
1245 continue;
1246 }
1247 batch_texts.push(text);
1248 batch_meta.push(sym);
1249
1250 if batch_texts.len() >= batch_size {
1251 if model.is_none() {
1252 model = Some(
1253 self.model
1254 .lock()
1255 .map_err(|_| anyhow::anyhow!("model lock"))?,
1256 );
1257 }
1258 total_indexed += Self::flush_batch(
1259 model.as_mut().expect("model lock initialized"),
1260 &*self.store,
1261 &batch_texts,
1262 &batch_meta,
1263 )?;
1264 batch_texts.clear();
1265 batch_meta.clear();
1266 }
1267 }
1268 }
1269
1270 if !batch_reused.is_empty() {
1271 total_indexed += self.store.insert(&batch_reused)?;
1272 }
1273
1274 if !batch_texts.is_empty() {
1275 if model.is_none() {
1276 model = Some(
1277 self.model
1278 .lock()
1279 .map_err(|_| anyhow::anyhow!("model lock"))?,
1280 );
1281 }
1282 total_indexed += Self::flush_batch(
1283 model.as_mut().expect("model lock initialized"),
1284 &*self.store,
1285 &batch_texts,
1286 &batch_meta,
1287 )?;
1288 }
1289
1290 Ok(total_indexed)
1291 }
1292
1293 pub fn is_indexed(&self) -> bool {
1295 self.store.count().unwrap_or(0) > 0
1296 }
1297
1298 pub fn index_info(&self) -> EmbeddingIndexInfo {
1299 EmbeddingIndexInfo {
1300 model_name: self.model_name.clone(),
1301 indexed_symbols: self.store.count().unwrap_or(0),
1302 }
1303 }
1304
1305 pub fn inspect_existing_index(project: &ProjectRoot) -> Result<Option<EmbeddingIndexInfo>> {
1306 let db_path = project.as_path().join(".codelens/index/embeddings.db");
1307 if !db_path.exists() {
1308 return Ok(None);
1309 }
1310
1311 let conn =
1312 crate::db::open_derived_sqlite_with_recovery(&db_path, "embedding index", || {
1313 ffi::register_sqlite_vec()?;
1314 let conn = Connection::open(&db_path)?;
1315 conn.execute_batch("PRAGMA busy_timeout=5000;")?;
1316 conn.query_row("PRAGMA schema_version", [], |_row| Ok(()))?;
1317 Ok(conn)
1318 })?;
1319
1320 let model_name: Option<String> = conn
1321 .query_row(
1322 "SELECT value FROM meta WHERE key = 'model' LIMIT 1",
1323 [],
1324 |row| row.get(0),
1325 )
1326 .ok();
1327 let indexed_symbols: usize = conn
1328 .query_row("SELECT COUNT(*) FROM symbols", [], |row| {
1329 row.get::<_, i64>(0)
1330 })
1331 .map(|count| count.max(0) as usize)
1332 .unwrap_or(0);
1333
1334 Ok(model_name.map(|model_name| EmbeddingIndexInfo {
1335 model_name,
1336 indexed_symbols,
1337 }))
1338 }
1339
1340 pub fn find_similar_code(
1344 &self,
1345 file_path: &str,
1346 symbol_name: &str,
1347 max_results: usize,
1348 ) -> Result<Vec<SemanticMatch>> {
1349 let target = self
1350 .store
1351 .get_embedding(file_path, symbol_name)?
1352 .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?;
1353
1354 let oversample = max_results.saturating_add(8).max(1);
1355 let scored = self
1356 .store
1357 .search(&target.embedding, oversample)?
1358 .into_iter()
1359 .filter(|c| !(c.file_path == file_path && c.symbol_name == symbol_name))
1360 .take(max_results)
1361 .map(SemanticMatch::from)
1362 .collect();
1363 Ok(scored)
1364 }
1365
1366 pub fn find_duplicates(&self, threshold: f64, max_pairs: usize) -> Result<Vec<DuplicatePair>> {
1369 let mut pairs = Vec::new();
1370 let mut seen_pairs = HashSet::new();
1371 let mut embedding_cache: HashMap<StoredChunkKey, Arc<EmbeddingChunk>> = HashMap::new();
1372 let candidate_limit = duplicate_candidate_limit(max_pairs);
1373 let mut done = false;
1374
1375 self.store
1376 .for_each_embedding_batch(DEFAULT_DUPLICATE_SCAN_BATCH_SIZE, &mut |batch| {
1377 if done {
1378 return Ok(());
1379 }
1380
1381 let mut candidate_lists = Vec::with_capacity(batch.len());
1382 let mut missing_candidates = Vec::new();
1383 let mut missing_keys = HashSet::new();
1384
1385 for chunk in &batch {
1386 if pairs.len() >= max_pairs {
1387 done = true;
1388 break;
1389 }
1390
1391 let filtered: Vec<ScoredChunk> = self
1392 .store
1393 .search(&chunk.embedding, candidate_limit)?
1394 .into_iter()
1395 .filter(|candidate| {
1396 !(chunk.file_path == candidate.file_path
1397 && chunk.symbol_name == candidate.symbol_name
1398 && chunk.line == candidate.line
1399 && chunk.signature == candidate.signature
1400 && chunk.name_path == candidate.name_path)
1401 })
1402 .collect();
1403
1404 for candidate in &filtered {
1405 let cache_key = stored_chunk_key_for_score(candidate);
1406 if !embedding_cache.contains_key(&cache_key)
1407 && missing_keys.insert(cache_key)
1408 {
1409 missing_candidates.push(candidate.clone());
1410 }
1411 }
1412
1413 candidate_lists.push(filtered);
1414 }
1415
1416 if !missing_candidates.is_empty() {
1417 for candidate_chunk in self
1418 .store
1419 .embeddings_for_scored_chunks(&missing_candidates)?
1420 {
1421 embedding_cache
1422 .entry(stored_chunk_key(&candidate_chunk))
1423 .or_insert_with(|| Arc::new(candidate_chunk));
1424 }
1425 }
1426
1427 for (chunk, candidates) in batch.iter().zip(candidate_lists.iter()) {
1428 if pairs.len() >= max_pairs {
1429 done = true;
1430 break;
1431 }
1432
1433 for candidate in candidates {
1434 let pair_key = duplicate_pair_key(
1435 &chunk.file_path,
1436 &chunk.symbol_name,
1437 &candidate.file_path,
1438 &candidate.symbol_name,
1439 );
1440 if !seen_pairs.insert(pair_key) {
1441 continue;
1442 }
1443
1444 let Some(candidate_chunk) =
1445 embedding_cache.get(&stored_chunk_key_for_score(candidate))
1446 else {
1447 continue;
1448 };
1449
1450 let sim = cosine_similarity(&chunk.embedding, &candidate_chunk.embedding);
1451 if sim < threshold {
1452 continue;
1453 }
1454
1455 pairs.push(DuplicatePair {
1456 symbol_a: format!("{}:{}", chunk.file_path, chunk.symbol_name),
1457 symbol_b: format!(
1458 "{}:{}",
1459 candidate_chunk.file_path, candidate_chunk.symbol_name
1460 ),
1461 file_a: chunk.file_path.clone(),
1462 file_b: candidate_chunk.file_path.clone(),
1463 line_a: chunk.line,
1464 line_b: candidate_chunk.line,
1465 similarity: sim,
1466 });
1467 if pairs.len() >= max_pairs {
1468 done = true;
1469 break;
1470 }
1471 }
1472 }
1473 Ok(())
1474 })?;
1475
1476 pairs.sort_by(|a, b| {
1477 b.similarity
1478 .partial_cmp(&a.similarity)
1479 .unwrap_or(std::cmp::Ordering::Equal)
1480 });
1481 Ok(pairs)
1482 }
1483}
1484
1485fn duplicate_candidate_limit(max_pairs: usize) -> usize {
1486 max_pairs.saturating_mul(4).clamp(32, 128)
1487}
1488
1489fn duplicate_pair_key(
1490 file_a: &str,
1491 symbol_a: &str,
1492 file_b: &str,
1493 symbol_b: &str,
1494) -> ((String, String), (String, String)) {
1495 let left = (file_a.to_owned(), symbol_a.to_owned());
1496 let right = (file_b.to_owned(), symbol_b.to_owned());
1497 if left <= right {
1498 (left, right)
1499 } else {
1500 (right, left)
1501 }
1502}
1503
1504type StoredChunkKey = (String, String, usize, String, String);
1505
1506fn stored_chunk_key(chunk: &EmbeddingChunk) -> StoredChunkKey {
1507 (
1508 chunk.file_path.clone(),
1509 chunk.symbol_name.clone(),
1510 chunk.line,
1511 chunk.signature.clone(),
1512 chunk.name_path.clone(),
1513 )
1514}
1515
1516fn stored_chunk_key_for_score(chunk: &ScoredChunk) -> StoredChunkKey {
1517 (
1518 chunk.file_path.clone(),
1519 chunk.symbol_name.clone(),
1520 chunk.line,
1521 chunk.signature.clone(),
1522 chunk.name_path.clone(),
1523 )
1524}
1525
1526impl EmbeddingEngine {
1527 pub fn classify_symbol(
1529 &self,
1530 file_path: &str,
1531 symbol_name: &str,
1532 categories: &[&str],
1533 ) -> Result<Vec<CategoryScore>> {
1534 let target = match self.store.get_embedding(file_path, symbol_name)? {
1535 Some(target) => target,
1536 None => self
1537 .store
1538 .all_with_embeddings()?
1539 .into_iter()
1540 .find(|c| c.file_path == file_path && c.symbol_name == symbol_name)
1541 .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?,
1542 };
1543
1544 let embeddings = self.embed_texts_cached(categories)?;
1545
1546 let mut scores: Vec<CategoryScore> = categories
1547 .iter()
1548 .zip(embeddings.iter())
1549 .map(|(cat, emb)| CategoryScore {
1550 category: cat.to_string(),
1551 score: cosine_similarity(&target.embedding, emb),
1552 })
1553 .collect();
1554
1555 scores.sort_by(|a, b| {
1556 b.score
1557 .partial_cmp(&a.score)
1558 .unwrap_or(std::cmp::Ordering::Equal)
1559 });
1560 Ok(scores)
1561 }
1562
1563 pub fn find_misplaced_code(&self, max_results: usize) -> Result<Vec<OutlierSymbol>> {
1565 let mut outliers = Vec::new();
1566
1567 self.store
1568 .for_each_file_embeddings(&mut |file_path, chunks| {
1569 if chunks.len() < 2 {
1570 return Ok(());
1571 }
1572
1573 for (idx, chunk) in chunks.iter().enumerate() {
1574 let mut sim_sum = 0.0;
1575 let mut count = 0;
1576 for (other_idx, other_chunk) in chunks.iter().enumerate() {
1577 if other_idx == idx {
1578 continue;
1579 }
1580 sim_sum += cosine_similarity(&chunk.embedding, &other_chunk.embedding);
1581 count += 1;
1582 }
1583 if count > 0 {
1584 let avg_sim = sim_sum / count as f64; outliers.push(OutlierSymbol {
1586 file_path: file_path.clone(),
1587 symbol_name: chunk.symbol_name.clone(),
1588 kind: chunk.kind.clone(),
1589 line: chunk.line,
1590 avg_similarity_to_file: avg_sim,
1591 });
1592 }
1593 }
1594 Ok(())
1595 })?;
1596
1597 outliers.sort_by(|a, b| {
1598 a.avg_similarity_to_file
1599 .partial_cmp(&b.avg_similarity_to_file)
1600 .unwrap_or(std::cmp::Ordering::Equal)
1601 });
1602 outliers.truncate(max_results);
1603 Ok(outliers)
1604 }
1605}
1606
1607#[derive(Debug, Clone, Serialize)]
1610pub struct DuplicatePair {
1611 pub symbol_a: String,
1612 pub symbol_b: String,
1613 pub file_a: String,
1614 pub file_b: String,
1615 pub line_a: usize,
1616 pub line_b: usize,
1617 pub similarity: f64,
1618}
1619
1620#[derive(Debug, Clone, Serialize)]
1621pub struct CategoryScore {
1622 pub category: String,
1623 pub score: f64,
1624}
1625
1626#[derive(Debug, Clone, Serialize)]
1627pub struct OutlierSymbol {
1628 pub file_path: String,
1629 pub symbol_name: String,
1630 pub kind: String,
1631 pub line: usize,
1632 pub avg_similarity_to_file: f64,
1633}
1634
1635fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
1640 debug_assert_eq!(a.len(), b.len());
1641
1642 let (mut dot, mut norm_a, mut norm_b) = (0.0f32, 0.0f32, 0.0f32);
1645 for (x, y) in a.iter().zip(b.iter()) {
1646 dot += x * y;
1647 norm_a += x * x;
1648 norm_b += y * y;
1649 }
1650
1651 let norm_a = (norm_a as f64).sqrt();
1652 let norm_b = (norm_b as f64).sqrt();
1653 if norm_a == 0.0 || norm_b == 0.0 {
1654 0.0
1655 } else {
1656 dot as f64 / (norm_a * norm_b)
1657 }
1658}
1659
1660fn split_identifier(name: &str) -> String {
1675 if !name.contains('_') && !name.chars().any(|c| c.is_uppercase()) {
1677 return name.to_string();
1678 }
1679 let mut words = Vec::new();
1680 let mut current = String::new();
1681 let chars: Vec<char> = name.chars().collect();
1682 for (i, &ch) in chars.iter().enumerate() {
1683 if ch == '_' {
1684 if !current.is_empty() {
1685 words.push(current.clone());
1686 current.clear();
1687 }
1688 } else if ch.is_uppercase()
1689 && !current.is_empty()
1690 && (current
1691 .chars()
1692 .last()
1693 .map(|c| c.is_lowercase())
1694 .unwrap_or(false)
1695 || chars.get(i + 1).map(|c| c.is_lowercase()).unwrap_or(false))
1696 {
1697 words.push(current.clone());
1699 current.clear();
1700 current.push(ch);
1701 } else {
1702 current.push(ch);
1703 }
1704 }
1705 if !current.is_empty() {
1706 words.push(current);
1707 }
1708 if words.len() <= 1 {
1709 return name.to_string(); }
1711 words.join(" ")
1712}
1713
1714fn is_test_only_symbol(sym: &crate::db::SymbolWithFile, source: Option<&str>) -> bool {
1715 if sym.file_path.contains("/tests/") || sym.file_path.ends_with("_tests.rs") {
1716 return true;
1717 }
1718 if sym.name_path.starts_with("tests::")
1719 || sym.name_path.contains("::tests::")
1720 || sym.name_path.starts_with("test::")
1721 || sym.name_path.contains("::test::")
1722 {
1723 return true;
1724 }
1725
1726 let Some(source) = source else {
1727 return false;
1728 };
1729
1730 let start = usize::try_from(sym.start_byte.max(0))
1731 .unwrap_or(0)
1732 .min(source.len());
1733 let window_start = start.saturating_sub(2048);
1734 let attrs = String::from_utf8_lossy(&source.as_bytes()[window_start..start]);
1735 attrs.contains("#[test]")
1736 || attrs.contains("#[tokio::test]")
1737 || attrs.contains("#[cfg(test)]")
1738 || attrs.contains("#[cfg(all(test")
1739}
1740
1741fn build_embedding_text(sym: &crate::db::SymbolWithFile, source: Option<&str>) -> String {
1742 let file_ctx = if sym.file_path.is_empty() {
1746 String::new()
1747 } else {
1748 let filename = sym.file_path.rsplit('/').next().unwrap_or(&sym.file_path);
1749 format!(" in {}", filename)
1750 };
1751
1752 let split_name = split_identifier(&sym.name);
1755 let name_with_split = if split_name != sym.name {
1756 format!("{} ({})", sym.name, split_name)
1757 } else {
1758 sym.name.clone()
1759 };
1760
1761 let parent_ctx = if !sym.name_path.is_empty() && sym.name_path.contains('/') {
1763 let parent = sym.name_path.rsplit_once('/').map(|x| x.0).unwrap_or("");
1764 if parent.is_empty() {
1765 String::new()
1766 } else {
1767 format!(" (in {})", parent)
1768 }
1769 } else {
1770 String::new()
1771 };
1772
1773 let module_ctx = if sym.file_path.contains('/') {
1776 let parts: Vec<&str> = sym.file_path.rsplitn(3, '/').collect();
1777 if parts.len() >= 2 {
1778 let dir = parts[1];
1779 if dir != "src" && dir != "crates" {
1781 format!(" [{dir}]")
1782 } else {
1783 String::new()
1784 }
1785 } else {
1786 String::new()
1787 }
1788 } else {
1789 String::new()
1790 };
1791
1792 let base = if sym.signature.is_empty() {
1793 format!(
1794 "{} {}{}{}{}", sym.kind, name_with_split, parent_ctx, module_ctx, file_ctx
1795 )
1796 } else {
1797 format!(
1798 "{} {}{}{}{}: {}",
1799 sym.kind, name_with_split, parent_ctx, module_ctx, file_ctx, sym.signature
1800 )
1801 };
1802
1803 let docstrings_disabled = std::env::var("CODELENS_EMBED_DOCSTRINGS")
1807 .map(|v| v == "0" || v == "false")
1808 .unwrap_or(false);
1809
1810 if docstrings_disabled {
1811 return base;
1812 }
1813
1814 let docstring = source
1815 .and_then(|src| extract_leading_doc(src, sym.start_byte as usize, sym.end_byte as usize))
1816 .unwrap_or_default();
1817
1818 let mut text = if docstring.is_empty() {
1819 let body_hint = source
1824 .and_then(|src| extract_body_hint(src, sym.start_byte as usize, sym.end_byte as usize))
1825 .unwrap_or_default();
1826 if body_hint.is_empty() {
1827 base
1828 } else {
1829 format!("{} — {}", base, body_hint)
1830 }
1831 } else {
1832 let line_budget = hint_line_budget();
1837 let lines: Vec<String> = docstring
1838 .lines()
1839 .map(str::trim)
1840 .filter(|line| !line.is_empty())
1841 .take(line_budget)
1842 .map(str::to_string)
1843 .collect();
1844 let hint = join_hint_lines(&lines);
1845 if hint.is_empty() {
1846 base
1847 } else {
1848 format!("{} — {}", base, hint)
1849 }
1850 };
1851
1852 if let Some(src) = source
1856 && let Some(nl_tokens) =
1857 extract_nl_tokens(src, sym.start_byte as usize, sym.end_byte as usize)
1858 && !nl_tokens.is_empty()
1859 {
1860 text.push_str(" · NL: ");
1861 text.push_str(&nl_tokens);
1862 }
1863
1864 if let Some(src) = source
1869 && let Some(api_calls) =
1870 extract_api_calls(src, sym.start_byte as usize, sym.end_byte as usize)
1871 && !api_calls.is_empty()
1872 {
1873 text.push_str(" · API: ");
1874 text.push_str(&api_calls);
1875 }
1876
1877 text
1878}
1879
1880const DEFAULT_HINT_TOTAL_CHAR_BUDGET: usize = 60;
1893
1894const DEFAULT_HINT_LINES: usize = 1;
1897
1898fn hint_char_budget() -> usize {
1899 std::env::var("CODELENS_EMBED_HINT_CHARS")
1900 .ok()
1901 .and_then(|raw| raw.parse::<usize>().ok())
1902 .map(|n| n.clamp(60, 512))
1903 .unwrap_or(DEFAULT_HINT_TOTAL_CHAR_BUDGET)
1904}
1905
1906fn hint_line_budget() -> usize {
1907 std::env::var("CODELENS_EMBED_HINT_LINES")
1908 .ok()
1909 .and_then(|raw| raw.parse::<usize>().ok())
1910 .map(|n| n.clamp(1, 10))
1911 .unwrap_or(DEFAULT_HINT_LINES)
1912}
1913
1914fn join_hint_lines(lines: &[String]) -> String {
1921 if lines.is_empty() {
1922 return String::new();
1923 }
1924 let joined = lines
1925 .iter()
1926 .map(String::as_str)
1927 .collect::<Vec<_>>()
1928 .join(" · ");
1929 let budget = hint_char_budget();
1930 if joined.chars().count() > budget {
1931 let truncated: String = joined.chars().take(budget).collect();
1932 format!("{truncated}...")
1933 } else {
1934 joined
1935 }
1936}
1937
1938fn extract_body_hint(source: &str, start: usize, end: usize) -> Option<String> {
1948 if start >= source.len() || end > source.len() || start >= end {
1949 return None;
1950 }
1951 let safe_start = if source.is_char_boundary(start) {
1952 start
1953 } else {
1954 source.floor_char_boundary(start)
1955 };
1956 let safe_end = end.min(source.len());
1957 let safe_end = if source.is_char_boundary(safe_end) {
1958 safe_end
1959 } else {
1960 source.floor_char_boundary(safe_end)
1961 };
1962 let body = &source[safe_start..safe_end];
1963
1964 let max_lines = hint_line_budget();
1965 let mut collected: Vec<String> = Vec::with_capacity(max_lines);
1966
1967 let mut past_signature = false;
1970 for line in body.lines() {
1971 let trimmed = line.trim();
1972 if !past_signature {
1973 if trimmed.ends_with('{') || trimmed.ends_with(':') || trimmed == "{" {
1975 past_signature = true;
1976 }
1977 continue;
1978 }
1979 if trimmed.is_empty()
1981 || trimmed.starts_with("//")
1982 || trimmed.starts_with('#')
1983 || trimmed.starts_with("/*")
1984 || trimmed.starts_with('*')
1985 || trimmed == "}"
1986 {
1987 continue;
1988 }
1989 collected.push(trimmed.to_string());
1990 if collected.len() >= max_lines {
1991 break;
1992 }
1993 }
1994
1995 if collected.is_empty() {
1996 None
1997 } else {
1998 Some(join_hint_lines(&collected))
1999 }
2000}
2001
2002fn nl_tokens_enabled() -> bool {
2012 if let Some(explicit) = parse_bool_env("CODELENS_EMBED_HINT_INCLUDE_COMMENTS") {
2013 return explicit;
2014 }
2015 auto_hint_should_enable()
2016}
2017
2018pub(super) fn auto_hint_mode_enabled() -> bool {
2060 parse_bool_env("CODELENS_EMBED_HINT_AUTO").unwrap_or(true)
2061}
2062
2063pub(super) fn auto_hint_lang() -> Option<String> {
2074 std::env::var("CODELENS_EMBED_HINT_AUTO_LANG")
2075 .ok()
2076 .map(|raw| raw.trim().to_ascii_lowercase())
2077}
2078
2079pub(super) fn language_supports_nl_stack(lang: &str) -> bool {
2115 matches!(
2116 lang.trim().to_ascii_lowercase().as_str(),
2117 "rs" | "rust"
2118 | "cpp"
2119 | "cc"
2120 | "cxx"
2121 | "c++"
2122 | "c"
2123 | "go"
2124 | "golang"
2125 | "java"
2126 | "kt"
2127 | "kotlin"
2128 | "scala"
2129 | "cs"
2130 | "csharp"
2131 | "ts"
2132 | "typescript"
2133 | "tsx"
2134 | "js"
2135 | "javascript"
2136 | "jsx"
2137 )
2138}
2139
2140pub(super) fn language_supports_sparse_weighting(lang: &str) -> bool {
2158 matches!(
2159 lang.trim().to_ascii_lowercase().as_str(),
2160 "rs" | "rust"
2161 | "cpp"
2162 | "cc"
2163 | "cxx"
2164 | "c++"
2165 | "c"
2166 | "go"
2167 | "golang"
2168 | "java"
2169 | "kt"
2170 | "kotlin"
2171 | "scala"
2172 | "cs"
2173 | "csharp"
2174 )
2175}
2176
2177pub(super) fn auto_hint_should_enable() -> bool {
2182 if !auto_hint_mode_enabled() {
2183 return false;
2184 }
2185 match auto_hint_lang() {
2186 Some(lang) => language_supports_nl_stack(&lang),
2187 None => false, }
2189}
2190
2191pub(super) fn auto_sparse_should_enable() -> bool {
2198 if !auto_hint_mode_enabled() {
2199 return false;
2200 }
2201 match auto_hint_lang() {
2202 Some(lang) => language_supports_sparse_weighting(&lang),
2203 None => false,
2204 }
2205}
2206
2207pub(super) fn is_nl_shaped(s: &str) -> bool {
2216 let s = s.trim();
2217 if s.chars().count() < 4 {
2218 return false;
2219 }
2220 if s.contains('/') || s.contains('\\') || s.contains("::") {
2221 return false;
2222 }
2223 if !s.contains(' ') {
2224 return false;
2225 }
2226 let non_ws: usize = s.chars().filter(|c| !c.is_whitespace()).count();
2227 if non_ws == 0 {
2228 return false;
2229 }
2230 let alpha: usize = s.chars().filter(|c| c.is_alphabetic()).count();
2231 (alpha * 100) / non_ws >= 60
2232}
2233
2234fn strict_comments_enabled() -> bool {
2249 std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS")
2250 .map(|raw| {
2251 let lowered = raw.to_ascii_lowercase();
2252 matches!(lowered.as_str(), "1" | "true" | "yes" | "on")
2253 })
2254 .unwrap_or(false)
2255}
2256
2257pub(super) fn looks_like_meta_annotation(body: &str) -> bool {
2278 let trimmed = body.trim_start();
2279 let word_end = trimmed
2282 .find(|c: char| !c.is_ascii_alphabetic())
2283 .unwrap_or(trimmed.len());
2284 if word_end == 0 {
2285 return false;
2286 }
2287 let first_word = &trimmed[..word_end];
2288 let upper = first_word.to_ascii_uppercase();
2289 matches!(
2290 upper.as_str(),
2291 "TODO"
2292 | "FIXME"
2293 | "HACK"
2294 | "XXX"
2295 | "BUG"
2296 | "REVIEW"
2297 | "REFACTOR"
2298 | "TEMP"
2299 | "TEMPORARY"
2300 | "DEPRECATED"
2301 )
2302}
2303
2304fn strict_literal_filter_enabled() -> bool {
2319 std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS")
2320 .map(|raw| {
2321 let lowered = raw.to_ascii_lowercase();
2322 matches!(lowered.as_str(), "1" | "true" | "yes" | "on")
2323 })
2324 .unwrap_or(false)
2325}
2326
2327pub(super) fn contains_format_specifier(s: &str) -> bool {
2339 let bytes = s.as_bytes();
2340 let len = bytes.len();
2341 let mut i = 0;
2342 while i + 1 < len {
2343 if bytes[i] == b'%' {
2344 let next = bytes[i + 1];
2345 if matches!(next, b's' | b'd' | b'r' | b'f' | b'x' | b'o' | b'i' | b'u') {
2346 return true;
2347 }
2348 }
2349 i += 1;
2350 }
2351 for window in s.split('{').skip(1) {
2359 let Some(close_idx) = window.find('}') else {
2360 continue;
2361 };
2362 let inside = &window[..close_idx];
2363 if inside.is_empty() {
2365 return true;
2366 }
2367 if inside.chars().any(|c| c.is_whitespace()) {
2369 continue;
2370 }
2371 if inside.starts_with(':') {
2373 return true;
2374 }
2375 let ident_end = inside.find(':').unwrap_or(inside.len());
2379 let ident = &inside[..ident_end];
2380 if !ident.is_empty()
2381 && ident
2382 .chars()
2383 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
2384 {
2385 return true;
2386 }
2387 }
2388 false
2389}
2390
2391pub(super) fn looks_like_error_or_log_prefix(s: &str) -> bool {
2402 let lower = s.trim().to_lowercase();
2403 const PREFIXES: &[&str] = &[
2404 "invalid ",
2405 "cannot ",
2406 "could not ",
2407 "unable to ",
2408 "failed to ",
2409 "expected ",
2410 "unexpected ",
2411 "missing ",
2412 "not found",
2413 "error: ",
2414 "error ",
2415 "warning: ",
2416 "warning ",
2417 "sending ",
2418 "received ",
2419 "starting ",
2420 "stopping ",
2421 "calling ",
2422 "connecting ",
2423 "disconnecting ",
2424 ];
2425 PREFIXES.iter().any(|p| lower.starts_with(p))
2426}
2427
2428#[cfg(test)]
2433pub(super) fn should_reject_literal_strict(s: &str) -> bool {
2434 contains_format_specifier(s) || looks_like_error_or_log_prefix(s)
2435}
2436
2437fn extract_nl_tokens(source: &str, start: usize, end: usize) -> Option<String> {
2451 if !nl_tokens_enabled() {
2452 return None;
2453 }
2454 extract_nl_tokens_inner(source, start, end)
2455}
2456
2457pub(super) fn extract_nl_tokens_inner(source: &str, start: usize, end: usize) -> Option<String> {
2462 if start >= source.len() || end > source.len() || start >= end {
2463 return None;
2464 }
2465 let safe_start = if source.is_char_boundary(start) {
2466 start
2467 } else {
2468 source.floor_char_boundary(start)
2469 };
2470 let safe_end = end.min(source.len());
2471 let safe_end = if source.is_char_boundary(safe_end) {
2472 safe_end
2473 } else {
2474 source.floor_char_boundary(safe_end)
2475 };
2476 let body = &source[safe_start..safe_end];
2477
2478 let mut tokens: Vec<String> = Vec::new();
2479
2480 let strict_comments = strict_comments_enabled();
2488 for line in body.lines() {
2489 let trimmed = line.trim();
2490 if let Some(cleaned) = extract_comment_body(trimmed)
2491 && is_nl_shaped(&cleaned)
2492 && (!strict_comments || !looks_like_meta_annotation(&cleaned))
2493 {
2494 tokens.push(cleaned);
2495 }
2496 }
2497
2498 let strict_literals = strict_literal_filter_enabled();
2508 let mut chars = body.chars().peekable();
2509 let mut in_string = false;
2510 let mut current = String::new();
2511 while let Some(c) = chars.next() {
2512 if in_string {
2513 if c == '\\' {
2514 let _ = chars.next();
2516 } else if c == '"' {
2517 if is_nl_shaped(¤t)
2518 && (!strict_literals
2519 || (!contains_format_specifier(¤t)
2520 && !looks_like_error_or_log_prefix(¤t)))
2521 {
2522 tokens.push(current.clone());
2523 }
2524 current.clear();
2525 in_string = false;
2526 } else {
2527 current.push(c);
2528 }
2529 } else if c == '"' {
2530 in_string = true;
2531 }
2532 }
2533
2534 if tokens.is_empty() {
2535 return None;
2536 }
2537 Some(join_hint_lines(&tokens))
2538}
2539
2540fn api_calls_enabled() -> bool {
2549 if let Some(explicit) = parse_bool_env("CODELENS_EMBED_HINT_INCLUDE_API_CALLS") {
2550 return explicit;
2551 }
2552 auto_hint_should_enable()
2553}
2554
2555pub(super) fn is_static_method_ident(ident: &str) -> bool {
2565 ident.chars().next().is_some_and(|c| c.is_ascii_uppercase())
2566}
2567
2568fn extract_api_calls(source: &str, start: usize, end: usize) -> Option<String> {
2580 if !api_calls_enabled() {
2581 return None;
2582 }
2583 extract_api_calls_inner(source, start, end)
2584}
2585
2586pub(super) fn extract_api_calls_inner(source: &str, start: usize, end: usize) -> Option<String> {
2600 if start >= source.len() || end > source.len() || start >= end {
2601 return None;
2602 }
2603 let safe_start = if source.is_char_boundary(start) {
2604 start
2605 } else {
2606 source.floor_char_boundary(start)
2607 };
2608 let safe_end = end.min(source.len());
2609 let safe_end = if source.is_char_boundary(safe_end) {
2610 safe_end
2611 } else {
2612 source.floor_char_boundary(safe_end)
2613 };
2614 if safe_start >= safe_end {
2615 return None;
2616 }
2617 let body = &source[safe_start..safe_end];
2618 let bytes = body.as_bytes();
2619 let len = bytes.len();
2620
2621 let mut calls: Vec<String> = Vec::new();
2622 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
2623
2624 let mut i = 0usize;
2625 while i < len {
2626 let b = bytes[i];
2627 if !(b == b'_' || b.is_ascii_alphabetic()) {
2629 i += 1;
2630 continue;
2631 }
2632 let ident_start = i;
2633 while i < len {
2634 let bb = bytes[i];
2635 if bb == b'_' || bb.is_ascii_alphanumeric() {
2636 i += 1;
2637 } else {
2638 break;
2639 }
2640 }
2641 let ident_end = i;
2642
2643 if i + 1 >= len || bytes[i] != b':' || bytes[i + 1] != b':' {
2645 continue;
2646 }
2647
2648 let type_ident = &body[ident_start..ident_end];
2649 if !is_static_method_ident(type_ident) {
2650 i += 2;
2653 continue;
2654 }
2655
2656 let mut j = i + 2;
2658 if j >= len || !(bytes[j] == b'_' || bytes[j].is_ascii_alphabetic()) {
2659 i = j;
2660 continue;
2661 }
2662 let method_start = j;
2663 while j < len {
2664 let bb = bytes[j];
2665 if bb == b'_' || bb.is_ascii_alphanumeric() {
2666 j += 1;
2667 } else {
2668 break;
2669 }
2670 }
2671 let method_end = j;
2672
2673 let method_ident = &body[method_start..method_end];
2674 let call = format!("{type_ident}::{method_ident}");
2675 if seen.insert(call.clone()) {
2676 calls.push(call);
2677 }
2678 i = j;
2679 }
2680
2681 if calls.is_empty() {
2682 return None;
2683 }
2684 Some(join_hint_lines(&calls))
2685}
2686
2687fn extract_comment_body(trimmed: &str) -> Option<String> {
2690 if trimmed.is_empty() {
2691 return None;
2692 }
2693 if let Some(rest) = trimmed.strip_prefix("///") {
2695 return Some(rest.trim().to_string());
2696 }
2697 if let Some(rest) = trimmed.strip_prefix("//!") {
2698 return Some(rest.trim().to_string());
2699 }
2700 if let Some(rest) = trimmed.strip_prefix("//") {
2701 return Some(rest.trim().to_string());
2702 }
2703 if trimmed.starts_with("#[") || trimmed.starts_with("#!") {
2705 return None;
2706 }
2707 if let Some(rest) = trimmed.strip_prefix('#') {
2709 return Some(rest.trim().to_string());
2710 }
2711 if let Some(rest) = trimmed.strip_prefix("/**") {
2713 return Some(rest.trim_end_matches("*/").trim().to_string());
2714 }
2715 if let Some(rest) = trimmed.strip_prefix("/*") {
2716 return Some(rest.trim_end_matches("*/").trim().to_string());
2717 }
2718 if let Some(rest) = trimmed.strip_prefix('*') {
2719 let rest = rest.trim_end_matches("*/").trim();
2722 if rest.is_empty() {
2723 return None;
2724 }
2725 if rest.contains(';') || rest.contains('{') {
2727 return None;
2728 }
2729 return Some(rest.to_string());
2730 }
2731 None
2732}
2733
2734fn extract_leading_doc(source: &str, start: usize, end: usize) -> Option<String> {
2737 if start >= source.len() || end > source.len() || start >= end {
2738 return None;
2739 }
2740 let safe_start = if source.is_char_boundary(start) {
2742 start
2743 } else {
2744 source.floor_char_boundary(start)
2745 };
2746 let safe_end = end.min(source.len());
2747 let safe_end = if source.is_char_boundary(safe_end) {
2748 safe_end
2749 } else {
2750 source.floor_char_boundary(safe_end)
2751 };
2752 if safe_start >= safe_end {
2753 return None;
2754 }
2755 let body = &source[safe_start..safe_end];
2756 let lines: Vec<&str> = body.lines().skip(1).collect(); if lines.is_empty() {
2758 return None;
2759 }
2760
2761 let mut doc_lines = Vec::new();
2762
2763 let first_trimmed = lines.first().map(|l| l.trim()).unwrap_or_default();
2765 if first_trimmed.starts_with("\"\"\"") || first_trimmed.starts_with("'''") {
2766 let quote = &first_trimmed[..3];
2767 for line in &lines {
2768 let t = line.trim();
2769 doc_lines.push(t.trim_start_matches(quote).trim_end_matches(quote));
2770 if doc_lines.len() > 1 && t.ends_with(quote) {
2771 break;
2772 }
2773 }
2774 }
2775 else if first_trimmed.starts_with("///") || first_trimmed.starts_with("//!") {
2777 for line in &lines {
2778 let t = line.trim();
2779 if t.starts_with("///") || t.starts_with("//!") {
2780 doc_lines.push(t.trim_start_matches("///").trim_start_matches("//!").trim());
2781 } else {
2782 break;
2783 }
2784 }
2785 }
2786 else if first_trimmed.starts_with("/**") {
2788 for line in &lines {
2789 let t = line.trim();
2790 let cleaned = t
2791 .trim_start_matches("/**")
2792 .trim_start_matches('*')
2793 .trim_end_matches("*/")
2794 .trim();
2795 if !cleaned.is_empty() {
2796 doc_lines.push(cleaned);
2797 }
2798 if t.ends_with("*/") {
2799 break;
2800 }
2801 }
2802 }
2803 else {
2805 for line in &lines {
2806 let t = line.trim();
2807 if t.starts_with("//") || t.starts_with('#') {
2808 doc_lines.push(t.trim_start_matches("//").trim_start_matches('#').trim());
2809 } else {
2810 break;
2811 }
2812 }
2813 }
2814
2815 if doc_lines.is_empty() {
2816 return None;
2817 }
2818 Some(doc_lines.join(" ").trim().to_owned())
2819}
2820
2821pub(super) fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
2822 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
2823}
2824
2825#[cfg(test)]
2826mod tests {
2827 use super::*;
2828 use crate::db::{IndexDb, NewSymbol};
2829 use std::sync::Mutex;
2830
2831 static MODEL_LOCK: Mutex<()> = Mutex::new(());
2833
2834 static ENV_LOCK: Mutex<()> = Mutex::new(());
2841
2842 macro_rules! skip_without_embedding_model {
2843 () => {
2844 if !super::embedding_model_assets_available() {
2845 eprintln!("skipping embedding test: CodeSearchNet model assets unavailable");
2846 return;
2847 }
2848 };
2849 }
2850
2851 fn make_project_with_source() -> (tempfile::TempDir, ProjectRoot) {
2853 let dir = tempfile::tempdir().unwrap();
2854 let root = dir.path();
2855
2856 let source = "def hello():\n print('hi')\n\ndef world():\n return 42\n";
2858 write_python_file_with_symbols(
2859 root,
2860 "main.py",
2861 source,
2862 "hash1",
2863 &[
2864 ("hello", "def hello():", "hello"),
2865 ("world", "def world():", "world"),
2866 ],
2867 );
2868
2869 let project = ProjectRoot::new_exact(root).unwrap();
2870 (dir, project)
2871 }
2872
2873 fn write_python_file_with_symbols(
2874 root: &std::path::Path,
2875 relative_path: &str,
2876 source: &str,
2877 hash: &str,
2878 symbols: &[(&str, &str, &str)],
2879 ) {
2880 std::fs::write(root.join(relative_path), source).unwrap();
2881 let db_path = crate::db::index_db_path(root);
2882 let db = IndexDb::open(&db_path).unwrap();
2883 let file_id = db
2884 .upsert_file(relative_path, 100, hash, source.len() as i64, Some("py"))
2885 .unwrap();
2886
2887 let new_symbols: Vec<NewSymbol<'_>> = symbols
2888 .iter()
2889 .map(|(name, signature, name_path)| {
2890 let start = source.find(signature).unwrap() as i64;
2891 let end = source[start as usize..]
2892 .find("\n\ndef ")
2893 .map(|offset| start + offset as i64)
2894 .unwrap_or(source.len() as i64);
2895 let line = source[..start as usize]
2896 .bytes()
2897 .filter(|&b| b == b'\n')
2898 .count() as i64
2899 + 1;
2900 NewSymbol {
2901 name,
2902 kind: "function",
2903 line,
2904 column_num: 0,
2905 start_byte: start,
2906 end_byte: end,
2907 signature,
2908 name_path,
2909 parent_id: None,
2910 }
2911 })
2912 .collect();
2913 db.insert_symbols(file_id, &new_symbols).unwrap();
2914 }
2915
2916 fn replace_file_embeddings_with_sentinels(
2917 engine: &EmbeddingEngine,
2918 file_path: &str,
2919 sentinels: &[(&str, f32)],
2920 ) {
2921 let mut chunks = engine.store.embeddings_for_files(&[file_path]).unwrap();
2922 for chunk in &mut chunks {
2923 if let Some((_, value)) = sentinels
2924 .iter()
2925 .find(|(symbol_name, _)| *symbol_name == chunk.symbol_name)
2926 {
2927 chunk.embedding = vec![*value; chunk.embedding.len()];
2928 }
2929 }
2930 engine.store.delete_by_file(&[file_path]).unwrap();
2931 engine.store.insert(&chunks).unwrap();
2932 }
2933
2934 #[test]
2935 fn build_embedding_text_with_signature() {
2936 let sym = crate::db::SymbolWithFile {
2937 name: "hello".into(),
2938 kind: "function".into(),
2939 file_path: "main.py".into(),
2940 line: 1,
2941 signature: "def hello():".into(),
2942 name_path: "hello".into(),
2943 start_byte: 0,
2944 end_byte: 10,
2945 };
2946 let text = build_embedding_text(&sym, Some("def hello(): pass"));
2947 assert_eq!(text, "function hello in main.py: def hello():");
2948 }
2949
2950 #[test]
2951 fn build_embedding_text_without_source() {
2952 let sym = crate::db::SymbolWithFile {
2953 name: "MyClass".into(),
2954 kind: "class".into(),
2955 file_path: "app.py".into(),
2956 line: 5,
2957 signature: "class MyClass:".into(),
2958 name_path: "MyClass".into(),
2959 start_byte: 0,
2960 end_byte: 50,
2961 };
2962 let text = build_embedding_text(&sym, None);
2963 assert_eq!(text, "class MyClass (My Class) in app.py: class MyClass:");
2964 }
2965
2966 #[test]
2967 fn build_embedding_text_empty_signature() {
2968 let sym = crate::db::SymbolWithFile {
2969 name: "CONFIG".into(),
2970 kind: "variable".into(),
2971 file_path: "config.py".into(),
2972 line: 1,
2973 signature: String::new(),
2974 name_path: "CONFIG".into(),
2975 start_byte: 0,
2976 end_byte: 0,
2977 };
2978 let text = build_embedding_text(&sym, None);
2979 assert_eq!(text, "variable CONFIG in config.py");
2980 }
2981
2982 #[test]
2983 fn filters_direct_test_symbols_from_embedding_index() {
2984 let source = "#[test]\nfn alias_case() {}\n";
2985 let sym = crate::db::SymbolWithFile {
2986 name: "alias_case".into(),
2987 kind: "function".into(),
2988 file_path: "src/lib.rs".into(),
2989 line: 2,
2990 signature: "fn alias_case() {}".into(),
2991 name_path: "alias_case".into(),
2992 start_byte: source.find("fn alias_case").unwrap() as i64,
2993 end_byte: source.len() as i64,
2994 };
2995
2996 assert!(is_test_only_symbol(&sym, Some(source)));
2997 }
2998
2999 #[test]
3000 fn filters_cfg_test_module_symbols_from_embedding_index() {
3001 let source = "#[cfg(all(test, feature = \"semantic\"))]\nmod semantic_tests {\n fn helper_case() {}\n}\n";
3002 let sym = crate::db::SymbolWithFile {
3003 name: "helper_case".into(),
3004 kind: "function".into(),
3005 file_path: "src/lib.rs".into(),
3006 line: 3,
3007 signature: "fn helper_case() {}".into(),
3008 name_path: "helper_case".into(),
3009 start_byte: source.find("fn helper_case").unwrap() as i64,
3010 end_byte: source.len() as i64,
3011 };
3012
3013 assert!(is_test_only_symbol(&sym, Some(source)));
3014 }
3015
3016 #[test]
3017 fn extract_python_docstring() {
3018 let source =
3019 "def greet(name):\n \"\"\"Say hello to a person.\"\"\"\n print(f'hi {name}')\n";
3020 let doc = extract_leading_doc(source, 0, source.len()).unwrap();
3021 assert!(doc.contains("Say hello to a person"));
3022 }
3023
3024 #[test]
3025 fn extract_rust_doc_comment() {
3026 let source = "fn dispatch_tool() {\n /// Route incoming tool requests.\n /// Handles all MCP methods.\n let x = 1;\n}\n";
3027 let doc = extract_leading_doc(source, 0, source.len()).unwrap();
3028 assert!(doc.contains("Route incoming tool requests"));
3029 assert!(doc.contains("Handles all MCP methods"));
3030 }
3031
3032 #[test]
3033 fn extract_leading_doc_returns_none_for_no_doc() {
3034 let source = "def f():\n return 1\n";
3035 assert!(extract_leading_doc(source, 0, source.len()).is_none());
3036 }
3037
3038 #[test]
3039 fn extract_body_hint_finds_first_meaningful_line() {
3040 let source = "pub fn parse_symbols(\n project: &ProjectRoot,\n) -> Vec<SymbolInfo> {\n let mut parser = tree_sitter::Parser::new();\n parser.set_language(lang);\n}\n";
3041 let hint = extract_body_hint(source, 0, source.len());
3042 assert!(hint.is_some());
3043 assert!(hint.unwrap().contains("tree_sitter::Parser"));
3044 }
3045
3046 #[test]
3047 fn extract_body_hint_skips_comments() {
3048 let source = "fn foo() {\n // setup\n let x = bar();\n}\n";
3049 let hint = extract_body_hint(source, 0, source.len());
3050 assert_eq!(hint.unwrap(), "let x = bar();");
3051 }
3052
3053 #[test]
3054 fn extract_body_hint_returns_none_for_empty() {
3055 let source = "fn empty() {\n}\n";
3056 let hint = extract_body_hint(source, 0, source.len());
3057 assert!(hint.is_none());
3058 }
3059
3060 #[test]
3061 fn extract_body_hint_multi_line_collection_via_env_override() {
3062 let previous_lines = std::env::var("CODELENS_EMBED_HINT_LINES").ok();
3067 let previous_chars = std::env::var("CODELENS_EMBED_HINT_CHARS").ok();
3068 unsafe {
3069 std::env::set_var("CODELENS_EMBED_HINT_LINES", "3");
3070 std::env::set_var("CODELENS_EMBED_HINT_CHARS", "200");
3071 }
3072
3073 let source = "\
3074fn route_request() {
3075 let kind = detect_request_kind();
3076 let target = dispatch_table.get(&kind);
3077 return target.handle();
3078}
3079";
3080 let hint = extract_body_hint(source, 0, source.len()).expect("hint present");
3081
3082 let env_restore = || unsafe {
3083 match &previous_lines {
3084 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_LINES", value),
3085 None => std::env::remove_var("CODELENS_EMBED_HINT_LINES"),
3086 }
3087 match &previous_chars {
3088 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_CHARS", value),
3089 None => std::env::remove_var("CODELENS_EMBED_HINT_CHARS"),
3090 }
3091 };
3092
3093 let all_three = hint.contains("detect_request_kind")
3094 && hint.contains("dispatch_table")
3095 && hint.contains("target.handle");
3096 let has_separator = hint.contains(" · ");
3097 env_restore();
3098
3099 assert!(all_three, "missing one of three body lines: {hint}");
3100 assert!(has_separator, "missing · separator: {hint}");
3101 }
3102
3103 #[test]
3114 fn hint_line_budget_respects_env_override() {
3115 let previous = std::env::var("CODELENS_EMBED_HINT_LINES").ok();
3118 unsafe {
3119 std::env::set_var("CODELENS_EMBED_HINT_LINES", "5");
3120 }
3121 let budget = super::hint_line_budget();
3122 unsafe {
3123 match previous {
3124 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_LINES", value),
3125 None => std::env::remove_var("CODELENS_EMBED_HINT_LINES"),
3126 }
3127 }
3128 assert_eq!(budget, 5);
3129 }
3130
3131 #[test]
3132 fn is_nl_shaped_accepts_multi_word_prose() {
3133 assert!(super::is_nl_shaped("skip comments and string literals"));
3134 assert!(super::is_nl_shaped("failed to open database"));
3135 assert!(super::is_nl_shaped("detect client version"));
3136 }
3137
3138 #[test]
3139 fn is_nl_shaped_rejects_code_and_paths() {
3140 assert!(!super::is_nl_shaped("crates/codelens-engine/src"));
3142 assert!(!super::is_nl_shaped("C:\\Users\\foo"));
3143 assert!(!super::is_nl_shaped("std::sync::Mutex"));
3145 assert!(!super::is_nl_shaped("detect_client"));
3147 assert!(!super::is_nl_shaped("ok"));
3149 assert!(!super::is_nl_shaped(""));
3150 assert!(!super::is_nl_shaped("1 2 3 4 5"));
3152 }
3153
3154 #[test]
3155 fn extract_comment_body_strips_comment_markers() {
3156 assert_eq!(
3157 super::extract_comment_body("/// rust doc comment"),
3158 Some("rust doc comment".to_string())
3159 );
3160 assert_eq!(
3161 super::extract_comment_body("// regular line comment"),
3162 Some("regular line comment".to_string())
3163 );
3164 assert_eq!(
3165 super::extract_comment_body("# python line comment"),
3166 Some("python line comment".to_string())
3167 );
3168 assert_eq!(
3169 super::extract_comment_body("/* inline block */"),
3170 Some("inline block".to_string())
3171 );
3172 assert_eq!(
3173 super::extract_comment_body("* continuation line"),
3174 Some("continuation line".to_string())
3175 );
3176 }
3177
3178 #[test]
3179 fn extract_comment_body_rejects_rust_attributes_and_shebangs() {
3180 assert!(super::extract_comment_body("#[derive(Debug)]").is_none());
3181 assert!(super::extract_comment_body("#[test]").is_none());
3182 assert!(super::extract_comment_body("#!/usr/bin/env python").is_none());
3183 }
3184
3185 #[test]
3186 fn extract_nl_tokens_gated_off_by_default() {
3187 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3188 let previous = std::env::var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS").ok();
3190 unsafe {
3191 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS");
3192 }
3193 let source = "\
3194fn skip_things() {
3195 // skip comments and string literals during search
3196 let lit = \"scan for matching tokens\";
3197}
3198";
3199 let result = extract_nl_tokens(source, 0, source.len());
3200 unsafe {
3201 if let Some(value) = previous {
3202 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", value);
3203 }
3204 }
3205 assert!(result.is_none(), "gate leaked: {result:?}");
3206 }
3207
3208 #[test]
3209 fn auto_hint_mode_defaults_on_unless_explicit_off() {
3210 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3211 let previous = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3219
3220 unsafe {
3222 std::env::remove_var("CODELENS_EMBED_HINT_AUTO");
3223 }
3224 let default_enabled = super::auto_hint_mode_enabled();
3225 assert!(
3226 default_enabled,
3227 "v1.6.0 default flip: auto hint mode should be ON when env unset"
3228 );
3229
3230 unsafe {
3232 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "0");
3233 }
3234 let explicit_off = super::auto_hint_mode_enabled();
3235 assert!(
3236 !explicit_off,
3237 "explicit CODELENS_EMBED_HINT_AUTO=0 must still disable (opt-out escape hatch)"
3238 );
3239
3240 unsafe {
3242 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3243 }
3244 let explicit_on = super::auto_hint_mode_enabled();
3245 assert!(
3246 explicit_on,
3247 "explicit CODELENS_EMBED_HINT_AUTO=1 must still enable"
3248 );
3249
3250 unsafe {
3252 match previous {
3253 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3254 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3255 }
3256 }
3257 }
3258
3259 #[test]
3260 fn language_supports_nl_stack_classifies_correctly() {
3261 assert!(super::language_supports_nl_stack("rs"));
3263 assert!(super::language_supports_nl_stack("rust"));
3264 assert!(super::language_supports_nl_stack("cpp"));
3265 assert!(super::language_supports_nl_stack("c++"));
3266 assert!(super::language_supports_nl_stack("c"));
3267 assert!(super::language_supports_nl_stack("go"));
3268 assert!(super::language_supports_nl_stack("golang"));
3269 assert!(super::language_supports_nl_stack("java"));
3270 assert!(super::language_supports_nl_stack("kt"));
3271 assert!(super::language_supports_nl_stack("kotlin"));
3272 assert!(super::language_supports_nl_stack("scala"));
3273 assert!(super::language_supports_nl_stack("cs"));
3274 assert!(super::language_supports_nl_stack("csharp"));
3275 assert!(super::language_supports_nl_stack("ts"));
3278 assert!(super::language_supports_nl_stack("typescript"));
3279 assert!(super::language_supports_nl_stack("tsx"));
3280 assert!(super::language_supports_nl_stack("js"));
3281 assert!(super::language_supports_nl_stack("javascript"));
3282 assert!(super::language_supports_nl_stack("jsx"));
3283 assert!(super::language_supports_nl_stack("Rust"));
3285 assert!(super::language_supports_nl_stack("RUST"));
3286 assert!(super::language_supports_nl_stack("TypeScript"));
3287 assert!(super::language_supports_nl_stack(" rust "));
3289 assert!(super::language_supports_nl_stack(" ts "));
3290
3291 assert!(!super::language_supports_nl_stack("py"));
3293 assert!(!super::language_supports_nl_stack("python"));
3294 assert!(!super::language_supports_nl_stack("rb"));
3295 assert!(!super::language_supports_nl_stack("ruby"));
3296 assert!(!super::language_supports_nl_stack("php"));
3297 assert!(!super::language_supports_nl_stack("lua"));
3298 assert!(!super::language_supports_nl_stack("sh"));
3299 assert!(!super::language_supports_nl_stack("klingon"));
3301 assert!(!super::language_supports_nl_stack(""));
3302 }
3303
3304 #[test]
3305 fn language_supports_sparse_weighting_classifies_correctly() {
3306 assert!(super::language_supports_sparse_weighting("rs"));
3307 assert!(super::language_supports_sparse_weighting("rust"));
3308 assert!(super::language_supports_sparse_weighting("cpp"));
3309 assert!(super::language_supports_sparse_weighting("go"));
3310 assert!(super::language_supports_sparse_weighting("java"));
3311 assert!(super::language_supports_sparse_weighting("kotlin"));
3312 assert!(super::language_supports_sparse_weighting("csharp"));
3313
3314 assert!(!super::language_supports_sparse_weighting("ts"));
3315 assert!(!super::language_supports_sparse_weighting("typescript"));
3316 assert!(!super::language_supports_sparse_weighting("tsx"));
3317 assert!(!super::language_supports_sparse_weighting("js"));
3318 assert!(!super::language_supports_sparse_weighting("javascript"));
3319 assert!(!super::language_supports_sparse_weighting("jsx"));
3320 assert!(!super::language_supports_sparse_weighting("py"));
3321 assert!(!super::language_supports_sparse_weighting("python"));
3322 assert!(!super::language_supports_sparse_weighting("klingon"));
3323 assert!(!super::language_supports_sparse_weighting(""));
3324 }
3325
3326 #[test]
3327 fn auto_hint_should_enable_requires_both_gate_and_supported_lang() {
3328 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3329 let prev_auto = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3330 let prev_lang = std::env::var("CODELENS_EMBED_HINT_AUTO_LANG").ok();
3331
3332 unsafe {
3336 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "0");
3337 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3338 }
3339 assert!(
3340 !super::auto_hint_should_enable(),
3341 "gate-off (explicit =0) with lang=rust must stay disabled"
3342 );
3343
3344 unsafe {
3346 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3347 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3348 }
3349 assert!(
3350 super::auto_hint_should_enable(),
3351 "gate-on + lang=rust must enable"
3352 );
3353
3354 unsafe {
3355 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3356 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "typescript");
3357 }
3358 assert!(
3359 super::auto_hint_should_enable(),
3360 "gate-on + lang=typescript must keep Phase 2b/2c enabled"
3361 );
3362
3363 unsafe {
3365 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3366 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3367 }
3368 assert!(
3369 !super::auto_hint_should_enable(),
3370 "gate-on + lang=python must stay disabled"
3371 );
3372
3373 unsafe {
3375 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3376 std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG");
3377 }
3378 assert!(
3379 !super::auto_hint_should_enable(),
3380 "gate-on + no lang tag must stay disabled"
3381 );
3382
3383 unsafe {
3385 match prev_auto {
3386 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3387 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3388 }
3389 match prev_lang {
3390 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", v),
3391 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG"),
3392 }
3393 }
3394 }
3395
3396 #[test]
3397 fn auto_sparse_should_enable_requires_both_gate_and_sparse_supported_lang() {
3398 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3399 let prev_auto = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3400 let prev_lang = std::env::var("CODELENS_EMBED_HINT_AUTO_LANG").ok();
3401
3402 unsafe {
3403 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "0");
3404 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3405 }
3406 assert!(
3407 !super::auto_sparse_should_enable(),
3408 "gate-off (explicit =0) must disable sparse auto gate"
3409 );
3410
3411 unsafe {
3412 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3413 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3414 }
3415 assert!(
3416 super::auto_sparse_should_enable(),
3417 "gate-on + lang=rust must enable sparse auto gate"
3418 );
3419
3420 unsafe {
3421 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3422 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "typescript");
3423 }
3424 assert!(
3425 !super::auto_sparse_should_enable(),
3426 "gate-on + lang=typescript must keep sparse auto gate disabled"
3427 );
3428
3429 unsafe {
3430 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3431 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3432 }
3433 assert!(
3434 !super::auto_sparse_should_enable(),
3435 "gate-on + lang=python must keep sparse auto gate disabled"
3436 );
3437
3438 unsafe {
3439 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3440 std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG");
3441 }
3442 assert!(
3443 !super::auto_sparse_should_enable(),
3444 "gate-on + no lang tag must keep sparse auto gate disabled"
3445 );
3446
3447 unsafe {
3448 match prev_auto {
3449 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3450 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3451 }
3452 match prev_lang {
3453 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", v),
3454 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG"),
3455 }
3456 }
3457 }
3458
3459 #[test]
3460 fn nl_tokens_enabled_explicit_env_wins_over_auto() {
3461 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3462 let prev_explicit = std::env::var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS").ok();
3463 let prev_auto = std::env::var("CODELENS_EMBED_HINT_AUTO").ok();
3464 let prev_lang = std::env::var("CODELENS_EMBED_HINT_AUTO_LANG").ok();
3465
3466 unsafe {
3468 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", "1");
3469 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3470 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3471 }
3472 assert!(
3473 super::nl_tokens_enabled(),
3474 "explicit=1 must win over auto+python=off"
3475 );
3476
3477 unsafe {
3479 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", "0");
3480 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3481 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3482 }
3483 assert!(
3484 !super::nl_tokens_enabled(),
3485 "explicit=0 must win over auto+rust=on"
3486 );
3487
3488 unsafe {
3490 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS");
3491 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3492 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "rust");
3493 }
3494 assert!(
3495 super::nl_tokens_enabled(),
3496 "no explicit + auto+rust must enable"
3497 );
3498
3499 unsafe {
3501 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS");
3502 std::env::set_var("CODELENS_EMBED_HINT_AUTO", "1");
3503 std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", "python");
3504 }
3505 assert!(
3506 !super::nl_tokens_enabled(),
3507 "no explicit + auto+python must stay disabled"
3508 );
3509
3510 unsafe {
3512 match prev_explicit {
3513 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS", v),
3514 None => std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_COMMENTS"),
3515 }
3516 match prev_auto {
3517 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO", v),
3518 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO"),
3519 }
3520 match prev_lang {
3521 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_AUTO_LANG", v),
3522 None => std::env::remove_var("CODELENS_EMBED_HINT_AUTO_LANG"),
3523 }
3524 }
3525 }
3526
3527 #[test]
3528 fn strict_comments_gated_off_by_default() {
3529 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3530 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS").ok();
3531 unsafe {
3532 std::env::remove_var("CODELENS_EMBED_HINT_STRICT_COMMENTS");
3533 }
3534 let enabled = super::strict_comments_enabled();
3535 unsafe {
3536 if let Some(value) = previous {
3537 std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", value);
3538 }
3539 }
3540 assert!(!enabled, "strict comments gate leaked");
3541 }
3542
3543 #[test]
3544 fn looks_like_meta_annotation_detects_rejected_prefixes() {
3545 assert!(super::looks_like_meta_annotation("TODO: fix later"));
3547 assert!(super::looks_like_meta_annotation("todo handle edge case"));
3548 assert!(super::looks_like_meta_annotation("FIXME this is broken"));
3549 assert!(super::looks_like_meta_annotation(
3550 "HACK: workaround for bug"
3551 ));
3552 assert!(super::looks_like_meta_annotation("XXX not implemented yet"));
3553 assert!(super::looks_like_meta_annotation(
3554 "BUG in the upstream crate"
3555 ));
3556 assert!(super::looks_like_meta_annotation("REVIEW before merging"));
3557 assert!(super::looks_like_meta_annotation(
3558 "REFACTOR this block later"
3559 ));
3560 assert!(super::looks_like_meta_annotation("TEMP: remove before v2"));
3561 assert!(super::looks_like_meta_annotation(
3562 "DEPRECATED use new_api instead"
3563 ));
3564 assert!(super::looks_like_meta_annotation(
3566 " TODO: with leading ws"
3567 ));
3568 }
3569
3570 #[test]
3571 fn looks_like_meta_annotation_preserves_behaviour_prefixes() {
3572 assert!(!super::looks_like_meta_annotation(
3574 "NOTE: this branch handles empty input"
3575 ));
3576 assert!(!super::looks_like_meta_annotation(
3577 "WARN: overflow is possible"
3578 ));
3579 assert!(!super::looks_like_meta_annotation(
3580 "SAFETY: caller must hold the lock"
3581 ));
3582 assert!(!super::looks_like_meta_annotation(
3583 "PANIC: unreachable by construction"
3584 ));
3585 assert!(!super::looks_like_meta_annotation(
3587 "parse json body from request"
3588 ));
3589 assert!(!super::looks_like_meta_annotation(
3590 "walk directory respecting gitignore"
3591 ));
3592 assert!(!super::looks_like_meta_annotation(
3593 "compute cosine similarity between vectors"
3594 ));
3595 assert!(!super::looks_like_meta_annotation(""));
3597 assert!(!super::looks_like_meta_annotation(" "));
3598 assert!(!super::looks_like_meta_annotation("123 numeric prefix"));
3599 }
3600
3601 #[test]
3602 fn strict_comments_filters_meta_annotations_during_extraction() {
3603 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3604 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS").ok();
3605 unsafe {
3606 std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", "1");
3607 }
3608 let source = "\
3609fn handle_request() {
3610 // TODO: handle the error path properly
3611 // parse json body from the incoming request
3612 // FIXME: this can panic on empty input
3613 // walk directory respecting the gitignore rules
3614 let _x = 1;
3615}
3616";
3617 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3618 unsafe {
3619 match previous {
3620 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", value),
3621 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_COMMENTS"),
3622 }
3623 }
3624 let hint = result.expect("behaviour comments must survive");
3625 assert!(
3629 hint.contains("parse json body"),
3630 "behaviour comment dropped: {hint}"
3631 );
3632 assert!(!hint.contains("TODO"), "TODO annotation leaked: {hint}");
3635 assert!(!hint.contains("FIXME"), "FIXME annotation leaked: {hint}");
3636 }
3637
3638 #[test]
3639 fn strict_comments_is_orthogonal_to_strict_literals() {
3640 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3641 let prev_c = std::env::var("CODELENS_EMBED_HINT_STRICT_COMMENTS").ok();
3645 let prev_l = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3646 unsafe {
3647 std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", "1");
3648 std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS");
3649 }
3650 let source = "\
3653fn handle() {
3654 // handles real behaviour
3655 let fmt = \"format error string\";
3656}
3657";
3658 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3659 unsafe {
3660 match prev_c {
3661 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_COMMENTS", v),
3662 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_COMMENTS"),
3663 }
3664 match prev_l {
3665 Some(v) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", v),
3666 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS"),
3667 }
3668 }
3669 let hint = result.expect("tokens must exist");
3670 assert!(hint.contains("handles real"), "comment dropped: {hint}");
3672 assert!(
3675 hint.contains("format error string"),
3676 "literal dropped: {hint}"
3677 );
3678 }
3679
3680 #[test]
3681 fn strict_literal_filter_gated_off_by_default() {
3682 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3683 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3684 unsafe {
3685 std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS");
3686 }
3687 let enabled = super::strict_literal_filter_enabled();
3688 unsafe {
3689 if let Some(value) = previous {
3690 std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", value);
3691 }
3692 }
3693 assert!(!enabled, "strict literal filter gate leaked");
3694 }
3695
3696 #[test]
3697 fn contains_format_specifier_detects_c_and_python_style() {
3698 assert!(super::contains_format_specifier("Invalid URL %s"));
3700 assert!(super::contains_format_specifier("got %d matches"));
3701 assert!(super::contains_format_specifier("value=%r"));
3702 assert!(super::contains_format_specifier("size=%f"));
3703 assert!(super::contains_format_specifier("sending request to {url}"));
3705 assert!(super::contains_format_specifier("got {0} items"));
3706 assert!(super::contains_format_specifier("{:?}"));
3707 assert!(super::contains_format_specifier("value: {x:.2f}"));
3708 assert!(super::contains_format_specifier("{}"));
3709 assert!(!super::contains_format_specifier(
3711 "skip comments and string literals"
3712 ));
3713 assert!(!super::contains_format_specifier("failed to open database"));
3714 assert!(!super::contains_format_specifier("{name: foo, id: 1}"));
3717 }
3718
3719 #[test]
3720 fn looks_like_error_or_log_prefix_rejects_common_patterns() {
3721 assert!(super::looks_like_error_or_log_prefix("Invalid URL format"));
3722 assert!(super::looks_like_error_or_log_prefix(
3723 "Cannot decode response"
3724 ));
3725 assert!(super::looks_like_error_or_log_prefix("could not open file"));
3726 assert!(super::looks_like_error_or_log_prefix(
3727 "Failed to send request"
3728 ));
3729 assert!(super::looks_like_error_or_log_prefix(
3730 "Expected int, got str"
3731 ));
3732 assert!(super::looks_like_error_or_log_prefix(
3733 "sending request to server"
3734 ));
3735 assert!(super::looks_like_error_or_log_prefix(
3736 "received response headers"
3737 ));
3738 assert!(super::looks_like_error_or_log_prefix(
3739 "starting worker pool"
3740 ));
3741 assert!(!super::looks_like_error_or_log_prefix(
3743 "parse json body from request"
3744 ));
3745 assert!(!super::looks_like_error_or_log_prefix(
3746 "compute cosine similarity between vectors"
3747 ));
3748 assert!(!super::looks_like_error_or_log_prefix(
3749 "walk directory tree respecting gitignore"
3750 ));
3751 }
3752
3753 #[test]
3754 fn strict_mode_rejects_format_and_error_literals_during_extraction() {
3755 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3756 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3760 unsafe {
3761 std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", "1");
3762 }
3763 let source = "\
3764fn handle_request() {
3765 let err = \"Invalid URL %s\";
3766 let log = \"sending request to the upstream server\";
3767 let fmt = \"received {count} items in batch\";
3768 let real = \"parse json body from the incoming request\";
3769}
3770";
3771 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3772 unsafe {
3773 match previous {
3774 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", value),
3775 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS"),
3776 }
3777 }
3778 let hint = result.expect("some token should survive");
3779 assert!(
3781 hint.contains("parse json body"),
3782 "real literal was filtered out: {hint}"
3783 );
3784 assert!(
3786 !hint.contains("Invalid URL"),
3787 "format-specifier literal leaked: {hint}"
3788 );
3789 assert!(
3790 !hint.contains("sending request"),
3791 "log-prefix literal leaked: {hint}"
3792 );
3793 assert!(
3794 !hint.contains("received {count}"),
3795 "python fstring literal leaked: {hint}"
3796 );
3797 }
3798
3799 #[test]
3800 fn strict_mode_leaves_comments_untouched() {
3801 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3802 let previous = std::env::var("CODELENS_EMBED_HINT_STRICT_LITERALS").ok();
3806 unsafe {
3807 std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", "1");
3808 }
3809 let source = "\
3810fn do_work() {
3811 // Invalid inputs are rejected by this guard clause.
3812 // sending requests in parallel across worker threads.
3813 let _lit = \"format spec %s\";
3814}
3815";
3816 let result = super::extract_nl_tokens_inner(source, 0, source.len());
3817 unsafe {
3818 match previous {
3819 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_STRICT_LITERALS", value),
3820 None => std::env::remove_var("CODELENS_EMBED_HINT_STRICT_LITERALS"),
3821 }
3822 }
3823 let hint = result.expect("comments should survive strict mode");
3824 assert!(
3827 hint.contains("Invalid inputs") || hint.contains("rejected by this guard"),
3828 "strict mode swallowed a comment: {hint}"
3829 );
3830 assert!(
3832 !hint.contains("format spec"),
3833 "format-specifier literal leaked under strict mode: {hint}"
3834 );
3835 }
3836
3837 #[test]
3838 fn should_reject_literal_strict_composes_format_and_prefix() {
3839 assert!(super::should_reject_literal_strict("Invalid URL %s"));
3843 assert!(super::should_reject_literal_strict(
3844 "sending request to server"
3845 ));
3846 assert!(super::should_reject_literal_strict("value: {x:.2f}"));
3847 assert!(!super::should_reject_literal_strict(
3849 "parse json body from the incoming request"
3850 ));
3851 assert!(!super::should_reject_literal_strict(
3852 "compute cosine similarity between vectors"
3853 ));
3854 }
3855
3856 #[test]
3857 fn is_static_method_ident_accepts_pascal_and_rejects_snake() {
3858 assert!(super::is_static_method_ident("HashMap"));
3859 assert!(super::is_static_method_ident("Parser"));
3860 assert!(super::is_static_method_ident("A"));
3861 assert!(!super::is_static_method_ident("std"));
3864 assert!(!super::is_static_method_ident("fs"));
3865 assert!(!super::is_static_method_ident("_private"));
3866 assert!(!super::is_static_method_ident(""));
3867 }
3868
3869 #[test]
3870 fn extract_api_calls_gated_off_by_default() {
3871 let _env_guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
3872 let previous = std::env::var("CODELENS_EMBED_HINT_INCLUDE_API_CALLS").ok();
3874 unsafe {
3875 std::env::remove_var("CODELENS_EMBED_HINT_INCLUDE_API_CALLS");
3876 }
3877 let source = "\
3878fn make_parser() {
3879 let p = Parser::new();
3880 let _ = HashMap::with_capacity(8);
3881}
3882";
3883 let result = extract_api_calls(source, 0, source.len());
3884 unsafe {
3885 if let Some(value) = previous {
3886 std::env::set_var("CODELENS_EMBED_HINT_INCLUDE_API_CALLS", value);
3887 }
3888 }
3889 assert!(result.is_none(), "gate leaked: {result:?}");
3890 }
3891
3892 #[test]
3893 fn extract_api_calls_captures_type_method_patterns() {
3894 let source = "\
3896fn open_db() {
3897 let p = Parser::new();
3898 let map = HashMap::with_capacity(16);
3899 let _ = tree_sitter::Parser::new();
3900}
3901";
3902 let hint = super::extract_api_calls_inner(source, 0, source.len())
3903 .expect("api calls should be produced");
3904 assert!(hint.contains("Parser::new"), "missing Parser::new: {hint}");
3905 assert!(
3906 hint.contains("HashMap::with_capacity"),
3907 "missing HashMap::with_capacity: {hint}"
3908 );
3909 }
3910
3911 #[test]
3912 fn extract_api_calls_rejects_module_prefixed_free_functions() {
3913 let source = "\
3916fn read_config() {
3917 let _ = std::fs::read_to_string(\"foo\");
3918 let _ = crate::util::parse();
3919}
3920";
3921 let hint = super::extract_api_calls_inner(source, 0, source.len());
3922 if let Some(hint) = hint {
3925 assert!(!hint.contains("std::fs"), "lowercase module leaked: {hint}");
3926 assert!(
3927 !hint.contains("fs::read_to_string"),
3928 "module-prefixed free function leaked: {hint}"
3929 );
3930 assert!(!hint.contains("crate::util"), "crate path leaked: {hint}");
3931 }
3932 }
3933
3934 #[test]
3935 fn extract_api_calls_deduplicates_repeated_calls() {
3936 let source = "\
3937fn hot_loop() {
3938 for _ in 0..10 {
3939 let _ = Parser::new();
3940 let _ = Parser::new();
3941 }
3942 let _ = Parser::new();
3943}
3944";
3945 let hint = super::extract_api_calls_inner(source, 0, source.len())
3946 .expect("api calls should be produced");
3947 let first = hint.find("Parser::new").expect("hit");
3948 let rest = &hint[first + "Parser::new".len()..];
3949 assert!(
3950 !rest.contains("Parser::new"),
3951 "duplicate not deduplicated: {hint}"
3952 );
3953 }
3954
3955 #[test]
3956 fn extract_api_calls_returns_none_when_body_has_no_type_calls() {
3957 let source = "\
3958fn plain() {
3959 let x = 1;
3960 let y = x + 2;
3961}
3962";
3963 assert!(super::extract_api_calls_inner(source, 0, source.len()).is_none());
3964 }
3965
3966 #[test]
3967 fn extract_nl_tokens_collects_comments_and_string_literals() {
3968 let source = "\
3972fn search_for_matches() {
3973 // skip comments and string literals during search
3974 let error = \"failed to open database\";
3975 let single = \"tok\";
3976 let path = \"src/foo/bar\";
3977 let keyword = match kind {
3978 Kind::Ident => \"detect client version\",
3979 _ => \"\",
3980 };
3981}
3982";
3983 let hint = super::extract_nl_tokens_inner(source, 0, source.len())
3989 .expect("nl tokens should be produced");
3990 let has_first_nl_signal = hint.contains("skip comments")
3994 || hint.contains("failed to open")
3995 || hint.contains("detect client");
3996 assert!(has_first_nl_signal, "no NL signal produced: {hint}");
3997 assert!(!hint.contains(" tok "), "short literal leaked: {hint}");
3999 assert!(!hint.contains("src/foo/bar"), "path literal leaked: {hint}");
4001 }
4002
4003 #[test]
4004 fn hint_char_budget_respects_env_override() {
4005 let previous = std::env::var("CODELENS_EMBED_HINT_CHARS").ok();
4006 unsafe {
4007 std::env::set_var("CODELENS_EMBED_HINT_CHARS", "120");
4008 }
4009 let budget = super::hint_char_budget();
4010 unsafe {
4011 match previous {
4012 Some(value) => std::env::set_var("CODELENS_EMBED_HINT_CHARS", value),
4013 None => std::env::remove_var("CODELENS_EMBED_HINT_CHARS"),
4014 }
4015 }
4016 assert_eq!(budget, 120);
4017 }
4018
4019 #[test]
4020 fn embedding_to_bytes_roundtrip() {
4021 let floats = vec![1.0f32, -0.5, 0.0, 3.25];
4022 let bytes = embedding_to_bytes(&floats);
4023 assert_eq!(bytes.len(), 4 * 4);
4024 let recovered: Vec<f32> = bytes
4026 .chunks_exact(4)
4027 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
4028 .collect();
4029 assert_eq!(floats, recovered);
4030 }
4031
4032 #[test]
4033 fn duplicate_pair_key_is_order_independent() {
4034 let a = duplicate_pair_key("a.py", "foo", "b.py", "bar");
4035 let b = duplicate_pair_key("b.py", "bar", "a.py", "foo");
4036 assert_eq!(a, b);
4037 }
4038
4039 #[test]
4040 fn text_embedding_cache_updates_recency() {
4041 let mut cache = TextEmbeddingCache::new(2);
4042 cache.insert("a".into(), vec![1.0]);
4043 cache.insert("b".into(), vec![2.0]);
4044 assert_eq!(cache.get("a"), Some(vec![1.0]));
4045 cache.insert("c".into(), vec![3.0]);
4046
4047 assert_eq!(cache.get("a"), Some(vec![1.0]));
4048 assert_eq!(cache.get("b"), None);
4049 assert_eq!(cache.get("c"), Some(vec![3.0]));
4050 }
4051
4052 #[test]
4053 fn text_embedding_cache_can_be_disabled() {
4054 let mut cache = TextEmbeddingCache::new(0);
4055 cache.insert("a".into(), vec![1.0]);
4056 assert_eq!(cache.get("a"), None);
4057 }
4058
4059 #[test]
4060 fn engine_new_and_index() {
4061 let _lock = MODEL_LOCK.lock().unwrap();
4062 skip_without_embedding_model!();
4063 let (_dir, project) = make_project_with_source();
4064 let engine = EmbeddingEngine::new(&project).expect("engine should load");
4065 assert!(!engine.is_indexed());
4066
4067 let count = engine.index_from_project(&project).unwrap();
4068 assert_eq!(count, 2, "should index 2 symbols");
4069 assert!(engine.is_indexed());
4070 }
4071
4072 #[test]
4073 fn engine_search_returns_results() {
4074 let _lock = MODEL_LOCK.lock().unwrap();
4075 skip_without_embedding_model!();
4076 let (_dir, project) = make_project_with_source();
4077 let engine = EmbeddingEngine::new(&project).unwrap();
4078 engine.index_from_project(&project).unwrap();
4079
4080 let results = engine.search("hello function", 10).unwrap();
4081 assert!(!results.is_empty(), "search should return results");
4082 for r in &results {
4083 assert!(
4084 r.score >= -1.0 && r.score <= 1.0,
4085 "score should be in [-1,1]: {}",
4086 r.score
4087 );
4088 }
4089 }
4090
4091 #[test]
4092 fn engine_incremental_index() {
4093 let _lock = MODEL_LOCK.lock().unwrap();
4094 skip_without_embedding_model!();
4095 let (_dir, project) = make_project_with_source();
4096 let engine = EmbeddingEngine::new(&project).unwrap();
4097 engine.index_from_project(&project).unwrap();
4098 assert_eq!(engine.store.count().unwrap(), 2);
4099
4100 let count = engine.index_changed_files(&project, &["main.py"]).unwrap();
4102 assert_eq!(count, 2);
4103 assert_eq!(engine.store.count().unwrap(), 2);
4104 }
4105
4106 #[test]
4107 fn engine_reindex_preserves_symbol_count() {
4108 let _lock = MODEL_LOCK.lock().unwrap();
4109 skip_without_embedding_model!();
4110 let (_dir, project) = make_project_with_source();
4111 let engine = EmbeddingEngine::new(&project).unwrap();
4112 engine.index_from_project(&project).unwrap();
4113 assert_eq!(engine.store.count().unwrap(), 2);
4114
4115 let count = engine.index_from_project(&project).unwrap();
4116 assert_eq!(count, 2);
4117 assert_eq!(engine.store.count().unwrap(), 2);
4118 }
4119
4120 #[test]
4121 fn full_reindex_reuses_unchanged_embeddings() {
4122 let _lock = MODEL_LOCK.lock().unwrap();
4123 skip_without_embedding_model!();
4124 let (_dir, project) = make_project_with_source();
4125 let engine = EmbeddingEngine::new(&project).unwrap();
4126 engine.index_from_project(&project).unwrap();
4127
4128 replace_file_embeddings_with_sentinels(
4129 &engine,
4130 "main.py",
4131 &[("hello", 11.0), ("world", 22.0)],
4132 );
4133
4134 let count = engine.index_from_project(&project).unwrap();
4135 assert_eq!(count, 2);
4136
4137 let hello = engine
4138 .store
4139 .get_embedding("main.py", "hello")
4140 .unwrap()
4141 .expect("hello should exist");
4142 let world = engine
4143 .store
4144 .get_embedding("main.py", "world")
4145 .unwrap()
4146 .expect("world should exist");
4147 assert!(hello.embedding.iter().all(|value| *value == 11.0));
4148 assert!(world.embedding.iter().all(|value| *value == 22.0));
4149 }
4150
4151 #[test]
4152 fn full_reindex_reuses_unchanged_sibling_after_edit() {
4153 let _lock = MODEL_LOCK.lock().unwrap();
4154 skip_without_embedding_model!();
4155 let (dir, project) = make_project_with_source();
4156 let engine = EmbeddingEngine::new(&project).unwrap();
4157 engine.index_from_project(&project).unwrap();
4158
4159 replace_file_embeddings_with_sentinels(
4160 &engine,
4161 "main.py",
4162 &[("hello", 11.0), ("world", 22.0)],
4163 );
4164
4165 let updated_source =
4166 "def hello():\n print('hi')\n\ndef world(name):\n return name.upper()\n";
4167 write_python_file_with_symbols(
4168 dir.path(),
4169 "main.py",
4170 updated_source,
4171 "hash2",
4172 &[
4173 ("hello", "def hello():", "hello"),
4174 ("world", "def world(name):", "world"),
4175 ],
4176 );
4177
4178 let count = engine.index_from_project(&project).unwrap();
4179 assert_eq!(count, 2);
4180
4181 let hello = engine
4182 .store
4183 .get_embedding("main.py", "hello")
4184 .unwrap()
4185 .expect("hello should exist");
4186 let world = engine
4187 .store
4188 .get_embedding("main.py", "world")
4189 .unwrap()
4190 .expect("world should exist");
4191 assert!(hello.embedding.iter().all(|value| *value == 11.0));
4192 assert!(world.embedding.iter().any(|value| *value != 22.0));
4193 assert_eq!(engine.store.count().unwrap(), 2);
4194 }
4195
4196 #[test]
4197 fn full_reindex_removes_deleted_files() {
4198 let _lock = MODEL_LOCK.lock().unwrap();
4199 skip_without_embedding_model!();
4200 let (dir, project) = make_project_with_source();
4201 write_python_file_with_symbols(
4202 dir.path(),
4203 "extra.py",
4204 "def bonus():\n return 7\n",
4205 "hash-extra",
4206 &[("bonus", "def bonus():", "bonus")],
4207 );
4208
4209 let engine = EmbeddingEngine::new(&project).unwrap();
4210 engine.index_from_project(&project).unwrap();
4211 assert_eq!(engine.store.count().unwrap(), 3);
4212
4213 std::fs::remove_file(dir.path().join("extra.py")).unwrap();
4214 let db_path = crate::db::index_db_path(dir.path());
4215 let db = IndexDb::open(&db_path).unwrap();
4216 db.delete_file("extra.py").unwrap();
4217
4218 let count = engine.index_from_project(&project).unwrap();
4219 assert_eq!(count, 2);
4220 assert_eq!(engine.store.count().unwrap(), 2);
4221 assert!(
4222 engine
4223 .store
4224 .embeddings_for_files(&["extra.py"])
4225 .unwrap()
4226 .is_empty()
4227 );
4228 }
4229
4230 #[test]
4231 fn engine_model_change_recreates_db() {
4232 let _lock = MODEL_LOCK.lock().unwrap();
4233 skip_without_embedding_model!();
4234 let (_dir, project) = make_project_with_source();
4235
4236 let engine1 = EmbeddingEngine::new(&project).unwrap();
4238 engine1.index_from_project(&project).unwrap();
4239 assert_eq!(engine1.store.count().unwrap(), 2);
4240 drop(engine1);
4241
4242 let engine2 = EmbeddingEngine::new(&project).unwrap();
4244 assert!(engine2.store.count().unwrap() >= 2);
4245 }
4246
4247 #[test]
4248 fn inspect_existing_index_returns_model_and_count() {
4249 let _lock = MODEL_LOCK.lock().unwrap();
4250 skip_without_embedding_model!();
4251 let (_dir, project) = make_project_with_source();
4252 let engine = EmbeddingEngine::new(&project).unwrap();
4253 engine.index_from_project(&project).unwrap();
4254
4255 let info = EmbeddingEngine::inspect_existing_index(&project)
4256 .unwrap()
4257 .expect("index info should exist");
4258 assert_eq!(info.model_name, engine.model_name());
4259 assert_eq!(info.indexed_symbols, 2);
4260 }
4261
4262 #[test]
4263 fn inspect_existing_index_recovers_from_corrupt_db() {
4264 let (_dir, project) = make_project_with_source();
4265 let index_dir = project.as_path().join(".codelens/index");
4266 let db_path = index_dir.join("embeddings.db");
4267 let wal_path = index_dir.join("embeddings.db-wal");
4268 let shm_path = index_dir.join("embeddings.db-shm");
4269
4270 std::fs::write(&db_path, b"not a sqlite database").unwrap();
4271 std::fs::write(&wal_path, b"bad wal").unwrap();
4272 std::fs::write(&shm_path, b"bad shm").unwrap();
4273
4274 let info = EmbeddingEngine::inspect_existing_index(&project).unwrap();
4275 assert!(info.is_none());
4276
4277 assert!(db_path.is_file());
4278
4279 let backup_names: Vec<String> = std::fs::read_dir(&index_dir)
4280 .unwrap()
4281 .map(|entry| entry.unwrap().file_name().to_string_lossy().into_owned())
4282 .filter(|name| name.contains(".corrupt-"))
4283 .collect();
4284
4285 assert!(
4286 backup_names
4287 .iter()
4288 .any(|name| name.starts_with("embeddings.db.corrupt-")),
4289 "expected quarantined embedding db, found {backup_names:?}"
4290 );
4291 }
4292
4293 #[test]
4294 fn store_can_fetch_single_embedding_without_loading_all() {
4295 let _lock = MODEL_LOCK.lock().unwrap();
4296 skip_without_embedding_model!();
4297 let (_dir, project) = make_project_with_source();
4298 let engine = EmbeddingEngine::new(&project).unwrap();
4299 engine.index_from_project(&project).unwrap();
4300
4301 let chunk = engine
4302 .store
4303 .get_embedding("main.py", "hello")
4304 .unwrap()
4305 .expect("embedding should exist");
4306 assert_eq!(chunk.file_path, "main.py");
4307 assert_eq!(chunk.symbol_name, "hello");
4308 assert!(!chunk.embedding.is_empty());
4309 }
4310
4311 #[test]
4312 fn find_similar_code_uses_index_and_excludes_target_symbol() {
4313 let _lock = MODEL_LOCK.lock().unwrap();
4314 skip_without_embedding_model!();
4315 let (_dir, project) = make_project_with_source();
4316 let engine = EmbeddingEngine::new(&project).unwrap();
4317 engine.index_from_project(&project).unwrap();
4318
4319 let matches = engine.find_similar_code("main.py", "hello", 5).unwrap();
4320 assert!(!matches.is_empty());
4321 assert!(
4322 matches
4323 .iter()
4324 .all(|m| !(m.file_path == "main.py" && m.symbol_name == "hello"))
4325 );
4326 }
4327
4328 #[test]
4329 fn delete_by_file_removes_rows_in_one_batch() {
4330 let _lock = MODEL_LOCK.lock().unwrap();
4331 skip_without_embedding_model!();
4332 let (_dir, project) = make_project_with_source();
4333 let engine = EmbeddingEngine::new(&project).unwrap();
4334 engine.index_from_project(&project).unwrap();
4335
4336 let deleted = engine.store.delete_by_file(&["main.py"]).unwrap();
4337 assert_eq!(deleted, 2);
4338 assert_eq!(engine.store.count().unwrap(), 0);
4339 }
4340
4341 #[test]
4342 fn store_streams_embeddings_grouped_by_file() {
4343 let _lock = MODEL_LOCK.lock().unwrap();
4344 skip_without_embedding_model!();
4345 let (_dir, project) = make_project_with_source();
4346 let engine = EmbeddingEngine::new(&project).unwrap();
4347 engine.index_from_project(&project).unwrap();
4348
4349 let mut groups = Vec::new();
4350 engine
4351 .store
4352 .for_each_file_embeddings(&mut |file_path, chunks| {
4353 groups.push((file_path, chunks.len()));
4354 Ok(())
4355 })
4356 .unwrap();
4357
4358 assert_eq!(groups, vec![("main.py".to_string(), 2)]);
4359 }
4360
4361 #[test]
4362 fn store_fetches_embeddings_for_specific_files() {
4363 let _lock = MODEL_LOCK.lock().unwrap();
4364 skip_without_embedding_model!();
4365 let (_dir, project) = make_project_with_source();
4366 let engine = EmbeddingEngine::new(&project).unwrap();
4367 engine.index_from_project(&project).unwrap();
4368
4369 let chunks = engine.store.embeddings_for_files(&["main.py"]).unwrap();
4370 assert_eq!(chunks.len(), 2);
4371 assert!(chunks.iter().all(|chunk| chunk.file_path == "main.py"));
4372 }
4373
4374 #[test]
4375 fn store_fetches_embeddings_for_scored_chunks() {
4376 let _lock = MODEL_LOCK.lock().unwrap();
4377 skip_without_embedding_model!();
4378 let (_dir, project) = make_project_with_source();
4379 let engine = EmbeddingEngine::new(&project).unwrap();
4380 engine.index_from_project(&project).unwrap();
4381
4382 let scored = engine.search_scored("hello world function", 2).unwrap();
4383 let chunks = engine.store.embeddings_for_scored_chunks(&scored).unwrap();
4384
4385 assert_eq!(chunks.len(), scored.len());
4386 assert!(scored.iter().all(|candidate| chunks.iter().any(|chunk| {
4387 chunk.file_path == candidate.file_path
4388 && chunk.symbol_name == candidate.symbol_name
4389 && chunk.line == candidate.line
4390 && chunk.signature == candidate.signature
4391 && chunk.name_path == candidate.name_path
4392 })));
4393 }
4394
4395 #[test]
4396 fn find_misplaced_code_returns_per_file_outliers() {
4397 let _lock = MODEL_LOCK.lock().unwrap();
4398 skip_without_embedding_model!();
4399 let (_dir, project) = make_project_with_source();
4400 let engine = EmbeddingEngine::new(&project).unwrap();
4401 engine.index_from_project(&project).unwrap();
4402
4403 let outliers = engine.find_misplaced_code(5).unwrap();
4404 assert_eq!(outliers.len(), 2);
4405 assert!(outliers.iter().all(|item| item.file_path == "main.py"));
4406 }
4407
4408 #[test]
4409 fn find_duplicates_uses_batched_candidate_embeddings() {
4410 let _lock = MODEL_LOCK.lock().unwrap();
4411 skip_without_embedding_model!();
4412 let (_dir, project) = make_project_with_source();
4413 let engine = EmbeddingEngine::new(&project).unwrap();
4414 engine.index_from_project(&project).unwrap();
4415
4416 replace_file_embeddings_with_sentinels(
4417 &engine,
4418 "main.py",
4419 &[("hello", 5.0), ("world", 5.0)],
4420 );
4421
4422 let duplicates = engine.find_duplicates(0.99, 4).unwrap();
4423 assert!(!duplicates.is_empty());
4424 assert!(duplicates.iter().any(|pair| {
4425 (pair.symbol_a == "main.py:hello" && pair.symbol_b == "main.py:world")
4426 || (pair.symbol_a == "main.py:world" && pair.symbol_b == "main.py:hello")
4427 }));
4428 }
4429
4430 #[test]
4431 fn search_scored_returns_raw_chunks() {
4432 let _lock = MODEL_LOCK.lock().unwrap();
4433 skip_without_embedding_model!();
4434 let (_dir, project) = make_project_with_source();
4435 let engine = EmbeddingEngine::new(&project).unwrap();
4436 engine.index_from_project(&project).unwrap();
4437
4438 let chunks = engine.search_scored("world function", 5).unwrap();
4439 assert!(!chunks.is_empty());
4440 for c in &chunks {
4441 assert!(!c.file_path.is_empty());
4442 assert!(!c.symbol_name.is_empty());
4443 }
4444 }
4445
4446 #[test]
4447 fn configured_embedding_model_name_defaults_to_codesearchnet() {
4448 assert_eq!(configured_embedding_model_name(), CODESEARCH_MODEL_NAME);
4449 }
4450
4451 #[test]
4452 fn requested_embedding_model_override_ignores_default_model_name() {
4453 let previous = std::env::var("CODELENS_EMBED_MODEL").ok();
4454 unsafe {
4455 std::env::set_var("CODELENS_EMBED_MODEL", CODESEARCH_MODEL_NAME);
4456 }
4457
4458 let result = requested_embedding_model_override().unwrap();
4459
4460 unsafe {
4461 match previous {
4462 Some(value) => std::env::set_var("CODELENS_EMBED_MODEL", value),
4463 None => std::env::remove_var("CODELENS_EMBED_MODEL"),
4464 }
4465 }
4466
4467 assert_eq!(result, None);
4468 }
4469
4470 #[cfg(not(feature = "model-bakeoff"))]
4471 #[test]
4472 fn requested_embedding_model_override_requires_bakeoff_feature() {
4473 let previous = std::env::var("CODELENS_EMBED_MODEL").ok();
4474 unsafe {
4475 std::env::set_var("CODELENS_EMBED_MODEL", "all-MiniLM-L12-v2");
4476 }
4477
4478 let err = requested_embedding_model_override().unwrap_err();
4479
4480 unsafe {
4481 match previous {
4482 Some(value) => std::env::set_var("CODELENS_EMBED_MODEL", value),
4483 None => std::env::remove_var("CODELENS_EMBED_MODEL"),
4484 }
4485 }
4486
4487 assert!(err.to_string().contains("model-bakeoff"));
4488 }
4489
4490 #[cfg(feature = "model-bakeoff")]
4491 #[test]
4492 fn requested_embedding_model_override_accepts_alternative_model() {
4493 let previous = std::env::var("CODELENS_EMBED_MODEL").ok();
4494 unsafe {
4495 std::env::set_var("CODELENS_EMBED_MODEL", "all-MiniLM-L12-v2");
4496 }
4497
4498 let result = requested_embedding_model_override().unwrap();
4499
4500 unsafe {
4501 match previous {
4502 Some(value) => std::env::set_var("CODELENS_EMBED_MODEL", value),
4503 None => std::env::remove_var("CODELENS_EMBED_MODEL"),
4504 }
4505 }
4506
4507 assert_eq!(result.as_deref(), Some("all-MiniLM-L12-v2"));
4508 }
4509
4510 #[test]
4511 fn recommended_embed_threads_caps_macos_style_load() {
4512 let threads = recommended_embed_threads();
4513 assert!(threads >= 1);
4514 assert!(threads <= 8);
4515 }
4516
4517 #[test]
4518 fn embed_batch_size_has_safe_default_floor() {
4519 assert!(embed_batch_size() >= 1);
4520 if cfg!(target_os = "macos") {
4521 assert!(embed_batch_size() <= DEFAULT_MACOS_EMBED_BATCH_SIZE);
4522 }
4523 }
4524}