1use std::path::{Path, PathBuf};
21use std::sync::atomic::{AtomicUsize, Ordering};
22
23use rayon::prelude::*;
24
25use crate::parser::{parse_file_symbols, Language};
26use crate::security::SecurityScanner;
27use crate::tokenizer::{TokenModel, Tokenizer};
28use crate::types::Symbol;
29
30use super::error::EmbedError;
31use super::hasher::hash_content;
32use super::hierarchy::{HierarchyBuilder, HierarchyConfig};
33use super::limits::ResourceLimits;
34use super::progress::ProgressReporter;
35use super::types::{
36 ChunkContext, ChunkKind, ChunkPart, ChunkSource, EmbedChunk, EmbedSettings, RepoIdentifier,
37 Visibility,
38};
39
40pub struct EmbedChunker {
42 settings: EmbedSettings,
43 limits: ResourceLimits,
44 tokenizer: Tokenizer,
45 security_scanner: Option<SecurityScanner>,
46 repo_id: RepoIdentifier,
48}
49
50impl EmbedChunker {
51 pub fn new(settings: EmbedSettings, limits: ResourceLimits) -> Self {
53 let security_scanner = if settings.scan_secrets {
55 Some(SecurityScanner::new())
56 } else {
57 None
58 };
59
60 Self {
61 settings,
62 limits,
63 tokenizer: Tokenizer::new(),
64 security_scanner,
65 repo_id: RepoIdentifier::default(),
66 }
67 }
68
69 pub fn with_defaults(settings: EmbedSettings) -> Self {
71 Self::new(settings, ResourceLimits::default())
72 }
73
74 pub fn with_repo_id(mut self, repo_id: RepoIdentifier) -> Self {
88 self.repo_id = repo_id;
89 self
90 }
91
92 pub fn set_repo_id(&mut self, repo_id: RepoIdentifier) {
94 self.repo_id = repo_id;
95 }
96
97 pub fn repo_id(&self) -> &RepoIdentifier {
99 &self.repo_id
100 }
101
102 pub fn chunk_repository(
111 &self,
112 repo_path: &Path,
113 progress: &dyn ProgressReporter,
114 ) -> Result<Vec<EmbedChunk>, EmbedError> {
115 let repo_root = self.validate_repo_path(repo_path)?;
117
118 progress.set_phase("Scanning repository...");
120 let mut files = self.discover_files(&repo_root)?;
121 files.sort(); progress.set_total(files.len());
123
124 if files.is_empty() {
125 return Err(EmbedError::NoChunksGenerated {
126 include_patterns: "default".to_string(),
127 exclude_patterns: "default".to_string(),
128 });
129 }
130
131 if !self.limits.check_file_count(files.len()) {
133 return Err(EmbedError::TooManyFiles {
134 count: files.len(),
135 max: self.limits.max_files,
136 });
137 }
138
139 progress.set_phase("Parsing and chunking...");
141 let chunk_count = AtomicUsize::new(0);
142 let processed = AtomicUsize::new(0);
143
144 let results: Vec<Result<Vec<EmbedChunk>, (PathBuf, EmbedError)>> = files
146 .par_iter()
147 .map(|file| {
148 let result = self.chunk_file(file, &repo_root);
149
150 let done = processed.fetch_add(1, Ordering::Relaxed) + 1;
152 progress.set_progress(done);
153
154 match result {
155 Ok(chunks) => {
156 let chunks_to_add = chunks.len();
159 loop {
160 let current = chunk_count.load(Ordering::Acquire);
162 let new_count = current + chunks_to_add;
163
164 if !self.limits.check_chunk_count(new_count) {
166 return Err((
167 file.clone(),
168 EmbedError::TooManyChunks {
169 count: new_count,
170 max: self.limits.max_total_chunks,
171 },
172 ));
173 }
174
175 match chunk_count.compare_exchange(
178 current,
179 new_count,
180 Ordering::AcqRel,
181 Ordering::Acquire,
182 ) {
183 Ok(_) => break, Err(_) => continue, }
186 }
187
188 Ok(chunks)
189 }
190 Err(e) => Err((file.clone(), e)),
191 }
192 })
193 .collect();
194
195 let mut all_chunks = Vec::new();
197 let mut errors = Vec::new();
198
199 for result in results {
200 match result {
201 Ok(chunks) => all_chunks.extend(chunks),
202 Err((path, err)) => errors.push((path, err)),
203 }
204 }
205
206 if !errors.is_empty() {
208 let critical: Vec<_> = errors
209 .iter()
210 .filter(|(_, e)| e.is_critical())
211 .cloned()
212 .collect();
213
214 if !critical.is_empty() {
215 return Err(EmbedError::from_file_errors(critical));
216 }
217
218 for (path, err) in &errors {
220 if err.is_skippable() {
221 progress.warn(&format!("Skipped {}: {}", path.display(), err));
222 }
223 }
224 }
225
226 if all_chunks.is_empty() {
228 return Err(EmbedError::NoChunksGenerated {
229 include_patterns: "default".to_string(),
230 exclude_patterns: "default".to_string(),
231 });
232 }
233
234 progress.set_phase("Building call graph...");
236 self.populate_called_by(&mut all_chunks);
237
238 if self.settings.enable_hierarchy {
240 progress.set_phase("Building hierarchy summaries...");
241 let hierarchy_config = HierarchyConfig {
242 min_children_for_summary: self.settings.hierarchy_min_children,
243 ..Default::default()
244 };
245 let builder = HierarchyBuilder::with_config(hierarchy_config);
246
247 builder.enrich_chunks(&mut all_chunks);
249
250 let mut summaries = builder.build_hierarchy(&all_chunks);
252
253 let token_model = self.parse_token_model(&self.settings.token_model);
255 for summary in &mut summaries {
256 summary.tokens = self.tokenizer.count(&summary.content, token_model);
257 }
258
259 all_chunks.extend(summaries);
260 }
261
262 progress.set_phase("Sorting chunks...");
267 all_chunks.par_sort_by(|a, b| {
268 a.source
269 .file
270 .cmp(&b.source.file)
271 .then_with(|| a.source.lines.0.cmp(&b.source.lines.0))
272 .then_with(|| a.source.lines.1.cmp(&b.source.lines.1))
273 .then_with(|| a.source.symbol.cmp(&b.source.symbol))
274 .then_with(|| a.id.cmp(&b.id)) });
276
277 progress.set_phase("Complete");
278 Ok(all_chunks)
279 }
280
281 fn populate_called_by(&self, chunks: &mut [EmbedChunk]) {
286 use std::collections::{BTreeMap, BTreeSet};
287
288 let mut reverse_calls: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
291
292 for chunk in chunks.iter() {
294 let caller_fqn = chunk.source.fqn.as_deref().unwrap_or(&chunk.source.symbol);
295 for callee in &chunk.context.calls {
296 reverse_calls
297 .entry(callee.clone())
298 .or_default()
299 .insert(caller_fqn.to_string());
300 }
301 }
302
303 for chunk in chunks.iter_mut() {
305 let fqn = chunk.source.fqn.as_deref().unwrap_or("");
307 let symbol = &chunk.source.symbol;
308
309 let mut called_by_set: BTreeSet<String> = BTreeSet::new();
311
312 if let Some(callers) = reverse_calls.get(fqn) {
314 called_by_set.extend(callers.iter().cloned());
315 }
316
317 if let Some(callers) = reverse_calls.get(symbol) {
319 called_by_set.extend(callers.iter().cloned());
320 }
321
322 chunk.context.called_by = called_by_set.into_iter().collect();
324 }
325 }
326
327 fn chunk_file(&self, path: &Path, repo_root: &Path) -> Result<Vec<EmbedChunk>, EmbedError> {
329 let metadata = std::fs::metadata(path)
331 .map_err(|e| EmbedError::IoError { path: path.to_path_buf(), source: e })?;
332
333 if !self.limits.check_file_size(metadata.len()) {
334 return Err(EmbedError::FileTooLarge {
335 path: path.to_path_buf(),
336 size: metadata.len(),
337 max: self.limits.max_file_size,
338 });
339 }
340
341 let mut content = std::fs::read_to_string(path)
343 .map_err(|e| EmbedError::IoError { path: path.to_path_buf(), source: e })?;
344
345 if let Some(max_line_len) = content.lines().map(|l| l.len()).max() {
348 if !self.limits.check_line_length(max_line_len) {
349 return Err(EmbedError::LineTooLong {
350 path: path.to_path_buf(),
351 length: max_line_len,
352 max: self.limits.max_line_length,
353 });
354 }
355 }
356
357 let relative_path = self.safe_relative_path(path, repo_root)?;
359
360 if let Some(ref scanner) = self.security_scanner {
362 let findings = scanner.scan(&content, &relative_path);
363 if !findings.is_empty() {
364 if self.settings.fail_on_secrets {
366 let files = findings
367 .iter()
368 .map(|f| format!(" {}:{} - {}", f.file, f.line, f.kind.name()))
369 .collect::<Vec<_>>()
370 .join("\n");
371 return Err(EmbedError::SecretsDetected {
372 count: findings.len(),
373 files,
374 });
375 }
376
377 if self.settings.redact_secrets {
379 content = scanner.redact_content(&content, &relative_path);
380 }
381 }
382 }
383 let language = self.detect_language(path);
384
385 let mut symbols = parse_file_symbols(&content, path);
387
388 symbols.sort_by(|a, b| {
390 a.start_line
391 .cmp(&b.start_line)
392 .then_with(|| a.end_line.cmp(&b.end_line))
393 .then_with(|| a.name.cmp(&b.name))
394 });
395
396 let lines: Vec<&str> = content.lines().collect();
397 let mut chunks = Vec::with_capacity(symbols.len() + 2);
398
399 for symbol in &symbols {
400 if !self.settings.include_imports
402 && matches!(symbol.kind, crate::types::SymbolKind::Import)
403 {
404 continue;
405 }
406
407 let (chunk_content, start_line, end_line) =
409 self.extract_symbol_content(&lines, symbol, self.settings.context_lines);
410
411 let token_model = self.parse_token_model(&self.settings.token_model);
413 let tokens = self.tokenizer.count(&chunk_content, token_model);
414
415 if self.settings.max_tokens > 0 && tokens > self.settings.max_tokens {
417 let split_chunks = self.split_large_symbol(
418 &chunk_content,
419 symbol,
420 &relative_path,
421 &language,
422 start_line,
423 0, )?;
425 chunks.extend(split_chunks);
426 } else {
427 let hash = hash_content(&chunk_content);
429
430 let context = self.extract_context(symbol, &chunk_content);
432
433 let fqn = self.compute_fqn(&relative_path, symbol);
435
436 chunks.push(EmbedChunk {
437 id: hash.short_id,
438 full_hash: hash.full_hash,
439 content: chunk_content,
440 tokens,
441 kind: symbol.kind.into(),
442 source: ChunkSource {
443 repo: self.repo_id.clone(),
444 file: relative_path.clone(),
445 lines: (start_line, end_line),
446 symbol: symbol.name.clone(),
447 fqn: Some(fqn),
448 language: language.clone(),
449 parent: symbol.parent.clone(),
450 visibility: symbol.visibility.into(),
451 is_test: self.is_test_code(path, symbol),
452 },
453 context,
454 part: None,
455 });
456 }
457 }
458
459 if self.settings.include_top_level && !symbols.is_empty() {
461 if let Some(top_level) = self.extract_top_level(&lines, &symbols, &relative_path, &language) {
462 chunks.push(top_level);
463 }
464 }
465
466 Ok(chunks)
467 }
468
469 fn extract_symbol_content(
471 &self,
472 lines: &[&str],
473 symbol: &Symbol,
474 context_lines: u32,
475 ) -> (String, u32, u32) {
476 let start_line = symbol.start_line.saturating_sub(1) as usize;
478 let end_line = (symbol.end_line as usize).min(lines.len());
479
480 let context_start = start_line.saturating_sub(context_lines as usize);
482 let context_end = (end_line + context_lines as usize).min(lines.len());
483
484 let content = lines[context_start..context_end].join("\n");
486
487 (
489 content,
490 (context_start + 1) as u32,
491 context_end as u32,
492 )
493 }
494
495 fn split_large_symbol(
502 &self,
503 content: &str,
504 symbol: &Symbol,
505 file: &str,
506 language: &str,
507 base_line: u32,
508 depth: u32,
509 ) -> Result<Vec<EmbedChunk>, EmbedError> {
510 if !self.limits.check_recursion_depth(depth) {
512 return Err(EmbedError::RecursionLimitExceeded {
513 depth,
514 max: self.limits.max_recursion_depth,
515 context: format!("splitting symbol {}", symbol.name),
516 });
517 }
518
519 let lines: Vec<&str> = content.lines().collect();
520 let total_lines = lines.len();
521
522 let token_model = self.parse_token_model(&self.settings.token_model);
524 let total_tokens = self.tokenizer.count(content, token_model) as usize;
525 let target_tokens = self.settings.max_tokens as usize;
526
527 if total_tokens == 0 || target_tokens == 0 {
528 return Ok(Vec::new());
529 }
530
531 let target_lines = ((total_lines * target_tokens) / total_tokens).max(1);
533
534 let overlap_tokens = self.settings.overlap_tokens as usize;
537 let overlap_lines = if overlap_tokens > 0 && total_tokens > 0 {
538 ((total_lines * overlap_tokens) / total_tokens).max(1).min(target_lines / 2)
539 } else {
540 0
541 };
542
543 let mut chunks = Vec::new();
544 let mut current_start = 0usize;
545 let mut part_num = 1u32;
546
547 let parent_hash = hash_content(content);
549
550 while current_start < total_lines {
551 let content_start = if part_num > 1 && overlap_lines > 0 {
554 current_start.saturating_sub(overlap_lines)
555 } else {
556 current_start
557 };
558 let content_end = (current_start + target_lines).min(total_lines);
559
560 let part_content = lines[content_start..content_end].join("\n");
561
562 let tokens = self.tokenizer.count(&part_content, token_model);
563
564 if tokens >= self.settings.min_tokens {
566 let hash = hash_content(&part_content);
567
568 let actual_overlap = if part_num > 1 {
570 current_start.saturating_sub(content_start) as u32
571 } else {
572 0
573 };
574
575 chunks.push(EmbedChunk {
576 id: hash.short_id,
577 full_hash: hash.full_hash,
578 content: part_content,
579 tokens,
580 kind: ChunkKind::FunctionPart, source: ChunkSource {
582 repo: self.repo_id.clone(),
583 file: file.to_string(),
584 lines: (
585 base_line + content_start as u32,
586 base_line + content_end as u32 - 1,
587 ),
588 symbol: format!("{}_part{}", symbol.name, part_num),
589 fqn: None,
590 language: language.to_string(),
591 parent: Some(symbol.name.clone()),
592 visibility: symbol.visibility.into(),
593 is_test: false,
594 },
595 context: ChunkContext {
596 signature: symbol.signature.clone(), docstring: symbol.docstring.clone(),
600 ..Default::default()
601 },
602 part: Some(ChunkPart {
603 part: part_num,
604 of: 0, parent_id: parent_hash.short_id.clone(),
606 parent_signature: symbol.signature.clone().unwrap_or_default(),
607 overlap_lines: actual_overlap,
608 }),
609 });
610
611 part_num += 1;
612 }
613
614 current_start = content_end;
615 }
616
617 let total_parts = chunks.len() as u32;
619 for chunk in &mut chunks {
620 if let Some(ref mut part) = chunk.part {
621 part.of = total_parts;
622 }
623 }
624
625 Ok(chunks)
626 }
627
628 fn extract_top_level(
630 &self,
631 lines: &[&str],
632 symbols: &[Symbol],
633 file: &str,
634 language: &str,
635 ) -> Option<EmbedChunk> {
636 if lines.is_empty() || symbols.is_empty() {
637 return None;
638 }
639
640 let mut covered = vec![false; lines.len()];
642 for symbol in symbols {
643 let start = symbol.start_line.saturating_sub(1) as usize;
644 let end = (symbol.end_line as usize).min(lines.len());
645 for i in start..end {
646 covered[i] = true;
647 }
648 }
649
650 let top_level_lines: Vec<&str> = lines
652 .iter()
653 .enumerate()
654 .filter(|(i, _)| !covered[*i])
655 .map(|(_, line)| *line)
656 .collect();
657
658 if top_level_lines.is_empty() {
659 return None;
660 }
661
662 let content = top_level_lines.join("\n").trim().to_string();
663 if content.is_empty() {
664 return None;
665 }
666
667 let token_model = self.parse_token_model(&self.settings.token_model);
668 let tokens = self.tokenizer.count(&content, token_model);
669
670 if tokens < self.settings.min_tokens {
671 return None;
672 }
673
674 let hash = hash_content(&content);
675
676 Some(EmbedChunk {
677 id: hash.short_id,
678 full_hash: hash.full_hash,
679 content,
680 tokens,
681 kind: ChunkKind::TopLevel,
682 source: ChunkSource {
683 repo: self.repo_id.clone(),
684 file: file.to_string(),
685 lines: (1, lines.len() as u32),
686 symbol: "<top_level>".to_string(),
687 fqn: None,
688 language: language.to_string(),
689 parent: None,
690 visibility: Visibility::Public,
691 is_test: false,
692 },
693 context: ChunkContext::default(),
694 part: None,
695 })
696 }
697
698 fn extract_context(&self, symbol: &Symbol, content: &str) -> ChunkContext {
700 ChunkContext {
701 docstring: symbol.docstring.clone(),
702 comments: Vec::new(), signature: symbol.signature.clone(),
704 calls: symbol.calls.clone(),
705 called_by: Vec::new(), imports: Vec::new(), tags: self.generate_tags(symbol),
708 lines_of_code: self.count_lines_of_code(content),
709 max_nesting_depth: self.calculate_nesting_depth(content),
710 }
711 }
712
713 fn count_lines_of_code(&self, content: &str) -> u32 {
715 content
716 .lines()
717 .filter(|line| {
718 let trimmed = line.trim();
719 !trimmed.is_empty()
721 && !trimmed.starts_with("//")
722 && !trimmed.starts_with('#')
723 && !trimmed.starts_with("/*")
724 && !trimmed.starts_with('*')
725 })
726 .count() as u32
727 }
728
729 fn calculate_nesting_depth(&self, content: &str) -> u32 {
734 let brace_depth = self.calculate_brace_depth(content);
736
737 if brace_depth <= 1 {
740 let indent_depth = self.calculate_indent_depth(content);
741 brace_depth.max(indent_depth)
743 } else {
744 brace_depth
745 }
746 }
747
748 fn calculate_brace_depth(&self, content: &str) -> u32 {
750 let mut max_depth = 0u32;
751 let mut current_depth = 0i32;
752
753 for ch in content.chars() {
754 match ch {
755 '{' | '(' | '[' => {
756 current_depth += 1;
757 max_depth = max_depth.max(current_depth as u32);
758 }
759 '}' | ')' | ']' => {
760 current_depth = (current_depth - 1).max(0);
761 }
762 _ => {}
763 }
764 }
765
766 max_depth
767 }
768
769 fn calculate_indent_depth(&self, content: &str) -> u32 {
772 let mut max_depth = 0u32;
773 let mut base_indent: Option<usize> = None;
774
775 for line in content.lines() {
776 let trimmed = line.trim();
778 if trimmed.is_empty() || trimmed.starts_with('#') || trimmed.starts_with("--") {
779 continue;
780 }
781
782 let leading_spaces = line.len() - line.trim_start().len();
784
785 if base_indent.is_none() {
787 base_indent = Some(leading_spaces);
788 }
789
790 let base = base_indent.unwrap_or(0);
792 if leading_spaces >= base {
793 let relative_indent = leading_spaces - base;
794 let depth = (relative_indent / 4).max(relative_indent / 2) as u32;
796 max_depth = max_depth.max(depth + 1); }
798 }
799
800 max_depth
801 }
802
803 fn generate_tags(&self, symbol: &Symbol) -> Vec<String> {
808 let mut tags = Vec::new();
809 let signature = symbol.signature.as_deref().unwrap_or("");
810 let name_lower = symbol.name.to_lowercase();
811
812 if signature.contains("async")
816 || signature.contains("await")
817 || signature.contains("suspend") {
819 tags.push("async".to_string());
820 }
821 if name_lower.contains("thread")
822 || name_lower.contains("mutex")
823 || name_lower.contains("lock")
824 || name_lower.contains("spawn")
825 || name_lower.contains("parallel")
826 || name_lower.contains("goroutine")
827 || name_lower.contains("channel")
828 || signature.contains("Mutex")
829 || signature.contains("RwLock")
830 || signature.contains("Arc")
831 || signature.contains("chan ") || signature.contains("<-chan") || signature.contains("chan<-") || signature.contains("sync.") || signature.contains("WaitGroup") {
837 tags.push("concurrency".to_string());
838 }
839
840 if name_lower.contains("password")
842 || name_lower.contains("token")
843 || name_lower.contains("secret")
844 || name_lower.contains("auth")
845 || name_lower.contains("crypt")
846 || name_lower.contains("hash")
847 || name_lower.contains("permission")
848 || signature.contains("password")
849 || signature.contains("token")
850 || signature.contains("secret")
851 {
852 tags.push("security".to_string());
853 }
854
855 if signature.contains("Error")
857 || signature.contains("Result")
858 || name_lower.contains("error")
859 || name_lower.contains("exception")
860 || name_lower.contains("panic")
861 || name_lower.contains("unwrap")
862 {
863 tags.push("error-handling".to_string());
864 }
865
866 if name_lower.contains("query")
868 || name_lower.contains("sql")
869 || name_lower.contains("database")
870 || name_lower.contains("db_")
871 || name_lower.starts_with("db")
872 || name_lower.contains("repository")
873 || name_lower.contains("transaction")
874 {
875 tags.push("database".to_string());
876 }
877
878 if name_lower.contains("http")
880 || name_lower.contains("request")
881 || name_lower.contains("response")
882 || name_lower.contains("endpoint")
883 || name_lower.contains("route")
884 || name_lower.contains("handler")
885 || name_lower.contains("middleware")
886 {
887 tags.push("http".to_string());
888 }
889
890 if name_lower.contains("command")
892 || name_lower.contains("cli")
893 || name_lower.contains("arg")
894 || name_lower.contains("flag")
895 || name_lower.contains("option")
896 || name_lower.contains("subcommand")
897 {
898 tags.push("cli".to_string());
899 }
900
901 if name_lower.contains("config")
903 || name_lower.contains("setting")
904 || name_lower.contains("preference")
905 || name_lower.contains("option")
906 || name_lower.contains("env")
907 {
908 tags.push("config".to_string());
909 }
910
911 if name_lower.contains("log")
913 || name_lower.contains("trace")
914 || name_lower.contains("debug")
915 || name_lower.contains("warn")
916 || name_lower.contains("info")
917 || name_lower.contains("metric")
918 {
919 tags.push("logging".to_string());
920 }
921
922 if name_lower.contains("cache")
924 || name_lower.contains("memoize")
925 || name_lower.contains("invalidate")
926 {
927 tags.push("cache".to_string());
928 }
929
930 if name_lower.contains("valid")
932 || name_lower.contains("check")
933 || name_lower.contains("verify")
934 || name_lower.contains("assert")
935 || name_lower.contains("sanitize")
936 {
937 tags.push("validation".to_string());
938 }
939
940 if name_lower.contains("serial")
942 || name_lower.contains("deserial")
943 || name_lower.contains("json")
944 || name_lower.contains("xml")
945 || name_lower.contains("yaml")
946 || name_lower.contains("toml")
947 || name_lower.contains("encode")
948 || name_lower.contains("decode")
949 || name_lower.contains("parse")
950 || name_lower.contains("format")
951 {
952 tags.push("serialization".to_string());
953 }
954
955 if name_lower.contains("file")
957 || name_lower.contains("read")
958 || name_lower.contains("write")
959 || name_lower.contains("path")
960 || name_lower.contains("dir")
961 || name_lower.contains("fs")
962 || name_lower.contains("io")
963 {
964 tags.push("io".to_string());
965 }
966
967 if name_lower.contains("socket")
969 || name_lower.contains("connect")
970 || name_lower.contains("network")
971 || name_lower.contains("tcp")
972 || name_lower.contains("udp")
973 || name_lower.contains("client")
974 || name_lower.contains("server")
975 {
976 tags.push("network".to_string());
977 }
978
979 if name_lower == "new"
981 || name_lower == "init"
982 || name_lower == "setup"
983 || name_lower == "create"
984 || name_lower.starts_with("new_")
985 || name_lower.starts_with("init_")
986 || name_lower.starts_with("create_")
987 || name_lower.ends_with("_new")
988 {
989 tags.push("init".to_string());
990 }
991
992 if name_lower.contains("cleanup")
994 || name_lower.contains("teardown")
995 || name_lower.contains("close")
996 || name_lower.contains("dispose")
997 || name_lower.contains("shutdown")
998 || name_lower == "drop"
999 {
1000 tags.push("cleanup".to_string());
1001 }
1002
1003 if symbol.name.starts_with("test_")
1005 || symbol.name.ends_with("_test")
1006 || symbol.name.contains("Test")
1007 || name_lower.contains("mock")
1008 || name_lower.contains("stub")
1009 || name_lower.contains("fixture")
1010 {
1011 tags.push("test".to_string());
1012 }
1013
1014 if signature.contains("deprecated") || signature.contains("Deprecated") {
1016 tags.push("deprecated".to_string());
1017 }
1018
1019 if signature.starts_with("pub fn")
1021 || signature.starts_with("pub async fn")
1022 || signature.starts_with("export")
1023 {
1024 tags.push("public-api".to_string());
1025 }
1026
1027 if name_lower.contains("model")
1029 || name_lower.contains("train")
1030 || name_lower.contains("predict")
1031 || name_lower.contains("inference")
1032 || name_lower.contains("neural")
1033 || name_lower.contains("embedding")
1034 || name_lower.contains("classifier")
1035 || name_lower.contains("regressor")
1036 || name_lower.contains("optimizer")
1037 || name_lower.contains("loss")
1038 || name_lower.contains("gradient")
1039 || name_lower.contains("backprop")
1040 || name_lower.contains("forward")
1041 || name_lower.contains("layer")
1042 || name_lower.contains("activation")
1043 || name_lower.contains("weight")
1044 || name_lower.contains("bias")
1045 || name_lower.contains("epoch")
1046 || name_lower.contains("batch")
1047 || signature.contains("torch")
1048 || signature.contains("tensorflow")
1049 || signature.contains("keras")
1050 || signature.contains("sklearn")
1051 || signature.contains("nn.")
1052 || signature.contains("nn::")
1053 {
1054 tags.push("ml".to_string());
1055 }
1056
1057 if name_lower.contains("dataframe")
1059 || name_lower.contains("dataset")
1060 || name_lower.contains("tensor")
1061 || name_lower.contains("numpy")
1062 || name_lower.contains("pandas")
1063 || name_lower.contains("array")
1064 || name_lower.contains("matrix")
1065 || name_lower.contains("vector")
1066 || name_lower.contains("feature")
1067 || name_lower.contains("preprocess")
1068 || name_lower.contains("normalize")
1069 || name_lower.contains("transform")
1070 || name_lower.contains("pipeline")
1071 || name_lower.contains("etl")
1072 || name_lower.contains("aggregate")
1073 || name_lower.contains("groupby")
1074 || name_lower.contains("pivot")
1075 || signature.contains("pd.")
1076 || signature.contains("np.")
1077 || signature.contains("DataFrame")
1078 || signature.contains("ndarray")
1079 {
1080 tags.push("data-science".to_string());
1081 }
1082
1083 tags
1084 }
1085
1086 fn compute_fqn(&self, file: &str, symbol: &Symbol) -> String {
1093 let module_path = file
1095 .strip_suffix(".rs")
1096 .or_else(|| file.strip_suffix(".py"))
1097 .or_else(|| file.strip_suffix(".ts"))
1098 .or_else(|| file.strip_suffix(".tsx"))
1099 .or_else(|| file.strip_suffix(".js"))
1100 .or_else(|| file.strip_suffix(".jsx"))
1101 .or_else(|| file.strip_suffix(".go"))
1102 .or_else(|| file.strip_suffix(".java"))
1103 .or_else(|| file.strip_suffix(".c"))
1104 .or_else(|| file.strip_suffix(".cpp"))
1105 .or_else(|| file.strip_suffix(".h"))
1106 .or_else(|| file.strip_suffix(".hpp"))
1107 .or_else(|| file.strip_suffix(".rb"))
1108 .or_else(|| file.strip_suffix(".php"))
1109 .or_else(|| file.strip_suffix(".cs"))
1110 .or_else(|| file.strip_suffix(".swift"))
1111 .or_else(|| file.strip_suffix(".kt"))
1112 .or_else(|| file.strip_suffix(".scala"))
1113 .unwrap_or(file)
1114 .replace('\\', "::") .replace('/', "::"); if let Some(ref parent) = symbol.parent {
1118 format!("{}::{}::{}", module_path, parent, symbol.name)
1119 } else {
1120 format!("{}::{}", module_path, symbol.name)
1121 }
1122 }
1123
1124 fn is_test_code(&self, path: &Path, symbol: &Symbol) -> bool {
1126 let path_str = path.to_string_lossy().to_lowercase();
1127
1128 if path_str.contains("test")
1130 || path_str.contains("spec")
1131 || path_str.contains("__tests__")
1132 {
1133 return true;
1134 }
1135
1136 let name = symbol.name.to_lowercase();
1138 if name.starts_with("test_") || name.ends_with("_test") || name.contains("_test_") {
1139 return true;
1140 }
1141
1142 false
1143 }
1144
1145 fn validate_repo_path(&self, path: &Path) -> Result<PathBuf, EmbedError> {
1147 let canonical = path
1148 .canonicalize()
1149 .map_err(|e| EmbedError::IoError { path: path.to_path_buf(), source: e })?;
1150
1151 if !canonical.is_dir() {
1153 return Err(EmbedError::NotADirectory {
1154 path: path.to_path_buf(),
1155 });
1156 }
1157
1158 Ok(canonical)
1159 }
1160
1161 fn safe_relative_path(&self, path: &Path, repo_root: &Path) -> Result<String, EmbedError> {
1163 let canonical = path
1164 .canonicalize()
1165 .map_err(|e| EmbedError::IoError { path: path.to_path_buf(), source: e })?;
1166
1167 if !canonical.starts_with(repo_root) {
1169 return Err(EmbedError::PathTraversal {
1170 path: canonical,
1171 repo_root: repo_root.to_path_buf(),
1172 });
1173 }
1174
1175 Ok(canonical
1177 .strip_prefix(repo_root)
1178 .unwrap_or(&canonical)
1179 .to_string_lossy()
1180 .replace('\\', "/"))
1181 }
1182
1183 fn discover_files(&self, repo_root: &Path) -> Result<Vec<PathBuf>, EmbedError> {
1185 use glob::Pattern;
1186 use ignore::WalkBuilder;
1187
1188 let mut files = Vec::new();
1189
1190 let mut include_patterns = Vec::new();
1192 for pattern_str in &self.settings.include_patterns {
1193 match Pattern::new(pattern_str) {
1194 Ok(pattern) => include_patterns.push(pattern),
1195 Err(e) => {
1196 return Err(EmbedError::InvalidPattern {
1197 pattern: pattern_str.clone(),
1198 reason: e.to_string(),
1199 });
1200 }
1201 }
1202 }
1203
1204 let mut exclude_patterns = Vec::new();
1206 for pattern_str in &self.settings.exclude_patterns {
1207 match Pattern::new(pattern_str) {
1208 Ok(pattern) => exclude_patterns.push(pattern),
1209 Err(e) => {
1210 return Err(EmbedError::InvalidPattern {
1211 pattern: pattern_str.clone(),
1212 reason: e.to_string(),
1213 });
1214 }
1215 }
1216 }
1217
1218 let walker = WalkBuilder::new(repo_root)
1219 .hidden(false) .git_ignore(true) .git_global(true)
1222 .git_exclude(true)
1223 .follow_links(false) .build();
1225
1226 for entry in walker {
1227 let entry =
1228 entry.map_err(|e| EmbedError::IoError { path: repo_root.to_path_buf(), source: std::io::Error::new(std::io::ErrorKind::Other, e.to_string()) })?;
1229
1230 let path = entry.path();
1231
1232 if !path.is_file() {
1234 continue;
1235 }
1236
1237 let relative_path = path
1239 .strip_prefix(repo_root)
1240 .unwrap_or(path)
1241 .to_string_lossy();
1242
1243 if !include_patterns.is_empty()
1245 && !include_patterns
1246 .iter()
1247 .any(|p| p.matches(&relative_path))
1248 {
1249 continue;
1250 }
1251
1252 if exclude_patterns.iter().any(|p| p.matches(&relative_path)) {
1254 continue;
1255 }
1256
1257 if !self.settings.include_tests && self.is_test_file(path) {
1259 continue;
1260 }
1261
1262 let ext = match path.extension().and_then(|e| e.to_str()) {
1264 Some(e) => e,
1265 None => continue,
1266 };
1267 if Language::from_extension(ext).is_none() {
1268 continue;
1269 }
1270
1271 files.push(path.to_path_buf());
1272 }
1273
1274 Ok(files)
1275 }
1276
1277 fn is_test_file(&self, path: &Path) -> bool {
1279 let path_str = path.to_string_lossy().to_lowercase();
1280
1281 if path_str.contains("/tests/")
1283 || path_str.contains("\\tests\\")
1284 || path_str.contains("/test/")
1285 || path_str.contains("\\test\\")
1286 || path_str.contains("/__tests__/")
1287 || path_str.contains("\\__tests__\\")
1288 || path_str.contains("/spec/")
1289 || path_str.contains("\\spec\\")
1290 {
1291 return true;
1292 }
1293
1294 let filename = path
1296 .file_name()
1297 .and_then(|n| n.to_str())
1298 .unwrap_or("")
1299 .to_lowercase();
1300
1301 filename.starts_with("test_")
1302 || filename.ends_with("_test.rs")
1303 || filename.ends_with("_test.py")
1304 || filename.ends_with("_test.go")
1305 || filename.ends_with(".test.ts")
1306 || filename.ends_with(".test.js")
1307 || filename.ends_with(".test.tsx")
1308 || filename.ends_with(".test.jsx")
1309 || filename.ends_with(".spec.ts")
1310 || filename.ends_with(".spec.js")
1311 || filename.ends_with("_spec.rb")
1312 }
1313
1314 fn detect_language(&self, path: &Path) -> String {
1316 path.extension()
1317 .and_then(|e| e.to_str())
1318 .and_then(Language::from_extension)
1319 .map(|l| l.display_name().to_string())
1320 .unwrap_or_else(|| "unknown".to_string())
1321 }
1322
1323 fn parse_token_model(&self, model: &str) -> TokenModel {
1325 TokenModel::from_model_name(model).unwrap_or(TokenModel::Claude)
1326 }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331 use super::*;
1332 use crate::embedding::progress::QuietProgress;
1333 use tempfile::TempDir;
1334
1335 fn create_test_file(dir: &Path, name: &str, content: &str) {
1336 let path = dir.join(name);
1337 if let Some(parent) = path.parent() {
1338 std::fs::create_dir_all(parent).unwrap();
1339 }
1340 std::fs::write(path, content).unwrap();
1341 }
1342
1343 #[test]
1344 fn test_chunker_creation() {
1345 let settings = EmbedSettings::default();
1346 let limits = ResourceLimits::default();
1347 let chunker = EmbedChunker::new(settings, limits);
1348 assert!(chunker.settings.max_tokens > 0);
1349 }
1350
1351 #[test]
1352 fn test_chunk_single_file() {
1353 let temp_dir = TempDir::new().unwrap();
1354 let rust_code = r#"
1355/// A test function
1356fn hello() {
1357 println!("Hello, world!");
1358}
1359
1360fn goodbye() {
1361 println!("Goodbye!");
1362}
1363"#;
1364 create_test_file(temp_dir.path(), "test.rs", rust_code);
1365
1366 let settings = EmbedSettings::default();
1367 let chunker = EmbedChunker::with_defaults(settings);
1368 let progress = QuietProgress;
1369
1370 let chunks = chunker.chunk_repository(temp_dir.path(), &progress).unwrap();
1371
1372 assert!(!chunks.is_empty());
1374
1375 for i in 1..chunks.len() {
1377 assert!(chunks[i - 1].source.file <= chunks[i].source.file);
1378 }
1379 }
1380
1381 #[test]
1382 fn test_determinism() {
1383 let temp_dir = TempDir::new().unwrap();
1384 create_test_file(temp_dir.path(), "a.rs", "fn foo() {}");
1385 create_test_file(temp_dir.path(), "b.rs", "fn bar() {}");
1386
1387 let settings = EmbedSettings::default();
1388 let progress = QuietProgress;
1389
1390 let results: Vec<Vec<EmbedChunk>> = (0..3)
1391 .map(|_| {
1392 let chunker = EmbedChunker::with_defaults(settings.clone());
1393 chunker.chunk_repository(temp_dir.path(), &progress).unwrap()
1394 })
1395 .collect();
1396
1397 for i in 1..results.len() {
1399 assert_eq!(results[0].len(), results[i].len());
1400 for j in 0..results[0].len() {
1401 assert_eq!(results[0][j].id, results[i][j].id);
1402 }
1403 }
1404 }
1405
1406 #[test]
1407 fn test_file_too_large() {
1408 let temp_dir = TempDir::new().unwrap();
1409 let large_content = "x".repeat(200);
1411 create_test_file(temp_dir.path(), "large.rs", &large_content);
1412
1413 let settings = EmbedSettings::default();
1414 let limits = ResourceLimits::default().with_max_file_size(100);
1415 let chunker = EmbedChunker::new(settings, limits);
1416 let progress = QuietProgress;
1417
1418 let result = chunker.chunk_repository(temp_dir.path(), &progress);
1420
1421 assert!(result.is_err());
1424 }
1425
1426 #[test]
1427 fn test_empty_directory() {
1428 let temp_dir = TempDir::new().unwrap();
1429
1430 let settings = EmbedSettings::default();
1431 let chunker = EmbedChunker::with_defaults(settings);
1432 let progress = QuietProgress;
1433
1434 let result = chunker.chunk_repository(temp_dir.path(), &progress);
1435
1436 assert!(matches!(result, Err(EmbedError::NoChunksGenerated { .. })));
1437 }
1438
1439 #[test]
1440 fn test_language_detection() {
1441 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1442
1443 assert_eq!(
1444 chunker.detect_language(Path::new("test.rs")),
1445 "Rust"
1446 );
1447 assert_eq!(
1448 chunker.detect_language(Path::new("test.py")),
1449 "Python"
1450 );
1451 assert_eq!(
1452 chunker.detect_language(Path::new("test.unknown")),
1453 "unknown"
1454 );
1455 }
1456
1457 #[test]
1458 fn test_is_test_code() {
1459 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1460
1461 let test_symbol = Symbol::new("test_foo", crate::types::SymbolKind::Function);
1462 assert!(chunker.is_test_code(Path::new("foo.rs"), &test_symbol));
1463
1464 let normal_symbol = Symbol::new("foo", crate::types::SymbolKind::Function);
1465 assert!(!chunker.is_test_code(Path::new("src/lib.rs"), &normal_symbol));
1466
1467 assert!(chunker.is_test_code(Path::new("tests/test_foo.rs"), &normal_symbol));
1469 }
1470
1471 #[test]
1472 fn test_generate_tags() {
1473 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1474
1475 let mut symbol = Symbol::new("authenticate_user", crate::types::SymbolKind::Function);
1476 symbol.signature = Some("async fn authenticate_user(password: &str)".to_string());
1477
1478 let tags = chunker.generate_tags(&symbol);
1479 assert!(tags.contains(&"async".to_string()));
1480 assert!(tags.contains(&"security".to_string()));
1481 }
1482
1483 #[test]
1484 fn test_generate_tags_kotlin_suspend() {
1485 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1486
1487 let mut symbol = Symbol::new("fetchData", crate::types::SymbolKind::Function);
1488 symbol.signature = Some("suspend fun fetchData(): Result<Data>".to_string());
1489
1490 let tags = chunker.generate_tags(&symbol);
1491 assert!(tags.contains(&"async".to_string()), "Kotlin suspend should be tagged as async");
1492 }
1493
1494 #[test]
1495 fn test_generate_tags_go_concurrency() {
1496 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1497
1498 let mut symbol = Symbol::new("processMessages", crate::types::SymbolKind::Function);
1499 symbol.signature = Some("func processMessages(ch chan string)".to_string());
1500
1501 let tags = chunker.generate_tags(&symbol);
1502 assert!(tags.contains(&"concurrency".to_string()), "Go channels should be tagged as concurrency");
1503 }
1504
1505 #[test]
1506 fn test_generate_tags_ml() {
1507 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1508
1509 let mut symbol = Symbol::new("train_model", crate::types::SymbolKind::Function);
1511 symbol.signature = Some("def train_model(epochs: int, batch_size: int)".to_string());
1512 let tags = chunker.generate_tags(&symbol);
1513 assert!(tags.contains(&"ml".to_string()), "train_model should be tagged as ml");
1514
1515 let mut symbol2 = Symbol::new("forward_pass", crate::types::SymbolKind::Function);
1517 symbol2.signature = Some("def forward_pass(self, x: torch.Tensor)".to_string());
1518 let tags2 = chunker.generate_tags(&symbol2);
1519 assert!(tags2.contains(&"ml".to_string()), "torch.Tensor in signature should be tagged as ml");
1520
1521 let mut symbol3 = Symbol::new("ImageClassifier", crate::types::SymbolKind::Class);
1523 symbol3.signature = Some("class ImageClassifier(nn.Module)".to_string());
1524 let tags3 = chunker.generate_tags(&symbol3);
1525 assert!(tags3.contains(&"ml".to_string()), "nn.Module should be tagged as ml");
1526 }
1527
1528 #[test]
1529 fn test_generate_tags_data_science() {
1530 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1531
1532 let mut symbol = Symbol::new("preprocess_dataframe", crate::types::SymbolKind::Function);
1534 symbol.signature = Some("def preprocess_dataframe(df: pd.DataFrame)".to_string());
1535 let tags = chunker.generate_tags(&symbol);
1536 assert!(tags.contains(&"data-science".to_string()), "DataFrame should be tagged as data-science");
1537
1538 let mut symbol2 = Symbol::new("normalize_array", crate::types::SymbolKind::Function);
1540 symbol2.signature = Some("def normalize_array(arr: np.ndarray)".to_string());
1541 let tags2 = chunker.generate_tags(&symbol2);
1542 assert!(tags2.contains(&"data-science".to_string()), "np.ndarray should be tagged as data-science");
1543
1544 let symbol3 = Symbol::new("run_etl_pipeline", crate::types::SymbolKind::Function);
1546 let tags3 = chunker.generate_tags(&symbol3);
1547 assert!(tags3.contains(&"data-science".to_string()), "etl should be tagged as data-science");
1548 }
1549
1550 #[test]
1551 fn test_brace_nesting_depth() {
1552 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1553
1554 let code = "fn foo() { if x { if y { } } }";
1556 assert_eq!(chunker.calculate_brace_depth(code), 3);
1557
1558 let flat = "let x = 1;";
1560 assert_eq!(chunker.calculate_brace_depth(flat), 0);
1561
1562 let deep = "fn f() { let a = vec![HashMap::new()]; }";
1564 assert!(chunker.calculate_brace_depth(deep) >= 2);
1565 }
1566
1567 #[test]
1568 fn test_indent_nesting_depth() {
1569 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1570
1571 let python_code = r#"
1573def foo():
1574 if x:
1575 if y:
1576 do_something()
1577 else:
1578 other()
1579"#;
1580 let depth = chunker.calculate_indent_depth(python_code);
1581 assert!(depth >= 3, "Should detect indentation nesting, got {}", depth);
1582
1583 let flat = "x = 1\ny = 2\n";
1585 assert!(chunker.calculate_indent_depth(flat) <= 1);
1586 }
1587
1588 #[test]
1589 fn test_combined_nesting_depth() {
1590 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1591
1592 let rust_code = "fn foo() { if x { match y { A => {}, B => {} } } }";
1594 let depth = chunker.calculate_nesting_depth(rust_code);
1595 assert!(depth >= 3, "Should use brace depth for Rust-like code");
1596
1597 let python_code = "def foo():\n if x:\n y()\n";
1599 let depth = chunker.calculate_nesting_depth(python_code);
1600 assert!(depth >= 1, "Should use indent depth for Python-like code");
1601 }
1602
1603 #[test]
1604 fn test_lines_of_code() {
1605 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1606
1607 let code = r#"
1608// This is a comment
1609fn foo() {
1610 let x = 1;
1611
1612 // Another comment
1613 let y = 2;
1614}
1615"#;
1616 let loc = chunker.count_lines_of_code(code);
1617 assert!(loc >= 4 && loc <= 5, "LOC should be ~4, got {}", loc);
1620 }
1621
1622 #[test]
1623 fn test_line_too_long_error() {
1624 let temp_dir = TempDir::new().unwrap();
1625
1626 let long_line = "x".repeat(50_000);
1628 let content = format!("fn foo() {{ {} }}", long_line);
1629 create_test_file(temp_dir.path(), "minified.rs", &content);
1630
1631 let settings = EmbedSettings::default();
1632 let limits = ResourceLimits::default().with_max_line_length(10_000);
1634 let chunker = EmbedChunker::new(settings, limits);
1635 let progress = QuietProgress;
1636
1637 let result = chunker.chunk_repository(temp_dir.path(), &progress);
1638
1639 assert!(result.is_err(), "Should reject files with very long lines");
1641 }
1642
1643 #[test]
1644 fn test_hierarchical_chunking_integration() {
1645 let temp_dir = TempDir::new().unwrap();
1646
1647 let rust_code = r#"
1649/// A user account
1650pub struct User {
1651 pub name: String,
1652 pub email: String,
1653}
1654
1655impl User {
1656 /// Create a new user
1657 pub fn new(name: String, email: String) -> Self {
1658 Self { name, email }
1659 }
1660
1661 /// Get the user's display name
1662 pub fn display_name(&self) -> &str {
1663 &self.name
1664 }
1665
1666 /// Validate the user's email
1667 pub fn validate_email(&self) -> bool {
1668 self.email.contains('@')
1669 }
1670}
1671"#;
1672 create_test_file(temp_dir.path(), "user.rs", rust_code);
1673
1674 let settings_no_hierarchy = EmbedSettings {
1676 enable_hierarchy: false,
1677 ..Default::default()
1678 };
1679 let chunker_no_hierarchy = EmbedChunker::with_defaults(settings_no_hierarchy);
1680 let progress = QuietProgress;
1681 let chunks_no_hierarchy = chunker_no_hierarchy
1682 .chunk_repository(temp_dir.path(), &progress)
1683 .unwrap();
1684
1685 let settings_with_hierarchy = EmbedSettings {
1687 enable_hierarchy: true,
1688 hierarchy_min_children: 2,
1689 ..Default::default()
1690 };
1691 let chunker_with_hierarchy = EmbedChunker::with_defaults(settings_with_hierarchy);
1692 let chunks_with_hierarchy = chunker_with_hierarchy
1693 .chunk_repository(temp_dir.path(), &progress)
1694 .unwrap();
1695
1696 assert!(
1698 chunks_with_hierarchy.len() >= chunks_no_hierarchy.len(),
1699 "Hierarchy should produce at least as many chunks: {} vs {}",
1700 chunks_with_hierarchy.len(),
1701 chunks_no_hierarchy.len()
1702 );
1703
1704 let summary_chunks: Vec<_> = chunks_with_hierarchy
1706 .iter()
1707 .filter(|c| matches!(c.kind, ChunkKind::Module)) .collect();
1709
1710 if !summary_chunks.is_empty() {
1713 for summary in &summary_chunks {
1715 assert!(
1716 !summary.content.is_empty(),
1717 "Summary chunk should have content"
1718 );
1719 }
1720 }
1721
1722 let chunks_with_hierarchy_2 = chunker_with_hierarchy
1724 .chunk_repository(temp_dir.path(), &progress)
1725 .unwrap();
1726 assert_eq!(
1727 chunks_with_hierarchy.len(),
1728 chunks_with_hierarchy_2.len(),
1729 "Hierarchical chunking should be deterministic"
1730 );
1731 for (c1, c2) in chunks_with_hierarchy.iter().zip(chunks_with_hierarchy_2.iter()) {
1732 assert_eq!(c1.id, c2.id, "Chunk IDs should be identical across runs");
1733 }
1734 }
1735}