Skip to main content

numra_optim/
augmented_lagrangian.rs

1//! Augmented Lagrangian method for constrained optimization.
2//!
3//! Converts a constrained problem into a sequence of unconstrained (or
4//! bound-constrained) subproblems by penalizing constraint violations
5//! and maintaining Lagrange multiplier estimates.
6//!
7//! Author: Moussa Leblouba
8//! Date: 8 February 2026
9//! Modified: 2 May 2026
10
11use numra_core::Scalar;
12
13use crate::error::OptimError;
14use crate::lbfgs::{lbfgs_minimize, LbfgsOptions};
15use crate::lbfgsb::{lbfgsb_minimize, LbfgsBOptions};
16use crate::problem::{
17    finite_diff_gradient, Constraint, ConstraintKind, ObjectiveKind, OptimProblem,
18};
19use crate::types::{IterationRecord, OptimOptions, OptimResult, OptimStatus};
20
21/// Options for the Augmented Lagrangian outer loop.
22#[derive(Clone, Debug)]
23pub struct AugLagOptions<S: Scalar> {
24    pub inner_opts: OptimOptions<S>,
25    pub max_outer_iter: usize,
26    pub sigma_init: S,
27    pub sigma_factor: S,
28    pub sigma_max: S,
29    pub ctol: S,
30}
31
32impl<S: Scalar> Default for AugLagOptions<S> {
33    fn default() -> Self {
34        Self {
35            inner_opts: OptimOptions::default().max_iter(500),
36            max_outer_iter: 50,
37            sigma_init: S::ONE,
38            sigma_factor: S::from_f64(10.0),
39            sigma_max: S::from_f64(1e12),
40            ctol: S::from_f64(1e-6),
41        }
42    }
43}
44
45/// Solve a constrained optimization problem using the Augmented Lagrangian method.
46///
47/// Expects `problem` to have a `Minimize` objective and constraints.
48pub fn augmented_lagrangian_minimize<S: Scalar>(
49    problem: OptimProblem<S>,
50    opts: &AugLagOptions<S>,
51) -> Result<OptimResult<S>, OptimError> {
52    let start = std::time::Instant::now();
53    // Destructure to avoid partial move issues
54    let OptimProblem {
55        n,
56        x0,
57        bounds,
58        objective,
59        constraints,
60        ..
61    } = problem;
62
63    let x0 = x0.ok_or(OptimError::NoInitialPoint)?;
64
65    // Extract objective
66    let (obj_func, obj_grad) = match objective {
67        Some(ObjectiveKind::Minimize { func, grad }) => (func, grad),
68        Some(ObjectiveKind::LeastSquares { .. }) => {
69            return Err(OptimError::Other(
70                "augmented Lagrangian requires scalar objective, not least squares".into(),
71            ));
72        }
73        Some(ObjectiveKind::Linear { .. }) | Some(ObjectiveKind::Quadratic { .. }) => {
74            return Err(OptimError::Other(
75                "augmented Lagrangian requires scalar objective; use Simplex for LP or ActiveSetQP for QP".into(),
76            ));
77        }
78        Some(ObjectiveKind::MultiObjective { .. }) => {
79            return Err(OptimError::Other(
80                "augmented Lagrangian requires scalar objective; use NSGA-II for multi-objective"
81                    .into(),
82            ));
83        }
84        None => return Err(OptimError::NoObjective),
85    };
86
87    let has_bounds = bounds.iter().any(|b| b.is_some());
88
89    // Separate equality and inequality constraints
90    let eq_constraints: Vec<&Constraint<S>> = constraints
91        .iter()
92        .filter(|c| c.kind == ConstraintKind::Equality)
93        .collect();
94    let ineq_constraints: Vec<&Constraint<S>> = constraints
95        .iter()
96        .filter(|c| c.kind == ConstraintKind::Inequality)
97        .collect();
98
99    let n_eq = eq_constraints.len();
100    let n_ineq = ineq_constraints.len();
101
102    // Multiplier estimates
103    let mut lambda_eq = vec![S::ZERO; n_eq];
104    let mut mu_ineq = vec![S::ZERO; n_ineq];
105    let mut sigma = opts.sigma_init;
106    let mut x = x0;
107
108    let mut total_feval = 0_usize;
109    let mut total_geval = 0_usize;
110    let mut history = Vec::new();
111
112    let two = S::TWO;
113
114    for outer in 0..opts.max_outer_iter {
115        // Build augmented Lagrangian subproblem
116        let lam_eq = lambda_eq.clone();
117        let mu_in = mu_ineq.clone();
118        let sig = sigma;
119
120        let aug_f = |xv: &[S]| -> S {
121            let mut val = (obj_func)(xv);
122
123            // Equality: lambda_j * h_j(x) + (sigma/2) * h_j(x)^2
124            for (j, c) in eq_constraints.iter().enumerate() {
125                let h = (c.func)(xv);
126                val = val + lam_eq[j] * h + (sig / two) * h * h;
127            }
128
129            // Inequality: (sigma/2) * max(0, g_i(x) + mu_i/sigma)^2 - mu_i^2/(2*sigma)
130            for (i, c) in ineq_constraints.iter().enumerate() {
131                let g = (c.func)(xv);
132                let shifted = g + mu_in[i] / sig;
133                if shifted > S::ZERO {
134                    val = val + (sig / two) * shifted * shifted - mu_in[i] * mu_in[i] / (two * sig);
135                }
136            }
137
138            val
139        };
140
141        let aug_grad = |xv: &[S], gout: &mut [S]| {
142            // Base gradient
143            if let Some(ref og) = obj_grad {
144                og(xv, gout);
145            } else {
146                finite_diff_gradient(&*obj_func, xv, gout);
147            }
148
149            let mut cgrad = vec![S::ZERO; n];
150
151            // Equality constraint gradients
152            for (j, c) in eq_constraints.iter().enumerate() {
153                let h = (c.func)(xv);
154                let mult = lam_eq[j] + sig * h;
155                if let Some(ref cg) = c.grad {
156                    cg(xv, &mut cgrad);
157                } else {
158                    finite_diff_gradient(&*c.func, xv, &mut cgrad);
159                }
160                for k in 0..n {
161                    gout[k] += mult * cgrad[k];
162                }
163            }
164
165            // Inequality constraint gradients
166            for (i, c) in ineq_constraints.iter().enumerate() {
167                let g_val = (c.func)(xv);
168                let shifted = g_val + mu_in[i] / sig;
169                if shifted > S::ZERO {
170                    let mult = sig * shifted;
171                    if let Some(ref cg) = c.grad {
172                        cg(xv, &mut cgrad);
173                    } else {
174                        finite_diff_gradient(&*c.func, xv, &mut cgrad);
175                    }
176                    for k in 0..n {
177                        gout[k] += mult * cgrad[k];
178                    }
179                }
180            }
181        };
182
183        // Solve subproblem
184        let sub_result = if has_bounds {
185            let sub_opts = LbfgsBOptions {
186                base: opts.inner_opts.clone(),
187                memory: 10,
188            };
189            lbfgsb_minimize(aug_f, aug_grad, &x, &bounds, &sub_opts)?
190        } else {
191            let sub_opts = LbfgsOptions {
192                base: opts.inner_opts.clone(),
193                memory: 10,
194            };
195            lbfgs_minimize(aug_f, aug_grad, &x, &sub_opts)?
196        };
197
198        total_feval += sub_result.n_feval;
199        total_geval += sub_result.n_geval;
200        x = sub_result.x;
201
202        // Compute constraint violations and update multipliers
203        let mut max_violation = S::ZERO;
204
205        for (j, c) in eq_constraints.iter().enumerate() {
206            let h = (c.func)(&x);
207            let abs_h = h.abs();
208            if abs_h > max_violation {
209                max_violation = abs_h;
210            }
211            lambda_eq[j] += sigma * h;
212        }
213
214        for (i, c) in ineq_constraints.iter().enumerate() {
215            let g_val = (c.func)(&x);
216            let shifted = g_val + mu_ineq[i] / sigma;
217            if shifted > S::ZERO {
218                let g_pos = if g_val > S::ZERO { g_val } else { S::ZERO };
219                if g_pos > max_violation {
220                    max_violation = g_pos;
221                }
222                let new_mu = mu_ineq[i] + sigma * g_val;
223                mu_ineq[i] = if new_mu > S::ZERO { new_mu } else { S::ZERO };
224            } else {
225                mu_ineq[i] = S::ZERO;
226            }
227        }
228
229        history.push(IterationRecord {
230            iteration: outer,
231            objective: (obj_func)(&x),
232            gradient_norm: S::ZERO,
233            step_size: sigma,
234            constraint_violation: max_violation,
235        });
236
237        // Check convergence
238        if max_violation < opts.ctol {
239            let fval = (obj_func)(&x);
240            let mut g_buf = vec![S::ZERO; n];
241            if let Some(ref og) = obj_grad {
242                og(&x, &mut g_buf);
243            } else {
244                finite_diff_gradient(&*obj_func, &x, &mut g_buf);
245            }
246
247            return Ok((OptimResult {
248                lambda_eq,
249                lambda_ineq: mu_ineq,
250                constraint_violation: max_violation,
251                history,
252                ..OptimResult::unconstrained(
253                    x,
254                    fval,
255                    g_buf,
256                    outer + 1,
257                    total_feval,
258                    total_geval,
259                    true,
260                    format!(
261                        "Converged: constraint violation {:.2e} after {} outer iterations",
262                        max_violation.to_f64(),
263                        outer + 1
264                    ),
265                    OptimStatus::GradientConverged,
266                )
267            })
268            .with_wall_time(start));
269        }
270
271        // Increase penalty
272        sigma *= opts.sigma_factor;
273        if sigma > opts.sigma_max {
274            sigma = opts.sigma_max;
275        }
276    }
277
278    // Did not converge
279    let max_violation: S = eq_constraints
280        .iter()
281        .map(|c| (c.func)(&x).abs())
282        .chain(ineq_constraints.iter().map(|c| {
283            let v = (c.func)(&x);
284            if v > S::ZERO {
285                v
286            } else {
287                S::ZERO
288            }
289        }))
290        .fold(S::ZERO, |a, b| if b > a { b } else { a });
291
292    if max_violation.to_f64() > 0.1 {
293        return Err(OptimError::Infeasible {
294            violation: max_violation.to_f64(),
295        });
296    }
297
298    let fval = (obj_func)(&x);
299    let mut g_buf = vec![S::ZERO; n];
300    if let Some(ref og) = obj_grad {
301        og(&x, &mut g_buf);
302    } else {
303        finite_diff_gradient(&*obj_func, &x, &mut g_buf);
304    }
305
306    Ok((OptimResult {
307        lambda_eq,
308        lambda_ineq: mu_ineq,
309        constraint_violation: max_violation,
310        history,
311        ..OptimResult::unconstrained(
312            x,
313            fval,
314            g_buf,
315            opts.max_outer_iter,
316            total_feval,
317            total_geval,
318            false,
319            format!(
320                "Maximum outer iterations ({}) reached, violation={:.2e}",
321                opts.max_outer_iter,
322                max_violation.to_f64()
323            ),
324            OptimStatus::MaxIterations,
325        )
326    })
327    .with_wall_time(start))
328}
329
330#[cfg(test)]
331mod tests {
332    use crate::problem::OptimProblem;
333
334    #[test]
335    fn test_equality_constrained_circle() {
336        // minimize x0 + x1 subject to x0^2 + x1^2 = 1
337        // Lagrangian: min at x = (-1/sqrt(2), -1/sqrt(2))
338        let result = OptimProblem::new(2)
339            .x0(&[1.0, 0.0])
340            .objective(|x: &[f64]| x[0] + x[1])
341            .gradient(|x: &[f64], g: &mut [f64]| {
342                g[0] = 1.0;
343                g[1] = 1.0;
344                let _ = x;
345            })
346            .constraint_eq_with_grad(
347                |x: &[f64]| x[0] * x[0] + x[1] * x[1] - 1.0,
348                |x: &[f64], g: &mut [f64]| {
349                    g[0] = 2.0 * x[0];
350                    g[1] = 2.0 * x[1];
351                },
352            )
353            .solve()
354            .unwrap();
355
356        assert!(result.converged, "did not converge: {}", result.message);
357        let expected = -1.0 / 2.0_f64.sqrt();
358        assert!(
359            (result.x[0] - expected).abs() < 1e-3,
360            "x0={}, expected {}",
361            result.x[0],
362            expected
363        );
364        assert!(
365            (result.x[1] - expected).abs() < 1e-3,
366            "x1={}, expected {}",
367            result.x[1],
368            expected
369        );
370        assert!(result.constraint_violation < 1e-5);
371    }
372
373    #[test]
374    fn test_inequality_constrained() {
375        // minimize (x0-2)^2 + (x1-2)^2 subject to x0 + x1 <= 2
376        // i.e. constraint: x0 + x1 - 2 <= 0
377        // Unconstrained min at (2,2), but constrained to x0+x1=2 => min at (1,1)
378        let result = OptimProblem::new(2)
379            .x0(&[0.0, 0.0])
380            .objective(|x: &[f64]| (x[0] - 2.0).powi(2) + (x[1] - 2.0).powi(2))
381            .gradient(|x: &[f64], g: &mut [f64]| {
382                g[0] = 2.0 * (x[0] - 2.0);
383                g[1] = 2.0 * (x[1] - 2.0);
384            })
385            .constraint_ineq_with_grad(
386                |x: &[f64]| x[0] + x[1] - 2.0,
387                |_x: &[f64], g: &mut [f64]| {
388                    g[0] = 1.0;
389                    g[1] = 1.0;
390                },
391            )
392            .solve()
393            .unwrap();
394
395        assert!(result.converged, "did not converge: {}", result.message);
396        assert!(
397            (result.x[0] - 1.0).abs() < 1e-2,
398            "x0={}, expected 1.0",
399            result.x[0]
400        );
401        assert!(
402            (result.x[1] - 1.0).abs() < 1e-2,
403            "x1={}, expected 1.0",
404            result.x[1]
405        );
406    }
407
408    #[test]
409    fn test_mixed_constraints() {
410        // minimize x0^2 + x1^2
411        // subject to: x0 + x1 = 1 (equality)
412        //             x0 >= 0.6   i.e. 0.6 - x0 <= 0 (inequality, active)
413        // Without ineq: x0=0.5, x1=0.5
414        // With ineq (active): x0=0.6, x1=0.4
415        let result = OptimProblem::new(2)
416            .x0(&[1.0, 1.0])
417            .objective(|x: &[f64]| x[0] * x[0] + x[1] * x[1])
418            .gradient(|x: &[f64], g: &mut [f64]| {
419                g[0] = 2.0 * x[0];
420                g[1] = 2.0 * x[1];
421            })
422            .constraint_eq_with_grad(
423                |x: &[f64]| x[0] + x[1] - 1.0,
424                |_x: &[f64], g: &mut [f64]| {
425                    g[0] = 1.0;
426                    g[1] = 1.0;
427                },
428            )
429            .constraint_ineq_with_grad(
430                |x: &[f64]| 0.6 - x[0],
431                |_x: &[f64], g: &mut [f64]| {
432                    g[0] = -1.0;
433                    g[1] = 0.0;
434                },
435            )
436            .solve()
437            .unwrap();
438
439        assert!(result.converged, "did not converge: {}", result.message);
440        assert!(
441            (result.x[0] - 0.6).abs() < 5e-2,
442            "x0={}, expected 0.6",
443            result.x[0]
444        );
445        assert!(
446            (result.x[1] - 0.4).abs() < 5e-2,
447            "x1={}, expected 0.4",
448            result.x[1]
449        );
450        assert!(result.constraint_violation < 1e-3);
451    }
452
453    #[test]
454    fn test_aug_lag_custom_options() {
455        use crate::augmented_lagrangian::AugLagOptions;
456        let opts = AugLagOptions {
457            sigma_init: 10.0,
458            ctol: 1e-8,
459            ..AugLagOptions::default()
460        };
461        let result = OptimProblem::new(2)
462            .x0(&[1.0, 0.0])
463            .objective(|x: &[f64]| x[0] + x[1])
464            .gradient(|x: &[f64], g: &mut [f64]| {
465                g[0] = 1.0;
466                g[1] = 1.0;
467                let _ = x;
468            })
469            .constraint_eq(|x: &[f64]| x[0] * x[0] + x[1] * x[1] - 1.0)
470            .aug_lag_options(opts)
471            .solve()
472            .unwrap();
473        assert!(result.converged, "did not converge: {}", result.message);
474        assert!(result.constraint_violation < 1e-7);
475    }
476}