1use anyhow::{Context, Result};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use std::collections::{HashMap, HashSet};
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13use std::time::Duration;
14use walkdir::{DirEntry, WalkDir};
15
16const INDEX_VERSION: u32 = 3;
17const KNOWLEDGE_GRAPH_VERSION: u32 = 1;
18const LOCAL_EMBEDDING_PROVIDER: &str = "local/hash-embedding";
19const DEFAULT_EMBEDDING_PROVIDER: &str = "local";
20const DEFAULT_EMBEDDING_MODEL: &str = "hash-v1";
21const DISABLED_EMBEDDING_PROVIDER: &str = "disabled";
22const DISABLED_EMBEDDING_MODEL: &str = "disabled";
23const DEFAULT_EMBEDDING_DIMENSIONS: usize = 384;
24const DEFAULT_EMBEDDING_BATCH_SIZE: usize = 32;
25const DEFAULT_EMBEDDING_INPUT_CHARS: usize = 8_000;
26const DEFAULT_EMBEDDING_MAX_RETRIES: u32 = 3;
27const DEFAULT_EMBEDDING_RETRY_INITIAL_MS: u64 = 250;
28const DEFAULT_EMBEDDING_RETRY_MAX_MS: u64 = 2_000;
29const DEFAULT_RUN_KNOWLEDGE_MAX_FILE_SIZE_BYTES: u64 = 512 * 1024;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct CodebaseIndex {
33 pub version: u32,
34 pub root: String,
35 pub generated_at: DateTime<Utc>,
36 pub embedding_provider: String,
37 pub embedding_model: String,
38 pub stats: IndexStats,
39 pub files: Vec<IndexedFile>,
40 pub knowledge_graph: KnowledgeGraph,
41}
42
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
44pub struct IndexStats {
45 pub total_seen_files: u64,
46 pub indexed_files: u64,
47 pub skipped_hidden: u64,
48 pub skipped_non_text: u64,
49 pub skipped_large: u64,
50 pub skipped_io_errors: u64,
51 pub total_bytes: u64,
52 pub total_lines: u64,
53 pub embedded_files: u64,
54 pub embedding_dimensions: u32,
55 pub embedding_prompt_tokens: u64,
56 pub embedding_total_tokens: u64,
57 pub language_counts: HashMap<String, u64>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct IndexedFile {
62 pub path: String,
63 pub language: String,
64 pub bytes: u64,
65 pub lines: u32,
66 pub symbol_hints: u32,
67 pub modified_unix_ms: Option<i64>,
68 pub embedding: Vec<f32>,
69}
70
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
72pub struct KnowledgeGraphStats {
73 pub file_nodes: u64,
74 pub symbol_nodes: u64,
75 pub module_nodes: u64,
76 pub symbol_reference_nodes: u64,
77 pub edges: u64,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct KnowledgeGraph {
82 pub version: u32,
83 pub nodes: Vec<KnowledgeNode>,
84 pub edges: Vec<KnowledgeEdge>,
85 pub stats: KnowledgeGraphStats,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct KnowledgeNode {
90 pub id: String,
91 pub kind: String,
92 pub label: String,
93 pub file_path: Option<String>,
94 pub language: Option<String>,
95 pub symbol_kind: Option<String>,
96 pub line: Option<u32>,
97 pub external: bool,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct KnowledgeEdge {
102 pub source: String,
103 pub target: String,
104 pub kind: String,
105}
106
107#[derive(Debug, Clone)]
108pub struct BuildOptions {
109 pub include_hidden: bool,
110 pub include_embeddings: bool,
111 pub max_file_size_bytes: u64,
112 pub embedding_provider: String,
113 pub embedding_model: String,
114 pub embedding_dimensions: usize,
115 pub embedding_batch_size: usize,
116 pub embedding_input_chars: usize,
117 pub embedding_max_retries: u32,
118 pub embedding_retry_initial_ms: u64,
119 pub embedding_retry_max_ms: u64,
120}
121
122impl Default for BuildOptions {
123 fn default() -> Self {
124 Self {
125 include_hidden: false,
126 include_embeddings: true,
127 max_file_size_bytes: 1024 * 1024,
128 embedding_provider: DEFAULT_EMBEDDING_PROVIDER.to_string(),
129 embedding_model: DEFAULT_EMBEDDING_MODEL.to_string(),
130 embedding_dimensions: DEFAULT_EMBEDDING_DIMENSIONS,
131 embedding_batch_size: DEFAULT_EMBEDDING_BATCH_SIZE,
132 embedding_input_chars: DEFAULT_EMBEDDING_INPUT_CHARS,
133 embedding_max_retries: DEFAULT_EMBEDDING_MAX_RETRIES,
134 embedding_retry_initial_ms: DEFAULT_EMBEDDING_RETRY_INITIAL_MS,
135 embedding_retry_max_ms: DEFAULT_EMBEDDING_RETRY_MAX_MS,
136 }
137 }
138}
139
140pub async fn run(args: crate::cli::IndexArgs) -> Result<()> {
141 let root = args
142 .path
143 .clone()
144 .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
145 let root = root.canonicalize().unwrap_or_else(|_| root.clone());
146
147 let options = BuildOptions {
148 include_hidden: args.include_hidden,
149 include_embeddings: true,
150 max_file_size_bytes: args.max_file_size_kib.saturating_mul(1024),
151 embedding_provider: if args.embedding_provider.trim().is_empty() {
152 DEFAULT_EMBEDDING_PROVIDER.to_string()
153 } else {
154 args.embedding_provider.clone()
155 },
156 embedding_model: if args.embedding_model.trim().is_empty() {
157 DEFAULT_EMBEDDING_MODEL.to_string()
158 } else {
159 args.embedding_model.clone()
160 },
161 embedding_dimensions: args.embedding_dimensions.max(64),
162 embedding_batch_size: args.embedding_batch_size.max(1),
163 embedding_input_chars: args.embedding_input_chars.max(256),
164 embedding_max_retries: args.embedding_max_retries,
165 embedding_retry_initial_ms: args.embedding_retry_initial_ms.max(1),
166 embedding_retry_max_ms: args
167 .embedding_retry_max_ms
168 .max(args.embedding_retry_initial_ms.max(1)),
169 };
170
171 let index = build_index(&root, &options).await?;
172 let output_path = args.output.unwrap_or_else(|| default_index_path(&root));
173
174 if let Some(parent) = output_path.parent() {
175 tokio::fs::create_dir_all(parent).await?;
176 }
177
178 let encoded = serde_json::to_string_pretty(&index)?;
179 tokio::fs::write(&output_path, encoded).await?;
180
181 if args.json {
182 let payload = serde_json::json!({
183 "index_path": output_path,
184 "root": index.root,
185 "generated_at": index.generated_at,
186 "embedding_provider": index.embedding_provider,
187 "embedding_model": index.embedding_model,
188 "stats": index.stats,
189 "knowledge_graph": index.knowledge_graph.stats,
190 });
191 println!("{}", serde_json::to_string_pretty(&payload)?);
192 } else {
193 println!("# Codebase Index Built\n");
194 println!("- Root: {}", index.root);
195 println!("- Output: {}", output_path.display());
196 println!(
197 "- Embeddings: {}/{}",
198 index.embedding_provider, index.embedding_model
199 );
200 println!("- Indexed files: {}", index.stats.indexed_files);
201 println!("- Embedded files: {}", index.stats.embedded_files);
202 println!(
203 "- Embedding dimensions: {}",
204 index.stats.embedding_dimensions
205 );
206 println!("- Total lines: {}", index.stats.total_lines);
207 println!("- Total bytes: {}", index.stats.total_bytes);
208 println!(
209 "- Knowledge graph: {} nodes / {} edges",
210 index.knowledge_graph.nodes.len(),
211 index.knowledge_graph.edges.len()
212 );
213 if !index.stats.language_counts.is_empty() {
214 let mut langs: Vec<_> = index.stats.language_counts.iter().collect();
215 langs.sort_by(|a, b| b.1.cmp(a.1).then_with(|| a.0.cmp(b.0)));
216 println!("\nTop languages:");
217 for (lang, count) in langs.into_iter().take(10) {
218 println!("- {}: {} files", lang, count);
219 }
220 }
221 }
222
223 Ok(())
224}
225
226pub async fn refresh_workspace_knowledge_snapshot(root: &Path) -> Result<PathBuf> {
227 let root = root.canonicalize().unwrap_or_else(|_| root.to_path_buf());
228 let options = BuildOptions {
229 include_hidden: false,
230 include_embeddings: false,
231 max_file_size_bytes: DEFAULT_RUN_KNOWLEDGE_MAX_FILE_SIZE_BYTES,
232 ..BuildOptions::default()
233 };
234 let index = build_index(&root, &options).await?;
235 let output_path = default_knowledge_graph_path(&root);
236
237 if let Some(parent) = output_path.parent() {
238 tokio::fs::create_dir_all(parent).await?;
239 }
240
241 let encoded = serde_json::to_string_pretty(&index)?;
242 tokio::fs::write(&output_path, encoded).await?;
243 Ok(output_path)
244}
245
246#[derive(Debug, Clone)]
247struct AnalyzedFileKnowledge {
248 file_node: KnowledgeNode,
249 symbol_nodes: Vec<KnowledgeNode>,
250 imported_modules: Vec<String>,
251 imported_symbols: Vec<String>,
252}
253
254pub async fn build_index(root: &Path, options: &BuildOptions) -> Result<CodebaseIndex> {
255 let mut stats = IndexStats::default();
256 let mut files = Vec::new();
257 let mut embedding_inputs = Vec::new();
258 let mut knowledge_inputs = Vec::new();
259
260 let walker = WalkDir::new(root)
261 .follow_links(false)
262 .into_iter()
263 .filter_entry(|entry| should_descend(entry, root, options.include_hidden));
264
265 for entry in walker.filter_map(std::result::Result::ok) {
266 let path = entry.path();
267 if !path.is_file() {
268 continue;
269 }
270
271 stats.total_seen_files += 1;
272
273 let rel_path = path.strip_prefix(root).unwrap_or(path);
274
275 if !options.include_hidden && is_hidden_path(rel_path) {
276 stats.skipped_hidden += 1;
277 continue;
278 }
279
280 let metadata = match std::fs::metadata(path) {
281 Ok(meta) => meta,
282 Err(_) => {
283 stats.skipped_io_errors += 1;
284 continue;
285 }
286 };
287
288 if metadata.len() > options.max_file_size_bytes {
289 stats.skipped_large += 1;
290 continue;
291 }
292
293 if !is_probably_text_file(path) {
294 stats.skipped_non_text += 1;
295 continue;
296 }
297
298 let content = match std::fs::read_to_string(path) {
299 Ok(text) => text,
300 Err(_) => {
301 stats.skipped_non_text += 1;
302 continue;
303 }
304 };
305
306 let language = detect_language(path);
307 let lines = if content.is_empty() {
308 0
309 } else {
310 (content.as_bytes().iter().filter(|b| **b == b'\n').count() + 1) as u32
311 };
312 let symbol_hints = estimate_symbol_hints(path, &content);
313
314 let rel_path = rel_path.to_string_lossy().to_string();
315
316 let modified_unix_ms = metadata
317 .modified()
318 .ok()
319 .and_then(|ts| ts.duration_since(std::time::UNIX_EPOCH).ok())
320 .map(|dur| dur.as_millis() as i64);
321
322 files.push(IndexedFile {
323 path: rel_path.clone(),
324 language: language.clone(),
325 bytes: metadata.len(),
326 lines,
327 symbol_hints,
328 modified_unix_ms,
329 embedding: Vec::new(),
330 });
331 knowledge_inputs.push(analyze_file_knowledge(
332 &rel_path,
333 &language,
334 lines,
335 metadata.len(),
336 modified_unix_ms,
337 &content,
338 ));
339 embedding_inputs.push(build_embedding_input(
340 &rel_path,
341 &language,
342 &content,
343 options.embedding_input_chars,
344 ));
345
346 stats.indexed_files += 1;
347 stats.total_bytes += metadata.len();
348 stats.total_lines += u64::from(lines);
349 *stats.language_counts.entry(language).or_insert(0) += 1;
350 }
351
352 let (embedding_provider, embedding_model) = if options.include_embeddings {
353 let backend = resolve_embedding_backend(options).await?;
354 let batch_size = options.embedding_batch_size.max(1);
355 stats.embedding_dimensions = options.embedding_dimensions.max(64) as u32;
356
357 for start in (0..embedding_inputs.len()).step_by(batch_size) {
358 let end = (start + batch_size).min(embedding_inputs.len());
359 let embedding_slice = &embedding_inputs[start..end];
360 let (vectors, usage) = match &backend {
361 EmbeddingBackend::Local { engine, .. } => {
362 let vectors = engine.embed_batch(embedding_slice);
363 let mut local_prompt_tokens = 0u64;
364 let mut local_total_tokens = 0u64;
365 for input in embedding_slice {
366 let approx_tokens = approximate_token_count(input);
367 local_prompt_tokens += approx_tokens;
368 local_total_tokens += approx_tokens;
369 }
370 (vectors, (local_prompt_tokens, local_total_tokens))
371 }
372 EmbeddingBackend::Remote(engine) => {
373 let response =
374 engine.embed_batch(embedding_slice).await.with_context(|| {
375 format!(
376 "failed embedding batch {}-{} via provider {}/{}",
377 start, end, engine.provider_name, engine.model
378 )
379 })?;
380
381 let vectors = response.embeddings;
382 let prompt_tokens = response.usage.prompt_tokens as u64;
383 let total_tokens = response.usage.total_tokens as u64;
384 (vectors, (prompt_tokens, total_tokens))
385 }
386 };
387
388 stats.embedding_prompt_tokens += usage.0;
389 stats.embedding_total_tokens += usage.1;
390
391 for (offset, vector) in vectors.into_iter().enumerate() {
392 let dim = vector.len() as u32;
393 if dim != stats.embedding_dimensions {
394 anyhow::bail!(
395 "embedding dimension mismatch: expected {}, got {} (provider: {}, model: {})",
396 stats.embedding_dimensions,
397 dim,
398 backend.provider_name(),
399 backend.model_name(),
400 );
401 }
402
403 files[start + offset].embedding = vector;
404 stats.embedded_files += 1;
405 }
406 }
407
408 (
409 backend.provider_name().to_string(),
410 backend.model_name().to_string(),
411 )
412 } else {
413 (
414 DISABLED_EMBEDDING_PROVIDER.to_string(),
415 DISABLED_EMBEDDING_MODEL.to_string(),
416 )
417 };
418
419 files.sort_by(|a, b| a.path.cmp(&b.path));
420 let knowledge_graph = build_knowledge_graph(knowledge_inputs);
421
422 Ok(CodebaseIndex {
423 version: INDEX_VERSION,
424 root: root.display().to_string(),
425 generated_at: Utc::now(),
426 embedding_provider,
427 embedding_model,
428 stats,
429 files,
430 knowledge_graph,
431 })
432}
433
434fn build_knowledge_graph(files: Vec<AnalyzedFileKnowledge>) -> KnowledgeGraph {
435 let mut nodes = Vec::new();
436 let mut edges = Vec::new();
437 let mut seen_edges: HashSet<(String, String, String)> = HashSet::new();
438 let mut symbol_index: HashMap<String, Vec<String>> = HashMap::new();
439 let mut module_nodes: HashMap<String, String> = HashMap::new();
440 let mut symbol_ref_nodes: HashMap<String, String> = HashMap::new();
441 let mut stats = KnowledgeGraphStats::default();
442
443 for file in &files {
444 nodes.push(file.file_node.clone());
445 stats.file_nodes += 1;
446
447 for symbol in &file.symbol_nodes {
448 nodes.push(symbol.clone());
449 stats.symbol_nodes += 1;
450 symbol_index
451 .entry(symbol.label.clone())
452 .or_default()
453 .push(symbol.id.clone());
454 push_knowledge_edge(
455 &mut edges,
456 &mut seen_edges,
457 &file.file_node.id,
458 &symbol.id,
459 "defines",
460 );
461 }
462 }
463
464 for file in files {
465 for module in file.imported_modules {
466 let module_id = module_nodes
467 .entry(module.clone())
468 .or_insert_with(|| {
469 stats.module_nodes += 1;
470 let id = module_node_id(&module);
471 nodes.push(KnowledgeNode {
472 id: id.clone(),
473 kind: "module".to_string(),
474 label: module.clone(),
475 file_path: None,
476 language: None,
477 symbol_kind: None,
478 line: None,
479 external: true,
480 });
481 id
482 })
483 .clone();
484
485 push_knowledge_edge(
486 &mut edges,
487 &mut seen_edges,
488 &file.file_node.id,
489 &module_id,
490 "imports_module",
491 );
492 }
493
494 for imported_symbol in file.imported_symbols {
495 let target_ids = symbol_index
496 .get(&imported_symbol)
497 .filter(|targets| !targets.is_empty() && targets.len() <= 8)
498 .cloned();
499
500 if let Some(target_ids) = target_ids {
501 for target_id in target_ids {
502 push_knowledge_edge(
503 &mut edges,
504 &mut seen_edges,
505 &file.file_node.id,
506 &target_id,
507 "imports_symbol",
508 );
509 }
510 continue;
511 }
512
513 let symbol_ref_id = symbol_ref_nodes
514 .entry(imported_symbol.clone())
515 .or_insert_with(|| {
516 stats.symbol_reference_nodes += 1;
517 let id = external_symbol_node_id(&imported_symbol);
518 nodes.push(KnowledgeNode {
519 id: id.clone(),
520 kind: "symbol_ref".to_string(),
521 label: imported_symbol.clone(),
522 file_path: None,
523 language: None,
524 symbol_kind: None,
525 line: None,
526 external: true,
527 });
528 id
529 })
530 .clone();
531
532 push_knowledge_edge(
533 &mut edges,
534 &mut seen_edges,
535 &file.file_node.id,
536 &symbol_ref_id,
537 "imports_symbol",
538 );
539 }
540 }
541
542 nodes.sort_by(|a, b| a.id.cmp(&b.id));
543 edges.sort_by(|a, b| {
544 a.source
545 .cmp(&b.source)
546 .then_with(|| a.target.cmp(&b.target))
547 .then_with(|| a.kind.cmp(&b.kind))
548 });
549 stats.edges = edges.len() as u64;
550
551 KnowledgeGraph {
552 version: KNOWLEDGE_GRAPH_VERSION,
553 nodes,
554 edges,
555 stats,
556 }
557}
558
559fn push_knowledge_edge(
560 edges: &mut Vec<KnowledgeEdge>,
561 seen_edges: &mut HashSet<(String, String, String)>,
562 source: &str,
563 target: &str,
564 kind: &str,
565) {
566 let key = (source.to_string(), target.to_string(), kind.to_string());
567 if seen_edges.insert(key.clone()) {
568 edges.push(KnowledgeEdge {
569 source: key.0,
570 target: key.1,
571 kind: key.2,
572 });
573 }
574}
575
576fn analyze_file_knowledge(
577 rel_path: &str,
578 language: &str,
579 _lines: u32,
580 _bytes: u64,
581 _modified_unix_ms: Option<i64>,
582 content: &str,
583) -> AnalyzedFileKnowledge {
584 let file_id = file_node_id(rel_path);
585 let mut symbol_nodes = Vec::new();
586 let mut imported_modules = Vec::new();
587 let mut imported_symbols = Vec::new();
588 let mut seen_symbols: HashSet<(String, u32, String)> = HashSet::new();
589 let mut go_import_block = false;
590
591 for (idx, raw_line) in content.lines().enumerate() {
592 let line_no = idx as u32 + 1;
593 let line = raw_line.trim();
594 if line.is_empty() {
595 continue;
596 }
597
598 if let Some((symbol_kind, name)) = extract_symbol_definition(language, line) {
599 let key = (name.clone(), line_no, symbol_kind.to_string());
600 if seen_symbols.insert(key) {
601 symbol_nodes.push(KnowledgeNode {
602 id: symbol_node_id(rel_path, &name, line_no),
603 kind: "symbol".to_string(),
604 label: name,
605 file_path: Some(rel_path.to_string()),
606 language: Some(language.to_string()),
607 symbol_kind: Some(symbol_kind.to_string()),
608 line: Some(line_no),
609 external: false,
610 });
611 }
612 }
613
614 extract_import_references(
615 language,
616 line,
617 &mut go_import_block,
618 &mut imported_modules,
619 &mut imported_symbols,
620 );
621 }
622
623 imported_modules.sort();
624 imported_modules.dedup();
625 imported_symbols.sort();
626 imported_symbols.dedup();
627
628 let file_node = KnowledgeNode {
629 id: file_id,
630 kind: "file".to_string(),
631 label: rel_path.to_string(),
632 file_path: Some(rel_path.to_string()),
633 language: Some(language.to_string()),
634 symbol_kind: None,
635 line: None,
636 external: false,
637 };
638
639 AnalyzedFileKnowledge {
640 file_node,
641 symbol_nodes,
642 imported_modules,
643 imported_symbols,
644 }
645}
646
647fn extract_symbol_definition(language: &str, line: &str) -> Option<(&'static str, String)> {
648 match language {
649 "rust" => extract_rust_symbol_definition(line),
650 "python" => extract_python_symbol_definition(line),
651 "typescript" | "javascript" => extract_script_symbol_definition(line),
652 "go" => extract_go_symbol_definition(line),
653 _ => None,
654 }
655}
656
657fn extract_rust_symbol_definition(line: &str) -> Option<(&'static str, String)> {
658 let normalized = strip_prefixes(
659 line,
660 &[
661 "pub(crate) ",
662 "pub(super) ",
663 "pub(self) ",
664 "pub ",
665 "async ",
666 "unsafe ",
667 ],
668 );
669
670 for (keyword, kind) in [
671 ("fn", "function"),
672 ("struct", "struct"),
673 ("enum", "enum"),
674 ("trait", "trait"),
675 ("mod", "module"),
676 ("type", "type"),
677 ("const", "const"),
678 ("static", "static"),
679 ] {
680 if let Some(name) = extract_identifier_after_keyword(normalized, keyword) {
681 return Some((kind, name));
682 }
683 }
684
685 None
686}
687
688fn extract_python_symbol_definition(line: &str) -> Option<(&'static str, String)> {
689 let normalized = strip_prefixes(line, &["async "]);
690 if let Some(name) = extract_identifier_after_keyword(normalized, "def") {
691 return Some(("function", name));
692 }
693 if let Some(name) = extract_identifier_after_keyword(normalized, "class") {
694 return Some(("class", name));
695 }
696 None
697}
698
699fn extract_script_symbol_definition(line: &str) -> Option<(&'static str, String)> {
700 let normalized = strip_prefixes(line, &["export default ", "export ", "default ", "async "]);
701
702 for (keyword, kind) in [
703 ("function", "function"),
704 ("class", "class"),
705 ("interface", "interface"),
706 ("type", "type"),
707 ("enum", "enum"),
708 ] {
709 if let Some(name) = extract_identifier_after_keyword(normalized, keyword) {
710 return Some((kind, name));
711 }
712 }
713
714 for keyword in ["const", "let", "var"] {
715 if let Some(name) = extract_identifier_after_keyword(normalized, keyword)
716 && (normalized.contains("=>") || normalized.contains("function("))
717 {
718 return Some(("variable", name));
719 }
720 }
721
722 None
723}
724
725fn extract_go_symbol_definition(line: &str) -> Option<(&'static str, String)> {
726 if let Some(name) = extract_identifier_after_keyword(line, "func") {
727 return Some(("function", name));
728 }
729 if let Some(name) = extract_identifier_after_keyword(line, "type") {
730 return Some(("type", name));
731 }
732 if let Some(name) = extract_identifier_after_keyword(line, "const") {
733 return Some(("const", name));
734 }
735 if let Some(name) = extract_identifier_after_keyword(line, "var") {
736 return Some(("variable", name));
737 }
738 None
739}
740
741fn extract_import_references(
742 language: &str,
743 line: &str,
744 go_import_block: &mut bool,
745 imported_modules: &mut Vec<String>,
746 imported_symbols: &mut Vec<String>,
747) {
748 match language {
749 "rust" => extract_rust_imports(line, imported_modules, imported_symbols),
750 "python" => extract_python_imports(line, imported_modules, imported_symbols),
751 "typescript" | "javascript" => {
752 extract_script_imports(line, imported_modules, imported_symbols);
753 }
754 "go" => extract_go_imports(line, go_import_block, imported_modules, imported_symbols),
755 _ => {}
756 }
757}
758
759fn extract_rust_imports(
760 line: &str,
761 imported_modules: &mut Vec<String>,
762 imported_symbols: &mut Vec<String>,
763) {
764 let normalized = strip_prefixes(line, &["pub "]);
765 let Some(spec) = normalized.strip_prefix("use ") else {
766 return;
767 };
768 let spec = spec.trim_end_matches(';').trim();
769 if spec.is_empty() {
770 return;
771 }
772
773 imported_modules.push(spec.to_string());
774 for segment in spec.split(&['{', '}', ','][..]) {
775 let segment = segment.trim();
776 if segment.is_empty() {
777 continue;
778 }
779
780 let alias_free = segment.split(" as ").next().unwrap_or(segment).trim();
781 let last = alias_free.rsplit("::").next().unwrap_or(alias_free).trim();
782 if last.is_empty() || matches!(last, "self" | "super" | "crate" | "*") {
783 continue;
784 }
785 if let Some(name) = sanitize_identifier(last) {
786 imported_symbols.push(name);
787 }
788 }
789}
790
791fn extract_python_imports(
792 line: &str,
793 imported_modules: &mut Vec<String>,
794 imported_symbols: &mut Vec<String>,
795) {
796 if let Some(rest) = line.strip_prefix("import ") {
797 for module in rest.split(',') {
798 let module = module.trim();
799 let module = module.split_whitespace().next().unwrap_or("");
800 if module.is_empty() {
801 continue;
802 }
803 imported_modules.push(module.to_string());
804 if let Some(name) = module.rsplit('.').next().and_then(sanitize_identifier) {
805 imported_symbols.push(name);
806 }
807 }
808 return;
809 }
810
811 let Some(rest) = line.strip_prefix("from ") else {
812 return;
813 };
814 let Some((module, names)) = rest.split_once(" import ") else {
815 return;
816 };
817 let module = module.trim();
818 if !module.is_empty() {
819 imported_modules.push(module.to_string());
820 }
821 for name in names.split(',') {
822 let name = name.trim();
823 let alias_free = name.split(" as ").next().unwrap_or(name).trim();
824 if let Some(clean) = sanitize_identifier(alias_free) {
825 imported_symbols.push(clean);
826 }
827 }
828}
829
830fn extract_script_imports(
831 line: &str,
832 imported_modules: &mut Vec<String>,
833 imported_symbols: &mut Vec<String>,
834) {
835 let trimmed = line.trim();
836 let is_module_import = trimmed.starts_with("import ")
837 || (trimmed.starts_with("export ") && trimmed.contains(" from "));
838 if !is_module_import && !trimmed.contains("require(") {
839 return;
840 }
841
842 if let Some(module) = extract_quoted_literal(trimmed) {
843 imported_modules.push(module.clone());
844 if let Some(name) = module.rsplit('/').next().and_then(sanitize_identifier) {
845 imported_symbols.push(name);
846 }
847 }
848
849 if let Some((before_from, _)) = trimmed.split_once(" from ") {
850 if let Some((default_import, _)) = before_from
851 .trim_start_matches("import ")
852 .trim_start_matches("export ")
853 .split_once(',')
854 {
855 let default_import = default_import.trim();
856 if !default_import.is_empty() && !default_import.starts_with('{') {
857 if let Some(name) = sanitize_identifier(default_import) {
858 imported_symbols.push(name);
859 }
860 }
861 }
862 }
863
864 if let Some(braced) = extract_braced_section(trimmed) {
865 for name in braced.split(',') {
866 let name = name.trim();
867 let alias_free = name.split(" as ").next().unwrap_or(name).trim();
868 let alias_free = alias_free.trim_start_matches("type ").trim();
869 if let Some(clean) = sanitize_identifier(alias_free) {
870 imported_symbols.push(clean);
871 }
872 }
873 }
874}
875
876fn extract_go_imports(
877 line: &str,
878 go_import_block: &mut bool,
879 imported_modules: &mut Vec<String>,
880 imported_symbols: &mut Vec<String>,
881) {
882 let trimmed = line.trim();
883
884 if *go_import_block {
885 if trimmed == ")" {
886 *go_import_block = false;
887 return;
888 }
889 extract_go_import_entry(trimmed, imported_modules, imported_symbols);
890 return;
891 }
892
893 if trimmed == "import (" {
894 *go_import_block = true;
895 return;
896 }
897
898 if let Some(rest) = trimmed.strip_prefix("import ") {
899 extract_go_import_entry(rest.trim(), imported_modules, imported_symbols);
900 }
901}
902
903fn extract_go_import_entry(
904 line: &str,
905 imported_modules: &mut Vec<String>,
906 imported_symbols: &mut Vec<String>,
907) {
908 let Some(module) = extract_quoted_literal(line) else {
909 return;
910 };
911 imported_modules.push(module.clone());
912
913 let alias = line.split_whitespace().next().unwrap_or("");
914 if !alias.is_empty() && !alias.starts_with('"') && !matches!(alias, "_" | ".") {
915 if let Some(clean) = sanitize_identifier(alias) {
916 imported_symbols.push(clean);
917 return;
918 }
919 }
920
921 if let Some(name) = module.rsplit('/').next().and_then(sanitize_identifier) {
922 imported_symbols.push(name);
923 }
924}
925
926fn extract_identifier_after_keyword(line: &str, keyword: &str) -> Option<String> {
927 let prefix = format!("{keyword} ");
928 let rest = line.strip_prefix(&prefix)?;
929 sanitize_identifier(rest)
930}
931
932fn sanitize_identifier(input: &str) -> Option<String> {
933 let mut out = String::new();
934 for ch in input.chars() {
935 if ch.is_ascii_alphanumeric() || ch == '_' || ch == '$' {
936 out.push(ch);
937 } else {
938 break;
939 }
940 }
941
942 if out.is_empty() { None } else { Some(out) }
943}
944
945fn strip_prefixes<'a>(mut input: &'a str, prefixes: &[&str]) -> &'a str {
946 loop {
947 let mut matched = false;
948 for prefix in prefixes {
949 if let Some(rest) = input.strip_prefix(prefix) {
950 input = rest.trim_start();
951 matched = true;
952 break;
953 }
954 }
955
956 if !matched {
957 return input;
958 }
959 }
960}
961
962fn extract_quoted_literal(line: &str) -> Option<String> {
963 for quote in ['"', '\''] {
964 let mut parts = line.split(quote);
965 let _ = parts.next();
966 if let Some(value) = parts.next()
967 && !value.trim().is_empty()
968 {
969 return Some(value.trim().to_string());
970 }
971 }
972 None
973}
974
975fn extract_braced_section(line: &str) -> Option<String> {
976 let start = line.find('{')?;
977 let end = line[start + 1..].find('}')?;
978 Some(line[start + 1..start + 1 + end].to_string())
979}
980
981fn file_node_id(path: &str) -> String {
982 format!("file:{path}")
983}
984
985fn symbol_node_id(path: &str, name: &str, line: u32) -> String {
986 format!("symbol:{path}:{line}:{name}")
987}
988
989fn module_node_id(module: &str) -> String {
990 format!("module:{module}")
991}
992
993fn external_symbol_node_id(symbol: &str) -> String {
994 format!("symbol-ref:{symbol}")
995}
996
997enum EmbeddingBackend {
998 Local {
999 engine: LocalEmbeddingEngine,
1000 model: String,
1001 },
1002 Remote(RemoteEmbeddingEngine),
1003}
1004
1005impl EmbeddingBackend {
1006 fn provider_name(&self) -> &str {
1007 match self {
1008 Self::Local { .. } => LOCAL_EMBEDDING_PROVIDER,
1009 Self::Remote(engine) => &engine.provider_name,
1010 }
1011 }
1012
1013 fn model_name(&self) -> &str {
1014 match self {
1015 Self::Local { model, .. } => model,
1016 Self::Remote(engine) => &engine.model,
1017 }
1018 }
1019}
1020
1021#[derive(Clone)]
1022struct RemoteEmbeddingEngine {
1023 provider: Arc<dyn crate::provider::Provider>,
1024 provider_name: String,
1025 model: String,
1026 max_retries: u32,
1027 retry_initial: Duration,
1028 retry_max: Duration,
1029}
1030
1031impl RemoteEmbeddingEngine {
1032 async fn embed_batch(&self, inputs: &[String]) -> Result<crate::provider::EmbeddingResponse> {
1033 if inputs.is_empty() {
1034 return Ok(crate::provider::EmbeddingResponse {
1035 embeddings: Vec::new(),
1036 usage: crate::provider::Usage::default(),
1037 });
1038 }
1039
1040 let mut attempt = 0u32;
1041 loop {
1042 let request = crate::provider::EmbeddingRequest {
1043 model: self.model.clone(),
1044 inputs: inputs.to_vec(),
1045 };
1046
1047 match self.provider.embed(request).await {
1048 Ok(response) => return Ok(response),
1049 Err(err) => {
1050 let should_retry =
1051 attempt < self.max_retries && is_retryable_embedding_error(&err);
1052 if !should_retry {
1053 return Err(anyhow::anyhow!(
1054 "remote embedding failed via {}/{} after {} attempt(s): {}",
1055 self.provider_name,
1056 self.model,
1057 attempt + 1,
1058 err
1059 ));
1060 }
1061
1062 let delay = retry_delay(attempt, self.retry_initial, self.retry_max);
1063 tracing::warn!(
1064 provider = %self.provider_name,
1065 model = %self.model,
1066 attempt = attempt + 1,
1067 retry_in_ms = delay.as_millis(),
1068 error = %err,
1069 "Embedding batch failed, retrying"
1070 );
1071
1072 tokio::time::sleep(delay).await;
1073 attempt += 1;
1074 }
1075 }
1076 }
1077 }
1078}
1079
1080async fn resolve_embedding_backend(options: &BuildOptions) -> Result<EmbeddingBackend> {
1081 let dimensions = options.embedding_dimensions.max(64);
1082 if is_local_embedding_provider(&options.embedding_provider) {
1083 return Ok(EmbeddingBackend::Local {
1084 engine: LocalEmbeddingEngine::new(dimensions),
1085 model: options.embedding_model.clone(),
1086 });
1087 }
1088
1089 let model_selector =
1090 build_model_selector(&options.embedding_provider, &options.embedding_model)?;
1091 let registry = crate::provider::ProviderRegistry::from_vault().await?;
1092 let (provider, model) = registry
1093 .resolve_model(&model_selector)
1094 .with_context(|| format!("failed resolving embedding model '{model_selector}'"))?;
1095
1096 let retry_initial = Duration::from_millis(options.embedding_retry_initial_ms.max(1));
1097 let retry_max = Duration::from_millis(options.embedding_retry_max_ms.max(1));
1098
1099 Ok(EmbeddingBackend::Remote(RemoteEmbeddingEngine {
1100 provider_name: provider.name().to_string(),
1101 provider,
1102 model,
1103 max_retries: options.embedding_max_retries,
1104 retry_initial,
1105 retry_max,
1106 }))
1107}
1108
1109fn is_local_embedding_provider(value: &str) -> bool {
1110 matches!(
1111 value.trim().to_ascii_lowercase().as_str(),
1112 "local" | "hash" | "hash-embedding" | "local/hash-embedding"
1113 )
1114}
1115
1116fn build_model_selector(provider: &str, model: &str) -> Result<String> {
1117 let provider = provider.trim();
1118 let model = model.trim();
1119
1120 if model.is_empty() {
1121 anyhow::bail!("embedding model cannot be empty");
1122 }
1123
1124 if model.contains('/') {
1125 return Ok(model.to_string());
1126 }
1127
1128 if provider.is_empty() {
1129 anyhow::bail!(
1130 "embedding provider cannot be empty when model does not include provider prefix"
1131 );
1132 }
1133
1134 Ok(format!("{provider}/{model}"))
1135}
1136
1137fn retry_delay(attempt: u32, initial: Duration, max: Duration) -> Duration {
1138 let multiplier = 2u128.saturating_pow(attempt);
1139 let initial_ms = initial.as_millis();
1140 let max_ms = max.as_millis().max(initial_ms);
1141 let delay_ms = initial_ms.saturating_mul(multiplier).min(max_ms);
1142 Duration::from_millis(delay_ms as u64)
1143}
1144
1145fn is_retryable_embedding_error(error: &anyhow::Error) -> bool {
1146 let msg = error.to_string().to_ascii_lowercase();
1147 [
1148 "timeout",
1149 "timed out",
1150 "connection reset",
1151 "connection refused",
1152 "temporary",
1153 "temporarily unavailable",
1154 "rate limit",
1155 "too many requests",
1156 " 429",
1157 " 500",
1158 " 502",
1159 " 503",
1160 " 504",
1161 ]
1162 .iter()
1163 .any(|needle| msg.contains(needle))
1164}
1165
1166fn approximate_token_count(text: &str) -> u64 {
1167 let words = text.split_whitespace().count() as u64;
1168 words.max(1)
1169}
1170
1171fn build_embedding_input(path: &str, language: &str, content: &str, max_chars: usize) -> String {
1172 let snippet = safe_char_prefix(content, max_chars);
1173 format!("path:{path}\nlanguage:{language}\n\n{snippet}")
1174}
1175
1176fn safe_char_prefix(input: &str, max_chars: usize) -> String {
1177 input.chars().take(max_chars).collect()
1178}
1179
1180#[derive(Debug, Clone)]
1181struct LocalEmbeddingEngine {
1182 dimensions: usize,
1183}
1184
1185impl LocalEmbeddingEngine {
1186 fn new(dimensions: usize) -> Self {
1187 Self { dimensions }
1188 }
1189
1190 fn embed_batch(&self, inputs: &[String]) -> Vec<Vec<f32>> {
1191 inputs
1192 .iter()
1193 .map(|input| self.embed_single(input))
1194 .collect()
1195 }
1196
1197 fn embed_single(&self, input: &str) -> Vec<f32> {
1198 let mut vector = vec![0.0f32; self.dimensions];
1199 let tokens = tokenize_for_embedding(input);
1200
1201 if tokens.is_empty() {
1202 self.accumulate_char_ngrams(&mut vector, input);
1203 } else {
1204 for (idx, token) in tokens.iter().enumerate() {
1205 let positional_weight = 1.0f32 / (1.0 + (idx as f32 / 128.0));
1206 self.accumulate_token(&mut vector, token, positional_weight);
1207
1208 if let Some(next) = tokens.get(idx + 1) {
1209 let bigram = format!("{token} {next}");
1210 self.accumulate_token(&mut vector, &bigram, positional_weight * 0.65);
1211 }
1212 }
1213 }
1214
1215 l2_normalize(&mut vector);
1216 vector
1217 }
1218
1219 fn accumulate_char_ngrams(&self, vector: &mut [f32], input: &str) {
1220 for ngram in input.as_bytes().windows(3).take(2048) {
1221 let key = String::from_utf8_lossy(ngram);
1222 self.accumulate_token(vector, &key, 0.5);
1223 }
1224 }
1225
1226 fn accumulate_token(&self, vector: &mut [f32], token: &str, weight: f32) {
1227 if token.is_empty() {
1228 return;
1229 }
1230
1231 let digest = Sha256::digest(token.as_bytes());
1232 let len = vector.len();
1233
1234 let idx_a = (u16::from_le_bytes([digest[0], digest[1]]) as usize) % len;
1235 let idx_b = (u16::from_le_bytes([digest[2], digest[3]]) as usize) % len;
1236 let idx_c = (u16::from_le_bytes([digest[4], digest[5]]) as usize) % len;
1237
1238 let sign_a = if digest[6] & 1 == 0 { 1.0 } else { -1.0 };
1239 let sign_b = if digest[7] & 1 == 0 { 1.0 } else { -1.0 };
1240 let sign_c = if digest[8] & 1 == 0 { 1.0 } else { -1.0 };
1241
1242 vector[idx_a] += sign_a * weight;
1243 vector[idx_b] += sign_b * (weight * 0.7);
1244 vector[idx_c] += sign_c * (weight * 0.4);
1245 }
1246}
1247
1248fn tokenize_for_embedding(input: &str) -> Vec<String> {
1249 let mut tokens = Vec::new();
1250 let mut current = String::new();
1251
1252 for ch in input.chars() {
1253 if ch.is_ascii_alphanumeric() || ch == '_' {
1254 current.push(ch.to_ascii_lowercase());
1255 } else if !current.is_empty() {
1256 tokens.push(std::mem::take(&mut current));
1257 if tokens.len() >= 4096 {
1258 return tokens;
1259 }
1260 }
1261 }
1262
1263 if !current.is_empty() {
1264 tokens.push(current);
1265 }
1266
1267 tokens
1268}
1269
1270fn l2_normalize(values: &mut [f32]) {
1271 let norm = values.iter().map(|v| v * v).sum::<f32>().sqrt();
1272 if norm > 0.0 {
1273 for value in values {
1274 *value /= norm;
1275 }
1276 }
1277}
1278
1279fn default_index_path(root: &Path) -> PathBuf {
1280 let mut hasher = Sha256::new();
1281 hasher.update(root.to_string_lossy().as_bytes());
1282 let digest = hasher.finalize();
1283 let short = hex::encode(digest);
1284 let short = &short[..16];
1285
1286 let base = crate::config::Config::data_dir().unwrap_or_else(|| root.join(".codetether-agent"));
1287 base.join("indexes")
1288 .join(format!("codebase-index-{short}.json"))
1289}
1290
1291fn default_knowledge_graph_path(root: &Path) -> PathBuf {
1292 let mut hasher = Sha256::new();
1293 hasher.update(root.to_string_lossy().as_bytes());
1294 let digest = hasher.finalize();
1295 let short = hex::encode(digest);
1296 let short = &short[..16];
1297
1298 let base = crate::config::Config::data_dir().unwrap_or_else(|| root.join(".codetether-agent"));
1299 base.join("indexes")
1300 .join(format!("workspace-knowledge-{short}.json"))
1301}
1302
1303fn should_descend(entry: &DirEntry, root: &Path, include_hidden: bool) -> bool {
1304 let path = entry.path();
1305 let rel_path = path.strip_prefix(root).unwrap_or(path);
1306
1307 if !include_hidden && is_hidden_path(rel_path) {
1308 return false;
1309 }
1310
1311 let skip_dirs = [
1312 ".git",
1313 ".hg",
1314 ".svn",
1315 "node_modules",
1316 "target",
1317 "dist",
1318 "build",
1319 ".next",
1320 "vendor",
1321 "__pycache__",
1322 ".venv",
1323 ".codetether-agent",
1324 ];
1325
1326 !path
1327 .components()
1328 .any(|c| skip_dirs.contains(&c.as_os_str().to_str().unwrap_or("")))
1329}
1330
1331fn is_hidden_path(path: &Path) -> bool {
1332 path.components().any(|c| {
1333 c.as_os_str()
1334 .to_str()
1335 .map(|name| name.starts_with('.'))
1336 .unwrap_or(false)
1337 })
1338}
1339
1340fn is_probably_text_file(path: &Path) -> bool {
1341 let text_exts = [
1342 "rs", "ts", "js", "tsx", "jsx", "py", "go", "java", "kt", "c", "cpp", "h", "hpp", "md",
1343 "txt", "json", "yaml", "yml", "toml", "sh", "bash", "zsh", "html", "css", "scss", "sql",
1344 "proto", "xml", "ini", "env", "lock",
1345 ];
1346
1347 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
1348 if text_exts.contains(&ext) {
1349 return true;
1350 }
1351 }
1352
1353 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1354 return matches!(name, "Dockerfile" | "Makefile" | "Jenkinsfile" | "README");
1355 }
1356
1357 false
1358}
1359
1360fn detect_language(path: &Path) -> String {
1361 let ext = path
1362 .extension()
1363 .and_then(|e| e.to_str())
1364 .unwrap_or_default()
1365 .to_ascii_lowercase();
1366
1367 match ext.as_str() {
1368 "rs" => "rust",
1369 "ts" | "tsx" => "typescript",
1370 "js" | "jsx" => "javascript",
1371 "py" => "python",
1372 "go" => "go",
1373 "java" => "java",
1374 "kt" => "kotlin",
1375 "c" | "h" => "c",
1376 "cpp" | "hpp" | "cc" | "cxx" => "cpp",
1377 "json" => "json",
1378 "yaml" | "yml" => "yaml",
1379 "toml" => "toml",
1380 "md" => "markdown",
1381 "sh" | "bash" | "zsh" => "shell",
1382 "proto" => "proto",
1383 "sql" => "sql",
1384 "html" => "html",
1385 "css" | "scss" => "css",
1386 _ => {
1387 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1388 match name {
1389 "Dockerfile" => "dockerfile",
1390 "Makefile" => "makefile",
1391 "Jenkinsfile" => "groovy",
1392 _ => "text",
1393 }
1394 } else {
1395 "text"
1396 }
1397 }
1398 }
1399 .to_string()
1400}
1401
1402fn estimate_symbol_hints(path: &Path, content: &str) -> u32 {
1403 let ext = path
1404 .extension()
1405 .and_then(|e| e.to_str())
1406 .unwrap_or_default()
1407 .to_ascii_lowercase();
1408
1409 let mut count = 0u32;
1410 for line in content.lines().map(str::trim_start) {
1411 let hit = match ext.as_str() {
1412 "rs" => estimate_rust_symbol_hint(line),
1413 "py" => line.starts_with("def ") || line.starts_with("class "),
1414 "ts" | "tsx" | "js" | "jsx" => {
1415 line.starts_with("function ")
1416 || line.contains("=>")
1417 || line.starts_with("class ")
1418 || line.starts_with("export function ")
1419 }
1420 "go" => line.starts_with("func ") || line.starts_with("type "),
1421 "java" | "kt" => {
1422 line.contains(" class ")
1423 || line.starts_with("class ")
1424 || line.starts_with("interface ")
1425 || line.contains(" fun ")
1426 }
1427 _ => false,
1428 };
1429
1430 if hit {
1431 count = count.saturating_add(1);
1432 }
1433 }
1434
1435 count
1436}
1437
1438fn estimate_rust_symbol_hint(line: &str) -> bool {
1439 let normalized = strip_prefixes(
1440 line,
1441 &[
1442 "pub(crate) ",
1443 "pub(super) ",
1444 "pub(self) ",
1445 "pub ",
1446 "async ",
1447 "unsafe ",
1448 ],
1449 );
1450
1451 normalized.starts_with("fn ")
1452 || normalized.starts_with("struct ")
1453 || normalized.starts_with("enum ")
1454 || normalized.starts_with("trait ")
1455 || normalized.starts_with("impl ")
1456 || normalized.starts_with("mod ")
1457 || normalized.starts_with("type ")
1458 || normalized.starts_with("const ")
1459 || normalized.starts_with("static ")
1460}
1461
1462#[cfg(test)]
1463mod tests {
1464 use super::*;
1465 use anyhow::anyhow;
1466 use tempfile::tempdir;
1467
1468 #[test]
1469 fn detects_hidden_paths() {
1470 assert!(is_hidden_path(Path::new(".git/config")));
1471 assert!(is_hidden_path(Path::new("src/.cache/file")));
1472 assert!(!is_hidden_path(Path::new("src/main.rs")));
1473 }
1474
1475 #[test]
1476 fn language_detection_works() {
1477 assert_eq!(detect_language(Path::new("src/main.rs")), "rust");
1478 assert_eq!(detect_language(Path::new("app.py")), "python");
1479 assert_eq!(detect_language(Path::new("Dockerfile")), "dockerfile");
1480 }
1481
1482 #[test]
1483 fn symbol_hint_estimation_works_for_rust() {
1484 let content = "pub struct A;\nimpl A {}\nfn run() {}\n";
1485 assert_eq!(estimate_symbol_hints(Path::new("src/lib.rs"), content), 3);
1486 }
1487
1488 #[test]
1489 fn local_embeddings_have_expected_dimensions() {
1490 let engine = LocalEmbeddingEngine::new(384);
1491 let vectors = engine.embed_batch(&["fn main() { println!(\"hi\") }".to_string()]);
1492 assert_eq!(vectors.len(), 1);
1493 assert_eq!(vectors[0].len(), 384);
1494 }
1495
1496 #[test]
1497 fn embedding_input_prefix_is_char_safe() {
1498 let input = "✓✓✓hello";
1499 let prefixed = build_embedding_input("src/main.rs", "rust", input, 2);
1500 assert!(prefixed.contains("✓✓"));
1501 }
1502
1503 #[test]
1504 fn local_embedding_provider_aliases_are_supported() {
1505 assert!(is_local_embedding_provider("local"));
1506 assert!(is_local_embedding_provider("local/hash-embedding"));
1507 assert!(is_local_embedding_provider("HASH"));
1508 assert!(!is_local_embedding_provider("huggingface"));
1509 }
1510
1511 #[test]
1512 fn model_selector_uses_explicit_prefix_when_missing() {
1513 let selector = build_model_selector("huggingface", "BAAI/bge-small-en-v1.5")
1514 .expect("model selector should build");
1515 assert_eq!(selector, "BAAI/bge-small-en-v1.5");
1516
1517 let selector = build_model_selector("huggingface", "text-embedding-3-large")
1518 .expect("model selector should build");
1519 assert_eq!(selector, "huggingface/text-embedding-3-large");
1520 }
1521
1522 #[test]
1523 fn retryable_embedding_error_detection_matches_transient_signals() {
1524 assert!(is_retryable_embedding_error(&anyhow!(
1525 "HTTP 429 too many requests"
1526 )));
1527 assert!(is_retryable_embedding_error(&anyhow!("gateway timeout")));
1528 assert!(!is_retryable_embedding_error(&anyhow!(
1529 "invalid embedding model"
1530 )));
1531 }
1532
1533 #[tokio::test]
1534 async fn build_index_emits_workspace_knowledge_graph() {
1535 let temp = tempdir().expect("tempdir");
1536 std::fs::write(temp.path().join("types.rs"), "pub struct Session;\n").expect("write");
1537 std::fs::write(
1538 temp.path().join("main.rs"),
1539 "use crate::types::Session;\nfn run() {}\n",
1540 )
1541 .expect("write");
1542
1543 let index = build_index(
1544 temp.path(),
1545 &BuildOptions {
1546 include_embeddings: false,
1547 ..BuildOptions::default()
1548 },
1549 )
1550 .await
1551 .expect("index should build");
1552
1553 assert_eq!(index.embedding_provider, DISABLED_EMBEDDING_PROVIDER);
1554 assert!(
1555 index
1556 .knowledge_graph
1557 .nodes
1558 .iter()
1559 .any(|node| node.kind == "symbol" && node.label == "Session")
1560 );
1561 assert!(
1562 index
1563 .knowledge_graph
1564 .edges
1565 .iter()
1566 .any(|edge| edge.kind == "imports_symbol" && edge.target.contains("Session"))
1567 );
1568 }
1569
1570 #[tokio::test]
1571 async fn refresh_workspace_knowledge_snapshot_writes_json() {
1572 let temp = tempdir().expect("tempdir");
1573 let data_dir = temp.path().join("data");
1574 std::fs::write(temp.path().join("lib.rs"), "pub fn run() {}\n").expect("write");
1575
1576 unsafe {
1577 std::env::set_var("CODETETHER_DATA_DIR", data_dir.display().to_string());
1578 }
1579
1580 let output_path = refresh_workspace_knowledge_snapshot(temp.path())
1581 .await
1582 .expect("snapshot should write");
1583 let payload = std::fs::read_to_string(&output_path).expect("snapshot payload");
1584
1585 unsafe {
1586 std::env::remove_var("CODETETHER_DATA_DIR");
1587 }
1588
1589 assert_eq!(
1590 output_path.extension().and_then(|ext| ext.to_str()),
1591 Some("json")
1592 );
1593 assert!(payload.contains("\"knowledge_graph\""));
1594 assert!(payload.contains("\"symbol\""));
1595 }
1596}