1use num_traits::cast::ToPrimitive;
5
6use crate::model::node::Node;
7use crate::model::constant::ConstantScalar;
8use crate::model::constraint::{Constraint, ConstraintKind};
9
10pub trait NodeCmp<T> {
12
13 fn equal_and_tag(&self, other: T, tag: &str) -> Constraint;
15
16 fn equal(&self, other: T) -> Constraint { self.equal_and_tag(other, "") }
18
19 fn geq_and_tag(&self, other: T, tag: &str) -> Constraint;
21
22 fn geq(&self, other: T) -> Constraint { self.geq_and_tag(other, "") }
24
25 fn leq_and_tag(&self, other: T, tag: &str) -> Constraint;
27
28 fn leq(&self, other: T) -> Constraint { self.leq_and_tag(other, "") }
30}
31
32macro_rules! impl_node_cmp_scalar {
33 ($x: ty, $y: ty) => {
34 impl NodeCmp<$y> for $x {
35
36 fn equal_and_tag(&self, other: $y, tag: &str) -> Constraint {
37 Constraint::new(self.clone(),
38 ConstraintKind::Equal,
39 ConstantScalar::new(other.to_f64().unwrap()),
40 tag)
41 }
42
43 fn geq_and_tag(&self, other: $y, tag: &str) -> Constraint {
44 Constraint::new(self.clone(),
45 ConstraintKind::GreaterEqual,
46 ConstantScalar::new(other.to_f64().unwrap()),
47 tag)
48 }
49
50 fn leq_and_tag(&self, other: $y, tag: &str) -> Constraint {
51 Constraint::new(self.clone(),
52 ConstraintKind::LessEqual,
53 ConstantScalar::new(other.to_f64().unwrap()),
54 tag)
55 }
56 }
57 };
58}
59
60impl_node_cmp_scalar!(Node, f64);
61
62macro_rules! impl_node_cmp_node {
63 ($x: ty, $y: ty) => {
64 impl NodeCmp<$y> for $x {
65
66 fn equal_and_tag(&self, other: $y, tag: &str) -> Constraint {
67 Constraint::new(self.clone(),
68 ConstraintKind::Equal,
69 other.clone(),
70 tag)
71 }
72
73 fn geq_and_tag(&self, other: $y, tag: &str) -> Constraint {
74 Constraint::new(self.clone(),
75 ConstraintKind::GreaterEqual,
76 other.clone(),
77 tag)
78 }
79
80 fn leq_and_tag(&self, other: $y, tag: &str) -> Constraint {
81 Constraint::new(self.clone(),
82 ConstraintKind::LessEqual,
83 other.clone(),
84 tag)
85 }
86 }
87 };
88}
89
90impl_node_cmp_node!(Node, Node);
91impl_node_cmp_node!(Node, &Node);
92
93macro_rules! impl_scalar_cmp_node {
94 ($x: ty, $y: ty) => {
95 impl NodeCmp<$y> for $x {
96
97 fn equal_and_tag(&self, other: $y, tag: &str) -> Constraint {
98 Constraint::new(ConstantScalar::new(self.to_f64().unwrap()),
99 ConstraintKind::Equal,
100 other.clone(),
101 tag)
102 }
103
104 fn geq_and_tag(&self, other: $y, tag: &str) -> Constraint {
105 Constraint::new(ConstantScalar::new(self.to_f64().unwrap()),
106 ConstraintKind::GreaterEqual,
107 other.clone(),
108 tag)
109 }
110
111 fn leq_and_tag(&self, other: $y, tag: &str) -> Constraint {
112 Constraint::new(ConstantScalar::new(self.to_f64().unwrap()),
113 ConstraintKind::LessEqual,
114 other.clone(),
115 tag)
116 }
117 }
118 };
119}
120
121impl_scalar_cmp_node!(f64, Node);
122impl_scalar_cmp_node!(f64, &Node);
123
124#[cfg(test)]
125mod tests {
126
127 use crate::model::node_cmp::NodeCmp;
128 use crate::model::variable::VariableScalar;
129 use crate::model::constant::ConstantScalar;
130
131 #[test]
132 fn node_cmp_node() {
133
134 let x = VariableScalar::new_continuous("x");
135 let c = ConstantScalar::new(5.);
136
137 let z1 = x.equal(&c);
138 assert_eq!(format!("{}", z1), "x == 5");
139
140 let z2 = &x.leq(&c);
141 assert_eq!(format!("{}", z2), "x <= 5");
142
143 let z3 = x.geq(&x + 3.);
144 assert_eq!(format!("{}", z3), "x >= x + 3");
145
146 let z4 = &x.leq(5.*&x);
147 assert_eq!(format!("{}", z4), "x <= 5*x");
148 }
149
150 #[test]
151 fn node_cmp_scalar() {
152
153 let x = VariableScalar::new_continuous("x");
154
155 let z1 = x.equal(6.);
156 assert_eq!(format!("{}", z1), "x == 6");
157
158 let z2 = &x.equal(10.);
159 assert_eq!(format!("{}", z2), "x == 10");
160
161 let z3 = (&x + 11.).equal(12.);
162 assert_eq!(format!("{}", z3), "x + 11 == 12");
163 }
164
165 #[test]
166 fn scalar_cmp_node() {
167
168 let x = VariableScalar::new_continuous("x");
169
170 let z1 = 4_f64.equal(&x);
171 assert_eq!(format!("{}", z1), "4 == x");
172
173 let z2 = 4_f64.leq(&x + 3.);
174 assert_eq!(format!("{}", z2), "4 <= x + 3");
175
176 let z3 = 5_f64.geq(&x*5.);
177 assert_eq!(format!("{}", z3), "5 >= x*5");
178 }
179}