use std::ops::{Add, Div, Mul, Neg, Sub};
use crate::handle::Expr;
use crate::linear::{add_into, div_into, mul_into, neg_into, sub_into};
impl<'a> Add for Expr<'a> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let id = add_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
Self::new(id, self.arena)
}
}
impl<'a> Sub for Expr<'a> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
let id = sub_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
Self::new(id, self.arena)
}
}
impl<'a> Mul for Expr<'a> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
let id = mul_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
Self::new(id, self.arena)
}
}
impl<'a> Div for Expr<'a> {
type Output = Self;
fn div(self, rhs: Self) -> Self {
let id = div_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
Self::new(id, self.arena)
}
}
impl<'a> Neg for Expr<'a> {
type Output = Self;
fn neg(self) -> Self {
let id = neg_into(&mut self.arena.borrow_mut(), self.id);
Self::new(id, self.arena)
}
}
macro_rules! impl_scalar_ops {
($scalar:ty, $to_f64:expr) => {
impl<'a> Add<$scalar> for Expr<'a> {
type Output = Self;
fn add(self, rhs: $scalar) -> Self {
let id = {
let mut a = self.arena.borrow_mut();
let rhs_id = a.constant($to_f64(rhs));
add_into(&mut a, self.id, rhs_id)
};
Self::new(id, self.arena)
}
}
impl<'a> Add<Expr<'a>> for $scalar {
type Output = Expr<'a>;
fn add(self, rhs: Expr<'a>) -> Expr<'a> {
rhs + self
}
}
impl<'a> Sub<$scalar> for Expr<'a> {
type Output = Self;
fn sub(self, rhs: $scalar) -> Self {
let id = {
let mut a = self.arena.borrow_mut();
let rhs_id = a.constant($to_f64(rhs));
sub_into(&mut a, self.id, rhs_id)
};
Self::new(id, self.arena)
}
}
impl<'a> Sub<Expr<'a>> for $scalar {
type Output = Expr<'a>;
fn sub(self, rhs: Expr<'a>) -> Expr<'a> {
let id = {
let mut a = rhs.arena.borrow_mut();
let lhs_id = a.constant($to_f64(self));
sub_into(&mut a, lhs_id, rhs.id)
};
Expr::new(id, rhs.arena)
}
}
impl<'a> Mul<$scalar> for Expr<'a> {
type Output = Self;
fn mul(self, rhs: $scalar) -> Self {
let id = {
let mut a = self.arena.borrow_mut();
let rhs_id = a.constant($to_f64(rhs));
mul_into(&mut a, self.id, rhs_id)
};
Self::new(id, self.arena)
}
}
impl<'a> Mul<Expr<'a>> for $scalar {
type Output = Expr<'a>;
fn mul(self, rhs: Expr<'a>) -> Expr<'a> {
rhs * self
}
}
impl<'a> Div<$scalar> for Expr<'a> {
type Output = Self;
fn div(self, rhs: $scalar) -> Self {
let id = {
let mut a = self.arena.borrow_mut();
let rhs_id = a.constant($to_f64(rhs));
div_into(&mut a, self.id, rhs_id)
};
Self::new(id, self.arena)
}
}
impl<'a> Div<Expr<'a>> for $scalar {
type Output = Expr<'a>;
fn div(self, rhs: Expr<'a>) -> Expr<'a> {
let id = {
let mut a = rhs.arena.borrow_mut();
let lhs_id = a.constant($to_f64(self));
div_into(&mut a, lhs_id, rhs.id)
};
Expr::new(id, rhs.arena)
}
}
};
}
impl_scalar_ops!(f64, core::convert::identity);
impl_scalar_ops!(i32, f64::from);
impl<'a> std::iter::Sum for Expr<'a> {
fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
let first = iter.next().expect("Expr::sum on empty iterator");
iter.fold(first, |acc, e| acc + e)
}
}
impl<'a, 'b> std::iter::Sum<&'b Expr<'a>> for Expr<'a> {
fn sum<I: Iterator<Item = &'b Expr<'a>>>(iter: I) -> Self {
iter.copied().sum()
}
}
pub fn dot<'a>(exprs: &[Expr<'a>], coeffs: &[f64]) -> Expr<'a> {
assert_eq!(
exprs.len(),
coeffs.len(),
"dot: length mismatch (exprs.len() = {}, coeffs.len() = {})",
exprs.len(),
coeffs.len(),
);
exprs.iter().zip(coeffs).map(|(e, c)| *c * *e).sum()
}