Skip to main content

forgekit_core/treesitter/
rust.rs

1use crate::cfg::TestCfg;
2use crate::error::{ForgeError, Result};
3use crate::types::BlockId;
4
5use super::{CfgExtractor, FunctionInfo, SupportedLanguage};
6
7impl CfgExtractor {
8    /// Extract CFG from Rust source code
9    pub fn extract_rust(source: &str) -> Result<Vec<FunctionInfo>> {
10        use tree_sitter::Parser;
11        use tree_sitter_rust;
12
13        let mut parser = Parser::new();
14        parser
15            .set_language(&tree_sitter_rust::LANGUAGE.into())
16            .map_err(|e| {
17                ForgeError::DatabaseError(format!("Failed to set Rust language: {:?}", e))
18            })?;
19
20        let tree = parser
21            .parse(source, None)
22            .ok_or_else(|| ForgeError::DatabaseError("Failed to parse Rust code".to_string()))?;
23
24        let root = tree.root_node();
25        let mut functions = Vec::new();
26
27        Self::extract_rust_functions(source, &root, &mut functions)?;
28
29        Ok(functions)
30    }
31
32    fn extract_rust_functions(
33        source: &str,
34        node: &tree_sitter::Node,
35        functions: &mut Vec<FunctionInfo>,
36    ) -> Result<()> {
37        let kind = node.kind();
38
39        // Look for function and method definitions
40        if kind == "function_item" || kind == "method_declaration" {
41            if let Some(func) = Self::parse_rust_function(source, node)? {
42                functions.push(func);
43            }
44        }
45
46        // Recurse into children
47        let mut cursor = node.walk();
48        for child in node.children(&mut cursor) {
49            Self::extract_rust_functions(source, &child, functions)?;
50        }
51
52        Ok(())
53    }
54
55    fn parse_rust_function(source: &str, node: &tree_sitter::Node) -> Result<Option<FunctionInfo>> {
56        let start_byte = node.start_byte();
57        let end_byte = node.end_byte();
58
59        // Find function name - look for identifier after fn keyword
60        let mut name = "unknown".to_string();
61        let mut found_fn = false;
62        let mut cursor = node.walk();
63
64        for child in node.children(&mut cursor) {
65            if child.kind() == "fn" {
66                found_fn = true;
67                continue;
68            }
69            if found_fn && child.kind() == "identifier" {
70                name = Self::node_text(source, &child);
71                break;
72            }
73        }
74
75        // Find function body (block)
76        let mut body = None;
77        let mut cursor = node.walk();
78        for child in node.children(&mut cursor) {
79            if child.kind() == "block" {
80                body = Some(child);
81                break;
82            }
83        }
84
85        let cfg = if let Some(body) = body {
86            Self::build_cfg_from_body(source, &body, SupportedLanguage::Rust)?
87        } else {
88            // Function without body (trait method)
89            TestCfg::new(BlockId(0))
90        };
91
92        Ok(Some(FunctionInfo {
93            name,
94            start_byte,
95            end_byte,
96            cfg,
97        }))
98    }
99}