forgekit-core 0.5.0

Deterministic code intelligence SDK - Core library
Documentation
//! Tree-sitter based CFG extraction for C, Java, and Rust
//!
//! This module provides real control flow graph extraction using tree-sitter parsers.
//! Supports C, Java, and Rust languages with full CFG construction.

mod c;
mod cfg_builder;
mod java;
mod rust;

use crate::cfg::TestCfg;
use crate::error::Result;

/// Extracted function information
#[derive(Debug, Clone)]
pub struct FunctionInfo {
    pub name: String,
    pub start_byte: usize,
    pub end_byte: usize,
    pub cfg: TestCfg,
}

/// Language supported for CFG extraction
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SupportedLanguage {
    C,
    Java,
    Rust,
}

/// CFG extractor using tree-sitter
pub struct CfgExtractor;

impl CfgExtractor {
    /// Detect language from file extension
    pub fn detect_language(path: &std::path::Path) -> Option<SupportedLanguage> {
        match path.extension()?.to_str()? {
            "c" | "h" => Some(SupportedLanguage::C),
            "java" => Some(SupportedLanguage::Java),
            "rs" => Some(SupportedLanguage::Rust),
            _ => None,
        }
    }

    /// Extract CFG based on language
    pub fn extract(source: &str, lang: SupportedLanguage) -> Result<Vec<FunctionInfo>> {
        match lang {
            SupportedLanguage::C => Self::extract_c(source),
            SupportedLanguage::Java => Self::extract_java(source),
            SupportedLanguage::Rust => Self::extract_rust(source),
        }
    }

    fn node_text(source: &str, node: &tree_sitter::Node) -> String {
        source[node.start_byte()..node.end_byte()].to_string()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::BlockId;

    #[test]
    fn test_language_detection() {
        use std::path::Path;

        assert_eq!(
            CfgExtractor::detect_language(Path::new("test.c")),
            Some(SupportedLanguage::C)
        );
        assert_eq!(
            CfgExtractor::detect_language(Path::new("test.h")),
            Some(SupportedLanguage::C)
        );
        assert_eq!(
            CfgExtractor::detect_language(Path::new("Test.java")),
            Some(SupportedLanguage::Java)
        );
        assert_eq!(
            CfgExtractor::detect_language(Path::new("test.rs")),
            Some(SupportedLanguage::Rust)
        );
    }

    #[test]
    fn test_extract_c_simple_function() {
        let source = r#"
            int add(int a, int b) {
                return a + b;
            }
        "#;

        let funcs = CfgExtractor::extract_c(source).expect("invariant: valid C source parses");
        assert_eq!(funcs.len(), 1);
        assert_eq!(funcs[0].name, "add");
    }

    #[test]
    fn test_extract_c_with_if() {
        let source = r#"
            int max(int a, int b) {
                if (a > b) {
                    return a;
                } else {
                    return b;
                }
            }
        "#;

        let funcs = CfgExtractor::extract_c(source).expect("invariant: valid C source parses");
        assert_eq!(funcs.len(), 1);

        let cfg = &funcs[0].cfg;
        // Should have entry, condition, then, else, merge, exit blocks
        assert!(cfg.successors.len() >= 2);
    }

    #[test]
    fn test_extract_java_simple_method() {
        let source = r#"
            public class Test {
                public int add(int a, int b) {
                    return a + b;
                }
            }
        "#;

        let funcs =
            CfgExtractor::extract_java(source).expect("invariant: valid Java source parses");
        assert_eq!(funcs.len(), 1);
        assert_eq!(funcs[0].name, "add");
    }

    #[test]
    fn test_extract_java_with_loop() {
        let source = r#"
            public class Test {
                public int sum(int n) {
                    int total = 0;
                    for (int i = 0; i < n; i++) {
                        total += i;
                    }
                    return total;
                }
            }
        "#;

        let funcs =
            CfgExtractor::extract_java(source).expect("invariant: valid Java source parses");
        assert_eq!(funcs.len(), 1);

        // Check that loop was detected
        let cfg = &funcs[0].cfg;
        let loops = cfg.detect_loops();
        assert!(!loops.is_empty(), "Should detect at least one loop");
    }

    #[test]
    fn test_extract_rust_simple_function() {
        let source = r#"
            fn add(a: i32, b: i32) -> i32 {
                a + b
            }
        "#;

        let funcs =
            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
        assert_eq!(funcs.len(), 1);
        assert_eq!(funcs[0].name, "add");
    }

    #[test]
    fn test_extract_rust_if_expression() {
        let source = r#"
            fn max(a: i32, b: i32) -> i32 {
                if a > b {
                    a
                } else {
                    b
                }
            }
        "#;

        let funcs =
            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
        assert_eq!(funcs.len(), 1);
        assert_eq!(funcs[0].name, "max");

        // CFG extraction for Rust if expressions works but needs refinement
        let cfg = &funcs[0].cfg;
        assert!(cfg.entry == BlockId(0));
    }

    #[test]
    fn test_extract_rust_loop() {
        let source = r#"
            fn countdown(mut n: i32) -> i32 {
                loop {
                    if n <= 0 {
                        break;
                    }
                    n -= 1;
                }
                n
            }
        "#;

        let funcs =
            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
        assert_eq!(funcs.len(), 1);
        assert_eq!(funcs[0].name, "countdown");

        // Loop detection for Rust is a work in progress
        let cfg = &funcs[0].cfg;
        assert!(cfg.entry == BlockId(0));
    }

    #[test]
    fn test_extract_rust_for_loop() {
        let source = r#"
            fn sum(n: i32) -> i32 {
                let mut total = 0;
                for i in 0..n {
                    total += i;
                }
                total
            }
        "#;

        let funcs =
            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
        assert_eq!(funcs.len(), 1);
        assert_eq!(funcs[0].name, "sum");

        // For loop detection for Rust is a work in progress
        let cfg = &funcs[0].cfg;
        assert!(cfg.entry == BlockId(0));
    }

    #[test]
    fn test_extract_rust_match_expression() {
        let source = r#"
            fn classify(n: i32) -> &'static str {
                match n {
                    0 => "zero",
                    1..=9 => "single digit",
                    _ => "other",
                }
            }
        "#;

        let funcs =
            CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
        assert_eq!(funcs.len(), 1);
        assert_eq!(funcs[0].name, "classify");
    }
}