1use std::path::{Path, PathBuf};
21use std::sync::atomic::{AtomicUsize, Ordering};
22
23use rayon::prelude::*;
24
25use crate::parser::{parse_file_symbols, Language};
26use crate::security::SecurityScanner;
27use crate::tokenizer::{TokenModel, Tokenizer};
28use crate::types::Symbol;
29
30use super::error::EmbedError;
31use super::hasher::hash_content;
32use super::hierarchy::{HierarchyBuilder, HierarchyConfig};
33use super::limits::ResourceLimits;
34use super::progress::ProgressReporter;
35use super::types::{
36 ChunkContext, ChunkKind, ChunkPart, ChunkSource, EmbedChunk, EmbedSettings, RepoIdentifier,
37 Visibility,
38};
39
40pub struct EmbedChunker {
42 settings: EmbedSettings,
43 limits: ResourceLimits,
44 tokenizer: Tokenizer,
45 security_scanner: Option<SecurityScanner>,
46 repo_id: RepoIdentifier,
48}
49
50impl EmbedChunker {
51 pub fn new(settings: EmbedSettings, limits: ResourceLimits) -> Self {
53 let security_scanner = if settings.scan_secrets {
55 Some(SecurityScanner::new())
56 } else {
57 None
58 };
59
60 Self {
61 settings,
62 limits,
63 tokenizer: Tokenizer::new(),
64 security_scanner,
65 repo_id: RepoIdentifier::default(),
66 }
67 }
68
69 pub fn with_defaults(settings: EmbedSettings) -> Self {
71 Self::new(settings, ResourceLimits::default())
72 }
73
74 pub fn with_repo_id(mut self, repo_id: RepoIdentifier) -> Self {
88 self.repo_id = repo_id;
89 self
90 }
91
92 pub fn set_repo_id(&mut self, repo_id: RepoIdentifier) {
94 self.repo_id = repo_id;
95 }
96
97 pub fn repo_id(&self) -> &RepoIdentifier {
99 &self.repo_id
100 }
101
102 pub fn chunk_repository(
111 &self,
112 repo_path: &Path,
113 progress: &dyn ProgressReporter,
114 ) -> Result<Vec<EmbedChunk>, EmbedError> {
115 let repo_root = self.validate_repo_path(repo_path)?;
117
118 progress.set_phase("Scanning repository...");
120 let mut files = self.discover_files(&repo_root)?;
121 files.sort(); progress.set_total(files.len());
123
124 if files.is_empty() {
125 return Err(EmbedError::NoChunksGenerated {
126 include_patterns: "default".to_owned(),
127 exclude_patterns: "default".to_owned(),
128 });
129 }
130
131 if !self.limits.check_file_count(files.len()) {
133 return Err(EmbedError::TooManyFiles {
134 count: files.len(),
135 max: self.limits.max_files,
136 });
137 }
138
139 progress.set_phase("Parsing and chunking...");
141 let chunk_count = AtomicUsize::new(0);
142 let processed = AtomicUsize::new(0);
143
144 let results: Vec<Result<Vec<EmbedChunk>, (PathBuf, EmbedError)>> = files
146 .par_iter()
147 .map(|file| {
148 let result = self.chunk_file(file, &repo_root);
149
150 let done = processed.fetch_add(1, Ordering::Relaxed) + 1;
152 progress.set_progress(done);
153
154 match result {
155 Ok(chunks) => {
156 let chunks_to_add = chunks.len();
159 loop {
160 let current = chunk_count.load(Ordering::Acquire);
162 let new_count = current + chunks_to_add;
163
164 if !self.limits.check_chunk_count(new_count) {
166 return Err((
167 file.clone(),
168 EmbedError::TooManyChunks {
169 count: new_count,
170 max: self.limits.max_total_chunks,
171 },
172 ));
173 }
174
175 match chunk_count.compare_exchange(
178 current,
179 new_count,
180 Ordering::AcqRel,
181 Ordering::Acquire,
182 ) {
183 Ok(_) => break, Err(_) => continue, }
186 }
187
188 Ok(chunks)
189 },
190 Err(e) => Err((file.clone(), e)),
191 }
192 })
193 .collect();
194
195 let mut all_chunks = Vec::new();
197 let mut errors = Vec::new();
198
199 for result in results {
200 match result {
201 Ok(chunks) => all_chunks.extend(chunks),
202 Err((path, err)) => errors.push((path, err)),
203 }
204 }
205
206 if !errors.is_empty() {
208 let critical: Vec<_> = errors
209 .iter()
210 .filter(|(_, e)| e.is_critical())
211 .cloned()
212 .collect();
213
214 if !critical.is_empty() {
215 return Err(EmbedError::from_file_errors(critical));
216 }
217
218 for (path, err) in &errors {
220 if err.is_skippable() {
221 progress.warn(&format!("Skipped {}: {}", path.display(), err));
222 }
223 }
224 }
225
226 if all_chunks.is_empty() {
228 return Err(EmbedError::NoChunksGenerated {
229 include_patterns: "default".to_owned(),
230 exclude_patterns: "default".to_owned(),
231 });
232 }
233
234 progress.set_phase("Building call graph...");
236 self.populate_called_by(&mut all_chunks);
237
238 if self.settings.enable_hierarchy {
240 progress.set_phase("Building hierarchy summaries...");
241 let hierarchy_config = HierarchyConfig {
242 min_children_for_summary: self.settings.hierarchy_min_children,
243 ..Default::default()
244 };
245 let builder = HierarchyBuilder::with_config(hierarchy_config);
246
247 builder.enrich_chunks(&mut all_chunks);
249
250 let mut summaries = builder.build_hierarchy(&all_chunks);
252
253 let token_model = self.parse_token_model(&self.settings.token_model);
255 for summary in &mut summaries {
256 summary.tokens = self.tokenizer.count(&summary.content, token_model);
257 }
258
259 all_chunks.extend(summaries);
260 }
261
262 progress.set_phase("Sorting chunks...");
267 all_chunks.par_sort_by(|a, b| {
268 a.source
269 .file
270 .cmp(&b.source.file)
271 .then_with(|| a.source.lines.0.cmp(&b.source.lines.0))
272 .then_with(|| a.source.lines.1.cmp(&b.source.lines.1))
273 .then_with(|| a.source.symbol.cmp(&b.source.symbol))
274 .then_with(|| a.id.cmp(&b.id)) });
276
277 progress.set_phase("Complete");
278 Ok(all_chunks)
279 }
280
281 fn populate_called_by(&self, chunks: &mut [EmbedChunk]) {
286 use std::collections::{BTreeMap, BTreeSet};
287
288 let mut reverse_calls: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
291
292 for chunk in chunks.iter() {
294 let caller_fqn = chunk.source.fqn.as_deref().unwrap_or(&chunk.source.symbol);
295 for callee in &chunk.context.calls {
296 reverse_calls
297 .entry(callee.clone())
298 .or_default()
299 .insert(caller_fqn.to_owned());
300 }
301 }
302
303 for chunk in chunks.iter_mut() {
305 let fqn = chunk.source.fqn.as_deref().unwrap_or("");
307 let symbol = &chunk.source.symbol;
308
309 let mut called_by_set: BTreeSet<String> = BTreeSet::new();
311
312 if let Some(callers) = reverse_calls.get(fqn) {
314 called_by_set.extend(callers.iter().cloned());
315 }
316
317 if let Some(callers) = reverse_calls.get(symbol) {
319 called_by_set.extend(callers.iter().cloned());
320 }
321
322 chunk.context.called_by = called_by_set.into_iter().collect();
324 }
325 }
326
327 fn chunk_file(&self, path: &Path, repo_root: &Path) -> Result<Vec<EmbedChunk>, EmbedError> {
329 let metadata = std::fs::metadata(path)
331 .map_err(|e| EmbedError::IoError { path: path.to_path_buf(), source: e })?;
332
333 if !self.limits.check_file_size(metadata.len()) {
334 return Err(EmbedError::FileTooLarge {
335 path: path.to_path_buf(),
336 size: metadata.len(),
337 max: self.limits.max_file_size,
338 });
339 }
340
341 let mut content = std::fs::read_to_string(path)
343 .map_err(|e| EmbedError::IoError { path: path.to_path_buf(), source: e })?;
344
345 if let Some(max_line_len) = content.lines().map(|l| l.len()).max() {
348 if !self.limits.check_line_length(max_line_len) {
349 return Err(EmbedError::LineTooLong {
350 path: path.to_path_buf(),
351 length: max_line_len,
352 max: self.limits.max_line_length,
353 });
354 }
355 }
356
357 let relative_path = self.safe_relative_path(path, repo_root)?;
359
360 if let Some(ref scanner) = self.security_scanner {
362 let findings = scanner.scan(&content, &relative_path);
363 if !findings.is_empty() {
364 if self.settings.fail_on_secrets {
366 let files = findings
367 .iter()
368 .map(|f| format!(" {}:{} - {}", f.file, f.line, f.kind.name()))
369 .collect::<Vec<_>>()
370 .join("\n");
371 return Err(EmbedError::SecretsDetected { count: findings.len(), files });
372 }
373
374 if self.settings.redact_secrets {
376 content = scanner.redact_content(&content, &relative_path);
377 }
378 }
379 }
380 let language = self.detect_language(path);
381
382 let mut symbols = parse_file_symbols(&content, path);
384
385 symbols.sort_by(|a, b| {
387 a.start_line
388 .cmp(&b.start_line)
389 .then_with(|| a.end_line.cmp(&b.end_line))
390 .then_with(|| a.name.cmp(&b.name))
391 });
392
393 let lines: Vec<&str> = content.lines().collect();
394 let mut chunks = Vec::with_capacity(symbols.len() + 2);
395
396 for symbol in &symbols {
397 if !self.settings.include_imports
399 && matches!(symbol.kind, crate::types::SymbolKind::Import)
400 {
401 continue;
402 }
403
404 let (chunk_content, start_line, end_line) =
406 self.extract_symbol_content(&lines, symbol, self.settings.context_lines);
407
408 let token_model = self.parse_token_model(&self.settings.token_model);
410 let tokens = self.tokenizer.count(&chunk_content, token_model);
411
412 if self.settings.max_tokens > 0 && tokens > self.settings.max_tokens {
414 let split_chunks = self.split_large_symbol(
415 &chunk_content,
416 symbol,
417 &relative_path,
418 &language,
419 start_line,
420 0, )?;
422 chunks.extend(split_chunks);
423 } else {
424 let hash = hash_content(&chunk_content);
426
427 let context = self.extract_context(symbol, &chunk_content);
429
430 let fqn = self.compute_fqn(&relative_path, symbol);
432
433 chunks.push(EmbedChunk {
434 id: hash.short_id,
435 full_hash: hash.full_hash,
436 content: chunk_content,
437 tokens,
438 kind: symbol.kind.into(),
439 source: ChunkSource {
440 repo: self.repo_id.clone(),
441 file: relative_path.clone(),
442 lines: (start_line, end_line),
443 symbol: symbol.name.clone(),
444 fqn: Some(fqn),
445 language: language.clone(),
446 parent: symbol.parent.clone(),
447 visibility: symbol.visibility.into(),
448 is_test: self.is_test_code(path, symbol),
449 },
450 context,
451 part: None,
452 });
453 }
454 }
455
456 if self.settings.include_top_level && !symbols.is_empty() {
458 if let Some(top_level) =
459 self.extract_top_level(&lines, &symbols, &relative_path, &language)
460 {
461 chunks.push(top_level);
462 }
463 }
464
465 Ok(chunks)
466 }
467
468 fn extract_symbol_content(
470 &self,
471 lines: &[&str],
472 symbol: &Symbol,
473 context_lines: u32,
474 ) -> (String, u32, u32) {
475 let start_line = symbol.start_line.saturating_sub(1) as usize;
477 let end_line = (symbol.end_line as usize).min(lines.len());
478
479 let context_start = start_line.saturating_sub(context_lines as usize);
481 let context_end = (end_line + context_lines as usize).min(lines.len());
482
483 let content = lines[context_start..context_end].join("\n");
485
486 (content, (context_start + 1) as u32, context_end as u32)
488 }
489
490 fn split_large_symbol(
497 &self,
498 content: &str,
499 symbol: &Symbol,
500 file: &str,
501 language: &str,
502 base_line: u32,
503 depth: u32,
504 ) -> Result<Vec<EmbedChunk>, EmbedError> {
505 if !self.limits.check_recursion_depth(depth) {
507 return Err(EmbedError::RecursionLimitExceeded {
508 depth,
509 max: self.limits.max_recursion_depth,
510 context: format!("splitting symbol {}", symbol.name),
511 });
512 }
513
514 let lines: Vec<&str> = content.lines().collect();
515 let total_lines = lines.len();
516
517 let token_model = self.parse_token_model(&self.settings.token_model);
519 let total_tokens = self.tokenizer.count(content, token_model) as usize;
520 let target_tokens = self.settings.max_tokens as usize;
521
522 if total_tokens == 0 || target_tokens == 0 {
523 return Ok(Vec::new());
524 }
525
526 let target_lines = ((total_lines * target_tokens) / total_tokens).max(1);
528
529 let overlap_tokens = self.settings.overlap_tokens as usize;
532 let overlap_lines = if overlap_tokens > 0 && total_tokens > 0 {
533 ((total_lines * overlap_tokens) / total_tokens)
534 .max(1)
535 .min(target_lines / 2)
536 } else {
537 0
538 };
539
540 let mut chunks = Vec::new();
541 let mut current_start = 0usize;
542 let mut part_num = 1u32;
543
544 let parent_hash = hash_content(content);
546
547 while current_start < total_lines {
548 let content_start = if part_num > 1 && overlap_lines > 0 {
551 current_start.saturating_sub(overlap_lines)
552 } else {
553 current_start
554 };
555 let content_end = (current_start + target_lines).min(total_lines);
556
557 let part_content = lines[content_start..content_end].join("\n");
558
559 let tokens = self.tokenizer.count(&part_content, token_model);
560
561 if tokens >= self.settings.min_tokens {
563 let hash = hash_content(&part_content);
564
565 let actual_overlap = if part_num > 1 {
567 current_start.saturating_sub(content_start) as u32
568 } else {
569 0
570 };
571
572 chunks.push(EmbedChunk {
573 id: hash.short_id,
574 full_hash: hash.full_hash,
575 content: part_content,
576 tokens,
577 kind: ChunkKind::FunctionPart, source: ChunkSource {
579 repo: self.repo_id.clone(),
580 file: file.to_owned(),
581 lines: (
582 base_line + content_start as u32,
583 base_line + content_end as u32 - 1,
584 ),
585 symbol: format!("{}_part{}", symbol.name, part_num),
586 fqn: None,
587 language: language.to_owned(),
588 parent: Some(symbol.name.clone()),
589 visibility: symbol.visibility.into(),
590 is_test: false,
591 },
592 context: ChunkContext {
593 signature: symbol.signature.clone(), docstring: symbol.docstring.clone(),
597 ..Default::default()
598 },
599 part: Some(ChunkPart {
600 part: part_num,
601 of: 0, parent_id: parent_hash.short_id.clone(),
603 parent_signature: symbol.signature.clone().unwrap_or_default(),
604 overlap_lines: actual_overlap,
605 }),
606 });
607
608 part_num += 1;
609 }
610
611 current_start = content_end;
612 }
613
614 let total_parts = chunks.len() as u32;
616 for chunk in &mut chunks {
617 if let Some(ref mut part) = chunk.part {
618 part.of = total_parts;
619 }
620 }
621
622 Ok(chunks)
623 }
624
625 fn extract_top_level(
627 &self,
628 lines: &[&str],
629 symbols: &[Symbol],
630 file: &str,
631 language: &str,
632 ) -> Option<EmbedChunk> {
633 if lines.is_empty() || symbols.is_empty() {
634 return None;
635 }
636
637 let mut covered = vec![false; lines.len()];
639 for symbol in symbols {
640 let start = symbol.start_line.saturating_sub(1) as usize;
641 let end = (symbol.end_line as usize).min(lines.len());
642 for i in start..end {
643 covered[i] = true;
644 }
645 }
646
647 let top_level_lines: Vec<&str> = lines
649 .iter()
650 .enumerate()
651 .filter(|(i, _)| !covered[*i])
652 .map(|(_, line)| *line)
653 .collect();
654
655 if top_level_lines.is_empty() {
656 return None;
657 }
658
659 let content = top_level_lines.join("\n").trim().to_owned();
660 if content.is_empty() {
661 return None;
662 }
663
664 let token_model = self.parse_token_model(&self.settings.token_model);
665 let tokens = self.tokenizer.count(&content, token_model);
666
667 if tokens < self.settings.min_tokens {
668 return None;
669 }
670
671 let hash = hash_content(&content);
672
673 Some(EmbedChunk {
674 id: hash.short_id,
675 full_hash: hash.full_hash,
676 content,
677 tokens,
678 kind: ChunkKind::TopLevel,
679 source: ChunkSource {
680 repo: self.repo_id.clone(),
681 file: file.to_owned(),
682 lines: (1, lines.len() as u32),
683 symbol: "<top_level>".to_owned(),
684 fqn: None,
685 language: language.to_owned(),
686 parent: None,
687 visibility: Visibility::Public,
688 is_test: false,
689 },
690 context: ChunkContext::default(),
691 part: None,
692 })
693 }
694
695 fn extract_context(&self, symbol: &Symbol, content: &str) -> ChunkContext {
697 ChunkContext {
698 docstring: symbol.docstring.clone(),
699 comments: Vec::new(), signature: symbol.signature.clone(),
701 calls: symbol.calls.clone(),
702 called_by: Vec::new(), imports: Vec::new(), tags: self.generate_tags(symbol),
705 lines_of_code: self.count_lines_of_code(content),
706 max_nesting_depth: self.calculate_nesting_depth(content),
707 }
708 }
709
710 fn count_lines_of_code(&self, content: &str) -> u32 {
712 content
713 .lines()
714 .filter(|line| {
715 let trimmed = line.trim();
716 !trimmed.is_empty()
718 && !trimmed.starts_with("//")
719 && !trimmed.starts_with('#')
720 && !trimmed.starts_with("/*")
721 && !trimmed.starts_with('*')
722 })
723 .count() as u32
724 }
725
726 fn calculate_nesting_depth(&self, content: &str) -> u32 {
731 let brace_depth = self.calculate_brace_depth(content);
733
734 if brace_depth <= 1 {
737 let indent_depth = self.calculate_indent_depth(content);
738 brace_depth.max(indent_depth)
740 } else {
741 brace_depth
742 }
743 }
744
745 fn calculate_brace_depth(&self, content: &str) -> u32 {
747 let mut max_depth = 0u32;
748 let mut current_depth = 0i32;
749
750 for ch in content.chars() {
751 match ch {
752 '{' | '(' | '[' => {
753 current_depth += 1;
754 max_depth = max_depth.max(current_depth as u32);
755 },
756 '}' | ')' | ']' => {
757 current_depth = (current_depth - 1).max(0);
758 },
759 _ => {},
760 }
761 }
762
763 max_depth
764 }
765
766 fn calculate_indent_depth(&self, content: &str) -> u32 {
769 let mut max_depth = 0u32;
770 let mut base_indent: Option<usize> = None;
771
772 for line in content.lines() {
773 let trimmed = line.trim();
775 if trimmed.is_empty() || trimmed.starts_with('#') || trimmed.starts_with("--") {
776 continue;
777 }
778
779 let leading_spaces = line.len() - line.trim_start().len();
781
782 if base_indent.is_none() {
784 base_indent = Some(leading_spaces);
785 }
786
787 let base = base_indent.unwrap_or(0);
789 if leading_spaces >= base {
790 let relative_indent = leading_spaces - base;
791 let depth = (relative_indent / 4).max(relative_indent / 2) as u32;
793 max_depth = max_depth.max(depth + 1); }
795 }
796
797 max_depth
798 }
799
800 fn generate_tags(&self, symbol: &Symbol) -> Vec<String> {
805 let mut tags = Vec::new();
806 let signature = symbol.signature.as_deref().unwrap_or("");
807 let name_lower = symbol.name.to_lowercase();
808
809 if signature.contains("async")
813 || signature.contains("await")
814 || signature.contains("suspend")
815 {
817 tags.push("async".to_owned());
818 }
819 if name_lower.contains("thread")
820 || name_lower.contains("mutex")
821 || name_lower.contains("lock")
822 || name_lower.contains("spawn")
823 || name_lower.contains("parallel")
824 || name_lower.contains("goroutine")
825 || name_lower.contains("channel")
826 || signature.contains("Mutex")
827 || signature.contains("RwLock")
828 || signature.contains("Arc")
829 || signature.contains("chan ") || signature.contains("<-chan") || signature.contains("chan<-") || signature.contains("sync.") || signature.contains("WaitGroup")
834 {
836 tags.push("concurrency".to_owned());
837 }
838
839 if name_lower.contains("password")
841 || name_lower.contains("token")
842 || name_lower.contains("secret")
843 || name_lower.contains("auth")
844 || name_lower.contains("crypt")
845 || name_lower.contains("hash")
846 || name_lower.contains("permission")
847 || signature.contains("password")
848 || signature.contains("token")
849 || signature.contains("secret")
850 {
851 tags.push("security".to_owned());
852 }
853
854 if signature.contains("Error")
856 || signature.contains("Result")
857 || name_lower.contains("error")
858 || name_lower.contains("exception")
859 || name_lower.contains("panic")
860 || name_lower.contains("unwrap")
861 {
862 tags.push("error-handling".to_owned());
863 }
864
865 if name_lower.contains("query")
867 || name_lower.contains("sql")
868 || name_lower.contains("database")
869 || name_lower.contains("db_")
870 || name_lower.starts_with("db")
871 || name_lower.contains("repository")
872 || name_lower.contains("transaction")
873 {
874 tags.push("database".to_owned());
875 }
876
877 if name_lower.contains("http")
879 || name_lower.contains("request")
880 || name_lower.contains("response")
881 || name_lower.contains("endpoint")
882 || name_lower.contains("route")
883 || name_lower.contains("handler")
884 || name_lower.contains("middleware")
885 {
886 tags.push("http".to_owned());
887 }
888
889 if name_lower.contains("command")
891 || name_lower.contains("cli")
892 || name_lower.contains("arg")
893 || name_lower.contains("flag")
894 || name_lower.contains("option")
895 || name_lower.contains("subcommand")
896 {
897 tags.push("cli".to_owned());
898 }
899
900 if name_lower.contains("config")
902 || name_lower.contains("setting")
903 || name_lower.contains("preference")
904 || name_lower.contains("option")
905 || name_lower.contains("env")
906 {
907 tags.push("config".to_owned());
908 }
909
910 if name_lower.contains("log")
912 || name_lower.contains("trace")
913 || name_lower.contains("debug")
914 || name_lower.contains("warn")
915 || name_lower.contains("info")
916 || name_lower.contains("metric")
917 {
918 tags.push("logging".to_owned());
919 }
920
921 if name_lower.contains("cache")
923 || name_lower.contains("memoize")
924 || name_lower.contains("invalidate")
925 {
926 tags.push("cache".to_owned());
927 }
928
929 if name_lower.contains("valid")
931 || name_lower.contains("check")
932 || name_lower.contains("verify")
933 || name_lower.contains("assert")
934 || name_lower.contains("sanitize")
935 {
936 tags.push("validation".to_owned());
937 }
938
939 if name_lower.contains("serial")
941 || name_lower.contains("deserial")
942 || name_lower.contains("json")
943 || name_lower.contains("xml")
944 || name_lower.contains("yaml")
945 || name_lower.contains("toml")
946 || name_lower.contains("encode")
947 || name_lower.contains("decode")
948 || name_lower.contains("parse")
949 || name_lower.contains("format")
950 {
951 tags.push("serialization".to_owned());
952 }
953
954 if name_lower.contains("file")
956 || name_lower.contains("read")
957 || name_lower.contains("write")
958 || name_lower.contains("path")
959 || name_lower.contains("dir")
960 || name_lower.contains("fs")
961 || name_lower.contains("io")
962 {
963 tags.push("io".to_owned());
964 }
965
966 if name_lower.contains("socket")
968 || name_lower.contains("connect")
969 || name_lower.contains("network")
970 || name_lower.contains("tcp")
971 || name_lower.contains("udp")
972 || name_lower.contains("client")
973 || name_lower.contains("server")
974 {
975 tags.push("network".to_owned());
976 }
977
978 if name_lower == "new"
980 || name_lower == "init"
981 || name_lower == "setup"
982 || name_lower == "create"
983 || name_lower.starts_with("new_")
984 || name_lower.starts_with("init_")
985 || name_lower.starts_with("create_")
986 || name_lower.ends_with("_new")
987 {
988 tags.push("init".to_owned());
989 }
990
991 if name_lower.contains("cleanup")
993 || name_lower.contains("teardown")
994 || name_lower.contains("close")
995 || name_lower.contains("dispose")
996 || name_lower.contains("shutdown")
997 || name_lower == "drop"
998 {
999 tags.push("cleanup".to_owned());
1000 }
1001
1002 if symbol.name.starts_with("test_")
1004 || symbol.name.ends_with("_test")
1005 || symbol.name.contains("Test")
1006 || name_lower.contains("mock")
1007 || name_lower.contains("stub")
1008 || name_lower.contains("fixture")
1009 {
1010 tags.push("test".to_owned());
1011 }
1012
1013 if signature.contains("deprecated") || signature.contains("Deprecated") {
1015 tags.push("deprecated".to_owned());
1016 }
1017
1018 if signature.starts_with("pub fn")
1020 || signature.starts_with("pub async fn")
1021 || signature.starts_with("export")
1022 {
1023 tags.push("public-api".to_owned());
1024 }
1025
1026 if name_lower.contains("model")
1028 || name_lower.contains("train")
1029 || name_lower.contains("predict")
1030 || name_lower.contains("inference")
1031 || name_lower.contains("neural")
1032 || name_lower.contains("embedding")
1033 || name_lower.contains("classifier")
1034 || name_lower.contains("regressor")
1035 || name_lower.contains("optimizer")
1036 || name_lower.contains("loss")
1037 || name_lower.contains("gradient")
1038 || name_lower.contains("backprop")
1039 || name_lower.contains("forward")
1040 || name_lower.contains("layer")
1041 || name_lower.contains("activation")
1042 || name_lower.contains("weight")
1043 || name_lower.contains("bias")
1044 || name_lower.contains("epoch")
1045 || name_lower.contains("batch")
1046 || signature.contains("torch")
1047 || signature.contains("tensorflow")
1048 || signature.contains("keras")
1049 || signature.contains("sklearn")
1050 || signature.contains("nn.")
1051 || signature.contains("nn::")
1052 {
1053 tags.push("ml".to_owned());
1054 }
1055
1056 if name_lower.contains("dataframe")
1058 || name_lower.contains("dataset")
1059 || name_lower.contains("tensor")
1060 || name_lower.contains("numpy")
1061 || name_lower.contains("pandas")
1062 || name_lower.contains("array")
1063 || name_lower.contains("matrix")
1064 || name_lower.contains("vector")
1065 || name_lower.contains("feature")
1066 || name_lower.contains("preprocess")
1067 || name_lower.contains("normalize")
1068 || name_lower.contains("transform")
1069 || name_lower.contains("pipeline")
1070 || name_lower.contains("etl")
1071 || name_lower.contains("aggregate")
1072 || name_lower.contains("groupby")
1073 || name_lower.contains("pivot")
1074 || signature.contains("pd.")
1075 || signature.contains("np.")
1076 || signature.contains("DataFrame")
1077 || signature.contains("ndarray")
1078 {
1079 tags.push("data-science".to_owned());
1080 }
1081
1082 tags
1083 }
1084
1085 fn compute_fqn(&self, file: &str, symbol: &Symbol) -> String {
1092 let module_path = file
1094 .strip_suffix(".rs")
1095 .or_else(|| file.strip_suffix(".py"))
1096 .or_else(|| file.strip_suffix(".ts"))
1097 .or_else(|| file.strip_suffix(".tsx"))
1098 .or_else(|| file.strip_suffix(".js"))
1099 .or_else(|| file.strip_suffix(".jsx"))
1100 .or_else(|| file.strip_suffix(".go"))
1101 .or_else(|| file.strip_suffix(".java"))
1102 .or_else(|| file.strip_suffix(".c"))
1103 .or_else(|| file.strip_suffix(".cpp"))
1104 .or_else(|| file.strip_suffix(".h"))
1105 .or_else(|| file.strip_suffix(".hpp"))
1106 .or_else(|| file.strip_suffix(".rb"))
1107 .or_else(|| file.strip_suffix(".php"))
1108 .or_else(|| file.strip_suffix(".cs"))
1109 .or_else(|| file.strip_suffix(".swift"))
1110 .or_else(|| file.strip_suffix(".kt"))
1111 .or_else(|| file.strip_suffix(".scala"))
1112 .unwrap_or(file)
1113 .replace(['\\', '/'], "::"); if let Some(ref parent) = symbol.parent {
1116 format!("{}::{}::{}", module_path, parent, symbol.name)
1117 } else {
1118 format!("{}::{}", module_path, symbol.name)
1119 }
1120 }
1121
1122 fn is_test_code(&self, path: &Path, symbol: &Symbol) -> bool {
1124 let path_str = path.to_string_lossy().to_lowercase();
1125
1126 if path_str.contains("test") || path_str.contains("spec") || path_str.contains("__tests__")
1128 {
1129 return true;
1130 }
1131
1132 let name = symbol.name.to_lowercase();
1134 if name.starts_with("test_") || name.ends_with("_test") || name.contains("_test_") {
1135 return true;
1136 }
1137
1138 false
1139 }
1140
1141 fn validate_repo_path(&self, path: &Path) -> Result<PathBuf, EmbedError> {
1143 let canonical = path
1144 .canonicalize()
1145 .map_err(|e| EmbedError::IoError { path: path.to_path_buf(), source: e })?;
1146
1147 if !canonical.is_dir() {
1149 return Err(EmbedError::NotADirectory { path: path.to_path_buf() });
1150 }
1151
1152 Ok(canonical)
1153 }
1154
1155 fn safe_relative_path(&self, path: &Path, repo_root: &Path) -> Result<String, EmbedError> {
1157 let canonical = path
1158 .canonicalize()
1159 .map_err(|e| EmbedError::IoError { path: path.to_path_buf(), source: e })?;
1160
1161 if !canonical.starts_with(repo_root) {
1163 return Err(EmbedError::PathTraversal {
1164 path: canonical,
1165 repo_root: repo_root.to_path_buf(),
1166 });
1167 }
1168
1169 Ok(canonical
1171 .strip_prefix(repo_root)
1172 .unwrap_or(&canonical)
1173 .to_string_lossy()
1174 .replace('\\', "/"))
1175 }
1176
1177 fn discover_files(&self, repo_root: &Path) -> Result<Vec<PathBuf>, EmbedError> {
1179 use glob::Pattern;
1180 use ignore::WalkBuilder;
1181
1182 let mut files = Vec::new();
1183
1184 let mut include_patterns = Vec::new();
1186 for pattern_str in &self.settings.include_patterns {
1187 match Pattern::new(pattern_str) {
1188 Ok(pattern) => include_patterns.push(pattern),
1189 Err(e) => {
1190 return Err(EmbedError::InvalidPattern {
1191 pattern: pattern_str.clone(),
1192 reason: e.to_string(),
1193 });
1194 },
1195 }
1196 }
1197
1198 let mut exclude_patterns = Vec::new();
1200 for pattern_str in &self.settings.exclude_patterns {
1201 match Pattern::new(pattern_str) {
1202 Ok(pattern) => exclude_patterns.push(pattern),
1203 Err(e) => {
1204 return Err(EmbedError::InvalidPattern {
1205 pattern: pattern_str.clone(),
1206 reason: e.to_string(),
1207 });
1208 },
1209 }
1210 }
1211
1212 let walker = WalkBuilder::new(repo_root)
1213 .hidden(false) .git_ignore(true) .git_global(true)
1216 .git_exclude(true)
1217 .follow_links(false) .build();
1219
1220 for entry in walker {
1221 let entry = entry.map_err(|e| EmbedError::IoError {
1222 path: repo_root.to_path_buf(),
1223 source: std::io::Error::other(e.to_string()),
1224 })?;
1225
1226 let path = entry.path();
1227
1228 if !path.is_file() {
1230 continue;
1231 }
1232
1233 let relative_path = path
1235 .strip_prefix(repo_root)
1236 .unwrap_or(path)
1237 .to_string_lossy();
1238
1239 if !include_patterns.is_empty()
1241 && !include_patterns.iter().any(|p| p.matches(&relative_path))
1242 {
1243 continue;
1244 }
1245
1246 if exclude_patterns.iter().any(|p| p.matches(&relative_path)) {
1248 continue;
1249 }
1250
1251 if !self.settings.include_tests && self.is_test_file(path) {
1253 continue;
1254 }
1255
1256 let ext = match path.extension().and_then(|e| e.to_str()) {
1258 Some(e) => e,
1259 None => continue,
1260 };
1261 if Language::from_extension(ext).is_none() {
1262 continue;
1263 }
1264
1265 files.push(path.to_path_buf());
1266 }
1267
1268 Ok(files)
1269 }
1270
1271 fn is_test_file(&self, path: &Path) -> bool {
1273 let path_str = path.to_string_lossy().to_lowercase();
1274
1275 if path_str.contains("/tests/")
1277 || path_str.contains("\\tests\\")
1278 || path_str.contains("/test/")
1279 || path_str.contains("\\test\\")
1280 || path_str.contains("/__tests__/")
1281 || path_str.contains("\\__tests__\\")
1282 || path_str.contains("/spec/")
1283 || path_str.contains("\\spec\\")
1284 {
1285 return true;
1286 }
1287
1288 let filename = path
1290 .file_name()
1291 .and_then(|n| n.to_str())
1292 .unwrap_or("")
1293 .to_lowercase();
1294
1295 filename.starts_with("test_")
1296 || filename.ends_with("_test.rs")
1297 || filename.ends_with("_test.py")
1298 || filename.ends_with("_test.go")
1299 || filename.ends_with(".test.ts")
1300 || filename.ends_with(".test.js")
1301 || filename.ends_with(".test.tsx")
1302 || filename.ends_with(".test.jsx")
1303 || filename.ends_with(".spec.ts")
1304 || filename.ends_with(".spec.js")
1305 || filename.ends_with("_spec.rb")
1306 }
1307
1308 fn detect_language(&self, path: &Path) -> String {
1310 path.extension()
1311 .and_then(|e| e.to_str())
1312 .and_then(Language::from_extension)
1313 .map_or_else(|| "unknown".to_owned(), |l| l.display_name().to_owned())
1314 }
1315
1316 fn parse_token_model(&self, model: &str) -> TokenModel {
1318 TokenModel::from_model_name(model).unwrap_or(TokenModel::Claude)
1319 }
1320}
1321
1322#[cfg(test)]
1323mod tests {
1324 use super::*;
1325 use crate::embedding::progress::QuietProgress;
1326 use tempfile::TempDir;
1327
1328 fn create_test_file(dir: &Path, name: &str, content: &str) {
1329 let path = dir.join(name);
1330 if let Some(parent) = path.parent() {
1331 std::fs::create_dir_all(parent).unwrap();
1332 }
1333 std::fs::write(path, content).unwrap();
1334 }
1335
1336 #[test]
1337 fn test_chunker_creation() {
1338 let settings = EmbedSettings::default();
1339 let limits = ResourceLimits::default();
1340 let chunker = EmbedChunker::new(settings, limits);
1341 assert!(chunker.settings.max_tokens > 0);
1342 }
1343
1344 #[test]
1345 fn test_chunk_single_file() {
1346 let temp_dir = TempDir::new().unwrap();
1347 let rust_code = r#"
1348/// A test function
1349fn hello() {
1350 println!("Hello, world!");
1351}
1352
1353fn goodbye() {
1354 println!("Goodbye!");
1355}
1356"#;
1357 create_test_file(temp_dir.path(), "test.rs", rust_code);
1358
1359 let settings = EmbedSettings::default();
1360 let chunker = EmbedChunker::with_defaults(settings);
1361 let progress = QuietProgress;
1362
1363 let chunks = chunker
1364 .chunk_repository(temp_dir.path(), &progress)
1365 .unwrap();
1366
1367 assert!(!chunks.is_empty());
1369
1370 for i in 1..chunks.len() {
1372 assert!(chunks[i - 1].source.file <= chunks[i].source.file);
1373 }
1374 }
1375
1376 #[test]
1377 fn test_determinism() {
1378 let temp_dir = TempDir::new().unwrap();
1379 create_test_file(temp_dir.path(), "a.rs", "fn foo() {}");
1380 create_test_file(temp_dir.path(), "b.rs", "fn bar() {}");
1381
1382 let settings = EmbedSettings::default();
1383 let progress = QuietProgress;
1384
1385 let results: Vec<Vec<EmbedChunk>> = (0..3)
1386 .map(|_| {
1387 let chunker = EmbedChunker::with_defaults(settings.clone());
1388 chunker
1389 .chunk_repository(temp_dir.path(), &progress)
1390 .unwrap()
1391 })
1392 .collect();
1393
1394 for i in 1..results.len() {
1396 assert_eq!(results[0].len(), results[i].len());
1397 for j in 0..results[0].len() {
1398 assert_eq!(results[0][j].id, results[i][j].id);
1399 }
1400 }
1401 }
1402
1403 #[test]
1404 fn test_file_too_large() {
1405 let temp_dir = TempDir::new().unwrap();
1406 let large_content = "x".repeat(200);
1408 create_test_file(temp_dir.path(), "large.rs", &large_content);
1409
1410 let settings = EmbedSettings::default();
1411 let limits = ResourceLimits::default().with_max_file_size(100);
1412 let chunker = EmbedChunker::new(settings, limits);
1413 let progress = QuietProgress;
1414
1415 let result = chunker.chunk_repository(temp_dir.path(), &progress);
1417
1418 assert!(result.is_err());
1421 }
1422
1423 #[test]
1424 fn test_empty_directory() {
1425 let temp_dir = TempDir::new().unwrap();
1426
1427 let settings = EmbedSettings::default();
1428 let chunker = EmbedChunker::with_defaults(settings);
1429 let progress = QuietProgress;
1430
1431 let result = chunker.chunk_repository(temp_dir.path(), &progress);
1432
1433 assert!(matches!(result, Err(EmbedError::NoChunksGenerated { .. })));
1434 }
1435
1436 #[test]
1437 fn test_language_detection() {
1438 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1439
1440 assert_eq!(chunker.detect_language(Path::new("test.rs")), "Rust");
1441 assert_eq!(chunker.detect_language(Path::new("test.py")), "Python");
1442 assert_eq!(chunker.detect_language(Path::new("test.unknown")), "unknown");
1443 }
1444
1445 #[test]
1446 fn test_is_test_code() {
1447 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1448
1449 let test_symbol = Symbol::new("test_foo", crate::types::SymbolKind::Function);
1450 assert!(chunker.is_test_code(Path::new("foo.rs"), &test_symbol));
1451
1452 let normal_symbol = Symbol::new("foo", crate::types::SymbolKind::Function);
1453 assert!(!chunker.is_test_code(Path::new("src/lib.rs"), &normal_symbol));
1454
1455 assert!(chunker.is_test_code(Path::new("tests/test_foo.rs"), &normal_symbol));
1457 }
1458
1459 #[test]
1460 fn test_generate_tags() {
1461 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1462
1463 let mut symbol = Symbol::new("authenticate_user", crate::types::SymbolKind::Function);
1464 symbol.signature = Some("async fn authenticate_user(password: &str)".to_owned());
1465
1466 let tags = chunker.generate_tags(&symbol);
1467 assert!(tags.contains(&"async".to_owned()));
1468 assert!(tags.contains(&"security".to_owned()));
1469 }
1470
1471 #[test]
1472 fn test_generate_tags_kotlin_suspend() {
1473 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1474
1475 let mut symbol = Symbol::new("fetchData", crate::types::SymbolKind::Function);
1476 symbol.signature = Some("suspend fun fetchData(): Result<Data>".to_owned());
1477
1478 let tags = chunker.generate_tags(&symbol);
1479 assert!(tags.contains(&"async".to_owned()), "Kotlin suspend should be tagged as async");
1480 }
1481
1482 #[test]
1483 fn test_generate_tags_go_concurrency() {
1484 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1485
1486 let mut symbol = Symbol::new("processMessages", crate::types::SymbolKind::Function);
1487 symbol.signature = Some("func processMessages(ch chan string)".to_owned());
1488
1489 let tags = chunker.generate_tags(&symbol);
1490 assert!(
1491 tags.contains(&"concurrency".to_owned()),
1492 "Go channels should be tagged as concurrency"
1493 );
1494 }
1495
1496 #[test]
1497 fn test_generate_tags_ml() {
1498 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1499
1500 let mut symbol = Symbol::new("train_model", crate::types::SymbolKind::Function);
1502 symbol.signature = Some("def train_model(epochs: int, batch_size: int)".to_owned());
1503 let tags = chunker.generate_tags(&symbol);
1504 assert!(tags.contains(&"ml".to_owned()), "train_model should be tagged as ml");
1505
1506 let mut symbol2 = Symbol::new("forward_pass", crate::types::SymbolKind::Function);
1508 symbol2.signature = Some("def forward_pass(self, x: torch.Tensor)".to_owned());
1509 let tags2 = chunker.generate_tags(&symbol2);
1510 assert!(
1511 tags2.contains(&"ml".to_owned()),
1512 "torch.Tensor in signature should be tagged as ml"
1513 );
1514
1515 let mut symbol3 = Symbol::new("ImageClassifier", crate::types::SymbolKind::Class);
1517 symbol3.signature = Some("class ImageClassifier(nn.Module)".to_owned());
1518 let tags3 = chunker.generate_tags(&symbol3);
1519 assert!(tags3.contains(&"ml".to_owned()), "nn.Module should be tagged as ml");
1520 }
1521
1522 #[test]
1523 fn test_generate_tags_data_science() {
1524 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1525
1526 let mut symbol = Symbol::new("preprocess_dataframe", crate::types::SymbolKind::Function);
1528 symbol.signature = Some("def preprocess_dataframe(df: pd.DataFrame)".to_owned());
1529 let tags = chunker.generate_tags(&symbol);
1530 assert!(
1531 tags.contains(&"data-science".to_owned()),
1532 "DataFrame should be tagged as data-science"
1533 );
1534
1535 let mut symbol2 = Symbol::new("normalize_array", crate::types::SymbolKind::Function);
1537 symbol2.signature = Some("def normalize_array(arr: np.ndarray)".to_owned());
1538 let tags2 = chunker.generate_tags(&symbol2);
1539 assert!(
1540 tags2.contains(&"data-science".to_owned()),
1541 "np.ndarray should be tagged as data-science"
1542 );
1543
1544 let symbol3 = Symbol::new("run_etl_pipeline", crate::types::SymbolKind::Function);
1546 let tags3 = chunker.generate_tags(&symbol3);
1547 assert!(tags3.contains(&"data-science".to_owned()), "etl should be tagged as data-science");
1548 }
1549
1550 #[test]
1551 fn test_brace_nesting_depth() {
1552 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1553
1554 let code = "fn foo() { if x { if y { } } }";
1556 assert_eq!(chunker.calculate_brace_depth(code), 3);
1557
1558 let flat = "let x = 1;";
1560 assert_eq!(chunker.calculate_brace_depth(flat), 0);
1561
1562 let deep = "fn f() { let a = vec![HashMap::new()]; }";
1564 assert!(chunker.calculate_brace_depth(deep) >= 2);
1565 }
1566
1567 #[test]
1568 fn test_indent_nesting_depth() {
1569 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1570
1571 let python_code = r#"
1573def foo():
1574 if x:
1575 if y:
1576 do_something()
1577 else:
1578 other()
1579"#;
1580 let depth = chunker.calculate_indent_depth(python_code);
1581 assert!(depth >= 3, "Should detect indentation nesting, got {}", depth);
1582
1583 let flat = "x = 1\ny = 2\n";
1585 assert!(chunker.calculate_indent_depth(flat) <= 1);
1586 }
1587
1588 #[test]
1589 fn test_combined_nesting_depth() {
1590 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1591
1592 let rust_code = "fn foo() { if x { match y { A => {}, B => {} } } }";
1594 let depth = chunker.calculate_nesting_depth(rust_code);
1595 assert!(depth >= 3, "Should use brace depth for Rust-like code");
1596
1597 let python_code = "def foo():\n if x:\n y()\n";
1599 let depth = chunker.calculate_nesting_depth(python_code);
1600 assert!(depth >= 1, "Should use indent depth for Python-like code");
1601 }
1602
1603 #[test]
1604 fn test_lines_of_code() {
1605 let chunker = EmbedChunker::with_defaults(EmbedSettings::default());
1606
1607 let code = r#"
1608// This is a comment
1609fn foo() {
1610 let x = 1;
1611
1612 // Another comment
1613 let y = 2;
1614}
1615"#;
1616 let loc = chunker.count_lines_of_code(code);
1617 assert!((4..=5).contains(&loc), "LOC should be ~4, got {}", loc);
1620 }
1621
1622 #[test]
1623 fn test_line_too_long_error() {
1624 let temp_dir = TempDir::new().unwrap();
1625
1626 let long_line = "x".repeat(50_000);
1628 let content = format!("fn foo() {{ {} }}", long_line);
1629 create_test_file(temp_dir.path(), "minified.rs", &content);
1630
1631 let settings = EmbedSettings::default();
1632 let limits = ResourceLimits::default().with_max_line_length(10_000);
1634 let chunker = EmbedChunker::new(settings, limits);
1635 let progress = QuietProgress;
1636
1637 let result = chunker.chunk_repository(temp_dir.path(), &progress);
1638
1639 assert!(result.is_err(), "Should reject files with very long lines");
1641 }
1642
1643 #[test]
1644 fn test_hierarchical_chunking_integration() {
1645 let temp_dir = TempDir::new().unwrap();
1646
1647 let rust_code = r#"
1649/// A user account
1650pub struct User {
1651 pub name: String,
1652 pub email: String,
1653}
1654
1655impl User {
1656 /// Create a new user
1657 pub fn new(name: String, email: String) -> Self {
1658 Self { name, email }
1659 }
1660
1661 /// Get the user's display name
1662 pub fn display_name(&self) -> &str {
1663 &self.name
1664 }
1665
1666 /// Validate the user's email
1667 pub fn validate_email(&self) -> bool {
1668 self.email.contains('@')
1669 }
1670}
1671"#;
1672 create_test_file(temp_dir.path(), "user.rs", rust_code);
1673
1674 let settings_no_hierarchy = EmbedSettings { enable_hierarchy: false, ..Default::default() };
1676 let chunker_no_hierarchy = EmbedChunker::with_defaults(settings_no_hierarchy);
1677 let progress = QuietProgress;
1678 let chunks_no_hierarchy = chunker_no_hierarchy
1679 .chunk_repository(temp_dir.path(), &progress)
1680 .unwrap();
1681
1682 let settings_with_hierarchy = EmbedSettings {
1684 enable_hierarchy: true,
1685 hierarchy_min_children: 2,
1686 ..Default::default()
1687 };
1688 let chunker_with_hierarchy = EmbedChunker::with_defaults(settings_with_hierarchy);
1689 let chunks_with_hierarchy = chunker_with_hierarchy
1690 .chunk_repository(temp_dir.path(), &progress)
1691 .unwrap();
1692
1693 assert!(
1695 chunks_with_hierarchy.len() >= chunks_no_hierarchy.len(),
1696 "Hierarchy should produce at least as many chunks: {} vs {}",
1697 chunks_with_hierarchy.len(),
1698 chunks_no_hierarchy.len()
1699 );
1700
1701 let summary_chunks: Vec<_> = chunks_with_hierarchy
1703 .iter()
1704 .filter(|c| matches!(c.kind, ChunkKind::Module)) .collect();
1706
1707 if !summary_chunks.is_empty() {
1710 for summary in &summary_chunks {
1712 assert!(!summary.content.is_empty(), "Summary chunk should have content");
1713 }
1714 }
1715
1716 let chunks_with_hierarchy_2 = chunker_with_hierarchy
1718 .chunk_repository(temp_dir.path(), &progress)
1719 .unwrap();
1720 assert_eq!(
1721 chunks_with_hierarchy.len(),
1722 chunks_with_hierarchy_2.len(),
1723 "Hierarchical chunking should be deterministic"
1724 );
1725 for (c1, c2) in chunks_with_hierarchy
1726 .iter()
1727 .zip(chunks_with_hierarchy_2.iter())
1728 {
1729 assert_eq!(c1.id, c2.id, "Chunk IDs should be identical across runs");
1730 }
1731 }
1732}