kizzasi_logic/constraint/
linear.rs

1use serde::{Deserialize, Serialize};
2
3// ============================================================================
4// Linear Constraints
5// ============================================================================
6
7/// Type of linear constraint
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum LinearConstraintType {
10    /// a·x <= b
11    LessEq,
12    /// a·x >= b
13    GreaterEq,
14    /// |a·x - b| <= tolerance
15    Equality { tolerance: f32 },
16}
17
18/// Linear constraint: a·x (op) b
19///
20/// Represents a single linear constraint on a vector of values.
21/// - Inequality: a·x <= b or a·x >= b
22/// - Equality: |a·x - b| <= tolerance
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct LinearConstraint {
25    /// Coefficient vector `a`
26    coefficients: Vec<f32>,
27    /// Right-hand side `b`
28    rhs: f32,
29    /// Constraint type
30    constraint_type: LinearConstraintType,
31    /// Weight for loss computation
32    weight: f32,
33}
34
35impl LinearConstraint {
36    /// Create a less-than-or-equal constraint: a·x <= b
37    pub fn less_eq(coefficients: Vec<f32>, rhs: f32) -> Self {
38        Self {
39            coefficients,
40            rhs,
41            constraint_type: LinearConstraintType::LessEq,
42            weight: 1.0,
43        }
44    }
45
46    /// Create a greater-than-or-equal constraint: a·x >= b
47    pub fn greater_eq(coefficients: Vec<f32>, rhs: f32) -> Self {
48        Self {
49            coefficients,
50            rhs,
51            constraint_type: LinearConstraintType::GreaterEq,
52            weight: 1.0,
53        }
54    }
55
56    /// Create an equality constraint: |a·x - b| <= tolerance
57    pub fn equality(coefficients: Vec<f32>, rhs: f32, tolerance: f32) -> Self {
58        Self {
59            coefficients,
60            rhs,
61            constraint_type: LinearConstraintType::Equality { tolerance },
62            weight: 1.0,
63        }
64    }
65
66    /// Set the constraint weight
67    pub fn with_weight(mut self, weight: f32) -> Self {
68        self.weight = weight;
69        self
70    }
71
72    /// Compute dot product a·x
73    fn dot(&self, x: &[f32]) -> f32 {
74        self.coefficients
75            .iter()
76            .zip(x.iter())
77            .map(|(a, xi)| a * xi)
78            .sum()
79    }
80
81    /// Compute squared norm of coefficients ||a||²
82    fn norm_sq(&self) -> f32 {
83        self.coefficients.iter().map(|a| a * a).sum()
84    }
85
86    /// Check if values satisfy the constraint
87    pub fn check(&self, x: &[f32]) -> bool {
88        let ax = self.dot(x);
89        match &self.constraint_type {
90            LinearConstraintType::LessEq => ax <= self.rhs,
91            LinearConstraintType::GreaterEq => ax >= self.rhs,
92            LinearConstraintType::Equality { tolerance } => (ax - self.rhs).abs() <= *tolerance,
93        }
94    }
95
96    /// Compute violation amount (0 if satisfied)
97    pub fn violation(&self, x: &[f32]) -> f32 {
98        let ax = self.dot(x);
99        match &self.constraint_type {
100            LinearConstraintType::LessEq => (ax - self.rhs).max(0.0),
101            LinearConstraintType::GreaterEq => (self.rhs - ax).max(0.0),
102            LinearConstraintType::Equality { tolerance } => {
103                let diff = (ax - self.rhs).abs();
104                (diff - tolerance).max(0.0)
105            }
106        }
107    }
108
109    /// Project x onto the constraint (closest point that satisfies)
110    ///
111    /// For a·x <= b or a·x >= b, uses orthogonal projection onto hyperplane.
112    pub fn project(&self, x: &[f32]) -> Vec<f32> {
113        let ax = self.dot(x);
114        let norm_sq = self.norm_sq();
115
116        if norm_sq < f32::EPSILON {
117            return x.to_vec();
118        }
119
120        let needs_projection = match &self.constraint_type {
121            LinearConstraintType::LessEq => ax > self.rhs,
122            LinearConstraintType::GreaterEq => ax < self.rhs,
123            LinearConstraintType::Equality { tolerance } => (ax - self.rhs).abs() > *tolerance,
124        };
125
126        if !needs_projection {
127            return x.to_vec();
128        }
129
130        // Orthogonal projection: x' = x - ((a·x - b) / ||a||²) * a
131        let factor = (ax - self.rhs) / norm_sq;
132        x.iter()
133            .zip(self.coefficients.iter())
134            .map(|(xi, ai)| xi - factor * ai)
135            .collect()
136    }
137
138    /// Get the weight
139    pub fn weight(&self) -> f32 {
140        self.weight
141    }
142
143    /// Get the coefficients
144    pub fn coefficients(&self) -> &[f32] {
145        &self.coefficients
146    }
147
148    /// Get the right-hand side
149    pub fn rhs(&self) -> f32 {
150        self.rhs
151    }
152}
153
154/// A set of linear constraints
155///
156/// Useful for representing polyhedral constraints (intersection of half-spaces).
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct LinearConstraintSet {
159    constraints: Vec<LinearConstraint>,
160}
161
162impl LinearConstraintSet {
163    /// Create a new constraint set
164    pub fn new(constraints: Vec<LinearConstraint>) -> Self {
165        Self { constraints }
166    }
167
168    /// Check if all constraints are satisfied
169    pub fn check_all(&self, x: &[f32]) -> bool {
170        self.constraints.iter().all(|c| c.check(x))
171    }
172
173    /// Get individual check results for each constraint
174    pub fn check_each(&self, x: &[f32]) -> Vec<bool> {
175        self.constraints.iter().map(|c| c.check(x)).collect()
176    }
177
178    /// Compute total weighted violation
179    pub fn total_violation(&self, x: &[f32]) -> f32 {
180        self.constraints
181            .iter()
182            .map(|c| c.violation(x) * c.weight())
183            .sum()
184    }
185
186    /// Project x onto feasible region using cyclic projection (Dykstra's)
187    ///
188    /// Iteratively projects onto each constraint. May not converge to
189    /// true projection for non-convex feasible regions.
190    pub fn project(&self, x: &[f32], max_iters: usize) -> Vec<f32> {
191        let mut current = x.to_vec();
192
193        for _ in 0..max_iters {
194            let prev = current.clone();
195            for c in &self.constraints {
196                current = c.project(&current);
197            }
198            // Check convergence
199            let diff: f32 = current
200                .iter()
201                .zip(prev.iter())
202                .map(|(a, b)| (a - b).abs())
203                .sum();
204            if diff < 1e-6 {
205                break;
206            }
207        }
208
209        current
210    }
211
212    /// Get the number of constraints
213    pub fn len(&self) -> usize {
214        self.constraints.len()
215    }
216
217    /// Check if the set is empty
218    pub fn is_empty(&self) -> bool {
219        self.constraints.is_empty()
220    }
221
222    /// Add a constraint to the set
223    pub fn add(&mut self, constraint: LinearConstraint) {
224        self.constraints.push(constraint);
225    }
226
227    /// Get constraints
228    pub fn constraints(&self) -> &[LinearConstraint] {
229        &self.constraints
230    }
231}
232
233impl Default for LinearConstraintSet {
234    fn default() -> Self {
235        Self::new(Vec::new())
236    }
237}
238
239/// Affine equality constraint: Ax = b
240///
241/// Represents a system of linear equality constraints where A is an m×n matrix
242/// and b is an m-vector.
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct AffineEquality {
245    /// Matrix A stored row-major (m rows, each with n elements)
246    matrix: Vec<Vec<f32>>,
247    /// Right-hand side vector b (m elements)
248    rhs: Vec<f32>,
249    /// Tolerance for equality check
250    tolerance: f32,
251}
252
253impl AffineEquality {
254    /// Create a new affine equality constraint Ax = b
255    ///
256    /// # Arguments
257    /// * `matrix` - Matrix A as row-major vec of vecs (m rows × n cols)
258    /// * `rhs` - Vector b with m elements
259    /// * `tolerance` - Tolerance for equality check
260    pub fn new(matrix: Vec<Vec<f32>>, rhs: Vec<f32>, tolerance: f32) -> Self {
261        Self {
262            matrix,
263            rhs,
264            tolerance,
265        }
266    }
267
268    /// Compute Ax
269    fn multiply(&self, x: &[f32]) -> Vec<f32> {
270        self.matrix
271            .iter()
272            .map(|row| row.iter().zip(x.iter()).map(|(a, xi)| a * xi).sum())
273            .collect()
274    }
275
276    /// Check if Ax ≈ b (within tolerance)
277    pub fn check(&self, x: &[f32]) -> bool {
278        let ax = self.multiply(x);
279        ax.iter()
280            .zip(self.rhs.iter())
281            .all(|(axi, bi)| (axi - bi).abs() <= self.tolerance)
282    }
283
284    /// Compute residual ||Ax - b||
285    pub fn residual(&self, x: &[f32]) -> f32 {
286        let ax = self.multiply(x);
287        ax.iter()
288            .zip(self.rhs.iter())
289            .map(|(axi, bi)| (axi - bi).powi(2))
290            .sum::<f32>()
291            .sqrt()
292    }
293
294    /// Compute element-wise violations (Ax - b)
295    pub fn violations(&self, x: &[f32]) -> Vec<f32> {
296        let ax = self.multiply(x);
297        ax.iter()
298            .zip(self.rhs.iter())
299            .map(|(axi, bi)| (axi - bi).abs())
300            .collect()
301    }
302
303    /// Get the matrix
304    pub fn matrix(&self) -> &[Vec<f32>] {
305        &self.matrix
306    }
307
308    /// Get the right-hand side
309    pub fn rhs(&self) -> &[f32] {
310        &self.rhs
311    }
312
313    /// Get the number of equations (rows)
314    pub fn num_equations(&self) -> usize {
315        self.matrix.len()
316    }
317
318    /// Get the number of variables (columns)
319    pub fn num_variables(&self) -> Option<usize> {
320        self.matrix.first().map(|row| row.len())
321    }
322}