use crate::ir::*;
use crate::traits::{ReadError, Reader};
use tree_sitter::{Node, Parser, Tree};
pub static PYTHON_READER: PythonReader = PythonReader;
pub struct PythonReader;
impl Reader for PythonReader {
fn language(&self) -> &'static str {
"python"
}
fn extensions(&self) -> &'static [&'static str] {
&["py"]
}
fn read(&self, source: &str) -> Result<Program, ReadError> {
read_python(source)
}
}
pub fn read_python(source: &str) -> Result<Program, ReadError> {
let language = normalize_languages::parsers::grammar_loader()
.get("python")
.map_err(|e| ReadError::Parse(format!("load python grammar: {e}")))?;
let mut parser = Parser::new();
parser
.set_language(&language)
.map_err(|err| ReadError::Parse(err.to_string()))?;
let tree = parser
.parse(source, None)
.ok_or_else(|| ReadError::Parse("failed to parse".into()))?;
let ctx = ReadContext::new(source);
ctx.read_program(&tree)
}
struct ReadContext<'a> {
source: &'a str,
}
impl<'a> ReadContext<'a> {
fn new(source: &'a str) -> Self {
Self { source }
}
fn node_text(&self, node: Node) -> &str {
node.utf8_text(self.source.as_bytes()).unwrap_or("")
}
fn read_program(&self, tree: &Tree) -> Result<Program, ReadError> {
let root = tree.root_node();
if root.has_error() {
return Err(ReadError::Parse("syntax error in source".into()));
}
let mut statements = Vec::new();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if child.is_named()
&& let Some(stmt) = self.read_stmt(child)?
{
statements.push(stmt);
}
}
Ok(Program::new(statements))
}
fn read_stmt(&self, node: Node) -> Result<Option<Stmt>, ReadError> {
match node.kind() {
"comment" => Ok(None),
"assignment" => self.read_assignment(node).map(Some),
"augmented_assignment" => self.read_augmented_assignment(node).map(Some),
"if_statement" => self.read_if_statement(node).map(Some),
"while_statement" => self.read_while_statement(node).map(Some),
"for_statement" => self.read_for_statement(node).map(Some),
"return_statement" => self.read_return_statement(node).map(Some),
"break_statement" => Ok(Some(Stmt::break_stmt())),
"continue_statement" => Ok(Some(Stmt::continue_stmt())),
"pass_statement" => Ok(None),
"function_definition" => self.read_function_definition(node).map(Some),
"import_statement" => self.read_import_statement(node).map(Some),
"import_from_statement" => self.read_import_from_statement(node).map(Some),
"class_definition" => self.read_class_definition(node).map(Some),
"try_statement" => Ok(None),
"with_statement" => Ok(None),
"decorated_definition" => {
if let Some(def) = node.child_by_field_name("definition") {
self.read_stmt(def)
} else {
Ok(None)
}
}
"call"
| "binary_operator"
| "comparison_operator"
| "boolean_operator"
| "identifier"
| "attribute" => Ok(Some(Stmt::expr(self.read_expr(node)?))),
_ => {
Ok(None)
}
}
}
fn read_assignment(&self, node: Node) -> Result<Stmt, ReadError> {
let left = node
.child_by_field_name("left")
.ok_or_else(|| ReadError::Parse("assignment missing left".into()))?;
let right = node
.child_by_field_name("right")
.ok_or_else(|| ReadError::Parse("assignment missing right".into()))?;
let value = self.read_expr(right)?;
match left.kind() {
"pattern_list" | "tuple_pattern" | "list_pattern" => {
let pat = self.read_py_pat(left)?;
let span = Span::from_ts(node.start_position(), node.end_position());
return Ok(Stmt::destructure(pat, value, true).with_span(span));
}
_ => {}
}
let name = self.node_text(left);
Ok(Stmt::let_decl(name, Some(value)))
}
fn read_py_pat(&self, node: Node) -> Result<Pat, ReadError> {
match node.kind() {
"identifier" => Ok(Pat::ident(self.node_text(node))),
"pattern_list" | "tuple_pattern" => {
let mut elements: Vec<Option<Pat>> = Vec::new();
let mut rest: Option<String> = None;
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if !child.is_named() {
continue;
}
match child.kind() {
"list_splat_pattern" => {
let mut inner_cursor = child.walk();
for inner in child.children(&mut inner_cursor) {
if inner.kind() == "identifier" {
rest = Some(self.node_text(inner).to_string());
break;
}
}
}
_ => {
elements.push(Some(self.read_py_pat(child)?));
}
}
}
Ok(Pat::Array(elements, rest))
}
"list_pattern" => {
let mut elements: Vec<Option<Pat>> = Vec::new();
let mut rest: Option<String> = None;
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if !child.is_named() {
continue;
}
match child.kind() {
"list_splat_pattern" => {
let mut inner_cursor = child.walk();
for inner in child.children(&mut inner_cursor) {
if inner.kind() == "identifier" {
rest = Some(self.node_text(inner).to_string());
break;
}
}
}
_ => {
elements.push(Some(self.read_py_pat(child)?));
}
}
}
Ok(Pat::Array(elements, rest))
}
other => Err(ReadError::Unsupported(format!(
"Python pattern type '{}'",
other
))),
}
}
fn read_augmented_assignment(&self, node: Node) -> Result<Stmt, ReadError> {
let left = node
.child_by_field_name("left")
.ok_or_else(|| ReadError::Parse("augmented_assignment missing left".into()))?;
let right = node
.child_by_field_name("right")
.ok_or_else(|| ReadError::Parse("augmented_assignment missing right".into()))?;
let op_node = node
.child_by_field_name("operator")
.ok_or_else(|| ReadError::Parse("augmented_assignment missing operator".into()))?;
let name = self.node_text(left);
let op_text = self.node_text(op_node);
let op = match op_text {
"+=" => BinaryOp::Add,
"-=" => BinaryOp::Sub,
"*=" => BinaryOp::Mul,
"/=" => BinaryOp::Div,
"%=" => BinaryOp::Mod,
_ => {
return Err(ReadError::Parse(format!(
"unknown augmented op: {}",
op_text
)));
}
};
let rhs = self.read_expr(right)?;
let value = Expr::binary(Expr::ident(name), op, rhs);
Ok(Stmt::expr(Expr::assign(Expr::ident(name), value)))
}
fn read_if_statement(&self, node: Node) -> Result<Stmt, ReadError> {
let condition = node
.child_by_field_name("condition")
.ok_or_else(|| ReadError::Parse("if missing condition".into()))?;
let consequence = node
.child_by_field_name("consequence")
.ok_or_else(|| ReadError::Parse("if missing consequence".into()))?;
let test = self.read_expr(condition)?;
let consequent = self.read_block(consequence)?;
let alternate = if let Some(alt) = node.child_by_field_name("alternative") {
match alt.kind() {
"else_clause" => {
if let Some(body) = alt.child_by_field_name("body") {
Some(self.read_block(body)?)
} else {
None
}
}
"elif_clause" => {
Some(self.read_elif_clause(alt)?)
}
_ => None,
}
} else {
None
};
Ok(Stmt::if_stmt(test, consequent, alternate))
}
fn read_elif_clause(&self, node: Node) -> Result<Stmt, ReadError> {
let condition = node
.child_by_field_name("condition")
.ok_or_else(|| ReadError::Parse("elif missing condition".into()))?;
let consequence = node
.child_by_field_name("consequence")
.ok_or_else(|| ReadError::Parse("elif missing consequence".into()))?;
let test = self.read_expr(condition)?;
let consequent = self.read_block(consequence)?;
let alternate = if let Some(alt) = node.child_by_field_name("alternative") {
match alt.kind() {
"else_clause" => {
if let Some(body) = alt.child_by_field_name("body") {
Some(self.read_block(body)?)
} else {
None
}
}
"elif_clause" => Some(self.read_elif_clause(alt)?),
_ => None,
}
} else {
None
};
Ok(Stmt::if_stmt(test, consequent, alternate))
}
fn read_while_statement(&self, node: Node) -> Result<Stmt, ReadError> {
let condition = node
.child_by_field_name("condition")
.ok_or_else(|| ReadError::Parse("while missing condition".into()))?;
let body = node
.child_by_field_name("body")
.ok_or_else(|| ReadError::Parse("while missing body".into()))?;
let test = self.read_expr(condition)?;
let body_stmt = self.read_block(body)?;
Ok(Stmt::while_loop(test, body_stmt))
}
fn read_for_statement(&self, node: Node) -> Result<Stmt, ReadError> {
let left = node
.child_by_field_name("left")
.ok_or_else(|| ReadError::Parse("for missing left".into()))?;
let right = node
.child_by_field_name("right")
.ok_or_else(|| ReadError::Parse("for missing right".into()))?;
let body = node
.child_by_field_name("body")
.ok_or_else(|| ReadError::Parse("for missing body".into()))?;
let variable = self.node_text(left).to_string();
let iterable = self.read_expr(right)?;
let body_stmt = self.read_block(body)?;
Ok(Stmt::for_in(variable, iterable, body_stmt))
}
fn read_return_statement(&self, node: Node) -> Result<Stmt, ReadError> {
let mut cursor = node.walk();
let expr_node = node
.children(&mut cursor)
.find(|c| c.is_named() && c.kind() != "return");
let expr = expr_node.map(|n| self.read_expr(n)).transpose()?;
Ok(Stmt::return_stmt(expr))
}
fn read_import_statement(&self, node: Node) -> Result<Stmt, ReadError> {
let mut names: Vec<ImportName> = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"dotted_name" => {
let module = self.node_text(child).to_string();
names.push(ImportName::named(&module));
}
"aliased_import" => {
let name = child
.child_by_field_name("name")
.map(|n| self.node_text(n).to_string())
.unwrap_or_default();
let alias = child
.child_by_field_name("alias")
.map(|n| self.node_text(n).to_string());
names.push(ImportName::aliased(name, alias.unwrap_or_default()));
}
_ => {}
}
}
let source = names.first().map(|n| n.name.clone()).unwrap_or_default();
Ok(Stmt::import(source, names))
}
fn read_import_from_statement(&self, node: Node) -> Result<Stmt, ReadError> {
let mut source = String::new();
let mut names: Vec<ImportName> = Vec::new();
let mut past_import_kw = false;
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"from" | "import" => {
if child.kind() == "import" {
past_import_kw = true;
}
}
"dotted_name" | "relative_import" | "import_prefix" if !past_import_kw => {
source = self.node_text(child).to_string();
}
"dotted_name" if past_import_kw => {
let name = self.node_text(child).to_string();
names.push(ImportName::named(name));
}
"wildcard_import" if past_import_kw => {
names.push(ImportName::namespace("*"));
}
"import_list" if past_import_kw => {
let mut c2 = child.walk();
for spec in child.children(&mut c2) {
match spec.kind() {
"identifier" => {
names.push(ImportName::named(self.node_text(spec)));
}
"aliased_import" => {
let name = spec
.child_by_field_name("name")
.map(|n| self.node_text(n).to_string())
.unwrap_or_default();
let alias = spec
.child_by_field_name("alias")
.map(|n| self.node_text(n).to_string());
names.push(ImportName::aliased(name, alias.unwrap_or_default()));
}
_ => {}
}
}
}
"aliased_import" if past_import_kw => {
let name = child
.child_by_field_name("name")
.map(|n| self.node_text(n).to_string())
.unwrap_or_default();
let alias = child
.child_by_field_name("alias")
.map(|n| self.node_text(n).to_string());
names.push(ImportName::aliased(name, alias.unwrap_or_default()));
}
_ => {}
}
}
Ok(Stmt::import(source, names))
}
fn read_class_definition(&self, node: Node) -> Result<Stmt, ReadError> {
let name = node
.child_by_field_name("name")
.map(|n| self.node_text(n).to_string())
.unwrap_or_else(|| "__class__".to_string());
let extends = {
let mut cur = node.walk();
node.children(&mut cur)
.find(|c| c.kind() == "argument_list")
}
.and_then(|args| {
let mut c = args.walk();
args.children(&mut c)
.find(|ch| ch.kind() == "identifier")
.map(|ch| self.node_text(ch).to_string())
});
let body = node
.child_by_field_name("body")
.ok_or_else(|| ReadError::Parse("class_definition missing body".into()))?;
let methods = self.read_class_body(body)?;
Ok(Stmt::class(name, extends, methods))
}
fn read_class_body(&self, body: Node) -> Result<Vec<Method>, ReadError> {
let mut methods = Vec::new();
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
if child.kind() == "function_definition" {
if let Ok(Stmt::Function(f)) = self.read_function_definition(child) {
let method = Method::new(f.name.clone(), f.params.clone(), f.body.clone());
methods.push(method);
}
} else if child.kind() == "decorated_definition" {
if let Some(inner) = child.child_by_field_name("definition")
&& inner.kind() == "function_definition"
{
let is_static = {
let mut c2 = child.walk();
child.children(&mut c2).any(|dec| {
if dec.kind() == "decorator" {
let mut c3 = dec.walk();
dec.children(&mut c3).any(|id| {
id.kind() == "identifier"
&& self.node_text(id) == "staticmethod"
})
} else {
false
}
})
};
if let Ok(Stmt::Function(f)) = self.read_function_definition(inner) {
let mut method =
Method::new(f.name.clone(), f.params.clone(), f.body.clone());
method.is_static = is_static;
methods.push(method);
}
}
}
}
Ok(methods)
}
fn read_function_definition(&self, node: Node) -> Result<Stmt, ReadError> {
let name = node
.child_by_field_name("name")
.ok_or_else(|| ReadError::Parse("function missing name".into()))?;
let params = node.child_by_field_name("parameters");
let body = node
.child_by_field_name("body")
.ok_or_else(|| ReadError::Parse("function missing body".into()))?;
let fn_name = self.node_text(name).to_string();
let fn_params = params
.map(|p| self.read_parameters(p))
.transpose()?
.unwrap_or_default();
let fn_body = self.read_block_stmts(body)?;
let return_type = node
.child_by_field_name("return_type")
.map(|n| self.node_text(n).to_string());
let mut func = Function::new(fn_name, fn_params, fn_body);
func.return_type = return_type;
Ok(Stmt::function(func))
}
fn read_parameters(&self, node: Node) -> Result<Vec<Param>, ReadError> {
let mut params = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"identifier" => {
params.push(Param::new(self.node_text(child)));
}
"default_parameter" => {
if let Some(name) = child.child_by_field_name("name") {
params.push(Param::new(self.node_text(name)));
}
}
"typed_parameter" | "typed_default_parameter" => {
if let Some(name) = child.child(0)
&& name.kind() == "identifier"
{
let type_annotation = child
.child_by_field_name("type")
.map(|n| self.node_text(n).to_string());
let mut param = Param::new(self.node_text(name));
param.type_annotation = type_annotation;
params.push(param);
}
}
_ => {}
}
}
Ok(params)
}
fn read_block(&self, node: Node) -> Result<Stmt, ReadError> {
Ok(Stmt::block(self.read_block_stmts(node)?))
}
fn read_block_stmts(&self, node: Node) -> Result<Vec<Stmt>, ReadError> {
let mut stmts = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.is_named()
&& let Some(stmt) = self.read_stmt(child)?
{
stmts.push(stmt);
}
}
Ok(stmts)
}
fn read_expr(&self, node: Node) -> Result<Expr, ReadError> {
match node.kind() {
"integer" | "float" => {
let text = self.node_text(node);
let num: f64 = text.parse().unwrap_or(0.0);
Ok(Expr::number(num))
}
"string" | "concatenated_string" => {
let text = self.node_text(node);
let inner = text
.trim_start_matches(['"', '\''])
.trim_start_matches("f\"")
.trim_start_matches("f'")
.trim_start_matches("r\"")
.trim_start_matches("r'")
.trim_end_matches(['"', '\'']);
Ok(Expr::string(inner))
}
"true" => Ok(Expr::bool(true)),
"false" => Ok(Expr::bool(false)),
"none" => Ok(Expr::null()),
"identifier" => Ok(Expr::ident(self.node_text(node))),
"binary_operator" => self.read_binary_operator(node),
"comparison_operator" => self.read_comparison_operator(node),
"boolean_operator" => self.read_boolean_operator(node),
"unary_operator" => self.read_unary_operator(node),
"not_operator" => {
let arg = node
.child_by_field_name("argument")
.ok_or_else(|| ReadError::Parse("not_operator missing argument".into()))?;
Ok(Expr::unary(UnaryOp::Not, self.read_expr(arg)?))
}
"call" => self.read_call(node),
"attribute" => self.read_attribute(node),
"subscript" => self.read_subscript(node),
"list" => self.read_list(node),
"dictionary" => self.read_dictionary(node),
"tuple" => self.read_tuple(node),
"parenthesized_expression" => {
let inner = node.child(1).ok_or_else(|| {
ReadError::Parse("parenthesized_expression missing inner".into())
})?;
self.read_expr(inner)
}
"conditional_expression" => self.read_conditional_expression(node),
"lambda" => self.read_lambda(node),
"named_expression" => {
let name = node
.child_by_field_name("name")
.ok_or_else(|| ReadError::Parse("named_expression missing name".into()))?;
let value = node
.child_by_field_name("value")
.ok_or_else(|| ReadError::Parse("named_expression missing value".into()))?;
Ok(Expr::assign(
Expr::ident(self.node_text(name)),
self.read_expr(value)?,
))
}
_ => {
Err(ReadError::Parse(format!(
"unsupported expression: {}",
node.kind()
)))
}
}
}
fn read_binary_operator(&self, node: Node) -> Result<Expr, ReadError> {
let left = node
.child_by_field_name("left")
.ok_or_else(|| ReadError::Parse("binary_operator missing left".into()))?;
let right = node
.child_by_field_name("right")
.ok_or_else(|| ReadError::Parse("binary_operator missing right".into()))?;
let op_node = node
.child_by_field_name("operator")
.ok_or_else(|| ReadError::Parse("binary_operator missing operator".into()))?;
let op = match self.node_text(op_node) {
"+" => BinaryOp::Add,
"-" => BinaryOp::Sub,
"*" => BinaryOp::Mul,
"/" | "//" => BinaryOp::Div,
"%" => BinaryOp::Mod,
_ => {
return Err(ReadError::Parse(format!(
"unknown binary op: {}",
self.node_text(op_node)
)));
}
};
Ok(Expr::binary(
self.read_expr(left)?,
op,
self.read_expr(right)?,
))
}
fn read_comparison_operator(&self, node: Node) -> Result<Expr, ReadError> {
let mut cursor = node.walk();
let children: Vec<_> = node.children(&mut cursor).collect();
if children.len() < 3 {
return Err(ReadError::Parse(
"comparison needs at least 3 children".into(),
));
}
let left = self.read_expr(children[0])?;
let op_text = self.node_text(children[1]);
let right = self.read_expr(children[2])?;
let op = match op_text {
"<" => BinaryOp::Lt,
"<=" => BinaryOp::Le,
">" => BinaryOp::Gt,
">=" => BinaryOp::Ge,
"==" => BinaryOp::Eq,
"!=" => BinaryOp::Ne,
_ => {
return Err(ReadError::Parse(format!(
"unknown comparison op: {}",
op_text
)));
}
};
if children.len() > 3 {
let mut result = Expr::binary(left, op, right.clone());
let mut prev_right = right;
for i in (3..children.len()).step_by(2) {
if i + 1 < children.len() {
let next_op_text = self.node_text(children[i]);
let next_right = self.read_expr(children[i + 1])?;
let next_op = match next_op_text {
"<" => BinaryOp::Lt,
"<=" => BinaryOp::Le,
">" => BinaryOp::Gt,
">=" => BinaryOp::Ge,
"==" => BinaryOp::Eq,
"!=" => BinaryOp::Ne,
_ => continue,
};
let next_cmp = Expr::binary(prev_right, next_op, next_right.clone());
result = Expr::binary(result, BinaryOp::And, next_cmp);
prev_right = next_right;
}
}
Ok(result)
} else {
Ok(Expr::binary(left, op, right))
}
}
fn read_boolean_operator(&self, node: Node) -> Result<Expr, ReadError> {
let left = node
.child_by_field_name("left")
.ok_or_else(|| ReadError::Parse("boolean_operator missing left".into()))?;
let right = node
.child_by_field_name("right")
.ok_or_else(|| ReadError::Parse("boolean_operator missing right".into()))?;
let op_node = node
.child_by_field_name("operator")
.ok_or_else(|| ReadError::Parse("boolean_operator missing operator".into()))?;
let op = match self.node_text(op_node) {
"and" => BinaryOp::And,
"or" => BinaryOp::Or,
_ => {
return Err(ReadError::Parse(format!(
"unknown boolean op: {}",
self.node_text(op_node)
)));
}
};
Ok(Expr::binary(
self.read_expr(left)?,
op,
self.read_expr(right)?,
))
}
fn read_unary_operator(&self, node: Node) -> Result<Expr, ReadError> {
let op_node = node
.child_by_field_name("operator")
.ok_or_else(|| ReadError::Parse("unary_operator missing operator".into()))?;
let arg = node
.child_by_field_name("argument")
.ok_or_else(|| ReadError::Parse("unary_operator missing argument".into()))?;
let op = match self.node_text(op_node) {
"-" => UnaryOp::Neg,
"+" => return self.read_expr(arg), "~" => {
return Err(ReadError::Parse(
"bitwise not (~) not supported in IR".into(),
));
}
_ => {
return Err(ReadError::Parse(format!(
"unknown unary op: {}",
self.node_text(op_node)
)));
}
};
Ok(Expr::unary(op, self.read_expr(arg)?))
}
fn read_call(&self, node: Node) -> Result<Expr, ReadError> {
let function = node
.child_by_field_name("function")
.ok_or_else(|| ReadError::Parse("call missing function".into()))?;
let arguments = node.child_by_field_name("arguments");
let callee = self.read_expr(function)?;
let args = arguments
.map(|a| self.read_arguments(a))
.transpose()?
.unwrap_or_default();
Ok(Expr::call(callee, args))
}
fn read_arguments(&self, node: Node) -> Result<Vec<Expr>, ReadError> {
let mut args = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.is_named() && child.kind() != "(" && child.kind() != ")" {
if child.kind() != "keyword_argument" {
args.push(self.read_expr(child)?);
}
}
}
Ok(args)
}
fn read_attribute(&self, node: Node) -> Result<Expr, ReadError> {
let object = node
.child_by_field_name("object")
.ok_or_else(|| ReadError::Parse("attribute missing object".into()))?;
let attribute = node
.child_by_field_name("attribute")
.ok_or_else(|| ReadError::Parse("attribute missing attribute".into()))?;
let obj_expr = self.read_expr(object)?;
let prop = self.node_text(attribute);
Ok(Expr::member(obj_expr, prop))
}
fn read_subscript(&self, node: Node) -> Result<Expr, ReadError> {
let value = node
.child_by_field_name("value")
.ok_or_else(|| ReadError::Parse("subscript missing value".into()))?;
let subscript = node
.child_by_field_name("subscript")
.ok_or_else(|| ReadError::Parse("subscript missing subscript".into()))?;
let obj_expr = self.read_expr(value)?;
let idx_expr = self.read_expr(subscript)?;
Ok(Expr::index(obj_expr, idx_expr))
}
fn read_list(&self, node: Node) -> Result<Expr, ReadError> {
let mut items = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.is_named() {
items.push(self.read_expr(child)?);
}
}
Ok(Expr::array(items))
}
fn read_dictionary(&self, node: Node) -> Result<Expr, ReadError> {
let mut pairs = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "pair" {
let key = child.child_by_field_name("key");
let value = child.child_by_field_name("value");
if let (Some(k), Some(v)) = (key, value) {
let key_text = self.node_text(k);
let key_str = key_text.trim_matches('"').trim_matches('\'').to_string();
pairs.push((key_str, self.read_expr(v)?));
}
}
}
Ok(Expr::object(pairs))
}
fn read_tuple(&self, node: Node) -> Result<Expr, ReadError> {
let mut items = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.is_named() {
items.push(self.read_expr(child)?);
}
}
Ok(Expr::array(items))
}
fn read_conditional_expression(&self, node: Node) -> Result<Expr, ReadError> {
let mut cursor = node.walk();
let children: Vec<_> = node
.children(&mut cursor)
.filter(|c| c.is_named())
.collect();
if children.len() >= 3 {
let consequent = self.read_expr(children[0])?;
let test = self.read_expr(children[1])?;
let alternate = self.read_expr(children[2])?;
Ok(Expr::conditional(test, consequent, alternate))
} else {
Err(ReadError::Parse(
"conditional_expression needs 3 parts".into(),
))
}
}
fn read_lambda(&self, node: Node) -> Result<Expr, ReadError> {
let params = node.child_by_field_name("parameters");
let body = node
.child_by_field_name("body")
.ok_or_else(|| ReadError::Parse("lambda missing body".into()))?;
let fn_params = params
.map(|p| self.read_lambda_parameters(p))
.transpose()?
.unwrap_or_default();
let body_expr = self.read_expr(body)?;
let fn_body = vec![Stmt::return_stmt(Some(body_expr))];
Ok(Expr::Function(Box::new(Function::anonymous(
fn_params, fn_body,
))))
}
fn read_lambda_parameters(&self, node: Node) -> Result<Vec<Param>, ReadError> {
let mut params = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" {
params.push(Param::new(self.node_text(child)));
}
}
Ok(params)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_assignment() -> Result<(), ReadError> {
let ir = read_python("x = 42")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Let { name, .. } => assert_eq!(name, "x"),
_ => panic!("expected Let"),
}
Ok(())
}
#[test]
fn test_binary_expr() -> Result<(), ReadError> {
let ir = read_python("result = 1 + 2 * 3")?;
assert_eq!(ir.body.len(), 1);
Ok(())
}
#[test]
fn test_function_call() -> Result<(), ReadError> {
let ir = read_python("print(\"hello\", 42)")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Expr(Expr::Call { callee, args, .. }) => {
assert!(matches!(callee.as_ref(), Expr::Ident(n) if n == "print"));
assert_eq!(args.len(), 2);
}
_ => panic!("expected Call"),
}
Ok(())
}
#[test]
fn test_function_declaration() -> Result<(), ReadError> {
let ir = read_python("def add(a, b):\n return a + b")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Function(f) => {
assert_eq!(f.name, "add");
assert_eq!(f.params.len(), 2);
assert_eq!(f.params[0].name, "a");
assert_eq!(f.params[1].name, "b");
}
_ => panic!("expected Function"),
}
Ok(())
}
#[test]
fn test_typed_function_declaration() -> Result<(), ReadError> {
let ir = read_python("def greet(name: str, age: int) -> str:\n return name")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Function(f) => {
assert_eq!(f.name, "greet");
assert_eq!(f.params.len(), 2);
assert_eq!(f.params[0].name, "name");
assert_eq!(f.params[0].type_annotation.as_deref(), Some("str"));
assert_eq!(f.params[1].name, "age");
assert_eq!(f.params[1].type_annotation.as_deref(), Some("int"));
assert_eq!(f.return_type.as_deref(), Some("str"));
}
_ => panic!("expected Function"),
}
Ok(())
}
#[test]
fn test_if_statement() -> Result<(), ReadError> {
let ir = read_python("if x > 0:\n print(x)")?;
assert_eq!(ir.body.len(), 1);
assert!(matches!(&ir.body[0], Stmt::If { .. }));
Ok(())
}
#[test]
fn test_for_loop() -> Result<(), ReadError> {
let ir = read_python("for i in items:\n print(i)")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::ForIn { variable, .. } => assert_eq!(variable, "i"),
_ => panic!("expected ForIn"),
}
Ok(())
}
#[test]
fn test_list_literal() -> Result<(), ReadError> {
let ir = read_python("arr = [1, 2, 3]")?;
assert_eq!(ir.body.len(), 1);
Ok(())
}
#[test]
fn test_dict_literal() -> Result<(), ReadError> {
let ir = read_python("obj = {\"x\": 1, \"y\": 2}")?;
assert_eq!(ir.body.len(), 1);
Ok(())
}
#[test]
fn test_import_statement() -> Result<(), ReadError> {
let ir = read_python("import os")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Import { source, names, .. } => {
assert_eq!(source, "os");
assert_eq!(names.len(), 1);
assert_eq!(names[0].name, "os");
}
_ => panic!("expected Import"),
}
Ok(())
}
#[test]
fn test_import_from_statement() -> Result<(), ReadError> {
let ir = read_python("from os.path import join, exists")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Import { source, names, .. } => {
assert_eq!(source, "os.path");
assert_eq!(names.len(), 2);
assert_eq!(names[0].name, "join");
assert_eq!(names[1].name, "exists");
}
_ => panic!("expected Import"),
}
Ok(())
}
#[test]
fn test_class_definition() -> Result<(), ReadError> {
let ir = read_python(
"class Animal:\n def __init__(self, name):\n self.name = name\n def speak(self):\n pass",
)?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Class {
name,
extends,
methods,
..
} => {
assert_eq!(name, "Animal");
assert!(extends.is_none());
assert_eq!(methods.len(), 2);
assert_eq!(methods[0].name, "__init__");
assert_eq!(methods[1].name, "speak");
}
_ => panic!("expected Class"),
}
Ok(())
}
#[test]
fn test_class_with_base() -> Result<(), ReadError> {
let ir = read_python("class Dog(Animal):\n def speak(self):\n pass")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Class {
name,
extends,
methods,
..
} => {
assert_eq!(name, "Dog");
assert_eq!(extends.as_deref(), Some("Animal"));
assert_eq!(methods.len(), 1);
}
_ => panic!("expected Class"),
}
Ok(())
}
#[test]
fn test_tuple_unpacking_ir() -> Result<(), ReadError> {
let ir = read_python("a, b = func()")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Destructure { pat, .. } => match pat {
Pat::Array(elements, rest) => {
assert_eq!(elements.len(), 2);
assert!(matches!(&elements[0], Some(Pat::Ident(n)) if n == "a"));
assert!(matches!(&elements[1], Some(Pat::Ident(n)) if n == "b"));
assert!(rest.is_none());
}
_ => panic!("expected Pat::Array, got {:?}", pat),
},
_ => panic!("expected Destructure, got {:?}", ir.body[0]),
}
Ok(())
}
#[test]
fn test_list_pattern_ir() -> Result<(), ReadError> {
let ir = read_python("[x, y] = arr")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Destructure { pat, .. } => match pat {
Pat::Array(elements, rest) => {
assert_eq!(elements.len(), 2);
assert!(matches!(&elements[0], Some(Pat::Ident(n)) if n == "x"));
assert!(matches!(&elements[1], Some(Pat::Ident(n)) if n == "y"));
assert!(rest.is_none());
}
_ => panic!("expected Pat::Array"),
},
_ => panic!("expected Destructure"),
}
Ok(())
}
#[test]
fn test_tuple_unpacking_with_rest() -> Result<(), ReadError> {
let ir = read_python("(first, *rest) = items")?;
assert_eq!(ir.body.len(), 1);
match &ir.body[0] {
Stmt::Destructure { pat, .. } => match pat {
Pat::Array(elements, rest) => {
assert_eq!(elements.len(), 1);
assert_eq!(rest.as_deref(), Some("rest"));
}
_ => panic!("expected Pat::Array"),
},
_ => panic!("expected Destructure"),
}
Ok(())
}
#[test]
fn test_tuple_unpacking_round_trip() -> Result<(), ReadError> {
use crate::output::python::PythonWriter;
let ir = read_python("a, b = func()")?;
let out = PythonWriter::emit(&ir);
assert_eq!(out.trim(), "a, b = func()");
Ok(())
}
#[test]
fn test_list_pattern_round_trip() -> Result<(), ReadError> {
use crate::output::python::PythonWriter;
let ir = read_python("[x, y] = arr")?;
let out = PythonWriter::emit(&ir);
assert_eq!(out.trim(), "x, y = arr");
Ok(())
}
}