use super::ast_repr::ASTRepr;
use crate::final_tagless::traits::NumericType;
use std::ops::{Add, Div, Mul, Neg, Sub};
impl<T> Add for ASTRepr<T>
where
T: NumericType + Add<Output = T>,
{
type Output = ASTRepr<T>;
fn add(self, rhs: Self) -> Self::Output {
ASTRepr::Add(Box::new(self), Box::new(rhs))
}
}
impl<T, R> Add<R> for &ASTRepr<T>
where
T: NumericType + Add<Output = T>,
R: AsRef<ASTRepr<T>>,
{
type Output = ASTRepr<T>;
fn add(self, rhs: R) -> Self::Output {
ASTRepr::Add(Box::new(self.clone()), Box::new(rhs.as_ref().clone()))
}
}
impl<T> Add<&ASTRepr<T>> for ASTRepr<T>
where
T: NumericType + Add<Output = T>,
{
type Output = ASTRepr<T>;
fn add(self, rhs: &ASTRepr<T>) -> Self::Output {
ASTRepr::Add(Box::new(self), Box::new(rhs.clone()))
}
}
impl<T> Sub for ASTRepr<T>
where
T: NumericType + Sub<Output = T>,
{
type Output = ASTRepr<T>;
fn sub(self, rhs: Self) -> Self::Output {
ASTRepr::Sub(Box::new(self), Box::new(rhs))
}
}
impl<T, R> Sub<R> for &ASTRepr<T>
where
T: NumericType + Sub<Output = T>,
R: AsRef<ASTRepr<T>>,
{
type Output = ASTRepr<T>;
fn sub(self, rhs: R) -> Self::Output {
ASTRepr::Sub(Box::new(self.clone()), Box::new(rhs.as_ref().clone()))
}
}
impl<T> Sub<&ASTRepr<T>> for ASTRepr<T>
where
T: NumericType + Sub<Output = T>,
{
type Output = ASTRepr<T>;
fn sub(self, rhs: &ASTRepr<T>) -> Self::Output {
ASTRepr::Sub(Box::new(self), Box::new(rhs.clone()))
}
}
impl<T> Mul for ASTRepr<T>
where
T: NumericType + Mul<Output = T>,
{
type Output = ASTRepr<T>;
fn mul(self, rhs: Self) -> Self::Output {
ASTRepr::Mul(Box::new(self), Box::new(rhs))
}
}
impl<T, R> Mul<R> for &ASTRepr<T>
where
T: NumericType + Mul<Output = T>,
R: AsRef<ASTRepr<T>>,
{
type Output = ASTRepr<T>;
fn mul(self, rhs: R) -> Self::Output {
ASTRepr::Mul(Box::new(self.clone()), Box::new(rhs.as_ref().clone()))
}
}
impl<T> Mul<&ASTRepr<T>> for ASTRepr<T>
where
T: NumericType + Mul<Output = T>,
{
type Output = ASTRepr<T>;
fn mul(self, rhs: &ASTRepr<T>) -> Self::Output {
ASTRepr::Mul(Box::new(self), Box::new(rhs.clone()))
}
}
impl<T> Div for ASTRepr<T>
where
T: NumericType + Div<Output = T>,
{
type Output = ASTRepr<T>;
fn div(self, rhs: Self) -> Self::Output {
ASTRepr::Div(Box::new(self), Box::new(rhs))
}
}
impl<T, R> Div<R> for &ASTRepr<T>
where
T: NumericType + Div<Output = T>,
R: AsRef<ASTRepr<T>>,
{
type Output = ASTRepr<T>;
fn div(self, rhs: R) -> Self::Output {
ASTRepr::Div(Box::new(self.clone()), Box::new(rhs.as_ref().clone()))
}
}
impl<T> Div<&ASTRepr<T>> for ASTRepr<T>
where
T: NumericType + Div<Output = T>,
{
type Output = ASTRepr<T>;
fn div(self, rhs: &ASTRepr<T>) -> Self::Output {
ASTRepr::Div(Box::new(self), Box::new(rhs.clone()))
}
}
impl<T> Neg for ASTRepr<T>
where
T: NumericType + Neg<Output = T>,
{
type Output = ASTRepr<T>;
fn neg(self) -> Self::Output {
ASTRepr::Neg(Box::new(self))
}
}
impl<T> Neg for &ASTRepr<T>
where
T: NumericType + Neg<Output = T>,
{
type Output = ASTRepr<T>;
fn neg(self) -> Self::Output {
ASTRepr::Neg(Box::new(self.clone()))
}
}
impl<T> AsRef<ASTRepr<T>> for ASTRepr<T> {
fn as_ref(&self) -> &ASTRepr<T> {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unified_operator_overloading() {
let x_f64 = ASTRepr::<f64>::Variable(0);
let y_f64 = ASTRepr::<f64>::Variable(1);
let const_f64 = ASTRepr::<f64>::Constant(2.5);
let expr1 = &x_f64 + &y_f64; let expr2 = x_f64.clone() + &y_f64; let expr3 = &x_f64 + y_f64.clone(); let expr4 = x_f64.clone() + y_f64.clone();
assert_eq!(expr1.count_operations(), 1);
assert_eq!(expr2.count_operations(), 1);
assert_eq!(expr3.count_operations(), 1);
assert_eq!(expr4.count_operations(), 1);
let expr_complex = &const_f64 * &x_f64 + &y_f64;
assert_eq!(expr_complex.count_operations(), 2);
}
#[test]
fn test_unified_operators_all_types() {
let x = ASTRepr::<f64>::Variable(0);
let y = ASTRepr::<f64>::Variable(1);
let two = ASTRepr::<f64>::Constant(2.0);
let add_expr = &x + &y;
let sub_expr = &x - &y;
let mul_expr = &x * &two;
let div_expr = &x / &two;
let neg_expr = -&x;
match add_expr {
ASTRepr::Add(_, _) => {}
_ => panic!("Expected Add"),
}
match sub_expr {
ASTRepr::Sub(_, _) => {}
_ => panic!("Expected Sub"),
}
match mul_expr {
ASTRepr::Mul(_, _) => {}
_ => panic!("Expected Mul"),
}
match div_expr {
ASTRepr::Div(_, _) => {}
_ => panic!("Expected Div"),
}
match neg_expr {
ASTRepr::Neg(_) => {}
_ => panic!("Expected Neg"),
}
}
#[test]
fn test_complex_expression_building() {
let x = ASTRepr::<f64>::Variable(0);
let y = ASTRepr::<f64>::Variable(1);
let two = ASTRepr::<f64>::Constant(2.0);
let three = ASTRepr::<f64>::Constant(3.0);
let expr = &two * &x + &three * &y;
assert_eq!(expr.count_operations(), 3);
let neg_expr = -(&two * &x + &three * &y);
assert_eq!(neg_expr.count_operations(), 4); }
#[test]
fn test_generic_numeric_types() {
let x_f32 = ASTRepr::<f32>::Variable(0);
let y_f32 = ASTRepr::<f32>::Variable(1);
let const_f32 = ASTRepr::<f32>::Constant(2.5_f32);
let expr_f32 = &x_f32 + &y_f32 * &const_f32;
assert_eq!(expr_f32.count_operations(), 2);
let x_i32 = ASTRepr::<i32>::Variable(0);
let const_i32 = ASTRepr::<i32>::Constant(42);
let expr_i32 = &x_i32 + &const_i32;
assert_eq!(expr_i32.count_operations(), 1);
}
}