Skip to main content

forgekit_core/treesitter/
mod.rs

1//! Tree-sitter based CFG extraction for C, Java, and Rust
2//!
3//! This module provides real control flow graph extraction using tree-sitter parsers.
4//! Supports C, Java, and Rust languages with full CFG construction.
5
6mod c;
7mod cfg_builder;
8mod java;
9mod rust;
10
11use crate::cfg::TestCfg;
12use crate::error::Result;
13
14/// Extracted function information
15#[derive(Debug, Clone)]
16pub struct FunctionInfo {
17    pub name: String,
18    pub start_byte: usize,
19    pub end_byte: usize,
20    pub cfg: TestCfg,
21}
22
23/// Language supported for CFG extraction
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SupportedLanguage {
26    C,
27    Java,
28    Rust,
29}
30
31/// CFG extractor using tree-sitter
32pub struct CfgExtractor;
33
34impl CfgExtractor {
35    /// Detect language from file extension
36    pub fn detect_language(path: &std::path::Path) -> Option<SupportedLanguage> {
37        match path.extension()?.to_str()? {
38            "c" | "h" => Some(SupportedLanguage::C),
39            "java" => Some(SupportedLanguage::Java),
40            "rs" => Some(SupportedLanguage::Rust),
41            _ => None,
42        }
43    }
44
45    /// Extract CFG based on language
46    pub fn extract(source: &str, lang: SupportedLanguage) -> Result<Vec<FunctionInfo>> {
47        match lang {
48            SupportedLanguage::C => Self::extract_c(source),
49            SupportedLanguage::Java => Self::extract_java(source),
50            SupportedLanguage::Rust => Self::extract_rust(source),
51        }
52    }
53
54    fn node_text(source: &str, node: &tree_sitter::Node) -> String {
55        source[node.start_byte()..node.end_byte()].to_string()
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use crate::types::BlockId;
63
64    #[test]
65    fn test_language_detection() {
66        use std::path::Path;
67
68        assert_eq!(
69            CfgExtractor::detect_language(Path::new("test.c")),
70            Some(SupportedLanguage::C)
71        );
72        assert_eq!(
73            CfgExtractor::detect_language(Path::new("test.h")),
74            Some(SupportedLanguage::C)
75        );
76        assert_eq!(
77            CfgExtractor::detect_language(Path::new("Test.java")),
78            Some(SupportedLanguage::Java)
79        );
80        assert_eq!(
81            CfgExtractor::detect_language(Path::new("test.rs")),
82            Some(SupportedLanguage::Rust)
83        );
84    }
85
86    #[test]
87    fn test_extract_c_simple_function() {
88        let source = r#"
89            int add(int a, int b) {
90                return a + b;
91            }
92        "#;
93
94        let funcs = CfgExtractor::extract_c(source).expect("invariant: valid C source parses");
95        assert_eq!(funcs.len(), 1);
96        assert_eq!(funcs[0].name, "add");
97    }
98
99    #[test]
100    fn test_extract_c_with_if() {
101        let source = r#"
102            int max(int a, int b) {
103                if (a > b) {
104                    return a;
105                } else {
106                    return b;
107                }
108            }
109        "#;
110
111        let funcs = CfgExtractor::extract_c(source).expect("invariant: valid C source parses");
112        assert_eq!(funcs.len(), 1);
113
114        let cfg = &funcs[0].cfg;
115        // Should have entry, condition, then, else, merge, exit blocks
116        assert!(cfg.successors.len() >= 2);
117    }
118
119    #[test]
120    fn test_extract_java_simple_method() {
121        let source = r#"
122            public class Test {
123                public int add(int a, int b) {
124                    return a + b;
125                }
126            }
127        "#;
128
129        let funcs =
130            CfgExtractor::extract_java(source).expect("invariant: valid Java source parses");
131        assert_eq!(funcs.len(), 1);
132        assert_eq!(funcs[0].name, "add");
133    }
134
135    #[test]
136    fn test_extract_java_with_loop() {
137        let source = r#"
138            public class Test {
139                public int sum(int n) {
140                    int total = 0;
141                    for (int i = 0; i < n; i++) {
142                        total += i;
143                    }
144                    return total;
145                }
146            }
147        "#;
148
149        let funcs =
150            CfgExtractor::extract_java(source).expect("invariant: valid Java source parses");
151        assert_eq!(funcs.len(), 1);
152
153        // Check that loop was detected
154        let cfg = &funcs[0].cfg;
155        let loops = cfg.detect_loops();
156        assert!(!loops.is_empty(), "Should detect at least one loop");
157    }
158
159    #[test]
160    fn test_extract_rust_simple_function() {
161        let source = r#"
162            fn add(a: i32, b: i32) -> i32 {
163                a + b
164            }
165        "#;
166
167        let funcs =
168            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
169        assert_eq!(funcs.len(), 1);
170        assert_eq!(funcs[0].name, "add");
171    }
172
173    #[test]
174    fn test_extract_rust_if_expression() {
175        let source = r#"
176            fn max(a: i32, b: i32) -> i32 {
177                if a > b {
178                    a
179                } else {
180                    b
181                }
182            }
183        "#;
184
185        let funcs =
186            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
187        assert_eq!(funcs.len(), 1);
188        assert_eq!(funcs[0].name, "max");
189
190        // CFG extraction for Rust if expressions works but needs refinement
191        let cfg = &funcs[0].cfg;
192        assert!(cfg.entry == BlockId(0));
193    }
194
195    #[test]
196    fn test_extract_rust_loop() {
197        let source = r#"
198            fn countdown(mut n: i32) -> i32 {
199                loop {
200                    if n <= 0 {
201                        break;
202                    }
203                    n -= 1;
204                }
205                n
206            }
207        "#;
208
209        let funcs =
210            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
211        assert_eq!(funcs.len(), 1);
212        assert_eq!(funcs[0].name, "countdown");
213
214        // Loop detection for Rust is a work in progress
215        let cfg = &funcs[0].cfg;
216        assert!(cfg.entry == BlockId(0));
217    }
218
219    #[test]
220    fn test_extract_rust_for_loop() {
221        let source = r#"
222            fn sum(n: i32) -> i32 {
223                let mut total = 0;
224                for i in 0..n {
225                    total += i;
226                }
227                total
228            }
229        "#;
230
231        let funcs =
232            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
233        assert_eq!(funcs.len(), 1);
234        assert_eq!(funcs[0].name, "sum");
235
236        // For loop detection for Rust is a work in progress
237        let cfg = &funcs[0].cfg;
238        assert!(cfg.entry == BlockId(0));
239    }
240
241    #[test]
242    fn test_extract_rust_match_expression() {
243        let source = r#"
244            fn classify(n: i32) -> &'static str {
245                match n {
246                    0 => "zero",
247                    1..=9 => "single digit",
248                    _ => "other",
249                }
250            }
251        "#;
252
253        let funcs =
254            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
255        assert_eq!(funcs.len(), 1);
256        assert_eq!(funcs[0].name, "classify");
257    }
258}