use std::sync::Arc;
use numra_core::Scalar;
use crate::error::OptimError;
use crate::optim_sensitivity::compute_param_sensitivity;
use crate::problem::{ConstraintKind, OptimProblem};
use crate::types::ParamSensitivity;
type ParamObjFn<S> = Arc<dyn Fn(&[S], &[S]) -> S + Send + Sync>;
type ParamGradFn<S> = Arc<dyn Fn(&[S], &[S], &mut [S]) + Send + Sync>;
#[derive(Clone, Debug)]
pub struct UncertainParam<S: Scalar> {
pub name: String,
pub mean: S,
pub std: S,
}
#[derive(Clone, Debug)]
pub struct RobustOptions<S: Scalar> {
pub confidence: S,
pub max_iter: usize,
}
impl<S: Scalar> Default for RobustOptions<S> {
fn default() -> Self {
Self {
confidence: S::from_f64(0.95),
max_iter: 1000,
}
}
}
#[derive(Clone, Debug)]
pub struct RobustResult<S: Scalar> {
pub x: Vec<S>,
pub f_nominal: S,
pub f_worst_case: S,
pub x_std: Vec<S>,
pub converged: bool,
pub message: String,
pub iterations: usize,
pub wall_time_secs: f64,
pub sensitivity: Option<ParamSensitivity<S>>,
}
struct RobustConstraint<S: Scalar> {
func: ParamObjFn<S>,
kind: ConstraintKind,
}
pub struct RobustProblem<S: Scalar> {
n: usize,
x0: Option<Vec<S>>,
bounds: Vec<Option<(S, S)>>,
objective: Option<ParamObjFn<S>>,
gradient: Option<ParamGradFn<S>>,
constraints: Vec<RobustConstraint<S>>,
params: Vec<UncertainParam<S>>,
options: RobustOptions<S>,
}
impl<S: Scalar> RobustProblem<S> {
pub fn new(n: usize) -> Self {
Self {
n,
x0: None,
bounds: vec![None; n],
objective: None,
gradient: None,
constraints: Vec::new(),
params: Vec::new(),
options: RobustOptions::default(),
}
}
pub fn x0(mut self, x0: &[S]) -> Self {
self.x0 = Some(x0.to_vec());
self
}
pub fn bounds(mut self, i: usize, lo_hi: (S, S)) -> Self {
self.bounds[i] = Some(lo_hi);
self
}
pub fn all_bounds(mut self, bounds: &[(S, S)]) -> Self {
for (i, &b) in bounds.iter().enumerate() {
self.bounds[i] = Some(b);
}
self
}
pub fn objective<F>(mut self, f: F) -> Self
where
F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
{
self.objective = Some(Arc::new(f));
self
}
pub fn gradient<G>(mut self, g: G) -> Self
where
G: Fn(&[S], &[S], &mut [S]) + Send + Sync + 'static,
{
self.gradient = Some(Arc::new(g));
self
}
pub fn constraint_ineq<F>(mut self, f: F) -> Self
where
F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
{
self.constraints.push(RobustConstraint {
func: Arc::new(f),
kind: ConstraintKind::Inequality,
});
self
}
pub fn constraint_eq<F>(mut self, f: F) -> Self
where
F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
{
self.constraints.push(RobustConstraint {
func: Arc::new(f),
kind: ConstraintKind::Equality,
});
self
}
pub fn param(mut self, name: &str, mean: S, std: S) -> Self {
self.params.push(UncertainParam {
name: name.to_string(),
mean,
std,
});
self
}
pub fn params(mut self, params: Vec<UncertainParam<S>>) -> Self {
self.params.extend(params);
self
}
pub fn confidence(mut self, level: S) -> Self {
self.options.confidence = level;
self
}
pub fn max_iter(mut self, n: usize) -> Self {
self.options.max_iter = n;
self
}
}
impl<S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField>
RobustProblem<S>
{
pub fn solve(self) -> Result<RobustResult<S>, OptimError> {
let start = std::time::Instant::now();
let obj = self.objective.ok_or(OptimError::NoObjective)?;
let x0 = self.x0.clone().ok_or(OptimError::NoInitialPoint)?;
let n = self.n;
let k = normal_quantile(self.options.confidence);
let p_nom: Vec<S> = self.params.iter().map(|p| p.mean).collect();
let p_stds: Vec<S> = self.params.iter().map(|p| p.std).collect();
let n_params = self.params.len();
let obj_for_problem = Arc::clone(&obj);
let p_nom_obj = p_nom.clone();
let mut problem = OptimProblem::new(n)
.x0(&x0)
.objective(move |x: &[S]| obj_for_problem(x, &p_nom_obj))
.max_iter(self.options.max_iter);
if let Some(grad_fn) = &self.gradient {
let grad_fn = Arc::clone(grad_fn);
let p_nom_grad = p_nom.clone();
problem = problem.gradient(move |x: &[S], g: &mut [S]| {
grad_fn(x, &p_nom_grad, g);
});
}
for (i, b) in self.bounds.iter().enumerate() {
if let Some(lo_hi) = b {
problem = problem.bounds(i, *lo_hi);
}
}
for rc in &self.constraints {
match rc.kind {
ConstraintKind::Equality => {
let func = Arc::clone(&rc.func);
let p_nom_eq = p_nom.clone();
problem = problem.constraint_eq(move |x: &[S]| func(x, &p_nom_eq));
}
ConstraintKind::Inequality => {
let p_worst =
compute_worst_case_params(&*rc.func, &x0, &p_nom, &p_stds, k, n_params);
let func = Arc::clone(&rc.func);
problem = problem.constraint_ineq(move |x: &[S]| func(x, &p_worst));
}
}
}
let result = problem.solve()?;
let x_star = result.x.clone();
let sensitivity = if !self.params.is_empty() {
let obj_sens = Arc::clone(&obj);
let bounds_sens = self.bounds.clone();
let grad_sens = self.gradient.clone();
let max_iter = self.options.max_iter;
let param_names: Vec<&str> = self.params.iter().map(|p| p.name.as_str()).collect();
let sens_result = compute_param_sensitivity(
|params: &[S]| {
let obj_inner = Arc::clone(&obj_sens);
let p_inner = params.to_vec();
let mut prob = OptimProblem::new(n)
.x0(&x_star)
.objective(move |x: &[S]| obj_inner(x, &p_inner))
.max_iter(max_iter);
if let Some(ref gf) = grad_sens {
let gf = Arc::clone(gf);
let p_g = params.to_vec();
prob = prob.gradient(move |x: &[S], g: &mut [S]| {
gf(x, &p_g, g);
});
}
for (i, b) in bounds_sens.iter().enumerate() {
if let Some(lo_hi) = b {
prob = prob.bounds(i, *lo_hi);
}
}
prob
},
&p_nom,
¶m_names,
None,
);
sens_result.ok()
} else {
None
};
let x_std = if let Some(ref sens) = sensitivity {
(0..n)
.map(|i| {
let var: S = (0..n_params)
.map(|j| {
let dxdp = sens.get(i, j);
dxdp * dxdp * p_stds[j] * p_stds[j]
})
.sum();
var.sqrt()
})
.collect()
} else {
vec![S::ZERO; n]
};
let f_nominal = obj(&x_star, &p_nom);
let f_worst_case = if !self.params.is_empty() {
let obj_worst = |_x: &[S], p: &[S]| obj(&x_star, p);
let p_worst_obj = compute_worst_case_params_for_obj(
&obj_worst, &x_star, &p_nom, &p_stds, k, n_params,
);
obj(&x_star, &p_worst_obj)
} else {
f_nominal
};
Ok(RobustResult {
x: x_star,
f_nominal,
f_worst_case,
x_std,
converged: result.converged,
message: result.message,
iterations: result.iterations,
wall_time_secs: start.elapsed().as_secs_f64(),
sensitivity,
})
}
}
fn compute_worst_case_params<S: Scalar>(
g: &dyn Fn(&[S], &[S]) -> S,
x0: &[S],
p_nom: &[S],
p_stds: &[S],
k: S,
n_params: usize,
) -> Vec<S> {
let mut p_worst = p_nom.to_vec();
let fd_eps = S::from_f64(1e-8);
for j in 0..n_params {
if p_stds[j] <= S::ZERO {
continue;
}
let h = fd_eps * (S::ONE + p_nom[j].abs());
let mut p_plus = p_nom.to_vec();
p_plus[j] += h;
let g_plus = g(x0, &p_plus);
let mut p_minus = p_nom.to_vec();
p_minus[j] -= h;
let g_minus = g(x0, &p_minus);
if g_plus > g_minus {
p_worst[j] = p_nom[j] + k * p_stds[j];
} else {
p_worst[j] = p_nom[j] - k * p_stds[j];
}
}
p_worst
}
fn compute_worst_case_params_for_obj<S: Scalar>(
_f_wrapper: &dyn Fn(&[S], &[S]) -> S,
x_star: &[S],
p_nom: &[S],
p_stds: &[S],
k: S,
n_params: usize,
) -> Vec<S> {
let mut p_worst = p_nom.to_vec();
let fd_eps = S::from_f64(1e-8);
let f_at = |p: &[S]| _f_wrapper(x_star, p);
for j in 0..n_params {
if p_stds[j] <= S::ZERO {
continue;
}
let h = fd_eps * (S::ONE + p_nom[j].abs());
let mut p_plus = p_nom.to_vec();
p_plus[j] += h;
let f_plus = f_at(&p_plus);
let mut p_minus = p_nom.to_vec();
p_minus[j] -= h;
let f_minus = f_at(&p_minus);
if f_plus > f_minus {
p_worst[j] = p_nom[j] + k * p_stds[j];
} else {
p_worst[j] = p_nom[j] - k * p_stds[j];
}
}
p_worst
}
pub fn normal_quantile<S: Scalar>(p: S) -> S {
assert!(
p > S::ZERO && p < S::ONE,
"p must be in (0, 1), got {}",
p.to_f64()
);
if (p - S::HALF).abs() < S::from_f64(1e-15) {
return S::ZERO;
}
if p < S::HALF {
return -normal_quantile(S::ONE - p);
}
let t = (S::from_f64(-2.0) * (S::ONE - p).ln()).sqrt();
let c0 = S::from_f64(2.515517);
let c1 = S::from_f64(0.802853);
let c2 = S::from_f64(0.010328);
let d1 = S::from_f64(1.432788);
let d2 = S::from_f64(0.189269);
let d3 = S::from_f64(0.001308);
t - (c0 + c1 * t + c2 * t * t) / (S::ONE + d1 * t + d2 * t * t + d3 * t * t * t)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normal_quantile() {
assert!(
normal_quantile(0.5_f64).abs() < 1e-10,
"q(0.5) = {}, expected 0.0",
normal_quantile(0.5_f64)
);
let q95 = normal_quantile(0.95_f64);
assert!(
(q95 - 1.6449).abs() < 1e-3,
"q(0.95) = {}, expected ~1.6449",
q95
);
let q99 = normal_quantile(0.99_f64);
assert!(
(q99 - 2.3263).abs() < 1e-3,
"q(0.99) = {}, expected ~2.3263",
q99
);
let q975 = normal_quantile(0.975_f64);
assert!(
(q975 - 1.9600).abs() < 1e-3,
"q(0.975) = {}, expected ~1.9600",
q975
);
}
#[test]
fn test_robust_unconstrained() {
let result = RobustProblem::<f64>::new(1)
.x0(&[0.0])
.objective(|x: &[f64], p: &[f64]| (x[0] - p[0]) * (x[0] - p[0]))
.gradient(|x: &[f64], p: &[f64], g: &mut [f64]| {
g[0] = 2.0 * (x[0] - p[0]);
})
.param("p", 5.0, 1.0)
.solve()
.unwrap();
assert!(
(result.x[0] - 5.0).abs() < 0.1,
"x* = {}, expected ~5.0",
result.x[0]
);
assert!(
(result.x_std[0] - 1.0).abs() < 0.3,
"x_std = {}, expected ~1.0",
result.x_std[0]
);
assert!(result.converged, "solver should converge");
}
#[test]
fn test_robust_constraint_tightening() {
let result = RobustProblem::<f64>::new(1)
.x0(&[5.0])
.objective(|x: &[f64], _p: &[f64]| -x[0])
.gradient(|_x: &[f64], _p: &[f64], g: &mut [f64]| {
g[0] = -1.0;
})
.constraint_ineq(|x: &[f64], p: &[f64]| {
x[0] - p[0] })
.param("p", 10.0, 2.0)
.confidence(0.95)
.bounds(0, (-100.0, 100.0))
.solve()
.unwrap();
assert!(
result.x[0] < 8.5,
"x* = {}, expected < 8.5 (robust tightening)",
result.x[0]
);
assert!(
result.x[0] > 4.0,
"x* = {}, should be > 4.0 (not overly conservative)",
result.x[0]
);
}
#[test]
fn test_robust_two_params() {
let result = RobustProblem::<f64>::new(1)
.x0(&[0.0])
.objective(|x: &[f64], _p: &[f64]| x[0] * x[0])
.gradient(|x: &[f64], _p: &[f64], g: &mut [f64]| {
g[0] = 2.0 * x[0];
})
.constraint_ineq(|x: &[f64], p: &[f64]| {
x[0] - (p[0] + p[1])
})
.param("p1", 5.0, 1.0)
.param("p2", 5.0, 1.0)
.confidence(0.95)
.solve()
.unwrap();
assert!(
result.x[0] < 10.0,
"x* = {}, expected < 10 (robust tightening with two params)",
result.x[0]
);
}
}