use crate::{AsgConvertError, ConstValue, Expression, ExpressionNode, FromAst, Node, PartialType, Scope, Span, Type};
pub use leo_ast::{BinaryOperation, BinaryOperationClass};
use std::cell::Cell;
#[derive(Clone)]
pub struct BinaryExpression<'a> {
pub parent: Cell<Option<&'a Expression<'a>>>,
pub span: Option<Span>,
pub operation: BinaryOperation,
pub left: Cell<&'a Expression<'a>>,
pub right: Cell<&'a Expression<'a>>,
}
impl<'a> Node for BinaryExpression<'a> {
fn span(&self) -> Option<&Span> {
self.span.as_ref()
}
}
impl<'a> ExpressionNode<'a> for BinaryExpression<'a> {
fn set_parent(&self, parent: &'a Expression<'a>) {
self.parent.replace(Some(parent));
}
fn get_parent(&self) -> Option<&'a Expression<'a>> {
self.parent.get()
}
fn enforce_parents(&self, expr: &'a Expression<'a>) {
self.left.get().set_parent(expr);
self.right.get().set_parent(expr);
}
fn get_type(&self) -> Option<Type<'a>> {
match self.operation.class() {
BinaryOperationClass::Boolean => Some(Type::Boolean),
BinaryOperationClass::Numeric => self.left.get().get_type(),
}
}
fn is_mut_ref(&self) -> bool {
false
}
fn const_value(&self) -> Option<ConstValue> {
use BinaryOperation::*;
let left = self.left.get().const_value()?;
let right = self.right.get().const_value()?;
match (left, right) {
(ConstValue::Int(left), ConstValue::Int(right)) => Some(match self.operation {
Add => ConstValue::Int(left.value_add(&right)?),
Sub => ConstValue::Int(left.value_sub(&right)?),
Mul => ConstValue::Int(left.value_mul(&right)?),
Div => ConstValue::Int(left.value_div(&right)?),
Pow => ConstValue::Int(left.value_pow(&right)?),
Eq => ConstValue::Boolean(left == right),
Ne => ConstValue::Boolean(left != right),
Ge => ConstValue::Boolean(left.value_ge(&right)?),
Gt => ConstValue::Boolean(left.value_gt(&right)?),
Le => ConstValue::Boolean(left.value_le(&right)?),
Lt => ConstValue::Boolean(left.value_lt(&right)?),
_ => return None,
}),
(ConstValue::Boolean(left), ConstValue::Boolean(right)) => Some(match self.operation {
Eq => ConstValue::Boolean(left == right),
Ne => ConstValue::Boolean(left != right),
And => ConstValue::Boolean(left && right),
Or => ConstValue::Boolean(left || right),
_ => return None,
}),
(left, right) => Some(match self.operation {
Eq => ConstValue::Boolean(left == right),
Ne => ConstValue::Boolean(left != right),
_ => return None,
}),
}
}
fn is_consty(&self) -> bool {
self.left.get().is_consty() && self.right.get().is_consty()
}
}
impl<'a> FromAst<'a, leo_ast::BinaryExpression> for BinaryExpression<'a> {
fn from_ast(
scope: &'a Scope<'a>,
value: &leo_ast::BinaryExpression,
expected_type: Option<PartialType<'a>>,
) -> Result<BinaryExpression<'a>, AsgConvertError> {
let class = value.op.class();
let expected_type = match class {
BinaryOperationClass::Boolean => match expected_type {
Some(PartialType::Type(Type::Boolean)) | None => None,
Some(x) => {
return Err(AsgConvertError::unexpected_type(
&x.to_string(),
Some(&*Type::Boolean.to_string()),
&value.span,
));
}
},
BinaryOperationClass::Numeric => match expected_type {
Some(x @ PartialType::Integer(_, _)) => Some(x),
Some(x @ PartialType::Type(Type::Field)) => Some(x),
Some(x @ PartialType::Type(Type::Group)) => Some(x),
Some(x) => {
return Err(AsgConvertError::unexpected_type(
&x.to_string(),
Some("integer, field, or group"),
&value.span,
));
}
None => None,
},
};
let (left, right) = match <&Expression<'a>>::from_ast(scope, &*value.left, expected_type.clone()) {
Ok(left) => {
if let Some(left_type) = left.get_type() {
let right = <&Expression<'a>>::from_ast(scope, &*value.right, Some(left_type.partial()))?;
(left, right)
} else {
let right = <&Expression<'a>>::from_ast(scope, &*value.right, expected_type)?;
if let Some(right_type) = right.get_type() {
(
<&Expression<'a>>::from_ast(scope, &*value.left, Some(right_type.partial()))?,
right,
)
} else {
(left, right)
}
}
}
Err(e) => {
let right = <&Expression<'a>>::from_ast(scope, &*value.right, expected_type)?;
if let Some(right_type) = right.get_type() {
(
<&Expression<'a>>::from_ast(scope, &*value.left, Some(right_type.partial()))?,
right,
)
} else {
return Err(e);
}
}
};
let left_type = left.get_type();
#[allow(clippy::unused_unit)]
match class {
BinaryOperationClass::Numeric => match left_type {
Some(Type::Integer(_)) => (),
Some(Type::Group) | Some(Type::Field)
if value.op == BinaryOperation::Add || value.op == BinaryOperation::Sub =>
{
()
}
Some(Type::Field) if value.op == BinaryOperation::Mul || value.op == BinaryOperation::Div => (),
type_ => {
return Err(AsgConvertError::unexpected_type(
"integer",
type_.map(|x| x.to_string()).as_deref(),
&value.span,
));
}
},
BinaryOperationClass::Boolean => match &value.op {
BinaryOperation::And | BinaryOperation::Or => match left_type {
Some(Type::Boolean) | None => (),
Some(x) => {
return Err(AsgConvertError::unexpected_type(
&x.to_string(),
Some(&*Type::Boolean.to_string()),
&value.span,
));
}
},
BinaryOperation::Eq | BinaryOperation::Ne => (),
_ => match left_type {
Some(Type::Integer(_)) | None => (),
Some(x) => {
return Err(AsgConvertError::unexpected_type(
&x.to_string(),
Some("integer"),
&value.span,
));
}
},
},
}
let right_type = right.get_type();
match (left_type, right_type) {
(Some(left_type), Some(right_type)) => {
if !left_type.is_assignable_from(&right_type) {
return Err(AsgConvertError::unexpected_type(
&left_type.to_string(),
Some(&*right_type.to_string()),
&value.span,
));
}
}
(None, None) => {
return Err(AsgConvertError::unexpected_type(
"any type",
Some("unknown type"),
&value.span,
));
}
(_, _) => (),
}
Ok(BinaryExpression {
parent: Cell::new(None),
span: Some(value.span.clone()),
operation: value.op.clone(),
left: Cell::new(left),
right: Cell::new(right),
})
}
}
impl<'a> Into<leo_ast::BinaryExpression> for &BinaryExpression<'a> {
fn into(self) -> leo_ast::BinaryExpression {
leo_ast::BinaryExpression {
op: self.operation.clone(),
left: Box::new(self.left.get().into()),
right: Box::new(self.right.get().into()),
span: self.span.clone().unwrap_or_default(),
}
}
}