1use std::collections::HashMap;
4use std::collections::hash_map::Entry;
5use std::convert::From;
6use std::ops::{Add,Mul,Neg,Sub};
7use num_rational::{Ratio,Rational32};
8use num_traits::{One,Zero};
9
10use ::{Coef,LinExpr,VarToken};
11
12macro_rules! impl_commutative_op {
13 ($LHS:ident + $RHS:ident) => {
14 impl Add<$RHS> for $LHS {
15 type Output = LinExpr;
16 fn add(self, rhs: $RHS) -> Self::Output { rhs + self }
17 }
18 };
19 ($LHS:ident * $RHS:ident) => {
20 impl Mul<$RHS> for $LHS {
21 type Output = LinExpr;
22 fn mul(self, rhs: $RHS) -> Self::Output { rhs * self }
23 }
24 };
25}
26
27macro_rules! impl_subtract_op {
28 ($LHS:ident - $RHS:ident) => {
29 impl Sub<$RHS> for $LHS {
30 type Output = LinExpr;
31 fn sub(self, rhs: $RHS) -> Self::Output { self + (-rhs) }
32 }
33 }
34}
35
36pub trait IntoCoef: Zero {
37 fn into_coef(self) -> Coef;
38}
39
40impl IntoCoef for i32 {
41 fn into_coef(self) -> Coef { Ratio::from_integer(self) }
42}
43
44impl IntoCoef for Rational32 {
45 fn into_coef(self) -> Coef { self }
46}
47
48impl<T: IntoCoef> From<T> for LinExpr {
51 fn from(constant: T) -> Self {
52 LinExpr {
53 constant: constant.into_coef(),
54 coef: HashMap::new(),
55 }
56 }
57}
58
59impl From<VarToken> for LinExpr {
60 fn from(var: VarToken) -> Self {
61 let mut coef = HashMap::new();
62 coef.insert(var, Ratio::one());
63
64 LinExpr {
65 constant: Ratio::zero(),
66 coef: coef,
67 }
68 }
69}
70
71impl Neg for VarToken {
76 type Output = LinExpr;
77 fn neg(self) -> Self::Output {
78 -LinExpr::from(self)
79 }
80}
81
82impl<T: IntoCoef> Add<T> for VarToken {
83 type Output = LinExpr;
84 fn add(self, rhs: T) -> Self::Output {
85 LinExpr::from(self) + rhs
86 }
87}
88
89impl_commutative_op!(i32 + VarToken);
90impl_commutative_op!(Rational32 + VarToken);
91
92impl_subtract_op!(VarToken - i32);
93impl_subtract_op!(i32 - VarToken);
94impl_subtract_op!(VarToken - Rational32);
95impl_subtract_op!(Rational32 - VarToken);
96
97impl<T: IntoCoef> Mul<T> for VarToken {
98 type Output = LinExpr;
99 fn mul(self, rhs: T) -> Self::Output {
100 LinExpr::from(self) * rhs
101 }
102}
103
104impl_commutative_op!(i32 * VarToken);
105impl_commutative_op!(Rational32 * VarToken);
106
107impl Add for VarToken {
112 type Output = LinExpr;
113 fn add(self, rhs: VarToken) -> Self::Output {
114 LinExpr::from(self) + LinExpr::from(rhs)
115 }
116}
117
118impl_subtract_op!(VarToken - VarToken);
119
120impl Neg for LinExpr {
125 type Output = LinExpr;
126 fn neg(self) -> Self::Output {
127 -1 * self
128 }
129}
130
131impl<T: IntoCoef> Add<T> for LinExpr {
132 type Output = LinExpr;
133 fn add(mut self, rhs: T) -> Self::Output {
134 self.constant = self.constant + rhs.into_coef();
135 self
136 }
137}
138
139impl_commutative_op!(i32 + LinExpr);
140impl_commutative_op!(Rational32 + LinExpr);
141
142impl_subtract_op!(LinExpr - i32);
143impl_subtract_op!(i32 - LinExpr);
144impl_subtract_op!(LinExpr - Rational32);
145impl_subtract_op!(Rational32 - LinExpr);
146
147impl<T: IntoCoef> Mul<T> for LinExpr {
148 type Output = LinExpr;
149 fn mul(mut self, rhs: T) -> Self::Output {
150 if rhs.is_zero() {
151 self.constant = Ratio::zero();
152 self.coef = HashMap::new();
153 } else {
154 let rhs = rhs.into_coef();
155 if rhs != Ratio::one() {
156 self.constant = self.constant * rhs;
157 for coef in self.coef.values_mut() {
158 *coef = *coef * rhs;
159 }
160 }
161 }
162
163 self
164 }
165}
166
167impl_commutative_op!(i32 * LinExpr);
168impl_commutative_op!(Rational32 * LinExpr);
169
170impl Add<VarToken> for LinExpr {
175 type Output = LinExpr;
176 fn add(self, rhs: VarToken) -> Self::Output {
177 self + LinExpr::from(rhs)
178 }
179}
180
181impl_commutative_op!(VarToken + LinExpr);
182
183impl_subtract_op!(LinExpr - VarToken);
184impl_subtract_op!(VarToken - LinExpr);
185
186impl Add for LinExpr {
191 type Output = LinExpr;
192 fn add(mut self, mut rhs: LinExpr) -> Self::Output {
193 self.constant = self.constant + rhs.constant;
194
195 for (x2, a2) in rhs.coef.drain() {
196 match self.coef.entry(x2) {
197 Entry::Vacant(e) => {
198 e.insert(a2);
199 },
200 Entry::Occupied(mut e) => {
201 let new_coef = *e.get() + a2;
202 if new_coef.is_zero() {
203 e.remove();
204 } else {
205 *e.get_mut() = new_coef;
206 }
207 },
208 }
209 }
210
211 self
212 }
213}
214
215impl_subtract_op!(LinExpr - LinExpr);
216
217#[cfg(test)]
220mod tests {
221 use num_rational::Ratio;
222 use ::Puzzle;
223
224 #[test]
225 fn test_ops() {
226 let mut puzzle = Puzzle::new();
227 let x = puzzle.new_var();
228 let y = puzzle.new_var();
229
230 let _ = x + 1;
232 let _ = x - 1;
233 let _ = x * 1;
234 let _ = x + Ratio::new(1, 2);
235 let _ = x - Ratio::new(1, 2);
236 let _ = x * Ratio::new(1, 2);
237
238 let _ = 1 + x;
240 let _ = 1 - x;
241 let _ = 1 * x;
242 let _ = Ratio::new(1, 2) + x;
243 let _ = Ratio::new(1, 2) - x;
244 let _ = Ratio::new(1, 2) * x;
245
246 let _ = -x;
248 let _ = x + y;
249 let _ = x - y;
250
251 let _ = (x + y) + 1;
253 let _ = (x + y) - 1;
254 let _ = (x + y) * 1;
255 let _ = (x + y) + Ratio::new(1, 2);
256 let _ = (x + y) - Ratio::new(1, 2);
257 let _ = (x + y) * Ratio::new(1, 2);
258
259 let _ = 1 + (x + y);
261 let _ = 1 - (x + y);
262 let _ = 1 * (x + y);
263 let _ = Ratio::new(1, 2) + (x + y);
264 let _ = Ratio::new(1, 2) - (x + y);
265 let _ = Ratio::new(1, 2) * (x + y);
266
267 let _ = (x + 1) + y;
269 let _ = (x + 1) - y;
270
271 let _ = x + (y + 1);
273 let _ = x - (y + 1);
274
275 let _ = -(x + y);
277 let _ = (x + y) + (x + y);
278 let _ = (x + y) - (x + y);
279 }
280
281 #[test]
282 fn test_coef_zero() {
283 let mut puzzle = Puzzle::new();
284 let x = puzzle.new_var();
285 let y = puzzle.new_var();
286
287 let expr = x * 0;
288 assert_eq!(expr.coef.len(), 0);
289
290 let expr = x - x;
291 assert_eq!(expr.coef.len(), 0);
292
293 let expr = (x + y) * 0;
294 assert_eq!(expr.coef.len(), 0);
295
296 let expr = (x + y) - (x + y);
297 assert_eq!(expr.coef.len(), 0);
298 }
299}