puzzle_solver/
linexpr.rs

1//! Linear expressions.
2
3use std::collections::HashMap;
4use std::collections::hash_map::Entry;
5use std::convert::From;
6use std::ops::{Add,Mul,Neg,Sub};
7use num_rational::{Ratio,Rational32};
8use num_traits::{One,Zero};
9
10use ::{Coef,LinExpr,VarToken};
11
12macro_rules! impl_commutative_op {
13    ($LHS:ident + $RHS:ident) => {
14        impl Add<$RHS> for $LHS {
15            type Output = LinExpr;
16            fn add(self, rhs: $RHS) -> Self::Output { rhs + self }
17        }
18    };
19    ($LHS:ident * $RHS:ident) => {
20        impl Mul<$RHS> for $LHS {
21            type Output = LinExpr;
22            fn mul(self, rhs: $RHS) -> Self::Output { rhs * self }
23        }
24    };
25}
26
27macro_rules! impl_subtract_op {
28    ($LHS:ident - $RHS:ident) => {
29        impl Sub<$RHS> for $LHS {
30            type Output = LinExpr;
31            fn sub(self, rhs: $RHS) -> Self::Output { self + (-rhs) }
32        }
33    }
34}
35
36pub trait IntoCoef: Zero {
37    fn into_coef(self) -> Coef;
38}
39
40impl IntoCoef for i32 {
41    fn into_coef(self) -> Coef { Ratio::from_integer(self) }
42}
43
44impl IntoCoef for Rational32 {
45    fn into_coef(self) -> Coef { self }
46}
47
48/*--------------------------------------------------------------*/
49
50impl<T: IntoCoef> From<T> for LinExpr {
51    fn from(constant: T) -> Self {
52        LinExpr {
53            constant: constant.into_coef(),
54            coef: HashMap::new(),
55        }
56    }
57}
58
59impl From<VarToken> for LinExpr {
60    fn from(var: VarToken) -> Self {
61        let mut coef = HashMap::new();
62        coef.insert(var, Ratio::one());
63
64        LinExpr {
65            constant: Ratio::zero(),
66            coef: coef,
67        }
68    }
69}
70
71/*--------------------------------------------------------------*/
72/* Var-Coef                                                     */
73/*--------------------------------------------------------------*/
74
75impl Neg for VarToken {
76    type Output = LinExpr;
77    fn neg(self) -> Self::Output {
78        -LinExpr::from(self)
79    }
80}
81
82impl<T: IntoCoef> Add<T> for VarToken {
83    type Output = LinExpr;
84    fn add(self, rhs: T) -> Self::Output {
85        LinExpr::from(self) + rhs
86    }
87}
88
89impl_commutative_op!(i32 + VarToken);
90impl_commutative_op!(Rational32 + VarToken);
91
92impl_subtract_op!(VarToken - i32);
93impl_subtract_op!(i32 - VarToken);
94impl_subtract_op!(VarToken - Rational32);
95impl_subtract_op!(Rational32 - VarToken);
96
97impl<T: IntoCoef> Mul<T> for VarToken {
98    type Output = LinExpr;
99    fn mul(self, rhs: T) -> Self::Output {
100        LinExpr::from(self) * rhs
101    }
102}
103
104impl_commutative_op!(i32 * VarToken);
105impl_commutative_op!(Rational32 * VarToken);
106
107/*--------------------------------------------------------------*/
108/* Var-Var                                                      */
109/*--------------------------------------------------------------*/
110
111impl Add for VarToken {
112    type Output = LinExpr;
113    fn add(self, rhs: VarToken) -> Self::Output {
114        LinExpr::from(self) + LinExpr::from(rhs)
115    }
116}
117
118impl_subtract_op!(VarToken - VarToken);
119
120/*--------------------------------------------------------------*/
121/* Expr-Coef                                                    */
122/*--------------------------------------------------------------*/
123
124impl Neg for LinExpr {
125    type Output = LinExpr;
126    fn neg(self) -> Self::Output {
127        -1 * self
128    }
129}
130
131impl<T: IntoCoef> Add<T> for LinExpr {
132    type Output = LinExpr;
133    fn add(mut self, rhs: T) -> Self::Output {
134        self.constant = self.constant + rhs.into_coef();
135        self
136    }
137}
138
139impl_commutative_op!(i32 + LinExpr);
140impl_commutative_op!(Rational32 + LinExpr);
141
142impl_subtract_op!(LinExpr - i32);
143impl_subtract_op!(i32 - LinExpr);
144impl_subtract_op!(LinExpr - Rational32);
145impl_subtract_op!(Rational32 - LinExpr);
146
147impl<T: IntoCoef> Mul<T> for LinExpr {
148    type Output = LinExpr;
149    fn mul(mut self, rhs: T) -> Self::Output {
150        if rhs.is_zero() {
151            self.constant = Ratio::zero();
152            self.coef = HashMap::new();
153        } else {
154            let rhs = rhs.into_coef();
155            if rhs != Ratio::one() {
156                self.constant = self.constant * rhs;
157                for coef in self.coef.values_mut() {
158                    *coef = *coef * rhs;
159                }
160            }
161        }
162
163        self
164    }
165}
166
167impl_commutative_op!(i32 * LinExpr);
168impl_commutative_op!(Rational32 * LinExpr);
169
170/*--------------------------------------------------------------*/
171/* Expr-Var                                                     */
172/*--------------------------------------------------------------*/
173
174impl Add<VarToken> for LinExpr {
175    type Output = LinExpr;
176    fn add(self, rhs: VarToken) -> Self::Output {
177        self + LinExpr::from(rhs)
178    }
179}
180
181impl_commutative_op!(VarToken + LinExpr);
182
183impl_subtract_op!(LinExpr - VarToken);
184impl_subtract_op!(VarToken - LinExpr);
185
186/*--------------------------------------------------------------*/
187/* Expr-Expr                                                    */
188/*--------------------------------------------------------------*/
189
190impl Add for LinExpr {
191    type Output = LinExpr;
192    fn add(mut self, mut rhs: LinExpr) -> Self::Output {
193        self.constant = self.constant + rhs.constant;
194
195        for (x2, a2) in rhs.coef.drain() {
196            match self.coef.entry(x2) {
197                Entry::Vacant(e) => {
198                    e.insert(a2);
199                },
200                Entry::Occupied(mut e) => {
201                    let new_coef = *e.get() + a2;
202                    if new_coef.is_zero() {
203                        e.remove();
204                    } else {
205                        *e.get_mut() = new_coef;
206                    }
207                },
208            }
209        }
210
211        self
212    }
213}
214
215impl_subtract_op!(LinExpr - LinExpr);
216
217/*--------------------------------------------------------------*/
218
219#[cfg(test)]
220mod tests {
221    use num_rational::Ratio;
222    use ::Puzzle;
223
224    #[test]
225    fn test_ops() {
226        let mut puzzle = Puzzle::new();
227        let x = puzzle.new_var();
228        let y = puzzle.new_var();
229
230        // expr = var + const;
231        let _ = x + 1;
232        let _ = x - 1;
233        let _ = x * 1;
234        let _ = x + Ratio::new(1, 2);
235        let _ = x - Ratio::new(1, 2);
236        let _ = x * Ratio::new(1, 2);
237
238        // expr = const + var;
239        let _ = 1 + x;
240        let _ = 1 - x;
241        let _ = 1 * x;
242        let _ = Ratio::new(1, 2) + x;
243        let _ = Ratio::new(1, 2) - x;
244        let _ = Ratio::new(1, 2) * x;
245
246        // expr = var + var;
247        let _ = -x;
248        let _ = x + y;
249        let _ = x - y;
250
251        // expr = expr + const;
252        let _ = (x + y) + 1;
253        let _ = (x + y) - 1;
254        let _ = (x + y) * 1;
255        let _ = (x + y) + Ratio::new(1, 2);
256        let _ = (x + y) - Ratio::new(1, 2);
257        let _ = (x + y) * Ratio::new(1, 2);
258
259        // expr = const + expr;
260        let _ = 1 + (x + y);
261        let _ = 1 - (x + y);
262        let _ = 1 * (x + y);
263        let _ = Ratio::new(1, 2) + (x + y);
264        let _ = Ratio::new(1, 2) - (x + y);
265        let _ = Ratio::new(1, 2) * (x + y);
266
267        // expr = expr + var;
268        let _ = (x + 1) + y;
269        let _ = (x + 1) - y;
270
271        // expr = var + expr;
272        let _ = x + (y + 1);
273        let _ = x - (y + 1);
274
275        // expr = expr + expr;
276        let _ = -(x + y);
277        let _ = (x + y) + (x + y);
278        let _ = (x + y) - (x + y);
279    }
280
281    #[test]
282    fn test_coef_zero() {
283        let mut puzzle = Puzzle::new();
284        let x = puzzle.new_var();
285        let y = puzzle.new_var();
286
287        let expr = x * 0;
288        assert_eq!(expr.coef.len(), 0);
289
290        let expr = x - x;
291        assert_eq!(expr.coef.len(), 0);
292
293        let expr = (x + y) * 0;
294        assert_eq!(expr.coef.len(), 0);
295
296        let expr = (x + y) - (x + y);
297        assert_eq!(expr.coef.len(), 0);
298    }
299}