kizzasi_logic/constraint/
geometric.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum GeometricSet {
10 Box { lower: Vec<f32>, upper: Vec<f32> },
12 Ball { center: Vec<f32>, radius: f32 },
14 Ellipsoid {
16 center: Vec<f32>,
17 shape_inv: Vec<f32>,
19 },
20 Polytope {
22 a_matrix: Vec<f32>,
24 b_vector: Vec<f32>,
26 num_constraints: usize,
28 dimension: usize,
30 },
31 LInfBall { center: Vec<f32>, radius: f32 },
33 Simplex { dimension: usize },
35}
36
37impl GeometricSet {
38 pub fn box_constraint(lower: Vec<f32>, upper: Vec<f32>) -> Self {
40 assert_eq!(
41 lower.len(),
42 upper.len(),
43 "Lower and upper bounds must have same dimension"
44 );
45 Self::Box { lower, upper }
46 }
47
48 pub fn ball(center: Vec<f32>, radius: f32) -> Self {
50 assert!(radius > 0.0, "Radius must be positive");
51 Self::Ball { center, radius }
52 }
53
54 pub fn ellipsoid(center: Vec<f32>, shape_inv: Vec<f32>) -> Self {
56 let dim = center.len();
57 assert_eq!(shape_inv.len(), dim * dim, "Shape matrix must be dim × dim");
58 Self::Ellipsoid { center, shape_inv }
59 }
60
61 pub fn polytope(
63 a_matrix: Vec<f32>,
64 b_vector: Vec<f32>,
65 num_constraints: usize,
66 dimension: usize,
67 ) -> Self {
68 assert_eq!(
69 a_matrix.len(),
70 num_constraints * dimension,
71 "A matrix size mismatch"
72 );
73 assert_eq!(b_vector.len(), num_constraints, "b vector size mismatch");
74 Self::Polytope {
75 a_matrix,
76 b_vector,
77 num_constraints,
78 dimension,
79 }
80 }
81
82 pub fn l_inf_ball(center: Vec<f32>, radius: f32) -> Self {
84 assert!(radius > 0.0, "Radius must be positive");
85 Self::LInfBall { center, radius }
86 }
87
88 pub fn simplex(dimension: usize) -> Self {
90 Self::Simplex { dimension }
91 }
92
93 pub fn contains(&self, x: &[f32]) -> bool {
95 match self {
96 Self::Box { lower, upper } => x
97 .iter()
98 .zip(lower.iter())
99 .zip(upper.iter())
100 .all(|((&xi, &li), &ui)| xi >= li && xi <= ui),
101 Self::Ball { center, radius } => {
102 let dist_sq: f32 = x
103 .iter()
104 .zip(center.iter())
105 .map(|(&xi, &ci)| (xi - ci).powi(2))
106 .sum();
107 dist_sq <= radius * radius
108 }
109 Self::Ellipsoid { center, shape_inv } => {
110 let dim = center.len();
111 let diff: Vec<f32> = x
112 .iter()
113 .zip(center.iter())
114 .map(|(&xi, &ci)| xi - ci)
115 .collect();
116
117 let mut quad_form = 0.0;
119 for i in 0..dim {
120 for j in 0..dim {
121 quad_form += diff[i] * shape_inv[i * dim + j] * diff[j];
122 }
123 }
124 quad_form <= 1.0
125 }
126 Self::Polytope {
127 a_matrix,
128 b_vector,
129 num_constraints,
130 dimension,
131 } => {
132 for i in 0..*num_constraints {
133 let mut ax = 0.0;
134 for j in 0..*dimension {
135 ax += a_matrix[i * dimension + j] * x[j];
136 }
137 if ax > b_vector[i] {
138 return false;
139 }
140 }
141 true
142 }
143 Self::LInfBall { center, radius } => x
144 .iter()
145 .zip(center.iter())
146 .all(|(&xi, &ci)| (xi - ci).abs() <= *radius),
147 Self::Simplex { dimension } => {
148 if x.len() != *dimension {
149 return false;
150 }
151 let sum: f32 = x.iter().sum();
152 x.iter().all(|&xi| xi >= 0.0) && sum <= 1.0
153 }
154 }
155 }
156
157 pub fn distance(&self, x: &[f32]) -> f32 {
159 match self {
160 Self::Box { lower, upper } => x
161 .iter()
162 .zip(lower.iter())
163 .zip(upper.iter())
164 .map(|((&xi, &li), &ui)| {
165 if xi < li {
166 li - xi
167 } else if xi > ui {
168 xi - ui
169 } else {
170 0.0
171 }
172 })
173 .map(|d| d * d)
174 .sum::<f32>()
175 .sqrt(),
176 Self::Ball { center, radius } => {
177 let dist_sq: f32 = x
178 .iter()
179 .zip(center.iter())
180 .map(|(&xi, &ci)| (xi - ci).powi(2))
181 .sum();
182 let dist = dist_sq.sqrt();
183 (dist - radius).max(0.0)
184 }
185 Self::Ellipsoid { .. } => {
186 if self.contains(x) {
188 0.0
189 } else {
190 1.0
192 }
193 }
194 Self::Polytope { .. } => {
195 if self.contains(x) {
197 0.0
198 } else {
199 1.0
200 }
201 }
202 Self::LInfBall { center, radius } => {
203 let max_diff = x
204 .iter()
205 .zip(center.iter())
206 .map(|(&xi, &ci)| (xi - ci).abs())
207 .fold(0.0f32, |a, b| a.max(b));
208 (max_diff - radius).max(0.0)
209 }
210 Self::Simplex { dimension } => {
211 if x.len() != *dimension {
212 return f32::MAX;
213 }
214 let neg_sum: f32 = x.iter().filter(|&&xi| xi < 0.0).map(|&xi| -xi).sum();
215 let sum: f32 = x.iter().sum();
216 let excess = (sum - 1.0).max(0.0);
217 (neg_sum.powi(2) + excess.powi(2)).sqrt()
218 }
219 }
220 }
221
222 pub fn project(&self, x: &[f32]) -> Vec<f32> {
224 match self {
225 Self::Box { lower, upper } => x
226 .iter()
227 .zip(lower.iter())
228 .zip(upper.iter())
229 .map(|((&xi, &li), &ui)| xi.clamp(li, ui))
230 .collect(),
231 Self::Ball { center, radius } => {
232 let diff: Vec<f32> = x
233 .iter()
234 .zip(center.iter())
235 .map(|(&xi, &ci)| xi - ci)
236 .collect();
237 let dist_sq: f32 = diff.iter().map(|&d| d * d).sum();
238 let dist = dist_sq.sqrt();
239
240 if dist <= *radius {
241 x.to_vec()
242 } else {
243 center
244 .iter()
245 .zip(diff.iter())
246 .map(|(&ci, &di)| ci + di * radius / dist)
247 .collect()
248 }
249 }
250 Self::Ellipsoid { .. } | Self::Polytope { .. } => {
251 x.to_vec()
254 }
255 Self::LInfBall { center, radius } => x
256 .iter()
257 .zip(center.iter())
258 .map(|(&xi, &ci)| {
259 let diff = xi - ci;
260 ci + diff.clamp(-*radius, *radius)
261 })
262 .collect(),
263 Self::Simplex { dimension } => {
264 if x.len() != *dimension {
265 return x.to_vec();
266 }
267
268 let mut sorted: Vec<f32> = x.to_vec();
270 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap());
271
272 let mut theta = 0.0;
273 let mut t_sum = 0.0;
274
275 for (i, &val) in sorted.iter().enumerate() {
276 t_sum += val;
277 let candidate = (t_sum - 1.0) / (i + 1) as f32;
278 if i + 1 == sorted.len() || sorted[i + 1] < val - candidate {
279 theta = candidate;
280 break;
281 }
282 }
283
284 x.iter().map(|&xi| (xi - theta).max(0.0)).collect()
285 }
286 }
287 }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct SetMembershipConstraint {
293 name: String,
294 set: GeometricSet,
295 weight: f32,
296}
297
298impl SetMembershipConstraint {
299 pub fn new(name: impl Into<String>, set: GeometricSet) -> Self {
301 Self {
302 name: name.into(),
303 set,
304 weight: 1.0,
305 }
306 }
307
308 pub fn with_weight(mut self, weight: f32) -> Self {
310 self.weight = weight;
311 self
312 }
313
314 pub fn check(&self, x: &[f32]) -> bool {
316 self.set.contains(x)
317 }
318
319 pub fn violation(&self, x: &[f32]) -> f32 {
321 self.set.distance(x)
322 }
323
324 pub fn project(&self, x: &[f32]) -> Vec<f32> {
326 self.set.project(x)
327 }
328
329 pub fn name(&self) -> &str {
331 &self.name
332 }
333
334 pub fn weight(&self) -> f32 {
336 self.weight
337 }
338
339 pub fn set(&self) -> &GeometricSet {
341 &self.set
342 }
343}