1use crate::constraint::{
7 LinearConstraint, NonlinearConstraint, SetMembershipConstraint, ViolationComputable,
8};
9use crate::error::LogicResult;
10use scirs2_core::ndarray::Array1;
11
12pub struct DykstraProjection<C> {
21 constraints: Vec<C>,
22 max_iterations: usize,
23 tolerance: f32,
24}
25
26impl<C: ViolationComputable> DykstraProjection<C> {
27 pub fn new(constraints: Vec<C>) -> Self {
29 Self {
30 constraints,
31 max_iterations: 100,
32 tolerance: 1e-6,
33 }
34 }
35
36 pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
38 self.max_iterations = max_iter;
39 self
40 }
41
42 pub fn with_tolerance(mut self, tol: f32) -> Self {
44 self.tolerance = tol;
45 self
46 }
47
48 pub fn num_constraints(&self) -> usize {
50 self.constraints.len()
51 }
52}
53
54impl DykstraProjection<LinearConstraint> {
55 pub fn project(&self, x: &Array1<f32>) -> LogicResult<Array1<f32>> {
57 let n = x.len();
58 let m = self.constraints.len();
59
60 let mut y = x.clone();
62
63 let mut increments: Vec<Array1<f32>> = vec![Array1::zeros(n); m];
65
66 for _iter in 0..self.max_iterations {
67 let y_old = y.clone();
68
69 for (i, constraint) in self.constraints.iter().enumerate() {
71 let z = &y + &increments[i];
73
74 let projected = constraint.project(
76 z.as_slice()
77 .expect("Array must have contiguous layout for projection"),
78 );
79 let p = Array1::from_vec(projected);
80
81 increments[i] = &z - &p;
83
84 y = p;
86 }
87
88 let diff = (&y - &y_old)
90 .iter()
91 .map(|&d| d.abs())
92 .fold(0.0f32, |a, b| a.max(b));
93
94 if diff < self.tolerance {
95 break;
96 }
97 }
98
99 Ok(y)
100 }
101}
102
103impl DykstraProjection<SetMembershipConstraint> {
104 pub fn project(&self, x: &Array1<f32>) -> LogicResult<Array1<f32>> {
106 let n = x.len();
107 let m = self.constraints.len();
108
109 let mut y = x.clone();
110 let mut increments: Vec<Array1<f32>> = vec![Array1::zeros(n); m];
111
112 for _iter in 0..self.max_iterations {
113 let y_old = y.clone();
114
115 for (i, constraint) in self.constraints.iter().enumerate() {
116 let z = &y + &increments[i];
117 let projected = constraint.project(
118 z.as_slice()
119 .expect("Array must have contiguous layout for projection"),
120 );
121 let p = Array1::from_vec(projected);
122 increments[i] = &z - &p;
123 y = p;
124 }
125
126 let diff = (&y - &y_old)
127 .iter()
128 .map(|&d| d.abs())
129 .fold(0.0f32, |a, b| a.max(b));
130
131 if diff < self.tolerance {
132 break;
133 }
134 }
135
136 Ok(y)
137 }
138}
139
140pub struct GradientProjection {
148 max_iterations: usize,
149 step_size: f32,
150 tolerance: f32,
151}
152
153impl GradientProjection {
154 pub fn new() -> Self {
156 Self {
157 max_iterations: 1000,
158 step_size: 0.01,
159 tolerance: 1e-6,
160 }
161 }
162
163 pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
165 self.max_iterations = max_iter;
166 self
167 }
168
169 pub fn with_step_size(mut self, step: f32) -> Self {
171 self.step_size = step;
172 self
173 }
174
175 pub fn with_tolerance(mut self, tol: f32) -> Self {
177 self.tolerance = tol;
178 self
179 }
180
181 pub fn project(
183 &self,
184 x: &Array1<f32>,
185 constraints: &[NonlinearConstraint],
186 ) -> LogicResult<Array1<f32>> {
187 let mut result = x.clone();
188
189 for _iter in 0..self.max_iterations {
190 if constraints.iter().all(|c| {
192 c.check(
193 result
194 .as_slice()
195 .expect("Array must have contiguous layout for projection"),
196 )
197 }) {
198 break;
199 }
200
201 let prev = result.clone();
202
203 let mut total_grad: Array1<f32> = Array1::zeros(x.len());
205 let mut has_gradient = false;
206
207 for constraint in constraints {
208 if !constraint.check(
209 result
210 .as_slice()
211 .expect("Array must have contiguous layout for projection"),
212 ) {
213 if let Some(grad) = constraint.gradient(
214 result
215 .as_slice()
216 .expect("Array must have contiguous layout for projection"),
217 ) {
218 let violation = constraint.violation(
219 result
220 .as_slice()
221 .expect("Array must have contiguous layout for projection"),
222 );
223 for (i, &gi) in grad.iter().enumerate() {
224 total_grad[i] += violation * gi;
225 }
226 has_gradient = true;
227 }
228 }
229 }
230
231 if !has_gradient {
232 break;
234 }
235
236 for (ri, &gi) in result.iter_mut().zip(total_grad.iter()) {
238 *ri -= self.step_size * gi;
239 }
240
241 let diff = (&result - &prev)
243 .iter()
244 .map(|&d| d.abs())
245 .fold(0.0f32, |a, b| a.max(b));
246
247 if diff < self.tolerance {
248 break;
249 }
250 }
251
252 Ok(result)
253 }
254
255 pub fn project_adaptive(
257 &self,
258 x: &Array1<f32>,
259 constraints: &[NonlinearConstraint],
260 ) -> LogicResult<Array1<f32>> {
261 let mut result = x.clone();
262 let mut step_size = self.step_size;
263
264 for _iter in 0..self.max_iterations {
265 if constraints.iter().all(|c| {
266 c.check(
267 result
268 .as_slice()
269 .expect("Array must have contiguous layout for projection"),
270 )
271 }) {
272 break;
273 }
274
275 let mut total_grad: Array1<f32> = Array1::zeros(x.len());
277 let mut current_violation = 0.0;
278
279 for constraint in constraints {
280 let viol = constraint.violation(
281 result
282 .as_slice()
283 .expect("Array must have contiguous layout for projection"),
284 );
285 current_violation += viol;
286
287 if viol > 0.0 {
288 if let Some(grad) = constraint.gradient(
289 result
290 .as_slice()
291 .expect("Array must have contiguous layout for projection"),
292 ) {
293 for (i, &gi) in grad.iter().enumerate() {
294 total_grad[i] += viol * gi;
295 }
296 }
297 }
298 }
299
300 let mut alpha = step_size;
302 for _ in 0..10 {
303 let mut candidate = result.clone();
304 for (ci, &gi) in candidate.iter_mut().zip(total_grad.iter()) {
305 *ci -= alpha * gi;
306 }
307
308 let new_violation: f32 = constraints
310 .iter()
311 .map(|c| {
312 c.violation(
313 candidate
314 .as_slice()
315 .expect("Array must have contiguous layout for projection"),
316 )
317 })
318 .sum();
319
320 if new_violation < current_violation {
321 result = candidate;
322 step_size = (alpha * 1.1).min(1.0); break;
324 }
325
326 alpha *= 0.5; }
328 }
329
330 Ok(result)
331 }
332}
333
334impl Default for GradientProjection {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340pub struct AugmentedLagrangian {
349 max_outer_iterations: usize,
350 max_inner_iterations: usize,
351 penalty_parameter: f32,
352 penalty_increase_factor: f32,
353 tolerance: f32,
354 step_size: f32,
355}
356
357impl AugmentedLagrangian {
358 pub fn new() -> Self {
360 Self {
361 max_outer_iterations: 20,
362 max_inner_iterations: 100,
363 penalty_parameter: 1.0,
364 penalty_increase_factor: 10.0,
365 tolerance: 1e-5,
366 step_size: 0.01,
367 }
368 }
369
370 pub fn with_max_outer_iterations(mut self, max_iter: usize) -> Self {
372 self.max_outer_iterations = max_iter;
373 self
374 }
375
376 pub fn with_penalty_parameter(mut self, rho: f32) -> Self {
378 self.penalty_parameter = rho;
379 self
380 }
381
382 pub fn project(
386 &self,
387 x0: &Array1<f32>,
388 constraints: &[NonlinearConstraint],
389 ) -> LogicResult<Array1<f32>> {
390 let n = x0.len();
391 let m = constraints.len();
392
393 let mut x = x0.clone();
394 let mut lambda = vec![0.0f32; m]; let mut rho = self.penalty_parameter;
396
397 for _outer in 0..self.max_outer_iterations {
398 for _inner in 0..self.max_inner_iterations {
400 let mut grad = Array1::zeros(n);
402
403 for (i, (&xi, &x0i)) in x.iter().zip(x0.iter()).enumerate() {
405 grad[i] = 2.0 * (xi - x0i);
406 }
407
408 for (j, constraint) in constraints.iter().enumerate() {
410 let g_j = constraint.evaluate(
411 x.as_slice()
412 .expect("Array must have contiguous layout for projection"),
413 );
414
415 if let Some(grad_g) = constraint.gradient(
416 x.as_slice()
417 .expect("Array must have contiguous layout for projection"),
418 ) {
419 let factor = lambda[j] + rho * g_j.max(0.0);
420 for (i, &dg) in grad_g.iter().enumerate() {
421 grad[i] += factor * dg;
422 }
423 }
424 }
425
426 for (xi, &gi) in x.iter_mut().zip(grad.iter()) {
428 *xi -= self.step_size * gi;
429 }
430 }
431
432 for (j, constraint) in constraints.iter().enumerate() {
434 let g_j = constraint.evaluate(
435 x.as_slice()
436 .expect("Array must have contiguous layout for projection"),
437 );
438 lambda[j] = (lambda[j] + rho * g_j).max(0.0);
439 }
440
441 let max_violation: f32 = constraints
443 .iter()
444 .map(|c| {
445 c.violation(
446 x.as_slice()
447 .expect("Array must have contiguous layout for projection"),
448 )
449 })
450 .fold(0.0f32, |a, b| a.max(b));
451
452 if max_violation < self.tolerance {
453 break;
454 }
455
456 rho *= self.penalty_increase_factor;
458 }
459
460 Ok(x)
461 }
462}
463
464impl Default for AugmentedLagrangian {
465 fn default() -> Self {
466 Self::new()
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::constraint::{GeometricSet, LinearConstraint};
474
475 #[test]
476 fn test_dykstra_linear_constraints() {
477 let c1 = LinearConstraint::less_eq(vec![1.0], 5.0);
479 let c2 = LinearConstraint::greater_eq(vec![1.0], 0.0);
480
481 let dykstra = DykstraProjection::new(vec![c1, c2]).with_tolerance(1e-6);
482
483 let x = Array1::from_vec(vec![-1.0]);
485 let projected = dykstra.project(&x).unwrap();
486 assert!((projected[0] - 0.0).abs() < 1e-5);
487
488 let x = Array1::from_vec(vec![10.0]);
490 let projected = dykstra.project(&x).unwrap();
491 assert!((projected[0] - 5.0).abs() < 1e-5);
492 }
493
494 #[test]
495 fn test_gradient_projection() {
496 let constraint =
498 NonlinearConstraint::inequality("x_squared", |x: &[f32]| x[0] * x[0] - 1.0)
499 .with_gradient(|x: &[f32]| vec![2.0 * x[0]]);
500
501 let proj = GradientProjection::new()
502 .with_max_iterations(100)
503 .with_step_size(0.1);
504
505 let x = Array1::from_vec(vec![2.0]);
507 let projected = proj.project(&x, &[constraint]).unwrap();
508 assert!(projected[0] < 2.0); assert!((projected[0] * projected[0] - 1.0).abs() < 0.5); }
512
513 #[test]
514 fn test_dykstra_set_constraints() {
515 let set1 = GeometricSet::ball(vec![0.0, 0.0], 2.0);
517 let set2 = GeometricSet::ball(vec![3.0, 0.0], 2.0);
518
519 let c1 = SetMembershipConstraint::new("ball1", set1);
520 let c2 = SetMembershipConstraint::new("ball2", set2);
521
522 let dykstra = DykstraProjection::new(vec![c1, c2]).with_tolerance(1e-5);
523
524 let x = Array1::from_vec(vec![1.5, 0.0]);
526 let projected = dykstra.project(&x).unwrap();
527
528 assert!(projected[0] >= 1.0 && projected[0] <= 2.0);
530 assert!(projected[1].abs() < 0.5);
531 }
532}