Skip to main content

scirs2_optimize/constrained/
penalty.rs

1//! Penalty Methods for Constrained Optimization
2//!
3//! Penalty methods transform constrained problems into a sequence of unconstrained
4//! problems by adding a penalty term to the objective function that penalizes
5//! constraint violations.
6//!
7//! Two major classes are implemented:
8//!
9//! ## External Penalty Method
10//! Adds a penalty for constraint violations:
11//! ```text
12//! P(x, mu) = f(x) + mu * sum_i max(0, g_i(x))^2 + mu * sum_j h_j(x)^2
13//! ```
14//! The penalty parameter mu is increased until the constraints are satisfied.
15//!
16//! ## Interior Penalty (Barrier) Method
17//! Adds a barrier function that prevents leaving the feasible region:
18//! ```text
19//! B(x, mu) = f(x) - mu * sum_i ln(-g_i(x))   (log barrier)
20//! ```
21//! The barrier parameter mu is decreased to zero.
22//!
23//! # References
24//! - Fiacco, A.V. & McCormick, G.P. (1968). "Nonlinear Programming: Sequential
25//!   Unconstrained Minimization Techniques." SIAM.
26//! - Nocedal, J. & Wright, S.J. (2006). "Numerical Optimization." Chapter 17.
27
28use crate::error::{OptimizeError, OptimizeResult};
29use crate::result::OptimizeResults;
30use scirs2_core::ndarray::{Array1, ArrayView1};
31
32/// Type of penalty method
33#[derive(Debug, Clone, Copy, PartialEq)]
34pub enum PenaltyKind {
35    /// External penalty: quadratic penalty for violations
36    External,
37    /// Interior penalty (log-barrier): only works from strictly feasible start
38    Interior,
39    /// Exact L1 penalty: |max(0, g_i(x))| + |h_j(x)|
40    ExactL1,
41}
42
43/// Options for the penalty method
44#[derive(Debug, Clone)]
45pub struct PenaltyOptions {
46    /// Penalty method kind
47    pub kind: PenaltyKind,
48    /// Initial penalty parameter
49    pub mu_init: f64,
50    /// Maximum penalty parameter (external) / min barrier (interior)
51    pub mu_max: f64,
52    /// Penalty increase factor (external) or decrease factor (interior)
53    pub mu_factor: f64,
54    /// Maximum number of outer iterations
55    pub max_outer_iter: usize,
56    /// Tolerance for constraint violation
57    pub constraint_tol: f64,
58    /// Tolerance for optimality of subproblem
59    pub optimality_tol: f64,
60    /// Finite difference step for gradient
61    pub eps: f64,
62}
63
64impl Default for PenaltyOptions {
65    fn default() -> Self {
66        PenaltyOptions {
67            kind: PenaltyKind::External,
68            mu_init: 1.0,
69            mu_max: 1e10,
70            mu_factor: 10.0,
71            max_outer_iter: 100,
72            constraint_tol: 1e-6,
73            optimality_tol: 1e-8,
74            eps: 1e-7,
75        }
76    }
77}
78
79/// Result from penalty method
80#[derive(Debug, Clone)]
81pub struct PenaltyResult {
82    /// Optimal solution
83    pub x: Array1<f64>,
84    /// Optimal objective value
85    pub fun: f64,
86    /// Number of outer iterations
87    pub nit: usize,
88    /// Total number of function evaluations
89    pub nfev: usize,
90    /// Success flag
91    pub success: bool,
92    /// Status message
93    pub message: String,
94    /// Final penalty parameter
95    pub mu: f64,
96    /// Final constraint violation
97    pub constraint_violation: f64,
98}
99
100impl From<PenaltyResult> for OptimizeResults<f64> {
101    fn from(r: PenaltyResult) -> Self {
102        OptimizeResults {
103            x: r.x,
104            fun: r.fun,
105            jac: None,
106            hess: None,
107            constr: None,
108            nit: r.nit,
109            nfev: r.nfev,
110            njev: 0,
111            nhev: 0,
112            maxcv: 0,
113            message: r.message,
114            success: r.success,
115            status: if r.success { 0 } else { 1 },
116        }
117    }
118}
119
120/// Internal gradient-based minimizer for penalty subproblems.
121/// Uses L-BFGS-like updates with simple backtracking line search.
122fn minimize_penalty_subproblem<P>(
123    penalty_fn: P,
124    x0: &[f64],
125    max_iter: usize,
126    gtol: f64,
127    eps: f64,
128    nfev: &mut usize,
129) -> Vec<f64>
130where
131    P: Fn(&[f64]) -> f64,
132{
133    let n = x0.len();
134    let mut x = x0.to_vec();
135    let m = 5usize; // L-BFGS history
136
137    let mut s_hist: Vec<Vec<f64>> = Vec::new();
138    let mut y_hist: Vec<Vec<f64>> = Vec::new();
139    let mut rho_hist: Vec<f64> = Vec::new();
140
141    let compute_grad = |xv: &[f64], nfev: &mut usize| -> Vec<f64> {
142        let h = eps;
143        let mut g = vec![0.0; n];
144        let mut xp = xv.to_vec();
145        let mut xm = xv.to_vec();
146        *nfev += 2 * n;
147        for i in 0..n {
148            xp[i] = xv[i] + h;
149            xm[i] = xv[i] - h;
150            g[i] = (penalty_fn(&xp) - penalty_fn(&xm)) / (2.0 * h);
151            xp[i] = xv[i];
152            xm[i] = xv[i];
153        }
154        g
155    };
156
157    let mut g = compute_grad(&x, nfev);
158
159    for _iter in 0..max_iter {
160        let gnorm: f64 = g.iter().map(|v| v * v).sum::<f64>().sqrt();
161        if gnorm < gtol {
162            break;
163        }
164
165        // L-BFGS two-loop recursion
166        let mut q = g.clone();
167        let hist_len = s_hist.len();
168        let mut alpha_hist = vec![0.0_f64; hist_len];
169
170        for i in (0..hist_len).rev() {
171            let si = &s_hist[i];
172            let yi = &y_hist[i];
173            let rho_i = rho_hist[i];
174            let dot: f64 = si.iter().zip(q.iter()).map(|(&s, &qi)| s * qi).sum();
175            alpha_hist[i] = rho_i * dot;
176            let a = alpha_hist[i];
177            for j in 0..n {
178                q[j] -= a * yi[j];
179            }
180        }
181
182        // Initial Hessian scaling
183        let mut r = if hist_len > 0 {
184            let last_s = &s_hist[hist_len - 1];
185            let last_y = &y_hist[hist_len - 1];
186            let sy: f64 = last_s.iter().zip(last_y.iter()).map(|(&s, &y)| s * y).sum();
187            let yy: f64 = last_y.iter().map(|v| v * v).sum();
188            let scale = if yy > 1e-15 { sy / yy } else { 1.0 };
189            q.iter().map(|&qi| scale * qi).collect::<Vec<f64>>()
190        } else {
191            q.clone()
192        };
193
194        for i in 0..hist_len {
195            let si = &s_hist[i];
196            let yi = &y_hist[i];
197            let rho_i = rho_hist[i];
198            let dot: f64 = yi.iter().zip(r.iter()).map(|(&y, &ri)| y * ri).sum();
199            let beta = rho_i * dot;
200            let diff = alpha_hist[i] - beta;
201            for j in 0..n {
202                r[j] += si[j] * diff;
203            }
204        }
205
206        // Direction: d = -r
207        let d: Vec<f64> = r.iter().map(|v| -v).collect();
208
209        // Backtracking line search
210        *nfev += 1;
211        let fx = penalty_fn(&x);
212        let dg: f64 = d.iter().zip(g.iter()).map(|(&di, &gi)| di * gi).sum();
213        let mut alpha = 1.0_f64;
214
215        for _ls in 0..20 {
216            let xnew: Vec<f64> = x
217                .iter()
218                .zip(d.iter())
219                .map(|(&xi, &di)| xi + alpha * di)
220                .collect();
221            *nfev += 1;
222            let fnew = penalty_fn(&xnew);
223            if fnew <= fx + 1e-4 * alpha * dg.min(0.0) {
224                break;
225            }
226            alpha *= 0.5;
227        }
228
229        let xnew: Vec<f64> = x
230            .iter()
231            .zip(d.iter())
232            .map(|(&xi, &di)| xi + alpha * di)
233            .collect();
234        let gnew = compute_grad(&xnew, nfev);
235
236        // L-BFGS update
237        let s: Vec<f64> = xnew
238            .iter()
239            .zip(x.iter())
240            .map(|(&xni, &xi)| xni - xi)
241            .collect();
242        let y: Vec<f64> = gnew
243            .iter()
244            .zip(g.iter())
245            .map(|(&gni, &gi)| gni - gi)
246            .collect();
247        let sy: f64 = s.iter().zip(y.iter()).map(|(&si, &yi)| si * yi).sum();
248
249        if sy > 1e-10 {
250            if s_hist.len() >= m {
251                s_hist.remove(0);
252                y_hist.remove(0);
253                rho_hist.remove(0);
254            }
255            s_hist.push(s);
256            y_hist.push(y);
257            rho_hist.push(1.0 / sy);
258        }
259
260        x = xnew;
261        g = gnew;
262    }
263
264    x
265}
266
267/// Penalty method solver
268pub struct PenaltyMethod {
269    pub options: PenaltyOptions,
270}
271
272impl PenaltyMethod {
273    /// Create with default options
274    pub fn new() -> Self {
275        PenaltyMethod {
276            options: PenaltyOptions::default(),
277        }
278    }
279
280    /// Create with custom options
281    pub fn with_options(options: PenaltyOptions) -> Self {
282        PenaltyMethod { options }
283    }
284
285    /// Compute constraint violation at x
286    fn compute_violation<E, G>(&self, x: &[f64], eq_cons: &[E], ineq_cons: &[G]) -> f64
287    where
288        E: Fn(&[f64]) -> f64,
289        G: Fn(&[f64]) -> f64,
290    {
291        let eq_viol: f64 = eq_cons.iter().map(|e| e(x).powi(2)).sum();
292        let ineq_viol: f64 = ineq_cons.iter().map(|g| g(x).max(0.0).powi(2)).sum();
293        (eq_viol + ineq_viol).sqrt()
294    }
295
296    /// Solve with equality and inequality constraints.
297    ///
298    /// # Arguments
299    /// - `f`: Objective function taking a `&[f64]`
300    /// - `eq_cons`: Equality constraint functions h_j(x) = 0 (slices)
301    /// - `ineq_cons`: Inequality constraint functions g_i(x) <= 0 (slices)
302    /// - `x0`: Initial point
303    pub fn solve_slice<F, E, G>(
304        &self,
305        f: F,
306        eq_cons: &[E],
307        ineq_cons: &[G],
308        x0: &[f64],
309    ) -> OptimizeResult<PenaltyResult>
310    where
311        F: Fn(&[f64]) -> f64,
312        E: Fn(&[f64]) -> f64,
313        G: Fn(&[f64]) -> f64,
314    {
315        match self.options.kind {
316            PenaltyKind::External => self.solve_external_slice(f, eq_cons, ineq_cons, x0),
317            PenaltyKind::Interior => self.solve_interior_slice(f, ineq_cons, x0),
318            PenaltyKind::ExactL1 => self.solve_l1_slice(f, eq_cons, ineq_cons, x0),
319        }
320    }
321
322    /// Solve with ArrayView1 interface (for compatibility)
323    pub fn solve<F, E, G>(
324        &self,
325        f: F,
326        eq_cons: &[E],
327        ineq_cons: &[G],
328        x0: &Array1<f64>,
329    ) -> OptimizeResult<PenaltyResult>
330    where
331        F: Fn(&ArrayView1<f64>) -> f64,
332        E: Fn(&ArrayView1<f64>) -> f64,
333        G: Fn(&ArrayView1<f64>) -> f64,
334    {
335        // Wrap ArrayView1 closures to slice-based
336        let f_slice = |x: &[f64]| {
337            let arr = Array1::from_vec(x.to_vec());
338            f(&arr.view())
339        };
340        let eq_slice: Vec<Box<dyn Fn(&[f64]) -> f64>> = eq_cons
341            .iter()
342            .map(|e| {
343                Box::new(move |x: &[f64]| {
344                    let arr = Array1::from_vec(x.to_vec());
345                    e(&arr.view())
346                }) as Box<dyn Fn(&[f64]) -> f64>
347            })
348            .collect();
349        let ineq_slice: Vec<Box<dyn Fn(&[f64]) -> f64>> = ineq_cons
350            .iter()
351            .map(|g| {
352                Box::new(move |x: &[f64]| {
353                    let arr = Array1::from_vec(x.to_vec());
354                    g(&arr.view())
355                }) as Box<dyn Fn(&[f64]) -> f64>
356            })
357            .collect();
358
359        let x0_slice: Vec<f64> = x0.iter().copied().collect();
360        self.solve_slice(f_slice, &eq_slice, &ineq_slice, &x0_slice)
361    }
362
363    fn solve_external_slice<F, E, G>(
364        &self,
365        f: F,
366        eq_cons: &[E],
367        ineq_cons: &[G],
368        x0: &[f64],
369    ) -> OptimizeResult<PenaltyResult>
370    where
371        F: Fn(&[f64]) -> f64,
372        E: Fn(&[f64]) -> f64,
373        G: Fn(&[f64]) -> f64,
374    {
375        let mut x = x0.to_vec();
376        let mut mu = self.options.mu_init;
377        let mut nfev_total = 0usize;
378        let mut nit = 0usize;
379
380        for _outer in 0..self.options.max_outer_iter {
381            nit += 1;
382            let mu_local = mu;
383
384            // Build penalty function capturing current state
385            let penalty_at = |xv: &[f64]| -> f64 {
386                let obj = f(xv);
387                let penalty: f64 = eq_cons
388                    .iter()
389                    .map(|e| mu_local * e(xv).powi(2))
390                    .sum::<f64>()
391                    + ineq_cons
392                        .iter()
393                        .map(|g| mu_local * g(xv).max(0.0).powi(2))
394                        .sum::<f64>();
395                obj + penalty
396            };
397
398            let new_x = minimize_penalty_subproblem(
399                penalty_at,
400                &x,
401                1000,
402                self.options.optimality_tol,
403                self.options.eps,
404                &mut nfev_total,
405            );
406            x = new_x;
407
408            // Check convergence after minimization
409            let cv = self.compute_violation(&x, eq_cons, ineq_cons);
410            if cv <= self.options.constraint_tol {
411                let fun = f(&x);
412                return Ok(PenaltyResult {
413                    x: Array1::from_vec(x),
414                    fun,
415                    nit,
416                    nfev: nfev_total,
417                    success: true,
418                    message: "Converged: constraint violation below tolerance".to_string(),
419                    mu,
420                    constraint_violation: cv,
421                });
422            }
423
424            mu = (mu * self.options.mu_factor).min(self.options.mu_max);
425        }
426
427        let cv = self.compute_violation(&x, eq_cons, ineq_cons);
428        let fun = f(&x);
429        let success = cv <= self.options.constraint_tol;
430        Ok(PenaltyResult {
431            x: Array1::from_vec(x),
432            fun,
433            nit,
434            nfev: nfev_total,
435            success,
436            message: if success {
437                "Converged".to_string()
438            } else {
439                format!("Maximum outer iterations reached (cv={:.2e})", cv)
440            },
441            mu,
442            constraint_violation: cv,
443        })
444    }
445
446    fn solve_interior_slice<F, G>(
447        &self,
448        f: F,
449        ineq_cons: &[G],
450        x0: &[f64],
451    ) -> OptimizeResult<PenaltyResult>
452    where
453        F: Fn(&[f64]) -> f64,
454        G: Fn(&[f64]) -> f64,
455    {
456        // Verify strict feasibility at starting point
457        for g in ineq_cons.iter() {
458            if g(x0) >= 0.0 {
459                return Err(OptimizeError::InvalidInput(
460                    "Interior penalty requires strictly feasible starting point (g_i(x0) < 0)"
461                        .to_string(),
462                ));
463            }
464        }
465
466        let mut x = x0.to_vec();
467        let mut mu = self.options.mu_init;
468        let mut nfev_total = 0usize;
469        let mut nit = 0usize;
470
471        for _outer in 0..self.options.max_outer_iter {
472            nit += 1;
473            if mu < 1e-12 {
474                break;
475            }
476
477            let mu_local = mu;
478
479            let barrier_fn = |xv: &[f64]| -> f64 {
480                let obj = f(xv);
481                let barrier: f64 = ineq_cons
482                    .iter()
483                    .map(|g| {
484                        let gv = g(xv);
485                        if gv < -1e-15 {
486                            -mu_local * gv.abs().ln()
487                        } else {
488                            f64::INFINITY
489                        }
490                    })
491                    .sum();
492                obj + barrier
493            };
494
495            let new_x = minimize_penalty_subproblem(
496                barrier_fn,
497                &x,
498                500,
499                mu * 0.01,
500                self.options.eps,
501                &mut nfev_total,
502            );
503
504            // Accept only if still feasible
505            let feasible = ineq_cons.iter().all(|g| g(&new_x) < 0.0);
506            if feasible {
507                x = new_x;
508            }
509
510            mu /= self.options.mu_factor;
511        }
512
513        let fun = f(&x);
514        Ok(PenaltyResult {
515            fun,
516            x: Array1::from_vec(x),
517            nit,
518            nfev: nfev_total,
519            success: true,
520            message: "Interior penalty completed".to_string(),
521            mu,
522            constraint_violation: 0.0,
523        })
524    }
525
526    fn solve_l1_slice<F, E, G>(
527        &self,
528        f: F,
529        eq_cons: &[E],
530        ineq_cons: &[G],
531        x0: &[f64],
532    ) -> OptimizeResult<PenaltyResult>
533    where
534        F: Fn(&[f64]) -> f64,
535        E: Fn(&[f64]) -> f64,
536        G: Fn(&[f64]) -> f64,
537    {
538        let mut x = x0.to_vec();
539        let mu = self.options.mu_init;
540        let mut nfev_total = 0usize;
541        let mut nit = 0usize;
542
543        for _outer in 0..self.options.max_outer_iter {
544            nit += 1;
545
546            let l1_fn = |xv: &[f64]| -> f64 {
547                let obj = f(xv);
548                let penalty: f64 = eq_cons.iter().map(|e| mu * e(xv).abs()).sum::<f64>()
549                    + ineq_cons.iter().map(|g| mu * g(xv).max(0.0)).sum::<f64>();
550                obj + penalty
551            };
552
553            let new_x = minimize_penalty_subproblem(
554                l1_fn,
555                &x,
556                1000,
557                self.options.optimality_tol,
558                self.options.eps,
559                &mut nfev_total,
560            );
561            x = new_x;
562
563            // Check convergence
564            let eq_cv: f64 = eq_cons.iter().map(|e| e(&x).abs()).sum();
565            let ineq_cv: f64 = ineq_cons.iter().map(|g| g(&x).max(0.0)).sum();
566            let cv = eq_cv + ineq_cv;
567
568            if cv <= self.options.constraint_tol {
569                let fun = f(&x);
570                return Ok(PenaltyResult {
571                    x: Array1::from_vec(x),
572                    fun,
573                    nit,
574                    nfev: nfev_total,
575                    success: true,
576                    message: "L1 penalty converged".to_string(),
577                    mu,
578                    constraint_violation: cv,
579                });
580            }
581        }
582
583        let eq_cv: f64 = eq_cons.iter().map(|e| e(&x).abs()).sum();
584        let ineq_cv: f64 = ineq_cons.iter().map(|g| g(&x).max(0.0)).sum();
585        let cv = eq_cv + ineq_cv;
586        let fun = f(&x);
587        Ok(PenaltyResult {
588            x: Array1::from_vec(x),
589            fun,
590            nit,
591            nfev: nfev_total,
592            success: cv <= self.options.constraint_tol,
593            message: "L1 penalty max iterations reached".to_string(),
594            mu,
595            constraint_violation: cv,
596        })
597    }
598}
599
600impl Default for PenaltyMethod {
601    fn default() -> Self {
602        PenaltyMethod::new()
603    }
604}
605
606// ─────────────────────────────────────────────────────────────────────────────
607// PenaltyMethodKind enum (Static / Dynamic / Adaptive / AugmentedLagrangian)
608// ─────────────────────────────────────────────────────────────────────────────
609
610/// Strategy for managing the penalty parameter sequence.
611///
612/// This enum governs *how* the penalty coefficient is updated across outer
613/// iterations, distinct from [`PenaltyKind`] which controls the *form* of the
614/// penalty term (quadratic / barrier / L1).
615#[derive(Debug, Clone, Copy, PartialEq, Eq)]
616pub enum PenaltyMethodKind {
617    /// Fixed penalty coefficient — never updated.
618    /// Suitable when a good penalty value is known in advance.
619    Static,
620
621    /// Penalty is multiplied by a constant factor each outer iteration.
622    /// Classic "increasing penalty" approach for exterior methods.
623    Dynamic,
624
625    /// Penalty is updated based on constraint violation progress.
626    /// Increases faster when violations are not decreasing, avoids ill-
627    /// conditioning by capping growth when convergence is progressing well.
628    Adaptive,
629
630    /// Augmented Lagrangian — maintains Lagrange multiplier estimates and
631    /// updates them after each outer iteration via dual ascent.  The penalty
632    /// parameter increases only when the multiplier update stalls.
633    AugmentedLagrangian,
634}
635
636// ─────────────────────────────────────────────────────────────────────────────
637// penalty_function — standalone evaluation helper
638// ─────────────────────────────────────────────────────────────────────────────
639
640/// Evaluate the penalised objective at point `x`.
641///
642/// For *exterior* (quadratic) penalties:
643/// ```text
644/// P(x, mu) = f(x)
645///          + mu * Σ_i max(0, g_i(x))²     [inequality: g_i(x) <= 0]
646///          + mu * Σ_j h_j(x)²              [equality:   h_j(x)  = 0]
647/// ```
648///
649/// For *interior* (log-barrier) penalties:
650/// ```text
651/// B(x, mu) = f(x) - mu * Σ_i ln(-g_i(x))     [only for strictly feasible x]
652/// ```
653///
654/// # Arguments
655/// * `x`              - Current iterate (decision vector).
656/// * `obj`            - Objective function `f: &[f64] -> f64`.
657/// * `ineq_cons`      - Inequality constraints `g_i(x) <= 0`.
658/// * `eq_cons`        - Equality constraints `h_j(x) = 0`.
659/// * `penalty_coeff`  - Scalar penalty / barrier weight `μ`.
660/// * `kind`           - Which form of penalty to apply.
661///
662/// # Returns
663/// Penalised scalar value.  Returns `f64::INFINITY` if the log-barrier is
664/// requested but `x` is infeasible (any `g_i(x) >= 0`).
665pub fn penalty_function<F, G, H>(
666    x: &[f64],
667    obj: F,
668    ineq_cons: &[G],
669    eq_cons: &[H],
670    penalty_coeff: f64,
671    kind: PenaltyKind,
672) -> f64
673where
674    F: Fn(&[f64]) -> f64,
675    G: Fn(&[f64]) -> f64,
676    H: Fn(&[f64]) -> f64,
677{
678    let f_val = obj(x);
679    match kind {
680        PenaltyKind::External => {
681            let ineq_pen: f64 = ineq_cons
682                .iter()
683                .map(|g| penalty_coeff * g(x).max(0.0).powi(2))
684                .sum();
685            let eq_pen: f64 = eq_cons.iter().map(|h| penalty_coeff * h(x).powi(2)).sum();
686            f_val + ineq_pen + eq_pen
687        }
688        PenaltyKind::Interior => {
689            let barrier: f64 = ineq_cons
690                .iter()
691                .map(|g| {
692                    let gv = g(x);
693                    if gv < -1e-15 {
694                        -penalty_coeff * gv.abs().ln()
695                    } else {
696                        f64::INFINITY
697                    }
698                })
699                .sum();
700            f_val + barrier
701        }
702        PenaltyKind::ExactL1 => {
703            let ineq_pen: f64 = ineq_cons
704                .iter()
705                .map(|g| penalty_coeff * g(x).max(0.0))
706                .sum();
707            let eq_pen: f64 = eq_cons.iter().map(|h| penalty_coeff * h(x).abs()).sum();
708            f_val + ineq_pen + eq_pen
709        }
710    }
711}
712
713// ─────────────────────────────────────────────────────────────────────────────
714// AdaptivePenalty
715// ─────────────────────────────────────────────────────────────────────────────
716
717/// Adaptive penalty controller.
718///
719/// Tracks constraint violation history and dynamically adjusts the penalty
720/// coefficient:
721/// - If violation has not improved by `improvement_threshold` fraction compared
722///   to the previous iteration → multiply penalty by `increase_factor`.
723/// - If violation improved substantially and current penalty is high → allow
724///   mild reduction to avoid ill-conditioning (optional, controlled by
725///   `allow_decrease`).
726///
727/// This implements a simplified version of the adaptive penalty from:
728/// Farmani & Wright (2003), "Self-adaptive fitness formulation for constrained
729/// optimization", IEEE TEC 7(5):445-455.
730#[derive(Debug, Clone)]
731pub struct AdaptivePenalty {
732    /// Current penalty coefficient.
733    pub penalty_coeff: f64,
734    /// Minimum allowed penalty (lower bound on growth).
735    pub min_penalty: f64,
736    /// Maximum allowed penalty (avoids numerical ill-conditioning).
737    pub max_penalty: f64,
738    /// Multiplicative increase factor applied when violations stall.
739    pub increase_factor: f64,
740    /// Multiplicative decrease factor applied when violations decrease rapidly.
741    pub decrease_factor: f64,
742    /// Relative improvement threshold below which penalty is increased.
743    /// E.g., 0.1 means "less than 10% improvement triggers increase".
744    pub improvement_threshold: f64,
745    /// Whether to allow penalty *decreases* (can reduce ill-conditioning).
746    pub allow_decrease: bool,
747    /// Stored constraint violation from previous outer iteration.
748    prev_violation: f64,
749    /// Number of consecutive non-improving iterations.
750    stall_count: usize,
751    /// Stall patience: how many non-improving iters before increasing penalty.
752    pub stall_patience: usize,
753}
754
755impl Default for AdaptivePenalty {
756    fn default() -> Self {
757        AdaptivePenalty {
758            penalty_coeff: 1.0,
759            min_penalty: 1e-3,
760            max_penalty: 1e10,
761            increase_factor: 10.0,
762            decrease_factor: 0.5,
763            improvement_threshold: 0.25,
764            allow_decrease: false,
765            prev_violation: f64::INFINITY,
766            stall_count: 0,
767            stall_patience: 1,
768        }
769    }
770}
771
772impl AdaptivePenalty {
773    /// Create a new adaptive penalty controller with the given initial coefficient.
774    pub fn new(initial_penalty: f64) -> Self {
775        AdaptivePenalty {
776            penalty_coeff: initial_penalty,
777            ..Default::default()
778        }
779    }
780
781    /// Create with full configuration.
782    #[allow(clippy::too_many_arguments)]
783    pub fn with_config(
784        initial_penalty: f64,
785        min_penalty: f64,
786        max_penalty: f64,
787        increase_factor: f64,
788        decrease_factor: f64,
789        improvement_threshold: f64,
790        allow_decrease: bool,
791        stall_patience: usize,
792    ) -> Self {
793        AdaptivePenalty {
794            penalty_coeff: initial_penalty,
795            min_penalty,
796            max_penalty,
797            increase_factor,
798            decrease_factor,
799            improvement_threshold,
800            allow_decrease,
801            prev_violation: f64::INFINITY,
802            stall_count: 0,
803            stall_patience,
804        }
805    }
806
807    /// Update the penalty coefficient based on `current_violation`.
808    ///
809    /// Should be called once per outer iteration after the sub-problem has been
810    /// solved.  Returns the updated penalty coefficient.
811    pub fn update(&mut self, current_violation: f64) -> f64 {
812        if self.prev_violation.is_infinite() {
813            // First call — just record and return current value
814            self.prev_violation = current_violation;
815            return self.penalty_coeff;
816        }
817
818        let relative_improvement = if self.prev_violation > 1e-15 {
819            (self.prev_violation - current_violation) / self.prev_violation
820        } else {
821            // Already near zero — treat as converged
822            1.0
823        };
824
825        if relative_improvement < self.improvement_threshold {
826            // Not improving fast enough
827            self.stall_count += 1;
828            if self.stall_count >= self.stall_patience {
829                self.penalty_coeff =
830                    (self.penalty_coeff * self.increase_factor).min(self.max_penalty);
831                self.stall_count = 0;
832            }
833        } else {
834            // Good improvement
835            self.stall_count = 0;
836            if self.allow_decrease && relative_improvement > 0.5 {
837                self.penalty_coeff =
838                    (self.penalty_coeff * self.decrease_factor).max(self.min_penalty);
839            }
840        }
841
842        self.prev_violation = current_violation;
843        self.penalty_coeff
844    }
845
846    /// Reset internal state (violation history and stall counter).
847    pub fn reset(&mut self) {
848        self.prev_violation = f64::INFINITY;
849        self.stall_count = 0;
850    }
851}
852
853// ─────────────────────────────────────────────────────────────────────────────
854// AugmentedLagrangianSolver
855// ─────────────────────────────────────────────────────────────────────────────
856
857/// Options for [`AugmentedLagrangianSolver`].
858#[derive(Debug, Clone)]
859pub struct AugLagOptions {
860    /// Initial penalty parameter ρ.
861    pub rho_init: f64,
862    /// Maximum penalty parameter.
863    pub rho_max: f64,
864    /// Factor by which ρ is multiplied when the constraint violation does not
865    /// decrease sufficiently.
866    pub rho_factor: f64,
867    /// Maximum number of outer (multiplier update) iterations.
868    pub max_outer_iter: usize,
869    /// Constraint violation tolerance (outer loop convergence criterion).
870    pub constraint_tol: f64,
871    /// Optimality tolerance passed to the inner unconstrained sub-solver.
872    pub optimality_tol: f64,
873    /// Finite-difference step for inner gradient computation.
874    pub eps: f64,
875    /// Required relative reduction in violation before multipliers are updated
876    /// (otherwise only the penalty grows).
877    pub violation_reduction_threshold: f64,
878}
879
880impl Default for AugLagOptions {
881    fn default() -> Self {
882        AugLagOptions {
883            rho_init: 1.0,
884            rho_max: 1e10,
885            rho_factor: 10.0,
886            max_outer_iter: 100,
887            constraint_tol: 1e-6,
888            optimality_tol: 1e-8,
889            eps: 1e-7,
890            violation_reduction_threshold: 0.25,
891        }
892    }
893}
894
895/// Result from [`AugmentedLagrangianSolver`].
896#[derive(Debug, Clone)]
897pub struct AugLagResult {
898    /// Optimal decision vector.
899    pub x: Array1<f64>,
900    /// Objective value at x.
901    pub fun: f64,
902    /// Number of outer iterations performed.
903    pub nit: usize,
904    /// Total function evaluations (inner + outer).
905    pub nfev: usize,
906    /// Success flag.
907    pub success: bool,
908    /// Status message.
909    pub message: String,
910    /// Final Lagrange multipliers for equality constraints.
911    pub lambda_eq: Vec<f64>,
912    /// Final Lagrange multipliers for inequality constraints.
913    pub lambda_ineq: Vec<f64>,
914    /// Final penalty parameter.
915    pub rho: f64,
916    /// Final constraint violation.
917    pub constraint_violation: f64,
918}
919
920impl From<AugLagResult> for OptimizeResults<f64> {
921    fn from(r: AugLagResult) -> Self {
922        OptimizeResults {
923            x: r.x,
924            fun: r.fun,
925            jac: None,
926            hess: None,
927            constr: None,
928            nit: r.nit,
929            nfev: r.nfev,
930            njev: 0,
931            nhev: 0,
932            maxcv: 0,
933            message: r.message,
934            success: r.success,
935            status: if r.success { 0 } else { 1 },
936        }
937    }
938}
939
940/// Augmented Lagrangian method solver (Method of Multipliers).
941///
942/// Solves problems of the form:
943/// ```text
944/// min   f(x)
945/// s.t.  h_j(x) = 0      (equality)
946///       g_i(x) <= 0     (inequality)
947/// ```
948///
949/// The augmented Lagrangian for equality constraints is:
950/// ```text
951/// L_A(x, λ, ρ) = f(x) + Σ_j λ_j h_j(x) + (ρ/2) Σ_j h_j(x)²
952/// ```
953/// Inequality constraints are handled via the shifted/signed penalty
954/// (Rockafellar's form):
955/// ```text
956/// L_A += Σ_i [ λ_i g_i(x) + (ρ/2) g_i(x)² ]   when  g_i(x) + λ_i/ρ > 0
957///      + 0                                        otherwise
958/// ```
959///
960/// Multipliers are updated each outer iteration:
961/// ```text
962/// λ_j ← λ_j + ρ h_j(x*)     (equality)
963/// λ_i ← max(0, λ_i + ρ g_i(x*))   (inequality)
964/// ```
965///
966/// # References
967/// - Nocedal & Wright (2006), §17.4, "Augmented Lagrangian Methods".
968/// - Bertsekas (1982), "Constrained Optimization and Lagrange Multiplier Methods".
969#[derive(Debug, Clone)]
970pub struct AugmentedLagrangianSolver {
971    /// Solver configuration.
972    pub options: AugLagOptions,
973}
974
975impl AugmentedLagrangianSolver {
976    /// Create with default options.
977    pub fn new() -> Self {
978        AugmentedLagrangianSolver {
979            options: AugLagOptions::default(),
980        }
981    }
982
983    /// Create with custom options.
984    pub fn with_options(options: AugLagOptions) -> Self {
985        AugmentedLagrangianSolver { options }
986    }
987
988    /// Compute total constraint violation (L2 norm of violations).
989    fn compute_violation<E, G>(&self, x: &[f64], eq_cons: &[E], ineq_cons: &[G]) -> f64
990    where
991        E: Fn(&[f64]) -> f64,
992        G: Fn(&[f64]) -> f64,
993    {
994        let eq_sq: f64 = eq_cons.iter().map(|h| h(x).powi(2)).sum();
995        let ineq_sq: f64 = ineq_cons.iter().map(|g| g(x).max(0.0).powi(2)).sum();
996        (eq_sq + ineq_sq).sqrt()
997    }
998
999    /// Solve the augmented Lagrangian problem.
1000    ///
1001    /// # Arguments
1002    /// * `f`        - Objective function.
1003    /// * `eq_cons`  - Equality constraints `h_j(x) = 0`.
1004    /// * `ineq_cons`- Inequality constraints `g_i(x) <= 0`.
1005    /// * `x0`       - Initial iterate.
1006    pub fn solve<F, E, G>(
1007        &self,
1008        f: F,
1009        eq_cons: &[E],
1010        ineq_cons: &[G],
1011        x0: &[f64],
1012    ) -> OptimizeResult<AugLagResult>
1013    where
1014        F: Fn(&[f64]) -> f64,
1015        E: Fn(&[f64]) -> f64,
1016        G: Fn(&[f64]) -> f64,
1017    {
1018        let n_eq = eq_cons.len();
1019        let n_ineq = ineq_cons.len();
1020
1021        // Initialise multipliers at zero
1022        let mut lambda_eq = vec![0.0_f64; n_eq];
1023        let mut lambda_ineq = vec![0.0_f64; n_ineq];
1024        let mut rho = self.options.rho_init;
1025
1026        let mut x = x0.to_vec();
1027        let mut nfev_total = 0usize;
1028        let mut nit = 0usize;
1029        let mut prev_violation = f64::INFINITY;
1030
1031        for _outer in 0..self.options.max_outer_iter {
1032            nit += 1;
1033
1034            // Snapshot multipliers and penalty for the closure
1035            let lam_eq_snap = lambda_eq.clone();
1036            let lam_ineq_snap = lambda_ineq.clone();
1037            let rho_snap = rho;
1038
1039            // Augmented Lagrangian function
1040            let aug_lag = |xv: &[f64]| -> f64 {
1041                let mut val = f(xv);
1042
1043                // Equality terms: λ_j h_j + (ρ/2) h_j²
1044                for (j, h) in eq_cons.iter().enumerate() {
1045                    let hv = h(xv);
1046                    val += lam_eq_snap[j] * hv + 0.5 * rho_snap * hv * hv;
1047                }
1048
1049                // Inequality terms (Rockafellar shifted form)
1050                // If  g_i + λ_i/ρ > 0  (active or violated):
1051                //   contribution = λ_i g_i + (ρ/2) g_i²
1052                //               = (1/(2ρ)) [ (λ_i + ρ g_i)² - λ_i² ]
1053                // If  g_i + λ_i/ρ <= 0 (inactive, well inside feasible region):
1054                //   contribution = -λ_i²/(2ρ)  [constant w.r.t. x, omit]
1055                for (i, g) in ineq_cons.iter().enumerate() {
1056                    let gv = g(xv);
1057                    let shifted = gv + lam_ineq_snap[i] / rho_snap;
1058                    if shifted > 0.0 {
1059                        val += lam_ineq_snap[i] * gv + 0.5 * rho_snap * gv * gv;
1060                    }
1061                    // else: constraint is inactive, no contribution
1062                }
1063
1064                val
1065            };
1066
1067            // Minimise the augmented Lagrangian (unconstrained sub-problem)
1068            let new_x = minimize_penalty_subproblem(
1069                aug_lag,
1070                &x,
1071                2000,
1072                self.options.optimality_tol,
1073                self.options.eps,
1074                &mut nfev_total,
1075            );
1076
1077            x = new_x;
1078
1079            // Compute current violation
1080            let cv = self.compute_violation(&x, eq_cons, ineq_cons);
1081
1082            // Check for convergence
1083            if cv <= self.options.constraint_tol {
1084                let fun = f(&x);
1085                return Ok(AugLagResult {
1086                    x: Array1::from_vec(x),
1087                    fun,
1088                    nit,
1089                    nfev: nfev_total,
1090                    success: true,
1091                    message: "Converged: constraint violation below tolerance".to_string(),
1092                    lambda_eq,
1093                    lambda_ineq,
1094                    rho,
1095                    constraint_violation: cv,
1096                });
1097            }
1098
1099            // Determine whether to update multipliers or just increase penalty
1100            let violation_improvement = if prev_violation.is_finite() && prev_violation > 1e-15 {
1101                (prev_violation - cv) / prev_violation
1102            } else {
1103                0.0
1104            };
1105
1106            if violation_improvement >= self.options.violation_reduction_threshold {
1107                // Good progress: update multipliers (dual ascent step)
1108                for (j, h) in eq_cons.iter().enumerate() {
1109                    lambda_eq[j] += rho * h(&x);
1110                }
1111                for (i, g) in ineq_cons.iter().enumerate() {
1112                    lambda_ineq[i] = (lambda_ineq[i] + rho * g(&x)).max(0.0);
1113                }
1114            } else {
1115                // Poor progress: increase penalty to force constraint satisfaction
1116                rho = (rho * self.options.rho_factor).min(self.options.rho_max);
1117            }
1118
1119            prev_violation = cv;
1120        }
1121
1122        // Max iterations reached
1123        let cv = self.compute_violation(&x, eq_cons, ineq_cons);
1124        let fun = f(&x);
1125        let success = cv <= self.options.constraint_tol;
1126        Ok(AugLagResult {
1127            x: Array1::from_vec(x),
1128            fun,
1129            nit,
1130            nfev: nfev_total,
1131            success,
1132            message: if success {
1133                "Converged".to_string()
1134            } else {
1135                format!(
1136                    "Maximum outer iterations ({}) reached; cv={:.2e}",
1137                    self.options.max_outer_iter, cv
1138                )
1139            },
1140            lambda_eq,
1141            lambda_ineq,
1142            rho,
1143            constraint_violation: cv,
1144        })
1145    }
1146}
1147
1148impl Default for AugmentedLagrangianSolver {
1149    fn default() -> Self {
1150        AugmentedLagrangianSolver::new()
1151    }
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156    use super::*;
1157    use approx::assert_abs_diff_eq;
1158
1159    #[test]
1160    fn test_penalty_equality_constraint() {
1161        // min x^2 + y^2  s.t. x + y = 1
1162        // Solution: x = y = 0.5, f = 0.5
1163        let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1164        let h = |x: &[f64]| x[0] + x[1] - 1.0;
1165
1166        let opts = PenaltyOptions {
1167            kind: PenaltyKind::External,
1168            mu_init: 1.0,
1169            mu_factor: 10.0,
1170            mu_max: 1e8,
1171            max_outer_iter: 30,
1172            constraint_tol: 1e-4,
1173            ..Default::default()
1174        };
1175        let solver = PenaltyMethod::with_options(opts);
1176        let result = solver
1177            .solve_slice(f, &[h], &[] as &[fn(&[f64]) -> f64], &[0.0, 0.0])
1178            .expect("solve failed");
1179
1180        assert_abs_diff_eq!(result.x[0], 0.5, epsilon = 1e-2);
1181        assert_abs_diff_eq!(result.x[1], 0.5, epsilon = 1e-2);
1182        assert_abs_diff_eq!(result.fun, 0.5, epsilon = 1e-2);
1183    }
1184
1185    #[test]
1186    fn test_penalty_inequality_constraint() {
1187        // min x^2 + y^2  s.t. x + y >= 1  (g: 1 - x - y <= 0)
1188        let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1189        let g = |x: &[f64]| 1.0 - x[0] - x[1]; // <= 0 means x+y >= 1
1190
1191        let opts = PenaltyOptions {
1192            kind: PenaltyKind::External,
1193            mu_init: 1.0,
1194            mu_factor: 10.0,
1195            mu_max: 1e8,
1196            max_outer_iter: 40,
1197            constraint_tol: 1e-3,
1198            ..Default::default()
1199        };
1200        let solver = PenaltyMethod::with_options(opts);
1201        let result = solver
1202            .solve_slice(f, &[] as &[fn(&[f64]) -> f64], &[g], &[2.0, 2.0])
1203            .expect("solve failed");
1204
1205        // Solution: (0.5, 0.5)
1206        assert_abs_diff_eq!(result.fun, 0.5, epsilon = 1e-2);
1207    }
1208
1209    #[test]
1210    fn test_penalty_no_constraints() {
1211        // min (x-3)^2 + (y-4)^2 unconstrained
1212        let f = |x: &[f64]| (x[0] - 3.0).powi(2) + (x[1] - 4.0).powi(2);
1213        let solver = PenaltyMethod::new();
1214        let result = solver
1215            .solve_slice(
1216                f,
1217                &[] as &[fn(&[f64]) -> f64],
1218                &[] as &[fn(&[f64]) -> f64],
1219                &[0.0, 0.0],
1220            )
1221            .expect("solve failed");
1222
1223        assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-3);
1224        assert_abs_diff_eq!(result.x[0], 3.0, epsilon = 1e-2);
1225        assert_abs_diff_eq!(result.x[1], 4.0, epsilon = 1e-2);
1226    }
1227
1228    #[test]
1229    fn test_penalty_l1_equality() {
1230        let f = |x: &[f64]| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2);
1231        let h = |x: &[f64]| x[0] + x[1] - 3.0;
1232
1233        let opts = PenaltyOptions {
1234            kind: PenaltyKind::ExactL1,
1235            mu_init: 10.0,
1236            max_outer_iter: 50,
1237            constraint_tol: 1e-3,
1238            ..Default::default()
1239        };
1240        let solver = PenaltyMethod::with_options(opts);
1241        let result = solver
1242            .solve_slice(f, &[h], &[] as &[fn(&[f64]) -> f64], &[0.0, 0.0])
1243            .expect("solve failed");
1244
1245        // Solution: (1, 2) satisfies x+y=3
1246        assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-2);
1247    }
1248
1249    #[test]
1250    fn test_penalty_interior_barrier() {
1251        // min (x-0.5)^2  s.t. x < 1 (g: x - 0.999 <= 0)
1252        // Start strictly inside: x0 = 0.3
1253        let f = |x: &[f64]| (x[0] - 0.5).powi(2);
1254        let g = |x: &[f64]| x[0] - 0.999;
1255        let opts = PenaltyOptions {
1256            kind: PenaltyKind::Interior,
1257            mu_init: 0.1,
1258            mu_factor: 5.0,
1259            max_outer_iter: 20,
1260            ..Default::default()
1261        };
1262        let solver = PenaltyMethod::with_options(opts);
1263        let result = solver
1264            .solve_slice(f, &[] as &[fn(&[f64]) -> f64], &[g], &[0.3])
1265            .expect("solve failed");
1266
1267        assert_abs_diff_eq!(result.x[0], 0.5, epsilon = 0.1);
1268    }
1269
1270    #[test]
1271    fn test_penalty_mixed_constraints() {
1272        // min x^2 + y^2 + z^2  s.t. x+y+z=3, z<=1
1273        let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2) + x[2].powi(2);
1274        let h = |x: &[f64]| x[0] + x[1] + x[2] - 3.0;
1275        let g = |x: &[f64]| x[2] - 1.0;
1276
1277        let opts = PenaltyOptions {
1278            kind: PenaltyKind::External,
1279            mu_init: 1.0,
1280            mu_factor: 5.0,
1281            mu_max: 1e7,
1282            max_outer_iter: 50,
1283            constraint_tol: 1e-3,
1284            ..Default::default()
1285        };
1286        let solver = PenaltyMethod::with_options(opts);
1287        let result = solver
1288            .solve_slice(f, &[h], &[g], &[0.0, 0.0, 0.0])
1289            .expect("solve failed");
1290
1291        // With z<=1 and x+y+z=3: optimal at z=1, x=y=1 -> f=3
1292        assert!(result.fun <= 4.0, "fun={}", result.fun);
1293    }
1294
1295    #[test]
1296    fn test_penalty_arrayview1_interface() {
1297        use scirs2_core::ndarray::{array, ArrayView1};
1298        let f = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
1299        let h = |x: &ArrayView1<f64>| x[0] + x[1] - 1.0;
1300
1301        let opts = PenaltyOptions {
1302            kind: PenaltyKind::External,
1303            max_outer_iter: 20,
1304            constraint_tol: 1e-3,
1305            ..Default::default()
1306        };
1307        let solver = PenaltyMethod::with_options(opts);
1308        let x0 = array![0.0, 0.0];
1309        let result = solver
1310            .solve(f, &[h], &[] as &[fn(&ArrayView1<f64>) -> f64], &x0)
1311            .expect("solve failed");
1312
1313        // Solution: (0.5, 0.5)
1314        assert_abs_diff_eq!(result.fun, 0.5, epsilon = 5e-2);
1315    }
1316
1317    // ── penalty_function free function ───────────────────────────────────────
1318
1319    #[test]
1320    fn test_penalty_function_external_feasible() {
1321        let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1322        let g = |x: &[f64]| x[0] + x[1] - 1.0; // g <= 0 means x+y <= 1; at (0,0) g = -1 <= 0
1323
1324        // At feasible point (0,0): g(x) = -1 < 0, no violation
1325        let val = penalty_function(
1326            &[0.0, 0.0],
1327            f,
1328            &[g],
1329            &[] as &[fn(&[f64]) -> f64],
1330            10.0,
1331            PenaltyKind::External,
1332        );
1333        // No penalty applied: val = f(0,0) = 0
1334        assert_abs_diff_eq!(val, 0.0, epsilon = 1e-12);
1335    }
1336
1337    #[test]
1338    fn test_penalty_function_external_violated() {
1339        let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1340        let g = |x: &[f64]| x[0] + x[1] - 1.0; // at (1,1): g = 1 > 0, violated
1341
1342        let mu = 5.0;
1343        let val = penalty_function(
1344            &[1.0, 1.0],
1345            f,
1346            &[g],
1347            &[] as &[fn(&[f64]) -> f64],
1348            mu,
1349            PenaltyKind::External,
1350        );
1351        // f = 2, penalty = mu * max(0, 1)^2 = 5
1352        assert_abs_diff_eq!(val, 2.0 + mu * 1.0_f64.powi(2), epsilon = 1e-12);
1353    }
1354
1355    #[test]
1356    fn test_penalty_function_equality() {
1357        let f = |_x: &[f64]| 0.0;
1358        let h = |x: &[f64]| x[0] - 1.0; // h = 0 at x=1; at x=2 h=1
1359
1360        let mu = 3.0;
1361        let val = penalty_function(
1362            &[2.0],
1363            f,
1364            &[] as &[fn(&[f64]) -> f64],
1365            &[h],
1366            mu,
1367            PenaltyKind::External,
1368        );
1369        // penalty = mu * h^2 = 3 * 1 = 3
1370        assert_abs_diff_eq!(val, 3.0, epsilon = 1e-12);
1371    }
1372
1373    // ── AdaptivePenalty ──────────────────────────────────────────────────────
1374
1375    #[test]
1376    fn test_adaptive_penalty_increases_on_stall() {
1377        let mut ap = AdaptivePenalty::new(1.0);
1378        ap.increase_factor = 5.0;
1379        ap.improvement_threshold = 0.1;
1380        ap.stall_patience = 1;
1381
1382        // First call: records violation, returns current penalty
1383        let p0 = ap.update(1.0);
1384        assert_abs_diff_eq!(p0, 1.0, epsilon = 1e-12);
1385
1386        // Second call: violation is 0.95 (only 5% improvement < threshold 10%)
1387        // => stall_count=1 >= patience=1 => penalty should increase
1388        let p1 = ap.update(0.95);
1389        assert!(p1 > 1.0, "Penalty should have increased; got {p1}");
1390    }
1391
1392    #[test]
1393    fn test_adaptive_penalty_no_increase_on_good_progress() {
1394        let mut ap = AdaptivePenalty::new(1.0);
1395        ap.improvement_threshold = 0.1;
1396        ap.stall_patience = 1;
1397
1398        ap.update(1.0); // seed prev_violation
1399        let p1 = ap.update(0.5); // 50% improvement, well above 10% threshold
1400                                 // Penalty should NOT increase (allow_decrease=false by default)
1401        assert_abs_diff_eq!(p1, 1.0, epsilon = 1e-12);
1402    }
1403
1404    #[test]
1405    fn test_adaptive_penalty_capped_at_max() {
1406        let mut ap = AdaptivePenalty::with_config(1e9, 1e-3, 1e10, 1000.0, 0.5, 0.1, false, 1);
1407        ap.update(1.0);
1408        let p = ap.update(1.0); // no improvement → multiply by 1000 → capped at 1e10
1409        assert!(p <= 1e10, "Penalty exceeded max; got {p}");
1410    }
1411
1412    // ── AugmentedLagrangianSolver ────────────────────────────────────────────
1413
1414    #[test]
1415    fn test_aug_lag_equality_constraint() {
1416        // min x^2 + y^2  s.t. x + y = 1
1417        // Solution: x = y = 0.5, f = 0.5
1418        let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
1419        let h = |x: &[f64]| x[0] + x[1] - 1.0;
1420
1421        let opts = AugLagOptions {
1422            rho_init: 1.0,
1423            rho_factor: 5.0,
1424            max_outer_iter: 50,
1425            constraint_tol: 1e-4,
1426            ..Default::default()
1427        };
1428        let solver = AugmentedLagrangianSolver::with_options(opts);
1429        let result = solver
1430            .solve(f, &[h], &[] as &[fn(&[f64]) -> f64], &[0.0, 0.0])
1431            .expect("AugLag solve failed");
1432
1433        assert!(
1434            result.success || result.constraint_violation < 1e-2,
1435            "cv={}",
1436            result.constraint_violation
1437        );
1438        assert_abs_diff_eq!(result.fun, 0.5, epsilon = 0.05);
1439    }
1440
1441    #[test]
1442    fn test_aug_lag_inequality_constraint() {
1443        // min x^2  s.t. x >= 1  =>  g(x) = 1 - x <= 0
1444        // Solution: x = 1
1445        let f = |x: &[f64]| x[0].powi(2);
1446        let g = |x: &[f64]| 1.0 - x[0]; // <= 0 means x >= 1
1447
1448        let opts = AugLagOptions {
1449            rho_init: 1.0,
1450            rho_factor: 10.0,
1451            max_outer_iter: 60,
1452            constraint_tol: 1e-3,
1453            ..Default::default()
1454        };
1455        let solver = AugmentedLagrangianSolver::with_options(opts);
1456        let result = solver
1457            .solve(f, &[] as &[fn(&[f64]) -> f64], &[g], &[0.5])
1458            .expect("AugLag solve failed");
1459
1460        // x should be near 1.0
1461        assert!(
1462            result.x[0] >= 0.9 && result.x[0] <= 1.2,
1463            "Expected x~1.0, got x={}",
1464            result.x[0]
1465        );
1466    }
1467
1468    #[test]
1469    fn test_aug_lag_no_constraints() {
1470        // Unconstrained: min (x-2)^2 + (y-3)^2
1471        let f = |x: &[f64]| (x[0] - 2.0).powi(2) + (x[1] - 3.0).powi(2);
1472        let solver = AugmentedLagrangianSolver::new();
1473        let result = solver
1474            .solve(
1475                f,
1476                &[] as &[fn(&[f64]) -> f64],
1477                &[] as &[fn(&[f64]) -> f64],
1478                &[0.0, 0.0],
1479            )
1480            .expect("AugLag solve failed");
1481
1482        assert_abs_diff_eq!(result.x[0], 2.0, epsilon = 1e-2);
1483        assert_abs_diff_eq!(result.x[1], 3.0, epsilon = 1e-2);
1484        assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-3);
1485    }
1486
1487    #[test]
1488    fn test_penalty_method_kind_variants() {
1489        // Just verify the enum can be constructed and compared
1490        assert_eq!(PenaltyMethodKind::Static, PenaltyMethodKind::Static);
1491        assert_ne!(PenaltyMethodKind::Dynamic, PenaltyMethodKind::Adaptive);
1492        assert_ne!(
1493            PenaltyMethodKind::AugmentedLagrangian,
1494            PenaltyMethodKind::Static
1495        );
1496    }
1497}