Skip to main content

oximo_expr/
ops.rs

1use std::ops::{Add, Div, Mul, Neg, Sub};
2
3use crate::handle::Expr;
4use crate::linear::{add_into, div_into, mul_into, neg_into, sub_into};
5
6// -----------------------------------------------------------------------------
7// Expr <op> Expr
8// -----------------------------------------------------------------------------
9
10impl<'a> Add for Expr<'a> {
11    type Output = Self;
12    fn add(self, rhs: Self) -> Self {
13        let id = add_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
14        Self::new(id, self.arena)
15    }
16}
17
18impl<'a> Sub for Expr<'a> {
19    type Output = Self;
20    fn sub(self, rhs: Self) -> Self {
21        let id = sub_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
22        Self::new(id, self.arena)
23    }
24}
25
26impl<'a> Mul for Expr<'a> {
27    type Output = Self;
28    fn mul(self, rhs: Self) -> Self {
29        let id = mul_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
30        Self::new(id, self.arena)
31    }
32}
33
34impl<'a> Div for Expr<'a> {
35    type Output = Self;
36    fn div(self, rhs: Self) -> Self {
37        let id = div_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
38        Self::new(id, self.arena)
39    }
40}
41
42impl<'a> Neg for Expr<'a> {
43    type Output = Self;
44    fn neg(self) -> Self {
45        let id = neg_into(&mut self.arena.borrow_mut(), self.id);
46        Self::new(id, self.arena)
47    }
48}
49
50// -----------------------------------------------------------------------------
51// Expr <op> f64 / f64 <op> Expr, and the same for i32 because `2 * x`
52// without type annotation is the most common ergonomic case.
53// -----------------------------------------------------------------------------
54
55macro_rules! impl_scalar_ops {
56    ($scalar:ty, $to_f64:expr) => {
57        impl<'a> Add<$scalar> for Expr<'a> {
58            type Output = Self;
59            fn add(self, rhs: $scalar) -> Self {
60                let id = {
61                    let mut a = self.arena.borrow_mut();
62                    let rhs_id = a.constant($to_f64(rhs));
63                    add_into(&mut a, self.id, rhs_id)
64                };
65                Self::new(id, self.arena)
66            }
67        }
68
69        impl<'a> Add<Expr<'a>> for $scalar {
70            type Output = Expr<'a>;
71            fn add(self, rhs: Expr<'a>) -> Expr<'a> {
72                rhs + self
73            }
74        }
75
76        impl<'a> Sub<$scalar> for Expr<'a> {
77            type Output = Self;
78            fn sub(self, rhs: $scalar) -> Self {
79                let id = {
80                    let mut a = self.arena.borrow_mut();
81                    let rhs_id = a.constant($to_f64(rhs));
82                    sub_into(&mut a, self.id, rhs_id)
83                };
84                Self::new(id, self.arena)
85            }
86        }
87
88        impl<'a> Sub<Expr<'a>> for $scalar {
89            type Output = Expr<'a>;
90            fn sub(self, rhs: Expr<'a>) -> Expr<'a> {
91                let id = {
92                    let mut a = rhs.arena.borrow_mut();
93                    let lhs_id = a.constant($to_f64(self));
94                    sub_into(&mut a, lhs_id, rhs.id)
95                };
96                Expr::new(id, rhs.arena)
97            }
98        }
99
100        impl<'a> Mul<$scalar> for Expr<'a> {
101            type Output = Self;
102            fn mul(self, rhs: $scalar) -> Self {
103                let id = {
104                    let mut a = self.arena.borrow_mut();
105                    let rhs_id = a.constant($to_f64(rhs));
106                    mul_into(&mut a, self.id, rhs_id)
107                };
108                Self::new(id, self.arena)
109            }
110        }
111
112        impl<'a> Mul<Expr<'a>> for $scalar {
113            type Output = Expr<'a>;
114            fn mul(self, rhs: Expr<'a>) -> Expr<'a> {
115                rhs * self
116            }
117        }
118
119        impl<'a> Div<$scalar> for Expr<'a> {
120            type Output = Self;
121            fn div(self, rhs: $scalar) -> Self {
122                let id = {
123                    let mut a = self.arena.borrow_mut();
124                    let rhs_id = a.constant($to_f64(rhs));
125                    div_into(&mut a, self.id, rhs_id)
126                };
127                Self::new(id, self.arena)
128            }
129        }
130
131        impl<'a> Div<Expr<'a>> for $scalar {
132            type Output = Expr<'a>;
133            fn div(self, rhs: Expr<'a>) -> Expr<'a> {
134                let id = {
135                    let mut a = rhs.arena.borrow_mut();
136                    let lhs_id = a.constant($to_f64(self));
137                    div_into(&mut a, lhs_id, rhs.id)
138                };
139                Expr::new(id, rhs.arena)
140            }
141        }
142    };
143}
144
145impl_scalar_ops!(f64, core::convert::identity);
146impl_scalar_ops!(i32, f64::from);
147
148// -----------------------------------------------------------------------------
149// std::iter::Sum: the first element of the iterator carries the arena handle,
150// so no external zero is required.
151// -----------------------------------------------------------------------------
152
153impl<'a> std::iter::Sum for Expr<'a> {
154    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
155        let first = iter.next().expect("Expr::sum on empty iterator");
156        iter.fold(first, |acc, e| acc + e)
157    }
158}
159
160impl<'a, 'b> std::iter::Sum<&'b Expr<'a>> for Expr<'a> {
161    fn sum<I: Iterator<Item = &'b Expr<'a>>>(iter: I) -> Self {
162        iter.copied().sum()
163    }
164}
165
166/// Dot product of expressions with scalar coefficients: `sum_{i} c_i * e_i`.
167///
168/// Both arguments are slices. Pass owned containers by reference:
169/// `&vec`, `vec.as_slice()`, or `&array`.
170///
171/// # Panics
172/// Panics if `exprs` and `coeffs` have different lengths, or if `exprs`
173/// is empty (the result needs an arena handle).
174pub fn dot<'a>(exprs: &[Expr<'a>], coeffs: &[f64]) -> Expr<'a> {
175    assert_eq!(
176        exprs.len(),
177        coeffs.len(),
178        "dot: length mismatch (exprs.len() = {}, coeffs.len() = {})",
179        exprs.len(),
180        coeffs.len(),
181    );
182    exprs.iter().zip(coeffs).map(|(e, c)| *c * *e).sum()
183}