1#![forbid(unsafe_code)]
79
80use std::collections::BTreeMap;
81use std::fmt;
82
83use franken_evidence::{EvidenceLedger, EvidenceLedgerBuilder};
84use franken_kernel::{DecisionId, TraceId};
85use serde::{Deserialize, Deserializer, Serialize};
86
87#[derive(Clone, Debug, PartialEq)]
93pub enum ValidationError {
94 InvalidLoss {
96 state: usize,
98 action: usize,
100 value: f64,
102 },
103 NegativeLoss {
105 state: usize,
107 action: usize,
109 value: f64,
111 },
112 DimensionMismatch {
114 expected: usize,
116 got: usize,
118 },
119 PosteriorNotNormalized {
121 sum: f64,
123 },
124 InvalidPosteriorProbability {
126 index: usize,
128 value: f64,
130 },
131 PosteriorLengthMismatch {
133 expected: usize,
135 got: usize,
137 },
138 EmptySpace {
140 field: &'static str,
142 },
143 ThresholdOutOfRange {
145 field: &'static str,
147 value: f64,
149 },
150}
151
152impl fmt::Display for ValidationError {
153 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154 match self {
155 Self::InvalidLoss {
156 state,
157 action,
158 value,
159 } => write!(
160 f,
161 "loss must be finite at state={state}, action={action}, got {value}"
162 ),
163 Self::NegativeLoss {
164 state,
165 action,
166 value,
167 } => write!(f, "negative loss {value} at state={state}, action={action}"),
168 Self::DimensionMismatch { expected, got } => {
169 write!(
170 f,
171 "dimension mismatch: expected {expected} values, got {got}"
172 )
173 }
174 Self::PosteriorNotNormalized { sum } => {
175 write!(f, "posterior sums to {sum}, expected 1.0")
176 }
177 Self::InvalidPosteriorProbability { index, value } => {
178 write!(
179 f,
180 "posterior[{index}] must be finite and non-negative, got {value}"
181 )
182 }
183 Self::PosteriorLengthMismatch { expected, got } => {
184 write!(
185 f,
186 "posterior length {got} does not match state count {expected}"
187 )
188 }
189 Self::EmptySpace { field } => write!(f, "{field} must not be empty"),
190 Self::ThresholdOutOfRange { field, value } => {
191 write!(f, "{field} threshold {value} out of valid range")
192 }
193 }
194 }
195}
196
197impl std::error::Error for ValidationError {}
198
199#[derive(Clone, Debug, Serialize, PartialEq)]
209pub struct LossMatrix {
210 state_names: Vec<String>,
211 action_names: Vec<String>,
212 values: Vec<f64>,
213}
214
215#[derive(Deserialize)]
216struct LossMatrixRepr {
217 state_names: Vec<String>,
218 action_names: Vec<String>,
219 values: Vec<f64>,
220}
221
222impl<'de> Deserialize<'de> for LossMatrix {
223 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
224 where
225 D: Deserializer<'de>,
226 {
227 let repr = LossMatrixRepr::deserialize(deserializer)?;
228 Self::new(repr.state_names, repr.action_names, repr.values)
229 .map_err(serde::de::Error::custom)
230 }
231}
232
233impl LossMatrix {
234 pub fn new(
240 state_names: Vec<String>,
241 action_names: Vec<String>,
242 values: Vec<f64>,
243 ) -> Result<Self, ValidationError> {
244 if state_names.is_empty() {
245 return Err(ValidationError::EmptySpace {
246 field: "state_names",
247 });
248 }
249 if action_names.is_empty() {
250 return Err(ValidationError::EmptySpace {
251 field: "action_names",
252 });
253 }
254 let expected = state_names.len() * action_names.len();
255 if values.len() != expected {
256 return Err(ValidationError::DimensionMismatch {
257 expected,
258 got: values.len(),
259 });
260 }
261 let n_actions = action_names.len();
262 for (i, &v) in values.iter().enumerate() {
263 if !v.is_finite() {
264 return Err(ValidationError::InvalidLoss {
265 state: i / n_actions,
266 action: i % n_actions,
267 value: v,
268 });
269 }
270 if v < 0.0 {
271 return Err(ValidationError::NegativeLoss {
272 state: i / n_actions,
273 action: i % n_actions,
274 value: v,
275 });
276 }
277 }
278 Ok(Self {
279 state_names,
280 action_names,
281 values,
282 })
283 }
284
285 pub fn get(&self, state: usize, action: usize) -> f64 {
287 self.values[state * self.action_names.len() + action]
288 }
289
290 pub fn n_states(&self) -> usize {
292 self.state_names.len()
293 }
294
295 pub fn n_actions(&self) -> usize {
297 self.action_names.len()
298 }
299
300 pub fn state_names(&self) -> &[String] {
302 &self.state_names
303 }
304
305 pub fn action_names(&self) -> &[String] {
307 &self.action_names
308 }
309
310 pub fn expected_loss(&self, posterior: &Posterior, action: usize) -> f64 {
314 posterior
315 .probs()
316 .iter()
317 .enumerate()
318 .map(|(s, &p)| p * self.get(s, action))
319 .sum()
320 }
321
322 pub fn expected_losses(&self, posterior: &Posterior) -> BTreeMap<String, f64> {
324 self.action_names
325 .iter()
326 .enumerate()
327 .map(|(a, name)| (name.clone(), self.expected_loss(posterior, a)))
328 .collect()
329 }
330
331 pub fn bayes_action(&self, posterior: &Posterior) -> usize {
335 (0..self.action_names.len())
336 .min_by(|&a, &b| {
337 self.expected_loss(posterior, a)
338 .partial_cmp(&self.expected_loss(posterior, b))
339 .unwrap_or(std::cmp::Ordering::Equal)
340 })
341 .unwrap_or(0)
342 }
343}
344
345const NORMALIZATION_TOLERANCE: f64 = 1e-6;
351
352#[derive(Clone, Debug, Serialize, PartialEq)]
356pub struct Posterior {
357 probs: Vec<f64>,
358}
359
360#[derive(Deserialize)]
361struct PosteriorRepr {
362 probs: Vec<f64>,
363}
364
365impl<'de> Deserialize<'de> for Posterior {
366 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
367 where
368 D: Deserializer<'de>,
369 {
370 let repr = PosteriorRepr::deserialize(deserializer)?;
371 Self::new(repr.probs).map_err(serde::de::Error::custom)
372 }
373}
374
375impl Posterior {
376 pub fn new(probs: Vec<f64>) -> Result<Self, ValidationError> {
380 for (index, &value) in probs.iter().enumerate() {
381 if !value.is_finite() || value < 0.0 {
382 return Err(ValidationError::InvalidPosteriorProbability { index, value });
383 }
384 }
385 let sum: f64 = probs.iter().sum();
386 if (sum - 1.0).abs() > NORMALIZATION_TOLERANCE {
387 return Err(ValidationError::PosteriorNotNormalized { sum });
388 }
389 Ok(Self { probs })
390 }
391
392 #[allow(clippy::cast_precision_loss)]
394 pub fn uniform(n: usize) -> Self {
395 let p = 1.0 / n as f64;
396 Self { probs: vec![p; n] }
397 }
398
399 pub fn probs(&self) -> &[f64] {
401 &self.probs
402 }
403
404 pub fn probs_mut(&mut self) -> &mut [f64] {
406 &mut self.probs
407 }
408
409 pub fn len(&self) -> usize {
411 self.probs.len()
412 }
413
414 pub fn is_empty(&self) -> bool {
416 self.probs.is_empty()
417 }
418
419 pub fn bayesian_update(&mut self, likelihoods: &[f64]) {
428 assert_eq!(likelihoods.len(), self.probs.len());
429 for (p, &l) in self.probs.iter_mut().zip(likelihoods) {
430 *p *= l;
431 }
432 self.normalize();
433 }
434
435 pub fn normalize(&mut self) {
437 let sum: f64 = self.probs.iter().sum();
438 if sum > 0.0 {
439 for p in &mut self.probs {
440 *p /= sum;
441 }
442 }
443 }
444
445 pub fn entropy(&self) -> f64 {
447 self.probs
448 .iter()
449 .filter(|&&p| p > 0.0)
450 .map(|&p| -p * p.log2())
451 .sum()
452 }
453
454 pub fn map_state(&self) -> usize {
458 self.probs
459 .iter()
460 .enumerate()
461 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
462 .map_or(0, |(i, _)| i)
463 }
464}
465
466#[derive(Clone, Debug, Serialize, PartialEq)]
475pub struct FallbackPolicy {
476 pub calibration_drift_threshold: f64,
478 pub e_process_breach_threshold: f64,
480 pub confidence_width_threshold: f64,
482}
483
484#[derive(Deserialize)]
485#[allow(clippy::struct_field_names)]
486struct FallbackPolicyRepr {
487 calibration_drift_threshold: f64,
488 e_process_breach_threshold: f64,
489 confidence_width_threshold: f64,
490}
491
492impl<'de> Deserialize<'de> for FallbackPolicy {
493 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
494 where
495 D: Deserializer<'de>,
496 {
497 let repr = FallbackPolicyRepr::deserialize(deserializer)?;
498 Self::new(
499 repr.calibration_drift_threshold,
500 repr.e_process_breach_threshold,
501 repr.confidence_width_threshold,
502 )
503 .map_err(serde::de::Error::custom)
504 }
505}
506
507impl FallbackPolicy {
508 pub fn new(
513 calibration_drift_threshold: f64,
514 e_process_breach_threshold: f64,
515 confidence_width_threshold: f64,
516 ) -> Result<Self, ValidationError> {
517 if !calibration_drift_threshold.is_finite()
518 || !(0.0..=1.0).contains(&calibration_drift_threshold)
519 {
520 return Err(ValidationError::ThresholdOutOfRange {
521 field: "calibration_drift_threshold",
522 value: calibration_drift_threshold,
523 });
524 }
525 if !e_process_breach_threshold.is_finite() || e_process_breach_threshold < 0.0 {
526 return Err(ValidationError::ThresholdOutOfRange {
527 field: "e_process_breach_threshold",
528 value: e_process_breach_threshold,
529 });
530 }
531 if !confidence_width_threshold.is_finite() || confidence_width_threshold < 0.0 {
532 return Err(ValidationError::ThresholdOutOfRange {
533 field: "confidence_width_threshold",
534 value: confidence_width_threshold,
535 });
536 }
537 Ok(Self {
538 calibration_drift_threshold,
539 e_process_breach_threshold,
540 confidence_width_threshold,
541 })
542 }
543
544 pub fn should_fallback(&self, calibration_score: f64, e_process: f64, ci_width: f64) -> bool {
546 calibration_score < self.calibration_drift_threshold
547 || e_process > self.e_process_breach_threshold
548 || ci_width > self.confidence_width_threshold
549 }
550}
551
552impl Default for FallbackPolicy {
553 fn default() -> Self {
554 Self {
555 calibration_drift_threshold: 0.7,
556 e_process_breach_threshold: 20.0,
557 confidence_width_threshold: 0.5,
558 }
559 }
560}
561
562pub trait DecisionContract {
572 fn name(&self) -> &str;
574
575 fn state_space(&self) -> &[String];
577
578 fn action_set(&self) -> &[String];
580
581 fn loss_matrix(&self) -> &LossMatrix;
583
584 fn update_posterior(&self, posterior: &mut Posterior, state_index: usize);
586
587 fn choose_action(&self, posterior: &Posterior) -> usize;
591
592 fn fallback_action(&self) -> usize;
596
597 fn fallback_policy(&self) -> &FallbackPolicy;
599}
600
601#[derive(Clone, Debug, Serialize, Deserialize)]
610pub struct DecisionAuditEntry {
611 pub decision_id: DecisionId,
613 pub trace_id: TraceId,
615 pub contract_name: String,
617 pub action_chosen: String,
619 pub expected_loss: f64,
621 pub calibration_score: f64,
623 pub fallback_active: bool,
625 pub posterior_snapshot: Vec<f64>,
627 pub expected_loss_by_action: BTreeMap<String, f64>,
629 pub ts_unix_ms: u64,
631}
632
633impl DecisionAuditEntry {
634 pub fn to_evidence_ledger(&self) -> EvidenceLedger {
636 let mut builder = EvidenceLedgerBuilder::new()
637 .ts_unix_ms(self.ts_unix_ms)
638 .component(&self.contract_name)
639 .action(&self.action_chosen)
640 .posterior(self.posterior_snapshot.clone())
641 .chosen_expected_loss(self.expected_loss)
642 .calibration_score(self.calibration_score)
643 .fallback_active(self.fallback_active);
644
645 for (action, &loss) in &self.expected_loss_by_action {
646 builder = builder.expected_loss(action, loss);
647 }
648
649 builder
650 .build()
651 .expect("audit entry should produce valid evidence ledger")
652 }
653}
654
655#[derive(Clone, Debug)]
661pub struct DecisionOutcome {
662 pub action_index: usize,
664 pub action_name: String,
666 pub expected_loss: f64,
668 pub expected_losses: BTreeMap<String, f64>,
670 pub fallback_active: bool,
672 pub audit_entry: DecisionAuditEntry,
674}
675
676#[derive(Clone, Debug)]
685pub struct EvalContext {
686 pub calibration_score: f64,
688 pub e_process: f64,
690 pub ci_width: f64,
692 pub decision_id: DecisionId,
694 pub trace_id: TraceId,
696 pub ts_unix_ms: u64,
698}
699
700pub fn evaluate<C: DecisionContract>(
710 contract: &C,
711 posterior: &Posterior,
712 ctx: &EvalContext,
713) -> DecisionOutcome {
714 let loss_matrix = contract.loss_matrix();
715 let expected_losses = loss_matrix.expected_losses(posterior);
716
717 let fallback_active = contract.fallback_policy().should_fallback(
718 ctx.calibration_score,
719 ctx.e_process,
720 ctx.ci_width,
721 );
722
723 let action_index = if fallback_active {
724 contract.fallback_action()
725 } else {
726 contract.choose_action(posterior)
727 };
728
729 let action_name = contract.action_set()[action_index].clone();
730 let expected_loss = expected_losses[&action_name];
731
732 let audit_entry = DecisionAuditEntry {
733 decision_id: ctx.decision_id,
734 trace_id: ctx.trace_id,
735 contract_name: contract.name().to_string(),
736 action_chosen: action_name.clone(),
737 expected_loss,
738 calibration_score: ctx.calibration_score,
739 fallback_active,
740 posterior_snapshot: posterior.probs().to_vec(),
741 expected_loss_by_action: expected_losses.clone(),
742 ts_unix_ms: ctx.ts_unix_ms,
743 };
744
745 DecisionOutcome {
746 action_index,
747 action_name,
748 expected_loss,
749 expected_losses,
750 fallback_active,
751 audit_entry,
752 }
753}
754
755#[cfg(test)]
760#[allow(clippy::float_cmp)]
761mod tests {
762 use super::*;
763
764 fn two_state_matrix() -> LossMatrix {
767 LossMatrix::new(
771 vec!["good".into(), "bad".into()],
772 vec!["continue".into(), "stop".into()],
773 vec![0.0, 0.3, 0.8, 0.1],
774 )
775 .unwrap()
776 }
777
778 struct TestContract {
779 states: Vec<String>,
780 actions: Vec<String>,
781 losses: LossMatrix,
782 policy: FallbackPolicy,
783 }
784
785 impl TestContract {
786 fn new() -> Self {
787 Self {
788 states: vec!["good".into(), "bad".into()],
789 actions: vec!["continue".into(), "stop".into()],
790 losses: two_state_matrix(),
791 policy: FallbackPolicy::default(),
792 }
793 }
794 }
795
796 #[allow(clippy::unnecessary_literal_bound)]
797 impl DecisionContract for TestContract {
798 fn name(&self) -> &str {
799 "test_contract"
800 }
801 fn state_space(&self) -> &[String] {
802 &self.states
803 }
804 fn action_set(&self) -> &[String] {
805 &self.actions
806 }
807 fn loss_matrix(&self) -> &LossMatrix {
808 &self.losses
809 }
810 fn update_posterior(&self, posterior: &mut Posterior, observation: usize) {
811 let mut likelihoods = vec![0.1; self.states.len()];
813 likelihoods[observation] = 0.9;
814 posterior.bayesian_update(&likelihoods);
815 }
816 fn choose_action(&self, posterior: &Posterior) -> usize {
817 self.losses.bayes_action(posterior)
818 }
819 fn fallback_action(&self) -> usize {
820 0 }
822 fn fallback_policy(&self) -> &FallbackPolicy {
823 &self.policy
824 }
825 }
826
827 #[test]
830 fn loss_matrix_creation() {
831 let m = two_state_matrix();
832 assert_eq!(m.n_states(), 2);
833 assert_eq!(m.n_actions(), 2);
834 assert_eq!(m.get(0, 0), 0.0);
835 assert_eq!(m.get(0, 1), 0.3);
836 assert_eq!(m.get(1, 0), 0.8);
837 assert_eq!(m.get(1, 1), 0.1);
838 }
839
840 #[test]
841 fn loss_matrix_empty_states_rejected() {
842 let err = LossMatrix::new(vec![], vec!["a".into()], vec![]).unwrap_err();
843 assert!(matches!(
844 err,
845 ValidationError::EmptySpace {
846 field: "state_names"
847 }
848 ));
849 }
850
851 #[test]
852 fn loss_matrix_empty_actions_rejected() {
853 let err = LossMatrix::new(vec!["s".into()], vec![], vec![]).unwrap_err();
854 assert!(matches!(
855 err,
856 ValidationError::EmptySpace {
857 field: "action_names"
858 }
859 ));
860 }
861
862 #[test]
863 fn loss_matrix_dimension_mismatch() {
864 let err = LossMatrix::new(
865 vec!["s1".into(), "s2".into()],
866 vec!["a1".into()],
867 vec![0.1], )
869 .unwrap_err();
870 assert!(matches!(
871 err,
872 ValidationError::DimensionMismatch {
873 expected: 2,
874 got: 1
875 }
876 ));
877 }
878
879 #[test]
880 fn loss_matrix_negative_rejected() {
881 let err = LossMatrix::new(vec!["s".into()], vec!["a".into()], vec![-0.5]).unwrap_err();
882 assert!(matches!(
883 err,
884 ValidationError::NegativeLoss {
885 state: 0,
886 action: 0,
887 ..
888 }
889 ));
890 }
891
892 #[test]
893 fn loss_matrix_non_finite_rejected() {
894 let err = LossMatrix::new(vec!["s".into()], vec!["a".into()], vec![f64::NAN]).unwrap_err();
895 assert!(matches!(
896 err,
897 ValidationError::InvalidLoss {
898 state: 0,
899 action: 0,
900 value
901 } if value.is_nan()
902 ));
903 }
904
905 #[test]
906 fn loss_matrix_expected_loss() {
907 let m = two_state_matrix();
908 let posterior = Posterior::new(vec![0.8, 0.2]).unwrap();
909 let el_continue = m.expected_loss(&posterior, 0);
911 assert!((el_continue - 0.16).abs() < 1e-10);
912 let el_stop = m.expected_loss(&posterior, 1);
914 assert!((el_stop - 0.26).abs() < 1e-10);
915 }
916
917 #[test]
918 fn loss_matrix_bayes_action() {
919 let m = two_state_matrix();
920 let mostly_good = Posterior::new(vec![0.9, 0.1]).unwrap();
922 assert_eq!(m.bayes_action(&mostly_good), 0); let mostly_bad = Posterior::new(vec![0.2, 0.8]).unwrap();
925 assert_eq!(m.bayes_action(&mostly_bad), 1); }
927
928 #[test]
929 fn loss_matrix_expected_losses_map() {
930 let m = two_state_matrix();
931 let posterior = Posterior::uniform(2);
932 let losses = m.expected_losses(&posterior);
933 assert_eq!(losses.len(), 2);
934 assert!(losses.contains_key("continue"));
935 assert!(losses.contains_key("stop"));
936 }
937
938 #[test]
939 fn loss_matrix_names() {
940 let m = two_state_matrix();
941 assert_eq!(m.state_names(), &["good", "bad"]);
942 assert_eq!(m.action_names(), &["continue", "stop"]);
943 }
944
945 #[test]
946 fn loss_matrix_toml_roundtrip() {
947 let m = two_state_matrix();
948 let toml_str = toml::to_string(&m).unwrap();
949 let parsed: LossMatrix = toml::from_str(&toml_str).unwrap();
950 assert_eq!(m, parsed);
951 }
952
953 #[test]
954 fn loss_matrix_json_roundtrip() {
955 let m = two_state_matrix();
956 let json = serde_json::to_string(&m).unwrap();
957 let parsed: LossMatrix = serde_json::from_str(&json).unwrap();
958 assert_eq!(m, parsed);
959 }
960
961 #[test]
962 fn loss_matrix_json_invalid_value_rejected_at_deserialize() {
963 let json = r#"{"state_names":["s"],"action_names":["a"],"values":[-0.5]}"#;
964 let err = serde_json::from_str::<LossMatrix>(json).unwrap_err();
965 assert!(err.to_string().contains("negative loss"));
966 }
967
968 #[test]
971 fn posterior_uniform() {
972 let p = Posterior::uniform(4);
973 assert_eq!(p.len(), 4);
974 for &v in p.probs() {
975 assert!((v - 0.25).abs() < 1e-10);
976 }
977 }
978
979 #[test]
980 fn posterior_new_valid() {
981 let p = Posterior::new(vec![0.3, 0.7]).unwrap();
982 assert_eq!(p.probs(), &[0.3, 0.7]);
983 }
984
985 #[test]
986 fn posterior_new_not_normalized() {
987 let err = Posterior::new(vec![0.5, 0.3]).unwrap_err();
988 assert!(matches!(
989 err,
990 ValidationError::PosteriorNotNormalized { .. }
991 ));
992 }
993
994 #[test]
995 fn posterior_new_negative_probability_rejected() {
996 let err = Posterior::new(vec![-0.1, 1.1]).unwrap_err();
997 assert!(matches!(
998 err,
999 ValidationError::InvalidPosteriorProbability {
1000 index: 0,
1001 value
1002 } if value == -0.1
1003 ));
1004 }
1005
1006 #[test]
1007 fn posterior_new_non_finite_probability_rejected() {
1008 let err = Posterior::new(vec![f64::NAN, 1.0]).unwrap_err();
1009 assert!(matches!(
1010 err,
1011 ValidationError::InvalidPosteriorProbability {
1012 index: 0,
1013 value
1014 } if value.is_nan()
1015 ));
1016 }
1017
1018 #[test]
1019 fn posterior_bayesian_update() {
1020 let mut p = Posterior::uniform(2);
1021 p.bayesian_update(&[0.9, 0.1]);
1023 assert!((p.probs()[0] - 0.9).abs() < 1e-10);
1025 assert!((p.probs()[1] - 0.1).abs() < 1e-10);
1026 }
1027
1028 #[test]
1029 fn posterior_bayesian_update_no_alloc() {
1030 let mut p = Posterior::uniform(3);
1032 let ptr_before = p.probs().as_ptr();
1033 p.bayesian_update(&[0.5, 0.3, 0.2]);
1034 let ptr_after = p.probs().as_ptr();
1035 assert_eq!(ptr_before, ptr_after);
1036 }
1037
1038 #[test]
1039 fn posterior_entropy() {
1040 let p = Posterior::uniform(2);
1042 assert!((p.entropy() - 1.0).abs() < 1e-10);
1043 let det = Posterior::new(vec![1.0, 0.0]).unwrap();
1045 assert!((det.entropy()).abs() < 1e-10);
1046 }
1047
1048 #[test]
1049 fn posterior_map_state() {
1050 let p = Posterior::new(vec![0.1, 0.7, 0.2]).unwrap();
1051 assert_eq!(p.map_state(), 1);
1052 }
1053
1054 #[test]
1055 fn posterior_is_empty() {
1056 let p = Posterior { probs: vec![] };
1057 assert!(p.is_empty());
1058 let p2 = Posterior::uniform(1);
1059 assert!(!p2.is_empty());
1060 }
1061
1062 #[test]
1063 fn posterior_probs_mut() {
1064 let mut p = Posterior::uniform(2);
1065 p.probs_mut()[0] = 0.8;
1066 p.probs_mut()[1] = 0.2;
1067 assert_eq!(p.probs(), &[0.8, 0.2]);
1068 }
1069
1070 #[test]
1073 fn fallback_policy_default() {
1074 let fp = FallbackPolicy::default();
1075 assert_eq!(fp.calibration_drift_threshold, 0.7);
1076 assert_eq!(fp.e_process_breach_threshold, 20.0);
1077 assert_eq!(fp.confidence_width_threshold, 0.5);
1078 }
1079
1080 #[test]
1081 fn fallback_policy_new_valid() {
1082 let fp = FallbackPolicy::new(0.8, 10.0, 0.3).unwrap();
1083 assert_eq!(fp.calibration_drift_threshold, 0.8);
1084 }
1085
1086 #[test]
1087 fn fallback_policy_calibration_out_of_range() {
1088 let err = FallbackPolicy::new(1.5, 10.0, 0.3).unwrap_err();
1089 assert!(matches!(
1090 err,
1091 ValidationError::ThresholdOutOfRange {
1092 field: "calibration_drift_threshold",
1093 ..
1094 }
1095 ));
1096 }
1097
1098 #[test]
1099 fn fallback_policy_negative_e_process() {
1100 let err = FallbackPolicy::new(0.7, -1.0, 0.3).unwrap_err();
1101 assert!(matches!(
1102 err,
1103 ValidationError::ThresholdOutOfRange {
1104 field: "e_process_breach_threshold",
1105 ..
1106 }
1107 ));
1108 }
1109
1110 #[test]
1111 fn fallback_policy_negative_ci_width() {
1112 let err = FallbackPolicy::new(0.7, 10.0, -0.1).unwrap_err();
1113 assert!(matches!(
1114 err,
1115 ValidationError::ThresholdOutOfRange {
1116 field: "confidence_width_threshold",
1117 ..
1118 }
1119 ));
1120 }
1121
1122 #[test]
1123 fn fallback_policy_non_finite_e_process_rejected() {
1124 let err = FallbackPolicy::new(0.7, f64::NAN, 0.3).unwrap_err();
1125 assert!(matches!(
1126 err,
1127 ValidationError::ThresholdOutOfRange {
1128 field: "e_process_breach_threshold",
1129 value
1130 } if value.is_nan()
1131 ));
1132 }
1133
1134 #[test]
1135 fn fallback_policy_non_finite_ci_width_rejected() {
1136 let err = FallbackPolicy::new(0.7, 10.0, f64::INFINITY).unwrap_err();
1137 assert!(matches!(
1138 err,
1139 ValidationError::ThresholdOutOfRange {
1140 field: "confidence_width_threshold",
1141 value
1142 } if value.is_infinite()
1143 ));
1144 }
1145
1146 #[test]
1147 fn fallback_policy_json_invalid_threshold_rejected_at_deserialize() {
1148 let json = r#"{
1149 "calibration_drift_threshold": 0.7,
1150 "e_process_breach_threshold": -1.0,
1151 "confidence_width_threshold": 0.3
1152 }"#;
1153 let err = serde_json::from_str::<FallbackPolicy>(json).unwrap_err();
1154 assert!(err.to_string().contains("threshold"));
1155 }
1156
1157 #[test]
1158 fn fallback_triggered_by_low_calibration() {
1159 let fp = FallbackPolicy::default();
1160 assert!(fp.should_fallback(0.5, 1.0, 0.1)); assert!(!fp.should_fallback(0.9, 1.0, 0.1)); }
1163
1164 #[test]
1165 fn fallback_triggered_by_e_process() {
1166 let fp = FallbackPolicy::default();
1167 assert!(fp.should_fallback(0.9, 25.0, 0.1)); assert!(!fp.should_fallback(0.9, 15.0, 0.1)); }
1170
1171 #[test]
1172 fn fallback_triggered_by_ci_width() {
1173 let fp = FallbackPolicy::default();
1174 assert!(fp.should_fallback(0.9, 1.0, 0.6)); assert!(!fp.should_fallback(0.9, 1.0, 0.3)); }
1177
1178 #[test]
1181 fn contract_implementable_under_50_lines() {
1182 let contract = TestContract::new();
1184 assert_eq!(contract.name(), "test_contract");
1185 assert_eq!(contract.state_space().len(), 2);
1186 assert_eq!(contract.action_set().len(), 2);
1187 }
1188
1189 fn test_ctx(cal: f64, random: u128) -> EvalContext {
1190 EvalContext {
1191 calibration_score: cal,
1192 e_process: 1.0,
1193 ci_width: 0.1,
1194 decision_id: DecisionId::from_parts(1_700_000_000_000, random),
1195 trace_id: TraceId::from_parts(1_700_000_000_000, random),
1196 ts_unix_ms: 1_700_000_000_000,
1197 }
1198 }
1199
1200 #[test]
1201 fn evaluate_normal_decision() {
1202 let contract = TestContract::new();
1203 let posterior = Posterior::new(vec![0.9, 0.1]).unwrap();
1204 let ctx = test_ctx(0.95, 42);
1205
1206 let outcome = evaluate(&contract, &posterior, &ctx);
1207
1208 assert!(!outcome.fallback_active);
1209 assert_eq!(outcome.action_name, "continue"); assert_eq!(outcome.action_index, 0);
1211 assert!(outcome.expected_loss < 0.1);
1212 assert_eq!(outcome.expected_losses.len(), 2);
1213 }
1214
1215 #[test]
1216 fn evaluate_fallback_decision() {
1217 let contract = TestContract::new();
1218 let posterior = Posterior::new(vec![0.2, 0.8]).unwrap();
1219 let ctx = test_ctx(0.5, 43); let outcome = evaluate(&contract, &posterior, &ctx);
1222
1223 assert!(outcome.fallback_active);
1224 assert_eq!(outcome.action_name, "continue"); assert_eq!(outcome.action_index, 0);
1226 }
1227
1228 #[test]
1229 fn evaluate_without_fallback_chooses_optimal() {
1230 let contract = TestContract::new();
1231 let posterior = Posterior::new(vec![0.2, 0.8]).unwrap();
1232 let ctx = test_ctx(0.95, 44); let outcome = evaluate(&contract, &posterior, &ctx);
1235
1236 assert!(!outcome.fallback_active);
1237 assert_eq!(outcome.action_name, "stop"); }
1239
1240 #[test]
1241 fn evaluate_audit_entry_fields() {
1242 let contract = TestContract::new();
1243 let posterior = Posterior::uniform(2);
1244 let ctx = test_ctx(0.85, 99);
1245
1246 let outcome = evaluate(&contract, &posterior, &ctx);
1247
1248 let audit = &outcome.audit_entry;
1249 assert_eq!(audit.decision_id, ctx.decision_id);
1250 assert_eq!(audit.trace_id, ctx.trace_id);
1251 assert_eq!(audit.contract_name, "test_contract");
1252 assert_eq!(audit.calibration_score, 0.85);
1253 assert_eq!(audit.ts_unix_ms, 1_700_000_000_000);
1254 assert_eq!(audit.posterior_snapshot.len(), 2);
1255 }
1256
1257 #[test]
1260 fn audit_entry_to_evidence_ledger() {
1261 let contract = TestContract::new();
1262 let posterior = Posterior::new(vec![0.6, 0.4]).unwrap();
1263 let ctx = test_ctx(0.92, 100);
1264
1265 let outcome = evaluate(&contract, &posterior, &ctx);
1266 let evidence = outcome.audit_entry.to_evidence_ledger();
1267
1268 assert_eq!(evidence.ts_unix_ms, 1_700_000_000_000);
1269 assert_eq!(evidence.component, "test_contract");
1270 assert_eq!(evidence.action, outcome.action_name);
1271 assert_eq!(evidence.calibration_score, 0.92);
1272 assert!(!evidence.fallback_active);
1273 assert_eq!(evidence.posterior, vec![0.6, 0.4]);
1274 assert!(evidence.is_valid());
1275 }
1276
1277 #[test]
1278 fn audit_entry_serde_roundtrip() {
1279 let contract = TestContract::new();
1280 let posterior = Posterior::uniform(2);
1281 let ctx = test_ctx(0.88, 101);
1282
1283 let outcome = evaluate(&contract, &posterior, &ctx);
1284 let json = serde_json::to_string(&outcome.audit_entry).unwrap();
1285 let parsed: DecisionAuditEntry = serde_json::from_str(&json).unwrap();
1286 assert_eq!(parsed.contract_name, "test_contract");
1287 assert_eq!(parsed.decision_id, ctx.decision_id);
1288 assert_eq!(parsed.trace_id, ctx.trace_id);
1289 }
1290
1291 #[test]
1294 fn contract_update_posterior() {
1295 let contract = TestContract::new();
1296 let mut posterior = Posterior::uniform(2);
1297 contract.update_posterior(&mut posterior, 0); assert!(posterior.probs()[0] > posterior.probs()[1]);
1300 }
1301
1302 #[test]
1305 fn validation_error_display() {
1306 let err = ValidationError::NegativeLoss {
1307 state: 1,
1308 action: 2,
1309 value: -0.5,
1310 };
1311 let msg = format!("{err}");
1312 assert!(msg.contains("-0.5"));
1313 assert!(msg.contains("state=1"));
1314 assert!(msg.contains("action=2"));
1315 }
1316
1317 #[test]
1318 fn dimension_mismatch_display() {
1319 let err = ValidationError::DimensionMismatch {
1320 expected: 6,
1321 got: 4,
1322 };
1323 let msg = format!("{err}");
1324 assert!(msg.contains('6'));
1325 assert!(msg.contains('4'));
1326 }
1327
1328 #[test]
1331 fn fallback_policy_toml_roundtrip() {
1332 let fp = FallbackPolicy::default();
1333 let toml_str = toml::to_string(&fp).unwrap();
1334 let parsed: FallbackPolicy = toml::from_str(&toml_str).unwrap();
1335 assert_eq!(fp, parsed);
1336 }
1337
1338 #[test]
1339 fn fallback_policy_json_roundtrip() {
1340 let fp = FallbackPolicy::default();
1341 let json = serde_json::to_string(&fp).unwrap();
1342 let parsed: FallbackPolicy = serde_json::from_str(&json).unwrap();
1343 assert_eq!(fp, parsed);
1344 }
1345
1346 #[test]
1349 fn argmin_correctness_deterministic_posterior() {
1350 let m = two_state_matrix();
1351 let certain_good = Posterior::new(vec![1.0, 0.0]).unwrap();
1353 assert_eq!(m.bayes_action(&certain_good), 0);
1354 let certain_bad = Posterior::new(vec![0.0, 1.0]).unwrap();
1356 assert_eq!(m.bayes_action(&certain_bad), 1);
1357 }
1358
1359 #[test]
1360 fn argmin_correctness_breakeven_point() {
1361 let m = two_state_matrix();
1362 let above = Posterior::new(vec![0.71, 0.29]).unwrap();
1366 assert_eq!(m.bayes_action(&above), 0);
1367 let below = Posterior::new(vec![0.69, 0.31]).unwrap();
1369 assert_eq!(m.bayes_action(&below), 1);
1370 }
1371
1372 #[test]
1373 fn argmin_three_state_three_action() {
1374 let m = LossMatrix::new(
1376 vec!["s0".into(), "s1".into(), "s2".into()],
1377 vec!["a0".into(), "a1".into(), "a2".into()],
1378 vec![
1379 1.0, 2.0, 3.0, 3.0, 1.0, 2.0, 2.0, 3.0, 1.0, ],
1383 )
1384 .unwrap();
1385 let uniform = Posterior::uniform(3);
1388 let action = m.bayes_action(&uniform);
1389 assert!(action < 3);
1391 let state1 = Posterior::new(vec![0.0, 1.0, 0.0]).unwrap();
1393 assert_eq!(m.bayes_action(&state1), 1);
1394 let state2 = Posterior::new(vec![0.0, 0.0, 1.0]).unwrap();
1396 assert_eq!(m.bayes_action(&state2), 2);
1397 }
1398
1399 #[test]
1402 fn bayesian_update_hand_computed_three_state() {
1403 let mut p = Posterior::new(vec![0.5, 0.3, 0.2]).unwrap();
1408 p.bayesian_update(&[0.1, 0.6, 0.3]);
1409 let expected = [0.05 / 0.29, 0.18 / 0.29, 0.06 / 0.29];
1410 for (i, &e) in expected.iter().enumerate() {
1411 assert!(
1412 (p.probs()[i] - e).abs() < 1e-10,
1413 "state {i}: got {}, expected {e}",
1414 p.probs()[i]
1415 );
1416 }
1417 }
1418
1419 #[test]
1420 fn bayesian_update_successive_convergence() {
1421 let mut p = Posterior::uniform(3);
1423 for _ in 0..20 {
1424 p.bayesian_update(&[0.9, 0.05, 0.05]);
1425 }
1426 assert!(p.probs()[0] > 0.999);
1427 assert!(p.probs()[1] < 0.001);
1428 assert!(p.probs()[2] < 0.001);
1429 }
1430
1431 #[test]
1434 fn end_to_end_pipeline() {
1435 let contract = TestContract::new();
1436 let mut posterior = Posterior::uniform(2);
1437
1438 for _ in 0..5 {
1440 contract.update_posterior(&mut posterior, 0);
1441 }
1442 assert!(posterior.probs()[0] > 0.99);
1443
1444 let ctx = test_ctx(0.95, 200);
1446 let outcome = evaluate(&contract, &posterior, &ctx);
1447 assert!(!outcome.fallback_active);
1448 assert_eq!(outcome.action_name, "continue");
1449 assert!(outcome.expected_loss < 0.01);
1450
1451 let evidence = outcome.audit_entry.to_evidence_ledger();
1453 assert_eq!(evidence.component, "test_contract");
1454 assert_eq!(evidence.action, "continue");
1455 assert!(evidence.is_valid());
1456
1457 for _ in 0..20 {
1459 contract.update_posterior(&mut posterior, 1);
1460 }
1461 assert!(posterior.probs()[1] > 0.99);
1462
1463 let ctx2 = test_ctx(0.95, 201);
1465 let outcome2 = evaluate(&contract, &posterior, &ctx2);
1466 assert_eq!(outcome2.action_name, "stop");
1467 }
1468
1469 #[test]
1472 fn concurrent_decision_safety() {
1473 use std::sync::Arc;
1474 use std::thread;
1475
1476 let contract = Arc::new(TestContract::new());
1477 let results: Vec<_> = (0..10)
1478 .map(|i| {
1479 let c = Arc::clone(&contract);
1480 thread::spawn(move || {
1481 let posterior = Posterior::uniform(2);
1482 let ctx = EvalContext {
1483 calibration_score: 0.9,
1484 e_process: 1.0,
1485 ci_width: 0.1,
1486 decision_id: DecisionId::from_parts(1_700_000_000_000, u128::from(i)),
1487 trace_id: TraceId::from_parts(1_700_000_000_000, u128::from(i)),
1488 ts_unix_ms: 1_700_000_000_000 + i,
1489 };
1490 let outcome = evaluate(c.as_ref(), &posterior, &ctx);
1491 assert!(!outcome.action_name.is_empty());
1492 assert_eq!(outcome.expected_losses.len(), 2);
1493 let evidence = outcome.audit_entry.to_evidence_ledger();
1494 assert!(evidence.is_valid());
1495 outcome
1496 })
1497 })
1498 .map(|h| h.join().unwrap())
1499 .collect();
1500 assert_eq!(results.len(), 10);
1501 let actions: std::collections::HashSet<_> =
1503 results.iter().map(|r| r.action_name.clone()).collect();
1504 assert_eq!(
1505 actions.len(),
1506 1,
1507 "all threads should choose the same action"
1508 );
1509 }
1510
1511 #[test]
1514 fn cross_crate_franken_kernel_types() {
1515 let did = DecisionId::from_parts(1_700_000_000_000, 42);
1517 assert_eq!(did.timestamp_ms(), 1_700_000_000_000);
1518 let tid = TraceId::from_parts(1_700_000_000_000, 1);
1519 assert_eq!(tid.timestamp_ms(), 1_700_000_000_000);
1520
1521 let contract = TestContract::new();
1523 let posterior = Posterior::uniform(2);
1524 let ctx = EvalContext {
1525 calibration_score: 0.9,
1526 e_process: 1.0,
1527 ci_width: 0.1,
1528 decision_id: did,
1529 trace_id: tid,
1530 ts_unix_ms: 1_700_000_000_000,
1531 };
1532 let outcome = evaluate(&contract, &posterior, &ctx);
1533 assert_eq!(outcome.audit_entry.decision_id, did);
1534 assert_eq!(outcome.audit_entry.trace_id, tid);
1535 }
1536
1537 #[test]
1540 fn posterior_json_roundtrip() {
1541 let p = Posterior::new(vec![0.25, 0.75]).unwrap();
1542 let json = serde_json::to_string(&p).unwrap();
1543 let parsed: Posterior = serde_json::from_str(&json).unwrap();
1544 assert_eq!(p, parsed);
1545 }
1546
1547 #[test]
1548 fn posterior_json_invalid_value_rejected_at_deserialize() {
1549 let json = r#"{"probs":[-0.1,1.1]}"#;
1550 let err = serde_json::from_str::<Posterior>(json).unwrap_err();
1551 assert!(err.to_string().contains("finite and non-negative"));
1552 }
1553
1554 #[test]
1557 fn loss_matrix_3x3_toml_roundtrip() {
1558 let m = LossMatrix::new(
1559 vec!["s0".into(), "s1".into(), "s2".into()],
1560 vec!["a0".into(), "a1".into(), "a2".into()],
1561 vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
1562 )
1563 .unwrap();
1564 let toml_str = toml::to_string(&m).unwrap();
1565 let parsed: LossMatrix = toml::from_str(&toml_str).unwrap();
1566 assert_eq!(m, parsed);
1567 }
1568
1569 #[test]
1572 fn decision_outcome_debug() {
1573 let contract = TestContract::new();
1574 let posterior = Posterior::uniform(2);
1575 let ctx = test_ctx(0.9, 300);
1576 let outcome = evaluate(&contract, &posterior, &ctx);
1577 let dbg = format!("{outcome:?}");
1578 assert!(dbg.contains("DecisionOutcome"));
1579 assert!(dbg.contains("action_name"));
1580 }
1581
1582 #[test]
1585 fn fallback_multiple_triggers_simultaneously() {
1586 let fp = FallbackPolicy::default();
1587 assert!(fp.should_fallback(0.3, 30.0, 0.9));
1589 }
1590
1591 #[test]
1592 fn fallback_no_trigger_at_exact_thresholds() {
1593 let fp = FallbackPolicy::default();
1594 assert!(!fp.should_fallback(0.7, 20.0, 0.5));
1596 }
1597
1598 #[test]
1601 fn posterior_entropy_three_state_uniform() {
1602 let p = Posterior::uniform(3);
1603 assert!((p.entropy() - 3.0_f64.log2()).abs() < 1e-10);
1605 }
1606
1607 #[test]
1608 fn posterior_entropy_single_state() {
1609 let p = Posterior::new(vec![1.0]).unwrap();
1610 assert!((p.entropy()).abs() < 1e-10);
1611 }
1612
1613 #[test]
1616 fn validation_error_is_std_error() {
1617 fn assert_error<E: std::error::Error>() {}
1618 assert_error::<ValidationError>();
1619 }
1620}
1621
1622#[cfg(test)]
1627#[allow(clippy::float_cmp)]
1628mod proptest_tests {
1629 use super::*;
1630 use proptest::prelude::*;
1631
1632 fn arb_posterior(n: usize) -> impl Strategy<Value = Posterior> {
1634 proptest::collection::vec(0.01_f64..=1.0, n).prop_map(|mut v| {
1635 let sum: f64 = v.iter().sum();
1636 for p in &mut v {
1637 *p /= sum;
1638 }
1639 Posterior::new(v).unwrap()
1640 })
1641 }
1642
1643 fn arb_loss_matrix(n_states: usize, n_actions: usize) -> impl Strategy<Value = LossMatrix> {
1645 let states: Vec<String> = (0..n_states).map(|i| format!("s{i}")).collect();
1646 let actions: Vec<String> = (0..n_actions).map(|i| format!("a{i}")).collect();
1647 proptest::collection::vec(0.0_f64..=10.0, n_states * n_actions).prop_map(move |values| {
1648 LossMatrix::new(states.clone(), actions.clone(), values).unwrap()
1649 })
1650 }
1651
1652 proptest! {
1655 #![proptest_config(ProptestConfig::with_cases(10_000))]
1656
1657 #[test]
1658 fn bayes_action_minimizes_expected_loss(
1659 matrix in arb_loss_matrix(3, 3),
1660 posterior in arb_posterior(3),
1661 ) {
1662 let chosen = matrix.bayes_action(&posterior);
1663 let chosen_loss = matrix.expected_loss(&posterior, chosen);
1664 for a in 0..matrix.n_actions() {
1665 let other_loss = matrix.expected_loss(&posterior, a);
1666 prop_assert!(
1667 chosen_loss <= other_loss + 1e-10,
1668 "action {chosen} (loss {chosen_loss}) should be <= action {a} (loss {other_loss})"
1669 );
1670 }
1671 }
1672 }
1673
1674 proptest! {
1675 #![proptest_config(ProptestConfig::with_cases(10_000))]
1676
1677 #[test]
1678 fn bayes_action_minimizes_2x2(
1679 matrix in arb_loss_matrix(2, 2),
1680 posterior in arb_posterior(2),
1681 ) {
1682 let chosen = matrix.bayes_action(&posterior);
1683 let chosen_loss = matrix.expected_loss(&posterior, chosen);
1684 for a in 0..matrix.n_actions() {
1685 prop_assert!(chosen_loss <= matrix.expected_loss(&posterior, a) + 1e-10);
1686 }
1687 }
1688 }
1689
1690 proptest! {
1693 #![proptest_config(ProptestConfig::with_cases(10_000))]
1694
1695 #[test]
1696 fn bayesian_update_preserves_normalization(
1697 prior in arb_posterior(4),
1698 likelihoods in proptest::collection::vec(0.01_f64..=1.0, 4usize),
1699 ) {
1700 let mut p = prior;
1701 p.bayesian_update(&likelihoods);
1702 let sum: f64 = p.probs().iter().sum();
1703 prop_assert!(
1704 (sum - 1.0).abs() < 1e-10,
1705 "posterior sum = {sum}, expected 1.0"
1706 );
1707 for &prob in p.probs() {
1708 prop_assert!(prob >= 0.0, "negative probability: {prob}");
1709 }
1710 }
1711 }
1712
1713 proptest! {
1716 #![proptest_config(ProptestConfig::with_cases(10_000))]
1717
1718 #[test]
1719 fn posterior_all_non_negative_after_update(
1720 prior in arb_posterior(3),
1721 likelihoods in proptest::collection::vec(0.0_f64..=1.0, 3usize),
1722 ) {
1723 let mut p = prior;
1724 let lik_sum: f64 = likelihoods.iter().sum();
1726 if lik_sum > 0.0 {
1727 p.bayesian_update(&likelihoods);
1728 for &prob in p.probs() {
1729 prop_assert!(prob >= 0.0, "negative probability: {prob}");
1730 }
1731 }
1732 }
1733 }
1734
1735 proptest! {
1738 #[test]
1739 fn fallback_policy_serde_roundtrip(
1740 cal in 0.0_f64..=1.0,
1741 e_proc in 0.0_f64..=100.0,
1742 ci in 0.0_f64..=10.0,
1743 ) {
1744 let fp = FallbackPolicy::new(cal, e_proc, ci).unwrap();
1745 let json = serde_json::to_string(&fp).unwrap();
1746 let parsed: FallbackPolicy = serde_json::from_str(&json).unwrap();
1747 prop_assert!((fp.calibration_drift_threshold - parsed.calibration_drift_threshold).abs() < 1e-12);
1749 prop_assert!((fp.e_process_breach_threshold - parsed.e_process_breach_threshold).abs() < 1e-12);
1750 prop_assert!((fp.confidence_width_threshold - parsed.confidence_width_threshold).abs() < 1e-12);
1751 }
1752 }
1753
1754 proptest! {
1757 #[test]
1758 fn loss_matrix_serde_roundtrip(
1759 matrix in arb_loss_matrix(2, 3),
1760 ) {
1761 let json = serde_json::to_string(&matrix).unwrap();
1762 let parsed: LossMatrix = serde_json::from_str(&json).unwrap();
1763 prop_assert_eq!(matrix.state_names(), parsed.state_names());
1764 prop_assert_eq!(matrix.action_names(), parsed.action_names());
1765 for s in 0..matrix.n_states() {
1767 for a in 0..matrix.n_actions() {
1768 prop_assert!((matrix.get(s, a) - parsed.get(s, a)).abs() < 1e-12);
1769 }
1770 }
1771 }
1772 }
1773
1774 proptest! {
1777 #![proptest_config(ProptestConfig::with_cases(10_000))]
1778
1779 #[test]
1780 fn expected_loss_within_loss_range(
1781 matrix in arb_loss_matrix(3, 3),
1782 posterior in arb_posterior(3),
1783 ) {
1784 for a in 0..matrix.n_actions() {
1785 let el = matrix.expected_loss(&posterior, a);
1786 let min_loss = (0..matrix.n_states())
1787 .map(|s| matrix.get(s, a))
1788 .fold(f64::INFINITY, f64::min);
1789 let max_loss = (0..matrix.n_states())
1790 .map(|s| matrix.get(s, a))
1791 .fold(f64::NEG_INFINITY, f64::max);
1792 prop_assert!(
1793 el >= min_loss - 1e-10 && el <= max_loss + 1e-10,
1794 "expected loss {el} outside [{min_loss}, {max_loss}]"
1795 );
1796 }
1797 }
1798 }
1799}