1use std::ops::{Add, Div, Mul, Neg, Sub};
2
3use crate::arena::ExprId;
4use crate::handle::Expr;
5use crate::linear::{add_into, add_n, div_into, mul_into, neg_into, sub_into};
6
7impl<'a> Add for Expr<'a> {
12 type Output = Self;
13 fn add(self, rhs: Self) -> Self {
14 let id = add_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
15 Self::new(id, self.arena)
16 }
17}
18
19impl<'a> Sub for Expr<'a> {
20 type Output = Self;
21 fn sub(self, rhs: Self) -> Self {
22 let id = sub_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
23 Self::new(id, self.arena)
24 }
25}
26
27impl<'a> Mul for Expr<'a> {
28 type Output = Self;
29 fn mul(self, rhs: Self) -> Self {
30 let id = mul_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
31 Self::new(id, self.arena)
32 }
33}
34
35impl<'a> Div for Expr<'a> {
36 type Output = Self;
37 fn div(self, rhs: Self) -> Self {
38 let id = div_into(&mut self.arena.borrow_mut(), self.id, rhs.id);
39 Self::new(id, self.arena)
40 }
41}
42
43impl<'a> Neg for Expr<'a> {
44 type Output = Self;
45 fn neg(self) -> Self {
46 let id = neg_into(&mut self.arena.borrow_mut(), self.id);
47 Self::new(id, self.arena)
48 }
49}
50
51macro_rules! impl_scalar_ops {
57 ($scalar:ty, $to_f64:expr) => {
58 impl<'a> Add<$scalar> for Expr<'a> {
59 type Output = Self;
60 fn add(self, rhs: $scalar) -> Self {
61 let id = {
62 let mut a = self.arena.borrow_mut();
63 let rhs_id = a.constant($to_f64(rhs));
64 add_into(&mut a, self.id, rhs_id)
65 };
66 Self::new(id, self.arena)
67 }
68 }
69
70 impl<'a> Add<Expr<'a>> for $scalar {
71 type Output = Expr<'a>;
72 fn add(self, rhs: Expr<'a>) -> Expr<'a> {
73 rhs + self
74 }
75 }
76
77 impl<'a> Sub<$scalar> for Expr<'a> {
78 type Output = Self;
79 fn sub(self, rhs: $scalar) -> Self {
80 let id = {
81 let mut a = self.arena.borrow_mut();
82 let rhs_id = a.constant($to_f64(rhs));
83 sub_into(&mut a, self.id, rhs_id)
84 };
85 Self::new(id, self.arena)
86 }
87 }
88
89 impl<'a> Sub<Expr<'a>> for $scalar {
90 type Output = Expr<'a>;
91 fn sub(self, rhs: Expr<'a>) -> Expr<'a> {
92 let id = {
93 let mut a = rhs.arena.borrow_mut();
94 let lhs_id = a.constant($to_f64(self));
95 sub_into(&mut a, lhs_id, rhs.id)
96 };
97 Expr::new(id, rhs.arena)
98 }
99 }
100
101 impl<'a> Mul<$scalar> for Expr<'a> {
102 type Output = Self;
103 fn mul(self, rhs: $scalar) -> Self {
104 let id = {
105 let mut a = self.arena.borrow_mut();
106 let rhs_id = a.constant($to_f64(rhs));
107 mul_into(&mut a, self.id, rhs_id)
108 };
109 Self::new(id, self.arena)
110 }
111 }
112
113 impl<'a> Mul<Expr<'a>> for $scalar {
114 type Output = Expr<'a>;
115 fn mul(self, rhs: Expr<'a>) -> Expr<'a> {
116 rhs * self
117 }
118 }
119
120 impl<'a> Div<$scalar> for Expr<'a> {
121 type Output = Self;
122 fn div(self, rhs: $scalar) -> Self {
123 let id = {
124 let mut a = self.arena.borrow_mut();
125 let rhs_id = a.constant($to_f64(rhs));
126 div_into(&mut a, self.id, rhs_id)
127 };
128 Self::new(id, self.arena)
129 }
130 }
131
132 impl<'a> Div<Expr<'a>> for $scalar {
133 type Output = Expr<'a>;
134 fn div(self, rhs: Expr<'a>) -> Expr<'a> {
135 let id = {
136 let mut a = rhs.arena.borrow_mut();
137 let lhs_id = a.constant($to_f64(self));
138 div_into(&mut a, lhs_id, rhs.id)
139 };
140 Expr::new(id, rhs.arena)
141 }
142 }
143 };
144}
145
146impl_scalar_ops!(f64, core::convert::identity);
147impl_scalar_ops!(i32, f64::from);
148
149impl<'a> std::iter::Sum for Expr<'a> {
155 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
156 let items: Vec<Self> = iter.collect();
157 let first = *items.first().expect("Expr::sum on empty iterator");
158 let ids: Vec<ExprId> = items.iter().map(|e| e.id).collect();
159 let id = add_n(&mut first.arena.borrow_mut(), &ids);
160 Self::new(id, first.arena)
161 }
162}
163
164impl<'a, 'b> std::iter::Sum<&'b Expr<'a>> for Expr<'a> {
165 fn sum<I: Iterator<Item = &'b Expr<'a>>>(iter: I) -> Self {
166 iter.copied().sum()
167 }
168}
169
170pub fn dot<'a>(exprs: &[Expr<'a>], coeffs: &[f64]) -> Expr<'a> {
179 assert_eq!(
180 exprs.len(),
181 coeffs.len(),
182 "dot: length mismatch (exprs.len() = {}, coeffs.len() = {})",
183 exprs.len(),
184 coeffs.len(),
185 );
186 exprs.iter().zip(coeffs).map(|(e, c)| *c * *e).sum()
187}