kizzasi_logic/
projection.rs1use crate::constraint::Constraint;
4use crate::error::LogicResult;
5use scirs2_core::ndarray::Array1;
6
7#[derive(Debug, Clone)]
9pub struct ConstrainedProjection {
10 constraints: Vec<Constraint>,
11 max_iterations: usize,
12 tolerance: f32,
13}
14
15impl ConstrainedProjection {
16 pub fn new(constraints: Vec<Constraint>) -> Self {
18 Self {
19 constraints,
20 max_iterations: 10,
21 tolerance: 1e-6,
22 }
23 }
24
25 pub fn max_iterations(mut self, n: usize) -> Self {
27 self.max_iterations = n;
28 self
29 }
30
31 pub fn tolerance(mut self, tol: f32) -> Self {
33 self.tolerance = tol;
34 self
35 }
36
37 pub fn project(&self, input: &Array1<f32>) -> LogicResult<Array1<f32>> {
39 let mut result = input.clone();
40
41 for _ in 0..self.max_iterations {
42 let prev = result.clone();
43
44 for constraint in &self.constraints {
46 if let Some(dim) = constraint.dimension() {
47 if dim < result.len() {
48 result[dim] = constraint.project(result[dim]);
49 }
50 } else {
51 for val in result.iter_mut() {
53 *val = constraint.project(*val);
54 }
55 }
56 }
57
58 let diff: f32 = result
60 .iter()
61 .zip(prev.iter())
62 .map(|(a, b)| (a - b).abs())
63 .sum();
64
65 if diff < self.tolerance {
66 break;
67 }
68 }
69
70 Ok(result)
71 }
72
73 pub fn total_violation(&self, input: &Array1<f32>) -> f32 {
75 let mut total = 0.0;
76
77 for constraint in &self.constraints {
78 if let Some(dim) = constraint.dimension() {
79 if dim < input.len() {
80 total += constraint.violation(input[dim]) * constraint.weight();
81 }
82 } else {
83 for val in input.iter() {
84 total += constraint.violation(*val) * constraint.weight();
85 }
86 }
87 }
88
89 total
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::constraint::ConstraintBuilder;
97
98 #[test]
99 fn test_projection() {
100 let constraints = vec![
101 ConstraintBuilder::new()
102 .name("bound1")
103 .dimension(0)
104 .in_range(-1.0, 1.0)
105 .build()
106 .unwrap(),
107 ConstraintBuilder::new()
108 .name("bound2")
109 .dimension(1)
110 .less_eq(0.5)
111 .build()
112 .unwrap(),
113 ];
114
115 let proj = ConstrainedProjection::new(constraints);
116 let input = Array1::from_vec(vec![2.0, 1.0, 0.3]);
117 let output = proj.project(&input).unwrap();
118
119 assert_eq!(output[0], 1.0);
120 assert_eq!(output[1], 0.5);
121 assert_eq!(output[2], 0.3);
122 }
123}