1use std::sync::Arc;
27
28use numra_core::Scalar;
29
30use crate::error::OptimError;
31use crate::optim_sensitivity::compute_param_sensitivity;
32use crate::problem::{ConstraintKind, OptimProblem};
33use crate::types::ParamSensitivity;
34
35type ParamObjFn<S> = Arc<dyn Fn(&[S], &[S]) -> S + Send + Sync>;
37type ParamGradFn<S> = Arc<dyn Fn(&[S], &[S], &mut [S]) + Send + Sync>;
39
40#[derive(Clone, Debug)]
46pub struct UncertainParam<S: Scalar> {
47 pub name: String,
49 pub mean: S,
51 pub std: S,
53}
54
55#[derive(Clone, Debug)]
57pub struct RobustOptions<S: Scalar> {
58 pub confidence: S,
60 pub max_iter: usize,
62}
63
64impl<S: Scalar> Default for RobustOptions<S> {
65 fn default() -> Self {
66 Self {
67 confidence: S::from_f64(0.95),
68 max_iter: 1000,
69 }
70 }
71}
72
73#[derive(Clone, Debug)]
75pub struct RobustResult<S: Scalar> {
76 pub x: Vec<S>,
78 pub f_nominal: S,
80 pub f_worst_case: S,
82 pub x_std: Vec<S>,
84 pub converged: bool,
86 pub message: String,
88 pub iterations: usize,
90 pub wall_time_secs: f64,
92 pub sensitivity: Option<ParamSensitivity<S>>,
94}
95
96struct RobustConstraint<S: Scalar> {
102 func: ParamObjFn<S>,
103 kind: ConstraintKind,
104}
105
106pub struct RobustProblem<S: Scalar> {
113 n: usize,
114 x0: Option<Vec<S>>,
115 bounds: Vec<Option<(S, S)>>,
116 objective: Option<ParamObjFn<S>>,
117 gradient: Option<ParamGradFn<S>>,
118 constraints: Vec<RobustConstraint<S>>,
119 params: Vec<UncertainParam<S>>,
120 options: RobustOptions<S>,
121}
122
123impl<S: Scalar> RobustProblem<S> {
124 pub fn new(n: usize) -> Self {
126 Self {
127 n,
128 x0: None,
129 bounds: vec![None; n],
130 objective: None,
131 gradient: None,
132 constraints: Vec::new(),
133 params: Vec::new(),
134 options: RobustOptions::default(),
135 }
136 }
137
138 pub fn x0(mut self, x0: &[S]) -> Self {
140 self.x0 = Some(x0.to_vec());
141 self
142 }
143
144 pub fn bounds(mut self, i: usize, lo_hi: (S, S)) -> Self {
146 self.bounds[i] = Some(lo_hi);
147 self
148 }
149
150 pub fn all_bounds(mut self, bounds: &[(S, S)]) -> Self {
152 for (i, &b) in bounds.iter().enumerate() {
153 self.bounds[i] = Some(b);
154 }
155 self
156 }
157
158 pub fn objective<F>(mut self, f: F) -> Self
160 where
161 F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
162 {
163 self.objective = Some(Arc::new(f));
164 self
165 }
166
167 pub fn gradient<G>(mut self, g: G) -> Self
171 where
172 G: Fn(&[S], &[S], &mut [S]) + Send + Sync + 'static,
173 {
174 self.gradient = Some(Arc::new(g));
175 self
176 }
177
178 pub fn constraint_ineq<F>(mut self, f: F) -> Self
180 where
181 F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
182 {
183 self.constraints.push(RobustConstraint {
184 func: Arc::new(f),
185 kind: ConstraintKind::Inequality,
186 });
187 self
188 }
189
190 pub fn constraint_eq<F>(mut self, f: F) -> Self
192 where
193 F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
194 {
195 self.constraints.push(RobustConstraint {
196 func: Arc::new(f),
197 kind: ConstraintKind::Equality,
198 });
199 self
200 }
201
202 pub fn param(mut self, name: &str, mean: S, std: S) -> Self {
204 self.params.push(UncertainParam {
205 name: name.to_string(),
206 mean,
207 std,
208 });
209 self
210 }
211
212 pub fn params(mut self, params: Vec<UncertainParam<S>>) -> Self {
214 self.params.extend(params);
215 self
216 }
217
218 pub fn confidence(mut self, level: S) -> Self {
220 self.options.confidence = level;
221 self
222 }
223
224 pub fn max_iter(mut self, n: usize) -> Self {
226 self.options.max_iter = n;
227 self
228 }
229}
230
231impl<S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField>
232 RobustProblem<S>
233{
234 pub fn solve(self) -> Result<RobustResult<S>, OptimError> {
241 let start = std::time::Instant::now();
242
243 let obj = self.objective.ok_or(OptimError::NoObjective)?;
244 let x0 = self.x0.clone().ok_or(OptimError::NoInitialPoint)?;
245 let n = self.n;
246
247 let k = normal_quantile(self.options.confidence);
249
250 let p_nom: Vec<S> = self.params.iter().map(|p| p.mean).collect();
252 let p_stds: Vec<S> = self.params.iter().map(|p| p.std).collect();
253 let n_params = self.params.len();
254
255 let obj_for_problem = Arc::clone(&obj);
258 let p_nom_obj = p_nom.clone();
259 let mut problem = OptimProblem::new(n)
260 .x0(&x0)
261 .objective(move |x: &[S]| obj_for_problem(x, &p_nom_obj))
262 .max_iter(self.options.max_iter);
263
264 if let Some(grad_fn) = &self.gradient {
266 let grad_fn = Arc::clone(grad_fn);
267 let p_nom_grad = p_nom.clone();
268 problem = problem.gradient(move |x: &[S], g: &mut [S]| {
269 grad_fn(x, &p_nom_grad, g);
270 });
271 }
272
273 for (i, b) in self.bounds.iter().enumerate() {
275 if let Some(lo_hi) = b {
276 problem = problem.bounds(i, *lo_hi);
277 }
278 }
279
280 for rc in &self.constraints {
282 match rc.kind {
283 ConstraintKind::Equality => {
284 let func = Arc::clone(&rc.func);
286 let p_nom_eq = p_nom.clone();
287 problem = problem.constraint_eq(move |x: &[S]| func(x, &p_nom_eq));
288 }
289 ConstraintKind::Inequality => {
290 let p_worst =
292 compute_worst_case_params(&*rc.func, &x0, &p_nom, &p_stds, k, n_params);
293 let func = Arc::clone(&rc.func);
294 problem = problem.constraint_ineq(move |x: &[S]| func(x, &p_worst));
295 }
296 }
297 }
298
299 let result = problem.solve()?;
301 let x_star = result.x.clone();
302
303 let sensitivity = if !self.params.is_empty() {
305 let obj_sens = Arc::clone(&obj);
306 let bounds_sens = self.bounds.clone();
307 let grad_sens = self.gradient.clone();
308 let max_iter = self.options.max_iter;
309 let param_names: Vec<&str> = self.params.iter().map(|p| p.name.as_str()).collect();
310
311 let sens_result = compute_param_sensitivity(
312 |params: &[S]| {
313 let obj_inner = Arc::clone(&obj_sens);
314 let p_inner = params.to_vec();
315 let mut prob = OptimProblem::new(n)
316 .x0(&x_star)
317 .objective(move |x: &[S]| obj_inner(x, &p_inner))
318 .max_iter(max_iter);
319
320 if let Some(ref gf) = grad_sens {
321 let gf = Arc::clone(gf);
322 let p_g = params.to_vec();
323 prob = prob.gradient(move |x: &[S], g: &mut [S]| {
324 gf(x, &p_g, g);
325 });
326 }
327
328 for (i, b) in bounds_sens.iter().enumerate() {
329 if let Some(lo_hi) = b {
330 prob = prob.bounds(i, *lo_hi);
331 }
332 }
333 prob
334 },
335 &p_nom,
336 ¶m_names,
337 None,
338 );
339
340 sens_result.ok()
341 } else {
342 None
343 };
344
345 let x_std = if let Some(ref sens) = sensitivity {
347 (0..n)
348 .map(|i| {
349 let var: S = (0..n_params)
350 .map(|j| {
351 let dxdp = sens.get(i, j);
352 dxdp * dxdp * p_stds[j] * p_stds[j]
353 })
354 .sum();
355 var.sqrt()
356 })
357 .collect()
358 } else {
359 vec![S::ZERO; n]
360 };
361
362 let f_nominal = obj(&x_star, &p_nom);
364
365 let f_worst_case = if !self.params.is_empty() {
367 let obj_worst = |_x: &[S], p: &[S]| obj(&x_star, p);
368 let p_worst_obj = compute_worst_case_params_for_obj(
369 &obj_worst, &x_star, &p_nom, &p_stds, k, n_params,
370 );
371 obj(&x_star, &p_worst_obj)
372 } else {
373 f_nominal
374 };
375
376 Ok(RobustResult {
377 x: x_star,
378 f_nominal,
379 f_worst_case,
380 x_std,
381 converged: result.converged,
382 message: result.message,
383 iterations: result.iterations,
384 wall_time_secs: start.elapsed().as_secs_f64(),
385 sensitivity,
386 })
387 }
388}
389
390fn compute_worst_case_params<S: Scalar>(
401 g: &dyn Fn(&[S], &[S]) -> S,
402 x0: &[S],
403 p_nom: &[S],
404 p_stds: &[S],
405 k: S,
406 n_params: usize,
407) -> Vec<S> {
408 let mut p_worst = p_nom.to_vec();
409 let fd_eps = S::from_f64(1e-8);
410
411 for j in 0..n_params {
412 if p_stds[j] <= S::ZERO {
413 continue;
414 }
415 let h = fd_eps * (S::ONE + p_nom[j].abs());
416
417 let mut p_plus = p_nom.to_vec();
418 p_plus[j] += h;
419 let g_plus = g(x0, &p_plus);
420
421 let mut p_minus = p_nom.to_vec();
422 p_minus[j] -= h;
423 let g_minus = g(x0, &p_minus);
424
425 if g_plus > g_minus {
427 p_worst[j] = p_nom[j] + k * p_stds[j];
428 } else {
429 p_worst[j] = p_nom[j] - k * p_stds[j];
430 }
431 }
432
433 p_worst
434}
435
436fn compute_worst_case_params_for_obj<S: Scalar>(
438 _f_wrapper: &dyn Fn(&[S], &[S]) -> S,
439 x_star: &[S],
440 p_nom: &[S],
441 p_stds: &[S],
442 k: S,
443 n_params: usize,
444) -> Vec<S> {
445 let mut p_worst = p_nom.to_vec();
446 let fd_eps = S::from_f64(1e-8);
447
448 let f_at = |p: &[S]| _f_wrapper(x_star, p);
451
452 for j in 0..n_params {
453 if p_stds[j] <= S::ZERO {
454 continue;
455 }
456 let h = fd_eps * (S::ONE + p_nom[j].abs());
457
458 let mut p_plus = p_nom.to_vec();
459 p_plus[j] += h;
460 let f_plus = f_at(&p_plus);
461
462 let mut p_minus = p_nom.to_vec();
463 p_minus[j] -= h;
464 let f_minus = f_at(&p_minus);
465
466 if f_plus > f_minus {
467 p_worst[j] = p_nom[j] + k * p_stds[j];
468 } else {
469 p_worst[j] = p_nom[j] - k * p_stds[j];
470 }
471 }
472
473 p_worst
474}
475
476pub fn normal_quantile<S: Scalar>(p: S) -> S {
496 assert!(
497 p > S::ZERO && p < S::ONE,
498 "p must be in (0, 1), got {}",
499 p.to_f64()
500 );
501
502 if (p - S::HALF).abs() < S::from_f64(1e-15) {
503 return S::ZERO;
504 }
505
506 if p < S::HALF {
507 return -normal_quantile(S::ONE - p);
508 }
509
510 let t = (S::from_f64(-2.0) * (S::ONE - p).ln()).sqrt();
512
513 let c0 = S::from_f64(2.515517);
514 let c1 = S::from_f64(0.802853);
515 let c2 = S::from_f64(0.010328);
516 let d1 = S::from_f64(1.432788);
517 let d2 = S::from_f64(0.189269);
518 let d3 = S::from_f64(0.001308);
519
520 t - (c0 + c1 * t + c2 * t * t) / (S::ONE + d1 * t + d2 * t * t + d3 * t * t * t)
521}
522
523#[cfg(test)]
528mod tests {
529 use super::*;
530
531 #[test]
532 fn test_normal_quantile() {
533 assert!(
535 normal_quantile(0.5_f64).abs() < 1e-10,
536 "q(0.5) = {}, expected 0.0",
537 normal_quantile(0.5_f64)
538 );
539
540 let q95 = normal_quantile(0.95_f64);
542 assert!(
543 (q95 - 1.6449).abs() < 1e-3,
544 "q(0.95) = {}, expected ~1.6449",
545 q95
546 );
547
548 let q99 = normal_quantile(0.99_f64);
550 assert!(
551 (q99 - 2.3263).abs() < 1e-3,
552 "q(0.99) = {}, expected ~2.3263",
553 q99
554 );
555
556 let q975 = normal_quantile(0.975_f64);
558 assert!(
559 (q975 - 1.9600).abs() < 1e-3,
560 "q(0.975) = {}, expected ~1.9600",
561 q975
562 );
563 }
564
565 #[test]
566 fn test_robust_unconstrained() {
567 let result = RobustProblem::<f64>::new(1)
570 .x0(&[0.0])
571 .objective(|x: &[f64], p: &[f64]| (x[0] - p[0]) * (x[0] - p[0]))
572 .gradient(|x: &[f64], p: &[f64], g: &mut [f64]| {
573 g[0] = 2.0 * (x[0] - p[0]);
574 })
575 .param("p", 5.0, 1.0)
576 .solve()
577 .unwrap();
578
579 assert!(
580 (result.x[0] - 5.0).abs() < 0.1,
581 "x* = {}, expected ~5.0",
582 result.x[0]
583 );
584 assert!(
585 (result.x_std[0] - 1.0).abs() < 0.3,
586 "x_std = {}, expected ~1.0",
587 result.x_std[0]
588 );
589 assert!(result.converged, "solver should converge");
590 }
591
592 #[test]
593 fn test_robust_constraint_tightening() {
594 let result = RobustProblem::<f64>::new(1)
598 .x0(&[5.0])
599 .objective(|x: &[f64], _p: &[f64]| -x[0])
600 .gradient(|_x: &[f64], _p: &[f64], g: &mut [f64]| {
601 g[0] = -1.0;
602 })
603 .constraint_ineq(|x: &[f64], p: &[f64]| {
604 x[0] - p[0] })
606 .param("p", 10.0, 2.0)
607 .confidence(0.95)
608 .bounds(0, (-100.0, 100.0))
609 .solve()
610 .unwrap();
611
612 assert!(
614 result.x[0] < 8.5,
615 "x* = {}, expected < 8.5 (robust tightening)",
616 result.x[0]
617 );
618 assert!(
620 result.x[0] > 4.0,
621 "x* = {}, should be > 4.0 (not overly conservative)",
622 result.x[0]
623 );
624 }
625
626 #[test]
627 fn test_robust_two_params() {
628 let result = RobustProblem::<f64>::new(1)
634 .x0(&[0.0])
635 .objective(|x: &[f64], _p: &[f64]| x[0] * x[0])
636 .gradient(|x: &[f64], _p: &[f64], g: &mut [f64]| {
637 g[0] = 2.0 * x[0];
638 })
639 .constraint_ineq(|x: &[f64], p: &[f64]| {
640 x[0] - (p[0] + p[1])
642 })
643 .param("p1", 5.0, 1.0)
644 .param("p2", 5.0, 1.0)
645 .confidence(0.95)
646 .solve()
647 .unwrap();
648
649 assert!(
650 result.x[0] < 10.0,
651 "x* = {}, expected < 10 (robust tightening with two params)",
652 result.x[0]
653 );
654 }
655}