kizzasi_logic/constraint/
linear.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum LinearConstraintType {
10 LessEq,
12 GreaterEq,
14 Equality { tolerance: f32 },
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct LinearConstraint {
25 coefficients: Vec<f32>,
27 rhs: f32,
29 constraint_type: LinearConstraintType,
31 weight: f32,
33}
34
35impl LinearConstraint {
36 pub fn less_eq(coefficients: Vec<f32>, rhs: f32) -> Self {
38 Self {
39 coefficients,
40 rhs,
41 constraint_type: LinearConstraintType::LessEq,
42 weight: 1.0,
43 }
44 }
45
46 pub fn greater_eq(coefficients: Vec<f32>, rhs: f32) -> Self {
48 Self {
49 coefficients,
50 rhs,
51 constraint_type: LinearConstraintType::GreaterEq,
52 weight: 1.0,
53 }
54 }
55
56 pub fn equality(coefficients: Vec<f32>, rhs: f32, tolerance: f32) -> Self {
58 Self {
59 coefficients,
60 rhs,
61 constraint_type: LinearConstraintType::Equality { tolerance },
62 weight: 1.0,
63 }
64 }
65
66 pub fn with_weight(mut self, weight: f32) -> Self {
68 self.weight = weight;
69 self
70 }
71
72 fn dot(&self, x: &[f32]) -> f32 {
74 self.coefficients
75 .iter()
76 .zip(x.iter())
77 .map(|(a, xi)| a * xi)
78 .sum()
79 }
80
81 fn norm_sq(&self) -> f32 {
83 self.coefficients.iter().map(|a| a * a).sum()
84 }
85
86 pub fn check(&self, x: &[f32]) -> bool {
88 let ax = self.dot(x);
89 match &self.constraint_type {
90 LinearConstraintType::LessEq => ax <= self.rhs,
91 LinearConstraintType::GreaterEq => ax >= self.rhs,
92 LinearConstraintType::Equality { tolerance } => (ax - self.rhs).abs() <= *tolerance,
93 }
94 }
95
96 pub fn violation(&self, x: &[f32]) -> f32 {
98 let ax = self.dot(x);
99 match &self.constraint_type {
100 LinearConstraintType::LessEq => (ax - self.rhs).max(0.0),
101 LinearConstraintType::GreaterEq => (self.rhs - ax).max(0.0),
102 LinearConstraintType::Equality { tolerance } => {
103 let diff = (ax - self.rhs).abs();
104 (diff - tolerance).max(0.0)
105 }
106 }
107 }
108
109 pub fn project(&self, x: &[f32]) -> Vec<f32> {
113 let ax = self.dot(x);
114 let norm_sq = self.norm_sq();
115
116 if norm_sq < f32::EPSILON {
117 return x.to_vec();
118 }
119
120 let needs_projection = match &self.constraint_type {
121 LinearConstraintType::LessEq => ax > self.rhs,
122 LinearConstraintType::GreaterEq => ax < self.rhs,
123 LinearConstraintType::Equality { tolerance } => (ax - self.rhs).abs() > *tolerance,
124 };
125
126 if !needs_projection {
127 return x.to_vec();
128 }
129
130 let factor = (ax - self.rhs) / norm_sq;
132 x.iter()
133 .zip(self.coefficients.iter())
134 .map(|(xi, ai)| xi - factor * ai)
135 .collect()
136 }
137
138 pub fn weight(&self) -> f32 {
140 self.weight
141 }
142
143 pub fn coefficients(&self) -> &[f32] {
145 &self.coefficients
146 }
147
148 pub fn rhs(&self) -> f32 {
150 self.rhs
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct LinearConstraintSet {
159 constraints: Vec<LinearConstraint>,
160}
161
162impl LinearConstraintSet {
163 pub fn new(constraints: Vec<LinearConstraint>) -> Self {
165 Self { constraints }
166 }
167
168 pub fn check_all(&self, x: &[f32]) -> bool {
170 self.constraints.iter().all(|c| c.check(x))
171 }
172
173 pub fn check_each(&self, x: &[f32]) -> Vec<bool> {
175 self.constraints.iter().map(|c| c.check(x)).collect()
176 }
177
178 pub fn total_violation(&self, x: &[f32]) -> f32 {
180 self.constraints
181 .iter()
182 .map(|c| c.violation(x) * c.weight())
183 .sum()
184 }
185
186 pub fn project(&self, x: &[f32], max_iters: usize) -> Vec<f32> {
191 let mut current = x.to_vec();
192
193 for _ in 0..max_iters {
194 let prev = current.clone();
195 for c in &self.constraints {
196 current = c.project(¤t);
197 }
198 let diff: f32 = current
200 .iter()
201 .zip(prev.iter())
202 .map(|(a, b)| (a - b).abs())
203 .sum();
204 if diff < 1e-6 {
205 break;
206 }
207 }
208
209 current
210 }
211
212 pub fn len(&self) -> usize {
214 self.constraints.len()
215 }
216
217 pub fn is_empty(&self) -> bool {
219 self.constraints.is_empty()
220 }
221
222 pub fn add(&mut self, constraint: LinearConstraint) {
224 self.constraints.push(constraint);
225 }
226
227 pub fn constraints(&self) -> &[LinearConstraint] {
229 &self.constraints
230 }
231}
232
233impl Default for LinearConstraintSet {
234 fn default() -> Self {
235 Self::new(Vec::new())
236 }
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct AffineEquality {
245 matrix: Vec<Vec<f32>>,
247 rhs: Vec<f32>,
249 tolerance: f32,
251}
252
253impl AffineEquality {
254 pub fn new(matrix: Vec<Vec<f32>>, rhs: Vec<f32>, tolerance: f32) -> Self {
261 Self {
262 matrix,
263 rhs,
264 tolerance,
265 }
266 }
267
268 fn multiply(&self, x: &[f32]) -> Vec<f32> {
270 self.matrix
271 .iter()
272 .map(|row| row.iter().zip(x.iter()).map(|(a, xi)| a * xi).sum())
273 .collect()
274 }
275
276 pub fn check(&self, x: &[f32]) -> bool {
278 let ax = self.multiply(x);
279 ax.iter()
280 .zip(self.rhs.iter())
281 .all(|(axi, bi)| (axi - bi).abs() <= self.tolerance)
282 }
283
284 pub fn residual(&self, x: &[f32]) -> f32 {
286 let ax = self.multiply(x);
287 ax.iter()
288 .zip(self.rhs.iter())
289 .map(|(axi, bi)| (axi - bi).powi(2))
290 .sum::<f32>()
291 .sqrt()
292 }
293
294 pub fn violations(&self, x: &[f32]) -> Vec<f32> {
296 let ax = self.multiply(x);
297 ax.iter()
298 .zip(self.rhs.iter())
299 .map(|(axi, bi)| (axi - bi).abs())
300 .collect()
301 }
302
303 pub fn matrix(&self) -> &[Vec<f32>] {
305 &self.matrix
306 }
307
308 pub fn rhs(&self) -> &[f32] {
310 &self.rhs
311 }
312
313 pub fn num_equations(&self) -> usize {
315 self.matrix.len()
316 }
317
318 pub fn num_variables(&self) -> Option<usize> {
320 self.matrix.first().map(|row| row.len())
321 }
322}