forgekit_core/treesitter/
mod.rs1mod c;
7mod cfg_builder;
8mod java;
9mod rust;
10
11use crate::cfg::TestCfg;
12use crate::error::Result;
13
14#[derive(Debug, Clone)]
16pub struct FunctionInfo {
17 pub name: String,
18 pub start_byte: usize,
19 pub end_byte: usize,
20 pub cfg: TestCfg,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SupportedLanguage {
26 C,
27 Java,
28 Rust,
29}
30
31pub struct CfgExtractor;
33
34impl CfgExtractor {
35 pub fn detect_language(path: &std::path::Path) -> Option<SupportedLanguage> {
37 match path.extension()?.to_str()? {
38 "c" | "h" => Some(SupportedLanguage::C),
39 "java" => Some(SupportedLanguage::Java),
40 "rs" => Some(SupportedLanguage::Rust),
41 _ => None,
42 }
43 }
44
45 pub fn extract(source: &str, lang: SupportedLanguage) -> Result<Vec<FunctionInfo>> {
47 match lang {
48 SupportedLanguage::C => Self::extract_c(source),
49 SupportedLanguage::Java => Self::extract_java(source),
50 SupportedLanguage::Rust => Self::extract_rust(source),
51 }
52 }
53
54 fn node_text(source: &str, node: &tree_sitter::Node) -> String {
55 source[node.start_byte()..node.end_byte()].to_string()
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use crate::types::BlockId;
63
64 #[test]
65 fn test_language_detection() {
66 use std::path::Path;
67
68 assert_eq!(
69 CfgExtractor::detect_language(Path::new("test.c")),
70 Some(SupportedLanguage::C)
71 );
72 assert_eq!(
73 CfgExtractor::detect_language(Path::new("test.h")),
74 Some(SupportedLanguage::C)
75 );
76 assert_eq!(
77 CfgExtractor::detect_language(Path::new("Test.java")),
78 Some(SupportedLanguage::Java)
79 );
80 assert_eq!(
81 CfgExtractor::detect_language(Path::new("test.rs")),
82 Some(SupportedLanguage::Rust)
83 );
84 }
85
86 #[test]
87 fn test_extract_c_simple_function() {
88 let source = r#"
89 int add(int a, int b) {
90 return a + b;
91 }
92 "#;
93
94 let funcs = CfgExtractor::extract_c(source).expect("invariant: valid C source parses");
95 assert_eq!(funcs.len(), 1);
96 assert_eq!(funcs[0].name, "add");
97 }
98
99 #[test]
100 fn test_extract_c_with_if() {
101 let source = r#"
102 int max(int a, int b) {
103 if (a > b) {
104 return a;
105 } else {
106 return b;
107 }
108 }
109 "#;
110
111 let funcs = CfgExtractor::extract_c(source).expect("invariant: valid C source parses");
112 assert_eq!(funcs.len(), 1);
113
114 let cfg = &funcs[0].cfg;
115 assert!(cfg.successors.len() >= 2);
117 }
118
119 #[test]
120 fn test_extract_java_simple_method() {
121 let source = r#"
122 public class Test {
123 public int add(int a, int b) {
124 return a + b;
125 }
126 }
127 "#;
128
129 let funcs =
130 CfgExtractor::extract_java(source).expect("invariant: valid Java source parses");
131 assert_eq!(funcs.len(), 1);
132 assert_eq!(funcs[0].name, "add");
133 }
134
135 #[test]
136 fn test_extract_java_with_loop() {
137 let source = r#"
138 public class Test {
139 public int sum(int n) {
140 int total = 0;
141 for (int i = 0; i < n; i++) {
142 total += i;
143 }
144 return total;
145 }
146 }
147 "#;
148
149 let funcs =
150 CfgExtractor::extract_java(source).expect("invariant: valid Java source parses");
151 assert_eq!(funcs.len(), 1);
152
153 let cfg = &funcs[0].cfg;
155 let loops = cfg.detect_loops();
156 assert!(!loops.is_empty(), "Should detect at least one loop");
157 }
158
159 #[test]
160 fn test_extract_rust_simple_function() {
161 let source = r#"
162 fn add(a: i32, b: i32) -> i32 {
163 a + b
164 }
165 "#;
166
167 let funcs =
168 CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
169 assert_eq!(funcs.len(), 1);
170 assert_eq!(funcs[0].name, "add");
171 }
172
173 #[test]
174 fn test_extract_rust_if_expression() {
175 let source = r#"
176 fn max(a: i32, b: i32) -> i32 {
177 if a > b {
178 a
179 } else {
180 b
181 }
182 }
183 "#;
184
185 let funcs =
186 CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
187 assert_eq!(funcs.len(), 1);
188 assert_eq!(funcs[0].name, "max");
189
190 let cfg = &funcs[0].cfg;
192 assert!(cfg.entry == BlockId(0));
193 }
194
195 #[test]
196 fn test_extract_rust_loop() {
197 let source = r#"
198 fn countdown(mut n: i32) -> i32 {
199 loop {
200 if n <= 0 {
201 break;
202 }
203 n -= 1;
204 }
205 n
206 }
207 "#;
208
209 let funcs =
210 CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
211 assert_eq!(funcs.len(), 1);
212 assert_eq!(funcs[0].name, "countdown");
213
214 let cfg = &funcs[0].cfg;
216 assert!(cfg.entry == BlockId(0));
217 }
218
219 #[test]
220 fn test_extract_rust_for_loop() {
221 let source = r#"
222 fn sum(n: i32) -> i32 {
223 let mut total = 0;
224 for i in 0..n {
225 total += i;
226 }
227 total
228 }
229 "#;
230
231 let funcs =
232 CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
233 assert_eq!(funcs.len(), 1);
234 assert_eq!(funcs[0].name, "sum");
235
236 let cfg = &funcs[0].cfg;
238 assert!(cfg.entry == BlockId(0));
239 }
240
241 #[test]
242 fn test_extract_rust_match_expression() {
243 let source = r#"
244 fn classify(n: i32) -> &'static str {
245 match n {
246 0 => "zero",
247 1..=9 => "single digit",
248 _ => "other",
249 }
250 }
251 "#;
252
253 let funcs =
254 CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
255 assert_eq!(funcs.len(), 1);
256 assert_eq!(funcs[0].name, "classify");
257 }
258}