Skip to main content

oximo_expr/
ops.rs

1use std::ops::{Add, Div, Mul, Neg, Sub};
2
3use crate::arena::ExprId;
4use crate::handle::Expr;
5use crate::linear::{add_into, add_n, div_into, mul_into, neg_into, sub_into};
6
7// -----------------------------------------------------------------------------
8// Expr <op> Expr
9// -----------------------------------------------------------------------------
10
11impl<'a> Add for Expr<'a> {
12    type Output = Self;
13    fn add(self, rhs: Self) -> Self {
14        let id = add_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
15        Self::new(id, self.arena)
16    }
17}
18
19impl<'a> Sub for Expr<'a> {
20    type Output = Self;
21    fn sub(self, rhs: Self) -> Self {
22        let id = sub_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
23        Self::new(id, self.arena)
24    }
25}
26
27impl<'a> Mul for Expr<'a> {
28    type Output = Self;
29    fn mul(self, rhs: Self) -> Self {
30        let id = mul_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
31        Self::new(id, self.arena)
32    }
33}
34
35impl<'a> Div for Expr<'a> {
36    type Output = Self;
37    fn div(self, rhs: Self) -> Self {
38        let id = div_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
39        Self::new(id, self.arena)
40    }
41}
42
43impl<'a> Neg for Expr<'a> {
44    type Output = Self;
45    fn neg(self) -> Self {
46        let id = neg_into(&mut self.arena.borrow_mut(), self.id);
47        Self::new(id, self.arena)
48    }
49}
50
51// -----------------------------------------------------------------------------
52// Expr <op> f64 / f64 <op> Expr, and the same for i32 because `2 * x`
53// without type annotation is the most common ergonomic case.
54// -----------------------------------------------------------------------------
55
56macro_rules! impl_scalar_ops {
57    ($scalar:ty, $to_f64:expr) => {
58        impl<'a> Add<$scalar> for Expr<'a> {
59            type Output = Self;
60            fn add(self, rhs: $scalar) -> Self {
61                let id = {
62                    let mut a = self.arena.borrow_mut();
63                    let rhs_id = a.constant($to_f64(rhs));
64                    add_into(&mut a, self.id, rhs_id)
65                };
66                Self::new(id, self.arena)
67            }
68        }
69
70        impl<'a> Add<Expr<'a>> for $scalar {
71            type Output = Expr<'a>;
72            fn add(self, rhs: Expr<'a>) -> Expr<'a> {
73                rhs + self
74            }
75        }
76
77        impl<'a> Sub<$scalar> for Expr<'a> {
78            type Output = Self;
79            fn sub(self, rhs: $scalar) -> Self {
80                let id = {
81                    let mut a = self.arena.borrow_mut();
82                    let rhs_id = a.constant($to_f64(rhs));
83                    sub_into(&mut a, self.id, rhs_id)
84                };
85                Self::new(id, self.arena)
86            }
87        }
88
89        impl<'a> Sub<Expr<'a>> for $scalar {
90            type Output = Expr<'a>;
91            fn sub(self, rhs: Expr<'a>) -> Expr<'a> {
92                let id = {
93                    let mut a = rhs.arena.borrow_mut();
94                    let lhs_id = a.constant($to_f64(self));
95                    sub_into(&mut a, lhs_id, rhs.id)
96                };
97                Expr::new(id, rhs.arena)
98            }
99        }
100
101        impl<'a> Mul<$scalar> for Expr<'a> {
102            type Output = Self;
103            fn mul(self, rhs: $scalar) -> Self {
104                let id = {
105                    let mut a = self.arena.borrow_mut();
106                    let rhs_id = a.constant($to_f64(rhs));
107                    mul_into(&mut a, self.id, rhs_id)
108                };
109                Self::new(id, self.arena)
110            }
111        }
112
113        impl<'a> Mul<Expr<'a>> for $scalar {
114            type Output = Expr<'a>;
115            fn mul(self, rhs: Expr<'a>) -> Expr<'a> {
116                rhs * self
117            }
118        }
119
120        impl<'a> Div<$scalar> for Expr<'a> {
121            type Output = Self;
122            fn div(self, rhs: $scalar) -> Self {
123                let id = {
124                    let mut a = self.arena.borrow_mut();
125                    let rhs_id = a.constant($to_f64(rhs));
126                    div_into(&mut a, self.id, rhs_id)
127                };
128                Self::new(id, self.arena)
129            }
130        }
131
132        impl<'a> Div<Expr<'a>> for $scalar {
133            type Output = Expr<'a>;
134            fn div(self, rhs: Expr<'a>) -> Expr<'a> {
135                let id = {
136                    let mut a = rhs.arena.borrow_mut();
137                    let lhs_id = a.constant($to_f64(self));
138                    div_into(&mut a, lhs_id, rhs.id)
139                };
140                Expr::new(id, rhs.arena)
141            }
142        }
143    };
144}
145
146impl_scalar_ops!(f64, core::convert::identity);
147impl_scalar_ops!(i32, f64::from);
148
149// -----------------------------------------------------------------------------
150// std::iter::Sum: the first element of the iterator carries the arena handle,
151// so no external zero is required. Collected into a single flat n-ary `Add`.
152// -----------------------------------------------------------------------------
153
154impl<'a> std::iter::Sum for Expr<'a> {
155    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
156        let items: Vec<Self> = iter.collect();
157        let first = *items.first().expect("Expr::sum on empty iterator");
158        let ids: Vec<ExprId> = items.iter().map(|e| e.id).collect();
159        let id = add_n(&mut first.arena.borrow_mut(), &ids);
160        Self::new(id, first.arena)
161    }
162}
163
164impl<'a, 'b> std::iter::Sum<&'b Expr<'a>> for Expr<'a> {
165    fn sum<I: Iterator<Item = &'b Expr<'a>>>(iter: I) -> Self {
166        iter.copied().sum()
167    }
168}
169
170/// Dot product of expressions with scalar coefficients: `sum_{i} c_i * e_i`.
171///
172/// Both arguments are slices. Pass owned containers by reference:
173/// `&vec`, `vec.as_slice()`, or `&array`.
174///
175/// # Panics
176/// Panics if `exprs` and `coeffs` have different lengths, or if `exprs`
177/// is empty (the result needs an arena handle).
178pub fn dot<'a>(exprs: &[Expr<'a>], coeffs: &[f64]) -> Expr<'a> {
179    assert_eq!(
180        exprs.len(),
181        coeffs.len(),
182        "dot: length mismatch (exprs.len() = {}, coeffs.len() = {})",
183        exprs.len(),
184        coeffs.len(),
185    );
186    exprs.iter().zip(coeffs).map(|(e, c)| *c * *e).sum()
187}