use super::{Expression, Number, MathConstant, BinaryOperator, UnaryOperator};
use super::memory::{SharedExpression, MemoryManager};
use std::collections::HashMap;
use num_traits::{ToPrimitive, Zero};
pub struct ExpressionBuilder {
memory_manager: MemoryManager,
common_expressions: HashMap<String, SharedExpression>,
}
impl ExpressionBuilder {
pub fn new() -> Self {
let mut builder = Self {
memory_manager: MemoryManager::new(),
common_expressions: HashMap::new(),
};
builder.preload_common_expressions();
builder
}
fn preload_common_expressions(&mut self) {
let common_numbers = [0, 1, -1, 2, -2, 10];
for &n in &common_numbers {
let expr = Expression::Number(Number::integer(n));
let shared = self.memory_manager.create_shared(expr);
self.common_expressions.insert(n.to_string(), shared);
}
let common_vars = ["x", "y", "z", "t", "n"];
for &var in &common_vars {
let expr = Expression::Variable(var.to_string());
let shared = self.memory_manager.create_shared(expr);
self.common_expressions.insert(format!("var_{}", var), shared);
}
let constants = [
("pi", MathConstant::Pi),
("e", MathConstant::E),
("i", MathConstant::I),
];
for (name, constant) in constants {
let expr = Expression::Constant(constant);
let shared = self.memory_manager.create_shared(expr);
self.common_expressions.insert(format!("const_{}", name), shared);
}
}
pub fn number(&mut self, n: Number) -> SharedExpression {
if let Number::Integer(ref big_int) = n {
if let Some(small_int) = big_int.to_i32() {
if let Some(cached) = self.common_expressions.get(&small_int.to_string()) {
return cached.clone_shared();
}
}
}
let expr = Expression::Number(n);
self.memory_manager.create_shared(expr)
}
pub fn variable(&mut self, name: &str) -> SharedExpression {
let cache_key = format!("var_{}", name);
if let Some(cached) = self.common_expressions.get(&cache_key) {
return cached.clone_shared();
}
let expr = Expression::Variable(name.to_string());
let shared = self.memory_manager.create_shared(expr);
if self.common_expressions.len() < 1000 {
self.common_expressions.insert(cache_key, shared.clone_shared());
}
shared
}
pub fn constant(&mut self, c: MathConstant) -> SharedExpression {
let cache_key = format!("const_{:?}", c);
if let Some(cached) = self.common_expressions.get(&cache_key) {
return cached.clone_shared();
}
let expr = Expression::Constant(c);
let shared = self.memory_manager.create_shared(expr);
self.common_expressions.insert(cache_key, shared.clone_shared());
shared
}
pub fn binary_op(&mut self, op: BinaryOperator, left: SharedExpression, right: SharedExpression) -> SharedExpression {
if let Some(simplified) = self.try_simplify_binary_op(&op, &left, &right) {
return simplified;
}
let expr = Expression::BinaryOp {
op,
left: Box::new(left.as_ref().clone()),
right: Box::new(right.as_ref().clone()),
};
self.memory_manager.create_shared(expr)
}
pub fn unary_op(&mut self, op: UnaryOperator, operand: SharedExpression) -> SharedExpression {
if let Some(simplified) = self.try_simplify_unary_op(&op, &operand) {
return simplified;
}
let expr = Expression::UnaryOp {
op,
operand: Box::new(operand.as_ref().clone()),
};
self.memory_manager.create_shared(expr)
}
pub fn function(&mut self, name: &str, args: Vec<SharedExpression>) -> SharedExpression {
let expr_args: Vec<Expression> = args.iter()
.map(|arg| arg.as_ref().clone())
.collect();
let expr = Expression::Function {
name: name.to_string(),
args: expr_args,
};
self.memory_manager.create_shared(expr)
}
fn try_simplify_binary_op(&mut self, op: &BinaryOperator, left: &SharedExpression, right: &SharedExpression) -> Option<SharedExpression> {
match op {
BinaryOperator::Add => {
if self.is_zero(left) {
return Some(right.clone_shared());
}
if self.is_zero(right) {
return Some(left.clone_shared());
}
if left == right {
let two = self.number(Number::integer(2));
return Some(self.binary_op(BinaryOperator::Multiply, two, left.clone_shared()));
}
}
BinaryOperator::Subtract => {
if self.is_zero(right) {
return Some(left.clone_shared());
}
if left == right {
return Some(self.number(Number::integer(0)));
}
}
BinaryOperator::Multiply => {
if self.is_zero(left) || self.is_zero(right) {
return Some(self.number(Number::integer(0)));
}
if self.is_one(left) {
return Some(right.clone_shared());
}
if self.is_one(right) {
return Some(left.clone_shared());
}
if self.is_negative_one(left) {
return Some(self.unary_op(UnaryOperator::Negate, right.clone_shared()));
}
if self.is_negative_one(right) {
return Some(self.unary_op(UnaryOperator::Negate, left.clone_shared()));
}
}
BinaryOperator::Divide => {
if self.is_one(right) {
return Some(left.clone_shared());
}
if left == right {
return Some(self.number(Number::integer(1)));
}
}
BinaryOperator::Power => {
if self.is_zero(right) {
return Some(self.number(Number::integer(1)));
}
if self.is_one(right) {
return Some(left.clone_shared());
}
if self.is_one(left) {
return Some(self.number(Number::integer(1)));
}
}
_ => {}
}
None
}
fn try_simplify_unary_op(&mut self, op: &UnaryOperator, operand: &SharedExpression) -> Option<SharedExpression> {
match op {
UnaryOperator::Negate => {
if let Expression::UnaryOp { op: UnaryOperator::Negate, operand: inner } = operand.as_ref() {
let inner_expr = Expression::clone(inner);
return Some(self.memory_manager.create_shared(inner_expr));
}
if self.is_zero(operand) {
return Some(operand.clone_shared());
}
}
UnaryOperator::Plus => {
return Some(operand.clone_shared());
}
UnaryOperator::Abs => {
if self.is_zero(operand) {
return Some(operand.clone_shared());
}
}
_ => {}
}
None
}
fn is_zero(&self, expr: &SharedExpression) -> bool {
match expr.as_ref() {
Expression::Number(Number::Integer(n)) => n.is_zero(),
Expression::Number(Number::Rational(r)) => r.is_zero(),
Expression::Number(Number::Real(r)) => r.is_zero(),
Expression::Number(Number::Float(f)) => *f == 0.0,
_ => false,
}
}
fn is_one(&self, expr: &SharedExpression) -> bool {
match expr.as_ref() {
Expression::Number(Number::Integer(n)) => n == &num_bigint::BigInt::from(1),
Expression::Number(Number::Rational(r)) => r == &num_rational::BigRational::from_integer(num_bigint::BigInt::from(1)),
Expression::Number(Number::Float(f)) => *f == 1.0,
_ => false,
}
}
fn is_negative_one(&self, expr: &SharedExpression) -> bool {
match expr.as_ref() {
Expression::Number(Number::Integer(n)) => n == &num_bigint::BigInt::from(-1),
Expression::Number(Number::Rational(r)) => r == &num_rational::BigRational::from_integer(num_bigint::BigInt::from(-1)),
Expression::Number(Number::Float(f)) => *f == -1.0,
_ => false,
}
}
pub fn add(&mut self, left: SharedExpression, right: SharedExpression) -> SharedExpression {
self.binary_op(BinaryOperator::Add, left, right)
}
pub fn subtract(&mut self, left: SharedExpression, right: SharedExpression) -> SharedExpression {
self.binary_op(BinaryOperator::Subtract, left, right)
}
pub fn multiply(&mut self, left: SharedExpression, right: SharedExpression) -> SharedExpression {
self.binary_op(BinaryOperator::Multiply, left, right)
}
pub fn divide(&mut self, left: SharedExpression, right: SharedExpression) -> SharedExpression {
self.binary_op(BinaryOperator::Divide, left, right)
}
pub fn power(&mut self, base: SharedExpression, exponent: SharedExpression) -> SharedExpression {
self.binary_op(BinaryOperator::Power, base, exponent)
}
pub fn negate(&mut self, operand: SharedExpression) -> SharedExpression {
self.unary_op(UnaryOperator::Negate, operand)
}
pub fn memory_stats(&mut self) -> &super::memory::MemoryStats {
self.memory_manager.get_stats()
}
pub fn cleanup(&mut self) {
self.memory_manager.cleanup();
self.common_expressions.retain(|_, shared_expr| {
shared_expr.ref_count() > 1
});
}
pub fn memory_manager(&mut self) -> &mut MemoryManager {
&mut self.memory_manager
}
}
impl Default for ExpressionBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct ExpressionFactory {
builder: ExpressionBuilder,
}
impl ExpressionFactory {
pub fn new() -> Self {
Self {
builder: ExpressionBuilder::new(),
}
}
pub fn int(&mut self, value: i64) -> SharedExpression {
self.builder.number(Number::integer(value))
}
pub fn rational(&mut self, numerator: i64, denominator: i64) -> SharedExpression {
self.builder.number(Number::rational(numerator, denominator))
}
pub fn float(&mut self, value: f64) -> SharedExpression {
self.builder.number(Number::Float(value))
}
pub fn var(&mut self, name: &str) -> SharedExpression {
self.builder.variable(name)
}
pub fn pi(&mut self) -> SharedExpression {
self.builder.constant(MathConstant::Pi)
}
pub fn e(&mut self) -> SharedExpression {
self.builder.constant(MathConstant::E)
}
pub fn i(&mut self) -> SharedExpression {
self.builder.constant(MathConstant::I)
}
pub fn add(&mut self, left: SharedExpression, right: SharedExpression) -> SharedExpression {
self.builder.add(left, right)
}
pub fn sub(&mut self, left: SharedExpression, right: SharedExpression) -> SharedExpression {
self.builder.subtract(left, right)
}
pub fn mul(&mut self, left: SharedExpression, right: SharedExpression) -> SharedExpression {
self.builder.multiply(left, right)
}
pub fn div(&mut self, left: SharedExpression, right: SharedExpression) -> SharedExpression {
self.builder.divide(left, right)
}
pub fn pow(&mut self, base: SharedExpression, exponent: SharedExpression) -> SharedExpression {
self.builder.power(base, exponent)
}
pub fn sin(&mut self, operand: SharedExpression) -> SharedExpression {
self.builder.unary_op(UnaryOperator::Sin, operand)
}
pub fn cos(&mut self, operand: SharedExpression) -> SharedExpression {
self.builder.unary_op(UnaryOperator::Cos, operand)
}
pub fn ln(&mut self, operand: SharedExpression) -> SharedExpression {
self.builder.unary_op(UnaryOperator::Ln, operand)
}
pub fn exp(&mut self, operand: SharedExpression) -> SharedExpression {
self.builder.unary_op(UnaryOperator::Exp, operand)
}
pub fn memory_stats(&mut self) -> &super::memory::MemoryStats {
self.builder.memory_stats()
}
pub fn cleanup(&mut self) {
self.builder.cleanup();
}
}
impl Default for ExpressionFactory {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expression_builder_basic() {
let mut builder = ExpressionBuilder::new();
let x = builder.variable("x");
let one = builder.number(Number::integer(1));
let zero = builder.number(Number::integer(0));
let x_plus_zero = builder.add(x.clone_shared(), zero);
assert_eq!(x_plus_zero.as_ref(), x.as_ref());
let one_times_x = builder.multiply(one, x.clone_shared());
assert_eq!(one_times_x.as_ref(), x.as_ref());
}
#[test]
fn test_expression_factory() {
let mut factory = ExpressionFactory::new();
let x = factory.var("x");
let two = factory.int(2);
let pi = factory.pi();
let sum = factory.add(x, pi);
let expr = factory.mul(two, sum);
match expr.as_ref() {
Expression::BinaryOp { op: BinaryOperator::Multiply, left, right } => {
assert!(matches!(left.as_ref(), Expression::Number(Number::Integer(_))));
assert!(matches!(right.as_ref(), Expression::BinaryOp { op: BinaryOperator::Add, .. }));
}
_ => panic!("期望乘法表达式"),
}
}
#[test]
fn test_common_expression_caching() {
let mut builder = ExpressionBuilder::new();
let x1 = builder.variable("x");
let x2 = builder.variable("x");
assert_eq!(x1, x2);
let one1 = builder.number(Number::integer(1));
let one2 = builder.number(Number::integer(1));
assert_eq!(one1, one2);
}
#[test]
fn test_algebraic_simplification() {
let mut builder = ExpressionBuilder::new();
let x = builder.variable("x");
let zero = builder.number(Number::integer(0));
let one = builder.number(Number::integer(1));
let simplified = builder.add(x.clone_shared(), zero.clone_shared());
assert_eq!(simplified.as_ref(), x.as_ref());
let simplified = builder.multiply(x.clone_shared(), one.clone_shared());
assert_eq!(simplified.as_ref(), x.as_ref());
let simplified = builder.subtract(x.clone_shared(), x.clone_shared());
assert_eq!(simplified.as_ref(), zero.as_ref());
}
#[test]
fn test_memory_management() {
let mut builder = ExpressionBuilder::new();
for i in 0..1000 {
let var = builder.variable(&format!("x{}", i));
let num = builder.number(Number::integer(i));
let _expr = builder.add(var, num);
}
let stats_before = builder.memory_stats().clone();
builder.cleanup();
let stats_after = builder.memory_stats().clone();
println!("清理前: 活跃表达式 {}, 共享表达式 {}",
stats_before.active_expressions, stats_before.shared_expressions);
println!("清理后: 活跃表达式 {}, 共享表达式 {}",
stats_after.active_expressions, stats_after.shared_expressions);
}
}