numopt/model/
node_cmp.rs

1//! Trait for comparing expression nodes and constructing 
2//! optimization constraints. 
3
4use num_traits::cast::ToPrimitive;
5
6use crate::model::node::Node;
7use crate::model::constant::ConstantScalar;
8use crate::model::constraint::{Constraint, ConstraintKind};
9
10/// Trait for comparing expression nodes.
11pub trait NodeCmp<T> {
12
13    /// Creates equality constraint and tags it.
14    fn equal_and_tag(&self, other: T, tag: &str) -> Constraint;
15
16    /// Creates equality constraint.
17    fn equal(&self, other: T) -> Constraint { self.equal_and_tag(other, "") }
18
19    /// Creates greater-than-or-equal constraint and tags it.
20    fn geq_and_tag(&self, other: T, tag: &str) -> Constraint;
21
22    /// Creates greater-than-or-equal constraint.
23    fn geq(&self, other: T) -> Constraint { self.geq_and_tag(other, "") }
24
25    /// Creates less-than-or-equal constraint and tags it.
26    fn leq_and_tag(&self, other: T, tag: &str) -> Constraint;
27
28    /// Creates less-than-or-equal constraint.
29    fn leq(&self, other: T) -> Constraint { self.leq_and_tag(other, "") }
30}
31
32macro_rules! impl_node_cmp_scalar {
33    ($x: ty, $y: ty) => {
34        impl NodeCmp<$y> for $x {
35    
36            fn equal_and_tag(&self, other: $y, tag: &str) -> Constraint {
37                Constraint::new(self.clone(),
38                                ConstraintKind::Equal,
39                                ConstantScalar::new(other.to_f64().unwrap()),
40                                tag)
41            }
42
43            fn geq_and_tag(&self, other: $y, tag: &str) -> Constraint {
44                Constraint::new(self.clone(),
45                                ConstraintKind::GreaterEqual,
46                                ConstantScalar::new(other.to_f64().unwrap()),
47                                tag)
48            }
49
50            fn leq_and_tag(&self, other: $y, tag: &str) -> Constraint {
51                Constraint::new(self.clone(),
52                                ConstraintKind::LessEqual,
53                                ConstantScalar::new(other.to_f64().unwrap()),
54                                tag)
55            }
56        }
57    };
58}
59
60impl_node_cmp_scalar!(Node, f64);
61
62macro_rules! impl_node_cmp_node {
63    ($x: ty, $y: ty) => {
64        impl NodeCmp<$y> for $x {
65    
66            fn equal_and_tag(&self, other: $y, tag: &str) -> Constraint {
67                Constraint::new(self.clone(),
68                                ConstraintKind::Equal,
69                                other.clone(),
70                                tag)
71            }
72
73            fn geq_and_tag(&self, other: $y, tag: &str) -> Constraint {
74                Constraint::new(self.clone(),
75                                ConstraintKind::GreaterEqual,
76                                other.clone(),
77                                tag)
78            }
79
80            fn leq_and_tag(&self, other: $y, tag: &str) -> Constraint {
81                Constraint::new(self.clone(),
82                                ConstraintKind::LessEqual,
83                                other.clone(),
84                                tag)
85            }
86        }
87    };
88}
89
90impl_node_cmp_node!(Node, Node);
91impl_node_cmp_node!(Node, &Node);
92
93macro_rules! impl_scalar_cmp_node {
94    ($x: ty, $y: ty) => {
95        impl NodeCmp<$y> for $x {
96    
97            fn equal_and_tag(&self, other: $y, tag: &str) -> Constraint {
98                Constraint::new(ConstantScalar::new(self.to_f64().unwrap()),
99                                ConstraintKind::Equal,
100                                other.clone(),
101                                tag)
102            }
103
104            fn geq_and_tag(&self, other: $y, tag: &str) -> Constraint {
105                Constraint::new(ConstantScalar::new(self.to_f64().unwrap()),
106                                ConstraintKind::GreaterEqual,
107                                other.clone(),
108                                tag)
109            }
110
111            fn leq_and_tag(&self, other: $y, tag: &str) -> Constraint {
112                Constraint::new(ConstantScalar::new(self.to_f64().unwrap()),
113                                ConstraintKind::LessEqual,
114                                other.clone(),
115                                tag)
116            }
117        }
118    };
119}
120
121impl_scalar_cmp_node!(f64, Node);
122impl_scalar_cmp_node!(f64, &Node);
123
124#[cfg(test)]
125mod tests {
126
127    use crate::model::node_cmp::NodeCmp;
128    use crate::model::variable::VariableScalar;
129    use crate::model::constant::ConstantScalar;
130
131    #[test]
132    fn node_cmp_node() {
133
134        let x = VariableScalar::new_continuous("x");
135        let c = ConstantScalar::new(5.);
136
137        let z1 = x.equal(&c);
138        assert_eq!(format!("{}", z1), "x == 5");
139
140        let z2 = &x.leq(&c);
141        assert_eq!(format!("{}", z2), "x <= 5");
142
143        let z3 = x.geq(&x + 3.);
144        assert_eq!(format!("{}", z3), "x >= x + 3");
145
146        let z4 = &x.leq(5.*&x);
147        assert_eq!(format!("{}", z4), "x <= 5*x");
148    }
149
150    #[test]
151    fn node_cmp_scalar() {
152
153        let x = VariableScalar::new_continuous("x");
154
155        let z1 = x.equal(6.);
156        assert_eq!(format!("{}", z1), "x == 6");
157
158        let z2 = &x.equal(10.);
159        assert_eq!(format!("{}", z2), "x == 10");
160
161        let z3 = (&x + 11.).equal(12.);
162        assert_eq!(format!("{}", z3), "x + 11 == 12");
163    }
164
165    #[test]
166    fn scalar_cmp_node() {
167
168        let x = VariableScalar::new_continuous("x");
169
170        let z1 = 4_f64.equal(&x);
171        assert_eq!(format!("{}", z1), "4 == x");
172
173        let z2 = 4_f64.leq(&x + 3.);
174        assert_eq!(format!("{}", z2), "4 <= x + 3");
175
176        let z3 = 5_f64.geq(&x*5.);
177        assert_eq!(format!("{}", z3), "5 >= x*5");
178    }
179}