1use std::sync::Arc;
32use std::time::Instant;
33
34use numra_core::Scalar;
35use rand::rngs::StdRng;
36use rand::SeedableRng;
37
38use crate::error::OptimError;
39use crate::problem::{ConstraintKind, OptimProblem};
40
41type ParamFn<S> = Arc<dyn Fn(&[S], &[S]) -> S + Send + Sync>;
43type DetFn<S> = Box<dyn Fn(&[S]) -> S + Send + Sync>;
45
46pub struct StochasticParam<S: Scalar> {
52 pub name: String,
54 pub sampler: Box<dyn Fn(&mut StdRng) -> S>,
56 pub nominal: S,
58}
59
60#[derive(Clone, Debug)]
62pub struct StochasticOptions {
63 pub n_samples: usize,
65 pub seed: u64,
67 pub max_iter: usize,
69}
70
71impl Default for StochasticOptions {
72 fn default() -> Self {
73 Self {
74 n_samples: 100,
75 seed: 42,
76 max_iter: 1000,
77 }
78 }
79}
80
81#[derive(Clone, Debug)]
83pub struct StochasticResult<S: Scalar> {
84 pub x: Vec<S>,
86 pub f_mean: S,
88 pub f_std_error: S,
90 pub scenario_objectives: Vec<S>,
92 pub chance_satisfaction: Vec<S>,
94 pub converged: bool,
96 pub message: String,
98 pub iterations: usize,
100 pub wall_time_secs: f64,
102}
103
104struct DetConstraint<S: Scalar> {
110 func: DetFn<S>,
111 kind: ConstraintKind,
112}
113
114struct ChanceConstraint<S: Scalar> {
116 func: ParamFn<S>,
117 probability: S,
118}
119
120pub fn param_normal<S: Scalar>(name: &str, mean: S, std: S) -> StochasticParam<S> {
129 use rand_distr::{Distribution, Normal};
130 let mean_f64 = mean.to_f64();
131 let std_f64 = std.to_f64();
132 let dist = Normal::new(mean_f64, std_f64).unwrap();
133 StochasticParam {
134 name: name.to_string(),
135 sampler: Box::new(move |rng: &mut StdRng| S::from_f64(dist.sample(rng))),
136 nominal: mean,
137 }
138}
139
140pub fn param_sampled<S: Scalar>(
155 name: &str,
156 nominal: S,
157 sampler: impl Fn(&mut StdRng) -> S + 'static,
158) -> StochasticParam<S> {
159 StochasticParam {
160 name: name.to_string(),
161 sampler: Box::new(sampler),
162 nominal,
163 }
164}
165
166pub fn param_uniform<S: Scalar>(name: &str, lo: S, hi: S) -> StochasticParam<S> {
171 use rand_distr::{Distribution, Uniform};
172 let lo_f64 = lo.to_f64();
173 let hi_f64 = hi.to_f64();
174 let dist = Uniform::new(lo_f64, hi_f64);
175 let two = S::from_f64(2.0);
176 StochasticParam {
177 name: name.to_string(),
178 sampler: Box::new(move |rng: &mut StdRng| S::from_f64(dist.sample(rng))),
179 nominal: (lo + hi) / two,
180 }
181}
182
183pub struct StochasticProblem<S: Scalar> {
192 n: usize,
193 x0: Option<Vec<S>>,
194 bounds: Vec<Option<(S, S)>>,
195 objective: Option<ParamFn<S>>,
196 deterministic_constraints: Vec<DetConstraint<S>>,
197 chance_constraints: Vec<ChanceConstraint<S>>,
198 params: Vec<StochasticParam<S>>,
199 options: StochasticOptions,
200 cvar_alpha: Option<S>,
203}
204
205impl<S: Scalar> StochasticProblem<S> {
206 pub fn new(n: usize) -> Self {
208 Self {
209 n,
210 x0: None,
211 bounds: vec![None; n],
212 objective: None,
213 deterministic_constraints: Vec::new(),
214 chance_constraints: Vec::new(),
215 params: Vec::new(),
216 options: StochasticOptions::default(),
217 cvar_alpha: None,
218 }
219 }
220
221 pub fn x0(mut self, x0: &[S]) -> Self {
223 self.x0 = Some(x0.to_vec());
224 self
225 }
226
227 pub fn bounds(mut self, i: usize, lo_hi: (S, S)) -> Self {
229 self.bounds[i] = Some(lo_hi);
230 self
231 }
232
233 pub fn all_bounds(mut self, bounds: &[(S, S)]) -> Self {
235 for (i, &b) in bounds.iter().enumerate() {
236 self.bounds[i] = Some(b);
237 }
238 self
239 }
240
241 pub fn objective<F>(mut self, f: F) -> Self
246 where
247 F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
248 {
249 self.objective = Some(Arc::new(f));
250 self
251 }
252
253 pub fn constraint_det_ineq<F>(mut self, f: F) -> Self
255 where
256 F: Fn(&[S]) -> S + Send + Sync + 'static,
257 {
258 self.deterministic_constraints.push(DetConstraint {
259 func: Box::new(f),
260 kind: ConstraintKind::Inequality,
261 });
262 self
263 }
264
265 pub fn constraint_det_eq<F>(mut self, f: F) -> Self
267 where
268 F: Fn(&[S]) -> S + Send + Sync + 'static,
269 {
270 self.deterministic_constraints.push(DetConstraint {
271 func: Box::new(f),
272 kind: ConstraintKind::Equality,
273 });
274 self
275 }
276
277 pub fn chance_constraint<F>(mut self, f: F, probability: S) -> Self
282 where
283 F: Fn(&[S], &[S]) -> S + Send + Sync + 'static,
284 {
285 self.chance_constraints.push(ChanceConstraint {
286 func: Arc::new(f),
287 probability,
288 });
289 self
290 }
291
292 pub fn param(mut self, p: StochasticParam<S>) -> Self {
294 self.params.push(p);
295 self
296 }
297
298 pub fn param_normal(mut self, name: &str, mean: S, std: S) -> Self {
300 self.params.push(param_normal(name, mean, std));
301 self
302 }
303
304 pub fn param_uniform(mut self, name: &str, lo: S, hi: S) -> Self {
306 self.params.push(param_uniform(name, lo, hi));
307 self
308 }
309
310 pub fn n_samples(mut self, n: usize) -> Self {
312 self.options.n_samples = n;
313 self
314 }
315
316 pub fn seed(mut self, s: u64) -> Self {
318 self.options.seed = s;
319 self
320 }
321
322 pub fn max_iter(mut self, n: usize) -> Self {
324 self.options.max_iter = n;
325 self
326 }
327
328 pub fn minimize_cvar(mut self, alpha: S) -> Self {
338 self.cvar_alpha = Some(alpha);
339 self
340 }
341}
342
343impl<S: Scalar + faer::SimpleEntity + faer::Conjugate<Canonical = S> + faer::ComplexField>
344 StochasticProblem<S>
345{
346 pub fn solve(self) -> Result<StochasticResult<S>, OptimError> {
353 let start = Instant::now();
354
355 let obj = self.objective.ok_or(OptimError::NoObjective)?;
357 let x0 = self.x0.clone().ok_or(OptimError::NoInitialPoint)?;
358 if self.params.is_empty() {
359 return Err(OptimError::Other(
360 "at least one stochastic parameter is required".to_string(),
361 ));
362 }
363
364 let n = self.n;
365 let n_samples = self.options.n_samples;
366 let cvar_alpha = self.cvar_alpha;
367
368 let mut rng = StdRng::seed_from_u64(self.options.seed);
370 let scenarios: Vec<Vec<S>> = (0..n_samples)
371 .map(|_| self.params.iter().map(|p| (p.sampler)(&mut rng)).collect())
372 .collect();
373 let scenarios = Arc::new(scenarios);
374
375 let obj_for_saa = Arc::clone(&obj);
377 let scenarios_for_saa = Arc::clone(&scenarios);
378
379 let chance_fns: Vec<ParamFn<S>> = self
381 .chance_constraints
382 .iter()
383 .map(|cc| Arc::clone(&cc.func))
384 .collect();
385 let chance_fns = Arc::new(chance_fns);
386 let chance_fns_for_saa = Arc::clone(&chance_fns);
387
388 let base_penalty = S::from_f64(1e4);
390 let chance_weights: Arc<Vec<S>> = Arc::new(
391 self.chance_constraints
392 .iter()
393 .map(|cc| base_penalty * cc.probability)
394 .collect(),
395 );
396 let chance_weights_for_saa = Arc::clone(&chance_weights);
397 let n_chance = self.chance_constraints.len();
398
399 let n_opt = if cvar_alpha.is_some() { n + 1 } else { n };
401 let mut x0_ext = x0.clone();
402 if cvar_alpha.is_some() {
403 let mut f_scenarios: Vec<S> = scenarios.iter().map(|s| obj(&x0, s)).collect();
405 f_scenarios.sort_by(|a, b| a.to_f64().partial_cmp(&b.to_f64()).unwrap());
406 let median_idx = f_scenarios.len() / 2;
407 x0_ext.push(f_scenarios[median_idx]);
408 }
409
410 let saa_objective = move |x_ext: &[S]| -> S {
411 let ns = scenarios_for_saa.len();
412 let inv_n = S::ONE / S::from_usize(ns);
413
414 match cvar_alpha {
415 Some(alpha) => {
416 let x = &x_ext[..x_ext.len() - 1];
420 let t = x_ext[x_ext.len() - 1];
421 let inv_one_minus_alpha = S::ONE / (S::ONE - alpha);
422 let eps = S::from_f64(1e-6);
423
424 let mut cvar_sum = S::ZERO;
425 for s in 0..ns {
426 let fs = obj_for_saa(x, &scenarios_for_saa[s]);
427 let z = fs - t;
428 let smooth_max = (z + (z * z + eps).sqrt()) * S::HALF;
430 cvar_sum += smooth_max;
431 }
432
433 let mut penalty = S::ZERO;
435 for c in 0..n_chance {
436 let w = chance_weights_for_saa[c];
437 for s in 0..ns {
438 let g = chance_fns_for_saa[c](x, &scenarios_for_saa[s]);
439 if g > S::ZERO {
440 penalty += w * g * g;
441 }
442 }
443 }
444
445 t + inv_one_minus_alpha * inv_n * cvar_sum + inv_n * penalty
446 }
447 None => {
448 let x = x_ext;
450 let mut f_sum = S::ZERO;
451 for s in 0..ns {
452 f_sum += obj_for_saa(x, &scenarios_for_saa[s]);
453 }
454
455 let mut penalty = S::ZERO;
457 for c in 0..n_chance {
458 let w = chance_weights_for_saa[c];
459 for s in 0..ns {
460 let g = chance_fns_for_saa[c](x, &scenarios_for_saa[s]);
461 if g > S::ZERO {
462 penalty += w * g * g;
463 }
464 }
465 }
466
467 inv_n * f_sum + inv_n * penalty
468 }
469 }
470 };
471
472 let mut problem = OptimProblem::<S>::new(n_opt)
474 .x0(&x0_ext)
475 .objective(saa_objective)
476 .max_iter(self.options.max_iter);
477
478 for (i, b) in self.bounds.iter().enumerate() {
480 if let Some(lo_hi) = b {
481 problem = problem.bounds(i, *lo_hi);
482 }
483 }
484
485 for dc in self.deterministic_constraints {
488 if cvar_alpha.is_some() {
489 let dc_n = n;
490 match dc.kind {
491 ConstraintKind::Inequality => {
492 let func = dc.func;
493 problem = problem.constraint_ineq(move |x_ext: &[S]| func(&x_ext[..dc_n]));
494 }
495 ConstraintKind::Equality => {
496 let func = dc.func;
497 problem = problem.constraint_eq(move |x_ext: &[S]| func(&x_ext[..dc_n]));
498 }
499 }
500 } else {
501 match dc.kind {
502 ConstraintKind::Inequality => {
503 problem = problem.constraint_ineq(dc.func);
504 }
505 ConstraintKind::Equality => {
506 problem = problem.constraint_eq(dc.func);
507 }
508 }
509 }
510 }
511
512 let result = problem.solve()?;
514 let x_star = if cvar_alpha.is_some() {
516 result.x[..n].to_vec()
517 } else {
518 result.x.clone()
519 };
520
521 let scenario_objectives: Vec<S> = scenarios.iter().map(|s| obj(&x_star, s)).collect();
523
524 let f_mean = scenario_objectives.iter().copied().sum::<S>() / S::from_usize(n_samples);
525
526 let f_std_error = if n_samples > 1 {
527 let variance = scenario_objectives
528 .iter()
529 .map(|&fi| (fi - f_mean) * (fi - f_mean))
530 .sum::<S>()
531 / S::from_usize(n_samples - 1);
532 variance.sqrt() / S::from_usize(n_samples).sqrt()
533 } else {
534 S::ZERO
535 };
536
537 let chance_satisfaction: Vec<S> = self
539 .chance_constraints
540 .iter()
541 .map(|cc| {
542 let n_satisfied = scenarios
543 .iter()
544 .filter(|s| (cc.func)(&x_star, s) <= S::ZERO)
545 .count();
546 S::from_usize(n_satisfied) / S::from_usize(n_samples)
547 })
548 .collect();
549
550 Ok(StochasticResult {
551 x: x_star,
552 f_mean,
553 f_std_error,
554 scenario_objectives,
555 chance_satisfaction,
556 converged: result.converged,
557 message: result.message,
558 iterations: result.iterations,
559 wall_time_secs: start.elapsed().as_secs_f64(),
560 })
561 }
562}
563
564#[cfg(test)]
569mod tests {
570 use super::*;
571 use rand::SeedableRng;
572
573 #[test]
574 fn test_saa_expected_value() {
575 let result = StochasticProblem::new(1)
578 .x0(&[0.0])
579 .objective(|x: &[f64], p: &[f64]| (x[0] - p[0]) * (x[0] - p[0]))
580 .param_normal("xi", 5.0, 1.0)
581 .n_samples(200)
582 .solve()
583 .unwrap();
584
585 assert!(
586 (result.x[0] - 5.0).abs() < 0.5,
587 "x* = {}, expected ~5.0",
588 result.x[0]
589 );
590 assert!(result.converged, "solver should converge");
591 assert_eq!(result.scenario_objectives.len(), 200);
592 assert!(result.f_std_error > 0.0, "std error should be positive");
593 }
594
595 #[test]
596 fn test_saa_bounded() {
597 let result = StochasticProblem::new(1)
600 .x0(&[1.0])
601 .objective(|x: &[f64], p: &[f64]| (x[0] - p[0]) * (x[0] - p[0]))
602 .param_normal("xi", 5.0, 1.0)
603 .bounds(0, (0.0, 3.0))
604 .n_samples(200)
605 .solve()
606 .unwrap();
607
608 assert!(
609 (result.x[0] - 3.0).abs() < 0.1,
610 "x* = {}, expected ~3.0 (bound active)",
611 result.x[0]
612 );
613 }
614
615 #[test]
616 fn test_chance_constraint() {
617 let result = StochasticProblem::new(1)
624 .x0(&[5.0])
625 .objective(|x: &[f64], _p: &[f64]| -x[0])
626 .chance_constraint(
627 |x: &[f64], p: &[f64]| x[0] - p[0], 0.95,
629 )
630 .param_normal("xi", 10.0, 2.0)
631 .bounds(0, (0.0, 20.0))
632 .n_samples(500)
633 .max_iter(2000)
634 .solve()
635 .unwrap();
636
637 assert!(
638 result.x[0] < 10.0,
639 "x* = {}, expected < 10.0 (must be conservative)",
640 result.x[0]
641 );
642 assert!(result.x[0] > 0.0, "x* = {}, expected > 0.0", result.x[0]);
643 assert!(
645 result.chance_satisfaction.len() == 1,
646 "should have one chance constraint"
647 );
648 }
649
650 #[test]
651 fn test_param_normal_uniform_helpers() {
652 let mut rng = StdRng::seed_from_u64(123);
654
655 {
657 use rand_distr::{Distribution, Normal};
658 let dist = Normal::new(5.0, 1.0).unwrap();
659 let samples: Vec<f64> = (0..1000).map(|_| dist.sample(&mut rng)).collect();
660 let mean = samples.iter().sum::<f64>() / 1000.0;
661 let variance = samples
662 .iter()
663 .map(|&s| (s - mean) * (s - mean))
664 .sum::<f64>()
665 / 999.0;
666 let std = variance.sqrt();
667 assert!(
668 (mean - 5.0).abs() < 0.3,
669 "Normal mean = {}, expected ~5.0",
670 mean
671 );
672 assert!(
673 (std - 1.0).abs() < 0.3,
674 "Normal std = {}, expected ~1.0",
675 std
676 );
677 }
678
679 {
681 use rand_distr::{Distribution, Uniform};
682 let dist = Uniform::new(0.0, 10.0);
683 let samples: Vec<f64> = (0..1000).map(|_| dist.sample(&mut rng)).collect();
684 let mean = samples.iter().sum::<f64>() / 1000.0;
685 assert!(
686 samples.iter().all(|&s| (0.0..10.0).contains(&s)),
687 "Uniform samples should be in [0, 10)"
688 );
689 assert!(
690 (mean - 5.0).abs() < 1.0,
691 "Uniform mean = {}, expected ~5.0",
692 mean
693 );
694 }
695
696 let problem = StochasticProblem::<f64>::new(1)
698 .param_normal("n1", 5.0, 1.0)
699 .param_uniform("u1", 0.0, 10.0);
700 assert_eq!(problem.params.len(), 2);
701 assert_eq!(problem.params[0].name, "n1");
702 assert!((problem.params[0].nominal - 5.0).abs() < 1e-15);
703 assert_eq!(problem.params[1].name, "u1");
704 assert!((problem.params[1].nominal - 5.0).abs() < 1e-15);
705
706 let mut rng2 = StdRng::seed_from_u64(456);
708 let normal_samples: Vec<f64> = (0..1000)
709 .map(|_| (problem.params[0].sampler)(&mut rng2))
710 .collect();
711 let n_mean = normal_samples.iter().sum::<f64>() / 1000.0;
712 assert!(
713 (n_mean - 5.0).abs() < 0.3,
714 "StochasticParam Normal mean = {}, expected ~5.0",
715 n_mean
716 );
717
718 let uniform_samples: Vec<f64> = (0..1000)
719 .map(|_| (problem.params[1].sampler)(&mut rng2))
720 .collect();
721 let u_mean = uniform_samples.iter().sum::<f64>() / 1000.0;
722 assert!(
723 uniform_samples.iter().all(|&s| (0.0..10.0).contains(&s)),
724 "StochasticParam Uniform samples should be in [0, 10)"
725 );
726 assert!(
727 (u_mean - 5.0).abs() < 1.0,
728 "StochasticParam Uniform mean = {}, expected ~5.0",
729 u_mean
730 );
731 }
732
733 #[test]
734 fn test_param_sampled() {
735 use rand_distr::{Distribution, Normal};
737 let dist = Normal::new(5.0_f64, 1.0).unwrap();
738 let p = param_sampled("xi", 5.0, move |rng: &mut StdRng| dist.sample(rng));
739
740 assert_eq!(p.name, "xi");
741 assert!((p.nominal - 5.0).abs() < 1e-15);
742
743 let mut rng = StdRng::seed_from_u64(42);
745 let samples: Vec<f64> = (0..1000).map(|_| (p.sampler)(&mut rng)).collect();
746 let mean = samples.iter().sum::<f64>() / 1000.0;
747 assert!((mean - 5.0).abs() < 0.3);
748 }
749
750 #[test]
751 fn test_param_sampled_in_problem() {
752 use rand_distr::{Distribution, Normal};
754 let dist = Normal::new(5.0_f64, 1.0).unwrap();
755
756 let result = StochasticProblem::new(1)
757 .x0(&[0.0])
758 .objective(|x: &[f64], p: &[f64]| (x[0] - p[0]) * (x[0] - p[0]))
759 .param(param_sampled("xi", 5.0, move |rng: &mut StdRng| {
760 dist.sample(rng)
761 }))
762 .n_samples(200)
763 .solve()
764 .unwrap();
765
766 assert!((result.x[0] - 5.0).abs() < 0.5);
767 assert!(result.converged);
768 }
769
770 #[test]
771 fn test_cvar_minimization() {
772 let result_cvar = StochasticProblem::new(1)
778 .x0(&[3.0])
779 .objective(|x: &[f64], p: &[f64]| (x[0] - p[0]) * (x[0] - p[0]))
780 .param_normal("xi", 5.0, 2.0)
781 .n_samples(500)
782 .max_iter(2000)
783 .minimize_cvar(0.9)
784 .solve()
785 .unwrap();
786
787 assert_eq!(result_cvar.x.len(), 1, "CVaR should return original dim");
789 assert!(
790 result_cvar.converged,
791 "CVaR should converge: {}",
792 result_cvar.message
793 );
794 assert!(
796 (result_cvar.x[0] - 5.0).abs() < 1.0,
797 "CVaR x* = {}, expected ~5.0",
798 result_cvar.x[0]
799 );
800
801 let result_ev = StochasticProblem::new(1)
803 .x0(&[0.0])
804 .objective(|x: &[f64], p: &[f64]| (x[0] - p[0]) * (x[0] - p[0]))
805 .param_normal("xi", 5.0, 2.0)
806 .n_samples(500)
807 .solve()
808 .unwrap();
809
810 assert!(
813 result_cvar.f_mean >= result_ev.f_mean * 0.5,
814 "CVaR f_mean={} should be comparable to EV f_mean={}",
815 result_cvar.f_mean,
816 result_ev.f_mean
817 );
818 }
819}