probe_code/language/
rust.rs1use super::language_trait::LanguageImpl;
2use tree_sitter::{Language as TSLanguage, Node};
3
4pub struct RustLanguage;
6
7impl Default for RustLanguage {
8 fn default() -> Self {
9 Self::new()
10 }
11}
12
13impl RustLanguage {
14 pub fn new() -> Self {
15 RustLanguage
16 }
17}
18
19impl LanguageImpl for RustLanguage {
20 fn get_tree_sitter_language(&self) -> TSLanguage {
21 tree_sitter_rust::LANGUAGE.into()
22 }
23
24 fn get_extension(&self) -> &'static str {
25 "rs"
26 }
27
28 fn is_acceptable_parent(&self, node: &Node) -> bool {
29 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
30
31 if matches!(
33 node.kind(),
34 "function_item"
35 | "struct_item"
36 | "impl_item"
37 | "trait_item"
38 | "enum_item"
39 | "mod_item"
40 | "macro_definition"
41 ) {
42 return true;
43 }
44
45 if node.kind() == "expression_statement" {
47 if debug_mode {
48 println!(
49 "DEBUG: Found expression_statement at lines {}-{}",
50 node.start_position().row + 1,
51 node.end_position().row + 1
52 );
53 }
54
55 return false;
58 }
59
60 if node.kind() == "token_tree" {
62 if let Some(parent) = node.parent() {
64 if parent.kind() == "macro_invocation" {
65 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
66
67 if debug_mode {
70 println!(
71 "DEBUG: Found token_tree in macro_invocation at lines {}-{}",
72 node.start_position().row + 1,
73 node.end_position().row + 1
74 );
75 }
76
77 let node_size = node.end_position().row - node.start_position().row;
83 if node_size > 5 {
84 if debug_mode {
85 println!(
86 "DEBUG: Considering large token_tree in macro as acceptable parent (size: {node_size} lines)"
87 );
88 }
89 return true;
90 }
91 }
92 }
93 }
94
95 false
96 }
97
98 fn is_test_node(&self, node: &Node, source: &[u8]) -> bool {
99 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
100 let node_type = node.kind();
101
102 if node_type == "function_item" {
104 let mut cursor = node.walk();
105 let mut has_test_attribute = false;
106
107 for child in node.children(&mut cursor) {
109 if child.kind() == "attribute_item" {
110 let attr_text = child.utf8_text(source).unwrap_or("");
111 if attr_text.contains("#[test") {
112 has_test_attribute = true;
113 break;
114 }
115 }
116 }
117
118 if has_test_attribute {
119 if debug_mode {
120 println!("DEBUG: Test node detected (Rust): #[test] attribute");
121 }
122 return true;
123 }
124
125 for child in node.children(&mut cursor) {
127 if child.kind() == "identifier" {
128 let name = child.utf8_text(source).unwrap_or("");
129 if name.starts_with("test_") {
130 if debug_mode {
131 println!("DEBUG: Test node detected (Rust): test_ function");
132 }
133 return true;
134 }
135 }
136 }
137 }
138
139 false
140 }
141}