1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ChanceConstraint {
21 name: String,
23 confidence: f32,
25 method: ChanceConstraintMethod,
27 weight: f32,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub enum ChanceConstraintMethod {
34 ScenarioBased {
36 num_scenarios: usize,
38 violation_tolerance: f32,
40 },
41 Gaussian {
43 mean: f32,
45 std_dev: f32,
47 },
48 Conservative {
50 tightening_factor: f32,
52 },
53}
54
55impl ChanceConstraint {
56 pub fn gaussian(name: impl Into<String>, confidence: f32, mean: f32, std_dev: f32) -> Self {
58 assert!(
59 confidence > 0.0 && confidence < 1.0,
60 "Confidence must be in (0, 1)"
61 );
62 assert!(std_dev > 0.0, "Standard deviation must be positive");
63
64 Self {
65 name: name.into(),
66 confidence,
67 method: ChanceConstraintMethod::Gaussian { mean, std_dev },
68 weight: 1.0,
69 }
70 }
71
72 pub fn scenario_based(name: impl Into<String>, confidence: f32, num_scenarios: usize) -> Self {
74 assert!(
75 confidence > 0.0 && confidence < 1.0,
76 "Confidence must be in (0, 1)"
77 );
78 assert!(num_scenarios > 0, "Number of scenarios must be positive");
79
80 Self {
81 name: name.into(),
82 confidence,
83 method: ChanceConstraintMethod::ScenarioBased {
84 num_scenarios,
85 violation_tolerance: 1.0 - confidence,
86 },
87 weight: 1.0,
88 }
89 }
90
91 pub fn with_weight(mut self, weight: f32) -> Self {
93 self.weight = weight;
94 self
95 }
96
97 pub fn get_tightened_bound(&self) -> f32 {
101 match &self.method {
102 ChanceConstraintMethod::Gaussian { mean, std_dev } => {
103 let z_alpha = self.confidence_to_quantile(self.confidence);
106 mean + z_alpha * std_dev
107 }
108 ChanceConstraintMethod::Conservative { tightening_factor } => *tightening_factor,
109 ChanceConstraintMethod::ScenarioBased { .. } => {
110 self.confidence * 10.0 }
113 }
114 }
115
116 fn confidence_to_quantile(&self, confidence: f32) -> f32 {
118 if confidence >= 0.99 {
121 2.58
122 } else if confidence >= 0.95 {
123 1.96
124 } else if confidence >= 0.90 {
125 1.64
126 } else if confidence >= 0.80 {
127 1.28
128 } else {
129 confidence * 3.0 - 1.5
131 }
132 }
133
134 pub fn name(&self) -> &str {
136 &self.name
137 }
138
139 pub fn confidence(&self) -> f32 {
141 self.confidence
142 }
143
144 pub fn weight(&self) -> f32 {
146 self.weight
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct RobustConstraint {
159 name: String,
161 uncertainty_set: UncertaintySet,
163 approach: RobustnessApproach,
165 weight: f32,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub enum UncertaintySet {
172 Box { min: Vec<f32>, max: Vec<f32> },
174 Ellipsoidal {
176 nominal: Vec<f32>,
177 shape_matrix: Vec<f32>,
178 radius: f32,
179 },
180 Polyhedral {
182 a_matrix: Vec<f32>,
183 b_vector: Vec<f32>,
184 dim: usize,
185 },
186 Budget {
188 nominal: Vec<f32>,
189 max_deviations: Vec<f32>,
190 budget: usize,
191 },
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub enum RobustnessApproach {
197 WorstCase,
199 AffinelyAdjustable,
201 Scenarios { num_scenarios: usize },
203}
204
205impl RobustConstraint {
206 pub fn box_uncertain(name: impl Into<String>, min: Vec<f32>, max: Vec<f32>) -> Self {
208 assert_eq!(min.len(), max.len(), "Min and max must have same dimension");
209 for (mi, ma) in min.iter().zip(max.iter()) {
210 assert!(mi <= ma, "Min must be <= max");
211 }
212
213 Self {
214 name: name.into(),
215 uncertainty_set: UncertaintySet::Box { min, max },
216 approach: RobustnessApproach::WorstCase,
217 weight: 1.0,
218 }
219 }
220
221 pub fn ellipsoidal_uncertain(name: impl Into<String>, nominal: Vec<f32>, radius: f32) -> Self {
223 assert!(radius > 0.0, "Radius must be positive");
224 let dim = nominal.len();
225 let identity = vec![1.0; dim * dim]; Self {
228 name: name.into(),
229 uncertainty_set: UncertaintySet::Ellipsoidal {
230 nominal,
231 shape_matrix: identity,
232 radius,
233 },
234 approach: RobustnessApproach::WorstCase,
235 weight: 1.0,
236 }
237 }
238
239 pub fn with_approach(mut self, approach: RobustnessApproach) -> Self {
241 self.approach = approach;
242 self
243 }
244
245 pub fn with_weight(mut self, weight: f32) -> Self {
247 self.weight = weight;
248 self
249 }
250
251 pub fn worst_case_scenario(&self, x: &[f32]) -> Vec<f32> {
253 match &self.uncertainty_set {
254 UncertaintySet::Box { min: _, max } => {
255 max.clone()
257 }
258 UncertaintySet::Ellipsoidal {
259 nominal, radius, ..
260 } => {
261 nominal.iter().map(|&v| v + radius).collect()
263 }
264 UncertaintySet::Polyhedral { .. } => {
265 vec![0.0; x.len()]
267 }
268 UncertaintySet::Budget {
269 nominal,
270 max_deviations,
271 budget,
272 } => {
273 let mut result = nominal.clone();
275 for (i, &dev) in max_deviations.iter().enumerate().take(*budget) {
276 if i < result.len() {
277 result[i] += dev;
278 }
279 }
280 result
281 }
282 }
283 }
284
285 pub fn name(&self) -> &str {
287 &self.name
288 }
289
290 pub fn weight(&self) -> f32 {
292 self.weight
293 }
294
295 pub fn uncertainty_set(&self) -> &UncertaintySet {
297 &self.uncertainty_set
298 }
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct CVaRConstraint {
310 name: String,
312 alpha: f32,
314 threshold: f32,
316 num_scenarios: usize,
318 weight: f32,
320}
321
322impl CVaRConstraint {
323 pub fn new(name: impl Into<String>, alpha: f32, threshold: f32, num_scenarios: usize) -> Self {
325 assert!(alpha > 0.0 && alpha < 1.0, "Alpha must be in (0, 1)");
326 assert!(num_scenarios > 0, "Number of scenarios must be positive");
327
328 Self {
329 name: name.into(),
330 alpha,
331 threshold,
332 num_scenarios,
333 weight: 1.0,
334 }
335 }
336
337 pub fn with_weight(mut self, weight: f32) -> Self {
339 self.weight = weight;
340 self
341 }
342
343 pub fn compute_cvar(&self, losses: &[f32]) -> f32 {
345 if losses.is_empty() {
346 return 0.0;
347 }
348
349 let mut sorted_losses = losses.to_vec();
350 sorted_losses.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); let cutoff = (self.alpha * sorted_losses.len() as f32).ceil() as usize;
354 let cutoff = cutoff.max(1).min(sorted_losses.len());
355
356 sorted_losses.iter().take(cutoff).sum::<f32>() / cutoff as f32
358 }
359
360 pub fn check(&self, losses: &[f32]) -> bool {
362 let cvar = self.compute_cvar(losses);
363 cvar <= self.threshold
364 }
365
366 pub fn violation(&self, losses: &[f32]) -> f32 {
368 let cvar = self.compute_cvar(losses);
369 (cvar - self.threshold).max(0.0)
370 }
371
372 pub fn name(&self) -> &str {
374 &self.name
375 }
376
377 pub fn alpha(&self) -> f32 {
379 self.alpha
380 }
381
382 pub fn threshold(&self) -> f32 {
384 self.threshold
385 }
386
387 pub fn weight(&self) -> f32 {
389 self.weight
390 }
391}
392
393#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct DistributionallyRobustConstraint {
402 name: String,
404 ambiguity_set: AmbiguitySet,
406 threshold: f32,
408 weight: f32,
410}
411
412#[derive(Debug, Clone, Serialize, Deserialize)]
414pub enum AmbiguitySet {
415 Wasserstein {
417 num_samples: usize,
419 radius: f32,
421 },
422 MomentBased {
424 mean: Vec<f32>,
426 cov_radius: f32,
428 },
429 PhiDivergence {
431 num_samples: usize,
433 radius: f32,
435 divergence_type: DivergenceType,
437 },
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize)]
442pub enum DivergenceType {
443 KL,
445 ChiSquared,
447 ModifiedChiSquared,
449}
450
451impl DistributionallyRobustConstraint {
452 pub fn wasserstein(
454 name: impl Into<String>,
455 threshold: f32,
456 num_samples: usize,
457 radius: f32,
458 ) -> Self {
459 assert!(radius > 0.0, "Radius must be positive");
460 assert!(num_samples > 0, "Number of samples must be positive");
461
462 Self {
463 name: name.into(),
464 ambiguity_set: AmbiguitySet::Wasserstein {
465 num_samples,
466 radius,
467 },
468 threshold,
469 weight: 1.0,
470 }
471 }
472
473 pub fn moment_based(
475 name: impl Into<String>,
476 threshold: f32,
477 mean: Vec<f32>,
478 cov_radius: f32,
479 ) -> Self {
480 assert!(cov_radius > 0.0, "Covariance radius must be positive");
481
482 Self {
483 name: name.into(),
484 ambiguity_set: AmbiguitySet::MomentBased { mean, cov_radius },
485 threshold,
486 weight: 1.0,
487 }
488 }
489
490 pub fn with_weight(mut self, weight: f32) -> Self {
492 self.weight = weight;
493 self
494 }
495
496 pub fn worst_case_expectation(&self, losses: &[f32]) -> f32 {
498 if losses.is_empty() {
499 return 0.0;
500 }
501
502 match &self.ambiguity_set {
503 AmbiguitySet::Wasserstein { radius, .. } => {
504 let mean: f32 = losses.iter().sum::<f32>() / losses.len() as f32;
506 let max_loss = losses.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
507 mean + radius * max_loss.abs()
508 }
509 AmbiguitySet::MomentBased { cov_radius, .. } => {
510 let mean: f32 = losses.iter().sum::<f32>() / losses.len() as f32;
512 let variance: f32 =
513 losses.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / losses.len() as f32;
514 mean + cov_radius.sqrt() * variance.sqrt()
515 }
516 AmbiguitySet::PhiDivergence { radius, .. } => {
517 let mean: f32 = losses.iter().sum::<f32>() / losses.len() as f32;
519 mean * (1.0 + radius)
520 }
521 }
522 }
523
524 pub fn check(&self, losses: &[f32]) -> bool {
526 let wc_exp = self.worst_case_expectation(losses);
527 wc_exp <= self.threshold
528 }
529
530 pub fn violation(&self, losses: &[f32]) -> f32 {
532 let wc_exp = self.worst_case_expectation(losses);
533 (wc_exp - self.threshold).max(0.0)
534 }
535
536 pub fn name(&self) -> &str {
538 &self.name
539 }
540
541 pub fn threshold(&self) -> f32 {
543 self.threshold
544 }
545
546 pub fn weight(&self) -> f32 {
548 self.weight
549 }
550
551 pub fn ambiguity_set(&self) -> &AmbiguitySet {
553 &self.ambiguity_set
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_chance_constraint_gaussian() {
563 let cc = ChanceConstraint::gaussian("test", 0.95, 10.0, 2.0);
564 assert_eq!(cc.name(), "test");
565 assert_eq!(cc.confidence(), 0.95);
566
567 let bound = cc.get_tightened_bound();
569 assert!(bound > 10.0);
570 assert!(bound < 15.0); }
572
573 #[test]
574 fn test_robust_constraint_box() {
575 let rc = RobustConstraint::box_uncertain("test", vec![-1.0, -2.0], vec![1.0, 2.0]);
576 assert_eq!(rc.name(), "test");
577
578 let worst_case = rc.worst_case_scenario(&[0.0, 0.0]);
579 assert_eq!(worst_case, vec![1.0, 2.0]);
580 }
581
582 #[test]
583 fn test_cvar_constraint() {
584 let cvar = CVaRConstraint::new("test", 0.1, 10.0, 100);
585
586 let losses = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
588 let cvar_value = cvar.compute_cvar(&losses);
589
590 assert!(cvar_value >= 9.0);
592 assert!(cvar_value <= 10.0);
593 }
594
595 #[test]
596 fn test_distributionally_robust_wasserstein() {
597 let drc = DistributionallyRobustConstraint::wasserstein("test", 15.0, 100, 0.5);
598
599 let losses = vec![5.0, 10.0, 15.0];
600 let wc_exp = drc.worst_case_expectation(&losses);
601
602 assert!(wc_exp >= 10.0); assert!(drc.check(&losses) || !drc.check(&losses)); }
606}