raxit-core 0.1.2

Core security scanning engine for AI agent applications
Documentation
//! Python AST parsing with tree-sitter

use crate::error::{RaxitError, Result};
use std::fs;
use std::path::Path;
use tree_sitter::{Parser, Tree};

use super::AstNode;

/// Parse a Python file and extract AST nodes
pub fn parse_file(path: &Path) -> Result<Vec<AstNode>> {
    // Read file content
    let source_code = fs::read_to_string(path).map_err(|e| RaxitError::Parse {
        file: path.to_path_buf(),
        message: e.to_string(),
    })?;

    parse_source(&source_code, path)
}

/// Parse Python source code
pub fn parse_source(source_code: &str, path: &Path) -> Result<Vec<AstNode>> {
    // Create parser
    let mut parser = Parser::new();
    let language = tree_sitter_python::LANGUAGE;

    parser
        .set_language(&language.into())
        .map_err(|e| RaxitError::Parse {
            file: path.to_path_buf(),
            message: format!("Failed to set language: {e}"),
        })?;

    // Parse the source code
    let tree = parser
        .parse(source_code, None)
        .ok_or_else(|| RaxitError::Parse {
            file: path.to_path_buf(),
            message: "Failed to parse source code".to_string(),
        })?;

    // Extract nodes
    extract_nodes(&tree, source_code)
}

/// Extract relevant AST nodes from the parse tree
fn extract_nodes(tree: &Tree, source_code: &str) -> Result<Vec<AstNode>> {
    let mut nodes = Vec::new();
    let root_node = tree.root_node();

    // Traverse the tree and extract function definitions, class definitions, etc.
    let _cursor = root_node.walk();

    visit_node(&root_node, source_code, &mut nodes);

    Ok(nodes)
}

/// Recursively visit AST nodes
fn visit_node(node: &tree_sitter::Node, source_code: &str, nodes: &mut Vec<AstNode>) {
    let kind = node.kind();

    // Extract relevant node types
    if matches!(
        kind,
        "function_definition" | "class_definition" | "decorated_definition" | "call"
    ) {
        let start_line = node.start_position().row as u32 + 1;
        let end_line = node.end_position().row as u32 + 1;
        let text = node
            .utf8_text(source_code.as_bytes())
            .unwrap_or("")
            .to_string();

        nodes.push(AstNode {
            kind: kind.to_string(),
            start_line,
            end_line,
            text,
        });
    }

    // Recursively visit children
    let mut child_cursor = node.walk();
    for child in node.children(&mut child_cursor) {
        visit_node(&child, source_code, nodes);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_simple_function() {
        let source = r#"
def hello():
    print("Hello, world!")
"#;

        let result = parse_source(source, Path::new("test.py"));
        assert!(result.is_ok());

        let nodes = result.unwrap();
        assert!(!nodes.is_empty());

        // Should find the function definition
        let func_nodes: Vec<_> = nodes
            .iter()
            .filter(|n| n.kind == "function_definition")
            .collect();
        assert_eq!(func_nodes.len(), 1);
    }

    #[test]
    fn test_parse_class() {
        let source = r#"
class MyClass:
    def __init__(self):
        pass
"#;

        let result = parse_source(source, Path::new("test.py"));
        assert!(result.is_ok());

        let nodes = result.unwrap();

        // Should find class and function definitions
        let class_nodes: Vec<_> = nodes
            .iter()
            .filter(|n| n.kind == "class_definition")
            .collect();
        assert_eq!(class_nodes.len(), 1);
    }
}