Skip to main content

franken_decision/
lib.rs

1//! Decision Contract schema and runtime for FrankenSuite (bd-3ai21).
2//!
3//! The third leg of the foundation tripod alongside `franken_kernel` (types)
4//! and `franken_evidence` (audit ledger). Every FrankenSuite project that
5//! makes runtime decisions uses this crate's contract schema.
6//!
7//! # Core abstractions
8//!
9//! - [`DecisionContract`] — trait defining state space, actions, losses, and
10//!   posterior updates. Implementable in <50 lines.
11//! - [`LossMatrix`] — non-negative loss values indexed by (state, action),
12//!   serializable to TOML for runtime reconfiguration.
13//! - [`Posterior`] — discrete probability distribution with O(|S|)
14//!   no-allocation Bayesian updates.
15//! - [`FallbackPolicy`] — calibration drift, e-process breach, and
16//!   confidence interval width thresholds.
17//! - [`DecisionAuditEntry`] — links decisions to [`EvidenceLedger`] entries.
18//!
19//! # Example
20//!
21//! ```
22//! use franken_decision::{
23//!     DecisionContract, EvalContext, FallbackPolicy, LossMatrix, Posterior, evaluate,
24//! };
25//! use franken_kernel::DecisionId;
26//!
27//! // Define a simple 2-state, 2-action contract.
28//! struct MyContract {
29//!     states: Vec<String>,
30//!     actions: Vec<String>,
31//!     losses: LossMatrix,
32//!     policy: FallbackPolicy,
33//! }
34//!
35//! impl DecisionContract for MyContract {
36//!     fn name(&self) -> &str { "example" }
37//!     fn state_space(&self) -> &[String] { &self.states }
38//!     fn action_set(&self) -> &[String] { &self.actions }
39//!     fn loss_matrix(&self) -> &LossMatrix { &self.losses }
40//!     fn update_posterior(&self, posterior: &mut Posterior, observation: usize) {
41//!         let likelihoods = [0.9, 0.1];
42//!         posterior.bayesian_update(&likelihoods);
43//!     }
44//!     fn choose_action(&self, posterior: &Posterior) -> usize {
45//!         self.losses.bayes_action(posterior)
46//!     }
47//!     fn fallback_action(&self) -> usize { 0 }
48//!     fn fallback_policy(&self) -> &FallbackPolicy { &self.policy }
49//! }
50//!
51//! let contract = MyContract {
52//!     states: vec!["good".into(), "bad".into()],
53//!     actions: vec!["continue".into(), "stop".into()],
54//!     losses: LossMatrix::new(
55//!         vec!["good".into(), "bad".into()],
56//!         vec!["continue".into(), "stop".into()],
57//!         vec![0.0, 0.3, 0.8, 0.1],
58//!     ).unwrap(),
59//!     policy: FallbackPolicy::default(),
60//! };
61//!
62//! let posterior = Posterior::uniform(2);
63//! let decision_id = DecisionId::from_parts(1_700_000_000_000, 42);
64//! let trace_id = franken_kernel::TraceId::from_parts(1_700_000_000_000, 1);
65//!
66//! let ctx = EvalContext {
67//!     calibration_score: 0.9,
68//!     e_process: 0.5,
69//!     ci_width: 0.1,
70//!     decision_id,
71//!     trace_id,
72//!     ts_unix_ms: 1_700_000_000_000,
73//! };
74//! let outcome = evaluate(&contract, &posterior, &ctx);
75//! assert!(!outcome.fallback_active);
76//! ```
77
78#![forbid(unsafe_code)]
79
80use std::collections::HashMap;
81use std::fmt;
82
83use franken_evidence::{EvidenceLedger, EvidenceLedgerBuilder};
84use franken_kernel::{DecisionId, TraceId};
85use serde::{Deserialize, Serialize};
86
87// ---------------------------------------------------------------------------
88// Validation errors
89// ---------------------------------------------------------------------------
90
91/// Validation errors for decision types.
92#[derive(Clone, Debug, PartialEq)]
93pub enum ValidationError {
94    /// Loss matrix contains a negative value.
95    NegativeLoss {
96        /// State index of the negative entry.
97        state: usize,
98        /// Action index of the negative entry.
99        action: usize,
100        /// The negative value.
101        value: f64,
102    },
103    /// Loss matrix value count does not match dimensions.
104    DimensionMismatch {
105        /// Expected number of values (states * actions).
106        expected: usize,
107        /// Actual number of values provided.
108        got: usize,
109    },
110    /// Posterior probabilities do not sum to ~1.0.
111    PosteriorNotNormalized {
112        /// Actual sum of the posterior.
113        sum: f64,
114    },
115    /// Posterior length does not match state space size.
116    PosteriorLengthMismatch {
117        /// Expected length.
118        expected: usize,
119        /// Actual length.
120        got: usize,
121    },
122    /// State space or action set is empty.
123    EmptySpace {
124        /// Which space is empty.
125        field: &'static str,
126    },
127    /// Threshold value is out of valid range.
128    ThresholdOutOfRange {
129        /// Which threshold.
130        field: &'static str,
131        /// The invalid value.
132        value: f64,
133    },
134}
135
136impl fmt::Display for ValidationError {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        match self {
139            Self::NegativeLoss {
140                state,
141                action,
142                value,
143            } => write!(f, "negative loss {value} at state={state}, action={action}"),
144            Self::DimensionMismatch { expected, got } => {
145                write!(
146                    f,
147                    "dimension mismatch: expected {expected} values, got {got}"
148                )
149            }
150            Self::PosteriorNotNormalized { sum } => {
151                write!(f, "posterior sums to {sum}, expected 1.0")
152            }
153            Self::PosteriorLengthMismatch { expected, got } => {
154                write!(
155                    f,
156                    "posterior length {got} does not match state count {expected}"
157                )
158            }
159            Self::EmptySpace { field } => write!(f, "{field} must not be empty"),
160            Self::ThresholdOutOfRange { field, value } => {
161                write!(f, "{field} threshold {value} out of valid range")
162            }
163        }
164    }
165}
166
167impl std::error::Error for ValidationError {}
168
169// ---------------------------------------------------------------------------
170// LossMatrix
171// ---------------------------------------------------------------------------
172
173/// A loss matrix indexed by (state, action) pairs.
174///
175/// Stored in row-major order: `values[state * n_actions + action]`.
176/// All values must be non-negative. Serializable to TOML/JSON for
177/// runtime reconfiguration.
178#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
179pub struct LossMatrix {
180    state_names: Vec<String>,
181    action_names: Vec<String>,
182    values: Vec<f64>,
183}
184
185impl LossMatrix {
186    /// Create a new loss matrix.
187    ///
188    /// `values` must have exactly `state_names.len() * action_names.len()`
189    /// elements, all non-negative. Laid out in row-major order:
190    /// `values[s * n_actions + a]` is the loss for state `s`, action `a`.
191    pub fn new(
192        state_names: Vec<String>,
193        action_names: Vec<String>,
194        values: Vec<f64>,
195    ) -> Result<Self, ValidationError> {
196        if state_names.is_empty() {
197            return Err(ValidationError::EmptySpace {
198                field: "state_names",
199            });
200        }
201        if action_names.is_empty() {
202            return Err(ValidationError::EmptySpace {
203                field: "action_names",
204            });
205        }
206        let expected = state_names.len() * action_names.len();
207        if values.len() != expected {
208            return Err(ValidationError::DimensionMismatch {
209                expected,
210                got: values.len(),
211            });
212        }
213        let n_actions = action_names.len();
214        for (i, &v) in values.iter().enumerate() {
215            if v < 0.0 {
216                return Err(ValidationError::NegativeLoss {
217                    state: i / n_actions,
218                    action: i % n_actions,
219                    value: v,
220                });
221            }
222        }
223        Ok(Self {
224            state_names,
225            action_names,
226            values,
227        })
228    }
229
230    /// Get the loss for a specific (state, action) pair.
231    pub fn get(&self, state: usize, action: usize) -> f64 {
232        self.values[state * self.action_names.len() + action]
233    }
234
235    /// Number of states.
236    pub fn n_states(&self) -> usize {
237        self.state_names.len()
238    }
239
240    /// Number of actions.
241    pub fn n_actions(&self) -> usize {
242        self.action_names.len()
243    }
244
245    /// State labels.
246    pub fn state_names(&self) -> &[String] {
247        &self.state_names
248    }
249
250    /// Action labels.
251    pub fn action_names(&self) -> &[String] {
252        &self.action_names
253    }
254
255    /// Compute expected loss for a specific action given a posterior.
256    ///
257    /// `E[loss|a] = sum_s posterior(s) * loss(s, a)`
258    pub fn expected_loss(&self, posterior: &Posterior, action: usize) -> f64 {
259        posterior
260            .probs()
261            .iter()
262            .enumerate()
263            .map(|(s, &p)| p * self.get(s, action))
264            .sum()
265    }
266
267    /// Compute expected losses for all actions as a name-indexed map.
268    pub fn expected_losses(&self, posterior: &Posterior) -> HashMap<String, f64> {
269        self.action_names
270            .iter()
271            .enumerate()
272            .map(|(a, name)| (name.clone(), self.expected_loss(posterior, a)))
273            .collect()
274    }
275
276    /// Choose the Bayes-optimal action (minimum expected loss).
277    ///
278    /// Returns the action index. Ties are broken by lowest index.
279    pub fn bayes_action(&self, posterior: &Posterior) -> usize {
280        (0..self.action_names.len())
281            .min_by(|&a, &b| {
282                self.expected_loss(posterior, a)
283                    .partial_cmp(&self.expected_loss(posterior, b))
284                    .unwrap_or(std::cmp::Ordering::Equal)
285            })
286            .unwrap_or(0)
287    }
288}
289
290// ---------------------------------------------------------------------------
291// Posterior
292// ---------------------------------------------------------------------------
293
294/// Tolerance for posterior normalization checks.
295const NORMALIZATION_TOLERANCE: f64 = 1e-6;
296
297/// A discrete probability distribution over states.
298///
299/// Supports in-place Bayesian updates in O(|S|) with no allocation.
300#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
301pub struct Posterior {
302    probs: Vec<f64>,
303}
304
305impl Posterior {
306    /// Create from explicit probabilities.
307    ///
308    /// Probabilities must sum to ~1.0 (within tolerance) and be non-negative.
309    pub fn new(probs: Vec<f64>) -> Result<Self, ValidationError> {
310        let sum: f64 = probs.iter().sum();
311        if (sum - 1.0).abs() > NORMALIZATION_TOLERANCE {
312            return Err(ValidationError::PosteriorNotNormalized { sum });
313        }
314        Ok(Self { probs })
315    }
316
317    /// Create a uniform prior over `n` states.
318    #[allow(clippy::cast_precision_loss)]
319    pub fn uniform(n: usize) -> Self {
320        let p = 1.0 / n as f64;
321        Self { probs: vec![p; n] }
322    }
323
324    /// Probability values (immutable).
325    pub fn probs(&self) -> &[f64] {
326        &self.probs
327    }
328
329    /// Mutable access to probability values for in-place updates.
330    pub fn probs_mut(&mut self) -> &mut [f64] {
331        &mut self.probs
332    }
333
334    /// Number of states in the distribution.
335    pub fn len(&self) -> usize {
336        self.probs.len()
337    }
338
339    /// Whether the distribution is empty.
340    pub fn is_empty(&self) -> bool {
341        self.probs.is_empty()
342    }
343
344    /// Bayesian update: multiply by likelihoods and renormalize.
345    ///
346    /// `likelihoods[s]` = P(observation | state = s).
347    /// Runs in O(|S|) with no allocation.
348    ///
349    /// # Panics
350    ///
351    /// Panics if `likelihoods.len() != self.len()`.
352    pub fn bayesian_update(&mut self, likelihoods: &[f64]) {
353        assert_eq!(likelihoods.len(), self.probs.len());
354        for (p, &l) in self.probs.iter_mut().zip(likelihoods) {
355            *p *= l;
356        }
357        self.normalize();
358    }
359
360    /// Renormalize probabilities to sum to 1.0.
361    pub fn normalize(&mut self) {
362        let sum: f64 = self.probs.iter().sum();
363        if sum > 0.0 {
364            for p in &mut self.probs {
365                *p /= sum;
366            }
367        }
368    }
369
370    /// Shannon entropy: -sum p * log2(p).
371    pub fn entropy(&self) -> f64 {
372        self.probs
373            .iter()
374            .filter(|&&p| p > 0.0)
375            .map(|&p| -p * p.log2())
376            .sum()
377    }
378
379    /// Index of the most probable state (MAP estimate).
380    ///
381    /// Ties are broken by lowest index.
382    pub fn map_state(&self) -> usize {
383        self.probs
384            .iter()
385            .enumerate()
386            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
387            .map_or(0, |(i, _)| i)
388    }
389}
390
391// ---------------------------------------------------------------------------
392// FallbackPolicy
393// ---------------------------------------------------------------------------
394
395/// Conditions under which to activate fallback heuristics.
396///
397/// A decision engine should switch to [`DecisionContract::fallback_action`]
398/// when any threshold is breached.
399#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
400pub struct FallbackPolicy {
401    /// Activate fallback if calibration score drops below this value.
402    pub calibration_drift_threshold: f64,
403    /// Activate fallback if e-process statistic exceeds this value.
404    pub e_process_breach_threshold: f64,
405    /// Activate fallback if confidence interval width exceeds this value.
406    pub confidence_width_threshold: f64,
407}
408
409impl FallbackPolicy {
410    /// Create a new fallback policy.
411    ///
412    /// `calibration_drift_threshold` must be in [0, 1].
413    /// Other thresholds must be non-negative.
414    pub fn new(
415        calibration_drift_threshold: f64,
416        e_process_breach_threshold: f64,
417        confidence_width_threshold: f64,
418    ) -> Result<Self, ValidationError> {
419        if !(0.0..=1.0).contains(&calibration_drift_threshold) {
420            return Err(ValidationError::ThresholdOutOfRange {
421                field: "calibration_drift_threshold",
422                value: calibration_drift_threshold,
423            });
424        }
425        if e_process_breach_threshold < 0.0 {
426            return Err(ValidationError::ThresholdOutOfRange {
427                field: "e_process_breach_threshold",
428                value: e_process_breach_threshold,
429            });
430        }
431        if confidence_width_threshold < 0.0 {
432            return Err(ValidationError::ThresholdOutOfRange {
433                field: "confidence_width_threshold",
434                value: confidence_width_threshold,
435            });
436        }
437        Ok(Self {
438            calibration_drift_threshold,
439            e_process_breach_threshold,
440            confidence_width_threshold,
441        })
442    }
443
444    /// Check if fallback should be activated based on current metrics.
445    pub fn should_fallback(&self, calibration_score: f64, e_process: f64, ci_width: f64) -> bool {
446        calibration_score < self.calibration_drift_threshold
447            || e_process > self.e_process_breach_threshold
448            || ci_width > self.confidence_width_threshold
449    }
450}
451
452impl Default for FallbackPolicy {
453    fn default() -> Self {
454        Self {
455            calibration_drift_threshold: 0.7,
456            e_process_breach_threshold: 20.0,
457            confidence_width_threshold: 0.5,
458        }
459    }
460}
461
462// ---------------------------------------------------------------------------
463// DecisionContract trait
464// ---------------------------------------------------------------------------
465
466/// A contract defining the decision-making framework for a component.
467///
468/// Implementors define the state space, action set, loss matrix, and
469/// posterior update logic. The [`evaluate`] function orchestrates the
470/// full decision pipeline and produces an auditable outcome.
471pub trait DecisionContract {
472    /// Human-readable contract name (e.g., "scheduler", "load_balancer").
473    fn name(&self) -> &str;
474
475    /// Ordered labels for the state space.
476    fn state_space(&self) -> &[String];
477
478    /// Ordered labels for the action set.
479    fn action_set(&self) -> &[String];
480
481    /// The loss matrix for this contract.
482    fn loss_matrix(&self) -> &LossMatrix;
483
484    /// Update the posterior given an observation at `state_index`.
485    fn update_posterior(&self, posterior: &mut Posterior, state_index: usize);
486
487    /// Choose the optimal action given the current posterior.
488    ///
489    /// Returns an action index into [`action_set`](Self::action_set).
490    fn choose_action(&self, posterior: &Posterior) -> usize;
491
492    /// The fallback action when the model is unreliable.
493    ///
494    /// Returns an action index into [`action_set`](Self::action_set).
495    fn fallback_action(&self) -> usize;
496
497    /// Policy governing fallback activation.
498    fn fallback_policy(&self) -> &FallbackPolicy;
499}
500
501// ---------------------------------------------------------------------------
502// DecisionAuditEntry
503// ---------------------------------------------------------------------------
504
505/// Structured audit record linking a decision to the evidence ledger.
506///
507/// Captures the full context of a runtime decision for offline analysis
508/// and replay.
509#[derive(Clone, Debug, Serialize, Deserialize)]
510pub struct DecisionAuditEntry {
511    /// Unique identifier for this decision.
512    pub decision_id: DecisionId,
513    /// Trace context for distributed tracing.
514    pub trace_id: TraceId,
515    /// Name of the decision contract that was evaluated.
516    pub contract_name: String,
517    /// The action that was chosen.
518    pub action_chosen: String,
519    /// Expected loss of the chosen action.
520    pub expected_loss: f64,
521    /// Current calibration score at decision time.
522    pub calibration_score: f64,
523    /// Whether the fallback heuristic was active.
524    pub fallback_active: bool,
525    /// Snapshot of the posterior at decision time.
526    pub posterior_snapshot: Vec<f64>,
527    /// Expected loss for each candidate action.
528    pub expected_loss_by_action: HashMap<String, f64>,
529    /// Unix timestamp in milliseconds.
530    pub ts_unix_ms: u64,
531}
532
533impl DecisionAuditEntry {
534    /// Convert to an [`EvidenceLedger`] entry for structured tracing.
535    pub fn to_evidence_ledger(&self) -> EvidenceLedger {
536        let mut builder = EvidenceLedgerBuilder::new()
537            .ts_unix_ms(self.ts_unix_ms)
538            .component(&self.contract_name)
539            .action(&self.action_chosen)
540            .posterior(self.posterior_snapshot.clone())
541            .chosen_expected_loss(self.expected_loss)
542            .calibration_score(self.calibration_score)
543            .fallback_active(self.fallback_active);
544
545        for (action, &loss) in &self.expected_loss_by_action {
546            builder = builder.expected_loss(action, loss);
547        }
548
549        builder
550            .build()
551            .expect("audit entry should produce valid evidence ledger")
552    }
553}
554
555// ---------------------------------------------------------------------------
556// DecisionOutcome
557// ---------------------------------------------------------------------------
558
559/// Result of evaluating a decision contract.
560#[derive(Clone, Debug)]
561pub struct DecisionOutcome {
562    /// Index of the chosen action.
563    pub action_index: usize,
564    /// Name of the chosen action.
565    pub action_name: String,
566    /// Expected loss of the chosen action.
567    pub expected_loss: f64,
568    /// Expected losses for all candidate actions.
569    pub expected_losses: HashMap<String, f64>,
570    /// Whether fallback was activated.
571    pub fallback_active: bool,
572    /// Full audit entry for this decision.
573    pub audit_entry: DecisionAuditEntry,
574}
575
576// ---------------------------------------------------------------------------
577// EvalContext
578// ---------------------------------------------------------------------------
579
580/// Runtime context for a single decision evaluation.
581///
582/// Bundles the monitoring metrics and tracing identifiers needed by
583/// [`evaluate`].
584#[derive(Clone, Debug)]
585pub struct EvalContext {
586    /// Current calibration score.
587    pub calibration_score: f64,
588    /// Current e-process statistic.
589    pub e_process: f64,
590    /// Current confidence interval width.
591    pub ci_width: f64,
592    /// Unique identifier for this decision.
593    pub decision_id: DecisionId,
594    /// Trace context for distributed tracing.
595    pub trace_id: TraceId,
596    /// Unix timestamp in milliseconds.
597    pub ts_unix_ms: u64,
598}
599
600// ---------------------------------------------------------------------------
601// Evaluate
602// ---------------------------------------------------------------------------
603
604/// Evaluate a decision contract and produce a full audit trail.
605///
606/// This is the primary entry point for making auditable decisions.
607/// It computes expected losses, checks fallback conditions, and produces
608/// a [`DecisionOutcome`] with a linked [`DecisionAuditEntry`].
609pub fn evaluate<C: DecisionContract>(
610    contract: &C,
611    posterior: &Posterior,
612    ctx: &EvalContext,
613) -> DecisionOutcome {
614    let loss_matrix = contract.loss_matrix();
615    let expected_losses = loss_matrix.expected_losses(posterior);
616
617    let fallback_active = contract.fallback_policy().should_fallback(
618        ctx.calibration_score,
619        ctx.e_process,
620        ctx.ci_width,
621    );
622
623    let action_index = if fallback_active {
624        contract.fallback_action()
625    } else {
626        contract.choose_action(posterior)
627    };
628
629    let action_name = contract.action_set()[action_index].clone();
630    let expected_loss = expected_losses[&action_name];
631
632    let audit_entry = DecisionAuditEntry {
633        decision_id: ctx.decision_id,
634        trace_id: ctx.trace_id,
635        contract_name: contract.name().to_string(),
636        action_chosen: action_name.clone(),
637        expected_loss,
638        calibration_score: ctx.calibration_score,
639        fallback_active,
640        posterior_snapshot: posterior.probs().to_vec(),
641        expected_loss_by_action: expected_losses.clone(),
642        ts_unix_ms: ctx.ts_unix_ms,
643    };
644
645    DecisionOutcome {
646        action_index,
647        action_name,
648        expected_loss,
649        expected_losses,
650        fallback_active,
651        audit_entry,
652    }
653}
654
655// ---------------------------------------------------------------------------
656// Tests
657// ---------------------------------------------------------------------------
658
659#[cfg(test)]
660#[allow(clippy::float_cmp)]
661mod tests {
662    use super::*;
663
664    // -- Helpers --
665
666    fn two_state_matrix() -> LossMatrix {
667        // States: [good, bad], Actions: [continue, stop]
668        // loss(good, continue) = 0.0, loss(good, stop) = 0.3
669        // loss(bad, continue)  = 0.8, loss(bad, stop)  = 0.1
670        LossMatrix::new(
671            vec!["good".into(), "bad".into()],
672            vec!["continue".into(), "stop".into()],
673            vec![0.0, 0.3, 0.8, 0.1],
674        )
675        .unwrap()
676    }
677
678    struct TestContract {
679        states: Vec<String>,
680        actions: Vec<String>,
681        losses: LossMatrix,
682        policy: FallbackPolicy,
683    }
684
685    impl TestContract {
686        fn new() -> Self {
687            Self {
688                states: vec!["good".into(), "bad".into()],
689                actions: vec!["continue".into(), "stop".into()],
690                losses: two_state_matrix(),
691                policy: FallbackPolicy::default(),
692            }
693        }
694    }
695
696    #[allow(clippy::unnecessary_literal_bound)]
697    impl DecisionContract for TestContract {
698        fn name(&self) -> &str {
699            "test_contract"
700        }
701        fn state_space(&self) -> &[String] {
702            &self.states
703        }
704        fn action_set(&self) -> &[String] {
705            &self.actions
706        }
707        fn loss_matrix(&self) -> &LossMatrix {
708            &self.losses
709        }
710        fn update_posterior(&self, posterior: &mut Posterior, observation: usize) {
711            // Simple likelihood model: observed state gets high likelihood.
712            let mut likelihoods = vec![0.1; self.states.len()];
713            likelihoods[observation] = 0.9;
714            posterior.bayesian_update(&likelihoods);
715        }
716        fn choose_action(&self, posterior: &Posterior) -> usize {
717            self.losses.bayes_action(posterior)
718        }
719        fn fallback_action(&self) -> usize {
720            0 // "continue"
721        }
722        fn fallback_policy(&self) -> &FallbackPolicy {
723            &self.policy
724        }
725    }
726
727    // -- LossMatrix tests --
728
729    #[test]
730    fn loss_matrix_creation() {
731        let m = two_state_matrix();
732        assert_eq!(m.n_states(), 2);
733        assert_eq!(m.n_actions(), 2);
734        assert_eq!(m.get(0, 0), 0.0);
735        assert_eq!(m.get(0, 1), 0.3);
736        assert_eq!(m.get(1, 0), 0.8);
737        assert_eq!(m.get(1, 1), 0.1);
738    }
739
740    #[test]
741    fn loss_matrix_empty_states_rejected() {
742        let err = LossMatrix::new(vec![], vec!["a".into()], vec![]).unwrap_err();
743        assert!(matches!(
744            err,
745            ValidationError::EmptySpace {
746                field: "state_names"
747            }
748        ));
749    }
750
751    #[test]
752    fn loss_matrix_empty_actions_rejected() {
753        let err = LossMatrix::new(vec!["s".into()], vec![], vec![]).unwrap_err();
754        assert!(matches!(
755            err,
756            ValidationError::EmptySpace {
757                field: "action_names"
758            }
759        ));
760    }
761
762    #[test]
763    fn loss_matrix_dimension_mismatch() {
764        let err = LossMatrix::new(
765            vec!["s1".into(), "s2".into()],
766            vec!["a1".into()],
767            vec![0.1], // needs 2 values
768        )
769        .unwrap_err();
770        assert!(matches!(
771            err,
772            ValidationError::DimensionMismatch {
773                expected: 2,
774                got: 1
775            }
776        ));
777    }
778
779    #[test]
780    fn loss_matrix_negative_rejected() {
781        let err = LossMatrix::new(vec!["s".into()], vec!["a".into()], vec![-0.5]).unwrap_err();
782        assert!(matches!(
783            err,
784            ValidationError::NegativeLoss {
785                state: 0,
786                action: 0,
787                ..
788            }
789        ));
790    }
791
792    #[test]
793    fn loss_matrix_expected_loss() {
794        let m = two_state_matrix();
795        let posterior = Posterior::new(vec![0.8, 0.2]).unwrap();
796        // E[loss|continue] = 0.8*0.0 + 0.2*0.8 = 0.16
797        let el_continue = m.expected_loss(&posterior, 0);
798        assert!((el_continue - 0.16).abs() < 1e-10);
799        // E[loss|stop] = 0.8*0.3 + 0.2*0.1 = 0.26
800        let el_stop = m.expected_loss(&posterior, 1);
801        assert!((el_stop - 0.26).abs() < 1e-10);
802    }
803
804    #[test]
805    fn loss_matrix_bayes_action() {
806        let m = two_state_matrix();
807        // When mostly good, continue is optimal.
808        let mostly_good = Posterior::new(vec![0.9, 0.1]).unwrap();
809        assert_eq!(m.bayes_action(&mostly_good), 0); // continue
810                                                     // When mostly bad, stop is optimal.
811        let mostly_bad = Posterior::new(vec![0.2, 0.8]).unwrap();
812        assert_eq!(m.bayes_action(&mostly_bad), 1); // stop
813    }
814
815    #[test]
816    fn loss_matrix_expected_losses_map() {
817        let m = two_state_matrix();
818        let posterior = Posterior::uniform(2);
819        let losses = m.expected_losses(&posterior);
820        assert_eq!(losses.len(), 2);
821        assert!(losses.contains_key("continue"));
822        assert!(losses.contains_key("stop"));
823    }
824
825    #[test]
826    fn loss_matrix_names() {
827        let m = two_state_matrix();
828        assert_eq!(m.state_names(), &["good", "bad"]);
829        assert_eq!(m.action_names(), &["continue", "stop"]);
830    }
831
832    #[test]
833    fn loss_matrix_toml_roundtrip() {
834        let m = two_state_matrix();
835        let toml_str = toml::to_string(&m).unwrap();
836        let parsed: LossMatrix = toml::from_str(&toml_str).unwrap();
837        assert_eq!(m, parsed);
838    }
839
840    #[test]
841    fn loss_matrix_json_roundtrip() {
842        let m = two_state_matrix();
843        let json = serde_json::to_string(&m).unwrap();
844        let parsed: LossMatrix = serde_json::from_str(&json).unwrap();
845        assert_eq!(m, parsed);
846    }
847
848    // -- Posterior tests --
849
850    #[test]
851    fn posterior_uniform() {
852        let p = Posterior::uniform(4);
853        assert_eq!(p.len(), 4);
854        for &v in p.probs() {
855            assert!((v - 0.25).abs() < 1e-10);
856        }
857    }
858
859    #[test]
860    fn posterior_new_valid() {
861        let p = Posterior::new(vec![0.3, 0.7]).unwrap();
862        assert_eq!(p.probs(), &[0.3, 0.7]);
863    }
864
865    #[test]
866    fn posterior_new_not_normalized() {
867        let err = Posterior::new(vec![0.5, 0.3]).unwrap_err();
868        assert!(matches!(
869            err,
870            ValidationError::PosteriorNotNormalized { .. }
871        ));
872    }
873
874    #[test]
875    fn posterior_bayesian_update() {
876        let mut p = Posterior::uniform(2);
877        // Likelihood: state 0 very likely given observation.
878        p.bayesian_update(&[0.9, 0.1]);
879        // After update: p(0) = 0.5*0.9 / (0.5*0.9 + 0.5*0.1) = 0.9
880        assert!((p.probs()[0] - 0.9).abs() < 1e-10);
881        assert!((p.probs()[1] - 0.1).abs() < 1e-10);
882    }
883
884    #[test]
885    fn posterior_bayesian_update_no_alloc() {
886        // Verify the update works in-place by checking pointer stability.
887        let mut p = Posterior::uniform(3);
888        let ptr_before = p.probs().as_ptr();
889        p.bayesian_update(&[0.5, 0.3, 0.2]);
890        let ptr_after = p.probs().as_ptr();
891        assert_eq!(ptr_before, ptr_after);
892    }
893
894    #[test]
895    fn posterior_entropy() {
896        // Uniform over 2 states: entropy = 1.0 bit.
897        let p = Posterior::uniform(2);
898        assert!((p.entropy() - 1.0).abs() < 1e-10);
899        // Deterministic: entropy = 0.
900        let det = Posterior::new(vec![1.0, 0.0]).unwrap();
901        assert!((det.entropy()).abs() < 1e-10);
902    }
903
904    #[test]
905    fn posterior_map_state() {
906        let p = Posterior::new(vec![0.1, 0.7, 0.2]).unwrap();
907        assert_eq!(p.map_state(), 1);
908    }
909
910    #[test]
911    fn posterior_is_empty() {
912        let p = Posterior { probs: vec![] };
913        assert!(p.is_empty());
914        let p2 = Posterior::uniform(1);
915        assert!(!p2.is_empty());
916    }
917
918    #[test]
919    fn posterior_probs_mut() {
920        let mut p = Posterior::uniform(2);
921        p.probs_mut()[0] = 0.8;
922        p.probs_mut()[1] = 0.2;
923        assert_eq!(p.probs(), &[0.8, 0.2]);
924    }
925
926    // -- FallbackPolicy tests --
927
928    #[test]
929    fn fallback_policy_default() {
930        let fp = FallbackPolicy::default();
931        assert_eq!(fp.calibration_drift_threshold, 0.7);
932        assert_eq!(fp.e_process_breach_threshold, 20.0);
933        assert_eq!(fp.confidence_width_threshold, 0.5);
934    }
935
936    #[test]
937    fn fallback_policy_new_valid() {
938        let fp = FallbackPolicy::new(0.8, 10.0, 0.3).unwrap();
939        assert_eq!(fp.calibration_drift_threshold, 0.8);
940    }
941
942    #[test]
943    fn fallback_policy_calibration_out_of_range() {
944        let err = FallbackPolicy::new(1.5, 10.0, 0.3).unwrap_err();
945        assert!(matches!(
946            err,
947            ValidationError::ThresholdOutOfRange {
948                field: "calibration_drift_threshold",
949                ..
950            }
951        ));
952    }
953
954    #[test]
955    fn fallback_policy_negative_e_process() {
956        let err = FallbackPolicy::new(0.7, -1.0, 0.3).unwrap_err();
957        assert!(matches!(
958            err,
959            ValidationError::ThresholdOutOfRange {
960                field: "e_process_breach_threshold",
961                ..
962            }
963        ));
964    }
965
966    #[test]
967    fn fallback_policy_negative_ci_width() {
968        let err = FallbackPolicy::new(0.7, 10.0, -0.1).unwrap_err();
969        assert!(matches!(
970            err,
971            ValidationError::ThresholdOutOfRange {
972                field: "confidence_width_threshold",
973                ..
974            }
975        ));
976    }
977
978    #[test]
979    fn fallback_triggered_by_low_calibration() {
980        let fp = FallbackPolicy::default();
981        assert!(fp.should_fallback(0.5, 1.0, 0.1)); // cal < 0.7
982        assert!(!fp.should_fallback(0.9, 1.0, 0.1)); // cal OK
983    }
984
985    #[test]
986    fn fallback_triggered_by_e_process() {
987        let fp = FallbackPolicy::default();
988        assert!(fp.should_fallback(0.9, 25.0, 0.1)); // e_process > 20
989        assert!(!fp.should_fallback(0.9, 15.0, 0.1)); // e_process OK
990    }
991
992    #[test]
993    fn fallback_triggered_by_ci_width() {
994        let fp = FallbackPolicy::default();
995        assert!(fp.should_fallback(0.9, 1.0, 0.6)); // ci > 0.5
996        assert!(!fp.should_fallback(0.9, 1.0, 0.3)); // ci OK
997    }
998
999    // -- DecisionContract + evaluate tests --
1000
1001    #[test]
1002    fn contract_implementable_under_50_lines() {
1003        // The TestContract impl above is 22 lines — well under 50.
1004        let contract = TestContract::new();
1005        assert_eq!(contract.name(), "test_contract");
1006        assert_eq!(contract.state_space().len(), 2);
1007        assert_eq!(contract.action_set().len(), 2);
1008    }
1009
1010    fn test_ctx(cal: f64, random: u128) -> EvalContext {
1011        EvalContext {
1012            calibration_score: cal,
1013            e_process: 1.0,
1014            ci_width: 0.1,
1015            decision_id: DecisionId::from_parts(1_700_000_000_000, random),
1016            trace_id: TraceId::from_parts(1_700_000_000_000, random),
1017            ts_unix_ms: 1_700_000_000_000,
1018        }
1019    }
1020
1021    #[test]
1022    fn evaluate_normal_decision() {
1023        let contract = TestContract::new();
1024        let posterior = Posterior::new(vec![0.9, 0.1]).unwrap();
1025        let ctx = test_ctx(0.95, 42);
1026
1027        let outcome = evaluate(&contract, &posterior, &ctx);
1028
1029        assert!(!outcome.fallback_active);
1030        assert_eq!(outcome.action_name, "continue"); // low loss when mostly good
1031        assert_eq!(outcome.action_index, 0);
1032        assert!(outcome.expected_loss < 0.1);
1033        assert_eq!(outcome.expected_losses.len(), 2);
1034    }
1035
1036    #[test]
1037    fn evaluate_fallback_decision() {
1038        let contract = TestContract::new();
1039        let posterior = Posterior::new(vec![0.2, 0.8]).unwrap();
1040        let ctx = test_ctx(0.5, 43); // low calibration triggers fallback
1041
1042        let outcome = evaluate(&contract, &posterior, &ctx);
1043
1044        assert!(outcome.fallback_active);
1045        assert_eq!(outcome.action_name, "continue"); // fallback action = 0
1046        assert_eq!(outcome.action_index, 0);
1047    }
1048
1049    #[test]
1050    fn evaluate_without_fallback_chooses_optimal() {
1051        let contract = TestContract::new();
1052        let posterior = Posterior::new(vec![0.2, 0.8]).unwrap();
1053        let ctx = test_ctx(0.95, 44); // good calibration, no fallback
1054
1055        let outcome = evaluate(&contract, &posterior, &ctx);
1056
1057        assert!(!outcome.fallback_active);
1058        assert_eq!(outcome.action_name, "stop"); // optimal when mostly bad
1059    }
1060
1061    #[test]
1062    fn evaluate_audit_entry_fields() {
1063        let contract = TestContract::new();
1064        let posterior = Posterior::uniform(2);
1065        let ctx = test_ctx(0.85, 99);
1066
1067        let outcome = evaluate(&contract, &posterior, &ctx);
1068
1069        let audit = &outcome.audit_entry;
1070        assert_eq!(audit.decision_id, ctx.decision_id);
1071        assert_eq!(audit.trace_id, ctx.trace_id);
1072        assert_eq!(audit.contract_name, "test_contract");
1073        assert_eq!(audit.calibration_score, 0.85);
1074        assert_eq!(audit.ts_unix_ms, 1_700_000_000_000);
1075        assert_eq!(audit.posterior_snapshot.len(), 2);
1076    }
1077
1078    // -- DecisionAuditEntry → EvidenceLedger --
1079
1080    #[test]
1081    fn audit_entry_to_evidence_ledger() {
1082        let contract = TestContract::new();
1083        let posterior = Posterior::new(vec![0.6, 0.4]).unwrap();
1084        let ctx = test_ctx(0.92, 100);
1085
1086        let outcome = evaluate(&contract, &posterior, &ctx);
1087        let evidence = outcome.audit_entry.to_evidence_ledger();
1088
1089        assert_eq!(evidence.ts_unix_ms, 1_700_000_000_000);
1090        assert_eq!(evidence.component, "test_contract");
1091        assert_eq!(evidence.action, outcome.action_name);
1092        assert_eq!(evidence.calibration_score, 0.92);
1093        assert!(!evidence.fallback_active);
1094        assert_eq!(evidence.posterior, vec![0.6, 0.4]);
1095        assert!(evidence.is_valid());
1096    }
1097
1098    #[test]
1099    fn audit_entry_serde_roundtrip() {
1100        let contract = TestContract::new();
1101        let posterior = Posterior::uniform(2);
1102        let ctx = test_ctx(0.88, 101);
1103
1104        let outcome = evaluate(&contract, &posterior, &ctx);
1105        let json = serde_json::to_string(&outcome.audit_entry).unwrap();
1106        let parsed: DecisionAuditEntry = serde_json::from_str(&json).unwrap();
1107        assert_eq!(parsed.contract_name, "test_contract");
1108        assert_eq!(parsed.decision_id, ctx.decision_id);
1109        assert_eq!(parsed.trace_id, ctx.trace_id);
1110    }
1111
1112    // -- Update posterior via contract --
1113
1114    #[test]
1115    fn contract_update_posterior() {
1116        let contract = TestContract::new();
1117        let mut posterior = Posterior::uniform(2);
1118        contract.update_posterior(&mut posterior, 0); // observe "good"
1119                                                      // After update: state 0 should be more probable.
1120        assert!(posterior.probs()[0] > posterior.probs()[1]);
1121    }
1122
1123    // -- Validation error display --
1124
1125    #[test]
1126    fn validation_error_display() {
1127        let err = ValidationError::NegativeLoss {
1128            state: 1,
1129            action: 2,
1130            value: -0.5,
1131        };
1132        let msg = format!("{err}");
1133        assert!(msg.contains("-0.5"));
1134        assert!(msg.contains("state=1"));
1135        assert!(msg.contains("action=2"));
1136    }
1137
1138    #[test]
1139    fn dimension_mismatch_display() {
1140        let err = ValidationError::DimensionMismatch {
1141            expected: 6,
1142            got: 4,
1143        };
1144        let msg = format!("{err}");
1145        assert!(msg.contains('6'));
1146        assert!(msg.contains('4'));
1147    }
1148
1149    // -- FallbackPolicy serde --
1150
1151    #[test]
1152    fn fallback_policy_toml_roundtrip() {
1153        let fp = FallbackPolicy::default();
1154        let toml_str = toml::to_string(&fp).unwrap();
1155        let parsed: FallbackPolicy = toml::from_str(&toml_str).unwrap();
1156        assert_eq!(fp, parsed);
1157    }
1158
1159    #[test]
1160    fn fallback_policy_json_roundtrip() {
1161        let fp = FallbackPolicy::default();
1162        let json = serde_json::to_string(&fp).unwrap();
1163        let parsed: FallbackPolicy = serde_json::from_str(&json).unwrap();
1164        assert_eq!(fp, parsed);
1165    }
1166
1167    // -- argmin correctness with known posteriors --
1168
1169    #[test]
1170    fn argmin_correctness_deterministic_posterior() {
1171        let m = two_state_matrix();
1172        // Fully certain state=good: E[continue]=0.0, E[stop]=0.3 → continue wins.
1173        let certain_good = Posterior::new(vec![1.0, 0.0]).unwrap();
1174        assert_eq!(m.bayes_action(&certain_good), 0);
1175        // Fully certain state=bad: E[continue]=0.8, E[stop]=0.1 → stop wins.
1176        let certain_bad = Posterior::new(vec![0.0, 1.0]).unwrap();
1177        assert_eq!(m.bayes_action(&certain_bad), 1);
1178    }
1179
1180    #[test]
1181    fn argmin_correctness_breakeven_point() {
1182        let m = two_state_matrix();
1183        // Find crossover: at p(good)=x, E[continue]=0.8(1-x) and E[stop]=0.3x+0.1(1-x).
1184        // Crossover: 0.8-0.8x = 0.3x+0.1-0.1x → 0.8-0.8x = 0.2x+0.1 → 0.7=x → x=0.7
1185        // At p(good)=0.71, continue is better.
1186        let above = Posterior::new(vec![0.71, 0.29]).unwrap();
1187        assert_eq!(m.bayes_action(&above), 0);
1188        // At p(good)=0.69, stop is better.
1189        let below = Posterior::new(vec![0.69, 0.31]).unwrap();
1190        assert_eq!(m.bayes_action(&below), 1);
1191    }
1192
1193    #[test]
1194    fn argmin_three_state_three_action() {
1195        // 3 states, 3 actions: verify argmin in a bigger space.
1196        let m = LossMatrix::new(
1197            vec!["s0".into(), "s1".into(), "s2".into()],
1198            vec!["a0".into(), "a1".into(), "a2".into()],
1199            vec![
1200                1.0, 2.0, 3.0, // state 0
1201                3.0, 1.0, 2.0, // state 1
1202                2.0, 3.0, 1.0, // state 2
1203            ],
1204        )
1205        .unwrap();
1206        // Uniform posterior: E[a0]=2.0, E[a1]=2.0, E[a2]=2.0 → all tied.
1207        // Rust's min_by returns the last equal element, so index 2.
1208        let uniform = Posterior::uniform(3);
1209        let action = m.bayes_action(&uniform);
1210        // Any action is valid since all expected losses are equal.
1211        assert!(action < 3);
1212        // Posterior concentrated on state 1: a1 has loss 1.0 → a1 wins.
1213        let state1 = Posterior::new(vec![0.0, 1.0, 0.0]).unwrap();
1214        assert_eq!(m.bayes_action(&state1), 1);
1215        // Posterior concentrated on state 2: a2 has loss 1.0 → a2 wins.
1216        let state2 = Posterior::new(vec![0.0, 0.0, 1.0]).unwrap();
1217        assert_eq!(m.bayes_action(&state2), 2);
1218    }
1219
1220    // -- Bayesian update hand-computed --
1221
1222    #[test]
1223    fn bayesian_update_hand_computed_three_state() {
1224        // Prior: [0.5, 0.3, 0.2]
1225        // Likelihoods: [0.1, 0.6, 0.3]
1226        // Unnorm: [0.05, 0.18, 0.06]  sum=0.29
1227        // Posterior: [0.05/0.29, 0.18/0.29, 0.06/0.29]
1228        let mut p = Posterior::new(vec![0.5, 0.3, 0.2]).unwrap();
1229        p.bayesian_update(&[0.1, 0.6, 0.3]);
1230        let expected = [0.05 / 0.29, 0.18 / 0.29, 0.06 / 0.29];
1231        for (i, &e) in expected.iter().enumerate() {
1232            assert!(
1233                (p.probs()[i] - e).abs() < 1e-10,
1234                "state {i}: got {}, expected {e}",
1235                p.probs()[i]
1236            );
1237        }
1238    }
1239
1240    #[test]
1241    fn bayesian_update_successive_convergence() {
1242        // Repeated observations of state 0 should drive posterior toward certainty.
1243        let mut p = Posterior::uniform(3);
1244        for _ in 0..20 {
1245            p.bayesian_update(&[0.9, 0.05, 0.05]);
1246        }
1247        assert!(p.probs()[0] > 0.999);
1248        assert!(p.probs()[1] < 0.001);
1249        assert!(p.probs()[2] < 0.001);
1250    }
1251
1252    // -- End-to-end decision pipeline --
1253
1254    #[test]
1255    fn end_to_end_pipeline() {
1256        let contract = TestContract::new();
1257        let mut posterior = Posterior::uniform(2);
1258
1259        // Feed 5 "good" observations: posterior should shift toward state 0.
1260        for _ in 0..5 {
1261            contract.update_posterior(&mut posterior, 0);
1262        }
1263        assert!(posterior.probs()[0] > 0.99);
1264
1265        // Make a decision: should be "continue" (low loss when good).
1266        let ctx = test_ctx(0.95, 200);
1267        let outcome = evaluate(&contract, &posterior, &ctx);
1268        assert!(!outcome.fallback_active);
1269        assert_eq!(outcome.action_name, "continue");
1270        assert!(outcome.expected_loss < 0.01);
1271
1272        // Verify evidence ledger entry.
1273        let evidence = outcome.audit_entry.to_evidence_ledger();
1274        assert_eq!(evidence.component, "test_contract");
1275        assert_eq!(evidence.action, "continue");
1276        assert!(evidence.is_valid());
1277
1278        // Now feed "bad" observations to shift posterior.
1279        for _ in 0..20 {
1280            contract.update_posterior(&mut posterior, 1);
1281        }
1282        assert!(posterior.probs()[1] > 0.99);
1283
1284        // Decision should now be "stop".
1285        let ctx2 = test_ctx(0.95, 201);
1286        let outcome2 = evaluate(&contract, &posterior, &ctx2);
1287        assert_eq!(outcome2.action_name, "stop");
1288    }
1289
1290    // -- Concurrent decision safety --
1291
1292    #[test]
1293    fn concurrent_decision_safety() {
1294        use std::sync::Arc;
1295        use std::thread;
1296
1297        let contract = Arc::new(TestContract::new());
1298        let results: Vec<_> = (0..10)
1299            .map(|i| {
1300                let c = Arc::clone(&contract);
1301                thread::spawn(move || {
1302                    let posterior = Posterior::uniform(2);
1303                    let ctx = EvalContext {
1304                        calibration_score: 0.9,
1305                        e_process: 1.0,
1306                        ci_width: 0.1,
1307                        decision_id: DecisionId::from_parts(1_700_000_000_000, u128::from(i)),
1308                        trace_id: TraceId::from_parts(1_700_000_000_000, u128::from(i)),
1309                        ts_unix_ms: 1_700_000_000_000 + i,
1310                    };
1311                    let outcome = evaluate(c.as_ref(), &posterior, &ctx);
1312                    assert!(!outcome.action_name.is_empty());
1313                    assert_eq!(outcome.expected_losses.len(), 2);
1314                    let evidence = outcome.audit_entry.to_evidence_ledger();
1315                    assert!(evidence.is_valid());
1316                    outcome
1317                })
1318            })
1319            .map(|h| h.join().unwrap())
1320            .collect();
1321        assert_eq!(results.len(), 10);
1322        // All should agree on the same action for uniform posterior.
1323        let actions: std::collections::HashSet<_> =
1324            results.iter().map(|r| r.action_name.clone()).collect();
1325        assert_eq!(
1326            actions.len(),
1327            1,
1328            "all threads should choose the same action"
1329        );
1330    }
1331
1332    // -- Cross-crate type verification --
1333
1334    #[test]
1335    fn cross_crate_franken_kernel_types() {
1336        // Verify DecisionId and TraceId are the franken_kernel versions.
1337        let did = DecisionId::from_parts(1_700_000_000_000, 42);
1338        assert_eq!(did.timestamp_ms(), 1_700_000_000_000);
1339        let tid = TraceId::from_parts(1_700_000_000_000, 1);
1340        assert_eq!(tid.timestamp_ms(), 1_700_000_000_000);
1341
1342        // Verify they work correctly in DecisionAuditEntry.
1343        let contract = TestContract::new();
1344        let posterior = Posterior::uniform(2);
1345        let ctx = EvalContext {
1346            calibration_score: 0.9,
1347            e_process: 1.0,
1348            ci_width: 0.1,
1349            decision_id: did,
1350            trace_id: tid,
1351            ts_unix_ms: 1_700_000_000_000,
1352        };
1353        let outcome = evaluate(&contract, &posterior, &ctx);
1354        assert_eq!(outcome.audit_entry.decision_id, did);
1355        assert_eq!(outcome.audit_entry.trace_id, tid);
1356    }
1357
1358    // -- Posterior serde roundtrips --
1359
1360    #[test]
1361    fn posterior_json_roundtrip() {
1362        let p = Posterior::new(vec![0.25, 0.75]).unwrap();
1363        let json = serde_json::to_string(&p).unwrap();
1364        let parsed: Posterior = serde_json::from_str(&json).unwrap();
1365        assert_eq!(p, parsed);
1366    }
1367
1368    // -- LossMatrix 3x3 TOML --
1369
1370    #[test]
1371    fn loss_matrix_3x3_toml_roundtrip() {
1372        let m = LossMatrix::new(
1373            vec!["s0".into(), "s1".into(), "s2".into()],
1374            vec!["a0".into(), "a1".into(), "a2".into()],
1375            vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
1376        )
1377        .unwrap();
1378        let toml_str = toml::to_string(&m).unwrap();
1379        let parsed: LossMatrix = toml::from_str(&toml_str).unwrap();
1380        assert_eq!(m, parsed);
1381    }
1382
1383    // -- DecisionOutcome debug --
1384
1385    #[test]
1386    fn decision_outcome_debug() {
1387        let contract = TestContract::new();
1388        let posterior = Posterior::uniform(2);
1389        let ctx = test_ctx(0.9, 300);
1390        let outcome = evaluate(&contract, &posterior, &ctx);
1391        let dbg = format!("{outcome:?}");
1392        assert!(dbg.contains("DecisionOutcome"));
1393        assert!(dbg.contains("action_name"));
1394    }
1395
1396    // -- Fallback all three triggers --
1397
1398    #[test]
1399    fn fallback_multiple_triggers_simultaneously() {
1400        let fp = FallbackPolicy::default();
1401        // All three conditions breached simultaneously.
1402        assert!(fp.should_fallback(0.3, 30.0, 0.9));
1403    }
1404
1405    #[test]
1406    fn fallback_no_trigger_at_exact_thresholds() {
1407        let fp = FallbackPolicy::default();
1408        // Exactly at thresholds: cal=0.7 (not < 0.7), e=20 (not > 20), ci=0.5 (not > 0.5).
1409        assert!(!fp.should_fallback(0.7, 20.0, 0.5));
1410    }
1411
1412    // -- Entropy edge cases --
1413
1414    #[test]
1415    fn posterior_entropy_three_state_uniform() {
1416        let p = Posterior::uniform(3);
1417        // entropy = log2(3) ≈ 1.585
1418        assert!((p.entropy() - 3.0_f64.log2()).abs() < 1e-10);
1419    }
1420
1421    #[test]
1422    fn posterior_entropy_single_state() {
1423        let p = Posterior::new(vec![1.0]).unwrap();
1424        assert!((p.entropy()).abs() < 1e-10);
1425    }
1426
1427    // -- ValidationError is std::error::Error --
1428
1429    #[test]
1430    fn validation_error_is_std_error() {
1431        fn assert_error<E: std::error::Error>() {}
1432        assert_error::<ValidationError>();
1433    }
1434}
1435
1436// ---------------------------------------------------------------------------
1437// Property-based tests (proptest)
1438// ---------------------------------------------------------------------------
1439
1440#[cfg(test)]
1441#[allow(clippy::float_cmp)]
1442mod proptest_tests {
1443    use super::*;
1444    use proptest::prelude::*;
1445
1446    /// Generate a valid probability vector of length `n`.
1447    fn arb_posterior(n: usize) -> impl Strategy<Value = Posterior> {
1448        proptest::collection::vec(0.01_f64..=1.0, n).prop_map(|mut v| {
1449            let sum: f64 = v.iter().sum();
1450            for p in &mut v {
1451                *p /= sum;
1452            }
1453            Posterior::new(v).unwrap()
1454        })
1455    }
1456
1457    /// Generate a valid loss matrix of given dimensions.
1458    fn arb_loss_matrix(n_states: usize, n_actions: usize) -> impl Strategy<Value = LossMatrix> {
1459        let states: Vec<String> = (0..n_states).map(|i| format!("s{i}")).collect();
1460        let actions: Vec<String> = (0..n_actions).map(|i| format!("a{i}")).collect();
1461        proptest::collection::vec(0.0_f64..=10.0, n_states * n_actions).prop_map(move |values| {
1462            LossMatrix::new(states.clone(), actions.clone(), values).unwrap()
1463        })
1464    }
1465
1466    // -- Argmin: chosen action minimizes expected loss for any valid posterior --
1467
1468    proptest! {
1469        #![proptest_config(ProptestConfig::with_cases(10_000))]
1470
1471        #[test]
1472        fn bayes_action_minimizes_expected_loss(
1473            matrix in arb_loss_matrix(3, 3),
1474            posterior in arb_posterior(3),
1475        ) {
1476            let chosen = matrix.bayes_action(&posterior);
1477            let chosen_loss = matrix.expected_loss(&posterior, chosen);
1478            for a in 0..matrix.n_actions() {
1479                let other_loss = matrix.expected_loss(&posterior, a);
1480                prop_assert!(
1481                    chosen_loss <= other_loss + 1e-10,
1482                    "action {chosen} (loss {chosen_loss}) should be <= action {a} (loss {other_loss})"
1483                );
1484            }
1485        }
1486    }
1487
1488    proptest! {
1489        #![proptest_config(ProptestConfig::with_cases(10_000))]
1490
1491        #[test]
1492        fn bayes_action_minimizes_2x2(
1493            matrix in arb_loss_matrix(2, 2),
1494            posterior in arb_posterior(2),
1495        ) {
1496            let chosen = matrix.bayes_action(&posterior);
1497            let chosen_loss = matrix.expected_loss(&posterior, chosen);
1498            for a in 0..matrix.n_actions() {
1499                prop_assert!(chosen_loss <= matrix.expected_loss(&posterior, a) + 1e-10);
1500            }
1501        }
1502    }
1503
1504    // -- Posterior update preserves normalization --
1505
1506    proptest! {
1507        #![proptest_config(ProptestConfig::with_cases(10_000))]
1508
1509        #[test]
1510        fn bayesian_update_preserves_normalization(
1511            prior in arb_posterior(4),
1512            likelihoods in proptest::collection::vec(0.01_f64..=1.0, 4usize),
1513        ) {
1514            let mut p = prior;
1515            p.bayesian_update(&likelihoods);
1516            let sum: f64 = p.probs().iter().sum();
1517            prop_assert!(
1518                (sum - 1.0).abs() < 1e-10,
1519                "posterior sum = {sum}, expected 1.0"
1520            );
1521            for &prob in p.probs() {
1522                prop_assert!(prob >= 0.0, "negative probability: {prob}");
1523            }
1524        }
1525    }
1526
1527    // -- Posterior: all elements non-negative after update --
1528
1529    proptest! {
1530        #![proptest_config(ProptestConfig::with_cases(10_000))]
1531
1532        #[test]
1533        fn posterior_all_non_negative_after_update(
1534            prior in arb_posterior(3),
1535            likelihoods in proptest::collection::vec(0.0_f64..=1.0, 3usize),
1536        ) {
1537            let mut p = prior;
1538            // Only update if likelihoods have positive sum (avoid degenerate case).
1539            let lik_sum: f64 = likelihoods.iter().sum();
1540            if lik_sum > 0.0 {
1541                p.bayesian_update(&likelihoods);
1542                for &prob in p.probs() {
1543                    prop_assert!(prob >= 0.0, "negative probability: {prob}");
1544                }
1545            }
1546        }
1547    }
1548
1549    // -- FallbackPolicy serde roundtrip --
1550
1551    proptest! {
1552        #[test]
1553        fn fallback_policy_serde_roundtrip(
1554            cal in 0.0_f64..=1.0,
1555            e_proc in 0.0_f64..=100.0,
1556            ci in 0.0_f64..=10.0,
1557        ) {
1558            let fp = FallbackPolicy::new(cal, e_proc, ci).unwrap();
1559            let json = serde_json::to_string(&fp).unwrap();
1560            let parsed: FallbackPolicy = serde_json::from_str(&json).unwrap();
1561            // Use approximate comparison due to f64 JSON round-trip precision.
1562            prop_assert!((fp.calibration_drift_threshold - parsed.calibration_drift_threshold).abs() < 1e-12);
1563            prop_assert!((fp.e_process_breach_threshold - parsed.e_process_breach_threshold).abs() < 1e-12);
1564            prop_assert!((fp.confidence_width_threshold - parsed.confidence_width_threshold).abs() < 1e-12);
1565        }
1566    }
1567
1568    // -- LossMatrix serde roundtrip --
1569
1570    proptest! {
1571        #[test]
1572        fn loss_matrix_serde_roundtrip(
1573            matrix in arb_loss_matrix(2, 3),
1574        ) {
1575            let json = serde_json::to_string(&matrix).unwrap();
1576            let parsed: LossMatrix = serde_json::from_str(&json).unwrap();
1577            prop_assert_eq!(matrix.state_names(), parsed.state_names());
1578            prop_assert_eq!(matrix.action_names(), parsed.action_names());
1579            // Use approximate comparison for f64 values.
1580            for s in 0..matrix.n_states() {
1581                for a in 0..matrix.n_actions() {
1582                    prop_assert!((matrix.get(s, a) - parsed.get(s, a)).abs() < 1e-12);
1583                }
1584            }
1585        }
1586    }
1587
1588    // -- Expected loss is a convex combination --
1589
1590    proptest! {
1591        #![proptest_config(ProptestConfig::with_cases(10_000))]
1592
1593        #[test]
1594        fn expected_loss_within_loss_range(
1595            matrix in arb_loss_matrix(3, 3),
1596            posterior in arb_posterior(3),
1597        ) {
1598            for a in 0..matrix.n_actions() {
1599                let el = matrix.expected_loss(&posterior, a);
1600                let min_loss = (0..matrix.n_states())
1601                    .map(|s| matrix.get(s, a))
1602                    .fold(f64::INFINITY, f64::min);
1603                let max_loss = (0..matrix.n_states())
1604                    .map(|s| matrix.get(s, a))
1605                    .fold(f64::NEG_INFINITY, f64::max);
1606                prop_assert!(
1607                    el >= min_loss - 1e-10 && el <= max_loss + 1e-10,
1608                    "expected loss {el} outside [{min_loss}, {max_loss}]"
1609                );
1610            }
1611        }
1612    }
1613}