use super::types::{DroConfig, DroResult};
use crate::error::{OptimizeError, OptimizeResult};
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Self {
state: seed.wrapping_add(1),
}
}
fn next_f64(&mut self) -> f64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((self.state >> 11) as f64) / ((1u64 << 53) as f64)
}
}
#[derive(Debug, Clone)]
pub struct CvarEstimator {
pub alpha: f64,
}
impl CvarEstimator {
pub fn new(alpha: f64) -> OptimizeResult<Self> {
if !(0.0 < alpha && alpha < 1.0) {
return Err(OptimizeError::InvalidParameter(format!(
"alpha must be in (0, 1), got {alpha}"
)));
}
Ok(Self { alpha })
}
pub fn compute_cvar(&self, losses: &[f64]) -> OptimizeResult<f64> {
if losses.is_empty() {
return Err(OptimizeError::InvalidParameter(
"losses must be non-empty".into(),
));
}
let mut sorted = losses.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
Ok(cvar_sorted(&sorted, self.alpha))
}
pub fn cvar_gradient(&self, losses: &[f64]) -> OptimizeResult<(f64, Vec<f64>)> {
if losses.is_empty() {
return Err(OptimizeError::InvalidParameter(
"losses must be non-empty".into(),
));
}
let n = losses.len();
let mut sorted = losses.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let cvar = cvar_sorted(&sorted, self.alpha);
let t_star = quantile_sorted(&sorted, self.alpha);
let scale = 1.0 / (n as f64 * (1.0 - self.alpha));
let grad: Vec<f64> = losses
.iter()
.map(|&l| if l > t_star { scale } else { 0.0 })
.collect();
Ok((cvar, grad))
}
}
fn cvar_sorted(sorted: &[f64], alpha: f64) -> f64 {
let n = sorted.len();
if n == 0 {
return 0.0;
}
let scale = 1.0 / ((1.0 - alpha) * n as f64);
let best = sorted
.iter()
.map(|&t| {
let excess: f64 = sorted.iter().map(|&l| (l - t).max(0.0)).sum();
t + scale * excess
})
.fold(f64::INFINITY, f64::min);
best
}
fn quantile_sorted(sorted: &[f64], alpha: f64) -> f64 {
let n = sorted.len();
if n == 0 {
return 0.0;
}
let idx = ((alpha * n as f64).floor() as usize).min(n - 1);
sorted[idx]
}
pub struct CvarDro<'a> {
config: DroConfig,
alpha: f64,
loss_fn: &'a dyn Fn(&[f64], &[f64]) -> f64,
grad_fn: &'a dyn Fn(&[f64], &[f64]) -> Vec<f64>,
}
impl<'a> CvarDro<'a> {
pub fn new(
config: DroConfig,
alpha: f64,
loss_fn: &'a dyn Fn(&[f64], &[f64]) -> f64,
grad_fn: &'a dyn Fn(&[f64], &[f64]) -> Vec<f64>,
) -> OptimizeResult<Self> {
config.validate()?;
if !(0.0 < alpha && alpha < 1.0) {
return Err(OptimizeError::InvalidParameter(format!(
"alpha must be in (0, 1), got {alpha}"
)));
}
Ok(Self {
config,
alpha,
loss_fn,
grad_fn,
})
}
pub fn solve(&self, n_features: usize, samples: &[Vec<f64>]) -> OptimizeResult<DroResult> {
if n_features == 0 {
return Err(OptimizeError::InvalidParameter(
"n_features must be positive".into(),
));
}
if samples.is_empty() {
return Err(OptimizeError::InvalidParameter(
"samples must be non-empty".into(),
));
}
let n = samples.len();
let eps = self.config.radius;
let estimator = CvarEstimator::new(self.alpha)?;
let mut w: Vec<f64> = vec![1.0 / n_features as f64; n_features];
let mut best_w = w.clone();
let mut best_obj = f64::INFINITY;
let c = 0.3_f64;
for t in 1..=self.config.max_iter {
let losses: Vec<f64> = samples.iter().map(|s| (self.loss_fn)(&w, s)).collect();
let cvar_val = estimator.compute_cvar(&losses)?;
let wn = w.iter().map(|x| x * x).sum::<f64>().sqrt();
let obj = cvar_val + eps * wn;
if obj < best_obj {
best_obj = obj;
best_w = w.clone();
}
let (_, loss_grad) = estimator.cvar_gradient(&losses)?;
let mut param_grad: Vec<f64> = vec![0.0; n_features];
for (i, sample) in samples.iter().enumerate() {
let lg = loss_grad[i];
if lg.abs() < 1e-14 {
continue;
}
let g = (self.grad_fn)(&w, sample);
for (pg, gi) in param_grad.iter_mut().zip(g.iter()) {
*pg += lg * gi;
}
}
let wn_safe = wn.max(1e-12);
for (pg, &wi) in param_grad.iter_mut().zip(w.iter()) {
*pg += eps * wi / wn_safe;
}
let grad_norm = param_grad.iter().map(|g| g * g).sum::<f64>().sqrt();
if grad_norm < self.config.tol {
return Ok(DroResult {
optimal_weights: best_w,
worst_case_loss: best_obj,
primal_obj: cvar_val,
n_iter: t,
converged: true,
});
}
let eta = self
.config
.step_size
.unwrap_or_else(|| c / (t as f64).sqrt());
for (wi, gi) in w.iter_mut().zip(param_grad.iter()) {
*wi -= eta * gi;
}
let _ = n; }
let losses: Vec<f64> = samples.iter().map(|s| (self.loss_fn)(&best_w, s)).collect();
let final_cvar = estimator.compute_cvar(&losses)?;
let final_wn = best_w.iter().map(|x| x * x).sum::<f64>().sqrt();
Ok(DroResult {
optimal_weights: best_w,
worst_case_loss: final_cvar + eps * final_wn,
primal_obj: final_cvar,
n_iter: self.config.max_iter,
converged: false,
})
}
}
pub fn solve_cvar_dro(
n_features: usize,
samples: &[Vec<f64>],
alpha: f64,
radius: f64,
config: Option<DroConfig>,
) -> OptimizeResult<DroResult> {
let cfg = config.unwrap_or_else(|| DroConfig {
radius,
n_samples: samples.len(),
max_iter: 500,
tol: 1e-6,
step_size: None,
});
let loss_fn = |w: &[f64], x: &[f64]| -> f64 {
-w.iter().zip(x.iter()).map(|(wi, xi)| wi * xi).sum::<f64>()
};
let grad_fn = |_w: &[f64], x: &[f64]| -> Vec<f64> { x.iter().map(|xi| -xi).collect() };
let solver = CvarDro::new(cfg, alpha, &loss_fn, &grad_fn)?;
solver.solve(n_features, samples)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cvar_computation() {
let losses: Vec<f64> = (0..10).map(|i| i as f64).collect();
let est = CvarEstimator::new(0.9).expect("valid alpha");
let cvar = est.compute_cvar(&losses).expect("cvar ok");
assert!(
(cvar - 9.0).abs() < 0.5,
"CVaR_0.9 of [0..9] should be ~9, got {cvar}"
);
}
#[test]
fn test_cvar_symmetry_ge_mean() {
let losses = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mean = losses.iter().sum::<f64>() / losses.len() as f64;
let est = CvarEstimator::new(0.8).expect("valid");
let cvar = est.compute_cvar(&losses).expect("cvar ok");
assert!(
cvar >= mean - 1e-9,
"CVaR should be >= mean ({mean}), got {cvar}"
);
}
#[test]
fn test_cvar_alpha_invalid_errors() {
assert!(CvarEstimator::new(0.0).is_err());
assert!(CvarEstimator::new(1.0).is_err());
assert!(CvarEstimator::new(-0.1).is_err());
assert!(CvarEstimator::new(1.5).is_err());
}
#[test]
fn test_cvar_at_alpha_near_one_gives_max() {
let losses = vec![1.0, 2.0, 5.0, 10.0, 3.0];
let est = CvarEstimator::new(0.99).expect("valid");
let cvar = est.compute_cvar(&losses).expect("cvar ok");
let max_loss = losses.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
assert!(
cvar >= max_loss - 1e-6,
"CVaR at alpha≈1 should be >= max loss ({max_loss}), got {cvar}"
);
}
#[test]
fn test_cvar_at_alpha_near_zero_gives_mean() {
let losses = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mean = losses.iter().sum::<f64>() / losses.len() as f64;
let est = CvarEstimator::new(0.01).expect("valid");
let cvar = est.compute_cvar(&losses).expect("cvar ok");
assert!(
(cvar - mean).abs() < 0.5,
"CVaR at alpha≈0 should be close to mean ({mean}), got {cvar}"
);
}
#[test]
fn test_cvar_gradient_shape() {
let losses = vec![1.0, 5.0, 2.0, 8.0, 3.0];
let est = CvarEstimator::new(0.6).expect("valid");
let (cvar, grad) = est.cvar_gradient(&losses).expect("grad ok");
assert_eq!(grad.len(), losses.len());
assert!(cvar.is_finite(), "CVaR should be finite");
}
#[test]
fn test_cvar_gradient_non_negative_entries() {
let losses = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let est = CvarEstimator::new(0.6).expect("valid");
let (_, grad) = est.cvar_gradient(&losses).expect("grad ok");
for &g in &grad {
assert!(g >= 0.0, "CVaR gradient entries should be non-negative");
}
}
#[test]
fn test_cvar_dro_converges() {
let loss_fn = |w: &[f64], x: &[f64]| -> f64 {
-w.iter().zip(x.iter()).map(|(wi, xi)| wi * xi).sum::<f64>()
};
let grad_fn = |_w: &[f64], x: &[f64]| -> Vec<f64> { x.iter().map(|xi| -xi).collect() };
let mut rng = Lcg::new(77);
let samples: Vec<Vec<f64>> = (0..30)
.map(|_| vec![rng.next_f64(), rng.next_f64()])
.collect();
let cfg = DroConfig {
radius: 0.05,
max_iter: 300,
tol: 1e-5,
..Default::default()
};
let solver = CvarDro::new(cfg, 0.8, &loss_fn, &grad_fn).expect("valid");
let result = solver.solve(2, &samples).expect("solve ok");
assert!(!result.primal_obj.is_nan(), "primal_obj is NaN");
assert!(!result.worst_case_loss.is_nan(), "worst_case is NaN");
assert_eq!(result.optimal_weights.len(), 2);
}
#[test]
fn test_cvar_dro_fields_non_nan() {
let result = solve_cvar_dro(2, &[vec![0.1, 0.2], vec![0.3, 0.4]], 0.8, 0.05, None)
.expect("solve ok");
assert!(!result.worst_case_loss.is_nan());
assert!(!result.primal_obj.is_nan());
for &w in &result.optimal_weights {
assert!(!w.is_nan());
}
}
}