1use numra_core::Scalar;
12
13use crate::error::OptimError;
14use crate::lbfgs::{lbfgs_minimize, LbfgsOptions};
15use crate::lbfgsb::{lbfgsb_minimize, LbfgsBOptions};
16use crate::problem::{
17 finite_diff_gradient, Constraint, ConstraintKind, ObjectiveKind, OptimProblem,
18};
19use crate::types::{IterationRecord, OptimOptions, OptimResult, OptimStatus};
20
21#[derive(Clone, Debug)]
23pub struct AugLagOptions<S: Scalar> {
24 pub inner_opts: OptimOptions<S>,
25 pub max_outer_iter: usize,
26 pub sigma_init: S,
27 pub sigma_factor: S,
28 pub sigma_max: S,
29 pub ctol: S,
30}
31
32impl<S: Scalar> Default for AugLagOptions<S> {
33 fn default() -> Self {
34 Self {
35 inner_opts: OptimOptions::default().max_iter(500),
36 max_outer_iter: 50,
37 sigma_init: S::ONE,
38 sigma_factor: S::from_f64(10.0),
39 sigma_max: S::from_f64(1e12),
40 ctol: S::from_f64(1e-6),
41 }
42 }
43}
44
45pub fn augmented_lagrangian_minimize<S: Scalar>(
49 problem: OptimProblem<S>,
50 opts: &AugLagOptions<S>,
51) -> Result<OptimResult<S>, OptimError> {
52 let start = std::time::Instant::now();
53 let OptimProblem {
55 n,
56 x0,
57 bounds,
58 objective,
59 constraints,
60 ..
61 } = problem;
62
63 let x0 = x0.ok_or(OptimError::NoInitialPoint)?;
64
65 let (obj_func, obj_grad) = match objective {
67 Some(ObjectiveKind::Minimize { func, grad }) => (func, grad),
68 Some(ObjectiveKind::LeastSquares { .. }) => {
69 return Err(OptimError::Other(
70 "augmented Lagrangian requires scalar objective, not least squares".into(),
71 ));
72 }
73 Some(ObjectiveKind::Linear { .. }) | Some(ObjectiveKind::Quadratic { .. }) => {
74 return Err(OptimError::Other(
75 "augmented Lagrangian requires scalar objective; use Simplex for LP or ActiveSetQP for QP".into(),
76 ));
77 }
78 Some(ObjectiveKind::MultiObjective { .. }) => {
79 return Err(OptimError::Other(
80 "augmented Lagrangian requires scalar objective; use NSGA-II for multi-objective"
81 .into(),
82 ));
83 }
84 None => return Err(OptimError::NoObjective),
85 };
86
87 let has_bounds = bounds.iter().any(|b| b.is_some());
88
89 let eq_constraints: Vec<&Constraint<S>> = constraints
91 .iter()
92 .filter(|c| c.kind == ConstraintKind::Equality)
93 .collect();
94 let ineq_constraints: Vec<&Constraint<S>> = constraints
95 .iter()
96 .filter(|c| c.kind == ConstraintKind::Inequality)
97 .collect();
98
99 let n_eq = eq_constraints.len();
100 let n_ineq = ineq_constraints.len();
101
102 let mut lambda_eq = vec![S::ZERO; n_eq];
104 let mut mu_ineq = vec![S::ZERO; n_ineq];
105 let mut sigma = opts.sigma_init;
106 let mut x = x0;
107
108 let mut total_feval = 0_usize;
109 let mut total_geval = 0_usize;
110 let mut history = Vec::new();
111
112 let two = S::TWO;
113
114 for outer in 0..opts.max_outer_iter {
115 let lam_eq = lambda_eq.clone();
117 let mu_in = mu_ineq.clone();
118 let sig = sigma;
119
120 let aug_f = |xv: &[S]| -> S {
121 let mut val = (obj_func)(xv);
122
123 for (j, c) in eq_constraints.iter().enumerate() {
125 let h = (c.func)(xv);
126 val = val + lam_eq[j] * h + (sig / two) * h * h;
127 }
128
129 for (i, c) in ineq_constraints.iter().enumerate() {
131 let g = (c.func)(xv);
132 let shifted = g + mu_in[i] / sig;
133 if shifted > S::ZERO {
134 val = val + (sig / two) * shifted * shifted - mu_in[i] * mu_in[i] / (two * sig);
135 }
136 }
137
138 val
139 };
140
141 let aug_grad = |xv: &[S], gout: &mut [S]| {
142 if let Some(ref og) = obj_grad {
144 og(xv, gout);
145 } else {
146 finite_diff_gradient(&*obj_func, xv, gout);
147 }
148
149 let mut cgrad = vec![S::ZERO; n];
150
151 for (j, c) in eq_constraints.iter().enumerate() {
153 let h = (c.func)(xv);
154 let mult = lam_eq[j] + sig * h;
155 if let Some(ref cg) = c.grad {
156 cg(xv, &mut cgrad);
157 } else {
158 finite_diff_gradient(&*c.func, xv, &mut cgrad);
159 }
160 for k in 0..n {
161 gout[k] += mult * cgrad[k];
162 }
163 }
164
165 for (i, c) in ineq_constraints.iter().enumerate() {
167 let g_val = (c.func)(xv);
168 let shifted = g_val + mu_in[i] / sig;
169 if shifted > S::ZERO {
170 let mult = sig * shifted;
171 if let Some(ref cg) = c.grad {
172 cg(xv, &mut cgrad);
173 } else {
174 finite_diff_gradient(&*c.func, xv, &mut cgrad);
175 }
176 for k in 0..n {
177 gout[k] += mult * cgrad[k];
178 }
179 }
180 }
181 };
182
183 let sub_result = if has_bounds {
185 let sub_opts = LbfgsBOptions {
186 base: opts.inner_opts.clone(),
187 memory: 10,
188 };
189 lbfgsb_minimize(aug_f, aug_grad, &x, &bounds, &sub_opts)?
190 } else {
191 let sub_opts = LbfgsOptions {
192 base: opts.inner_opts.clone(),
193 memory: 10,
194 };
195 lbfgs_minimize(aug_f, aug_grad, &x, &sub_opts)?
196 };
197
198 total_feval += sub_result.n_feval;
199 total_geval += sub_result.n_geval;
200 x = sub_result.x;
201
202 let mut max_violation = S::ZERO;
204
205 for (j, c) in eq_constraints.iter().enumerate() {
206 let h = (c.func)(&x);
207 let abs_h = h.abs();
208 if abs_h > max_violation {
209 max_violation = abs_h;
210 }
211 lambda_eq[j] += sigma * h;
212 }
213
214 for (i, c) in ineq_constraints.iter().enumerate() {
215 let g_val = (c.func)(&x);
216 let shifted = g_val + mu_ineq[i] / sigma;
217 if shifted > S::ZERO {
218 let g_pos = if g_val > S::ZERO { g_val } else { S::ZERO };
219 if g_pos > max_violation {
220 max_violation = g_pos;
221 }
222 let new_mu = mu_ineq[i] + sigma * g_val;
223 mu_ineq[i] = if new_mu > S::ZERO { new_mu } else { S::ZERO };
224 } else {
225 mu_ineq[i] = S::ZERO;
226 }
227 }
228
229 history.push(IterationRecord {
230 iteration: outer,
231 objective: (obj_func)(&x),
232 gradient_norm: S::ZERO,
233 step_size: sigma,
234 constraint_violation: max_violation,
235 });
236
237 if max_violation < opts.ctol {
239 let fval = (obj_func)(&x);
240 let mut g_buf = vec![S::ZERO; n];
241 if let Some(ref og) = obj_grad {
242 og(&x, &mut g_buf);
243 } else {
244 finite_diff_gradient(&*obj_func, &x, &mut g_buf);
245 }
246
247 return Ok((OptimResult {
248 lambda_eq,
249 lambda_ineq: mu_ineq,
250 constraint_violation: max_violation,
251 history,
252 ..OptimResult::unconstrained(
253 x,
254 fval,
255 g_buf,
256 outer + 1,
257 total_feval,
258 total_geval,
259 true,
260 format!(
261 "Converged: constraint violation {:.2e} after {} outer iterations",
262 max_violation.to_f64(),
263 outer + 1
264 ),
265 OptimStatus::GradientConverged,
266 )
267 })
268 .with_wall_time(start));
269 }
270
271 sigma *= opts.sigma_factor;
273 if sigma > opts.sigma_max {
274 sigma = opts.sigma_max;
275 }
276 }
277
278 let max_violation: S = eq_constraints
280 .iter()
281 .map(|c| (c.func)(&x).abs())
282 .chain(ineq_constraints.iter().map(|c| {
283 let v = (c.func)(&x);
284 if v > S::ZERO {
285 v
286 } else {
287 S::ZERO
288 }
289 }))
290 .fold(S::ZERO, |a, b| if b > a { b } else { a });
291
292 if max_violation.to_f64() > 0.1 {
293 return Err(OptimError::Infeasible {
294 violation: max_violation.to_f64(),
295 });
296 }
297
298 let fval = (obj_func)(&x);
299 let mut g_buf = vec![S::ZERO; n];
300 if let Some(ref og) = obj_grad {
301 og(&x, &mut g_buf);
302 } else {
303 finite_diff_gradient(&*obj_func, &x, &mut g_buf);
304 }
305
306 Ok((OptimResult {
307 lambda_eq,
308 lambda_ineq: mu_ineq,
309 constraint_violation: max_violation,
310 history,
311 ..OptimResult::unconstrained(
312 x,
313 fval,
314 g_buf,
315 opts.max_outer_iter,
316 total_feval,
317 total_geval,
318 false,
319 format!(
320 "Maximum outer iterations ({}) reached, violation={:.2e}",
321 opts.max_outer_iter,
322 max_violation.to_f64()
323 ),
324 OptimStatus::MaxIterations,
325 )
326 })
327 .with_wall_time(start))
328}
329
330#[cfg(test)]
331mod tests {
332 use crate::problem::OptimProblem;
333
334 #[test]
335 fn test_equality_constrained_circle() {
336 let result = OptimProblem::new(2)
339 .x0(&[1.0, 0.0])
340 .objective(|x: &[f64]| x[0] + x[1])
341 .gradient(|x: &[f64], g: &mut [f64]| {
342 g[0] = 1.0;
343 g[1] = 1.0;
344 let _ = x;
345 })
346 .constraint_eq_with_grad(
347 |x: &[f64]| x[0] * x[0] + x[1] * x[1] - 1.0,
348 |x: &[f64], g: &mut [f64]| {
349 g[0] = 2.0 * x[0];
350 g[1] = 2.0 * x[1];
351 },
352 )
353 .solve()
354 .unwrap();
355
356 assert!(result.converged, "did not converge: {}", result.message);
357 let expected = -1.0 / 2.0_f64.sqrt();
358 assert!(
359 (result.x[0] - expected).abs() < 1e-3,
360 "x0={}, expected {}",
361 result.x[0],
362 expected
363 );
364 assert!(
365 (result.x[1] - expected).abs() < 1e-3,
366 "x1={}, expected {}",
367 result.x[1],
368 expected
369 );
370 assert!(result.constraint_violation < 1e-5);
371 }
372
373 #[test]
374 fn test_inequality_constrained() {
375 let result = OptimProblem::new(2)
379 .x0(&[0.0, 0.0])
380 .objective(|x: &[f64]| (x[0] - 2.0).powi(2) + (x[1] - 2.0).powi(2))
381 .gradient(|x: &[f64], g: &mut [f64]| {
382 g[0] = 2.0 * (x[0] - 2.0);
383 g[1] = 2.0 * (x[1] - 2.0);
384 })
385 .constraint_ineq_with_grad(
386 |x: &[f64]| x[0] + x[1] - 2.0,
387 |_x: &[f64], g: &mut [f64]| {
388 g[0] = 1.0;
389 g[1] = 1.0;
390 },
391 )
392 .solve()
393 .unwrap();
394
395 assert!(result.converged, "did not converge: {}", result.message);
396 assert!(
397 (result.x[0] - 1.0).abs() < 1e-2,
398 "x0={}, expected 1.0",
399 result.x[0]
400 );
401 assert!(
402 (result.x[1] - 1.0).abs() < 1e-2,
403 "x1={}, expected 1.0",
404 result.x[1]
405 );
406 }
407
408 #[test]
409 fn test_mixed_constraints() {
410 let result = OptimProblem::new(2)
416 .x0(&[1.0, 1.0])
417 .objective(|x: &[f64]| x[0] * x[0] + x[1] * x[1])
418 .gradient(|x: &[f64], g: &mut [f64]| {
419 g[0] = 2.0 * x[0];
420 g[1] = 2.0 * x[1];
421 })
422 .constraint_eq_with_grad(
423 |x: &[f64]| x[0] + x[1] - 1.0,
424 |_x: &[f64], g: &mut [f64]| {
425 g[0] = 1.0;
426 g[1] = 1.0;
427 },
428 )
429 .constraint_ineq_with_grad(
430 |x: &[f64]| 0.6 - x[0],
431 |_x: &[f64], g: &mut [f64]| {
432 g[0] = -1.0;
433 g[1] = 0.0;
434 },
435 )
436 .solve()
437 .unwrap();
438
439 assert!(result.converged, "did not converge: {}", result.message);
440 assert!(
441 (result.x[0] - 0.6).abs() < 5e-2,
442 "x0={}, expected 0.6",
443 result.x[0]
444 );
445 assert!(
446 (result.x[1] - 0.4).abs() < 5e-2,
447 "x1={}, expected 0.4",
448 result.x[1]
449 );
450 assert!(result.constraint_violation < 1e-3);
451 }
452
453 #[test]
454 fn test_aug_lag_custom_options() {
455 use crate::augmented_lagrangian::AugLagOptions;
456 let opts = AugLagOptions {
457 sigma_init: 10.0,
458 ctol: 1e-8,
459 ..AugLagOptions::default()
460 };
461 let result = OptimProblem::new(2)
462 .x0(&[1.0, 0.0])
463 .objective(|x: &[f64]| x[0] + x[1])
464 .gradient(|x: &[f64], g: &mut [f64]| {
465 g[0] = 1.0;
466 g[1] = 1.0;
467 let _ = x;
468 })
469 .constraint_eq(|x: &[f64]| x[0] * x[0] + x[1] * x[1] - 1.0)
470 .aug_lag_options(opts)
471 .solve()
472 .unwrap();
473 assert!(result.converged, "did not converge: {}", result.message);
474 assert!(result.constraint_violation < 1e-7);
475 }
476}