Skip to main content

oximo_expr/
ops.rs

1use std::ops::{Add, Mul, Neg, Sub};
2
3use crate::handle::Expr;
4use crate::linear::{add_into, mul_into, neg_into, sub_into};
5
6// -----------------------------------------------------------------------------
7// Expr <op> Expr
8// -----------------------------------------------------------------------------
9
10impl<'a> Add for Expr<'a> {
11    type Output = Self;
12    fn add(self, rhs: Self) -> Self {
13        let id = add_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
14        Self::new(id, self.arena)
15    }
16}
17
18impl<'a> Sub for Expr<'a> {
19    type Output = Self;
20    fn sub(self, rhs: Self) -> Self {
21        let id = sub_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
22        Self::new(id, self.arena)
23    }
24}
25
26impl<'a> Mul for Expr<'a> {
27    type Output = Self;
28    fn mul(self, rhs: Self) -> Self {
29        let id = mul_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
30        Self::new(id, self.arena)
31    }
32}
33
34impl<'a> Neg for Expr<'a> {
35    type Output = Self;
36    fn neg(self) -> Self {
37        let id = neg_into(&mut self.arena.borrow_mut(), self.id);
38        Self::new(id, self.arena)
39    }
40}
41
42// -----------------------------------------------------------------------------
43// Expr <op> f64 / f64 <op> Expr, and the same for i32 because `2 * x`
44// without type annotation is the most common ergonomic case.
45// -----------------------------------------------------------------------------
46
47macro_rules! impl_scalar_ops {
48    ($scalar:ty) => {
49        impl<'a> Add<$scalar> for Expr<'a> {
50            type Output = Self;
51            fn add(self, rhs: $scalar) -> Self {
52                #[allow(clippy::cast_lossless)]
53                let id = {
54                    let mut a = self.arena.borrow_mut();
55                    let rhs_id = a.constant(rhs as f64);
56                    add_into(&mut a, self.id, rhs_id)
57                };
58                Self::new(id, self.arena)
59            }
60        }
61
62        impl<'a> Add<Expr<'a>> for $scalar {
63            type Output = Expr<'a>;
64            fn add(self, rhs: Expr<'a>) -> Expr<'a> {
65                rhs + self
66            }
67        }
68
69        impl<'a> Sub<$scalar> for Expr<'a> {
70            type Output = Self;
71            fn sub(self, rhs: $scalar) -> Self {
72                #[allow(clippy::cast_lossless)]
73                let id = {
74                    let mut a = self.arena.borrow_mut();
75                    let rhs_id = a.constant(rhs as f64);
76                    sub_into(&mut a, self.id, rhs_id)
77                };
78                Self::new(id, self.arena)
79            }
80        }
81
82        impl<'a> Sub<Expr<'a>> for $scalar {
83            type Output = Expr<'a>;
84            fn sub(self, rhs: Expr<'a>) -> Expr<'a> {
85                #[allow(clippy::cast_lossless)]
86                let id = {
87                    let mut a = rhs.arena.borrow_mut();
88                    let lhs_id = a.constant(self as f64);
89                    sub_into(&mut a, lhs_id, rhs.id)
90                };
91                Expr::new(id, rhs.arena)
92            }
93        }
94
95        impl<'a> Mul<$scalar> for Expr<'a> {
96            type Output = Self;
97            fn mul(self, rhs: $scalar) -> Self {
98                #[allow(clippy::cast_lossless)]
99                let id = {
100                    let mut a = self.arena.borrow_mut();
101                    let rhs_id = a.constant(rhs as f64);
102                    mul_into(&mut a, self.id, rhs_id)
103                };
104                Self::new(id, self.arena)
105            }
106        }
107
108        impl<'a> Mul<Expr<'a>> for $scalar {
109            type Output = Expr<'a>;
110            fn mul(self, rhs: Expr<'a>) -> Expr<'a> {
111                rhs * self
112            }
113        }
114    };
115}
116
117impl_scalar_ops!(f64);
118impl_scalar_ops!(i32);
119
120// -----------------------------------------------------------------------------
121// Sum support for `iter.sum::<Expr>()` would be nice but requires a starting
122// value tied to the arena. Provide a free function instead.
123// -----------------------------------------------------------------------------
124
125/// Sum a non-empty iterator of expressions sharing the same arena.
126///
127/// # Panics
128/// Panics if the iterator is empty.
129pub fn sum<'a, I: IntoIterator<Item = Expr<'a>>>(iter: I) -> Expr<'a> {
130    let mut it = iter.into_iter();
131    let first = it.next().expect("oximo_expr::sum on empty iterator");
132    it.fold(first, |acc, e| acc + e)
133}