use crate::final_tagless::traits::NumericType;
use num_traits::Float;
#[derive(Debug, Clone, PartialEq)]
pub enum ASTRepr<T> {
Constant(T),
Variable(usize),
Add(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Sub(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Mul(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Div(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Pow(Box<ASTRepr<T>>, Box<ASTRepr<T>>),
Neg(Box<ASTRepr<T>>),
Ln(Box<ASTRepr<T>>),
Exp(Box<ASTRepr<T>>),
Sqrt(Box<ASTRepr<T>>),
Sin(Box<ASTRepr<T>>),
Cos(Box<ASTRepr<T>>),
}
impl<T> ASTRepr<T> {
pub fn count_operations(&self) -> usize {
match self {
ASTRepr::Constant(_) | ASTRepr::Variable(_) => 0,
ASTRepr::Add(left, right)
| ASTRepr::Sub(left, right)
| ASTRepr::Mul(left, right)
| ASTRepr::Div(left, right)
| ASTRepr::Pow(left, right) => 1 + left.count_operations() + right.count_operations(),
ASTRepr::Neg(inner)
| ASTRepr::Ln(inner)
| ASTRepr::Exp(inner)
| ASTRepr::Sin(inner)
| ASTRepr::Cos(inner)
| ASTRepr::Sqrt(inner) => 1 + inner.count_operations(),
}
}
pub fn variable_index(&self) -> Option<usize> {
match self {
ASTRepr::Variable(index) => Some(*index),
_ => None,
}
}
pub fn count_summation_operations(&self) -> usize {
0
}
}
impl<T> ASTRepr<T>
where
T: NumericType,
{
#[must_use]
pub fn pow(self, exp: ASTRepr<T>) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Pow(Box::new(self), Box::new(exp))
}
#[must_use]
pub fn pow_ref(&self, exp: &ASTRepr<T>) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Pow(Box::new(self.clone()), Box::new(exp.clone()))
}
#[must_use]
pub fn ln(self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Ln(Box::new(self))
}
#[must_use]
pub fn ln_ref(&self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Ln(Box::new(self.clone()))
}
#[must_use]
pub fn exp(self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Exp(Box::new(self))
}
#[must_use]
pub fn exp_ref(&self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Exp(Box::new(self.clone()))
}
#[must_use]
pub fn sqrt(self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Sqrt(Box::new(self))
}
#[must_use]
pub fn sqrt_ref(&self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Sqrt(Box::new(self.clone()))
}
#[must_use]
pub fn sin(self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Sin(Box::new(self))
}
#[must_use]
pub fn sin_ref(&self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Sin(Box::new(self.clone()))
}
#[must_use]
pub fn cos(self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Cos(Box::new(self))
}
#[must_use]
pub fn cos_ref(&self) -> ASTRepr<T>
where
T: Float,
{
ASTRepr::Cos(Box::new(self.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ast_repr_basic_operations() {
let x = ASTRepr::<f64>::Variable(0);
let y = ASTRepr::<f64>::Variable(1);
let const_2 = ASTRepr::<f64>::Constant(2.0);
let add_expr = ASTRepr::Add(Box::new(x.clone()), Box::new(y.clone()));
assert_eq!(add_expr.count_operations(), 1);
let mul_expr = ASTRepr::Mul(Box::new(x.clone()), Box::new(const_2.clone()));
assert_eq!(mul_expr.count_operations(), 1);
let complex_expr = ASTRepr::Mul(Box::new(add_expr), Box::new(const_2));
assert_eq!(complex_expr.count_operations(), 2); }
#[test]
fn test_variable_index_access() {
let expr: ASTRepr<f64> = ASTRepr::Variable(5);
assert_eq!(expr.variable_index(), Some(5));
let expr: ASTRepr<f64> = ASTRepr::Constant(42.0);
assert_eq!(expr.variable_index(), None);
}
#[test]
fn test_transcendental_functions() {
let x = ASTRepr::<f64>::Variable(0);
let sin_expr = x.clone().sin();
match sin_expr {
ASTRepr::Sin(_) => {}
_ => panic!("Expected sine expression"),
}
let exp_expr = x.clone().exp();
match exp_expr {
ASTRepr::Exp(_) => {}
_ => panic!("Expected exponential expression"),
}
let ln_expr = x.ln();
match ln_expr {
ASTRepr::Ln(_) => {}
_ => panic!("Expected natural logarithm expression"),
}
}
#[test]
fn test_convenience_methods() {
let x = ASTRepr::<f64>::Variable(0);
let two = ASTRepr::<f64>::Constant(2.0);
let pow_expr = x.pow_ref(&two);
match pow_expr {
ASTRepr::Pow(_, _) => {}
_ => panic!("Expected power expression"),
}
let sqrt_expr = x.sqrt_ref();
match sqrt_expr {
ASTRepr::Sqrt(_) => {}
_ => panic!("Expected square root expression"),
}
}
}