1use crate::error::{OptimizeError, OptimizeResult};
29use crate::result::OptimizeResults;
30use scirs2_core::ndarray::{Array1, ArrayView1};
31
32#[derive(Debug, Clone, Copy, PartialEq)]
34pub enum PenaltyKind {
35 External,
37 Interior,
39 ExactL1,
41}
42
43#[derive(Debug, Clone)]
45pub struct PenaltyOptions {
46 pub kind: PenaltyKind,
48 pub mu_init: f64,
50 pub mu_max: f64,
52 pub mu_factor: f64,
54 pub max_outer_iter: usize,
56 pub constraint_tol: f64,
58 pub optimality_tol: f64,
60 pub eps: f64,
62}
63
64impl Default for PenaltyOptions {
65 fn default() -> Self {
66 PenaltyOptions {
67 kind: PenaltyKind::External,
68 mu_init: 1.0,
69 mu_max: 1e10,
70 mu_factor: 10.0,
71 max_outer_iter: 100,
72 constraint_tol: 1e-6,
73 optimality_tol: 1e-8,
74 eps: 1e-7,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct PenaltyResult {
82 pub x: Array1<f64>,
84 pub fun: f64,
86 pub nit: usize,
88 pub nfev: usize,
90 pub success: bool,
92 pub message: String,
94 pub mu: f64,
96 pub constraint_violation: f64,
98}
99
100impl From<PenaltyResult> for OptimizeResults<f64> {
101 fn from(r: PenaltyResult) -> Self {
102 OptimizeResults {
103 x: r.x,
104 fun: r.fun,
105 jac: None,
106 hess: None,
107 constr: None,
108 nit: r.nit,
109 nfev: r.nfev,
110 njev: 0,
111 nhev: 0,
112 maxcv: 0,
113 message: r.message,
114 success: r.success,
115 status: if r.success { 0 } else { 1 },
116 }
117 }
118}
119
120fn minimize_penalty_subproblem<P>(
123 penalty_fn: P,
124 x0: &[f64],
125 max_iter: usize,
126 gtol: f64,
127 eps: f64,
128 nfev: &mut usize,
129) -> Vec<f64>
130where
131 P: Fn(&[f64]) -> f64,
132{
133 let n = x0.len();
134 let mut x = x0.to_vec();
135 let m = 5usize; let mut s_hist: Vec<Vec<f64>> = Vec::new();
138 let mut y_hist: Vec<Vec<f64>> = Vec::new();
139 let mut rho_hist: Vec<f64> = Vec::new();
140
141 let compute_grad = |xv: &[f64], nfev: &mut usize| -> Vec<f64> {
142 let h = eps;
143 let mut g = vec![0.0; n];
144 let mut xp = xv.to_vec();
145 let mut xm = xv.to_vec();
146 *nfev += 2 * n;
147 for i in 0..n {
148 xp[i] = xv[i] + h;
149 xm[i] = xv[i] - h;
150 g[i] = (penalty_fn(&xp) - penalty_fn(&xm)) / (2.0 * h);
151 xp[i] = xv[i];
152 xm[i] = xv[i];
153 }
154 g
155 };
156
157 let mut g = compute_grad(&x, nfev);
158
159 for _iter in 0..max_iter {
160 let gnorm: f64 = g.iter().map(|v| v * v).sum::<f64>().sqrt();
161 if gnorm < gtol {
162 break;
163 }
164
165 let mut q = g.clone();
167 let hist_len = s_hist.len();
168 let mut alpha_hist = vec![0.0_f64; hist_len];
169
170 for i in (0..hist_len).rev() {
171 let si = &s_hist[i];
172 let yi = &y_hist[i];
173 let rho_i = rho_hist[i];
174 let dot: f64 = si.iter().zip(q.iter()).map(|(&s, &qi)| s * qi).sum();
175 alpha_hist[i] = rho_i * dot;
176 let a = alpha_hist[i];
177 for j in 0..n {
178 q[j] -= a * yi[j];
179 }
180 }
181
182 let mut r = if hist_len > 0 {
184 let last_s = &s_hist[hist_len - 1];
185 let last_y = &y_hist[hist_len - 1];
186 let sy: f64 = last_s.iter().zip(last_y.iter()).map(|(&s, &y)| s * y).sum();
187 let yy: f64 = last_y.iter().map(|v| v * v).sum();
188 let scale = if yy > 1e-15 { sy / yy } else { 1.0 };
189 q.iter().map(|&qi| scale * qi).collect::<Vec<f64>>()
190 } else {
191 q.clone()
192 };
193
194 for i in 0..hist_len {
195 let si = &s_hist[i];
196 let yi = &y_hist[i];
197 let rho_i = rho_hist[i];
198 let dot: f64 = yi.iter().zip(r.iter()).map(|(&y, &ri)| y * ri).sum();
199 let beta = rho_i * dot;
200 let diff = alpha_hist[i] - beta;
201 for j in 0..n {
202 r[j] += si[j] * diff;
203 }
204 }
205
206 let d: Vec<f64> = r.iter().map(|v| -v).collect();
208
209 *nfev += 1;
211 let fx = penalty_fn(&x);
212 let dg: f64 = d.iter().zip(g.iter()).map(|(&di, &gi)| di * gi).sum();
213 let mut alpha = 1.0_f64;
214
215 for _ls in 0..20 {
216 let xnew: Vec<f64> = x
217 .iter()
218 .zip(d.iter())
219 .map(|(&xi, &di)| xi + alpha * di)
220 .collect();
221 *nfev += 1;
222 let fnew = penalty_fn(&xnew);
223 if fnew <= fx + 1e-4 * alpha * dg.min(0.0) {
224 break;
225 }
226 alpha *= 0.5;
227 }
228
229 let xnew: Vec<f64> = x
230 .iter()
231 .zip(d.iter())
232 .map(|(&xi, &di)| xi + alpha * di)
233 .collect();
234 let gnew = compute_grad(&xnew, nfev);
235
236 let s: Vec<f64> = xnew
238 .iter()
239 .zip(x.iter())
240 .map(|(&xni, &xi)| xni - xi)
241 .collect();
242 let y: Vec<f64> = gnew
243 .iter()
244 .zip(g.iter())
245 .map(|(&gni, &gi)| gni - gi)
246 .collect();
247 let sy: f64 = s.iter().zip(y.iter()).map(|(&si, &yi)| si * yi).sum();
248
249 if sy > 1e-10 {
250 if s_hist.len() >= m {
251 s_hist.remove(0);
252 y_hist.remove(0);
253 rho_hist.remove(0);
254 }
255 s_hist.push(s);
256 y_hist.push(y);
257 rho_hist.push(1.0 / sy);
258 }
259
260 x = xnew;
261 g = gnew;
262 }
263
264 x
265}
266
267pub struct PenaltyMethod {
269 pub options: PenaltyOptions,
270}
271
272impl PenaltyMethod {
273 pub fn new() -> Self {
275 PenaltyMethod {
276 options: PenaltyOptions::default(),
277 }
278 }
279
280 pub fn with_options(options: PenaltyOptions) -> Self {
282 PenaltyMethod { options }
283 }
284
285 fn compute_violation<E, G>(&self, x: &[f64], eq_cons: &[E], ineq_cons: &[G]) -> f64
287 where
288 E: Fn(&[f64]) -> f64,
289 G: Fn(&[f64]) -> f64,
290 {
291 let eq_viol: f64 = eq_cons.iter().map(|e| e(x).powi(2)).sum();
292 let ineq_viol: f64 = ineq_cons.iter().map(|g| g(x).max(0.0).powi(2)).sum();
293 (eq_viol + ineq_viol).sqrt()
294 }
295
296 pub fn solve_slice<F, E, G>(
304 &self,
305 f: F,
306 eq_cons: &[E],
307 ineq_cons: &[G],
308 x0: &[f64],
309 ) -> OptimizeResult<PenaltyResult>
310 where
311 F: Fn(&[f64]) -> f64,
312 E: Fn(&[f64]) -> f64,
313 G: Fn(&[f64]) -> f64,
314 {
315 match self.options.kind {
316 PenaltyKind::External => self.solve_external_slice(f, eq_cons, ineq_cons, x0),
317 PenaltyKind::Interior => self.solve_interior_slice(f, ineq_cons, x0),
318 PenaltyKind::ExactL1 => self.solve_l1_slice(f, eq_cons, ineq_cons, x0),
319 }
320 }
321
322 pub fn solve<F, E, G>(
324 &self,
325 f: F,
326 eq_cons: &[E],
327 ineq_cons: &[G],
328 x0: &Array1<f64>,
329 ) -> OptimizeResult<PenaltyResult>
330 where
331 F: Fn(&ArrayView1<f64>) -> f64,
332 E: Fn(&ArrayView1<f64>) -> f64,
333 G: Fn(&ArrayView1<f64>) -> f64,
334 {
335 let f_slice = |x: &[f64]| {
337 let arr = Array1::from_vec(x.to_vec());
338 f(&arr.view())
339 };
340 let eq_slice: Vec<Box<dyn Fn(&[f64]) -> f64>> = eq_cons
341 .iter()
342 .map(|e| {
343 Box::new(move |x: &[f64]| {
344 let arr = Array1::from_vec(x.to_vec());
345 e(&arr.view())
346 }) as Box<dyn Fn(&[f64]) -> f64>
347 })
348 .collect();
349 let ineq_slice: Vec<Box<dyn Fn(&[f64]) -> f64>> = ineq_cons
350 .iter()
351 .map(|g| {
352 Box::new(move |x: &[f64]| {
353 let arr = Array1::from_vec(x.to_vec());
354 g(&arr.view())
355 }) as Box<dyn Fn(&[f64]) -> f64>
356 })
357 .collect();
358
359 let x0_slice: Vec<f64> = x0.iter().copied().collect();
360 self.solve_slice(f_slice, &eq_slice, &ineq_slice, &x0_slice)
361 }
362
363 fn solve_external_slice<F, E, G>(
364 &self,
365 f: F,
366 eq_cons: &[E],
367 ineq_cons: &[G],
368 x0: &[f64],
369 ) -> OptimizeResult<PenaltyResult>
370 where
371 F: Fn(&[f64]) -> f64,
372 E: Fn(&[f64]) -> f64,
373 G: Fn(&[f64]) -> f64,
374 {
375 let mut x = x0.to_vec();
376 let mut mu = self.options.mu_init;
377 let mut nfev_total = 0usize;
378 let mut nit = 0usize;
379
380 for _outer in 0..self.options.max_outer_iter {
381 nit += 1;
382 let mu_local = mu;
383
384 let penalty_at = |xv: &[f64]| -> f64 {
386 let obj = f(xv);
387 let penalty: f64 = eq_cons
388 .iter()
389 .map(|e| mu_local * e(xv).powi(2))
390 .sum::<f64>()
391 + ineq_cons
392 .iter()
393 .map(|g| mu_local * g(xv).max(0.0).powi(2))
394 .sum::<f64>();
395 obj + penalty
396 };
397
398 let new_x = minimize_penalty_subproblem(
399 penalty_at,
400 &x,
401 1000,
402 self.options.optimality_tol,
403 self.options.eps,
404 &mut nfev_total,
405 );
406 x = new_x;
407
408 let cv = self.compute_violation(&x, eq_cons, ineq_cons);
410 if cv <= self.options.constraint_tol {
411 let fun = f(&x);
412 return Ok(PenaltyResult {
413 x: Array1::from_vec(x),
414 fun,
415 nit,
416 nfev: nfev_total,
417 success: true,
418 message: "Converged: constraint violation below tolerance".to_string(),
419 mu,
420 constraint_violation: cv,
421 });
422 }
423
424 mu = (mu * self.options.mu_factor).min(self.options.mu_max);
425 }
426
427 let cv = self.compute_violation(&x, eq_cons, ineq_cons);
428 let fun = f(&x);
429 let success = cv <= self.options.constraint_tol;
430 Ok(PenaltyResult {
431 x: Array1::from_vec(x),
432 fun,
433 nit,
434 nfev: nfev_total,
435 success,
436 message: if success {
437 "Converged".to_string()
438 } else {
439 format!("Maximum outer iterations reached (cv={:.2e})", cv)
440 },
441 mu,
442 constraint_violation: cv,
443 })
444 }
445
446 fn solve_interior_slice<F, G>(
447 &self,
448 f: F,
449 ineq_cons: &[G],
450 x0: &[f64],
451 ) -> OptimizeResult<PenaltyResult>
452 where
453 F: Fn(&[f64]) -> f64,
454 G: Fn(&[f64]) -> f64,
455 {
456 for g in ineq_cons.iter() {
458 if g(x0) >= 0.0 {
459 return Err(OptimizeError::InvalidInput(
460 "Interior penalty requires strictly feasible starting point (g_i(x0) < 0)"
461 .to_string(),
462 ));
463 }
464 }
465
466 let mut x = x0.to_vec();
467 let mut mu = self.options.mu_init;
468 let mut nfev_total = 0usize;
469 let mut nit = 0usize;
470
471 for _outer in 0..self.options.max_outer_iter {
472 nit += 1;
473 if mu < 1e-12 {
474 break;
475 }
476
477 let mu_local = mu;
478
479 let barrier_fn = |xv: &[f64]| -> f64 {
480 let obj = f(xv);
481 let barrier: f64 = ineq_cons
482 .iter()
483 .map(|g| {
484 let gv = g(xv);
485 if gv < -1e-15 {
486 -mu_local * gv.abs().ln()
487 } else {
488 f64::INFINITY
489 }
490 })
491 .sum();
492 obj + barrier
493 };
494
495 let new_x = minimize_penalty_subproblem(
496 barrier_fn,
497 &x,
498 500,
499 mu * 0.01,
500 self.options.eps,
501 &mut nfev_total,
502 );
503
504 let feasible = ineq_cons.iter().all(|g| g(&new_x) < 0.0);
506 if feasible {
507 x = new_x;
508 }
509
510 mu /= self.options.mu_factor;
511 }
512
513 let fun = f(&x);
514 Ok(PenaltyResult {
515 fun,
516 x: Array1::from_vec(x),
517 nit,
518 nfev: nfev_total,
519 success: true,
520 message: "Interior penalty completed".to_string(),
521 mu,
522 constraint_violation: 0.0,
523 })
524 }
525
526 fn solve_l1_slice<F, E, G>(
527 &self,
528 f: F,
529 eq_cons: &[E],
530 ineq_cons: &[G],
531 x0: &[f64],
532 ) -> OptimizeResult<PenaltyResult>
533 where
534 F: Fn(&[f64]) -> f64,
535 E: Fn(&[f64]) -> f64,
536 G: Fn(&[f64]) -> f64,
537 {
538 let mut x = x0.to_vec();
539 let mu = self.options.mu_init;
540 let mut nfev_total = 0usize;
541 let mut nit = 0usize;
542
543 for _outer in 0..self.options.max_outer_iter {
544 nit += 1;
545
546 let l1_fn = |xv: &[f64]| -> f64 {
547 let obj = f(xv);
548 let penalty: f64 = eq_cons.iter().map(|e| mu * e(xv).abs()).sum::<f64>()
549 + ineq_cons.iter().map(|g| mu * g(xv).max(0.0)).sum::<f64>();
550 obj + penalty
551 };
552
553 let new_x = minimize_penalty_subproblem(
554 l1_fn,
555 &x,
556 1000,
557 self.options.optimality_tol,
558 self.options.eps,
559 &mut nfev_total,
560 );
561 x = new_x;
562
563 let eq_cv: f64 = eq_cons.iter().map(|e| e(&x).abs()).sum();
565 let ineq_cv: f64 = ineq_cons.iter().map(|g| g(&x).max(0.0)).sum();
566 let cv = eq_cv + ineq_cv;
567
568 if cv <= self.options.constraint_tol {
569 let fun = f(&x);
570 return Ok(PenaltyResult {
571 x: Array1::from_vec(x),
572 fun,
573 nit,
574 nfev: nfev_total,
575 success: true,
576 message: "L1 penalty converged".to_string(),
577 mu,
578 constraint_violation: cv,
579 });
580 }
581 }
582
583 let eq_cv: f64 = eq_cons.iter().map(|e| e(&x).abs()).sum();
584 let ineq_cv: f64 = ineq_cons.iter().map(|g| g(&x).max(0.0)).sum();
585 let cv = eq_cv + ineq_cv;
586 let fun = f(&x);
587 Ok(PenaltyResult {
588 x: Array1::from_vec(x),
589 fun,
590 nit,
591 nfev: nfev_total,
592 success: cv <= self.options.constraint_tol,
593 message: "L1 penalty max iterations reached".to_string(),
594 mu,
595 constraint_violation: cv,
596 })
597 }
598}
599
600impl Default for PenaltyMethod {
601 fn default() -> Self {
602 PenaltyMethod::new()
603 }
604}
605
606#[derive(Debug, Clone, Copy, PartialEq, Eq)]
616pub enum PenaltyMethodKind {
617 Static,
620
621 Dynamic,
624
625 Adaptive,
629
630 AugmentedLagrangian,
634}
635
636pub fn penalty_function<F, G, H>(
666 x: &[f64],
667 obj: F,
668 ineq_cons: &[G],
669 eq_cons: &[H],
670 penalty_coeff: f64,
671 kind: PenaltyKind,
672) -> f64
673where
674 F: Fn(&[f64]) -> f64,
675 G: Fn(&[f64]) -> f64,
676 H: Fn(&[f64]) -> f64,
677{
678 let f_val = obj(x);
679 match kind {
680 PenaltyKind::External => {
681 let ineq_pen: f64 = ineq_cons
682 .iter()
683 .map(|g| penalty_coeff * g(x).max(0.0).powi(2))
684 .sum();
685 let eq_pen: f64 = eq_cons.iter().map(|h| penalty_coeff * h(x).powi(2)).sum();
686 f_val + ineq_pen + eq_pen
687 }
688 PenaltyKind::Interior => {
689 let barrier: f64 = ineq_cons
690 .iter()
691 .map(|g| {
692 let gv = g(x);
693 if gv < -1e-15 {
694 -penalty_coeff * gv.abs().ln()
695 } else {
696 f64::INFINITY
697 }
698 })
699 .sum();
700 f_val + barrier
701 }
702 PenaltyKind::ExactL1 => {
703 let ineq_pen: f64 = ineq_cons
704 .iter()
705 .map(|g| penalty_coeff * g(x).max(0.0))
706 .sum();
707 let eq_pen: f64 = eq_cons.iter().map(|h| penalty_coeff * h(x).abs()).sum();
708 f_val + ineq_pen + eq_pen
709 }
710 }
711}
712
713#[derive(Debug, Clone)]
731pub struct AdaptivePenalty {
732 pub penalty_coeff: f64,
734 pub min_penalty: f64,
736 pub max_penalty: f64,
738 pub increase_factor: f64,
740 pub decrease_factor: f64,
742 pub improvement_threshold: f64,
745 pub allow_decrease: bool,
747 prev_violation: f64,
749 stall_count: usize,
751 pub stall_patience: usize,
753}
754
755impl Default for AdaptivePenalty {
756 fn default() -> Self {
757 AdaptivePenalty {
758 penalty_coeff: 1.0,
759 min_penalty: 1e-3,
760 max_penalty: 1e10,
761 increase_factor: 10.0,
762 decrease_factor: 0.5,
763 improvement_threshold: 0.25,
764 allow_decrease: false,
765 prev_violation: f64::INFINITY,
766 stall_count: 0,
767 stall_patience: 1,
768 }
769 }
770}
771
772impl AdaptivePenalty {
773 pub fn new(initial_penalty: f64) -> Self {
775 AdaptivePenalty {
776 penalty_coeff: initial_penalty,
777 ..Default::default()
778 }
779 }
780
781 #[allow(clippy::too_many_arguments)]
783 pub fn with_config(
784 initial_penalty: f64,
785 min_penalty: f64,
786 max_penalty: f64,
787 increase_factor: f64,
788 decrease_factor: f64,
789 improvement_threshold: f64,
790 allow_decrease: bool,
791 stall_patience: usize,
792 ) -> Self {
793 AdaptivePenalty {
794 penalty_coeff: initial_penalty,
795 min_penalty,
796 max_penalty,
797 increase_factor,
798 decrease_factor,
799 improvement_threshold,
800 allow_decrease,
801 prev_violation: f64::INFINITY,
802 stall_count: 0,
803 stall_patience,
804 }
805 }
806
807 pub fn update(&mut self, current_violation: f64) -> f64 {
812 if self.prev_violation.is_infinite() {
813 self.prev_violation = current_violation;
815 return self.penalty_coeff;
816 }
817
818 let relative_improvement = if self.prev_violation > 1e-15 {
819 (self.prev_violation - current_violation) / self.prev_violation
820 } else {
821 1.0
823 };
824
825 if relative_improvement < self.improvement_threshold {
826 self.stall_count += 1;
828 if self.stall_count >= self.stall_patience {
829 self.penalty_coeff =
830 (self.penalty_coeff * self.increase_factor).min(self.max_penalty);
831 self.stall_count = 0;
832 }
833 } else {
834 self.stall_count = 0;
836 if self.allow_decrease && relative_improvement > 0.5 {
837 self.penalty_coeff =
838 (self.penalty_coeff * self.decrease_factor).max(self.min_penalty);
839 }
840 }
841
842 self.prev_violation = current_violation;
843 self.penalty_coeff
844 }
845
846 pub fn reset(&mut self) {
848 self.prev_violation = f64::INFINITY;
849 self.stall_count = 0;
850 }
851}
852
853#[derive(Debug, Clone)]
859pub struct AugLagOptions {
860 pub rho_init: f64,
862 pub rho_max: f64,
864 pub rho_factor: f64,
867 pub max_outer_iter: usize,
869 pub constraint_tol: f64,
871 pub optimality_tol: f64,
873 pub eps: f64,
875 pub violation_reduction_threshold: f64,
878}
879
880impl Default for AugLagOptions {
881 fn default() -> Self {
882 AugLagOptions {
883 rho_init: 1.0,
884 rho_max: 1e10,
885 rho_factor: 10.0,
886 max_outer_iter: 100,
887 constraint_tol: 1e-6,
888 optimality_tol: 1e-8,
889 eps: 1e-7,
890 violation_reduction_threshold: 0.25,
891 }
892 }
893}
894
895#[derive(Debug, Clone)]
897pub struct AugLagResult {
898 pub x: Array1<f64>,
900 pub fun: f64,
902 pub nit: usize,
904 pub nfev: usize,
906 pub success: bool,
908 pub message: String,
910 pub lambda_eq: Vec<f64>,
912 pub lambda_ineq: Vec<f64>,
914 pub rho: f64,
916 pub constraint_violation: f64,
918}
919
920impl From<AugLagResult> for OptimizeResults<f64> {
921 fn from(r: AugLagResult) -> Self {
922 OptimizeResults {
923 x: r.x,
924 fun: r.fun,
925 jac: None,
926 hess: None,
927 constr: None,
928 nit: r.nit,
929 nfev: r.nfev,
930 njev: 0,
931 nhev: 0,
932 maxcv: 0,
933 message: r.message,
934 success: r.success,
935 status: if r.success { 0 } else { 1 },
936 }
937 }
938}
939
940#[derive(Debug, Clone)]
970pub struct AugmentedLagrangianSolver {
971 pub options: AugLagOptions,
973}
974
975impl AugmentedLagrangianSolver {
976 pub fn new() -> Self {
978 AugmentedLagrangianSolver {
979 options: AugLagOptions::default(),
980 }
981 }
982
983 pub fn with_options(options: AugLagOptions) -> Self {
985 AugmentedLagrangianSolver { options }
986 }
987
988 fn compute_violation<E, G>(&self, x: &[f64], eq_cons: &[E], ineq_cons: &[G]) -> f64
990 where
991 E: Fn(&[f64]) -> f64,
992 G: Fn(&[f64]) -> f64,
993 {
994 let eq_sq: f64 = eq_cons.iter().map(|h| h(x).powi(2)).sum();
995 let ineq_sq: f64 = ineq_cons.iter().map(|g| g(x).max(0.0).powi(2)).sum();
996 (eq_sq + ineq_sq).sqrt()
997 }
998
999 pub fn solve<F, E, G>(
1007 &self,
1008 f: F,
1009 eq_cons: &[E],
1010 ineq_cons: &[G],
1011 x0: &[f64],
1012 ) -> OptimizeResult<AugLagResult>
1013 where
1014 F: Fn(&[f64]) -> f64,
1015 E: Fn(&[f64]) -> f64,
1016 G: Fn(&[f64]) -> f64,
1017 {
1018 let n_eq = eq_cons.len();
1019 let n_ineq = ineq_cons.len();
1020
1021 let mut lambda_eq = vec![0.0_f64; n_eq];
1023 let mut lambda_ineq = vec![0.0_f64; n_ineq];
1024 let mut rho = self.options.rho_init;
1025
1026 let mut x = x0.to_vec();
1027 let mut nfev_total = 0usize;
1028 let mut nit = 0usize;
1029 let mut prev_violation = f64::INFINITY;
1030
1031 for _outer in 0..self.options.max_outer_iter {
1032 nit += 1;
1033
1034 let lam_eq_snap = lambda_eq.clone();
1036 let lam_ineq_snap = lambda_ineq.clone();
1037 let rho_snap = rho;
1038
1039 let aug_lag = |xv: &[f64]| -> f64 {
1041 let mut val = f(xv);
1042
1043 for (j, h) in eq_cons.iter().enumerate() {
1045 let hv = h(xv);
1046 val += lam_eq_snap[j] * hv + 0.5 * rho_snap * hv * hv;
1047 }
1048
1049 for (i, g) in ineq_cons.iter().enumerate() {
1056 let gv = g(xv);
1057 let shifted = gv + lam_ineq_snap[i] / rho_snap;
1058 if shifted > 0.0 {
1059 val += lam_ineq_snap[i] * gv + 0.5 * rho_snap * gv * gv;
1060 }
1061 }
1063
1064 val
1065 };
1066
1067 let new_x = minimize_penalty_subproblem(
1069 aug_lag,
1070 &x,
1071 2000,
1072 self.options.optimality_tol,
1073 self.options.eps,
1074 &mut nfev_total,
1075 );
1076
1077 x = new_x;
1078
1079 let cv = self.compute_violation(&x, eq_cons, ineq_cons);
1081
1082 if cv <= self.options.constraint_tol {
1084 let fun = f(&x);
1085 return Ok(AugLagResult {
1086 x: Array1::from_vec(x),
1087 fun,
1088 nit,
1089 nfev: nfev_total,
1090 success: true,
1091 message: "Converged: constraint violation below tolerance".to_string(),
1092 lambda_eq,
1093 lambda_ineq,
1094 rho,
1095 constraint_violation: cv,
1096 });
1097 }
1098
1099 let violation_improvement = if prev_violation.is_finite() && prev_violation > 1e-15 {
1101 (prev_violation - cv) / prev_violation
1102 } else {
1103 0.0
1104 };
1105
1106 if violation_improvement >= self.options.violation_reduction_threshold {
1107 for (j, h) in eq_cons.iter().enumerate() {
1109 lambda_eq[j] += rho * h(&x);
1110 }
1111 for (i, g) in ineq_cons.iter().enumerate() {
1112 lambda_ineq[i] = (lambda_ineq[i] + rho * g(&x)).max(0.0);
1113 }
1114 } else {
1115 rho = (rho * self.options.rho_factor).min(self.options.rho_max);
1117 }
1118
1119 prev_violation = cv;
1120 }
1121
1122 let cv = self.compute_violation(&x, eq_cons, ineq_cons);
1124 let fun = f(&x);
1125 let success = cv <= self.options.constraint_tol;
1126 Ok(AugLagResult {
1127 x: Array1::from_vec(x),
1128 fun,
1129 nit,
1130 nfev: nfev_total,
1131 success,
1132 message: if success {
1133 "Converged".to_string()
1134 } else {
1135 format!(
1136 "Maximum outer iterations ({}) reached; cv={:.2e}",
1137 self.options.max_outer_iter, cv
1138 )
1139 },
1140 lambda_eq,
1141 lambda_ineq,
1142 rho,
1143 constraint_violation: cv,
1144 })
1145 }
1146}
1147
1148impl Default for AugmentedLagrangianSolver {
1149 fn default() -> Self {
1150 AugmentedLagrangianSolver::new()
1151 }
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156 use super::*;
1157 use approx::assert_abs_diff_eq;
1158
1159 #[test]
1160 fn test_penalty_equality_constraint() {
1161 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1164 let h = |x: &[f64]| x[0] + x[1] - 1.0;
1165
1166 let opts = PenaltyOptions {
1167 kind: PenaltyKind::External,
1168 mu_init: 1.0,
1169 mu_factor: 10.0,
1170 mu_max: 1e8,
1171 max_outer_iter: 30,
1172 constraint_tol: 1e-4,
1173 ..Default::default()
1174 };
1175 let solver = PenaltyMethod::with_options(opts);
1176 let result = solver
1177 .solve_slice(f, &[h], &[] as &[fn(&[f64]) -> f64], &[0.0, 0.0])
1178 .expect("solve failed");
1179
1180 assert_abs_diff_eq!(result.x[0], 0.5, epsilon = 1e-2);
1181 assert_abs_diff_eq!(result.x[1], 0.5, epsilon = 1e-2);
1182 assert_abs_diff_eq!(result.fun, 0.5, epsilon = 1e-2);
1183 }
1184
1185 #[test]
1186 fn test_penalty_inequality_constraint() {
1187 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1189 let g = |x: &[f64]| 1.0 - x[0] - x[1]; let opts = PenaltyOptions {
1192 kind: PenaltyKind::External,
1193 mu_init: 1.0,
1194 mu_factor: 10.0,
1195 mu_max: 1e8,
1196 max_outer_iter: 40,
1197 constraint_tol: 1e-3,
1198 ..Default::default()
1199 };
1200 let solver = PenaltyMethod::with_options(opts);
1201 let result = solver
1202 .solve_slice(f, &[] as &[fn(&[f64]) -> f64], &[g], &[2.0, 2.0])
1203 .expect("solve failed");
1204
1205 assert_abs_diff_eq!(result.fun, 0.5, epsilon = 1e-2);
1207 }
1208
1209 #[test]
1210 fn test_penalty_no_constraints() {
1211 let f = |x: &[f64]| (x[0] - 3.0).powi(2) + (x[1] - 4.0).powi(2);
1213 let solver = PenaltyMethod::new();
1214 let result = solver
1215 .solve_slice(
1216 f,
1217 &[] as &[fn(&[f64]) -> f64],
1218 &[] as &[fn(&[f64]) -> f64],
1219 &[0.0, 0.0],
1220 )
1221 .expect("solve failed");
1222
1223 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-3);
1224 assert_abs_diff_eq!(result.x[0], 3.0, epsilon = 1e-2);
1225 assert_abs_diff_eq!(result.x[1], 4.0, epsilon = 1e-2);
1226 }
1227
1228 #[test]
1229 fn test_penalty_l1_equality() {
1230 let f = |x: &[f64]| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2);
1231 let h = |x: &[f64]| x[0] + x[1] - 3.0;
1232
1233 let opts = PenaltyOptions {
1234 kind: PenaltyKind::ExactL1,
1235 mu_init: 10.0,
1236 max_outer_iter: 50,
1237 constraint_tol: 1e-3,
1238 ..Default::default()
1239 };
1240 let solver = PenaltyMethod::with_options(opts);
1241 let result = solver
1242 .solve_slice(f, &[h], &[] as &[fn(&[f64]) -> f64], &[0.0, 0.0])
1243 .expect("solve failed");
1244
1245 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-2);
1247 }
1248
1249 #[test]
1250 fn test_penalty_interior_barrier() {
1251 let f = |x: &[f64]| (x[0] - 0.5).powi(2);
1254 let g = |x: &[f64]| x[0] - 0.999;
1255 let opts = PenaltyOptions {
1256 kind: PenaltyKind::Interior,
1257 mu_init: 0.1,
1258 mu_factor: 5.0,
1259 max_outer_iter: 20,
1260 ..Default::default()
1261 };
1262 let solver = PenaltyMethod::with_options(opts);
1263 let result = solver
1264 .solve_slice(f, &[] as &[fn(&[f64]) -> f64], &[g], &[0.3])
1265 .expect("solve failed");
1266
1267 assert_abs_diff_eq!(result.x[0], 0.5, epsilon = 0.1);
1268 }
1269
1270 #[test]
1271 fn test_penalty_mixed_constraints() {
1272 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2) + x[2].powi(2);
1274 let h = |x: &[f64]| x[0] + x[1] + x[2] - 3.0;
1275 let g = |x: &[f64]| x[2] - 1.0;
1276
1277 let opts = PenaltyOptions {
1278 kind: PenaltyKind::External,
1279 mu_init: 1.0,
1280 mu_factor: 5.0,
1281 mu_max: 1e7,
1282 max_outer_iter: 50,
1283 constraint_tol: 1e-3,
1284 ..Default::default()
1285 };
1286 let solver = PenaltyMethod::with_options(opts);
1287 let result = solver
1288 .solve_slice(f, &[h], &[g], &[0.0, 0.0, 0.0])
1289 .expect("solve failed");
1290
1291 assert!(result.fun <= 4.0, "fun={}", result.fun);
1293 }
1294
1295 #[test]
1296 fn test_penalty_arrayview1_interface() {
1297 use scirs2_core::ndarray::{array, ArrayView1};
1298 let f = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
1299 let h = |x: &ArrayView1<f64>| x[0] + x[1] - 1.0;
1300
1301 let opts = PenaltyOptions {
1302 kind: PenaltyKind::External,
1303 max_outer_iter: 20,
1304 constraint_tol: 1e-3,
1305 ..Default::default()
1306 };
1307 let solver = PenaltyMethod::with_options(opts);
1308 let x0 = array![0.0, 0.0];
1309 let result = solver
1310 .solve(f, &[h], &[] as &[fn(&ArrayView1<f64>) -> f64], &x0)
1311 .expect("solve failed");
1312
1313 assert_abs_diff_eq!(result.fun, 0.5, epsilon = 5e-2);
1315 }
1316
1317 #[test]
1320 fn test_penalty_function_external_feasible() {
1321 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1322 let g = |x: &[f64]| x[0] + x[1] - 1.0; let val = penalty_function(
1326 &[0.0, 0.0],
1327 f,
1328 &[g],
1329 &[] as &[fn(&[f64]) -> f64],
1330 10.0,
1331 PenaltyKind::External,
1332 );
1333 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-12);
1335 }
1336
1337 #[test]
1338 fn test_penalty_function_external_violated() {
1339 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1340 let g = |x: &[f64]| x[0] + x[1] - 1.0; let mu = 5.0;
1343 let val = penalty_function(
1344 &[1.0, 1.0],
1345 f,
1346 &[g],
1347 &[] as &[fn(&[f64]) -> f64],
1348 mu,
1349 PenaltyKind::External,
1350 );
1351 assert_abs_diff_eq!(val, 2.0 + mu * 1.0_f64.powi(2), epsilon = 1e-12);
1353 }
1354
1355 #[test]
1356 fn test_penalty_function_equality() {
1357 let f = |_x: &[f64]| 0.0;
1358 let h = |x: &[f64]| x[0] - 1.0; let mu = 3.0;
1361 let val = penalty_function(
1362 &[2.0],
1363 f,
1364 &[] as &[fn(&[f64]) -> f64],
1365 &[h],
1366 mu,
1367 PenaltyKind::External,
1368 );
1369 assert_abs_diff_eq!(val, 3.0, epsilon = 1e-12);
1371 }
1372
1373 #[test]
1376 fn test_adaptive_penalty_increases_on_stall() {
1377 let mut ap = AdaptivePenalty::new(1.0);
1378 ap.increase_factor = 5.0;
1379 ap.improvement_threshold = 0.1;
1380 ap.stall_patience = 1;
1381
1382 let p0 = ap.update(1.0);
1384 assert_abs_diff_eq!(p0, 1.0, epsilon = 1e-12);
1385
1386 let p1 = ap.update(0.95);
1389 assert!(p1 > 1.0, "Penalty should have increased; got {p1}");
1390 }
1391
1392 #[test]
1393 fn test_adaptive_penalty_no_increase_on_good_progress() {
1394 let mut ap = AdaptivePenalty::new(1.0);
1395 ap.improvement_threshold = 0.1;
1396 ap.stall_patience = 1;
1397
1398 ap.update(1.0); let p1 = ap.update(0.5); assert_abs_diff_eq!(p1, 1.0, epsilon = 1e-12);
1402 }
1403
1404 #[test]
1405 fn test_adaptive_penalty_capped_at_max() {
1406 let mut ap = AdaptivePenalty::with_config(1e9, 1e-3, 1e10, 1000.0, 0.5, 0.1, false, 1);
1407 ap.update(1.0);
1408 let p = ap.update(1.0); assert!(p <= 1e10, "Penalty exceeded max; got {p}");
1410 }
1411
1412 #[test]
1415 fn test_aug_lag_equality_constraint() {
1416 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1419 let h = |x: &[f64]| x[0] + x[1] - 1.0;
1420
1421 let opts = AugLagOptions {
1422 rho_init: 1.0,
1423 rho_factor: 5.0,
1424 max_outer_iter: 50,
1425 constraint_tol: 1e-4,
1426 ..Default::default()
1427 };
1428 let solver = AugmentedLagrangianSolver::with_options(opts);
1429 let result = solver
1430 .solve(f, &[h], &[] as &[fn(&[f64]) -> f64], &[0.0, 0.0])
1431 .expect("AugLag solve failed");
1432
1433 assert!(
1434 result.success || result.constraint_violation < 1e-2,
1435 "cv={}",
1436 result.constraint_violation
1437 );
1438 assert_abs_diff_eq!(result.fun, 0.5, epsilon = 0.05);
1439 }
1440
1441 #[test]
1442 fn test_aug_lag_inequality_constraint() {
1443 let f = |x: &[f64]| x[0].powi(2);
1446 let g = |x: &[f64]| 1.0 - x[0]; let opts = AugLagOptions {
1449 rho_init: 1.0,
1450 rho_factor: 10.0,
1451 max_outer_iter: 60,
1452 constraint_tol: 1e-3,
1453 ..Default::default()
1454 };
1455 let solver = AugmentedLagrangianSolver::with_options(opts);
1456 let result = solver
1457 .solve(f, &[] as &[fn(&[f64]) -> f64], &[g], &[0.5])
1458 .expect("AugLag solve failed");
1459
1460 assert!(
1462 result.x[0] >= 0.9 && result.x[0] <= 1.2,
1463 "Expected x~1.0, got x={}",
1464 result.x[0]
1465 );
1466 }
1467
1468 #[test]
1469 fn test_aug_lag_no_constraints() {
1470 let f = |x: &[f64]| (x[0] - 2.0).powi(2) + (x[1] - 3.0).powi(2);
1472 let solver = AugmentedLagrangianSolver::new();
1473 let result = solver
1474 .solve(
1475 f,
1476 &[] as &[fn(&[f64]) -> f64],
1477 &[] as &[fn(&[f64]) -> f64],
1478 &[0.0, 0.0],
1479 )
1480 .expect("AugLag solve failed");
1481
1482 assert_abs_diff_eq!(result.x[0], 2.0, epsilon = 1e-2);
1483 assert_abs_diff_eq!(result.x[1], 3.0, epsilon = 1e-2);
1484 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-3);
1485 }
1486
1487 #[test]
1488 fn test_penalty_method_kind_variants() {
1489 assert_eq!(PenaltyMethodKind::Static, PenaltyMethodKind::Static);
1491 assert_ne!(PenaltyMethodKind::Dynamic, PenaltyMethodKind::Adaptive);
1492 assert_ne!(
1493 PenaltyMethodKind::AugmentedLagrangian,
1494 PenaltyMethodKind::Static
1495 );
1496 }
1497}