1use std::ops::{Add, Div, Mul, Neg, Sub};
2
3use crate::handle::Expr;
4use crate::linear::{add_into, div_into, mul_into, neg_into, sub_into};
5
6impl<'a> Add for Expr<'a> {
11 type Output = Self;
12 fn add(self, rhs: Self) -> Self {
13 let id = add_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
14 Self::new(id, self.arena)
15 }
16}
17
18impl<'a> Sub for Expr<'a> {
19 type Output = Self;
20 fn sub(self, rhs: Self) -> Self {
21 let id = sub_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
22 Self::new(id, self.arena)
23 }
24}
25
26impl<'a> Mul for Expr<'a> {
27 type Output = Self;
28 fn mul(self, rhs: Self) -> Self {
29 let id = mul_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
30 Self::new(id, self.arena)
31 }
32}
33
34impl<'a> Div for Expr<'a> {
35 type Output = Self;
36 fn div(self, rhs: Self) -> Self {
37 let id = div_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
38 Self::new(id, self.arena)
39 }
40}
41
42impl<'a> Neg for Expr<'a> {
43 type Output = Self;
44 fn neg(self) -> Self {
45 let id = neg_into(&mut self.arena.borrow_mut(), self.id);
46 Self::new(id, self.arena)
47 }
48}
49
50macro_rules! impl_scalar_ops {
56 ($scalar:ty, $to_f64:expr) => {
57 impl<'a> Add<$scalar> for Expr<'a> {
58 type Output = Self;
59 fn add(self, rhs: $scalar) -> Self {
60 let id = {
61 let mut a = self.arena.borrow_mut();
62 let rhs_id = a.constant($to_f64(rhs));
63 add_into(&mut a, self.id, rhs_id)
64 };
65 Self::new(id, self.arena)
66 }
67 }
68
69 impl<'a> Add<Expr<'a>> for $scalar {
70 type Output = Expr<'a>;
71 fn add(self, rhs: Expr<'a>) -> Expr<'a> {
72 rhs + self
73 }
74 }
75
76 impl<'a> Sub<$scalar> for Expr<'a> {
77 type Output = Self;
78 fn sub(self, rhs: $scalar) -> Self {
79 let id = {
80 let mut a = self.arena.borrow_mut();
81 let rhs_id = a.constant($to_f64(rhs));
82 sub_into(&mut a, self.id, rhs_id)
83 };
84 Self::new(id, self.arena)
85 }
86 }
87
88 impl<'a> Sub<Expr<'a>> for $scalar {
89 type Output = Expr<'a>;
90 fn sub(self, rhs: Expr<'a>) -> Expr<'a> {
91 let id = {
92 let mut a = rhs.arena.borrow_mut();
93 let lhs_id = a.constant($to_f64(self));
94 sub_into(&mut a, lhs_id, rhs.id)
95 };
96 Expr::new(id, rhs.arena)
97 }
98 }
99
100 impl<'a> Mul<$scalar> for Expr<'a> {
101 type Output = Self;
102 fn mul(self, rhs: $scalar) -> Self {
103 let id = {
104 let mut a = self.arena.borrow_mut();
105 let rhs_id = a.constant($to_f64(rhs));
106 mul_into(&mut a, self.id, rhs_id)
107 };
108 Self::new(id, self.arena)
109 }
110 }
111
112 impl<'a> Mul<Expr<'a>> for $scalar {
113 type Output = Expr<'a>;
114 fn mul(self, rhs: Expr<'a>) -> Expr<'a> {
115 rhs * self
116 }
117 }
118
119 impl<'a> Div<$scalar> for Expr<'a> {
120 type Output = Self;
121 fn div(self, rhs: $scalar) -> Self {
122 let id = {
123 let mut a = self.arena.borrow_mut();
124 let rhs_id = a.constant($to_f64(rhs));
125 div_into(&mut a, self.id, rhs_id)
126 };
127 Self::new(id, self.arena)
128 }
129 }
130
131 impl<'a> Div<Expr<'a>> for $scalar {
132 type Output = Expr<'a>;
133 fn div(self, rhs: Expr<'a>) -> Expr<'a> {
134 let id = {
135 let mut a = rhs.arena.borrow_mut();
136 let lhs_id = a.constant($to_f64(self));
137 div_into(&mut a, lhs_id, rhs.id)
138 };
139 Expr::new(id, rhs.arena)
140 }
141 }
142 };
143}
144
145impl_scalar_ops!(f64, core::convert::identity);
146impl_scalar_ops!(i32, f64::from);
147
148impl<'a> std::iter::Sum for Expr<'a> {
154 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
155 let first = iter.next().expect("Expr::sum on empty iterator");
156 iter.fold(first, |acc, e| acc + e)
157 }
158}
159
160impl<'a, 'b> std::iter::Sum<&'b Expr<'a>> for Expr<'a> {
161 fn sum<I: Iterator<Item = &'b Expr<'a>>>(iter: I) -> Self {
162 iter.copied().sum()
163 }
164}
165
166pub fn dot<'a>(exprs: &[Expr<'a>], coeffs: &[f64]) -> Expr<'a> {
175 assert_eq!(
176 exprs.len(),
177 coeffs.len(),
178 "dot: length mismatch (exprs.len() = {}, coeffs.len() = {})",
179 exprs.len(),
180 coeffs.len(),
181 );
182 exprs.iter().zip(coeffs).map(|(e, c)| *c * *e).sum()
183}