use crate::core::error::{Error, Result};
use crate::optimized::jit::types::{JitNumeric, NumericValue, TypedVector};
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub enum ExpressionNode {
Constant(NumericValue),
Variable {
name: String,
var_type: String,
index: usize,
},
ArrayAccess {
array: Box<ExpressionNode>,
index: Box<ExpressionNode>,
},
BinaryOp {
left: Box<ExpressionNode>,
right: Box<ExpressionNode>,
operator: BinaryOperator,
},
UnaryOp {
operand: Box<ExpressionNode>,
operator: UnaryOperator,
},
FunctionCall {
function: String,
arguments: Vec<ExpressionNode>,
},
Reduction {
array: Box<ExpressionNode>,
operation: ReductionOperation,
axis: Option<usize>,
},
Conditional {
condition: Box<ExpressionNode>,
true_expr: Box<ExpressionNode>,
false_expr: Box<ExpressionNode>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BinaryOperator {
Add,
Subtract,
Multiply,
Divide,
Modulo,
Power,
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
LogicalAnd,
LogicalOr,
BitwiseAnd,
BitwiseOr,
BitwiseXor,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UnaryOperator {
Negate,
LogicalNot,
BitwiseNot,
Abs,
Sqrt,
Sin,
Cos,
Tan,
Log,
Exp,
Floor,
Ceil,
Round,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReductionOperation {
Sum,
Product,
Mean,
Min,
Max,
Count,
Any,
All,
Variance,
StandardDeviation,
}
impl fmt::Display for BinaryOperator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BinaryOperator::Add => write!(f, "+"),
BinaryOperator::Subtract => write!(f, "-"),
BinaryOperator::Multiply => write!(f, "*"),
BinaryOperator::Divide => write!(f, "/"),
BinaryOperator::Modulo => write!(f, "%"),
BinaryOperator::Power => write!(f, "**"),
BinaryOperator::Equal => write!(f, "=="),
BinaryOperator::NotEqual => write!(f, "!="),
BinaryOperator::LessThan => write!(f, "<"),
BinaryOperator::LessThanOrEqual => write!(f, "<="),
BinaryOperator::GreaterThan => write!(f, ">"),
BinaryOperator::GreaterThanOrEqual => write!(f, ">="),
BinaryOperator::LogicalAnd => write!(f, "&&"),
BinaryOperator::LogicalOr => write!(f, "||"),
BinaryOperator::BitwiseAnd => write!(f, "&"),
BinaryOperator::BitwiseOr => write!(f, "|"),
BinaryOperator::BitwiseXor => write!(f, "^"),
}
}
}
impl fmt::Display for UnaryOperator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UnaryOperator::Negate => write!(f, "-"),
UnaryOperator::LogicalNot => write!(f, "!"),
UnaryOperator::BitwiseNot => write!(f, "~"),
UnaryOperator::Abs => write!(f, "abs"),
UnaryOperator::Sqrt => write!(f, "sqrt"),
UnaryOperator::Sin => write!(f, "sin"),
UnaryOperator::Cos => write!(f, "cos"),
UnaryOperator::Tan => write!(f, "tan"),
UnaryOperator::Log => write!(f, "log"),
UnaryOperator::Exp => write!(f, "exp"),
UnaryOperator::Floor => write!(f, "floor"),
UnaryOperator::Ceil => write!(f, "ceil"),
UnaryOperator::Round => write!(f, "round"),
}
}
}
impl Hash for ExpressionNode {
fn hash<H: Hasher>(&self, state: &mut H) {
std::mem::discriminant(self).hash(state);
match self {
ExpressionNode::Constant(value) => {
value.type_name().hash(state);
value.to_f64().to_bits().hash(state);
}
ExpressionNode::Variable {
name,
var_type,
index,
} => {
name.hash(state);
var_type.hash(state);
index.hash(state);
}
ExpressionNode::ArrayAccess { array, index } => {
array.hash(state);
index.hash(state);
}
ExpressionNode::BinaryOp {
left,
right,
operator,
} => {
left.hash(state);
right.hash(state);
operator.hash(state);
}
ExpressionNode::UnaryOp { operand, operator } => {
operand.hash(state);
operator.hash(state);
}
ExpressionNode::FunctionCall {
function,
arguments,
} => {
function.hash(state);
for arg in arguments {
arg.hash(state);
}
}
ExpressionNode::Reduction {
array,
operation,
axis,
} => {
array.hash(state);
operation.hash(state);
axis.hash(state);
}
ExpressionNode::Conditional {
condition,
true_expr,
false_expr,
} => {
condition.hash(state);
true_expr.hash(state);
false_expr.hash(state);
}
}
}
}
#[derive(Debug, Clone)]
pub struct ExpressionTree {
pub root: ExpressionNode,
pub variables: HashMap<String, VariableInfo>,
pub metadata: ExpressionMetadata,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct VariableInfo {
pub name: String,
pub data_type: String,
pub shape: Option<Vec<usize>>,
pub is_mutable: bool,
}
#[derive(Debug, Clone)]
pub struct ExpressionMetadata {
pub complexity: u64,
pub is_vectorizable: bool,
pub is_parallelizable: bool,
pub estimated_memory_usage: usize,
pub optimizations: Vec<OptimizationOpportunity>,
}
#[derive(Debug, Clone)]
pub struct OptimizationOpportunity {
pub optimization_type: OptimizationType,
pub description: String,
pub estimated_improvement: f64,
pub node_path: Vec<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizationType {
CommonSubexpressionElimination,
ConstantFolding,
AlgebraicSimplification,
LoopFusion,
Vectorization,
MemoryLayoutOptimization,
BranchElimination,
StrengthReduction,
}
impl ExpressionTree {
pub fn new(root: ExpressionNode) -> Self {
let mut tree = Self {
root,
variables: HashMap::new(),
metadata: ExpressionMetadata {
complexity: 0,
is_vectorizable: false,
is_parallelizable: false,
estimated_memory_usage: 0,
optimizations: Vec::new(),
},
};
tree.analyze();
tree
}
pub fn add_variable(
&mut self,
name: String,
data_type: String,
shape: Option<Vec<usize>>,
is_mutable: bool,
) {
let var_info = VariableInfo {
name: name.clone(),
data_type,
shape,
is_mutable,
};
self.variables.insert(name, var_info);
}
pub fn analyze(&mut self) {
self.calculate_complexity();
self.analyze_vectorization();
self.analyze_parallelization();
self.estimate_memory_usage();
self.find_optimizations();
}
fn calculate_complexity(&mut self) {
self.metadata.complexity = self.calculate_node_complexity(&self.root);
}
fn calculate_node_complexity(&self, node: &ExpressionNode) -> u64 {
match node {
ExpressionNode::Constant(_) => 1,
ExpressionNode::Variable { .. } => 1,
ExpressionNode::ArrayAccess { array, index } => {
self.calculate_node_complexity(array) + self.calculate_node_complexity(index) + 2
}
ExpressionNode::BinaryOp {
left,
right,
operator,
} => {
let left_complexity = self.calculate_node_complexity(left);
let right_complexity = self.calculate_node_complexity(right);
let op_complexity = match operator {
BinaryOperator::Add | BinaryOperator::Subtract => 1,
BinaryOperator::Multiply => 2,
BinaryOperator::Divide => 5,
BinaryOperator::Power => 10,
_ => 1,
};
left_complexity + right_complexity + op_complexity
}
ExpressionNode::UnaryOp { operand, operator } => {
let operand_complexity = self.calculate_node_complexity(operand);
let op_complexity = match operator {
UnaryOperator::Negate | UnaryOperator::LogicalNot => 1,
UnaryOperator::Abs => 2,
UnaryOperator::Sqrt => 5,
UnaryOperator::Sin | UnaryOperator::Cos | UnaryOperator::Tan => 10,
UnaryOperator::Log | UnaryOperator::Exp => 8,
_ => 3,
};
operand_complexity + op_complexity
}
ExpressionNode::FunctionCall { arguments, .. } => {
let args_complexity: u64 = arguments
.iter()
.map(|arg| self.calculate_node_complexity(arg))
.sum();
args_complexity + 10 }
ExpressionNode::Reduction {
array, operation, ..
} => {
let array_complexity = self.calculate_node_complexity(array);
let reduction_complexity = match operation {
ReductionOperation::Sum | ReductionOperation::Count => 1,
ReductionOperation::Mean => 2,
ReductionOperation::Min | ReductionOperation::Max => 1,
ReductionOperation::Variance | ReductionOperation::StandardDeviation => 10,
_ => 5,
};
array_complexity * reduction_complexity
}
ExpressionNode::Conditional {
condition,
true_expr,
false_expr,
} => {
let condition_complexity = self.calculate_node_complexity(condition);
let true_complexity = self.calculate_node_complexity(true_expr);
let false_complexity = self.calculate_node_complexity(false_expr);
condition_complexity + true_complexity.max(false_complexity) + 2
}
}
}
fn analyze_vectorization(&mut self) {
self.metadata.is_vectorizable = self.is_node_vectorizable(&self.root);
}
fn is_node_vectorizable(&self, node: &ExpressionNode) -> bool {
match node {
ExpressionNode::Constant(_) => true,
ExpressionNode::Variable { .. } => true,
ExpressionNode::ArrayAccess { .. } => true,
ExpressionNode::BinaryOp {
left,
right,
operator,
} => {
let vectorizable_ops = [
BinaryOperator::Add,
BinaryOperator::Subtract,
BinaryOperator::Multiply,
BinaryOperator::Divide,
];
vectorizable_ops.contains(operator)
&& self.is_node_vectorizable(left)
&& self.is_node_vectorizable(right)
}
ExpressionNode::UnaryOp { operand, operator } => {
let vectorizable_ops = [
UnaryOperator::Negate,
UnaryOperator::Abs,
UnaryOperator::Sqrt,
];
vectorizable_ops.contains(operator) && self.is_node_vectorizable(operand)
}
ExpressionNode::Reduction { .. } => true, _ => false, }
}
fn analyze_parallelization(&mut self) {
self.metadata.is_parallelizable = self.is_node_parallelizable(&self.root);
}
fn is_node_parallelizable(&self, node: &ExpressionNode) -> bool {
match node {
ExpressionNode::Constant(_) => true,
ExpressionNode::Variable { .. } => true,
ExpressionNode::ArrayAccess { .. } => true,
ExpressionNode::BinaryOp { left, right, .. } => {
self.is_node_parallelizable(left) && self.is_node_parallelizable(right)
}
ExpressionNode::UnaryOp { operand, .. } => self.is_node_parallelizable(operand),
ExpressionNode::Reduction { .. } => true, ExpressionNode::FunctionCall { arguments, .. } => {
arguments.iter().all(|arg| self.is_node_parallelizable(arg))
}
ExpressionNode::Conditional { .. } => false, }
}
fn estimate_memory_usage(&mut self) {
self.metadata.estimated_memory_usage = self.estimate_node_memory(&self.root);
}
fn estimate_node_memory(&self, node: &ExpressionNode) -> usize {
match node {
ExpressionNode::Constant(_) => 8, ExpressionNode::Variable { var_type, .. } => {
let base_size = match var_type.as_str() {
"f64" => 8,
"f32" => 4,
"i64" => 8,
"i32" => 4,
_ => 8,
};
if let Some(var_info) = self.variables.get(&format!("{}", var_type)) {
if let Some(shape) = &var_info.shape {
let total_elements: usize = shape.iter().product();
base_size * total_elements
} else {
base_size
}
} else {
base_size
}
}
ExpressionNode::ArrayAccess { array, index } => {
self.estimate_node_memory(array) + self.estimate_node_memory(index)
}
ExpressionNode::BinaryOp { left, right, .. } => {
self.estimate_node_memory(left) + self.estimate_node_memory(right) + 8
}
ExpressionNode::UnaryOp { operand, .. } => self.estimate_node_memory(operand) + 8,
ExpressionNode::FunctionCall { arguments, .. } => {
arguments
.iter()
.map(|arg| self.estimate_node_memory(arg))
.sum::<usize>()
+ 64
}
ExpressionNode::Reduction { array, .. } => {
self.estimate_node_memory(array) + 64 }
ExpressionNode::Conditional {
condition,
true_expr,
false_expr,
} => {
self.estimate_node_memory(condition)
+ self
.estimate_node_memory(true_expr)
.max(self.estimate_node_memory(false_expr))
}
}
}
fn find_optimizations(&mut self) {
self.metadata.optimizations.clear();
let root_clone = self.root.clone();
self.find_node_optimizations(&root_clone, &mut Vec::new());
}
fn find_node_optimizations(&mut self, node: &ExpressionNode, path: &mut Vec<usize>) {
if self.can_constant_fold(node) {
self.metadata.optimizations.push(OptimizationOpportunity {
optimization_type: OptimizationType::ConstantFolding,
description: "Constant expression can be pre-computed".to_string(),
estimated_improvement: 0.2,
node_path: path.clone(),
});
}
if self.can_algebraic_simplify(node) {
self.metadata.optimizations.push(OptimizationOpportunity {
optimization_type: OptimizationType::AlgebraicSimplification,
description: "Expression can be algebraically simplified".to_string(),
estimated_improvement: 0.15,
node_path: path.clone(),
});
}
if self.metadata.is_vectorizable && self.can_vectorize(node) {
self.metadata.optimizations.push(OptimizationOpportunity {
optimization_type: OptimizationType::Vectorization,
description: "Operation can be vectorized with SIMD instructions".to_string(),
estimated_improvement: 0.4,
node_path: path.clone(),
});
}
match node {
ExpressionNode::ArrayAccess { array, index } => {
path.push(0);
self.find_node_optimizations(array, path);
path.pop();
path.push(1);
self.find_node_optimizations(index, path);
path.pop();
}
ExpressionNode::BinaryOp { left, right, .. } => {
path.push(0);
self.find_node_optimizations(left, path);
path.pop();
path.push(1);
self.find_node_optimizations(right, path);
path.pop();
}
ExpressionNode::UnaryOp { operand, .. } => {
path.push(0);
self.find_node_optimizations(operand, path);
path.pop();
}
ExpressionNode::FunctionCall { arguments, .. } => {
for (i, arg) in arguments.iter().enumerate() {
path.push(i);
self.find_node_optimizations(arg, path);
path.pop();
}
}
ExpressionNode::Reduction { array, .. } => {
path.push(0);
self.find_node_optimizations(array, path);
path.pop();
}
ExpressionNode::Conditional {
condition,
true_expr,
false_expr,
} => {
path.push(0);
self.find_node_optimizations(condition, path);
path.pop();
path.push(1);
self.find_node_optimizations(true_expr, path);
path.pop();
path.push(2);
self.find_node_optimizations(false_expr, path);
path.pop();
}
_ => {} }
}
fn can_constant_fold(&self, node: &ExpressionNode) -> bool {
match node {
ExpressionNode::BinaryOp { left, right, .. } => {
matches!(**left, ExpressionNode::Constant(_))
&& matches!(**right, ExpressionNode::Constant(_))
}
ExpressionNode::UnaryOp { operand, .. } => {
matches!(**operand, ExpressionNode::Constant(_))
}
_ => false,
}
}
fn can_algebraic_simplify(&self, node: &ExpressionNode) -> bool {
match node {
ExpressionNode::BinaryOp {
left,
right,
operator,
} => {
match operator {
BinaryOperator::Add => {
self.is_zero_constant(left) || self.is_zero_constant(right)
}
BinaryOperator::Multiply => {
self.is_zero_constant(left)
|| self.is_zero_constant(right)
|| self.is_one_constant(left)
|| self.is_one_constant(right)
}
BinaryOperator::Power => {
self.is_zero_constant(right) || self.is_one_constant(right)
}
_ => false,
}
}
_ => false,
}
}
fn can_vectorize(&self, node: &ExpressionNode) -> bool {
match node {
ExpressionNode::BinaryOp { operator, .. } => {
matches!(
operator,
BinaryOperator::Add
| BinaryOperator::Subtract
| BinaryOperator::Multiply
| BinaryOperator::Divide
)
}
ExpressionNode::UnaryOp { operator, .. } => {
matches!(
operator,
UnaryOperator::Negate | UnaryOperator::Abs | UnaryOperator::Sqrt
)
}
_ => false,
}
}
fn is_zero_constant(&self, node: &ExpressionNode) -> bool {
match node {
ExpressionNode::Constant(value) => value.to_f64() == 0.0,
_ => false,
}
}
fn is_one_constant(&self, node: &ExpressionNode) -> bool {
match node {
ExpressionNode::Constant(value) => value.to_f64() == 1.0,
_ => false,
}
}
pub fn optimize(&self) -> Result<ExpressionTree> {
let mut optimized_root = self.root.clone();
optimized_root = self.apply_constant_folding(optimized_root)?;
optimized_root = self.apply_algebraic_simplification(optimized_root)?;
optimized_root = self.apply_common_subexpression_elimination(optimized_root)?;
let mut optimized_tree = ExpressionTree::new(optimized_root);
optimized_tree.variables = self.variables.clone();
Ok(optimized_tree)
}
fn apply_constant_folding(&self, node: ExpressionNode) -> Result<ExpressionNode> {
match node {
ExpressionNode::BinaryOp {
left,
right,
operator,
} => {
let left = self.apply_constant_folding(*left)?;
let right = self.apply_constant_folding(*right)?;
if let (ExpressionNode::Constant(l_val), ExpressionNode::Constant(r_val)) =
(&left, &right)
{
let result = self.evaluate_binary_constant_op(l_val, r_val, operator)?;
Ok(ExpressionNode::Constant(result))
} else {
Ok(ExpressionNode::BinaryOp {
left: Box::new(left),
right: Box::new(right),
operator,
})
}
}
ExpressionNode::UnaryOp { operand, operator } => {
let operand = self.apply_constant_folding(*operand)?;
if let ExpressionNode::Constant(val) = &operand {
let result = self.evaluate_unary_constant_op(val, operator)?;
Ok(ExpressionNode::Constant(result))
} else {
Ok(ExpressionNode::UnaryOp {
operand: Box::new(operand),
operator,
})
}
}
_ => Ok(node), }
}
fn apply_algebraic_simplification(&self, node: ExpressionNode) -> Result<ExpressionNode> {
match node {
ExpressionNode::BinaryOp {
left,
right,
operator,
} => {
let left = self.apply_algebraic_simplification(*left)?;
let right = self.apply_algebraic_simplification(*right)?;
match operator {
BinaryOperator::Add => {
if self.is_zero_constant(&left) {
return Ok(right);
}
if self.is_zero_constant(&right) {
return Ok(left);
}
}
BinaryOperator::Multiply => {
if self.is_zero_constant(&left) || self.is_zero_constant(&right) {
return Ok(ExpressionNode::Constant(NumericValue::F64(0.0)));
}
if self.is_one_constant(&left) {
return Ok(right);
}
if self.is_one_constant(&right) {
return Ok(left);
}
}
BinaryOperator::Power => {
if self.is_zero_constant(&right) {
return Ok(ExpressionNode::Constant(NumericValue::F64(1.0)));
}
if self.is_one_constant(&right) {
return Ok(left);
}
}
_ => {}
}
Ok(ExpressionNode::BinaryOp {
left: Box::new(left),
right: Box::new(right),
operator,
})
}
_ => Ok(node),
}
}
fn apply_common_subexpression_elimination(
&self,
node: ExpressionNode,
) -> Result<ExpressionNode> {
Ok(node)
}
fn evaluate_binary_constant_op(
&self,
left: &NumericValue,
right: &NumericValue,
operator: BinaryOperator,
) -> Result<NumericValue> {
let l_val = left.to_f64();
let r_val = right.to_f64();
let result = match operator {
BinaryOperator::Add => l_val + r_val,
BinaryOperator::Subtract => l_val - r_val,
BinaryOperator::Multiply => l_val * r_val,
BinaryOperator::Divide => {
if r_val == 0.0 {
return Err(Error::InvalidOperation("Division by zero".to_string()));
}
l_val / r_val
}
BinaryOperator::Power => l_val.powf(r_val),
_ => {
return Err(Error::NotImplemented(format!(
"Constant evaluation for {:?}",
operator
)))
}
};
Ok(NumericValue::F64(result))
}
fn evaluate_unary_constant_op(
&self,
operand: &NumericValue,
operator: UnaryOperator,
) -> Result<NumericValue> {
let val = operand.to_f64();
let result = match operator {
UnaryOperator::Negate => -val,
UnaryOperator::Abs => val.abs(),
UnaryOperator::Sqrt => {
if val < 0.0 {
return Err(Error::InvalidOperation(
"Square root of negative number".to_string(),
));
}
val.sqrt()
}
UnaryOperator::Sin => val.sin(),
UnaryOperator::Cos => val.cos(),
UnaryOperator::Tan => val.tan(),
UnaryOperator::Log => {
if val <= 0.0 {
return Err(Error::InvalidOperation(
"Logarithm of non-positive number".to_string(),
));
}
val.ln()
}
UnaryOperator::Exp => val.exp(),
UnaryOperator::Floor => val.floor(),
UnaryOperator::Ceil => val.ceil(),
UnaryOperator::Round => val.round(),
_ => {
return Err(Error::NotImplemented(format!(
"Constant evaluation for {:?}",
operator
)))
}
};
Ok(NumericValue::F64(result))
}
pub fn to_string(&self) -> String {
self.node_to_string(&self.root)
}
fn node_to_string(&self, node: &ExpressionNode) -> String {
match node {
ExpressionNode::Constant(value) => format!("{}", value.to_f64()),
ExpressionNode::Variable { name, .. } => name.clone(),
ExpressionNode::ArrayAccess { array, index } => {
format!(
"{}[{}]",
self.node_to_string(array),
self.node_to_string(index)
)
}
ExpressionNode::BinaryOp {
left,
right,
operator,
} => {
format!(
"({} {} {})",
self.node_to_string(left),
operator,
self.node_to_string(right)
)
}
ExpressionNode::UnaryOp { operand, operator } => {
format!("{}({})", operator, self.node_to_string(operand))
}
ExpressionNode::FunctionCall {
function,
arguments,
} => {
let args: Vec<String> = arguments
.iter()
.map(|arg| self.node_to_string(arg))
.collect();
format!("{}({})", function, args.join(", "))
}
ExpressionNode::Reduction {
array,
operation,
axis,
} => {
let axis_str = if let Some(axis) = axis {
format!(", axis={})", axis)
} else {
")".to_string()
};
format!("{:?}({}{}", operation, self.node_to_string(array), axis_str)
}
ExpressionNode::Conditional {
condition,
true_expr,
false_expr,
} => {
format!(
"if {} then {} else {}",
self.node_to_string(condition),
self.node_to_string(true_expr),
self.node_to_string(false_expr)
)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expression_creation() {
let x = ExpressionNode::Variable {
name: "x".to_string(),
var_type: "f64".to_string(),
index: 0,
};
let constant = ExpressionNode::Constant(NumericValue::F64(2.0));
let expr = ExpressionNode::BinaryOp {
left: Box::new(x),
right: Box::new(constant),
operator: BinaryOperator::Multiply,
};
let tree = ExpressionTree::new(expr);
assert!(tree.metadata.complexity > 0);
}
#[test]
fn test_constant_folding() {
let left = ExpressionNode::Constant(NumericValue::F64(2.0));
let right = ExpressionNode::Constant(NumericValue::F64(3.0));
let expr = ExpressionNode::BinaryOp {
left: Box::new(left),
right: Box::new(right),
operator: BinaryOperator::Add,
};
let tree = ExpressionTree::new(expr);
let optimized = tree.optimize().expect("operation should succeed");
if let ExpressionNode::Constant(value) = optimized.root {
assert_eq!(value.to_f64(), 5.0);
} else {
panic!("Expected constant folded result");
}
}
#[test]
fn test_algebraic_simplification() {
let x = ExpressionNode::Variable {
name: "x".to_string(),
var_type: "f64".to_string(),
index: 0,
};
let zero = ExpressionNode::Constant(NumericValue::F64(0.0));
let expr = ExpressionNode::BinaryOp {
left: Box::new(x.clone()),
right: Box::new(zero),
operator: BinaryOperator::Add,
};
let tree = ExpressionTree::new(expr);
let optimized = tree.optimize().expect("operation should succeed");
if let ExpressionNode::Variable { name, .. } = optimized.root {
assert_eq!(name, "x");
} else {
panic!("Expected simplified result to be just variable x");
}
}
#[test]
fn test_expression_string_representation() {
let x = ExpressionNode::Variable {
name: "x".to_string(),
var_type: "f64".to_string(),
index: 0,
};
let constant = ExpressionNode::Constant(NumericValue::F64(2.0));
let expr = ExpressionNode::BinaryOp {
left: Box::new(x),
right: Box::new(constant),
operator: BinaryOperator::Multiply,
};
let tree = ExpressionTree::new(expr);
let string_repr = tree.to_string();
assert_eq!(string_repr, "(x * 2)");
}
}