mod c;
mod cfg_builder;
mod java;
mod rust;
use crate::cfg::TestCfg;
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct FunctionInfo {
pub name: String,
pub start_byte: usize,
pub end_byte: usize,
pub cfg: TestCfg,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SupportedLanguage {
C,
Java,
Rust,
}
pub struct CfgExtractor;
impl CfgExtractor {
pub fn detect_language(path: &std::path::Path) -> Option<SupportedLanguage> {
match path.extension()?.to_str()? {
"c" | "h" => Some(SupportedLanguage::C),
"java" => Some(SupportedLanguage::Java),
"rs" => Some(SupportedLanguage::Rust),
_ => None,
}
}
pub fn extract(source: &str, lang: SupportedLanguage) -> Result<Vec<FunctionInfo>> {
match lang {
SupportedLanguage::C => Self::extract_c(source),
SupportedLanguage::Java => Self::extract_java(source),
SupportedLanguage::Rust => Self::extract_rust(source),
}
}
fn node_text(source: &str, node: &tree_sitter::Node) -> String {
source[node.start_byte()..node.end_byte()].to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::BlockId;
#[test]
fn test_language_detection() {
use std::path::Path;
assert_eq!(
CfgExtractor::detect_language(Path::new("test.c")),
Some(SupportedLanguage::C)
);
assert_eq!(
CfgExtractor::detect_language(Path::new("test.h")),
Some(SupportedLanguage::C)
);
assert_eq!(
CfgExtractor::detect_language(Path::new("Test.java")),
Some(SupportedLanguage::Java)
);
assert_eq!(
CfgExtractor::detect_language(Path::new("test.rs")),
Some(SupportedLanguage::Rust)
);
}
#[test]
fn test_extract_c_simple_function() {
let source = r#"
int add(int a, int b) {
return a + b;
}
"#;
let funcs = CfgExtractor::extract_c(source).expect("invariant: valid C source parses");
assert_eq!(funcs.len(), 1);
assert_eq!(funcs[0].name, "add");
}
#[test]
fn test_extract_c_with_if() {
let source = r#"
int max(int a, int b) {
if (a > b) {
return a;
} else {
return b;
}
}
"#;
let funcs = CfgExtractor::extract_c(source).expect("invariant: valid C source parses");
assert_eq!(funcs.len(), 1);
let cfg = &funcs[0].cfg;
assert!(cfg.successors.len() >= 2);
}
#[test]
fn test_extract_java_simple_method() {
let source = r#"
public class Test {
public int add(int a, int b) {
return a + b;
}
}
"#;
let funcs =
CfgExtractor::extract_java(source).expect("invariant: valid Java source parses");
assert_eq!(funcs.len(), 1);
assert_eq!(funcs[0].name, "add");
}
#[test]
fn test_extract_java_with_loop() {
let source = r#"
public class Test {
public int sum(int n) {
int total = 0;
for (int i = 0; i < n; i++) {
total += i;
}
return total;
}
}
"#;
let funcs =
CfgExtractor::extract_java(source).expect("invariant: valid Java source parses");
assert_eq!(funcs.len(), 1);
let cfg = &funcs[0].cfg;
let loops = cfg.detect_loops();
assert!(!loops.is_empty(), "Should detect at least one loop");
}
#[test]
fn test_extract_rust_simple_function() {
let source = r#"
fn add(a: i32, b: i32) -> i32 {
a + b
}
"#;
let funcs =
CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
assert_eq!(funcs.len(), 1);
assert_eq!(funcs[0].name, "add");
}
#[test]
fn test_extract_rust_if_expression() {
let source = r#"
fn max(a: i32, b: i32) -> i32 {
if a > b {
a
} else {
b
}
}
"#;
let funcs =
CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
assert_eq!(funcs.len(), 1);
assert_eq!(funcs[0].name, "max");
let cfg = &funcs[0].cfg;
assert!(cfg.entry == BlockId(0));
}
#[test]
fn test_extract_rust_loop() {
let source = r#"
fn countdown(mut n: i32) -> i32 {
loop {
if n <= 0 {
break;
}
n -= 1;
}
n
}
"#;
let funcs =
CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
assert_eq!(funcs.len(), 1);
assert_eq!(funcs[0].name, "countdown");
let cfg = &funcs[0].cfg;
assert!(cfg.entry == BlockId(0));
}
#[test]
fn test_extract_rust_for_loop() {
let source = r#"
fn sum(n: i32) -> i32 {
let mut total = 0;
for i in 0..n {
total += i;
}
total
}
"#;
let funcs =
CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
assert_eq!(funcs.len(), 1);
assert_eq!(funcs[0].name, "sum");
let cfg = &funcs[0].cfg;
assert!(cfg.entry == BlockId(0));
}
#[test]
fn test_extract_rust_match_expression() {
let source = r#"
fn classify(n: i32) -> &'static str {
match n {
0 => "zero",
1..=9 => "single digit",
_ => "other",
}
}
"#;
let funcs =
CfgExtractor::extract_rust(source).expect("invariant: valid Rust source parses");
assert_eq!(funcs.len(), 1);
assert_eq!(funcs[0].name, "classify");
}
}