1use crate::models::{Language, SearchResult, Span, SymbolKind};
18use anyhow::{Context, Result};
19use streaming_iterator::StreamingIterator;
20use tree_sitter::{Parser, Query, QueryCursor};
21
22pub fn parse(path: &str, source: &str) -> Result<Vec<SearchResult>> {
24 let mut parser = Parser::new();
25 let language = tree_sitter_cpp::LANGUAGE;
26
27 parser
28 .set_language(&language.into())
29 .context("Failed to set C++ language")?;
30
31 let tree = parser
32 .parse(source, None)
33 .context("Failed to parse C++ source")?;
34
35 let root_node = tree.root_node();
36
37 let mut symbols = Vec::new();
38
39 symbols.extend(extract_functions(source, &root_node, &language.into())?);
41 symbols.extend(extract_classes(source, &root_node, &language.into())?);
42 symbols.extend(extract_structs(source, &root_node, &language.into())?);
43 symbols.extend(extract_namespaces(source, &root_node, &language.into())?);
44 symbols.extend(extract_enums(source, &root_node, &language.into())?);
45 symbols.extend(extract_methods(source, &root_node, &language.into())?);
46 symbols.extend(extract_local_variables(
47 source,
48 &root_node,
49 &language.into(),
50 )?);
51 symbols.extend(extract_type_aliases(source, &root_node, &language.into())?);
52
53 for symbol in &mut symbols {
55 symbol.path = path.to_string();
56 symbol.lang = Language::Cpp;
57 }
58
59 Ok(symbols)
60}
61
62fn extract_functions(
64 source: &str,
65 root: &tree_sitter::Node,
66 language: &tree_sitter::Language,
67) -> Result<Vec<SearchResult>> {
68 let query_str = r#"
69 (function_definition
70 declarator: (function_declarator
71 declarator: (identifier) @name)) @function
72
73 (function_definition
74 declarator: (function_declarator
75 declarator: (qualified_identifier
76 name: (identifier) @name))) @function
77
78 (template_declaration
79 (function_definition
80 declarator: (function_declarator
81 declarator: (identifier) @name))) @function
82 "#;
83
84 let query = Query::new(language, query_str).context("Failed to create function query")?;
85
86 extract_symbols(source, root, &query, SymbolKind::Function, None)
87}
88
89fn extract_classes(
91 source: &str,
92 root: &tree_sitter::Node,
93 language: &tree_sitter::Language,
94) -> Result<Vec<SearchResult>> {
95 let query_str = r#"
96 (class_specifier
97 name: (type_identifier) @name) @class
98
99 (template_declaration
100 (class_specifier
101 name: (type_identifier) @name)) @class
102 "#;
103
104 let query = Query::new(language, query_str).context("Failed to create class query")?;
105
106 extract_symbols(source, root, &query, SymbolKind::Class, None)
107}
108
109fn extract_structs(
111 source: &str,
112 root: &tree_sitter::Node,
113 language: &tree_sitter::Language,
114) -> Result<Vec<SearchResult>> {
115 let query_str = r#"
116 (struct_specifier
117 name: (type_identifier) @name) @struct
118
119 (template_declaration
120 (struct_specifier
121 name: (type_identifier) @name)) @struct
122 "#;
123
124 let query = Query::new(language, query_str).context("Failed to create struct query")?;
125
126 extract_symbols(source, root, &query, SymbolKind::Struct, None)
127}
128
129fn extract_namespaces(
131 source: &str,
132 root: &tree_sitter::Node,
133 language: &tree_sitter::Language,
134) -> Result<Vec<SearchResult>> {
135 let query_str = r#"
136 (namespace_definition
137 name: (_) @name) @namespace
138 "#;
139
140 let query = Query::new(language, query_str).context("Failed to create namespace query")?;
141
142 extract_symbols(source, root, &query, SymbolKind::Namespace, None)
143}
144
145fn extract_enums(
147 source: &str,
148 root: &tree_sitter::Node,
149 language: &tree_sitter::Language,
150) -> Result<Vec<SearchResult>> {
151 let query_str = r#"
152 (enum_specifier
153 name: (type_identifier) @name) @enum
154 "#;
155
156 let query = Query::new(language, query_str).context("Failed to create enum query")?;
157
158 extract_symbols(source, root, &query, SymbolKind::Enum, None)
159}
160
161fn extract_methods(
163 source: &str,
164 root: &tree_sitter::Node,
165 language: &tree_sitter::Language,
166) -> Result<Vec<SearchResult>> {
167 let query_str = r#"
168 (class_specifier
169 name: (type_identifier) @class_name
170 body: (field_declaration_list
171 (function_definition
172 declarator: (function_declarator
173 declarator: (field_identifier) @method_name)))) @class
174
175 (class_specifier
176 name: (type_identifier) @class_name
177 body: (field_declaration_list
178 (function_definition
179 declarator: (function_declarator
180 declarator: (destructor_name) @method_name)))) @class
181
182 (struct_specifier
183 name: (type_identifier) @struct_name
184 body: (field_declaration_list
185 (function_definition
186 declarator: (function_declarator
187 declarator: (field_identifier) @method_name)))) @struct
188
189 (struct_specifier
190 name: (type_identifier) @struct_name
191 body: (field_declaration_list
192 (function_definition
193 declarator: (function_declarator
194 declarator: (destructor_name) @method_name)))) @struct
195 "#;
196
197 let query = Query::new(language, query_str).context("Failed to create method query")?;
198
199 let mut cursor = QueryCursor::new();
200 let mut matches = cursor.matches(&query, *root, source.as_bytes());
201
202 let mut symbols = Vec::new();
203
204 while let Some(match_) = matches.next() {
205 let mut scope_name = None;
206 let mut scope_type = None;
207 let mut method_name = None;
208 let mut method_node = None;
209
210 for capture in match_.captures {
211 let capture_name: &str = &query.capture_names()[capture.index as usize];
212 match capture_name {
213 "class_name" => {
214 scope_name = Some(
215 capture
216 .node
217 .utf8_text(source.as_bytes())
218 .unwrap_or("")
219 .to_string(),
220 );
221 scope_type = Some("class");
222 }
223 "struct_name" => {
224 scope_name = Some(
225 capture
226 .node
227 .utf8_text(source.as_bytes())
228 .unwrap_or("")
229 .to_string(),
230 );
231 scope_type = Some("struct");
232 }
233 "method_name" => {
234 method_name = Some(
235 capture
236 .node
237 .utf8_text(source.as_bytes())
238 .unwrap_or("")
239 .to_string(),
240 );
241 let mut current = capture.node;
243 while let Some(parent) = current.parent() {
244 if parent.kind() == "function_definition" {
245 method_node = Some(parent);
246 break;
247 }
248 current = parent;
249 }
250 }
251 _ => {}
252 }
253 }
254
255 if let (Some(scope_name), Some(scope_type), Some(method_name), Some(node)) =
256 (scope_name, scope_type, method_name, method_node)
257 {
258 let scope = format!("{} {}", scope_type, scope_name);
259 let span = node_to_span(&node);
260 let preview = extract_preview(source, &span);
261
262 symbols.push(SearchResult::new(
263 String::new(),
264 Language::Cpp,
265 SymbolKind::Method,
266 Some(method_name),
267 span,
268 Some(scope),
269 preview,
270 ));
271 }
272 }
273
274 Ok(symbols)
275}
276
277fn extract_local_variables(
279 source: &str,
280 root: &tree_sitter::Node,
281 language: &tree_sitter::Language,
282) -> Result<Vec<SearchResult>> {
283 let query_str = r#"
284 (declaration
285 declarator: (init_declarator
286 declarator: (identifier) @name)) @var
287 "#;
288
289 let query = Query::new(language, query_str).context("Failed to create local variable query")?;
290
291 let mut cursor = QueryCursor::new();
292 let mut matches = cursor.matches(&query, *root, source.as_bytes());
293
294 let mut symbols = Vec::new();
295
296 while let Some(match_) = matches.next() {
297 let mut name = None;
298 let mut var_node = None;
299
300 for capture in match_.captures {
301 let capture_name: &str = &query.capture_names()[capture.index as usize];
302 match capture_name {
303 "name" => {
304 name = Some(
305 capture
306 .node
307 .utf8_text(source.as_bytes())
308 .unwrap_or("")
309 .to_string(),
310 );
311 }
312 "var" => {
313 var_node = Some(capture.node);
314 }
315 _ => {}
316 }
317 }
318
319 if let (Some(name), Some(node)) = (name, var_node) {
321 let mut is_local_var = false;
322 let mut current = node;
323
324 while let Some(parent) = current.parent() {
325 if parent.kind() == "function_definition" {
326 is_local_var = true;
327 break;
328 }
329 current = parent;
330 }
331
332 if is_local_var {
333 let span = node_to_span(&node);
334 let preview = extract_preview(source, &span);
335
336 symbols.push(SearchResult::new(
337 String::new(),
338 Language::Cpp,
339 SymbolKind::Variable,
340 Some(name),
341 span,
342 None, preview,
344 ));
345 }
346 }
347 }
348
349 Ok(symbols)
350}
351
352fn extract_type_aliases(
354 source: &str,
355 root: &tree_sitter::Node,
356 language: &tree_sitter::Language,
357) -> Result<Vec<SearchResult>> {
358 let query_str = r#"
359 (type_definition
360 declarator: (type_identifier) @name) @typedef
361
362 (alias_declaration
363 name: (type_identifier) @name) @using
364 "#;
365
366 let query = Query::new(language, query_str).context("Failed to create type alias query")?;
367
368 extract_symbols(source, root, &query, SymbolKind::Type, None)
369}
370
371fn extract_symbols(
373 source: &str,
374 root: &tree_sitter::Node,
375 query: &Query,
376 kind: SymbolKind,
377 scope: Option<String>,
378) -> Result<Vec<SearchResult>> {
379 let mut cursor = QueryCursor::new();
380 let mut matches = cursor.matches(query, *root, source.as_bytes());
381
382 let mut symbols = Vec::new();
383 let mut seen_names = std::collections::HashSet::new();
384
385 while let Some(match_) = matches.next() {
386 let mut name = None;
388 let mut name_node = None;
389 let mut full_node = None;
390
391 for capture in match_.captures {
392 let capture_name: &str = &query.capture_names()[capture.index as usize];
393 if capture_name == "name" {
394 name = Some(
395 capture
396 .node
397 .utf8_text(source.as_bytes())
398 .unwrap_or("")
399 .to_string(),
400 );
401 name_node = Some(capture.node);
402 } else {
403 full_node = Some(capture.node);
405 }
406 }
407
408 if let (Some(name), Some(name_node), Some(node)) = (name, name_node, full_node) {
409 let name_key = (name_node.start_byte(), name_node.end_byte(), name.clone());
412 if seen_names.contains(&name_key) {
413 continue; }
415 seen_names.insert(name_key);
416
417 let span = node_to_span(&node);
418 let preview = extract_preview(source, &span);
419
420 symbols.push(SearchResult::new(
421 String::new(),
422 Language::Cpp,
423 kind.clone(),
424 Some(name),
425 span,
426 scope.clone(),
427 preview,
428 ));
429 }
430 }
431
432 Ok(symbols)
433}
434
435fn node_to_span(node: &tree_sitter::Node) -> Span {
437 let start = node.start_position();
438 let end = node.end_position();
439
440 Span::new(
441 start.row + 1, start.column,
443 end.row + 1,
444 end.column,
445 )
446}
447
448fn extract_preview(source: &str, span: &Span) -> String {
450 let lines: Vec<&str> = source.lines().collect();
451
452 let start_idx = (span.start_line - 1) as usize; let end_idx = (start_idx + 7).min(lines.len());
455
456 lines[start_idx..end_idx].join("\n")
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn test_parse_function() {
465 let source = r#"
466int add(int a, int b) {
467 return a + b;
468}
469 "#;
470
471 let symbols = parse("test.cpp", source).unwrap();
472 assert_eq!(symbols.len(), 1);
473 assert_eq!(symbols[0].symbol.as_deref(), Some("add"));
474 assert!(matches!(symbols[0].kind, SymbolKind::Function));
475 }
476
477 #[test]
478 fn test_parse_class() {
479 let source = r#"
480class User {
481private:
482 std::string name;
483 int age;
484
485public:
486 User(std::string n, int a) : name(n), age(a) {}
487};
488 "#;
489
490 let symbols = parse("test.cpp", source).unwrap();
491
492 let class_symbols: Vec<_> = symbols
493 .iter()
494 .filter(|s| matches!(s.kind, SymbolKind::Class))
495 .collect();
496
497 assert_eq!(class_symbols.len(), 1);
498 assert_eq!(class_symbols[0].symbol.as_deref(), Some("User"));
499 }
500
501 #[test]
502 fn test_parse_namespace() {
503 let source = r#"
504namespace MyNamespace {
505 int value = 42;
506}
507
508namespace Nested::Namespace {
509 void function() {}
510}
511 "#;
512
513 let symbols = parse("test.cpp", source).unwrap();
514
515 let namespace_symbols: Vec<_> = symbols
516 .iter()
517 .filter(|s| matches!(s.kind, SymbolKind::Namespace))
518 .collect();
519
520 assert!(namespace_symbols.len() >= 1);
521 assert!(
522 namespace_symbols
523 .iter()
524 .any(|s| s.symbol.as_deref() == Some("MyNamespace"))
525 );
526 }
527
528 #[test]
529 fn test_parse_struct() {
530 let source = r#"
531struct Point {
532 int x;
533 int y;
534};
535 "#;
536
537 let symbols = parse("test.cpp", source).unwrap();
538 assert_eq!(symbols.len(), 1);
539 assert_eq!(symbols[0].symbol.as_deref(), Some("Point"));
540 assert!(matches!(symbols[0].kind, SymbolKind::Struct));
541 }
542
543 #[test]
544 fn test_parse_enum() {
545 let source = r#"
546enum Color {
547 RED,
548 GREEN,
549 BLUE
550};
551
552enum class Status {
553 Active,
554 Inactive
555};
556 "#;
557
558 let symbols = parse("test.cpp", source).unwrap();
559
560 let enum_symbols: Vec<_> = symbols
561 .iter()
562 .filter(|s| matches!(s.kind, SymbolKind::Enum))
563 .collect();
564
565 assert_eq!(enum_symbols.len(), 2);
566 assert!(
567 enum_symbols
568 .iter()
569 .any(|s| s.symbol.as_deref() == Some("Color"))
570 );
571 assert!(
572 enum_symbols
573 .iter()
574 .any(|s| s.symbol.as_deref() == Some("Status"))
575 );
576 }
577
578 #[test]
579 fn test_parse_template_class() {
580 let source = r#"
581template <typename T>
582class Container {
583private:
584 T value;
585
586public:
587 Container(T v) : value(v) {}
588 T getValue() { return value; }
589};
590 "#;
591
592 let symbols = parse("test.cpp", source).unwrap();
593
594 let class_symbols: Vec<_> = symbols
595 .iter()
596 .filter(|s| matches!(s.kind, SymbolKind::Class))
597 .collect();
598
599 assert_eq!(class_symbols.len(), 1);
600 assert_eq!(class_symbols[0].symbol.as_deref(), Some("Container"));
601 }
602
603 #[test]
604 fn test_parse_template_function() {
605 let source = r#"
606template <typename T>
607T max(T a, T b) {
608 return (a > b) ? a : b;
609}
610 "#;
611
612 let symbols = parse("test.cpp", source).unwrap();
613 assert_eq!(symbols.len(), 1);
614 assert_eq!(symbols[0].symbol.as_deref(), Some("max"));
615 assert!(matches!(symbols[0].kind, SymbolKind::Function));
616 }
617
618 #[test]
619 fn test_parse_class_with_methods() {
620 let source = r#"
621class Calculator {
622public:
623 int add(int a, int b) {
624 return a + b;
625 }
626
627 int subtract(int a, int b) {
628 return a - b;
629 }
630};
631 "#;
632
633 let symbols = parse("test.cpp", source).unwrap();
634
635 let method_symbols: Vec<_> = symbols
636 .iter()
637 .filter(|s| matches!(s.kind, SymbolKind::Method))
638 .collect();
639
640 assert_eq!(method_symbols.len(), 2);
641 assert!(
642 method_symbols
643 .iter()
644 .any(|s| s.symbol.as_deref() == Some("add"))
645 );
646 assert!(
647 method_symbols
648 .iter()
649 .any(|s| s.symbol.as_deref() == Some("subtract"))
650 );
651
652 for method in method_symbols {
654 }
656 }
657
658 #[test]
659 fn test_parse_using_declaration() {
660 let source = r#"
661using StringVector = std::vector<std::string>;
662using IntPtr = int*;
663 "#;
664
665 let symbols = parse("test.cpp", source).unwrap();
666
667 let type_symbols: Vec<_> = symbols
668 .iter()
669 .filter(|s| matches!(s.kind, SymbolKind::Type))
670 .collect();
671
672 assert!(type_symbols.len() >= 1);
673 assert!(
674 type_symbols
675 .iter()
676 .any(|s| s.symbol.as_deref() == Some("StringVector"))
677 );
678 }
679
680 #[test]
681 fn test_parse_typedef() {
682 let source = r#"
683typedef unsigned int uint;
684typedef struct {
685 int x, y;
686} Point;
687 "#;
688
689 let symbols = parse("test.cpp", source).unwrap();
690
691 let type_symbols: Vec<_> = symbols
692 .iter()
693 .filter(|s| matches!(s.kind, SymbolKind::Type))
694 .collect();
695
696 assert!(type_symbols.len() >= 1);
697 }
698
699 #[test]
700 fn test_parse_mixed_symbols() {
701 let source = r#"
702namespace Math {
703 class Vector {
704 private:
705 double x, y;
706
707 public:
708 Vector(double x, double y) : x(x), y(y) {}
709
710 double magnitude() {
711 return sqrt(x*x + y*y);
712 }
713 };
714
715 enum Operation {
716 ADD,
717 SUBTRACT
718 };
719
720 template <typename T>
721 T multiply(T a, T b) {
722 return a * b;
723 }
724}
725 "#;
726
727 let symbols = parse("test.cpp", source).unwrap();
728
729 assert!(symbols.len() >= 5);
731
732 let kinds: Vec<&SymbolKind> = symbols.iter().map(|s| &s.kind).collect();
733 assert!(kinds.contains(&&SymbolKind::Namespace));
734 assert!(kinds.contains(&&SymbolKind::Class));
735 assert!(kinds.contains(&&SymbolKind::Enum));
736 assert!(kinds.contains(&&SymbolKind::Method));
737 assert!(kinds.contains(&&SymbolKind::Function));
738 }
739
740 #[test]
741 fn test_parse_nested_namespace() {
742 let source = r#"
743namespace Outer {
744 namespace Inner {
745 void function() {}
746 }
747}
748 "#;
749
750 let symbols = parse("test.cpp", source).unwrap();
751
752 let namespace_symbols: Vec<_> = symbols
753 .iter()
754 .filter(|s| matches!(s.kind, SymbolKind::Namespace))
755 .collect();
756
757 assert_eq!(namespace_symbols.len(), 2);
758 assert!(
759 namespace_symbols
760 .iter()
761 .any(|s| s.symbol.as_deref() == Some("Outer"))
762 );
763 assert!(
764 namespace_symbols
765 .iter()
766 .any(|s| s.symbol.as_deref() == Some("Inner"))
767 );
768 }
769
770 #[test]
771 fn test_parse_virtual_methods() {
772 let source = r#"
773class Base {
774public:
775 virtual void draw() = 0;
776 virtual void update() {}
777};
778
779class Derived : public Base {
780public:
781 void draw() override {
782 // Implementation
783 }
784};
785 "#;
786
787 let symbols = parse("test.cpp", source).unwrap();
788
789 let class_symbols: Vec<_> = symbols
790 .iter()
791 .filter(|s| matches!(s.kind, SymbolKind::Class))
792 .collect();
793
794 assert_eq!(class_symbols.len(), 2);
795 assert!(
796 class_symbols
797 .iter()
798 .any(|s| s.symbol.as_deref() == Some("Base"))
799 );
800 assert!(
801 class_symbols
802 .iter()
803 .any(|s| s.symbol.as_deref() == Some("Derived"))
804 );
805
806 let method_symbols: Vec<_> = symbols
807 .iter()
808 .filter(|s| matches!(s.kind, SymbolKind::Method))
809 .collect();
810
811 assert!(method_symbols.len() >= 2);
812 }
813
814 #[test]
815 fn test_parse_operator_overload() {
816 let source = r#"
817class Complex {
818private:
819 double real, imag;
820
821public:
822 Complex operator+(const Complex& other) {
823 return Complex(real + other.real, imag + other.imag);
824 }
825};
826 "#;
827
828 let symbols = parse("test.cpp", source).unwrap();
829
830 let class_symbols: Vec<_> = symbols
831 .iter()
832 .filter(|s| matches!(s.kind, SymbolKind::Class))
833 .collect();
834
835 assert_eq!(class_symbols.len(), 1);
836 assert_eq!(class_symbols[0].symbol.as_deref(), Some("Complex"));
837 }
838
839 #[test]
840 fn test_local_variables_included() {
841 let source = r#"
842int calculate(int input) {
843 int localVar = input * 2;
844 auto result = localVar + 10;
845 return result;
846}
847
848class Calculator {
849public:
850 int compute(int value) {
851 int temp = value * 3;
852 auto final = temp + 5;
853 return final;
854 }
855};
856 "#;
857
858 let symbols = parse("test.cpp", source).unwrap();
859
860 let variables: Vec<_> = symbols
862 .iter()
863 .filter(|s| matches!(s.kind, SymbolKind::Variable))
864 .collect();
865
866 assert!(
868 variables
869 .iter()
870 .any(|v| v.symbol.as_deref() == Some("localVar"))
871 );
872 assert!(
873 variables
874 .iter()
875 .any(|v| v.symbol.as_deref() == Some("result"))
876 );
877 assert!(
878 variables
879 .iter()
880 .any(|v| v.symbol.as_deref() == Some("temp"))
881 );
882 assert!(
883 variables
884 .iter()
885 .any(|v| v.symbol.as_deref() == Some("final"))
886 );
887
888 for var in variables {
890 }
892 }
893
894 #[test]
895 fn test_parse_destructor() {
896 let source = r#"
897class Resource {
898private:
899 int* data;
900
901public:
902 Resource() {
903 data = new int[100];
904 }
905
906 ~Resource() {
907 delete[] data;
908 }
909};
910 "#;
911
912 let symbols = parse("test.cpp", source).unwrap();
913
914 let class_symbols: Vec<_> = symbols
915 .iter()
916 .filter(|s| matches!(s.kind, SymbolKind::Class))
917 .collect();
918
919 assert_eq!(class_symbols.len(), 1);
920 assert_eq!(class_symbols[0].symbol.as_deref(), Some("Resource"));
921
922 let method_symbols: Vec<_> = symbols
924 .iter()
925 .filter(|s| matches!(s.kind, SymbolKind::Method))
926 .collect();
927
928 assert!(
930 method_symbols.len() >= 1,
931 "Expected at least constructor or destructor to be extracted"
932 );
933
934 for method in &method_symbols {
936 println!("Found method: {:?}", method.symbol);
937 }
938
939 let has_destructor = method_symbols.iter().any(|s| {
941 s.symbol
942 .as_deref()
943 .map(|name| name.contains("~") || name == "Resource")
944 .unwrap_or(false)
945 });
946
947 if !has_destructor {
949 println!("WARNING: Destructor extraction may not be working");
950 }
951 }
952}
953
954use crate::models::ImportType;
959use crate::parsers::{DependencyExtractor, ImportInfo};
960
961pub struct CppDependencyExtractor;
963
964impl DependencyExtractor for CppDependencyExtractor {
965 fn extract_dependencies(source: &str) -> Result<Vec<ImportInfo>> {
966 let mut parser = Parser::new();
967 let language = tree_sitter_cpp::LANGUAGE;
968
969 parser
970 .set_language(&language.into())
971 .context("Failed to set C++ language")?;
972
973 let tree = parser
974 .parse(source, None)
975 .context("Failed to parse C++ source")?;
976
977 let root_node = tree.root_node();
978
979 let mut imports = Vec::new();
980
981 imports.extend(extract_cpp_includes(source, &root_node)?);
983
984 Ok(imports)
985 }
986}
987
988fn extract_cpp_includes(source: &str, root: &tree_sitter::Node) -> Result<Vec<ImportInfo>> {
990 let language = tree_sitter_cpp::LANGUAGE;
991
992 let query_str = r#"
993 (preproc_include
994 path: (string_literal) @include_path) @include
995
996 (preproc_include
997 path: (system_lib_string) @include_path) @include
998 "#;
999
1000 let query =
1001 Query::new(&language.into(), query_str).context("Failed to create C++ include query")?;
1002
1003 let mut cursor = QueryCursor::new();
1004 let mut matches = cursor.matches(&query, *root, source.as_bytes());
1005
1006 let mut imports = Vec::new();
1007
1008 while let Some(match_) = matches.next() {
1009 let mut include_path = None;
1010 let mut include_node = None;
1011
1012 for capture in match_.captures {
1013 let capture_name: &str = &query.capture_names()[capture.index as usize];
1014 match capture_name {
1015 "include_path" => {
1016 let raw_path = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
1018 include_path = Some(
1019 raw_path
1020 .trim_matches(|c| c == '"' || c == '<' || c == '>')
1021 .to_string(),
1022 );
1023 }
1024 "include" => {
1025 include_node = Some(capture.node);
1026 }
1027 _ => {}
1028 }
1029 }
1030
1031 if let (Some(path), Some(node)) = (include_path, include_node) {
1032 let import_type = classify_cpp_include(&path, source, &node);
1033 let line_number = node.start_position().row + 1;
1034
1035 imports.push(ImportInfo {
1036 imported_path: path,
1037 import_type,
1038 line_number,
1039 imported_symbols: None, });
1041 }
1042 }
1043
1044 Ok(imports)
1045}
1046
1047fn classify_cpp_include(include_path: &str, source: &str, node: &tree_sitter::Node) -> ImportType {
1049 let line_start = node.start_position();
1051 let lines: Vec<&str> = source.lines().collect();
1052
1053 if line_start.row < lines.len() {
1054 let line = lines[line_start.row];
1055
1056 if line.contains(&format!("\"{}\"", include_path)) {
1058 return ImportType::Internal;
1059 }
1060 }
1061
1062 const CPP_STDLIB_HEADERS: &[&str] = &[
1064 "stdio.h",
1066 "stdlib.h",
1067 "string.h",
1068 "math.h",
1069 "time.h",
1070 "ctype.h",
1071 "assert.h",
1072 "errno.h",
1073 "limits.h",
1074 "float.h",
1075 "stddef.h",
1076 "stdint.h",
1077 "stdbool.h",
1078 "stdarg.h",
1079 "setjmp.h",
1080 "signal.h",
1081 "locale.h",
1082 "wchar.h",
1083 "wctype.h",
1084 "complex.h",
1085 "fenv.h",
1086 "inttypes.h",
1087 "iso646.h",
1088 "tgmath.h",
1089 "threads.h",
1090 "algorithm",
1092 "any",
1093 "array",
1094 "atomic",
1095 "barrier",
1096 "bit",
1097 "bitset",
1098 "charconv",
1099 "chrono",
1100 "codecvt",
1101 "compare",
1102 "complex",
1103 "concepts",
1104 "condition_variable",
1105 "coroutine",
1106 "deque",
1107 "exception",
1108 "execution",
1109 "expected",
1110 "filesystem",
1111 "format",
1112 "forward_list",
1113 "fstream",
1114 "functional",
1115 "future",
1116 "initializer_list",
1117 "iomanip",
1118 "ios",
1119 "iosfwd",
1120 "iostream",
1121 "istream",
1122 "iterator",
1123 "latch",
1124 "limits",
1125 "list",
1126 "locale",
1127 "map",
1128 "mdspan",
1129 "memory",
1130 "memory_resource",
1131 "mutex",
1132 "new",
1133 "numbers",
1134 "numeric",
1135 "optional",
1136 "ostream",
1137 "queue",
1138 "random",
1139 "ranges",
1140 "ratio",
1141 "regex",
1142 "scoped_allocator",
1143 "semaphore",
1144 "set",
1145 "shared_mutex",
1146 "source_location",
1147 "span",
1148 "sstream",
1149 "stack",
1150 "stacktrace",
1151 "stdexcept",
1152 "stop_token",
1153 "streambuf",
1154 "string",
1155 "string_view",
1156 "strstream",
1157 "syncstream",
1158 "system_error",
1159 "thread",
1160 "tuple",
1161 "type_traits",
1162 "typeindex",
1163 "typeinfo",
1164 "unordered_map",
1165 "unordered_set",
1166 "utility",
1167 "valarray",
1168 "variant",
1169 "vector",
1170 "version",
1171 "cassert",
1173 "cctype",
1174 "cerrno",
1175 "cfenv",
1176 "cfloat",
1177 "cinttypes",
1178 "climits",
1179 "clocale",
1180 "cmath",
1181 "csetjmp",
1182 "csignal",
1183 "cstdarg",
1184 "cstddef",
1185 "cstdint",
1186 "cstdio",
1187 "cstdlib",
1188 "cstring",
1189 "ctime",
1190 "cuchar",
1191 "cwchar",
1192 "cwctype",
1193 ];
1194
1195 if CPP_STDLIB_HEADERS.contains(&include_path) {
1196 return ImportType::Stdlib;
1197 }
1198
1199 ImportType::External
1201}
1202
1203pub fn resolve_cpp_include_to_path(
1217 include_path: &str,
1218 current_file_path: Option<&str>,
1219) -> Option<String> {
1220 let current_file = current_file_path?;
1224
1225 let current_dir = std::path::Path::new(current_file).parent()?;
1227
1228 let resolved = current_dir.join(include_path);
1230
1231 match resolved.canonicalize() {
1233 Ok(normalized) => Some(normalized.display().to_string()),
1234 Err(_) => {
1235 Some(resolved.display().to_string())
1237 }
1238 }
1239}
1240
1241#[cfg(test)]
1246mod resolution_tests {
1247 use super::*;
1248
1249 #[test]
1250 fn test_resolve_cpp_include_same_directory() {
1251 let result = resolve_cpp_include_to_path("helper.hpp", Some("/project/src/main.cpp"));
1252
1253 assert!(result.is_some());
1254 let path = result.unwrap();
1255 assert!(path.ends_with("src/helper.hpp") || path.ends_with("src\\helper.hpp"));
1256 }
1257
1258 #[test]
1259 fn test_resolve_cpp_include_subdirectory() {
1260 let result = resolve_cpp_include_to_path("utils/helper.hpp", Some("/project/src/main.cpp"));
1261
1262 assert!(result.is_some());
1263 let path = result.unwrap();
1264 assert!(path.ends_with("src/utils/helper.hpp") || path.ends_with("src\\utils\\helper.hpp"));
1265 }
1266
1267 #[test]
1268 fn test_resolve_cpp_include_parent_directory() {
1269 let result =
1270 resolve_cpp_include_to_path("../include/common.hpp", Some("/project/src/main.cpp"));
1271
1272 assert!(result.is_some());
1273 let path = result.unwrap();
1274 assert!(path.contains("include") && path.contains("common.hpp"));
1275 }
1276
1277 #[test]
1278 fn test_resolve_cpp_include_h_extension() {
1279 let result = resolve_cpp_include_to_path("legacy.h", Some("/project/src/main.cpp"));
1280
1281 assert!(result.is_some());
1282 let path = result.unwrap();
1283 assert!(path.ends_with("src/legacy.h") || path.ends_with("src\\legacy.h"));
1284 }
1285
1286 #[test]
1287 fn test_resolve_cpp_include_no_current_file() {
1288 let result = resolve_cpp_include_to_path("helper.hpp", None);
1289
1290 assert!(result.is_none());
1291 }
1292}