kasuari/
term.rs

1use alloc::vec;
2use core::ops;
3
4use crate::{Expression, Variable};
5
6/// A variable and a coefficient to multiply that variable by.
7///
8/// This is a sub-expression in a constraint equation that represents:
9///
10/// ```text
11/// term = coefficient * variable
12/// ```
13#[derive(Copy, Clone, Debug, PartialEq)]
14pub struct Term {
15    pub variable: Variable,
16    pub coefficient: f64,
17}
18
19impl Term {
20    /// Construct a new Term from a variable and a coefficient.
21    #[inline]
22    pub const fn new(variable: Variable, coefficient: f64) -> Term {
23        Term {
24            variable,
25            coefficient,
26        }
27    }
28
29    /// Construct a new Term from a variable with a coefficient of 1.0.
30    #[inline]
31    pub const fn from_variable(variable: Variable) -> Term {
32        Term::new(variable, 1.0)
33    }
34}
35
36impl From<Variable> for Term {
37    #[inline]
38    fn from(variable: Variable) -> Term {
39        Term::from_variable(variable)
40    }
41}
42
43impl ops::Mul<f64> for Term {
44    type Output = Term;
45
46    #[inline]
47    fn mul(self, rhs: f64) -> Term {
48        Term::new(self.variable, self.coefficient * rhs)
49    }
50}
51
52impl ops::Mul<Term> for f64 {
53    type Output = Term;
54
55    #[inline]
56    fn mul(self, rhs: Term) -> Term {
57        Term::new(rhs.variable, self * rhs.coefficient)
58    }
59}
60
61impl ops::Mul<f32> for Term {
62    type Output = Term;
63
64    #[inline]
65    fn mul(self, rhs: f32) -> Term {
66        Term::new(self.variable, self.coefficient * rhs as f64)
67    }
68}
69
70impl ops::Mul<Term> for f32 {
71    type Output = Term;
72
73    #[inline]
74    fn mul(self, rhs: Term) -> Term {
75        Term::new(rhs.variable, self as f64 * rhs.coefficient)
76    }
77}
78
79impl ops::MulAssign<f64> for Term {
80    #[inline]
81    fn mul_assign(&mut self, rhs: f64) {
82        self.coefficient *= rhs;
83    }
84}
85
86impl ops::MulAssign<f32> for Term {
87    #[inline]
88    fn mul_assign(&mut self, rhs: f32) {
89        self.coefficient *= rhs as f64;
90    }
91}
92
93impl ops::Div<f64> for Term {
94    type Output = Term;
95
96    #[inline]
97    fn div(self, rhs: f64) -> Term {
98        Term::new(self.variable, self.coefficient / rhs)
99    }
100}
101impl ops::Div<f32> for Term {
102    type Output = Term;
103
104    #[inline]
105    fn div(self, rhs: f32) -> Term {
106        Term::new(self.variable, self.coefficient / rhs as f64)
107    }
108}
109
110impl ops::DivAssign<f64> for Term {
111    #[inline]
112    fn div_assign(&mut self, rhs: f64) {
113        self.coefficient /= rhs;
114    }
115}
116
117impl ops::DivAssign<f32> for Term {
118    #[inline]
119    fn div_assign(&mut self, rhs: f32) {
120        self.coefficient /= rhs as f64;
121    }
122}
123
124impl ops::Add<f64> for Term {
125    type Output = Expression;
126
127    #[inline]
128    fn add(self, rhs: f64) -> Expression {
129        Expression::new(vec![self], rhs)
130    }
131}
132
133impl ops::Add<f32> for Term {
134    type Output = Expression;
135
136    #[inline]
137    fn add(self, rhs: f32) -> Expression {
138        Expression::new(vec![self], rhs as f64)
139    }
140}
141
142impl ops::Add<Term> for f64 {
143    type Output = Expression;
144
145    #[inline]
146    fn add(self, rhs: Term) -> Expression {
147        Expression::new(vec![rhs], self)
148    }
149}
150
151impl ops::Add<Term> for f32 {
152    type Output = Expression;
153
154    #[inline]
155    fn add(self, rhs: Term) -> Expression {
156        Expression::new(vec![rhs], self as f64)
157    }
158}
159
160impl ops::Add<Term> for Term {
161    type Output = Expression;
162
163    #[inline]
164    fn add(self, rhs: Term) -> Expression {
165        Expression::from_terms(vec![self, rhs])
166    }
167}
168
169impl ops::Add<Expression> for Term {
170    type Output = Expression;
171
172    #[inline]
173    fn add(self, mut rhs: Expression) -> Expression {
174        rhs.terms.insert(0, self);
175        rhs
176    }
177}
178
179impl ops::Add<Term> for Expression {
180    type Output = Expression;
181
182    #[inline]
183    fn add(mut self, rhs: Term) -> Expression {
184        self.terms.push(rhs);
185        self
186    }
187}
188
189impl ops::AddAssign<Term> for Expression {
190    #[inline]
191    fn add_assign(&mut self, rhs: Term) {
192        self.terms.push(rhs);
193    }
194}
195
196impl ops::Neg for Term {
197    type Output = Term;
198
199    #[inline]
200    fn neg(mut self) -> Term {
201        self.coefficient = -self.coefficient;
202        self
203    }
204}
205
206impl ops::Sub<f64> for Term {
207    type Output = Expression;
208
209    #[inline]
210    fn sub(self, rhs: f64) -> Expression {
211        Expression::new(vec![self], -rhs)
212    }
213}
214
215impl ops::Sub<f32> for Term {
216    type Output = Expression;
217
218    #[inline]
219    fn sub(self, rhs: f32) -> Expression {
220        Expression::new(vec![self], -(rhs as f64))
221    }
222}
223
224impl ops::Sub<Term> for f64 {
225    type Output = Expression;
226
227    #[inline]
228    fn sub(self, rhs: Term) -> Expression {
229        Expression::new(vec![-rhs], self)
230    }
231}
232
233impl ops::Sub<Term> for f32 {
234    type Output = Expression;
235
236    #[inline]
237    fn sub(self, rhs: Term) -> Expression {
238        Expression::new(vec![-rhs], self as f64)
239    }
240}
241
242impl ops::Sub<Term> for Term {
243    type Output = Expression;
244
245    #[inline]
246    fn sub(self, rhs: Term) -> Expression {
247        Expression::from_terms(vec![self, -rhs])
248    }
249}
250
251impl ops::Sub<Expression> for Term {
252    type Output = Expression;
253
254    #[inline]
255    fn sub(self, mut rhs: Expression) -> Expression {
256        rhs = -rhs;
257        rhs.terms.insert(0, self);
258        rhs
259    }
260}
261
262impl ops::Sub<Term> for Expression {
263    type Output = Expression;
264
265    #[inline]
266    fn sub(mut self, rhs: Term) -> Expression {
267        self -= rhs;
268        self
269    }
270}
271
272impl ops::SubAssign<Term> for Expression {
273    #[inline]
274    fn sub_assign(&mut self, rhs: Term) {
275        self.terms.push(-rhs);
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    const LEFT: Variable = Variable::from_id(0);
284    const RIGHT: Variable = Variable::from_id(1);
285    const LEFT_TERM: Term = Term::from_variable(LEFT);
286    const RIGHT_TERM: Term = Term::from_variable(RIGHT);
287
288    #[test]
289    fn new() {
290        assert_eq!(
291            Term::new(LEFT, 2.0),
292            Term {
293                variable: LEFT,
294                coefficient: 2.0
295            }
296        );
297    }
298
299    #[test]
300    fn from_variable() {
301        assert_eq!(
302            Term::from_variable(LEFT),
303            Term {
304                variable: LEFT,
305                coefficient: 1.0
306            }
307        );
308    }
309
310    #[test]
311    fn mul_f64() {
312        assert_eq!(
313            LEFT_TERM * 2.0,
314            Term {
315                variable: LEFT,
316                coefficient: 2.0
317            }
318        );
319        assert_eq!(
320            2.0 * LEFT_TERM,
321            Term {
322                variable: LEFT,
323                coefficient: 2.0
324            }
325        );
326    }
327
328    #[test]
329    fn mul_f32() {
330        assert_eq!(
331            LEFT_TERM * 2.0f32,
332            Term {
333                variable: LEFT,
334                coefficient: 2.0
335            }
336        );
337        assert_eq!(
338            2.0f32 * LEFT_TERM,
339            Term {
340                variable: LEFT,
341                coefficient: 2.0
342            }
343        );
344    }
345
346    #[test]
347    fn mul_assign_f64() {
348        let mut term = LEFT_TERM;
349        term *= 2.0;
350        assert_eq!(
351            term,
352            Term {
353                variable: LEFT,
354                coefficient: 2.0
355            }
356        );
357    }
358
359    #[test]
360    fn mul_assign_f32() {
361        let mut term = LEFT_TERM;
362        term *= 2.0f32;
363        assert_eq!(
364            term,
365            Term {
366                variable: LEFT,
367                coefficient: 2.0
368            }
369        );
370    }
371
372    #[test]
373    fn div_f64() {
374        assert_eq!(
375            LEFT_TERM / 2.0,
376            Term {
377                variable: LEFT,
378                coefficient: 0.5
379            }
380        );
381    }
382
383    #[test]
384    fn div_f32() {
385        assert_eq!(
386            LEFT_TERM / 2.0f32,
387            Term {
388                variable: LEFT,
389                coefficient: 0.5
390            }
391        );
392    }
393
394    #[test]
395    fn div_assign_f64() {
396        let mut term = LEFT_TERM;
397        term /= 2.0;
398        assert_eq!(
399            term,
400            Term {
401                variable: LEFT,
402                coefficient: 0.5
403            }
404        );
405    }
406
407    #[test]
408    fn div_assign_f32() {
409        let mut term = LEFT_TERM;
410        term /= 2.0f32;
411        assert_eq!(
412            term,
413            Term {
414                variable: LEFT,
415                coefficient: 0.5
416            }
417        );
418    }
419
420    #[test]
421    fn add_f64() {
422        assert_eq!(LEFT_TERM + 2.0, Expression::new(vec![LEFT_TERM], 2.0));
423        assert_eq!(2.0 + LEFT_TERM, Expression::new(vec![LEFT_TERM], 2.0));
424    }
425
426    #[test]
427    fn add_f32() {
428        assert_eq!(LEFT_TERM + 2.0f32, Expression::new(vec![LEFT_TERM], 2.0));
429        assert_eq!(2.0f32 + LEFT_TERM, Expression::new(vec![LEFT_TERM], 2.0));
430    }
431
432    #[test]
433    fn add_term() {
434        assert_eq!(
435            LEFT_TERM + RIGHT_TERM,
436            Expression::from_terms(vec![LEFT_TERM, RIGHT_TERM])
437        );
438    }
439
440    #[test]
441    fn add_expression() {
442        assert_eq!(
443            LEFT_TERM + Expression::new(vec![RIGHT_TERM], 1.0),
444            Expression::new(vec![LEFT_TERM, RIGHT_TERM], 1.0)
445        );
446    }
447
448    #[test]
449    fn sub_f64() {
450        assert_eq!(LEFT_TERM - 2.0, Expression::new(vec![LEFT_TERM], -2.0));
451        assert_eq!(2.0 - LEFT_TERM, Expression::new(vec![-LEFT_TERM], 2.0));
452    }
453
454    #[test]
455    fn sub_f32() {
456        assert_eq!(LEFT_TERM - 2.0f32, Expression::new(vec![LEFT_TERM], -2.0));
457        assert_eq!(2.0f32 - LEFT_TERM, Expression::new(vec![-LEFT_TERM], 2.0));
458    }
459
460    #[test]
461    fn sub_term() {
462        assert_eq!(
463            LEFT_TERM - RIGHT_TERM,
464            Expression::from_terms(vec![LEFT_TERM, -RIGHT_TERM])
465        );
466    }
467
468    #[test]
469    fn sub_expression() {
470        assert_eq!(
471            LEFT_TERM - Expression::new(vec![RIGHT_TERM], 1.0),
472            Expression::new(vec![LEFT_TERM, -RIGHT_TERM], -1.0)
473        );
474    }
475
476    #[test]
477    fn neg() {
478        assert_eq!(
479            -LEFT_TERM,
480            Term {
481                variable: LEFT,
482                coefficient: -1.0
483            }
484        );
485    }
486}