1use anyhow::{Context, Result};
18use streaming_iterator::StreamingIterator;
19use tree_sitter::{Parser, Query, QueryCursor};
20use crate::models::{Language, SearchResult, Span, SymbolKind};
21
22pub fn parse(path: &str, source: &str) -> Result<Vec<SearchResult>> {
24 let mut parser = Parser::new();
25 let language = tree_sitter_cpp::LANGUAGE;
26
27 parser
28 .set_language(&language.into())
29 .context("Failed to set C++ language")?;
30
31 let tree = parser
32 .parse(source, None)
33 .context("Failed to parse C++ source")?;
34
35 let root_node = tree.root_node();
36
37 let mut symbols = Vec::new();
38
39 symbols.extend(extract_functions(source, &root_node, &language.into())?);
41 symbols.extend(extract_classes(source, &root_node, &language.into())?);
42 symbols.extend(extract_structs(source, &root_node, &language.into())?);
43 symbols.extend(extract_namespaces(source, &root_node, &language.into())?);
44 symbols.extend(extract_enums(source, &root_node, &language.into())?);
45 symbols.extend(extract_methods(source, &root_node, &language.into())?);
46 symbols.extend(extract_local_variables(source, &root_node, &language.into())?);
47 symbols.extend(extract_type_aliases(source, &root_node, &language.into())?);
48
49 for symbol in &mut symbols {
51 symbol.path = path.to_string();
52 symbol.lang = Language::Cpp;
53 }
54
55 Ok(symbols)
56}
57
58fn extract_functions(
60 source: &str,
61 root: &tree_sitter::Node,
62 language: &tree_sitter::Language,
63) -> Result<Vec<SearchResult>> {
64 let query_str = r#"
65 (function_definition
66 declarator: (function_declarator
67 declarator: (identifier) @name)) @function
68
69 (function_definition
70 declarator: (function_declarator
71 declarator: (qualified_identifier
72 name: (identifier) @name))) @function
73
74 (template_declaration
75 (function_definition
76 declarator: (function_declarator
77 declarator: (identifier) @name))) @function
78 "#;
79
80 let query = Query::new(language, query_str)
81 .context("Failed to create function query")?;
82
83 extract_symbols(source, root, &query, SymbolKind::Function, None)
84}
85
86fn extract_classes(
88 source: &str,
89 root: &tree_sitter::Node,
90 language: &tree_sitter::Language,
91) -> Result<Vec<SearchResult>> {
92 let query_str = r#"
93 (class_specifier
94 name: (type_identifier) @name) @class
95
96 (template_declaration
97 (class_specifier
98 name: (type_identifier) @name)) @class
99 "#;
100
101 let query = Query::new(language, query_str)
102 .context("Failed to create class query")?;
103
104 extract_symbols(source, root, &query, SymbolKind::Class, None)
105}
106
107fn extract_structs(
109 source: &str,
110 root: &tree_sitter::Node,
111 language: &tree_sitter::Language,
112) -> Result<Vec<SearchResult>> {
113 let query_str = r#"
114 (struct_specifier
115 name: (type_identifier) @name) @struct
116
117 (template_declaration
118 (struct_specifier
119 name: (type_identifier) @name)) @struct
120 "#;
121
122 let query = Query::new(language, query_str)
123 .context("Failed to create struct query")?;
124
125 extract_symbols(source, root, &query, SymbolKind::Struct, None)
126}
127
128fn extract_namespaces(
130 source: &str,
131 root: &tree_sitter::Node,
132 language: &tree_sitter::Language,
133) -> Result<Vec<SearchResult>> {
134 let query_str = r#"
135 (namespace_definition
136 name: (_) @name) @namespace
137 "#;
138
139 let query = Query::new(language, query_str)
140 .context("Failed to create namespace query")?;
141
142 extract_symbols(source, root, &query, SymbolKind::Namespace, None)
143}
144
145fn extract_enums(
147 source: &str,
148 root: &tree_sitter::Node,
149 language: &tree_sitter::Language,
150) -> Result<Vec<SearchResult>> {
151 let query_str = r#"
152 (enum_specifier
153 name: (type_identifier) @name) @enum
154 "#;
155
156 let query = Query::new(language, query_str)
157 .context("Failed to create enum query")?;
158
159 extract_symbols(source, root, &query, SymbolKind::Enum, None)
160}
161
162fn extract_methods(
164 source: &str,
165 root: &tree_sitter::Node,
166 language: &tree_sitter::Language,
167) -> Result<Vec<SearchResult>> {
168 let query_str = r#"
169 (class_specifier
170 name: (type_identifier) @class_name
171 body: (field_declaration_list
172 (function_definition
173 declarator: (function_declarator
174 declarator: (field_identifier) @method_name)))) @class
175
176 (class_specifier
177 name: (type_identifier) @class_name
178 body: (field_declaration_list
179 (function_definition
180 declarator: (function_declarator
181 declarator: (destructor_name) @method_name)))) @class
182
183 (struct_specifier
184 name: (type_identifier) @struct_name
185 body: (field_declaration_list
186 (function_definition
187 declarator: (function_declarator
188 declarator: (field_identifier) @method_name)))) @struct
189
190 (struct_specifier
191 name: (type_identifier) @struct_name
192 body: (field_declaration_list
193 (function_definition
194 declarator: (function_declarator
195 declarator: (destructor_name) @method_name)))) @struct
196 "#;
197
198 let query = Query::new(language, query_str)
199 .context("Failed to create method query")?;
200
201 let mut cursor = QueryCursor::new();
202 let mut matches = cursor.matches(&query, *root, source.as_bytes());
203
204 let mut symbols = Vec::new();
205
206 while let Some(match_) = matches.next() {
207 let mut scope_name = None;
208 let mut scope_type = None;
209 let mut method_name = None;
210 let mut method_node = None;
211
212 for capture in match_.captures {
213 let capture_name: &str = &query.capture_names()[capture.index as usize];
214 match capture_name {
215 "class_name" => {
216 scope_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
217 scope_type = Some("class");
218 }
219 "struct_name" => {
220 scope_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
221 scope_type = Some("struct");
222 }
223 "method_name" => {
224 method_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
225 let mut current = capture.node;
227 while let Some(parent) = current.parent() {
228 if parent.kind() == "function_definition" {
229 method_node = Some(parent);
230 break;
231 }
232 current = parent;
233 }
234 }
235 _ => {}
236 }
237 }
238
239 if let (Some(scope_name), Some(scope_type), Some(method_name), Some(node)) =
240 (scope_name, scope_type, method_name, method_node) {
241 let scope = format!("{} {}", scope_type, scope_name);
242 let span = node_to_span(&node);
243 let preview = extract_preview(source, &span);
244
245 symbols.push(SearchResult::new(
246 String::new(),
247 Language::Cpp,
248 SymbolKind::Method,
249 Some(method_name),
250 span,
251 Some(scope),
252 preview,
253 ));
254 }
255 }
256
257 Ok(symbols)
258}
259
260fn extract_local_variables(
262 source: &str,
263 root: &tree_sitter::Node,
264 language: &tree_sitter::Language,
265) -> Result<Vec<SearchResult>> {
266 let query_str = r#"
267 (declaration
268 declarator: (init_declarator
269 declarator: (identifier) @name)) @var
270 "#;
271
272 let query = Query::new(language, query_str)
273 .context("Failed to create local variable query")?;
274
275 let mut cursor = QueryCursor::new();
276 let mut matches = cursor.matches(&query, *root, source.as_bytes());
277
278 let mut symbols = Vec::new();
279
280 while let Some(match_) = matches.next() {
281 let mut name = None;
282 let mut var_node = None;
283
284 for capture in match_.captures {
285 let capture_name: &str = &query.capture_names()[capture.index as usize];
286 match capture_name {
287 "name" => {
288 name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
289 }
290 "var" => {
291 var_node = Some(capture.node);
292 }
293 _ => {}
294 }
295 }
296
297 if let (Some(name), Some(node)) = (name, var_node) {
299 let mut is_local_var = false;
300 let mut current = node;
301
302 while let Some(parent) = current.parent() {
303 if parent.kind() == "function_definition" {
304 is_local_var = true;
305 break;
306 }
307 current = parent;
308 }
309
310 if is_local_var {
311 let span = node_to_span(&node);
312 let preview = extract_preview(source, &span);
313
314 symbols.push(SearchResult::new(
315 String::new(),
316 Language::Cpp,
317 SymbolKind::Variable,
318 Some(name),
319 span,
320 None, preview,
322 ));
323 }
324 }
325 }
326
327 Ok(symbols)
328}
329
330fn extract_type_aliases(
332 source: &str,
333 root: &tree_sitter::Node,
334 language: &tree_sitter::Language,
335) -> Result<Vec<SearchResult>> {
336 let query_str = r#"
337 (type_definition
338 declarator: (type_identifier) @name) @typedef
339
340 (alias_declaration
341 name: (type_identifier) @name) @using
342 "#;
343
344 let query = Query::new(language, query_str)
345 .context("Failed to create type alias query")?;
346
347 extract_symbols(source, root, &query, SymbolKind::Type, None)
348}
349
350fn extract_symbols(
352 source: &str,
353 root: &tree_sitter::Node,
354 query: &Query,
355 kind: SymbolKind,
356 scope: Option<String>,
357) -> Result<Vec<SearchResult>> {
358 let mut cursor = QueryCursor::new();
359 let mut matches = cursor.matches(query, *root, source.as_bytes());
360
361 let mut symbols = Vec::new();
362 let mut seen_names = std::collections::HashSet::new();
363
364 while let Some(match_) = matches.next() {
365 let mut name = None;
367 let mut name_node = None;
368 let mut full_node = None;
369
370 for capture in match_.captures {
371 let capture_name: &str = &query.capture_names()[capture.index as usize];
372 if capture_name == "name" {
373 name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
374 name_node = Some(capture.node);
375 } else {
376 full_node = Some(capture.node);
378 }
379 }
380
381 if let (Some(name), Some(name_node), Some(node)) = (name, name_node, full_node) {
382 let name_key = (name_node.start_byte(), name_node.end_byte(), name.clone());
385 if seen_names.contains(&name_key) {
386 continue; }
388 seen_names.insert(name_key);
389
390 let span = node_to_span(&node);
391 let preview = extract_preview(source, &span);
392
393 symbols.push(SearchResult::new(
394 String::new(),
395 Language::Cpp,
396 kind.clone(),
397 Some(name),
398 span,
399 scope.clone(),
400 preview,
401 ));
402 }
403 }
404
405 Ok(symbols)
406}
407
408fn node_to_span(node: &tree_sitter::Node) -> Span {
410 let start = node.start_position();
411 let end = node.end_position();
412
413 Span::new(
414 start.row + 1, start.column,
416 end.row + 1,
417 end.column,
418 )
419}
420
421fn extract_preview(source: &str, span: &Span) -> String {
423 let lines: Vec<&str> = source.lines().collect();
424
425 let start_idx = (span.start_line - 1) as usize; let end_idx = (start_idx + 7).min(lines.len());
428
429 lines[start_idx..end_idx].join("\n")
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_parse_function() {
438 let source = r#"
439int add(int a, int b) {
440 return a + b;
441}
442 "#;
443
444 let symbols = parse("test.cpp", source).unwrap();
445 assert_eq!(symbols.len(), 1);
446 assert_eq!(symbols[0].symbol.as_deref(), Some("add"));
447 assert!(matches!(symbols[0].kind, SymbolKind::Function));
448 }
449
450 #[test]
451 fn test_parse_class() {
452 let source = r#"
453class User {
454private:
455 std::string name;
456 int age;
457
458public:
459 User(std::string n, int a) : name(n), age(a) {}
460};
461 "#;
462
463 let symbols = parse("test.cpp", source).unwrap();
464
465 let class_symbols: Vec<_> = symbols.iter()
466 .filter(|s| matches!(s.kind, SymbolKind::Class))
467 .collect();
468
469 assert_eq!(class_symbols.len(), 1);
470 assert_eq!(class_symbols[0].symbol.as_deref(), Some("User"));
471 }
472
473 #[test]
474 fn test_parse_namespace() {
475 let source = r#"
476namespace MyNamespace {
477 int value = 42;
478}
479
480namespace Nested::Namespace {
481 void function() {}
482}
483 "#;
484
485 let symbols = parse("test.cpp", source).unwrap();
486
487 let namespace_symbols: Vec<_> = symbols.iter()
488 .filter(|s| matches!(s.kind, SymbolKind::Namespace))
489 .collect();
490
491 assert!(namespace_symbols.len() >= 1);
492 assert!(namespace_symbols.iter().any(|s| s.symbol.as_deref() == Some("MyNamespace")));
493 }
494
495 #[test]
496 fn test_parse_struct() {
497 let source = r#"
498struct Point {
499 int x;
500 int y;
501};
502 "#;
503
504 let symbols = parse("test.cpp", source).unwrap();
505 assert_eq!(symbols.len(), 1);
506 assert_eq!(symbols[0].symbol.as_deref(), Some("Point"));
507 assert!(matches!(symbols[0].kind, SymbolKind::Struct));
508 }
509
510 #[test]
511 fn test_parse_enum() {
512 let source = r#"
513enum Color {
514 RED,
515 GREEN,
516 BLUE
517};
518
519enum class Status {
520 Active,
521 Inactive
522};
523 "#;
524
525 let symbols = parse("test.cpp", source).unwrap();
526
527 let enum_symbols: Vec<_> = symbols.iter()
528 .filter(|s| matches!(s.kind, SymbolKind::Enum))
529 .collect();
530
531 assert_eq!(enum_symbols.len(), 2);
532 assert!(enum_symbols.iter().any(|s| s.symbol.as_deref() == Some("Color")));
533 assert!(enum_symbols.iter().any(|s| s.symbol.as_deref() == Some("Status")));
534 }
535
536 #[test]
537 fn test_parse_template_class() {
538 let source = r#"
539template <typename T>
540class Container {
541private:
542 T value;
543
544public:
545 Container(T v) : value(v) {}
546 T getValue() { return value; }
547};
548 "#;
549
550 let symbols = parse("test.cpp", source).unwrap();
551
552 let class_symbols: Vec<_> = symbols.iter()
553 .filter(|s| matches!(s.kind, SymbolKind::Class))
554 .collect();
555
556 assert_eq!(class_symbols.len(), 1);
557 assert_eq!(class_symbols[0].symbol.as_deref(), Some("Container"));
558 }
559
560 #[test]
561 fn test_parse_template_function() {
562 let source = r#"
563template <typename T>
564T max(T a, T b) {
565 return (a > b) ? a : b;
566}
567 "#;
568
569 let symbols = parse("test.cpp", source).unwrap();
570 assert_eq!(symbols.len(), 1);
571 assert_eq!(symbols[0].symbol.as_deref(), Some("max"));
572 assert!(matches!(symbols[0].kind, SymbolKind::Function));
573 }
574
575 #[test]
576 fn test_parse_class_with_methods() {
577 let source = r#"
578class Calculator {
579public:
580 int add(int a, int b) {
581 return a + b;
582 }
583
584 int subtract(int a, int b) {
585 return a - b;
586 }
587};
588 "#;
589
590 let symbols = parse("test.cpp", source).unwrap();
591
592 let method_symbols: Vec<_> = symbols.iter()
593 .filter(|s| matches!(s.kind, SymbolKind::Method))
594 .collect();
595
596 assert_eq!(method_symbols.len(), 2);
597 assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("add")));
598 assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("subtract")));
599
600 for method in method_symbols {
602 }
604 }
605
606 #[test]
607 fn test_parse_using_declaration() {
608 let source = r#"
609using StringVector = std::vector<std::string>;
610using IntPtr = int*;
611 "#;
612
613 let symbols = parse("test.cpp", source).unwrap();
614
615 let type_symbols: Vec<_> = symbols.iter()
616 .filter(|s| matches!(s.kind, SymbolKind::Type))
617 .collect();
618
619 assert!(type_symbols.len() >= 1);
620 assert!(type_symbols.iter().any(|s| s.symbol.as_deref() == Some("StringVector")));
621 }
622
623 #[test]
624 fn test_parse_typedef() {
625 let source = r#"
626typedef unsigned int uint;
627typedef struct {
628 int x, y;
629} Point;
630 "#;
631
632 let symbols = parse("test.cpp", source).unwrap();
633
634 let type_symbols: Vec<_> = symbols.iter()
635 .filter(|s| matches!(s.kind, SymbolKind::Type))
636 .collect();
637
638 assert!(type_symbols.len() >= 1);
639 }
640
641 #[test]
642 fn test_parse_mixed_symbols() {
643 let source = r#"
644namespace Math {
645 class Vector {
646 private:
647 double x, y;
648
649 public:
650 Vector(double x, double y) : x(x), y(y) {}
651
652 double magnitude() {
653 return sqrt(x*x + y*y);
654 }
655 };
656
657 enum Operation {
658 ADD,
659 SUBTRACT
660 };
661
662 template <typename T>
663 T multiply(T a, T b) {
664 return a * b;
665 }
666}
667 "#;
668
669 let symbols = parse("test.cpp", source).unwrap();
670
671 assert!(symbols.len() >= 5);
673
674 let kinds: Vec<&SymbolKind> = symbols.iter().map(|s| &s.kind).collect();
675 assert!(kinds.contains(&&SymbolKind::Namespace));
676 assert!(kinds.contains(&&SymbolKind::Class));
677 assert!(kinds.contains(&&SymbolKind::Enum));
678 assert!(kinds.contains(&&SymbolKind::Method));
679 assert!(kinds.contains(&&SymbolKind::Function));
680 }
681
682 #[test]
683 fn test_parse_nested_namespace() {
684 let source = r#"
685namespace Outer {
686 namespace Inner {
687 void function() {}
688 }
689}
690 "#;
691
692 let symbols = parse("test.cpp", source).unwrap();
693
694 let namespace_symbols: Vec<_> = symbols.iter()
695 .filter(|s| matches!(s.kind, SymbolKind::Namespace))
696 .collect();
697
698 assert_eq!(namespace_symbols.len(), 2);
699 assert!(namespace_symbols.iter().any(|s| s.symbol.as_deref() == Some("Outer")));
700 assert!(namespace_symbols.iter().any(|s| s.symbol.as_deref() == Some("Inner")));
701 }
702
703 #[test]
704 fn test_parse_virtual_methods() {
705 let source = r#"
706class Base {
707public:
708 virtual void draw() = 0;
709 virtual void update() {}
710};
711
712class Derived : public Base {
713public:
714 void draw() override {
715 // Implementation
716 }
717};
718 "#;
719
720 let symbols = parse("test.cpp", source).unwrap();
721
722 let class_symbols: Vec<_> = symbols.iter()
723 .filter(|s| matches!(s.kind, SymbolKind::Class))
724 .collect();
725
726 assert_eq!(class_symbols.len(), 2);
727 assert!(class_symbols.iter().any(|s| s.symbol.as_deref() == Some("Base")));
728 assert!(class_symbols.iter().any(|s| s.symbol.as_deref() == Some("Derived")));
729
730 let method_symbols: Vec<_> = symbols.iter()
731 .filter(|s| matches!(s.kind, SymbolKind::Method))
732 .collect();
733
734 assert!(method_symbols.len() >= 2);
735 }
736
737 #[test]
738 fn test_parse_operator_overload() {
739 let source = r#"
740class Complex {
741private:
742 double real, imag;
743
744public:
745 Complex operator+(const Complex& other) {
746 return Complex(real + other.real, imag + other.imag);
747 }
748};
749 "#;
750
751 let symbols = parse("test.cpp", source).unwrap();
752
753 let class_symbols: Vec<_> = symbols.iter()
754 .filter(|s| matches!(s.kind, SymbolKind::Class))
755 .collect();
756
757 assert_eq!(class_symbols.len(), 1);
758 assert_eq!(class_symbols[0].symbol.as_deref(), Some("Complex"));
759 }
760
761 #[test]
762 fn test_local_variables_included() {
763 let source = r#"
764int calculate(int input) {
765 int localVar = input * 2;
766 auto result = localVar + 10;
767 return result;
768}
769
770class Calculator {
771public:
772 int compute(int value) {
773 int temp = value * 3;
774 auto final = temp + 5;
775 return final;
776 }
777};
778 "#;
779
780 let symbols = parse("test.cpp", source).unwrap();
781
782 let variables: Vec<_> = symbols.iter()
784 .filter(|s| matches!(s.kind, SymbolKind::Variable))
785 .collect();
786
787 assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("localVar")));
789 assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("result")));
790 assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("temp")));
791 assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("final")));
792
793 for var in variables {
795 }
797 }
798
799 #[test]
800 fn test_parse_destructor() {
801 let source = r#"
802class Resource {
803private:
804 int* data;
805
806public:
807 Resource() {
808 data = new int[100];
809 }
810
811 ~Resource() {
812 delete[] data;
813 }
814};
815 "#;
816
817 let symbols = parse("test.cpp", source).unwrap();
818
819 let class_symbols: Vec<_> = symbols.iter()
820 .filter(|s| matches!(s.kind, SymbolKind::Class))
821 .collect();
822
823 assert_eq!(class_symbols.len(), 1);
824 assert_eq!(class_symbols[0].symbol.as_deref(), Some("Resource"));
825
826 let method_symbols: Vec<_> = symbols.iter()
828 .filter(|s| matches!(s.kind, SymbolKind::Method))
829 .collect();
830
831 assert!(method_symbols.len() >= 1, "Expected at least constructor or destructor to be extracted");
833
834 for method in &method_symbols {
836 println!("Found method: {:?}", method.symbol);
837 }
838
839 let has_destructor = method_symbols.iter().any(|s| {
841 s.symbol.as_deref()
842 .map(|name| name.contains("~") || name == "Resource")
843 .unwrap_or(false)
844 });
845
846 if !has_destructor {
848 println!("WARNING: Destructor extraction may not be working");
849 }
850 }
851}
852
853use crate::models::ImportType;
858use crate::parsers::{DependencyExtractor, ImportInfo};
859
860pub struct CppDependencyExtractor;
862
863impl DependencyExtractor for CppDependencyExtractor {
864 fn extract_dependencies(source: &str) -> Result<Vec<ImportInfo>> {
865 let mut parser = Parser::new();
866 let language = tree_sitter_cpp::LANGUAGE;
867
868 parser
869 .set_language(&language.into())
870 .context("Failed to set C++ language")?;
871
872 let tree = parser
873 .parse(source, None)
874 .context("Failed to parse C++ source")?;
875
876 let root_node = tree.root_node();
877
878 let mut imports = Vec::new();
879
880 imports.extend(extract_cpp_includes(source, &root_node)?);
882
883 Ok(imports)
884 }
885}
886
887fn extract_cpp_includes(
889 source: &str,
890 root: &tree_sitter::Node,
891) -> Result<Vec<ImportInfo>> {
892 let language = tree_sitter_cpp::LANGUAGE;
893
894 let query_str = r#"
895 (preproc_include
896 path: (string_literal) @include_path) @include
897
898 (preproc_include
899 path: (system_lib_string) @include_path) @include
900 "#;
901
902 let query = Query::new(&language.into(), query_str)
903 .context("Failed to create C++ include query")?;
904
905 let mut cursor = QueryCursor::new();
906 let mut matches = cursor.matches(&query, *root, source.as_bytes());
907
908 let mut imports = Vec::new();
909
910 while let Some(match_) = matches.next() {
911 let mut include_path = None;
912 let mut include_node = None;
913
914 for capture in match_.captures {
915 let capture_name: &str = &query.capture_names()[capture.index as usize];
916 match capture_name {
917 "include_path" => {
918 let raw_path = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
920 include_path = Some(raw_path.trim_matches(|c| c == '"' || c == '<' || c == '>').to_string());
921 }
922 "include" => {
923 include_node = Some(capture.node);
924 }
925 _ => {}
926 }
927 }
928
929 if let (Some(path), Some(node)) = (include_path, include_node) {
930 let import_type = classify_cpp_include(&path, source, &node);
931 let line_number = node.start_position().row + 1;
932
933 imports.push(ImportInfo {
934 imported_path: path,
935 import_type,
936 line_number,
937 imported_symbols: None, });
939 }
940 }
941
942 Ok(imports)
943}
944
945fn classify_cpp_include(include_path: &str, source: &str, node: &tree_sitter::Node) -> ImportType {
947 let line_start = node.start_position();
949 let lines: Vec<&str> = source.lines().collect();
950
951 if line_start.row < lines.len() {
952 let line = lines[line_start.row];
953
954 if line.contains(&format!("\"{}\"", include_path)) {
956 return ImportType::Internal;
957 }
958 }
959
960 const CPP_STDLIB_HEADERS: &[&str] = &[
962 "stdio.h", "stdlib.h", "string.h", "math.h", "time.h",
964 "ctype.h", "assert.h", "errno.h", "limits.h", "float.h",
965 "stddef.h", "stdint.h", "stdbool.h", "stdarg.h", "setjmp.h",
966 "signal.h", "locale.h", "wchar.h", "wctype.h", "complex.h",
967 "fenv.h", "inttypes.h", "iso646.h", "tgmath.h", "threads.h",
968
969 "algorithm", "any", "array", "atomic", "barrier", "bit",
971 "bitset", "charconv", "chrono", "codecvt", "compare", "complex",
972 "concepts", "condition_variable", "coroutine", "deque", "exception",
973 "execution", "expected", "filesystem", "format", "forward_list",
974 "fstream", "functional", "future", "initializer_list", "iomanip",
975 "ios", "iosfwd", "iostream", "istream", "iterator", "latch",
976 "limits", "list", "locale", "map", "mdspan", "memory",
977 "memory_resource", "mutex", "new", "numbers", "numeric", "optional",
978 "ostream", "queue", "random", "ranges", "ratio", "regex",
979 "scoped_allocator", "semaphore", "set", "shared_mutex", "source_location",
980 "span", "sstream", "stack", "stacktrace", "stdexcept", "stop_token",
981 "streambuf", "string", "string_view", "strstream", "syncstream",
982 "system_error", "thread", "tuple", "type_traits", "typeindex",
983 "typeinfo", "unordered_map", "unordered_set", "utility", "valarray",
984 "variant", "vector", "version",
985
986 "cassert", "cctype", "cerrno", "cfenv", "cfloat", "cinttypes",
988 "climits", "clocale", "cmath", "csetjmp", "csignal", "cstdarg",
989 "cstddef", "cstdint", "cstdio", "cstdlib", "cstring", "ctime",
990 "cuchar", "cwchar", "cwctype",
991 ];
992
993 if CPP_STDLIB_HEADERS.contains(&include_path) {
994 return ImportType::Stdlib;
995 }
996
997 ImportType::External
999}
1000
1001pub fn resolve_cpp_include_to_path(
1015 include_path: &str,
1016 current_file_path: Option<&str>,
1017) -> Option<String> {
1018 let current_file = current_file_path?;
1022
1023 let current_dir = std::path::Path::new(current_file).parent()?;
1025
1026 let resolved = current_dir.join(include_path);
1028
1029 match resolved.canonicalize() {
1031 Ok(normalized) => Some(normalized.display().to_string()),
1032 Err(_) => {
1033 Some(resolved.display().to_string())
1035 }
1036 }
1037}
1038
1039#[cfg(test)]
1044mod resolution_tests {
1045 use super::*;
1046
1047 #[test]
1048 fn test_resolve_cpp_include_same_directory() {
1049 let result = resolve_cpp_include_to_path(
1050 "helper.hpp",
1051 Some("/project/src/main.cpp"),
1052 );
1053
1054 assert!(result.is_some());
1055 let path = result.unwrap();
1056 assert!(path.ends_with("src/helper.hpp") || path.ends_with("src\\helper.hpp"));
1057 }
1058
1059 #[test]
1060 fn test_resolve_cpp_include_subdirectory() {
1061 let result = resolve_cpp_include_to_path(
1062 "utils/helper.hpp",
1063 Some("/project/src/main.cpp"),
1064 );
1065
1066 assert!(result.is_some());
1067 let path = result.unwrap();
1068 assert!(path.ends_with("src/utils/helper.hpp") || path.ends_with("src\\utils\\helper.hpp"));
1069 }
1070
1071 #[test]
1072 fn test_resolve_cpp_include_parent_directory() {
1073 let result = resolve_cpp_include_to_path(
1074 "../include/common.hpp",
1075 Some("/project/src/main.cpp"),
1076 );
1077
1078 assert!(result.is_some());
1079 let path = result.unwrap();
1080 assert!(path.contains("include") && path.contains("common.hpp"));
1081 }
1082
1083 #[test]
1084 fn test_resolve_cpp_include_h_extension() {
1085 let result = resolve_cpp_include_to_path(
1086 "legacy.h",
1087 Some("/project/src/main.cpp"),
1088 );
1089
1090 assert!(result.is_some());
1091 let path = result.unwrap();
1092 assert!(path.ends_with("src/legacy.h") || path.ends_with("src\\legacy.h"));
1093 }
1094
1095 #[test]
1096 fn test_resolve_cpp_include_no_current_file() {
1097 let result = resolve_cpp_include_to_path(
1098 "helper.hpp",
1099 None,
1100 );
1101
1102 assert!(result.is_none());
1103 }
1104}