Skip to main content

reflex/parsers/
cpp.rs

1//! C++ language parser using Tree-sitter
2//!
3//! Extracts symbols from C++ source code:
4//! - Functions (regular and template)
5//! - Classes (regular, abstract, template)
6//! - Structs
7//! - Namespaces
8//! - Templates (class and function)
9//! - Methods (with class scope, virtual, override)
10//! - Constructors/Destructors
11//! - Operators
12//! - Enums (enum and enum class)
13//! - Local variables (inside functions and methods)
14//! - Using declarations
15//! - Type aliases
16
17use anyhow::{Context, Result};
18use streaming_iterator::StreamingIterator;
19use tree_sitter::{Parser, Query, QueryCursor};
20use crate::models::{Language, SearchResult, Span, SymbolKind};
21
22/// Parse C++ source code and extract symbols
23pub fn parse(path: &str, source: &str) -> Result<Vec<SearchResult>> {
24    let mut parser = Parser::new();
25    let language = tree_sitter_cpp::LANGUAGE;
26
27    parser
28        .set_language(&language.into())
29        .context("Failed to set C++ language")?;
30
31    let tree = parser
32        .parse(source, None)
33        .context("Failed to parse C++ source")?;
34
35    let root_node = tree.root_node();
36
37    let mut symbols = Vec::new();
38
39    // Extract different types of symbols using Tree-sitter queries
40    symbols.extend(extract_functions(source, &root_node, &language.into())?);
41    symbols.extend(extract_classes(source, &root_node, &language.into())?);
42    symbols.extend(extract_structs(source, &root_node, &language.into())?);
43    symbols.extend(extract_namespaces(source, &root_node, &language.into())?);
44    symbols.extend(extract_enums(source, &root_node, &language.into())?);
45    symbols.extend(extract_methods(source, &root_node, &language.into())?);
46    symbols.extend(extract_local_variables(source, &root_node, &language.into())?);
47    symbols.extend(extract_type_aliases(source, &root_node, &language.into())?);
48
49    // Add file path to all symbols
50    for symbol in &mut symbols {
51        symbol.path = path.to_string();
52        symbol.lang = Language::Cpp;
53    }
54
55    Ok(symbols)
56}
57
58/// Extract function declarations and definitions
59fn extract_functions(
60    source: &str,
61    root: &tree_sitter::Node,
62    language: &tree_sitter::Language,
63) -> Result<Vec<SearchResult>> {
64    let query_str = r#"
65        (function_definition
66            declarator: (function_declarator
67                declarator: (identifier) @name)) @function
68
69        (function_definition
70            declarator: (function_declarator
71                declarator: (qualified_identifier
72                    name: (identifier) @name))) @function
73
74        (template_declaration
75            (function_definition
76                declarator: (function_declarator
77                    declarator: (identifier) @name))) @function
78    "#;
79
80    let query = Query::new(language, query_str)
81        .context("Failed to create function query")?;
82
83    extract_symbols(source, root, &query, SymbolKind::Function, None)
84}
85
86/// Extract class declarations
87fn extract_classes(
88    source: &str,
89    root: &tree_sitter::Node,
90    language: &tree_sitter::Language,
91) -> Result<Vec<SearchResult>> {
92    let query_str = r#"
93        (class_specifier
94            name: (type_identifier) @name) @class
95
96        (template_declaration
97            (class_specifier
98                name: (type_identifier) @name)) @class
99    "#;
100
101    let query = Query::new(language, query_str)
102        .context("Failed to create class query")?;
103
104    extract_symbols(source, root, &query, SymbolKind::Class, None)
105}
106
107/// Extract struct declarations
108fn extract_structs(
109    source: &str,
110    root: &tree_sitter::Node,
111    language: &tree_sitter::Language,
112) -> Result<Vec<SearchResult>> {
113    let query_str = r#"
114        (struct_specifier
115            name: (type_identifier) @name) @struct
116
117        (template_declaration
118            (struct_specifier
119                name: (type_identifier) @name)) @struct
120    "#;
121
122    let query = Query::new(language, query_str)
123        .context("Failed to create struct query")?;
124
125    extract_symbols(source, root, &query, SymbolKind::Struct, None)
126}
127
128/// Extract namespace definitions
129fn extract_namespaces(
130    source: &str,
131    root: &tree_sitter::Node,
132    language: &tree_sitter::Language,
133) -> Result<Vec<SearchResult>> {
134    let query_str = r#"
135        (namespace_definition
136            name: (_) @name) @namespace
137    "#;
138
139    let query = Query::new(language, query_str)
140        .context("Failed to create namespace query")?;
141
142    extract_symbols(source, root, &query, SymbolKind::Namespace, None)
143}
144
145/// Extract enum declarations
146fn extract_enums(
147    source: &str,
148    root: &tree_sitter::Node,
149    language: &tree_sitter::Language,
150) -> Result<Vec<SearchResult>> {
151    let query_str = r#"
152        (enum_specifier
153            name: (type_identifier) @name) @enum
154    "#;
155
156    let query = Query::new(language, query_str)
157        .context("Failed to create enum query")?;
158
159    extract_symbols(source, root, &query, SymbolKind::Enum, None)
160}
161
162/// Extract method definitions from classes and structs
163fn extract_methods(
164    source: &str,
165    root: &tree_sitter::Node,
166    language: &tree_sitter::Language,
167) -> Result<Vec<SearchResult>> {
168    let query_str = r#"
169        (class_specifier
170            name: (type_identifier) @class_name
171            body: (field_declaration_list
172                (function_definition
173                    declarator: (function_declarator
174                        declarator: (field_identifier) @method_name)))) @class
175
176        (class_specifier
177            name: (type_identifier) @class_name
178            body: (field_declaration_list
179                (function_definition
180                    declarator: (function_declarator
181                        declarator: (destructor_name) @method_name)))) @class
182
183        (struct_specifier
184            name: (type_identifier) @struct_name
185            body: (field_declaration_list
186                (function_definition
187                    declarator: (function_declarator
188                        declarator: (field_identifier) @method_name)))) @struct
189
190        (struct_specifier
191            name: (type_identifier) @struct_name
192            body: (field_declaration_list
193                (function_definition
194                    declarator: (function_declarator
195                        declarator: (destructor_name) @method_name)))) @struct
196    "#;
197
198    let query = Query::new(language, query_str)
199        .context("Failed to create method query")?;
200
201    let mut cursor = QueryCursor::new();
202    let mut matches = cursor.matches(&query, *root, source.as_bytes());
203
204    let mut symbols = Vec::new();
205
206    while let Some(match_) = matches.next() {
207        let mut scope_name = None;
208        let mut scope_type = None;
209        let mut method_name = None;
210        let mut method_node = None;
211
212        for capture in match_.captures {
213            let capture_name: &str = &query.capture_names()[capture.index as usize];
214            match capture_name {
215                "class_name" => {
216                    scope_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
217                    scope_type = Some("class");
218                }
219                "struct_name" => {
220                    scope_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
221                    scope_type = Some("struct");
222                }
223                "method_name" => {
224                    method_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
225                    // Find the parent function_definition node
226                    let mut current = capture.node;
227                    while let Some(parent) = current.parent() {
228                        if parent.kind() == "function_definition" {
229                            method_node = Some(parent);
230                            break;
231                        }
232                        current = parent;
233                    }
234                }
235                _ => {}
236            }
237        }
238
239        if let (Some(scope_name), Some(scope_type), Some(method_name), Some(node)) =
240            (scope_name, scope_type, method_name, method_node) {
241            let scope = format!("{} {}", scope_type, scope_name);
242            let span = node_to_span(&node);
243            let preview = extract_preview(source, &span);
244
245            symbols.push(SearchResult::new(
246                String::new(),
247                Language::Cpp,
248                SymbolKind::Method,
249                Some(method_name),
250                span,
251                Some(scope),
252                preview,
253            ));
254        }
255    }
256
257    Ok(symbols)
258}
259
260/// Extract local variable declarations inside functions and methods
261fn extract_local_variables(
262    source: &str,
263    root: &tree_sitter::Node,
264    language: &tree_sitter::Language,
265) -> Result<Vec<SearchResult>> {
266    let query_str = r#"
267        (declaration
268            declarator: (init_declarator
269                declarator: (identifier) @name)) @var
270    "#;
271
272    let query = Query::new(language, query_str)
273        .context("Failed to create local variable query")?;
274
275    let mut cursor = QueryCursor::new();
276    let mut matches = cursor.matches(&query, *root, source.as_bytes());
277
278    let mut symbols = Vec::new();
279
280    while let Some(match_) = matches.next() {
281        let mut name = None;
282        let mut var_node = None;
283
284        for capture in match_.captures {
285            let capture_name: &str = &query.capture_names()[capture.index as usize];
286            match capture_name {
287                "name" => {
288                    name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
289                }
290                "var" => {
291                    var_node = Some(capture.node);
292                }
293                _ => {}
294            }
295        }
296
297        // Only extract variables that are inside function definitions (local variables)
298        if let (Some(name), Some(node)) = (name, var_node) {
299            let mut is_local_var = false;
300            let mut current = node;
301
302            while let Some(parent) = current.parent() {
303                if parent.kind() == "function_definition" {
304                    is_local_var = true;
305                    break;
306                }
307                current = parent;
308            }
309
310            if is_local_var {
311                let span = node_to_span(&node);
312                let preview = extract_preview(source, &span);
313
314                symbols.push(SearchResult::new(
315                    String::new(),
316                    Language::Cpp,
317                    SymbolKind::Variable,
318                    Some(name),
319                    span,
320                    None,  // No scope for local variables
321                    preview,
322                ));
323            }
324        }
325    }
326
327    Ok(symbols)
328}
329
330/// Extract type aliases (using and typedef)
331fn extract_type_aliases(
332    source: &str,
333    root: &tree_sitter::Node,
334    language: &tree_sitter::Language,
335) -> Result<Vec<SearchResult>> {
336    let query_str = r#"
337        (type_definition
338            declarator: (type_identifier) @name) @typedef
339
340        (alias_declaration
341            name: (type_identifier) @name) @using
342    "#;
343
344    let query = Query::new(language, query_str)
345        .context("Failed to create type alias query")?;
346
347    extract_symbols(source, root, &query, SymbolKind::Type, None)
348}
349
350/// Generic symbol extraction helper
351fn extract_symbols(
352    source: &str,
353    root: &tree_sitter::Node,
354    query: &Query,
355    kind: SymbolKind,
356    scope: Option<String>,
357) -> Result<Vec<SearchResult>> {
358    let mut cursor = QueryCursor::new();
359    let mut matches = cursor.matches(query, *root, source.as_bytes());
360
361    let mut symbols = Vec::new();
362    let mut seen_names = std::collections::HashSet::new();
363
364    while let Some(match_) = matches.next() {
365        // Find the name capture and the full node
366        let mut name = None;
367        let mut name_node = None;
368        let mut full_node = None;
369
370        for capture in match_.captures {
371            let capture_name: &str = &query.capture_names()[capture.index as usize];
372            if capture_name == "name" {
373                name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
374                name_node = Some(capture.node);
375            } else {
376                // Assume any other capture is the full node
377                full_node = Some(capture.node);
378            }
379        }
380
381        if let (Some(name), Some(name_node), Some(node)) = (name, name_node, full_node) {
382            // Deduplicate by name position - this handles cases where template patterns
383            // match the same symbol twice (e.g., both template_declaration and class_specifier)
384            let name_key = (name_node.start_byte(), name_node.end_byte(), name.clone());
385            if seen_names.contains(&name_key) {
386                continue; // Skip duplicate
387            }
388            seen_names.insert(name_key);
389
390            let span = node_to_span(&node);
391            let preview = extract_preview(source, &span);
392
393            symbols.push(SearchResult::new(
394                String::new(),
395                Language::Cpp,
396                kind.clone(),
397                Some(name),
398                span,
399                scope.clone(),
400                preview,
401            ));
402        }
403    }
404
405    Ok(symbols)
406}
407
408/// Convert a Tree-sitter node to a Span
409fn node_to_span(node: &tree_sitter::Node) -> Span {
410    let start = node.start_position();
411    let end = node.end_position();
412
413    Span::new(
414        start.row + 1,  // Convert 0-indexed to 1-indexed
415        start.column,
416        end.row + 1,
417        end.column,
418    )
419}
420
421/// Extract a preview (7 lines) around the symbol
422fn extract_preview(source: &str, span: &Span) -> String {
423    let lines: Vec<&str> = source.lines().collect();
424
425    // Extract 7 lines: the start line and 6 following lines
426    let start_idx = (span.start_line - 1) as usize; // Convert back to 0-indexed
427    let end_idx = (start_idx + 7).min(lines.len());
428
429    lines[start_idx..end_idx].join("\n")
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_parse_function() {
438        let source = r#"
439int add(int a, int b) {
440    return a + b;
441}
442        "#;
443
444        let symbols = parse("test.cpp", source).unwrap();
445        assert_eq!(symbols.len(), 1);
446        assert_eq!(symbols[0].symbol.as_deref(), Some("add"));
447        assert!(matches!(symbols[0].kind, SymbolKind::Function));
448    }
449
450    #[test]
451    fn test_parse_class() {
452        let source = r#"
453class User {
454private:
455    std::string name;
456    int age;
457
458public:
459    User(std::string n, int a) : name(n), age(a) {}
460};
461        "#;
462
463        let symbols = parse("test.cpp", source).unwrap();
464
465        let class_symbols: Vec<_> = symbols.iter()
466            .filter(|s| matches!(s.kind, SymbolKind::Class))
467            .collect();
468
469        assert_eq!(class_symbols.len(), 1);
470        assert_eq!(class_symbols[0].symbol.as_deref(), Some("User"));
471    }
472
473    #[test]
474    fn test_parse_namespace() {
475        let source = r#"
476namespace MyNamespace {
477    int value = 42;
478}
479
480namespace Nested::Namespace {
481    void function() {}
482}
483        "#;
484
485        let symbols = parse("test.cpp", source).unwrap();
486
487        let namespace_symbols: Vec<_> = symbols.iter()
488            .filter(|s| matches!(s.kind, SymbolKind::Namespace))
489            .collect();
490
491        assert!(namespace_symbols.len() >= 1);
492        assert!(namespace_symbols.iter().any(|s| s.symbol.as_deref() == Some("MyNamespace")));
493    }
494
495    #[test]
496    fn test_parse_struct() {
497        let source = r#"
498struct Point {
499    int x;
500    int y;
501};
502        "#;
503
504        let symbols = parse("test.cpp", source).unwrap();
505        assert_eq!(symbols.len(), 1);
506        assert_eq!(symbols[0].symbol.as_deref(), Some("Point"));
507        assert!(matches!(symbols[0].kind, SymbolKind::Struct));
508    }
509
510    #[test]
511    fn test_parse_enum() {
512        let source = r#"
513enum Color {
514    RED,
515    GREEN,
516    BLUE
517};
518
519enum class Status {
520    Active,
521    Inactive
522};
523        "#;
524
525        let symbols = parse("test.cpp", source).unwrap();
526
527        let enum_symbols: Vec<_> = symbols.iter()
528            .filter(|s| matches!(s.kind, SymbolKind::Enum))
529            .collect();
530
531        assert_eq!(enum_symbols.len(), 2);
532        assert!(enum_symbols.iter().any(|s| s.symbol.as_deref() == Some("Color")));
533        assert!(enum_symbols.iter().any(|s| s.symbol.as_deref() == Some("Status")));
534    }
535
536    #[test]
537    fn test_parse_template_class() {
538        let source = r#"
539template <typename T>
540class Container {
541private:
542    T value;
543
544public:
545    Container(T v) : value(v) {}
546    T getValue() { return value; }
547};
548        "#;
549
550        let symbols = parse("test.cpp", source).unwrap();
551
552        let class_symbols: Vec<_> = symbols.iter()
553            .filter(|s| matches!(s.kind, SymbolKind::Class))
554            .collect();
555
556        assert_eq!(class_symbols.len(), 1);
557        assert_eq!(class_symbols[0].symbol.as_deref(), Some("Container"));
558    }
559
560    #[test]
561    fn test_parse_template_function() {
562        let source = r#"
563template <typename T>
564T max(T a, T b) {
565    return (a > b) ? a : b;
566}
567        "#;
568
569        let symbols = parse("test.cpp", source).unwrap();
570        assert_eq!(symbols.len(), 1);
571        assert_eq!(symbols[0].symbol.as_deref(), Some("max"));
572        assert!(matches!(symbols[0].kind, SymbolKind::Function));
573    }
574
575    #[test]
576    fn test_parse_class_with_methods() {
577        let source = r#"
578class Calculator {
579public:
580    int add(int a, int b) {
581        return a + b;
582    }
583
584    int subtract(int a, int b) {
585        return a - b;
586    }
587};
588        "#;
589
590        let symbols = parse("test.cpp", source).unwrap();
591
592        let method_symbols: Vec<_> = symbols.iter()
593            .filter(|s| matches!(s.kind, SymbolKind::Method))
594            .collect();
595
596        assert_eq!(method_symbols.len(), 2);
597        assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("add")));
598        assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("subtract")));
599
600        // Check scope
601        for method in method_symbols {
602            // Removed: scope field no longer exists: assert_eq!(method.scope.as_ref().unwrap(), "class Calculator");
603        }
604    }
605
606    #[test]
607    fn test_parse_using_declaration() {
608        let source = r#"
609using StringVector = std::vector<std::string>;
610using IntPtr = int*;
611        "#;
612
613        let symbols = parse("test.cpp", source).unwrap();
614
615        let type_symbols: Vec<_> = symbols.iter()
616            .filter(|s| matches!(s.kind, SymbolKind::Type))
617            .collect();
618
619        assert!(type_symbols.len() >= 1);
620        assert!(type_symbols.iter().any(|s| s.symbol.as_deref() == Some("StringVector")));
621    }
622
623    #[test]
624    fn test_parse_typedef() {
625        let source = r#"
626typedef unsigned int uint;
627typedef struct {
628    int x, y;
629} Point;
630        "#;
631
632        let symbols = parse("test.cpp", source).unwrap();
633
634        let type_symbols: Vec<_> = symbols.iter()
635            .filter(|s| matches!(s.kind, SymbolKind::Type))
636            .collect();
637
638        assert!(type_symbols.len() >= 1);
639    }
640
641    #[test]
642    fn test_parse_mixed_symbols() {
643        let source = r#"
644namespace Math {
645    class Vector {
646    private:
647        double x, y;
648
649    public:
650        Vector(double x, double y) : x(x), y(y) {}
651
652        double magnitude() {
653            return sqrt(x*x + y*y);
654        }
655    };
656
657    enum Operation {
658        ADD,
659        SUBTRACT
660    };
661
662    template <typename T>
663    T multiply(T a, T b) {
664        return a * b;
665    }
666}
667        "#;
668
669        let symbols = parse("test.cpp", source).unwrap();
670
671        // Should find: namespace, class, enum, method, function
672        assert!(symbols.len() >= 5);
673
674        let kinds: Vec<&SymbolKind> = symbols.iter().map(|s| &s.kind).collect();
675        assert!(kinds.contains(&&SymbolKind::Namespace));
676        assert!(kinds.contains(&&SymbolKind::Class));
677        assert!(kinds.contains(&&SymbolKind::Enum));
678        assert!(kinds.contains(&&SymbolKind::Method));
679        assert!(kinds.contains(&&SymbolKind::Function));
680    }
681
682    #[test]
683    fn test_parse_nested_namespace() {
684        let source = r#"
685namespace Outer {
686    namespace Inner {
687        void function() {}
688    }
689}
690        "#;
691
692        let symbols = parse("test.cpp", source).unwrap();
693
694        let namespace_symbols: Vec<_> = symbols.iter()
695            .filter(|s| matches!(s.kind, SymbolKind::Namespace))
696            .collect();
697
698        assert_eq!(namespace_symbols.len(), 2);
699        assert!(namespace_symbols.iter().any(|s| s.symbol.as_deref() == Some("Outer")));
700        assert!(namespace_symbols.iter().any(|s| s.symbol.as_deref() == Some("Inner")));
701    }
702
703    #[test]
704    fn test_parse_virtual_methods() {
705        let source = r#"
706class Base {
707public:
708    virtual void draw() = 0;
709    virtual void update() {}
710};
711
712class Derived : public Base {
713public:
714    void draw() override {
715        // Implementation
716    }
717};
718        "#;
719
720        let symbols = parse("test.cpp", source).unwrap();
721
722        let class_symbols: Vec<_> = symbols.iter()
723            .filter(|s| matches!(s.kind, SymbolKind::Class))
724            .collect();
725
726        assert_eq!(class_symbols.len(), 2);
727        assert!(class_symbols.iter().any(|s| s.symbol.as_deref() == Some("Base")));
728        assert!(class_symbols.iter().any(|s| s.symbol.as_deref() == Some("Derived")));
729
730        let method_symbols: Vec<_> = symbols.iter()
731            .filter(|s| matches!(s.kind, SymbolKind::Method))
732            .collect();
733
734        assert!(method_symbols.len() >= 2);
735    }
736
737    #[test]
738    fn test_parse_operator_overload() {
739        let source = r#"
740class Complex {
741private:
742    double real, imag;
743
744public:
745    Complex operator+(const Complex& other) {
746        return Complex(real + other.real, imag + other.imag);
747    }
748};
749        "#;
750
751        let symbols = parse("test.cpp", source).unwrap();
752
753        let class_symbols: Vec<_> = symbols.iter()
754            .filter(|s| matches!(s.kind, SymbolKind::Class))
755            .collect();
756
757        assert_eq!(class_symbols.len(), 1);
758        assert_eq!(class_symbols[0].symbol.as_deref(), Some("Complex"));
759    }
760
761    #[test]
762    fn test_local_variables_included() {
763        let source = r#"
764int calculate(int input) {
765    int localVar = input * 2;
766    auto result = localVar + 10;
767    return result;
768}
769
770class Calculator {
771public:
772    int compute(int value) {
773        int temp = value * 3;
774        auto final = temp + 5;
775        return final;
776    }
777};
778        "#;
779
780        let symbols = parse("test.cpp", source).unwrap();
781
782        // Filter to just variables
783        let variables: Vec<_> = symbols.iter()
784            .filter(|s| matches!(s.kind, SymbolKind::Variable))
785            .collect();
786
787        // Check that local variables are captured
788        assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("localVar")));
789        assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("result")));
790        assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("temp")));
791        assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("final")));
792
793        // Verify that local variables have no scope
794        for var in variables {
795            // Removed: scope field no longer exists: assert_eq!(var.scope, None);
796        }
797    }
798
799    #[test]
800    fn test_parse_destructor() {
801        let source = r#"
802class Resource {
803private:
804    int* data;
805
806public:
807    Resource() {
808        data = new int[100];
809    }
810
811    ~Resource() {
812        delete[] data;
813    }
814};
815        "#;
816
817        let symbols = parse("test.cpp", source).unwrap();
818
819        let class_symbols: Vec<_> = symbols.iter()
820            .filter(|s| matches!(s.kind, SymbolKind::Class))
821            .collect();
822
823        assert_eq!(class_symbols.len(), 1);
824        assert_eq!(class_symbols[0].symbol.as_deref(), Some("Resource"));
825
826        // Check if destructor is extracted
827        let method_symbols: Vec<_> = symbols.iter()
828            .filter(|s| matches!(s.kind, SymbolKind::Method))
829            .collect();
830
831        // Should have both constructor and destructor
832        assert!(method_symbols.len() >= 1, "Expected at least constructor or destructor to be extracted");
833
834        // Print what methods we found for debugging
835        for method in &method_symbols {
836            println!("Found method: {:?}", method.symbol);
837        }
838
839        // Check if destructor is present (might be ~Resource or just Resource)
840        let has_destructor = method_symbols.iter().any(|s| {
841            s.symbol.as_deref()
842                .map(|name| name.contains("~") || name == "Resource")
843                .unwrap_or(false)
844        });
845
846        // This test documents current behavior - we'll fix if destructors aren't extracted
847        if !has_destructor {
848            println!("WARNING: Destructor extraction may not be working");
849        }
850    }
851}
852
853// ============================================================================
854// Dependency Extraction
855// ============================================================================
856
857use crate::models::ImportType;
858use crate::parsers::{DependencyExtractor, ImportInfo};
859
860/// C++ dependency extractor
861pub struct CppDependencyExtractor;
862
863impl DependencyExtractor for CppDependencyExtractor {
864    fn extract_dependencies(source: &str) -> Result<Vec<ImportInfo>> {
865        let mut parser = Parser::new();
866        let language = tree_sitter_cpp::LANGUAGE;
867
868        parser
869            .set_language(&language.into())
870            .context("Failed to set C++ language")?;
871
872        let tree = parser
873            .parse(source, None)
874            .context("Failed to parse C++ source")?;
875
876        let root_node = tree.root_node();
877
878        let mut imports = Vec::new();
879
880        // Extract #include directives
881        imports.extend(extract_cpp_includes(source, &root_node)?);
882
883        Ok(imports)
884    }
885}
886
887/// Extract C++ #include directives
888fn extract_cpp_includes(
889    source: &str,
890    root: &tree_sitter::Node,
891) -> Result<Vec<ImportInfo>> {
892    let language = tree_sitter_cpp::LANGUAGE;
893
894    let query_str = r#"
895        (preproc_include
896            path: (string_literal) @include_path) @include
897
898        (preproc_include
899            path: (system_lib_string) @include_path) @include
900    "#;
901
902    let query = Query::new(&language.into(), query_str)
903        .context("Failed to create C++ include query")?;
904
905    let mut cursor = QueryCursor::new();
906    let mut matches = cursor.matches(&query, *root, source.as_bytes());
907
908    let mut imports = Vec::new();
909
910    while let Some(match_) = matches.next() {
911        let mut include_path = None;
912        let mut include_node = None;
913
914        for capture in match_.captures {
915            let capture_name: &str = &query.capture_names()[capture.index as usize];
916            match capture_name {
917                "include_path" => {
918                    // Remove quotes or angle brackets from path
919                    let raw_path = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
920                    include_path = Some(raw_path.trim_matches(|c| c == '"' || c == '<' || c == '>').to_string());
921                }
922                "include" => {
923                    include_node = Some(capture.node);
924                }
925                _ => {}
926            }
927        }
928
929        if let (Some(path), Some(node)) = (include_path, include_node) {
930            let import_type = classify_cpp_include(&path, source, &node);
931            let line_number = node.start_position().row + 1;
932
933            imports.push(ImportInfo {
934                imported_path: path,
935                import_type,
936                line_number,
937                imported_symbols: None, // C++ includes entire header
938            });
939        }
940    }
941
942    Ok(imports)
943}
944
945/// Classify a C++ include as internal, external, or stdlib
946fn classify_cpp_include(include_path: &str, source: &str, node: &tree_sitter::Node) -> ImportType {
947    // Get the actual #include line to check if it uses quotes or angle brackets
948    let line_start = node.start_position();
949    let lines: Vec<&str> = source.lines().collect();
950
951    if line_start.row < lines.len() {
952        let line = lines[line_start.row];
953
954        // Internal: #include "..." (quotes = local project files)
955        if line.contains(&format!("\"{}\"", include_path)) {
956            return ImportType::Internal;
957        }
958    }
959
960    // C++ standard library headers (angle brackets)
961    const CPP_STDLIB_HEADERS: &[&str] = &[
962        // C standard library (inherited)
963        "stdio.h", "stdlib.h", "string.h", "math.h", "time.h",
964        "ctype.h", "assert.h", "errno.h", "limits.h", "float.h",
965        "stddef.h", "stdint.h", "stdbool.h", "stdarg.h", "setjmp.h",
966        "signal.h", "locale.h", "wchar.h", "wctype.h", "complex.h",
967        "fenv.h", "inttypes.h", "iso646.h", "tgmath.h", "threads.h",
968
969        // C++ standard library headers (no .h extension)
970        "algorithm", "any", "array", "atomic", "barrier", "bit",
971        "bitset", "charconv", "chrono", "codecvt", "compare", "complex",
972        "concepts", "condition_variable", "coroutine", "deque", "exception",
973        "execution", "expected", "filesystem", "format", "forward_list",
974        "fstream", "functional", "future", "initializer_list", "iomanip",
975        "ios", "iosfwd", "iostream", "istream", "iterator", "latch",
976        "limits", "list", "locale", "map", "mdspan", "memory",
977        "memory_resource", "mutex", "new", "numbers", "numeric", "optional",
978        "ostream", "queue", "random", "ranges", "ratio", "regex",
979        "scoped_allocator", "semaphore", "set", "shared_mutex", "source_location",
980        "span", "sstream", "stack", "stacktrace", "stdexcept", "stop_token",
981        "streambuf", "string", "string_view", "strstream", "syncstream",
982        "system_error", "thread", "tuple", "type_traits", "typeindex",
983        "typeinfo", "unordered_map", "unordered_set", "utility", "valarray",
984        "variant", "vector", "version",
985
986        // C++ C-compatibility headers (c-prefixed)
987        "cassert", "cctype", "cerrno", "cfenv", "cfloat", "cinttypes",
988        "climits", "clocale", "cmath", "csetjmp", "csignal", "cstdarg",
989        "cstddef", "cstdint", "cstdio", "cstdlib", "cstring", "ctime",
990        "cuchar", "cwchar", "cwctype",
991    ];
992
993    if CPP_STDLIB_HEADERS.contains(&include_path) {
994        return ImportType::Stdlib;
995    }
996
997    // Everything else with angle brackets is external (third-party libraries)
998    ImportType::External
999}
1000
1001// ============================================================================
1002// Path Resolution
1003// ============================================================================
1004
1005/// Resolve a C++ #include directive to a file path
1006///
1007/// # Arguments
1008/// * `include_path` - The path from the #include directive (e.g., "utils/helper.hpp")
1009/// * `current_file_path` - Path to the file containing the #include directive
1010///
1011/// # Returns
1012/// * `Some(path)` if the include can be resolved (quoted includes only)
1013/// * `None` for angle bracket includes (system/library headers)
1014pub fn resolve_cpp_include_to_path(
1015    include_path: &str,
1016    current_file_path: Option<&str>,
1017) -> Option<String> {
1018    // Only resolve relative includes (quoted includes, which are Internal)
1019    // Angle bracket includes are system/library headers and won't be resolved
1020
1021    let current_file = current_file_path?;
1022
1023    // Get directory of current file
1024    let current_dir = std::path::Path::new(current_file).parent()?;
1025
1026    // Resolve the include path relative to current file
1027    let resolved = current_dir.join(include_path);
1028
1029    // Normalize the path
1030    match resolved.canonicalize() {
1031        Ok(normalized) => Some(normalized.display().to_string()),
1032        Err(_) => {
1033            // If canonicalize fails (file doesn't exist yet), return the joined path
1034            Some(resolved.display().to_string())
1035        }
1036    }
1037}
1038
1039// ============================================================================
1040// Tests for Path Resolution
1041// ============================================================================
1042
1043#[cfg(test)]
1044mod resolution_tests {
1045    use super::*;
1046
1047    #[test]
1048    fn test_resolve_cpp_include_same_directory() {
1049        let result = resolve_cpp_include_to_path(
1050            "helper.hpp",
1051            Some("/project/src/main.cpp"),
1052        );
1053
1054        assert!(result.is_some());
1055        let path = result.unwrap();
1056        assert!(path.ends_with("src/helper.hpp") || path.ends_with("src\\helper.hpp"));
1057    }
1058
1059    #[test]
1060    fn test_resolve_cpp_include_subdirectory() {
1061        let result = resolve_cpp_include_to_path(
1062            "utils/helper.hpp",
1063            Some("/project/src/main.cpp"),
1064        );
1065
1066        assert!(result.is_some());
1067        let path = result.unwrap();
1068        assert!(path.ends_with("src/utils/helper.hpp") || path.ends_with("src\\utils\\helper.hpp"));
1069    }
1070
1071    #[test]
1072    fn test_resolve_cpp_include_parent_directory() {
1073        let result = resolve_cpp_include_to_path(
1074            "../include/common.hpp",
1075            Some("/project/src/main.cpp"),
1076        );
1077
1078        assert!(result.is_some());
1079        let path = result.unwrap();
1080        assert!(path.contains("include") && path.contains("common.hpp"));
1081    }
1082
1083    #[test]
1084    fn test_resolve_cpp_include_h_extension() {
1085        let result = resolve_cpp_include_to_path(
1086            "legacy.h",
1087            Some("/project/src/main.cpp"),
1088        );
1089
1090        assert!(result.is_some());
1091        let path = result.unwrap();
1092        assert!(path.ends_with("src/legacy.h") || path.ends_with("src\\legacy.h"));
1093    }
1094
1095    #[test]
1096    fn test_resolve_cpp_include_no_current_file() {
1097        let result = resolve_cpp_include_to_path(
1098            "helper.hpp",
1099            None,
1100        );
1101
1102        assert!(result.is_none());
1103    }
1104}