probe_code/language/
rust.rs

1use super::language_trait::LanguageImpl;
2use tree_sitter::{Language as TSLanguage, Node};
3
4/// Implementation of LanguageImpl for Rust
5pub struct RustLanguage;
6
7impl Default for RustLanguage {
8    fn default() -> Self {
9        Self::new()
10    }
11}
12
13impl RustLanguage {
14    pub fn new() -> Self {
15        RustLanguage
16    }
17}
18
19impl LanguageImpl for RustLanguage {
20    fn get_tree_sitter_language(&self) -> TSLanguage {
21        tree_sitter_rust::LANGUAGE.into()
22    }
23
24    fn get_extension(&self) -> &'static str {
25        "rs"
26    }
27
28    fn is_acceptable_parent(&self, node: &Node) -> bool {
29        let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
30
31        // Check for standard Rust items
32        if matches!(
33            node.kind(),
34            "function_item"
35                | "struct_item"
36                | "impl_item"
37                | "trait_item"
38                | "enum_item"
39                | "mod_item"
40                | "macro_definition"
41        ) {
42            return true;
43        }
44
45        // For expression_statement nodes, we need to find the parent function
46        if node.kind() == "expression_statement" {
47            if debug_mode {
48                println!(
49                    "DEBUG: Found expression_statement at lines {}-{}",
50                    node.start_position().row + 1,
51                    node.end_position().row + 1
52                );
53            }
54
55            // Instead of returning true directly, we'll look for the parent function
56            // and return that node in the parser.rs code
57            return false;
58        }
59
60        // Special handling for token trees inside macros
61        if node.kind() == "token_tree" {
62            // Check if this token tree is inside a macro invocation
63            if let Some(parent) = node.parent() {
64                if parent.kind() == "macro_invocation" {
65                    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
66
67                    // For Rust property tests, we want to consider token trees inside macros
68                    // as acceptable parents, especially for proptest! macros
69                    if debug_mode {
70                        println!(
71                            "DEBUG: Found token_tree in macro_invocation at lines {}-{}",
72                            node.start_position().row + 1,
73                            node.end_position().row + 1
74                        );
75                    }
76
77                    // We previously tried to use the file path as a heuristic,
78                    // but we don't have access to the actual file path here
79
80                    // If the token tree is large enough (contains multiple lines of code),
81                    // it's likely a meaningful code block that should be extracted
82                    let node_size = node.end_position().row - node.start_position().row;
83                    if node_size > 5 {
84                        if debug_mode {
85                            println!(
86                                "DEBUG: Considering large token_tree in macro as acceptable parent (size: {node_size} lines)"
87                            );
88                        }
89                        return true;
90                    }
91                }
92            }
93        }
94
95        false
96    }
97
98    fn is_test_node(&self, node: &Node, source: &[u8]) -> bool {
99        let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
100        let node_type = node.kind();
101
102        // Rust: Check for #[test] attribute on function_item nodes
103        if node_type == "function_item" {
104            let mut cursor = node.walk();
105            let mut has_test_attribute = false;
106
107            // Look for attribute nodes
108            for child in node.children(&mut cursor) {
109                if child.kind() == "attribute_item" {
110                    let attr_text = child.utf8_text(source).unwrap_or("");
111                    if attr_text.contains("#[test") {
112                        has_test_attribute = true;
113                        break;
114                    }
115                }
116            }
117
118            if has_test_attribute {
119                if debug_mode {
120                    println!("DEBUG: Test node detected (Rust): #[test] attribute");
121                }
122                return true;
123            }
124
125            // Also check function name starting with "test_"
126            for child in node.children(&mut cursor) {
127                if child.kind() == "identifier" {
128                    let name = child.utf8_text(source).unwrap_or("");
129                    if name.starts_with("test_") {
130                        if debug_mode {
131                            println!("DEBUG: Test node detected (Rust): test_ function");
132                        }
133                        return true;
134                    }
135                }
136            }
137        }
138
139        false
140    }
141}