kizzasi_logic/
projection.rs

1//! Projection onto constraint manifolds
2
3use crate::constraint::Constraint;
4use crate::error::LogicResult;
5use scirs2_core::ndarray::Array1;
6
7/// Projects predictions onto valid constraint manifolds
8#[derive(Debug, Clone)]
9pub struct ConstrainedProjection {
10    constraints: Vec<Constraint>,
11    max_iterations: usize,
12    tolerance: f32,
13}
14
15impl ConstrainedProjection {
16    /// Create a new constrained projection
17    pub fn new(constraints: Vec<Constraint>) -> Self {
18        Self {
19            constraints,
20            max_iterations: 10,
21            tolerance: 1e-6,
22        }
23    }
24
25    /// Set maximum projection iterations
26    pub fn max_iterations(mut self, n: usize) -> Self {
27        self.max_iterations = n;
28        self
29    }
30
31    /// Set convergence tolerance
32    pub fn tolerance(mut self, tol: f32) -> Self {
33        self.tolerance = tol;
34        self
35    }
36
37    /// Project a vector onto the valid manifold using iterative projection
38    pub fn project(&self, input: &Array1<f32>) -> LogicResult<Array1<f32>> {
39        let mut result = input.clone();
40
41        for _ in 0..self.max_iterations {
42            let prev = result.clone();
43
44            // Apply each constraint's projection
45            for constraint in &self.constraints {
46                if let Some(dim) = constraint.dimension() {
47                    if dim < result.len() {
48                        result[dim] = constraint.project(result[dim]);
49                    }
50                } else {
51                    // Apply to all dimensions
52                    for val in result.iter_mut() {
53                        *val = constraint.project(*val);
54                    }
55                }
56            }
57
58            // Check convergence
59            let diff: f32 = result
60                .iter()
61                .zip(prev.iter())
62                .map(|(a, b)| (a - b).abs())
63                .sum();
64
65            if diff < self.tolerance {
66                break;
67            }
68        }
69
70        Ok(result)
71    }
72
73    /// Compute total violation of all constraints
74    pub fn total_violation(&self, input: &Array1<f32>) -> f32 {
75        let mut total = 0.0;
76
77        for constraint in &self.constraints {
78            if let Some(dim) = constraint.dimension() {
79                if dim < input.len() {
80                    total += constraint.violation(input[dim]) * constraint.weight();
81                }
82            } else {
83                for val in input.iter() {
84                    total += constraint.violation(*val) * constraint.weight();
85                }
86            }
87        }
88
89        total
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::constraint::ConstraintBuilder;
97
98    #[test]
99    fn test_projection() {
100        let constraints = vec![
101            ConstraintBuilder::new()
102                .name("bound1")
103                .dimension(0)
104                .in_range(-1.0, 1.0)
105                .build()
106                .unwrap(),
107            ConstraintBuilder::new()
108                .name("bound2")
109                .dimension(1)
110                .less_eq(0.5)
111                .build()
112                .unwrap(),
113        ];
114
115        let proj = ConstrainedProjection::new(constraints);
116        let input = Array1::from_vec(vec![2.0, 1.0, 0.3]);
117        let output = proj.project(&input).unwrap();
118
119        assert_eq!(output[0], 1.0);
120        assert_eq!(output[1], 0.5);
121        assert_eq!(output[2], 0.3);
122    }
123}