use crate::error::{OptimizeError, OptimizeResult};
#[derive(Debug, Clone)]
pub struct DroConfig {
pub radius: f64,
pub n_samples: usize,
pub max_iter: usize,
pub tol: f64,
pub step_size: Option<f64>,
}
impl Default for DroConfig {
fn default() -> Self {
Self {
radius: 0.1,
n_samples: 100,
max_iter: 500,
tol: 1e-6,
step_size: None,
}
}
}
impl DroConfig {
pub fn validate(&self) -> OptimizeResult<()> {
if self.radius < 0.0 {
return Err(OptimizeError::InvalidParameter(
"radius must be non-negative".into(),
));
}
if self.n_samples == 0 {
return Err(OptimizeError::InvalidParameter(
"n_samples must be positive".into(),
));
}
if self.max_iter == 0 {
return Err(OptimizeError::InvalidParameter(
"max_iter must be positive".into(),
));
}
if self.tol <= 0.0 {
return Err(OptimizeError::InvalidParameter(
"tol must be positive".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DroResult {
pub optimal_weights: Vec<f64>,
pub worst_case_loss: f64,
pub primal_obj: f64,
pub n_iter: usize,
pub converged: bool,
}
#[derive(Debug, Clone)]
pub struct WassersteinBall {
pub center_samples: Vec<Vec<f64>>,
pub radius: f64,
}
impl WassersteinBall {
pub fn new(center_samples: Vec<Vec<f64>>, radius: f64) -> OptimizeResult<Self> {
if radius < 0.0 {
return Err(OptimizeError::InvalidParameter(
"Wasserstein ball radius must be non-negative".into(),
));
}
if center_samples.is_empty() {
return Err(OptimizeError::InvalidParameter(
"center_samples must be non-empty".into(),
));
}
Ok(Self {
center_samples,
radius,
})
}
pub fn distance_to_point(&self, q: &[f64]) -> f64 {
self.center_samples
.iter()
.map(|c| {
c.iter()
.zip(q.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt()
})
.fold(f64::INFINITY, f64::min)
}
pub fn contains_point(&self, q: &[f64]) -> bool {
self.distance_to_point(q) <= self.radius + f64::EPSILON
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum RobustObjective {
MeanVariance {
lambda: f64,
},
CVaR {
alpha: f64,
},
WorstCase,
}
impl Default for RobustObjective {
fn default() -> Self {
Self::CVaR { alpha: 0.95 }
}
}
#[derive(Debug, Clone)]
pub struct DroSolver {
pub config: DroConfig,
pub objective: RobustObjective,
}
impl Default for DroSolver {
fn default() -> Self {
Self {
config: DroConfig::default(),
objective: RobustObjective::default(),
}
}
}
impl DroSolver {
pub fn new(config: DroConfig, objective: RobustObjective) -> OptimizeResult<Self> {
config.validate()?;
Ok(Self { config, objective })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dro_config_default_valid() {
let cfg = DroConfig::default();
assert!(cfg.validate().is_ok());
}
#[test]
fn test_dro_config_negative_radius_error() {
let cfg = DroConfig {
radius: -0.1,
..Default::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_wasserstein_ball_contains_center() {
let sample = vec![1.0, 2.0];
let ball = WassersteinBall::new(vec![sample.clone()], 0.5).expect("valid ball");
assert!(ball.contains_point(&sample));
}
#[test]
fn test_wasserstein_ball_outside_radius() {
let sample = vec![0.0, 0.0];
let ball = WassersteinBall::new(vec![sample], 0.5).expect("valid ball");
assert!(!ball.contains_point(&[1.0, 1.0]));
}
#[test]
fn test_wasserstein_ball_negative_radius_error() {
assert!(WassersteinBall::new(vec![vec![0.0]], -0.1).is_err());
}
#[test]
fn test_robust_objective_default() {
let obj = RobustObjective::default();
matches!(obj, RobustObjective::CVaR { .. });
}
#[test]
fn test_dro_solver_default() {
let solver = DroSolver::default();
assert!(solver.config.radius >= 0.0);
}
}