kizzasi_logic/constraint/
geometric.rs

1use serde::{Deserialize, Serialize};
2
3// ============================================================================
4// Set Membership Constraints
5// ============================================================================
6
7/// Geometric set types for membership constraints
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum GeometricSet {
10    /// Axis-aligned box: l <= x <= u
11    Box { lower: Vec<f32>, upper: Vec<f32> },
12    /// Euclidean ball: ||x - center||₂ <= radius
13    Ball { center: Vec<f32>, radius: f32 },
14    /// Ellipsoid: (x-c)ᵀ P (x-c) <= 1
15    Ellipsoid {
16        center: Vec<f32>,
17        /// Inverse covariance matrix (flattened)
18        shape_inv: Vec<f32>,
19    },
20    /// Polytope: Ax <= b
21    Polytope {
22        /// Constraint matrix A (row-major)
23        a_matrix: Vec<f32>,
24        /// Right-hand side b
25        b_vector: Vec<f32>,
26        /// Number of rows in A
27        num_constraints: usize,
28        /// Number of columns in A (dimension)
29        dimension: usize,
30    },
31    /// L-infinity ball: ||x - center||_∞ <= radius
32    LInfBall { center: Vec<f32>, radius: f32 },
33    /// Simplex: x_i >= 0, Σx_i <= 1
34    Simplex { dimension: usize },
35}
36
37impl GeometricSet {
38    /// Create a box constraint
39    pub fn box_constraint(lower: Vec<f32>, upper: Vec<f32>) -> Self {
40        assert_eq!(
41            lower.len(),
42            upper.len(),
43            "Lower and upper bounds must have same dimension"
44        );
45        Self::Box { lower, upper }
46    }
47
48    /// Create a ball constraint
49    pub fn ball(center: Vec<f32>, radius: f32) -> Self {
50        assert!(radius > 0.0, "Radius must be positive");
51        Self::Ball { center, radius }
52    }
53
54    /// Create an ellipsoid constraint
55    pub fn ellipsoid(center: Vec<f32>, shape_inv: Vec<f32>) -> Self {
56        let dim = center.len();
57        assert_eq!(shape_inv.len(), dim * dim, "Shape matrix must be dim × dim");
58        Self::Ellipsoid { center, shape_inv }
59    }
60
61    /// Create a polytope constraint Ax <= b
62    pub fn polytope(
63        a_matrix: Vec<f32>,
64        b_vector: Vec<f32>,
65        num_constraints: usize,
66        dimension: usize,
67    ) -> Self {
68        assert_eq!(
69            a_matrix.len(),
70            num_constraints * dimension,
71            "A matrix size mismatch"
72        );
73        assert_eq!(b_vector.len(), num_constraints, "b vector size mismatch");
74        Self::Polytope {
75            a_matrix,
76            b_vector,
77            num_constraints,
78            dimension,
79        }
80    }
81
82    /// Create an L-infinity ball
83    pub fn l_inf_ball(center: Vec<f32>, radius: f32) -> Self {
84        assert!(radius > 0.0, "Radius must be positive");
85        Self::LInfBall { center, radius }
86    }
87
88    /// Create a simplex constraint
89    pub fn simplex(dimension: usize) -> Self {
90        Self::Simplex { dimension }
91    }
92
93    /// Check if a point is in the set
94    pub fn contains(&self, x: &[f32]) -> bool {
95        match self {
96            Self::Box { lower, upper } => x
97                .iter()
98                .zip(lower.iter())
99                .zip(upper.iter())
100                .all(|((&xi, &li), &ui)| xi >= li && xi <= ui),
101            Self::Ball { center, radius } => {
102                let dist_sq: f32 = x
103                    .iter()
104                    .zip(center.iter())
105                    .map(|(&xi, &ci)| (xi - ci).powi(2))
106                    .sum();
107                dist_sq <= radius * radius
108            }
109            Self::Ellipsoid { center, shape_inv } => {
110                let dim = center.len();
111                let diff: Vec<f32> = x
112                    .iter()
113                    .zip(center.iter())
114                    .map(|(&xi, &ci)| xi - ci)
115                    .collect();
116
117                // Compute (x-c)ᵀ P (x-c)
118                let mut quad_form = 0.0;
119                for i in 0..dim {
120                    for j in 0..dim {
121                        quad_form += diff[i] * shape_inv[i * dim + j] * diff[j];
122                    }
123                }
124                quad_form <= 1.0
125            }
126            Self::Polytope {
127                a_matrix,
128                b_vector,
129                num_constraints,
130                dimension,
131            } => {
132                for i in 0..*num_constraints {
133                    let mut ax = 0.0;
134                    for j in 0..*dimension {
135                        ax += a_matrix[i * dimension + j] * x[j];
136                    }
137                    if ax > b_vector[i] {
138                        return false;
139                    }
140                }
141                true
142            }
143            Self::LInfBall { center, radius } => x
144                .iter()
145                .zip(center.iter())
146                .all(|(&xi, &ci)| (xi - ci).abs() <= *radius),
147            Self::Simplex { dimension } => {
148                if x.len() != *dimension {
149                    return false;
150                }
151                let sum: f32 = x.iter().sum();
152                x.iter().all(|&xi| xi >= 0.0) && sum <= 1.0
153            }
154        }
155    }
156
157    /// Compute distance to the set (0 if inside)
158    pub fn distance(&self, x: &[f32]) -> f32 {
159        match self {
160            Self::Box { lower, upper } => x
161                .iter()
162                .zip(lower.iter())
163                .zip(upper.iter())
164                .map(|((&xi, &li), &ui)| {
165                    if xi < li {
166                        li - xi
167                    } else if xi > ui {
168                        xi - ui
169                    } else {
170                        0.0
171                    }
172                })
173                .map(|d| d * d)
174                .sum::<f32>()
175                .sqrt(),
176            Self::Ball { center, radius } => {
177                let dist_sq: f32 = x
178                    .iter()
179                    .zip(center.iter())
180                    .map(|(&xi, &ci)| (xi - ci).powi(2))
181                    .sum();
182                let dist = dist_sq.sqrt();
183                (dist - radius).max(0.0)
184            }
185            Self::Ellipsoid { .. } => {
186                // For ellipsoid, use simple check-based distance
187                if self.contains(x) {
188                    0.0
189                } else {
190                    // Approximation: would need proper optimization
191                    1.0
192                }
193            }
194            Self::Polytope { .. } => {
195                // For polytope, use simple check-based distance
196                if self.contains(x) {
197                    0.0
198                } else {
199                    1.0
200                }
201            }
202            Self::LInfBall { center, radius } => {
203                let max_diff = x
204                    .iter()
205                    .zip(center.iter())
206                    .map(|(&xi, &ci)| (xi - ci).abs())
207                    .fold(0.0f32, |a, b| a.max(b));
208                (max_diff - radius).max(0.0)
209            }
210            Self::Simplex { dimension } => {
211                if x.len() != *dimension {
212                    return f32::MAX;
213                }
214                let neg_sum: f32 = x.iter().filter(|&&xi| xi < 0.0).map(|&xi| -xi).sum();
215                let sum: f32 = x.iter().sum();
216                let excess = (sum - 1.0).max(0.0);
217                (neg_sum.powi(2) + excess.powi(2)).sqrt()
218            }
219        }
220    }
221
222    /// Project a point onto the set
223    pub fn project(&self, x: &[f32]) -> Vec<f32> {
224        match self {
225            Self::Box { lower, upper } => x
226                .iter()
227                .zip(lower.iter())
228                .zip(upper.iter())
229                .map(|((&xi, &li), &ui)| xi.clamp(li, ui))
230                .collect(),
231            Self::Ball { center, radius } => {
232                let diff: Vec<f32> = x
233                    .iter()
234                    .zip(center.iter())
235                    .map(|(&xi, &ci)| xi - ci)
236                    .collect();
237                let dist_sq: f32 = diff.iter().map(|&d| d * d).sum();
238                let dist = dist_sq.sqrt();
239
240                if dist <= *radius {
241                    x.to_vec()
242                } else {
243                    center
244                        .iter()
245                        .zip(diff.iter())
246                        .map(|(&ci, &di)| ci + di * radius / dist)
247                        .collect()
248                }
249            }
250            Self::Ellipsoid { .. } | Self::Polytope { .. } => {
251                // For ellipsoid and polytope, proper projection requires optimization/QP solver
252                // Simple fallback: return input (would need iterative projection for exact solution)
253                x.to_vec()
254            }
255            Self::LInfBall { center, radius } => x
256                .iter()
257                .zip(center.iter())
258                .map(|(&xi, &ci)| {
259                    let diff = xi - ci;
260                    ci + diff.clamp(-*radius, *radius)
261                })
262                .collect(),
263            Self::Simplex { dimension } => {
264                if x.len() != *dimension {
265                    return x.to_vec();
266                }
267
268                // Project onto simplex using efficient algorithm
269                let mut sorted: Vec<f32> = x.to_vec();
270                sorted.sort_by(|a, b| b.partial_cmp(a).unwrap());
271
272                let mut theta = 0.0;
273                let mut t_sum = 0.0;
274
275                for (i, &val) in sorted.iter().enumerate() {
276                    t_sum += val;
277                    let candidate = (t_sum - 1.0) / (i + 1) as f32;
278                    if i + 1 == sorted.len() || sorted[i + 1] < val - candidate {
279                        theta = candidate;
280                        break;
281                    }
282                }
283
284                x.iter().map(|&xi| (xi - theta).max(0.0)).collect()
285            }
286        }
287    }
288}
289
290/// Set membership constraint
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct SetMembershipConstraint {
293    name: String,
294    set: GeometricSet,
295    weight: f32,
296}
297
298impl SetMembershipConstraint {
299    /// Create a new set membership constraint
300    pub fn new(name: impl Into<String>, set: GeometricSet) -> Self {
301        Self {
302            name: name.into(),
303            set,
304            weight: 1.0,
305        }
306    }
307
308    /// Set the constraint weight
309    pub fn with_weight(mut self, weight: f32) -> Self {
310        self.weight = weight;
311        self
312    }
313
314    /// Check if point is in the set
315    pub fn check(&self, x: &[f32]) -> bool {
316        self.set.contains(x)
317    }
318
319    /// Compute violation (distance to set)
320    pub fn violation(&self, x: &[f32]) -> f32 {
321        self.set.distance(x)
322    }
323
324    /// Project point onto the set
325    pub fn project(&self, x: &[f32]) -> Vec<f32> {
326        self.set.project(x)
327    }
328
329    /// Get constraint name
330    pub fn name(&self) -> &str {
331        &self.name
332    }
333
334    /// Get weight
335    pub fn weight(&self) -> f32 {
336        self.weight
337    }
338
339    /// Get the geometric set
340    pub fn set(&self) -> &GeometricSet {
341        &self.set
342    }
343}