Skip to main content

numra_optim/
robust.rs

1//! Robust optimization with worst-case constraint reformulation.
2//!
3//! When problem parameters are uncertain, robust optimization tightens
4//! constraints to ensure feasibility at a specified confidence level.
5//! For each inequality constraint `g(x, p) <= 0`, the solver determines
6//! the worst-case parameter values (within the confidence ellipsoid) and
7//! enforces `g(x, p_worst) <= 0` instead.
8//!
9//! # Example
10//!
11//! ```rust
12//! use numra_optim::robust::RobustProblem;
13//!
14//! let result = RobustProblem::<f64>::new(1)
15//!     .x0(&[5.0])
16//!     .objective(|x: &[f64], _p: &[f64]| (x[0] - 5.0) * (x[0] - 5.0))
17//!     .param("target", 5.0, 1.0)
18//!     .solve()
19//!     .unwrap();
20//! ```
21//!
22//! Author: Moussa Leblouba
23//! Date: 8 February 2026
24//! Modified: 2 May 2026
25
26use std::sync::Arc;
27
28use numra_core::Scalar;
29
30use crate::error::OptimError;
31use crate::optim_sensitivity::compute_param_sensitivity;
32use crate::problem::{ConstraintKind, OptimProblem};
33use crate::types::ParamSensitivity;
34
35/// Shared parameterized scalar function: `f(x, params) -> S`.
36type ParamObjFn<S> = Arc<dyn Fn(&[S], &[S]) -> S + Send + Sync>;
37/// Shared parameterized gradient function: `g(x, params, grad_out)`.
38type ParamGradFn<S> = Arc<dyn Fn(&[S], &[S], &mut [S]) + Send + Sync>;
39
40// ---------------------------------------------------------------------------
41// Types
42// ---------------------------------------------------------------------------
43
44/// An uncertain parameter for robust optimization.
45#[derive(Clone, Debug)]
46pub struct UncertainParam<S: Scalar> {
47    /// Parameter name (for reporting).
48    pub name: String,
49    /// Nominal (mean) value.
50    pub mean: S,
51    /// Standard deviation.
52    pub std: S,
53}
54
55/// Options for robust optimization.
56#[derive(Clone, Debug)]
57pub struct RobustOptions<S: Scalar> {
58    /// Confidence level in (0, 1). Default: 0.95.
59    pub confidence: S,
60    /// Maximum optimizer iterations. Default: 1000.
61    pub max_iter: usize,
62}
63
64impl<S: Scalar> Default for RobustOptions<S> {
65    fn default() -> Self {
66        Self {
67            confidence: S::from_f64(0.95),
68            max_iter: 1000,
69        }
70    }
71}
72
73/// Result of robust optimization.
74#[derive(Clone, Debug)]
75pub struct RobustResult<S: Scalar> {
76    /// Optimal decision variables.
77    pub x: Vec<S>,
78    /// Nominal objective value (at mean parameters).
79    pub f_nominal: S,
80    /// Worst-case objective value.
81    pub f_worst_case: S,
82    /// Solution uncertainty: std dev of each x_i due to parameter uncertainty.
83    pub x_std: Vec<S>,
84    /// Whether the optimizer converged.
85    pub converged: bool,
86    /// Status message.
87    pub message: String,
88    /// Iterations.
89    pub iterations: usize,
90    /// Wall time.
91    pub wall_time_secs: f64,
92    /// Parametric sensitivity (dx*/dp) if computed.
93    pub sensitivity: Option<ParamSensitivity<S>>,
94}
95
96// ---------------------------------------------------------------------------
97// Builder
98// ---------------------------------------------------------------------------
99
100/// A parameterized constraint for robust optimization.
101struct RobustConstraint<S: Scalar> {
102    func: ParamObjFn<S>,
103    kind: ConstraintKind,
104}
105
106/// Declarative builder for robust optimization problems.
107///
108/// The objective and constraints are functions of both decision variables `x`
109/// and uncertain parameters `p`. The solver reformulates the problem so that
110/// constraints hold under worst-case parameter perturbations at the specified
111/// confidence level.
112pub struct RobustProblem<S: Scalar> {
113    n: usize,
114    x0: Option<Vec<S>>,
115    bounds: Vec<Option<(S, S)>>,
116    objective: Option<ParamObjFn<S>>,
117    gradient: Option<ParamGradFn<S>>,
118    constraints: Vec<RobustConstraint<S>>,
119    params: Vec<UncertainParam<S>>,
120    options: RobustOptions<S>,
121}
122
123impl<S: Scalar> RobustProblem<S> {
124    /// Create a new robust optimization problem with `n` decision variables.
125    pub fn new(n: usize) -> Self {
126        Self {
127            n,
128            x0: None,
129            bounds: vec![None; n],
130            objective: None,
131            gradient: None,
132            constraints: Vec::new(),
133            params: Vec::new(),
134            options: RobustOptions::default(),
135        }
136    }
137
138    /// Set the initial point.
139    pub fn x0(mut self, x0: &[S]) -> Self {
140        self.x0 = Some(x0.to_vec());
141        self
142    }
143
144    /// Set bounds for variable `i`.
145    pub fn bounds(mut self, i: usize, lo_hi: (S, S)) -> Self {
146        self.bounds[i] = Some(lo_hi);
147        self
148    }
149
150    /// Set bounds for all variables at once.
151    pub fn all_bounds(mut self, bounds: &[(S, S)]) -> Self {
152        for (i, &b) in bounds.iter().enumerate() {
153            self.bounds[i] = Some(b);
154        }
155        self
156    }
157
158    /// Set the parameterized objective function `f(x, params)`.
159    pub fn objective<F>(mut self, f: F) -> Self
160    where
161        F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
162    {
163        self.objective = Some(Arc::new(f));
164        self
165    }
166
167    /// Set the gradient of the objective w.r.t. `x`.
168    ///
169    /// `g(x, params, grad_out)` writes the gradient into `grad_out`.
170    pub fn gradient<G>(mut self, g: G) -> Self
171    where
172        G: Fn(&[S], &[S], &mut [S]) + Send + Sync + 'static,
173    {
174        self.gradient = Some(Arc::new(g));
175        self
176    }
177
178    /// Add an inequality constraint `g(x, params) <= 0`.
179    pub fn constraint_ineq<F>(mut self, f: F) -> Self
180    where
181        F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
182    {
183        self.constraints.push(RobustConstraint {
184            func: Arc::new(f),
185            kind: ConstraintKind::Inequality,
186        });
187        self
188    }
189
190    /// Add an equality constraint `h(x, params) = 0`.
191    pub fn constraint_eq<F>(mut self, f: F) -> Self
192    where
193        F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
194    {
195        self.constraints.push(RobustConstraint {
196            func: Arc::new(f),
197            kind: ConstraintKind::Equality,
198        });
199        self
200    }
201
202    /// Add a single uncertain parameter.
203    pub fn param(mut self, name: &str, mean: S, std: S) -> Self {
204        self.params.push(UncertainParam {
205            name: name.to_string(),
206            mean,
207            std,
208        });
209        self
210    }
211
212    /// Add multiple uncertain parameters at once.
213    pub fn params(mut self, params: Vec<UncertainParam<S>>) -> Self {
214        self.params.extend(params);
215        self
216    }
217
218    /// Set the confidence level (must be in (0, 1)).
219    pub fn confidence(mut self, level: S) -> Self {
220        self.options.confidence = level;
221        self
222    }
223
224    /// Set the maximum number of optimizer iterations.
225    pub fn max_iter(mut self, n: usize) -> Self {
226        self.options.max_iter = n;
227        self
228    }
229}
230
231impl<S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField>
232    RobustProblem<S>
233{
234    /// Solve the robust optimization problem.
235    ///
236    /// 1. Computes worst-case parameter vectors for each inequality constraint.
237    /// 2. Reformulates as a standard `OptimProblem` with tightened constraints.
238    /// 3. Solves the reformulated problem.
239    /// 4. Computes parametric sensitivity and solution uncertainty.
240    pub fn solve(self) -> Result<RobustResult<S>, OptimError> {
241        let start = std::time::Instant::now();
242
243        let obj = self.objective.ok_or(OptimError::NoObjective)?;
244        let x0 = self.x0.clone().ok_or(OptimError::NoInitialPoint)?;
245        let n = self.n;
246
247        // 1. Compute k-factor from confidence level.
248        let k = normal_quantile(self.options.confidence);
249
250        // 2. Extract nominal parameter values.
251        let p_nom: Vec<S> = self.params.iter().map(|p| p.mean).collect();
252        let p_stds: Vec<S> = self.params.iter().map(|p| p.std).collect();
253        let n_params = self.params.len();
254
255        // 3. Build a standard OptimProblem.
256        // a. Objective at nominal params.
257        let obj_for_problem = Arc::clone(&obj);
258        let p_nom_obj = p_nom.clone();
259        let mut problem = OptimProblem::new(n)
260            .x0(&x0)
261            .objective(move |x: &[S]| obj_for_problem(x, &p_nom_obj))
262            .max_iter(self.options.max_iter);
263
264        // b. If gradient is provided, set it with nominal params.
265        if let Some(grad_fn) = &self.gradient {
266            let grad_fn = Arc::clone(grad_fn);
267            let p_nom_grad = p_nom.clone();
268            problem = problem.gradient(move |x: &[S], g: &mut [S]| {
269                grad_fn(x, &p_nom_grad, g);
270            });
271        }
272
273        // c. Apply bounds.
274        for (i, b) in self.bounds.iter().enumerate() {
275            if let Some(lo_hi) = b {
276                problem = problem.bounds(i, *lo_hi);
277            }
278        }
279
280        // d. Add constraints.
281        for rc in &self.constraints {
282            match rc.kind {
283                ConstraintKind::Equality => {
284                    // Equality constraints use nominal params (not robustified).
285                    let func = Arc::clone(&rc.func);
286                    let p_nom_eq = p_nom.clone();
287                    problem = problem.constraint_eq(move |x: &[S]| func(x, &p_nom_eq));
288                }
289                ConstraintKind::Inequality => {
290                    // Compute worst-case params for this constraint.
291                    let p_worst =
292                        compute_worst_case_params(&*rc.func, &x0, &p_nom, &p_stds, k, n_params);
293                    let func = Arc::clone(&rc.func);
294                    problem = problem.constraint_ineq(move |x: &[S]| func(x, &p_worst));
295                }
296            }
297        }
298
299        // 4. Solve the reformulated problem.
300        let result = problem.solve()?;
301        let x_star = result.x.clone();
302
303        // 5. Compute parametric sensitivity (dx*/dp).
304        let sensitivity = if !self.params.is_empty() {
305            let obj_sens = Arc::clone(&obj);
306            let bounds_sens = self.bounds.clone();
307            let grad_sens = self.gradient.clone();
308            let max_iter = self.options.max_iter;
309            let param_names: Vec<&str> = self.params.iter().map(|p| p.name.as_str()).collect();
310
311            let sens_result = compute_param_sensitivity(
312                |params: &[S]| {
313                    let obj_inner = Arc::clone(&obj_sens);
314                    let p_inner = params.to_vec();
315                    let mut prob = OptimProblem::new(n)
316                        .x0(&x_star)
317                        .objective(move |x: &[S]| obj_inner(x, &p_inner))
318                        .max_iter(max_iter);
319
320                    if let Some(ref gf) = grad_sens {
321                        let gf = Arc::clone(gf);
322                        let p_g = params.to_vec();
323                        prob = prob.gradient(move |x: &[S], g: &mut [S]| {
324                            gf(x, &p_g, g);
325                        });
326                    }
327
328                    for (i, b) in bounds_sens.iter().enumerate() {
329                        if let Some(lo_hi) = b {
330                            prob = prob.bounds(i, *lo_hi);
331                        }
332                    }
333                    prob
334                },
335                &p_nom,
336                &param_names,
337                None,
338            );
339
340            sens_result.ok()
341        } else {
342            None
343        };
344
345        // 6. Compute x_std from sensitivity: x_std[i] = sqrt(sum_j (dx_i/dp_j)^2 * sigma_j^2).
346        let x_std = if let Some(ref sens) = sensitivity {
347            (0..n)
348                .map(|i| {
349                    let var: S = (0..n_params)
350                        .map(|j| {
351                            let dxdp = sens.get(i, j);
352                            dxdp * dxdp * p_stds[j] * p_stds[j]
353                        })
354                        .sum();
355                    var.sqrt()
356                })
357                .collect()
358        } else {
359            vec![S::ZERO; n]
360        };
361
362        // 7. Compute nominal and worst-case objective values.
363        let f_nominal = obj(&x_star, &p_nom);
364
365        // Worst-case objective: find worst direction for each param.
366        let f_worst_case = if !self.params.is_empty() {
367            let obj_worst = |_x: &[S], p: &[S]| obj(&x_star, p);
368            let p_worst_obj = compute_worst_case_params_for_obj(
369                &obj_worst, &x_star, &p_nom, &p_stds, k, n_params,
370            );
371            obj(&x_star, &p_worst_obj)
372        } else {
373            f_nominal
374        };
375
376        Ok(RobustResult {
377            x: x_star,
378            f_nominal,
379            f_worst_case,
380            x_std,
381            converged: result.converged,
382            message: result.message,
383            iterations: result.iterations,
384            wall_time_secs: start.elapsed().as_secs_f64(),
385            sensitivity,
386        })
387    }
388}
389
390// ---------------------------------------------------------------------------
391// Helper: compute worst-case parameters for a constraint
392// ---------------------------------------------------------------------------
393
394/// For an inequality constraint `g(x, p) <= 0`, determine the worst-case
395/// parameter vector (within the k-sigma confidence region) that maximises `g`.
396///
397/// For each parameter j, estimate the sign of dg/dp_j via finite differences
398/// at the representative point `(x0, p_nom)`, then set p_worst_j to
399/// `p_nom_j + k * std_j` if dg/dp_j > 0, or `p_nom_j - k * std_j` otherwise.
400fn compute_worst_case_params<S: Scalar>(
401    g: &dyn Fn(&[S], &[S]) -> S,
402    x0: &[S],
403    p_nom: &[S],
404    p_stds: &[S],
405    k: S,
406    n_params: usize,
407) -> Vec<S> {
408    let mut p_worst = p_nom.to_vec();
409    let fd_eps = S::from_f64(1e-8);
410
411    for j in 0..n_params {
412        if p_stds[j] <= S::ZERO {
413            continue;
414        }
415        let h = fd_eps * (S::ONE + p_nom[j].abs());
416
417        let mut p_plus = p_nom.to_vec();
418        p_plus[j] += h;
419        let g_plus = g(x0, &p_plus);
420
421        let mut p_minus = p_nom.to_vec();
422        p_minus[j] -= h;
423        let g_minus = g(x0, &p_minus);
424
425        // Choose the direction that makes g larger (more violated).
426        if g_plus > g_minus {
427            p_worst[j] = p_nom[j] + k * p_stds[j];
428        } else {
429            p_worst[j] = p_nom[j] - k * p_stds[j];
430        }
431    }
432
433    p_worst
434}
435
436/// For the objective function, find worst-case params that maximise f (worst case).
437fn compute_worst_case_params_for_obj<S: Scalar>(
438    _f_wrapper: &dyn Fn(&[S], &[S]) -> S,
439    x_star: &[S],
440    p_nom: &[S],
441    p_stds: &[S],
442    k: S,
443    n_params: usize,
444) -> Vec<S> {
445    let mut p_worst = p_nom.to_vec();
446    let fd_eps = S::from_f64(1e-8);
447
448    // We need the actual objective. Since _f_wrapper just calls obj(x_star, p),
449    // we use it directly.
450    let f_at = |p: &[S]| _f_wrapper(x_star, p);
451
452    for j in 0..n_params {
453        if p_stds[j] <= S::ZERO {
454            continue;
455        }
456        let h = fd_eps * (S::ONE + p_nom[j].abs());
457
458        let mut p_plus = p_nom.to_vec();
459        p_plus[j] += h;
460        let f_plus = f_at(&p_plus);
461
462        let mut p_minus = p_nom.to_vec();
463        p_minus[j] -= h;
464        let f_minus = f_at(&p_minus);
465
466        if f_plus > f_minus {
467            p_worst[j] = p_nom[j] + k * p_stds[j];
468        } else {
469            p_worst[j] = p_nom[j] - k * p_stds[j];
470        }
471    }
472
473    p_worst
474}
475
476// ---------------------------------------------------------------------------
477// Normal quantile (inverse CDF)
478// ---------------------------------------------------------------------------
479
480/// Compute the inverse of the standard normal CDF (quantile function).
481///
482/// Uses the Abramowitz & Stegun 26.2.23 rational approximation.
483///
484/// # Arguments
485///
486/// * `p` - Probability in (0, 1).
487///
488/// # Returns
489///
490/// The value `z` such that `Phi(z) = p` where `Phi` is the standard normal CDF.
491///
492/// # Panics
493///
494/// Panics if `p` is not in (0, 1).
495pub fn normal_quantile<S: Scalar>(p: S) -> S {
496    assert!(
497        p > S::ZERO && p < S::ONE,
498        "p must be in (0, 1), got {}",
499        p.to_f64()
500    );
501
502    if (p - S::HALF).abs() < S::from_f64(1e-15) {
503        return S::ZERO;
504    }
505
506    if p < S::HALF {
507        return -normal_quantile(S::ONE - p);
508    }
509
510    // p > 0.5: use the Abramowitz & Stegun rational approximation.
511    let t = (S::from_f64(-2.0) * (S::ONE - p).ln()).sqrt();
512
513    let c0 = S::from_f64(2.515517);
514    let c1 = S::from_f64(0.802853);
515    let c2 = S::from_f64(0.010328);
516    let d1 = S::from_f64(1.432788);
517    let d2 = S::from_f64(0.189269);
518    let d3 = S::from_f64(0.001308);
519
520    t - (c0 + c1 * t + c2 * t * t) / (S::ONE + d1 * t + d2 * t * t + d3 * t * t * t)
521}
522
523// ---------------------------------------------------------------------------
524// Tests
525// ---------------------------------------------------------------------------
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_normal_quantile() {
533        // Phi^{-1}(0.5) = 0.0
534        assert!(
535            normal_quantile(0.5_f64).abs() < 1e-10,
536            "q(0.5) = {}, expected 0.0",
537            normal_quantile(0.5_f64)
538        );
539
540        // Phi^{-1}(0.95) ~ 1.6449
541        let q95 = normal_quantile(0.95_f64);
542        assert!(
543            (q95 - 1.6449).abs() < 1e-3,
544            "q(0.95) = {}, expected ~1.6449",
545            q95
546        );
547
548        // Phi^{-1}(0.99) ~ 2.3263
549        let q99 = normal_quantile(0.99_f64);
550        assert!(
551            (q99 - 2.3263).abs() < 1e-3,
552            "q(0.99) = {}, expected ~2.3263",
553            q99
554        );
555
556        // Phi^{-1}(0.975) ~ 1.9600
557        let q975 = normal_quantile(0.975_f64);
558        assert!(
559            (q975 - 1.9600).abs() < 1e-3,
560            "q(0.975) = {}, expected ~1.9600",
561            q975
562        );
563    }
564
565    #[test]
566    fn test_robust_unconstrained() {
567        // min (x - p)^2 with uncertain param p = 5 +/- 1.
568        // Optimal x* = p_nom = 5. Sensitivity dx*/dp = 1, so x_std ~ 1.0.
569        let result = RobustProblem::<f64>::new(1)
570            .x0(&[0.0])
571            .objective(|x: &[f64], p: &[f64]| (x[0] - p[0]) * (x[0] - p[0]))
572            .gradient(|x: &[f64], p: &[f64], g: &mut [f64]| {
573                g[0] = 2.0 * (x[0] - p[0]);
574            })
575            .param("p", 5.0, 1.0)
576            .solve()
577            .unwrap();
578
579        assert!(
580            (result.x[0] - 5.0).abs() < 0.1,
581            "x* = {}, expected ~5.0",
582            result.x[0]
583        );
584        assert!(
585            (result.x_std[0] - 1.0).abs() < 0.3,
586            "x_std = {}, expected ~1.0",
587            result.x_std[0]
588        );
589        assert!(result.converged, "solver should converge");
590    }
591
592    #[test]
593    fn test_robust_constraint_tightening() {
594        // min -x (maximize x) s.t. x - p <= 0 (i.e. x <= p)
595        // param p = 10 +/- 2, confidence 0.95 (k ~ 1.645).
596        // Nominal: x* = 10. Robust: x* ~ 10 - 1.645*2 = 6.71.
597        let result = RobustProblem::<f64>::new(1)
598            .x0(&[5.0])
599            .objective(|x: &[f64], _p: &[f64]| -x[0])
600            .gradient(|_x: &[f64], _p: &[f64], g: &mut [f64]| {
601                g[0] = -1.0;
602            })
603            .constraint_ineq(|x: &[f64], p: &[f64]| {
604                x[0] - p[0] // x <= p
605            })
606            .param("p", 10.0, 2.0)
607            .confidence(0.95)
608            .bounds(0, (-100.0, 100.0))
609            .solve()
610            .unwrap();
611
612        // Robust solution should be well below nominal 10.
613        assert!(
614            result.x[0] < 8.5,
615            "x* = {}, expected < 8.5 (robust tightening)",
616            result.x[0]
617        );
618        // And should be approximately 10 - 1.645*2 = 6.71
619        assert!(
620            result.x[0] > 4.0,
621            "x* = {}, should be > 4.0 (not overly conservative)",
622            result.x[0]
623        );
624    }
625
626    #[test]
627    fn test_robust_two_params() {
628        // min x^2 s.t. x - (p1 + p2) <= 0 (i.e. x <= p1 + p2).
629        // Params: p1 = 5 +/- 1, p2 = 5 +/- 1. Confidence 0.95.
630        // Nominal bound: x <= 10. Robust bound tighter (x <= 10 - 2*k*1 ~ 6.71).
631        // Since min x^2 with x <= bound: unconstrained min is x=0 which satisfies
632        // any positive upper bound. So x* = 0 < 10 in all cases.
633        let result = RobustProblem::<f64>::new(1)
634            .x0(&[0.0])
635            .objective(|x: &[f64], _p: &[f64]| x[0] * x[0])
636            .gradient(|x: &[f64], _p: &[f64], g: &mut [f64]| {
637                g[0] = 2.0 * x[0];
638            })
639            .constraint_ineq(|x: &[f64], p: &[f64]| {
640                // x - (p1 + p2) <= 0, i.e. x <= p1 + p2
641                x[0] - (p[0] + p[1])
642            })
643            .param("p1", 5.0, 1.0)
644            .param("p2", 5.0, 1.0)
645            .confidence(0.95)
646            .solve()
647            .unwrap();
648
649        assert!(
650            result.x[0] < 10.0,
651            "x* = {}, expected < 10 (robust tightening with two params)",
652            result.x[0]
653        );
654    }
655}