use crate::error::{RaxitError, Result};
use std::fs;
use std::path::Path;
use tree_sitter::{Parser, Tree};
use super::AstNode;
pub fn parse_file(path: &Path) -> Result<Vec<AstNode>> {
let source_code = fs::read_to_string(path).map_err(|e| RaxitError::Parse {
file: path.to_path_buf(),
message: e.to_string(),
})?;
parse_source(&source_code, path)
}
pub fn parse_source(source_code: &str, path: &Path) -> Result<Vec<AstNode>> {
let mut parser = Parser::new();
let language = tree_sitter_python::LANGUAGE;
parser
.set_language(&language.into())
.map_err(|e| RaxitError::Parse {
file: path.to_path_buf(),
message: format!("Failed to set language: {e}"),
})?;
let tree = parser
.parse(source_code, None)
.ok_or_else(|| RaxitError::Parse {
file: path.to_path_buf(),
message: "Failed to parse source code".to_string(),
})?;
extract_nodes(&tree, source_code)
}
fn extract_nodes(tree: &Tree, source_code: &str) -> Result<Vec<AstNode>> {
let mut nodes = Vec::new();
let root_node = tree.root_node();
let _cursor = root_node.walk();
visit_node(&root_node, source_code, &mut nodes);
Ok(nodes)
}
fn visit_node(node: &tree_sitter::Node, source_code: &str, nodes: &mut Vec<AstNode>) {
let kind = node.kind();
if matches!(
kind,
"function_definition" | "class_definition" | "decorated_definition" | "call"
) {
let start_line = node.start_position().row as u32 + 1;
let end_line = node.end_position().row as u32 + 1;
let text = node
.utf8_text(source_code.as_bytes())
.unwrap_or("")
.to_string();
nodes.push(AstNode {
kind: kind.to_string(),
start_line,
end_line,
text,
});
}
let mut child_cursor = node.walk();
for child in node.children(&mut child_cursor) {
visit_node(&child, source_code, nodes);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_simple_function() {
let source = r#"
def hello():
print("Hello, world!")
"#;
let result = parse_source(source, Path::new("test.py"));
assert!(result.is_ok());
let nodes = result.unwrap();
assert!(!nodes.is_empty());
let func_nodes: Vec<_> = nodes
.iter()
.filter(|n| n.kind == "function_definition")
.collect();
assert_eq!(func_nodes.len(), 1);
}
#[test]
fn test_parse_class() {
let source = r#"
class MyClass:
def __init__(self):
pass
"#;
let result = parse_source(source, Path::new("test.py"));
assert!(result.is_ok());
let nodes = result.unwrap();
let class_nodes: Vec<_> = nodes
.iter()
.filter(|n| n.kind == "class_definition")
.collect();
assert_eq!(class_nodes.len(), 1);
}
}