Skip to main content

franken_decision/
lib.rs

1//! Decision Contract schema and runtime for FrankenSuite (bd-3ai21).
2//!
3//! The third leg of the foundation tripod alongside `franken_kernel` (types)
4//! and `franken_evidence` (audit ledger). Every FrankenSuite project that
5//! makes runtime decisions uses this crate's contract schema.
6//!
7//! # Core abstractions
8//!
9//! - [`DecisionContract`] — trait defining state space, actions, losses, and
10//!   posterior updates. Implementable in <50 lines.
11//! - [`LossMatrix`] — non-negative loss values indexed by (state, action),
12//!   serializable to TOML for runtime reconfiguration.
13//! - [`Posterior`] — discrete probability distribution with O(|S|)
14//!   no-allocation Bayesian updates.
15//! - [`FallbackPolicy`] — calibration drift, e-process breach, and
16//!   confidence interval width thresholds.
17//! - [`DecisionAuditEntry`] — links decisions to [`EvidenceLedger`] entries.
18//!
19//! # Example
20//!
21//! ```
22//! use franken_decision::{
23//!     DecisionContract, EvalContext, FallbackPolicy, LossMatrix, Posterior, evaluate,
24//! };
25//! use franken_kernel::DecisionId;
26//!
27//! // Define a simple 2-state, 2-action contract.
28//! struct MyContract {
29//!     states: Vec<String>,
30//!     actions: Vec<String>,
31//!     losses: LossMatrix,
32//!     policy: FallbackPolicy,
33//! }
34//!
35//! impl DecisionContract for MyContract {
36//!     fn name(&self) -> &str { "example" }
37//!     fn state_space(&self) -> &[String] { &self.states }
38//!     fn action_set(&self) -> &[String] { &self.actions }
39//!     fn loss_matrix(&self) -> &LossMatrix { &self.losses }
40//!     fn update_posterior(&self, posterior: &mut Posterior, observation: usize) {
41//!         let likelihoods = [0.9, 0.1];
42//!         posterior.bayesian_update(&likelihoods);
43//!     }
44//!     fn choose_action(&self, posterior: &Posterior) -> usize {
45//!         self.losses.bayes_action(posterior)
46//!     }
47//!     fn fallback_action(&self) -> usize { 0 }
48//!     fn fallback_policy(&self) -> &FallbackPolicy { &self.policy }
49//! }
50//!
51//! let contract = MyContract {
52//!     states: vec!["good".into(), "bad".into()],
53//!     actions: vec!["continue".into(), "stop".into()],
54//!     losses: LossMatrix::new(
55//!         vec!["good".into(), "bad".into()],
56//!         vec!["continue".into(), "stop".into()],
57//!         vec![0.0, 0.3, 0.8, 0.1],
58//!     ).unwrap(),
59//!     policy: FallbackPolicy::default(),
60//! };
61//!
62//! let posterior = Posterior::uniform(2);
63//! let decision_id = DecisionId::from_parts(1_700_000_000_000, 42);
64//! let trace_id = franken_kernel::TraceId::from_parts(1_700_000_000_000, 1);
65//!
66//! let ctx = EvalContext {
67//!     calibration_score: 0.9,
68//!     e_process: 0.5,
69//!     ci_width: 0.1,
70//!     decision_id,
71//!     trace_id,
72//!     ts_unix_ms: 1_700_000_000_000,
73//! };
74//! let outcome = evaluate(&contract, &posterior, &ctx);
75//! assert!(!outcome.fallback_active);
76//! ```
77
78#![forbid(unsafe_code)]
79
80use std::collections::BTreeMap;
81use std::fmt;
82
83use franken_evidence::{EvidenceLedger, EvidenceLedgerBuilder};
84use franken_kernel::{DecisionId, TraceId};
85use serde::{Deserialize, Deserializer, Serialize};
86
87// ---------------------------------------------------------------------------
88// Validation errors
89// ---------------------------------------------------------------------------
90
91/// Validation errors for decision types.
92#[derive(Clone, Debug, PartialEq)]
93pub enum ValidationError {
94    /// Loss matrix contains a non-finite value.
95    InvalidLoss {
96        /// State index of the invalid entry.
97        state: usize,
98        /// Action index of the invalid entry.
99        action: usize,
100        /// The invalid value.
101        value: f64,
102    },
103    /// Loss matrix contains a negative value.
104    NegativeLoss {
105        /// State index of the negative entry.
106        state: usize,
107        /// Action index of the negative entry.
108        action: usize,
109        /// The negative value.
110        value: f64,
111    },
112    /// Loss matrix value count does not match dimensions.
113    DimensionMismatch {
114        /// Expected number of values (states * actions).
115        expected: usize,
116        /// Actual number of values provided.
117        got: usize,
118    },
119    /// Posterior probabilities do not sum to ~1.0.
120    PosteriorNotNormalized {
121        /// Actual sum of the posterior.
122        sum: f64,
123    },
124    /// Posterior contains a negative or non-finite probability.
125    InvalidPosteriorProbability {
126        /// Index of the invalid probability.
127        index: usize,
128        /// The invalid value.
129        value: f64,
130    },
131    /// Posterior length does not match state space size.
132    PosteriorLengthMismatch {
133        /// Expected length.
134        expected: usize,
135        /// Actual length.
136        got: usize,
137    },
138    /// State space or action set is empty.
139    EmptySpace {
140        /// Which space is empty.
141        field: &'static str,
142    },
143    /// Threshold value is out of valid range.
144    ThresholdOutOfRange {
145        /// Which threshold.
146        field: &'static str,
147        /// The invalid value.
148        value: f64,
149    },
150}
151
152impl fmt::Display for ValidationError {
153    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154        match self {
155            Self::InvalidLoss {
156                state,
157                action,
158                value,
159            } => write!(
160                f,
161                "loss must be finite at state={state}, action={action}, got {value}"
162            ),
163            Self::NegativeLoss {
164                state,
165                action,
166                value,
167            } => write!(f, "negative loss {value} at state={state}, action={action}"),
168            Self::DimensionMismatch { expected, got } => {
169                write!(
170                    f,
171                    "dimension mismatch: expected {expected} values, got {got}"
172                )
173            }
174            Self::PosteriorNotNormalized { sum } => {
175                write!(f, "posterior sums to {sum}, expected 1.0")
176            }
177            Self::InvalidPosteriorProbability { index, value } => {
178                write!(
179                    f,
180                    "posterior[{index}] must be finite and non-negative, got {value}"
181                )
182            }
183            Self::PosteriorLengthMismatch { expected, got } => {
184                write!(
185                    f,
186                    "posterior length {got} does not match state count {expected}"
187                )
188            }
189            Self::EmptySpace { field } => write!(f, "{field} must not be empty"),
190            Self::ThresholdOutOfRange { field, value } => {
191                write!(f, "{field} threshold {value} out of valid range")
192            }
193        }
194    }
195}
196
197impl std::error::Error for ValidationError {}
198
199// ---------------------------------------------------------------------------
200// LossMatrix
201// ---------------------------------------------------------------------------
202
203/// A loss matrix indexed by (state, action) pairs.
204///
205/// Stored in row-major order: `values[state * n_actions + action]`.
206/// All values must be non-negative. Serializable to TOML/JSON for
207/// runtime reconfiguration.
208#[derive(Clone, Debug, Serialize, PartialEq)]
209pub struct LossMatrix {
210    state_names: Vec<String>,
211    action_names: Vec<String>,
212    values: Vec<f64>,
213}
214
215#[derive(Deserialize)]
216struct LossMatrixRepr {
217    state_names: Vec<String>,
218    action_names: Vec<String>,
219    values: Vec<f64>,
220}
221
222impl<'de> Deserialize<'de> for LossMatrix {
223    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
224    where
225        D: Deserializer<'de>,
226    {
227        let repr = LossMatrixRepr::deserialize(deserializer)?;
228        Self::new(repr.state_names, repr.action_names, repr.values)
229            .map_err(serde::de::Error::custom)
230    }
231}
232
233impl LossMatrix {
234    /// Create a new loss matrix.
235    ///
236    /// `values` must have exactly `state_names.len() * action_names.len()`
237    /// elements, all non-negative. Laid out in row-major order:
238    /// `values[s * n_actions + a]` is the loss for state `s`, action `a`.
239    pub fn new(
240        state_names: Vec<String>,
241        action_names: Vec<String>,
242        values: Vec<f64>,
243    ) -> Result<Self, ValidationError> {
244        if state_names.is_empty() {
245            return Err(ValidationError::EmptySpace {
246                field: "state_names",
247            });
248        }
249        if action_names.is_empty() {
250            return Err(ValidationError::EmptySpace {
251                field: "action_names",
252            });
253        }
254        let expected = state_names.len() * action_names.len();
255        if values.len() != expected {
256            return Err(ValidationError::DimensionMismatch {
257                expected,
258                got: values.len(),
259            });
260        }
261        let n_actions = action_names.len();
262        for (i, &v) in values.iter().enumerate() {
263            if !v.is_finite() {
264                return Err(ValidationError::InvalidLoss {
265                    state: i / n_actions,
266                    action: i % n_actions,
267                    value: v,
268                });
269            }
270            if v < 0.0 {
271                return Err(ValidationError::NegativeLoss {
272                    state: i / n_actions,
273                    action: i % n_actions,
274                    value: v,
275                });
276            }
277        }
278        Ok(Self {
279            state_names,
280            action_names,
281            values,
282        })
283    }
284
285    /// Get the loss for a specific (state, action) pair.
286    pub fn get(&self, state: usize, action: usize) -> f64 {
287        self.values[state * self.action_names.len() + action]
288    }
289
290    /// Number of states.
291    pub fn n_states(&self) -> usize {
292        self.state_names.len()
293    }
294
295    /// Number of actions.
296    pub fn n_actions(&self) -> usize {
297        self.action_names.len()
298    }
299
300    /// State labels.
301    pub fn state_names(&self) -> &[String] {
302        &self.state_names
303    }
304
305    /// Action labels.
306    pub fn action_names(&self) -> &[String] {
307        &self.action_names
308    }
309
310    /// Compute expected loss for a specific action given a posterior.
311    ///
312    /// `E[loss|a] = sum_s posterior(s) * loss(s, a)`
313    pub fn expected_loss(&self, posterior: &Posterior, action: usize) -> f64 {
314        posterior
315            .probs()
316            .iter()
317            .enumerate()
318            .map(|(s, &p)| p * self.get(s, action))
319            .sum()
320    }
321
322    /// Compute expected losses for all actions as a name-indexed map.
323    pub fn expected_losses(&self, posterior: &Posterior) -> BTreeMap<String, f64> {
324        self.action_names
325            .iter()
326            .enumerate()
327            .map(|(a, name)| (name.clone(), self.expected_loss(posterior, a)))
328            .collect()
329    }
330
331    /// Choose the Bayes-optimal action (minimum expected loss).
332    ///
333    /// Returns the action index. Ties are broken by lowest index.
334    pub fn bayes_action(&self, posterior: &Posterior) -> usize {
335        (0..self.action_names.len())
336            .min_by(|&a, &b| {
337                self.expected_loss(posterior, a)
338                    .partial_cmp(&self.expected_loss(posterior, b))
339                    .unwrap_or(std::cmp::Ordering::Equal)
340            })
341            .unwrap_or(0)
342    }
343}
344
345// ---------------------------------------------------------------------------
346// Posterior
347// ---------------------------------------------------------------------------
348
349/// Tolerance for posterior normalization checks.
350const NORMALIZATION_TOLERANCE: f64 = 1e-6;
351
352/// A discrete probability distribution over states.
353///
354/// Supports in-place Bayesian updates in O(|S|) with no allocation.
355#[derive(Clone, Debug, Serialize, PartialEq)]
356pub struct Posterior {
357    probs: Vec<f64>,
358}
359
360#[derive(Deserialize)]
361struct PosteriorRepr {
362    probs: Vec<f64>,
363}
364
365impl<'de> Deserialize<'de> for Posterior {
366    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
367    where
368        D: Deserializer<'de>,
369    {
370        let repr = PosteriorRepr::deserialize(deserializer)?;
371        Self::new(repr.probs).map_err(serde::de::Error::custom)
372    }
373}
374
375impl Posterior {
376    /// Create from explicit probabilities.
377    ///
378    /// Probabilities must sum to ~1.0 (within tolerance) and be non-negative.
379    pub fn new(probs: Vec<f64>) -> Result<Self, ValidationError> {
380        for (index, &value) in probs.iter().enumerate() {
381            if !value.is_finite() || value < 0.0 {
382                return Err(ValidationError::InvalidPosteriorProbability { index, value });
383            }
384        }
385        let sum: f64 = probs.iter().sum();
386        if (sum - 1.0).abs() > NORMALIZATION_TOLERANCE {
387            return Err(ValidationError::PosteriorNotNormalized { sum });
388        }
389        Ok(Self { probs })
390    }
391
392    /// Create a uniform prior over `n` states.
393    #[allow(clippy::cast_precision_loss)]
394    pub fn uniform(n: usize) -> Self {
395        let p = 1.0 / n as f64;
396        Self { probs: vec![p; n] }
397    }
398
399    /// Probability values (immutable).
400    pub fn probs(&self) -> &[f64] {
401        &self.probs
402    }
403
404    /// Mutable access to probability values for in-place updates.
405    pub fn probs_mut(&mut self) -> &mut [f64] {
406        &mut self.probs
407    }
408
409    /// Number of states in the distribution.
410    pub fn len(&self) -> usize {
411        self.probs.len()
412    }
413
414    /// Whether the distribution is empty.
415    pub fn is_empty(&self) -> bool {
416        self.probs.is_empty()
417    }
418
419    /// Bayesian update: multiply by likelihoods and renormalize.
420    ///
421    /// `likelihoods[s]` = P(observation | state = s).
422    /// Runs in O(|S|) with no allocation.
423    ///
424    /// # Panics
425    ///
426    /// Panics if `likelihoods.len() != self.len()`.
427    pub fn bayesian_update(&mut self, likelihoods: &[f64]) {
428        assert_eq!(likelihoods.len(), self.probs.len());
429        for (p, &l) in self.probs.iter_mut().zip(likelihoods) {
430            *p *= l;
431        }
432        self.normalize();
433    }
434
435    /// Renormalize probabilities to sum to 1.0.
436    pub fn normalize(&mut self) {
437        let sum: f64 = self.probs.iter().sum();
438        if sum > 0.0 {
439            for p in &mut self.probs {
440                *p /= sum;
441            }
442        }
443    }
444
445    /// Shannon entropy: -sum p * log2(p).
446    pub fn entropy(&self) -> f64 {
447        self.probs
448            .iter()
449            .filter(|&&p| p > 0.0)
450            .map(|&p| -p * p.log2())
451            .sum()
452    }
453
454    /// Index of the most probable state (MAP estimate).
455    ///
456    /// Ties are broken by lowest index.
457    pub fn map_state(&self) -> usize {
458        self.probs
459            .iter()
460            .enumerate()
461            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
462            .map_or(0, |(i, _)| i)
463    }
464}
465
466// ---------------------------------------------------------------------------
467// FallbackPolicy
468// ---------------------------------------------------------------------------
469
470/// Conditions under which to activate fallback heuristics.
471///
472/// A decision engine should switch to [`DecisionContract::fallback_action`]
473/// when any threshold is breached.
474#[derive(Clone, Debug, Serialize, PartialEq)]
475pub struct FallbackPolicy {
476    /// Activate fallback if calibration score drops below this value.
477    pub calibration_drift_threshold: f64,
478    /// Activate fallback if e-process statistic exceeds this value.
479    pub e_process_breach_threshold: f64,
480    /// Activate fallback if confidence interval width exceeds this value.
481    pub confidence_width_threshold: f64,
482}
483
484#[derive(Deserialize)]
485#[allow(clippy::struct_field_names)]
486struct FallbackPolicyRepr {
487    calibration_drift_threshold: f64,
488    e_process_breach_threshold: f64,
489    confidence_width_threshold: f64,
490}
491
492impl<'de> Deserialize<'de> for FallbackPolicy {
493    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
494    where
495        D: Deserializer<'de>,
496    {
497        let repr = FallbackPolicyRepr::deserialize(deserializer)?;
498        Self::new(
499            repr.calibration_drift_threshold,
500            repr.e_process_breach_threshold,
501            repr.confidence_width_threshold,
502        )
503        .map_err(serde::de::Error::custom)
504    }
505}
506
507impl FallbackPolicy {
508    /// Create a new fallback policy.
509    ///
510    /// `calibration_drift_threshold` must be in [0, 1].
511    /// Other thresholds must be non-negative.
512    pub fn new(
513        calibration_drift_threshold: f64,
514        e_process_breach_threshold: f64,
515        confidence_width_threshold: f64,
516    ) -> Result<Self, ValidationError> {
517        if !calibration_drift_threshold.is_finite()
518            || !(0.0..=1.0).contains(&calibration_drift_threshold)
519        {
520            return Err(ValidationError::ThresholdOutOfRange {
521                field: "calibration_drift_threshold",
522                value: calibration_drift_threshold,
523            });
524        }
525        if !e_process_breach_threshold.is_finite() || e_process_breach_threshold < 0.0 {
526            return Err(ValidationError::ThresholdOutOfRange {
527                field: "e_process_breach_threshold",
528                value: e_process_breach_threshold,
529            });
530        }
531        if !confidence_width_threshold.is_finite() || confidence_width_threshold < 0.0 {
532            return Err(ValidationError::ThresholdOutOfRange {
533                field: "confidence_width_threshold",
534                value: confidence_width_threshold,
535            });
536        }
537        Ok(Self {
538            calibration_drift_threshold,
539            e_process_breach_threshold,
540            confidence_width_threshold,
541        })
542    }
543
544    /// Check if fallback should be activated based on current metrics.
545    pub fn should_fallback(&self, calibration_score: f64, e_process: f64, ci_width: f64) -> bool {
546        calibration_score < self.calibration_drift_threshold
547            || e_process > self.e_process_breach_threshold
548            || ci_width > self.confidence_width_threshold
549    }
550}
551
552impl Default for FallbackPolicy {
553    fn default() -> Self {
554        Self {
555            calibration_drift_threshold: 0.7,
556            e_process_breach_threshold: 20.0,
557            confidence_width_threshold: 0.5,
558        }
559    }
560}
561
562// ---------------------------------------------------------------------------
563// DecisionContract trait
564// ---------------------------------------------------------------------------
565
566/// A contract defining the decision-making framework for a component.
567///
568/// Implementors define the state space, action set, loss matrix, and
569/// posterior update logic. The [`evaluate`] function orchestrates the
570/// full decision pipeline and produces an auditable outcome.
571pub trait DecisionContract {
572    /// Human-readable contract name (e.g., "scheduler", "load_balancer").
573    fn name(&self) -> &str;
574
575    /// Ordered labels for the state space.
576    fn state_space(&self) -> &[String];
577
578    /// Ordered labels for the action set.
579    fn action_set(&self) -> &[String];
580
581    /// The loss matrix for this contract.
582    fn loss_matrix(&self) -> &LossMatrix;
583
584    /// Update the posterior given an observation at `state_index`.
585    fn update_posterior(&self, posterior: &mut Posterior, state_index: usize);
586
587    /// Choose the optimal action given the current posterior.
588    ///
589    /// Returns an action index into [`action_set`](Self::action_set).
590    fn choose_action(&self, posterior: &Posterior) -> usize;
591
592    /// The fallback action when the model is unreliable.
593    ///
594    /// Returns an action index into [`action_set`](Self::action_set).
595    fn fallback_action(&self) -> usize;
596
597    /// Policy governing fallback activation.
598    fn fallback_policy(&self) -> &FallbackPolicy;
599}
600
601// ---------------------------------------------------------------------------
602// DecisionAuditEntry
603// ---------------------------------------------------------------------------
604
605/// Structured audit record linking a decision to the evidence ledger.
606///
607/// Captures the full context of a runtime decision for offline analysis
608/// and replay.
609#[derive(Clone, Debug, Serialize, Deserialize)]
610pub struct DecisionAuditEntry {
611    /// Unique identifier for this decision.
612    pub decision_id: DecisionId,
613    /// Trace context for distributed tracing.
614    pub trace_id: TraceId,
615    /// Name of the decision contract that was evaluated.
616    pub contract_name: String,
617    /// The action that was chosen.
618    pub action_chosen: String,
619    /// Expected loss of the chosen action.
620    pub expected_loss: f64,
621    /// Current calibration score at decision time.
622    pub calibration_score: f64,
623    /// Whether the fallback heuristic was active.
624    pub fallback_active: bool,
625    /// Snapshot of the posterior at decision time.
626    pub posterior_snapshot: Vec<f64>,
627    /// Expected loss for each candidate action.
628    pub expected_loss_by_action: BTreeMap<String, f64>,
629    /// Unix timestamp in milliseconds.
630    pub ts_unix_ms: u64,
631}
632
633impl DecisionAuditEntry {
634    /// Convert to an [`EvidenceLedger`] entry for structured tracing.
635    pub fn to_evidence_ledger(&self) -> EvidenceLedger {
636        let mut builder = EvidenceLedgerBuilder::new()
637            .ts_unix_ms(self.ts_unix_ms)
638            .component(&self.contract_name)
639            .action(&self.action_chosen)
640            .posterior(self.posterior_snapshot.clone())
641            .chosen_expected_loss(self.expected_loss)
642            .calibration_score(self.calibration_score)
643            .fallback_active(self.fallback_active);
644
645        for (action, &loss) in &self.expected_loss_by_action {
646            builder = builder.expected_loss(action, loss);
647        }
648
649        builder
650            .build()
651            .expect("audit entry should produce valid evidence ledger")
652    }
653}
654
655// ---------------------------------------------------------------------------
656// DecisionOutcome
657// ---------------------------------------------------------------------------
658
659/// Result of evaluating a decision contract.
660#[derive(Clone, Debug)]
661pub struct DecisionOutcome {
662    /// Index of the chosen action.
663    pub action_index: usize,
664    /// Name of the chosen action.
665    pub action_name: String,
666    /// Expected loss of the chosen action.
667    pub expected_loss: f64,
668    /// Expected losses for all candidate actions.
669    pub expected_losses: BTreeMap<String, f64>,
670    /// Whether fallback was activated.
671    pub fallback_active: bool,
672    /// Full audit entry for this decision.
673    pub audit_entry: DecisionAuditEntry,
674}
675
676// ---------------------------------------------------------------------------
677// EvalContext
678// ---------------------------------------------------------------------------
679
680/// Runtime context for a single decision evaluation.
681///
682/// Bundles the monitoring metrics and tracing identifiers needed by
683/// [`evaluate`].
684#[derive(Clone, Debug)]
685pub struct EvalContext {
686    /// Current calibration score.
687    pub calibration_score: f64,
688    /// Current e-process statistic.
689    pub e_process: f64,
690    /// Current confidence interval width.
691    pub ci_width: f64,
692    /// Unique identifier for this decision.
693    pub decision_id: DecisionId,
694    /// Trace context for distributed tracing.
695    pub trace_id: TraceId,
696    /// Unix timestamp in milliseconds.
697    pub ts_unix_ms: u64,
698}
699
700// ---------------------------------------------------------------------------
701// Evaluate
702// ---------------------------------------------------------------------------
703
704/// Evaluate a decision contract and produce a full audit trail.
705///
706/// This is the primary entry point for making auditable decisions.
707/// It computes expected losses, checks fallback conditions, and produces
708/// a [`DecisionOutcome`] with a linked [`DecisionAuditEntry`].
709pub fn evaluate<C: DecisionContract>(
710    contract: &C,
711    posterior: &Posterior,
712    ctx: &EvalContext,
713) -> DecisionOutcome {
714    let loss_matrix = contract.loss_matrix();
715    let expected_losses = loss_matrix.expected_losses(posterior);
716
717    let fallback_active = contract.fallback_policy().should_fallback(
718        ctx.calibration_score,
719        ctx.e_process,
720        ctx.ci_width,
721    );
722
723    let action_index = if fallback_active {
724        contract.fallback_action()
725    } else {
726        contract.choose_action(posterior)
727    };
728
729    let action_name = contract.action_set()[action_index].clone();
730    let expected_loss = expected_losses[&action_name];
731
732    let audit_entry = DecisionAuditEntry {
733        decision_id: ctx.decision_id,
734        trace_id: ctx.trace_id,
735        contract_name: contract.name().to_string(),
736        action_chosen: action_name.clone(),
737        expected_loss,
738        calibration_score: ctx.calibration_score,
739        fallback_active,
740        posterior_snapshot: posterior.probs().to_vec(),
741        expected_loss_by_action: expected_losses.clone(),
742        ts_unix_ms: ctx.ts_unix_ms,
743    };
744
745    DecisionOutcome {
746        action_index,
747        action_name,
748        expected_loss,
749        expected_losses,
750        fallback_active,
751        audit_entry,
752    }
753}
754
755// ---------------------------------------------------------------------------
756// Tests
757// ---------------------------------------------------------------------------
758
759#[cfg(test)]
760#[allow(clippy::float_cmp)]
761mod tests {
762    use super::*;
763
764    // -- Helpers --
765
766    fn two_state_matrix() -> LossMatrix {
767        // States: [good, bad], Actions: [continue, stop]
768        // loss(good, continue) = 0.0, loss(good, stop) = 0.3
769        // loss(bad, continue)  = 0.8, loss(bad, stop)  = 0.1
770        LossMatrix::new(
771            vec!["good".into(), "bad".into()],
772            vec!["continue".into(), "stop".into()],
773            vec![0.0, 0.3, 0.8, 0.1],
774        )
775        .unwrap()
776    }
777
778    struct TestContract {
779        states: Vec<String>,
780        actions: Vec<String>,
781        losses: LossMatrix,
782        policy: FallbackPolicy,
783    }
784
785    impl TestContract {
786        fn new() -> Self {
787            Self {
788                states: vec!["good".into(), "bad".into()],
789                actions: vec!["continue".into(), "stop".into()],
790                losses: two_state_matrix(),
791                policy: FallbackPolicy::default(),
792            }
793        }
794    }
795
796    #[allow(clippy::unnecessary_literal_bound)]
797    impl DecisionContract for TestContract {
798        fn name(&self) -> &str {
799            "test_contract"
800        }
801        fn state_space(&self) -> &[String] {
802            &self.states
803        }
804        fn action_set(&self) -> &[String] {
805            &self.actions
806        }
807        fn loss_matrix(&self) -> &LossMatrix {
808            &self.losses
809        }
810        fn update_posterior(&self, posterior: &mut Posterior, observation: usize) {
811            // Simple likelihood model: observed state gets high likelihood.
812            let mut likelihoods = vec![0.1; self.states.len()];
813            likelihoods[observation] = 0.9;
814            posterior.bayesian_update(&likelihoods);
815        }
816        fn choose_action(&self, posterior: &Posterior) -> usize {
817            self.losses.bayes_action(posterior)
818        }
819        fn fallback_action(&self) -> usize {
820            0 // "continue"
821        }
822        fn fallback_policy(&self) -> &FallbackPolicy {
823            &self.policy
824        }
825    }
826
827    // -- LossMatrix tests --
828
829    #[test]
830    fn loss_matrix_creation() {
831        let m = two_state_matrix();
832        assert_eq!(m.n_states(), 2);
833        assert_eq!(m.n_actions(), 2);
834        assert_eq!(m.get(0, 0), 0.0);
835        assert_eq!(m.get(0, 1), 0.3);
836        assert_eq!(m.get(1, 0), 0.8);
837        assert_eq!(m.get(1, 1), 0.1);
838    }
839
840    #[test]
841    fn loss_matrix_empty_states_rejected() {
842        let err = LossMatrix::new(vec![], vec!["a".into()], vec![]).unwrap_err();
843        assert!(matches!(
844            err,
845            ValidationError::EmptySpace {
846                field: "state_names"
847            }
848        ));
849    }
850
851    #[test]
852    fn loss_matrix_empty_actions_rejected() {
853        let err = LossMatrix::new(vec!["s".into()], vec![], vec![]).unwrap_err();
854        assert!(matches!(
855            err,
856            ValidationError::EmptySpace {
857                field: "action_names"
858            }
859        ));
860    }
861
862    #[test]
863    fn loss_matrix_dimension_mismatch() {
864        let err = LossMatrix::new(
865            vec!["s1".into(), "s2".into()],
866            vec!["a1".into()],
867            vec![0.1], // needs 2 values
868        )
869        .unwrap_err();
870        assert!(matches!(
871            err,
872            ValidationError::DimensionMismatch {
873                expected: 2,
874                got: 1
875            }
876        ));
877    }
878
879    #[test]
880    fn loss_matrix_negative_rejected() {
881        let err = LossMatrix::new(vec!["s".into()], vec!["a".into()], vec![-0.5]).unwrap_err();
882        assert!(matches!(
883            err,
884            ValidationError::NegativeLoss {
885                state: 0,
886                action: 0,
887                ..
888            }
889        ));
890    }
891
892    #[test]
893    fn loss_matrix_non_finite_rejected() {
894        let err = LossMatrix::new(vec!["s".into()], vec!["a".into()], vec![f64::NAN]).unwrap_err();
895        assert!(matches!(
896            err,
897            ValidationError::InvalidLoss {
898                state: 0,
899                action: 0,
900                value
901            } if value.is_nan()
902        ));
903    }
904
905    #[test]
906    fn loss_matrix_expected_loss() {
907        let m = two_state_matrix();
908        let posterior = Posterior::new(vec![0.8, 0.2]).unwrap();
909        // E[loss|continue] = 0.8*0.0 + 0.2*0.8 = 0.16
910        let el_continue = m.expected_loss(&posterior, 0);
911        assert!((el_continue - 0.16).abs() < 1e-10);
912        // E[loss|stop] = 0.8*0.3 + 0.2*0.1 = 0.26
913        let el_stop = m.expected_loss(&posterior, 1);
914        assert!((el_stop - 0.26).abs() < 1e-10);
915    }
916
917    #[test]
918    fn loss_matrix_bayes_action() {
919        let m = two_state_matrix();
920        // When mostly good, continue is optimal.
921        let mostly_good = Posterior::new(vec![0.9, 0.1]).unwrap();
922        assert_eq!(m.bayes_action(&mostly_good), 0); // continue
923        // When mostly bad, stop is optimal.
924        let mostly_bad = Posterior::new(vec![0.2, 0.8]).unwrap();
925        assert_eq!(m.bayes_action(&mostly_bad), 1); // stop
926    }
927
928    #[test]
929    fn loss_matrix_expected_losses_map() {
930        let m = two_state_matrix();
931        let posterior = Posterior::uniform(2);
932        let losses = m.expected_losses(&posterior);
933        assert_eq!(losses.len(), 2);
934        assert!(losses.contains_key("continue"));
935        assert!(losses.contains_key("stop"));
936    }
937
938    #[test]
939    fn loss_matrix_names() {
940        let m = two_state_matrix();
941        assert_eq!(m.state_names(), &["good", "bad"]);
942        assert_eq!(m.action_names(), &["continue", "stop"]);
943    }
944
945    #[test]
946    fn loss_matrix_toml_roundtrip() {
947        let m = two_state_matrix();
948        let toml_str = toml::to_string(&m).unwrap();
949        let parsed: LossMatrix = toml::from_str(&toml_str).unwrap();
950        assert_eq!(m, parsed);
951    }
952
953    #[test]
954    fn loss_matrix_json_roundtrip() {
955        let m = two_state_matrix();
956        let json = serde_json::to_string(&m).unwrap();
957        let parsed: LossMatrix = serde_json::from_str(&json).unwrap();
958        assert_eq!(m, parsed);
959    }
960
961    #[test]
962    fn loss_matrix_json_invalid_value_rejected_at_deserialize() {
963        let json = r#"{"state_names":["s"],"action_names":["a"],"values":[-0.5]}"#;
964        let err = serde_json::from_str::<LossMatrix>(json).unwrap_err();
965        assert!(err.to_string().contains("negative loss"));
966    }
967
968    // -- Posterior tests --
969
970    #[test]
971    fn posterior_uniform() {
972        let p = Posterior::uniform(4);
973        assert_eq!(p.len(), 4);
974        for &v in p.probs() {
975            assert!((v - 0.25).abs() < 1e-10);
976        }
977    }
978
979    #[test]
980    fn posterior_new_valid() {
981        let p = Posterior::new(vec![0.3, 0.7]).unwrap();
982        assert_eq!(p.probs(), &[0.3, 0.7]);
983    }
984
985    #[test]
986    fn posterior_new_not_normalized() {
987        let err = Posterior::new(vec![0.5, 0.3]).unwrap_err();
988        assert!(matches!(
989            err,
990            ValidationError::PosteriorNotNormalized { .. }
991        ));
992    }
993
994    #[test]
995    fn posterior_new_negative_probability_rejected() {
996        let err = Posterior::new(vec![-0.1, 1.1]).unwrap_err();
997        assert!(matches!(
998            err,
999            ValidationError::InvalidPosteriorProbability {
1000                index: 0,
1001                value
1002            } if value == -0.1
1003        ));
1004    }
1005
1006    #[test]
1007    fn posterior_new_non_finite_probability_rejected() {
1008        let err = Posterior::new(vec![f64::NAN, 1.0]).unwrap_err();
1009        assert!(matches!(
1010            err,
1011            ValidationError::InvalidPosteriorProbability {
1012                index: 0,
1013                value
1014            } if value.is_nan()
1015        ));
1016    }
1017
1018    #[test]
1019    fn posterior_bayesian_update() {
1020        let mut p = Posterior::uniform(2);
1021        // Likelihood: state 0 very likely given observation.
1022        p.bayesian_update(&[0.9, 0.1]);
1023        // After update: p(0) = 0.5*0.9 / (0.5*0.9 + 0.5*0.1) = 0.9
1024        assert!((p.probs()[0] - 0.9).abs() < 1e-10);
1025        assert!((p.probs()[1] - 0.1).abs() < 1e-10);
1026    }
1027
1028    #[test]
1029    fn posterior_bayesian_update_no_alloc() {
1030        // Verify the update works in-place by checking pointer stability.
1031        let mut p = Posterior::uniform(3);
1032        let ptr_before = p.probs().as_ptr();
1033        p.bayesian_update(&[0.5, 0.3, 0.2]);
1034        let ptr_after = p.probs().as_ptr();
1035        assert_eq!(ptr_before, ptr_after);
1036    }
1037
1038    #[test]
1039    fn posterior_entropy() {
1040        // Uniform over 2 states: entropy = 1.0 bit.
1041        let p = Posterior::uniform(2);
1042        assert!((p.entropy() - 1.0).abs() < 1e-10);
1043        // Deterministic: entropy = 0.
1044        let det = Posterior::new(vec![1.0, 0.0]).unwrap();
1045        assert!((det.entropy()).abs() < 1e-10);
1046    }
1047
1048    #[test]
1049    fn posterior_map_state() {
1050        let p = Posterior::new(vec![0.1, 0.7, 0.2]).unwrap();
1051        assert_eq!(p.map_state(), 1);
1052    }
1053
1054    #[test]
1055    fn posterior_is_empty() {
1056        let p = Posterior { probs: vec![] };
1057        assert!(p.is_empty());
1058        let p2 = Posterior::uniform(1);
1059        assert!(!p2.is_empty());
1060    }
1061
1062    #[test]
1063    fn posterior_probs_mut() {
1064        let mut p = Posterior::uniform(2);
1065        p.probs_mut()[0] = 0.8;
1066        p.probs_mut()[1] = 0.2;
1067        assert_eq!(p.probs(), &[0.8, 0.2]);
1068    }
1069
1070    // -- FallbackPolicy tests --
1071
1072    #[test]
1073    fn fallback_policy_default() {
1074        let fp = FallbackPolicy::default();
1075        assert_eq!(fp.calibration_drift_threshold, 0.7);
1076        assert_eq!(fp.e_process_breach_threshold, 20.0);
1077        assert_eq!(fp.confidence_width_threshold, 0.5);
1078    }
1079
1080    #[test]
1081    fn fallback_policy_new_valid() {
1082        let fp = FallbackPolicy::new(0.8, 10.0, 0.3).unwrap();
1083        assert_eq!(fp.calibration_drift_threshold, 0.8);
1084    }
1085
1086    #[test]
1087    fn fallback_policy_calibration_out_of_range() {
1088        let err = FallbackPolicy::new(1.5, 10.0, 0.3).unwrap_err();
1089        assert!(matches!(
1090            err,
1091            ValidationError::ThresholdOutOfRange {
1092                field: "calibration_drift_threshold",
1093                ..
1094            }
1095        ));
1096    }
1097
1098    #[test]
1099    fn fallback_policy_negative_e_process() {
1100        let err = FallbackPolicy::new(0.7, -1.0, 0.3).unwrap_err();
1101        assert!(matches!(
1102            err,
1103            ValidationError::ThresholdOutOfRange {
1104                field: "e_process_breach_threshold",
1105                ..
1106            }
1107        ));
1108    }
1109
1110    #[test]
1111    fn fallback_policy_negative_ci_width() {
1112        let err = FallbackPolicy::new(0.7, 10.0, -0.1).unwrap_err();
1113        assert!(matches!(
1114            err,
1115            ValidationError::ThresholdOutOfRange {
1116                field: "confidence_width_threshold",
1117                ..
1118            }
1119        ));
1120    }
1121
1122    #[test]
1123    fn fallback_policy_non_finite_e_process_rejected() {
1124        let err = FallbackPolicy::new(0.7, f64::NAN, 0.3).unwrap_err();
1125        assert!(matches!(
1126            err,
1127            ValidationError::ThresholdOutOfRange {
1128                field: "e_process_breach_threshold",
1129                value
1130            } if value.is_nan()
1131        ));
1132    }
1133
1134    #[test]
1135    fn fallback_policy_non_finite_ci_width_rejected() {
1136        let err = FallbackPolicy::new(0.7, 10.0, f64::INFINITY).unwrap_err();
1137        assert!(matches!(
1138            err,
1139            ValidationError::ThresholdOutOfRange {
1140                field: "confidence_width_threshold",
1141                value
1142            } if value.is_infinite()
1143        ));
1144    }
1145
1146    #[test]
1147    fn fallback_policy_json_invalid_threshold_rejected_at_deserialize() {
1148        let json = r#"{
1149            "calibration_drift_threshold": 0.7,
1150            "e_process_breach_threshold": -1.0,
1151            "confidence_width_threshold": 0.3
1152        }"#;
1153        let err = serde_json::from_str::<FallbackPolicy>(json).unwrap_err();
1154        assert!(err.to_string().contains("threshold"));
1155    }
1156
1157    #[test]
1158    fn fallback_triggered_by_low_calibration() {
1159        let fp = FallbackPolicy::default();
1160        assert!(fp.should_fallback(0.5, 1.0, 0.1)); // cal < 0.7
1161        assert!(!fp.should_fallback(0.9, 1.0, 0.1)); // cal OK
1162    }
1163
1164    #[test]
1165    fn fallback_triggered_by_e_process() {
1166        let fp = FallbackPolicy::default();
1167        assert!(fp.should_fallback(0.9, 25.0, 0.1)); // e_process > 20
1168        assert!(!fp.should_fallback(0.9, 15.0, 0.1)); // e_process OK
1169    }
1170
1171    #[test]
1172    fn fallback_triggered_by_ci_width() {
1173        let fp = FallbackPolicy::default();
1174        assert!(fp.should_fallback(0.9, 1.0, 0.6)); // ci > 0.5
1175        assert!(!fp.should_fallback(0.9, 1.0, 0.3)); // ci OK
1176    }
1177
1178    // -- DecisionContract + evaluate tests --
1179
1180    #[test]
1181    fn contract_implementable_under_50_lines() {
1182        // The TestContract impl above is 22 lines — well under 50.
1183        let contract = TestContract::new();
1184        assert_eq!(contract.name(), "test_contract");
1185        assert_eq!(contract.state_space().len(), 2);
1186        assert_eq!(contract.action_set().len(), 2);
1187    }
1188
1189    fn test_ctx(cal: f64, random: u128) -> EvalContext {
1190        EvalContext {
1191            calibration_score: cal,
1192            e_process: 1.0,
1193            ci_width: 0.1,
1194            decision_id: DecisionId::from_parts(1_700_000_000_000, random),
1195            trace_id: TraceId::from_parts(1_700_000_000_000, random),
1196            ts_unix_ms: 1_700_000_000_000,
1197        }
1198    }
1199
1200    #[test]
1201    fn evaluate_normal_decision() {
1202        let contract = TestContract::new();
1203        let posterior = Posterior::new(vec![0.9, 0.1]).unwrap();
1204        let ctx = test_ctx(0.95, 42);
1205
1206        let outcome = evaluate(&contract, &posterior, &ctx);
1207
1208        assert!(!outcome.fallback_active);
1209        assert_eq!(outcome.action_name, "continue"); // low loss when mostly good
1210        assert_eq!(outcome.action_index, 0);
1211        assert!(outcome.expected_loss < 0.1);
1212        assert_eq!(outcome.expected_losses.len(), 2);
1213    }
1214
1215    #[test]
1216    fn evaluate_fallback_decision() {
1217        let contract = TestContract::new();
1218        let posterior = Posterior::new(vec![0.2, 0.8]).unwrap();
1219        let ctx = test_ctx(0.5, 43); // low calibration triggers fallback
1220
1221        let outcome = evaluate(&contract, &posterior, &ctx);
1222
1223        assert!(outcome.fallback_active);
1224        assert_eq!(outcome.action_name, "continue"); // fallback action = 0
1225        assert_eq!(outcome.action_index, 0);
1226    }
1227
1228    #[test]
1229    fn evaluate_without_fallback_chooses_optimal() {
1230        let contract = TestContract::new();
1231        let posterior = Posterior::new(vec![0.2, 0.8]).unwrap();
1232        let ctx = test_ctx(0.95, 44); // good calibration, no fallback
1233
1234        let outcome = evaluate(&contract, &posterior, &ctx);
1235
1236        assert!(!outcome.fallback_active);
1237        assert_eq!(outcome.action_name, "stop"); // optimal when mostly bad
1238    }
1239
1240    #[test]
1241    fn evaluate_audit_entry_fields() {
1242        let contract = TestContract::new();
1243        let posterior = Posterior::uniform(2);
1244        let ctx = test_ctx(0.85, 99);
1245
1246        let outcome = evaluate(&contract, &posterior, &ctx);
1247
1248        let audit = &outcome.audit_entry;
1249        assert_eq!(audit.decision_id, ctx.decision_id);
1250        assert_eq!(audit.trace_id, ctx.trace_id);
1251        assert_eq!(audit.contract_name, "test_contract");
1252        assert_eq!(audit.calibration_score, 0.85);
1253        assert_eq!(audit.ts_unix_ms, 1_700_000_000_000);
1254        assert_eq!(audit.posterior_snapshot.len(), 2);
1255    }
1256
1257    // -- DecisionAuditEntry → EvidenceLedger --
1258
1259    #[test]
1260    fn audit_entry_to_evidence_ledger() {
1261        let contract = TestContract::new();
1262        let posterior = Posterior::new(vec![0.6, 0.4]).unwrap();
1263        let ctx = test_ctx(0.92, 100);
1264
1265        let outcome = evaluate(&contract, &posterior, &ctx);
1266        let evidence = outcome.audit_entry.to_evidence_ledger();
1267
1268        assert_eq!(evidence.ts_unix_ms, 1_700_000_000_000);
1269        assert_eq!(evidence.component, "test_contract");
1270        assert_eq!(evidence.action, outcome.action_name);
1271        assert_eq!(evidence.calibration_score, 0.92);
1272        assert!(!evidence.fallback_active);
1273        assert_eq!(evidence.posterior, vec![0.6, 0.4]);
1274        assert!(evidence.is_valid());
1275    }
1276
1277    #[test]
1278    fn audit_entry_serde_roundtrip() {
1279        let contract = TestContract::new();
1280        let posterior = Posterior::uniform(2);
1281        let ctx = test_ctx(0.88, 101);
1282
1283        let outcome = evaluate(&contract, &posterior, &ctx);
1284        let json = serde_json::to_string(&outcome.audit_entry).unwrap();
1285        let parsed: DecisionAuditEntry = serde_json::from_str(&json).unwrap();
1286        assert_eq!(parsed.contract_name, "test_contract");
1287        assert_eq!(parsed.decision_id, ctx.decision_id);
1288        assert_eq!(parsed.trace_id, ctx.trace_id);
1289    }
1290
1291    // -- Update posterior via contract --
1292
1293    #[test]
1294    fn contract_update_posterior() {
1295        let contract = TestContract::new();
1296        let mut posterior = Posterior::uniform(2);
1297        contract.update_posterior(&mut posterior, 0); // observe "good"
1298        // After update: state 0 should be more probable.
1299        assert!(posterior.probs()[0] > posterior.probs()[1]);
1300    }
1301
1302    // -- Validation error display --
1303
1304    #[test]
1305    fn validation_error_display() {
1306        let err = ValidationError::NegativeLoss {
1307            state: 1,
1308            action: 2,
1309            value: -0.5,
1310        };
1311        let msg = format!("{err}");
1312        assert!(msg.contains("-0.5"));
1313        assert!(msg.contains("state=1"));
1314        assert!(msg.contains("action=2"));
1315    }
1316
1317    #[test]
1318    fn dimension_mismatch_display() {
1319        let err = ValidationError::DimensionMismatch {
1320            expected: 6,
1321            got: 4,
1322        };
1323        let msg = format!("{err}");
1324        assert!(msg.contains('6'));
1325        assert!(msg.contains('4'));
1326    }
1327
1328    // -- FallbackPolicy serde --
1329
1330    #[test]
1331    fn fallback_policy_toml_roundtrip() {
1332        let fp = FallbackPolicy::default();
1333        let toml_str = toml::to_string(&fp).unwrap();
1334        let parsed: FallbackPolicy = toml::from_str(&toml_str).unwrap();
1335        assert_eq!(fp, parsed);
1336    }
1337
1338    #[test]
1339    fn fallback_policy_json_roundtrip() {
1340        let fp = FallbackPolicy::default();
1341        let json = serde_json::to_string(&fp).unwrap();
1342        let parsed: FallbackPolicy = serde_json::from_str(&json).unwrap();
1343        assert_eq!(fp, parsed);
1344    }
1345
1346    // -- argmin correctness with known posteriors --
1347
1348    #[test]
1349    fn argmin_correctness_deterministic_posterior() {
1350        let m = two_state_matrix();
1351        // Fully certain state=good: E[continue]=0.0, E[stop]=0.3 → continue wins.
1352        let certain_good = Posterior::new(vec![1.0, 0.0]).unwrap();
1353        assert_eq!(m.bayes_action(&certain_good), 0);
1354        // Fully certain state=bad: E[continue]=0.8, E[stop]=0.1 → stop wins.
1355        let certain_bad = Posterior::new(vec![0.0, 1.0]).unwrap();
1356        assert_eq!(m.bayes_action(&certain_bad), 1);
1357    }
1358
1359    #[test]
1360    fn argmin_correctness_breakeven_point() {
1361        let m = two_state_matrix();
1362        // Find crossover: at p(good)=x, E[continue]=0.8(1-x) and E[stop]=0.3x+0.1(1-x).
1363        // Crossover: 0.8-0.8x = 0.3x+0.1-0.1x → 0.8-0.8x = 0.2x+0.1 → 0.7=x → x=0.7
1364        // At p(good)=0.71, continue is better.
1365        let above = Posterior::new(vec![0.71, 0.29]).unwrap();
1366        assert_eq!(m.bayes_action(&above), 0);
1367        // At p(good)=0.69, stop is better.
1368        let below = Posterior::new(vec![0.69, 0.31]).unwrap();
1369        assert_eq!(m.bayes_action(&below), 1);
1370    }
1371
1372    #[test]
1373    fn argmin_three_state_three_action() {
1374        // 3 states, 3 actions: verify argmin in a bigger space.
1375        let m = LossMatrix::new(
1376            vec!["s0".into(), "s1".into(), "s2".into()],
1377            vec!["a0".into(), "a1".into(), "a2".into()],
1378            vec![
1379                1.0, 2.0, 3.0, // state 0
1380                3.0, 1.0, 2.0, // state 1
1381                2.0, 3.0, 1.0, // state 2
1382            ],
1383        )
1384        .unwrap();
1385        // Uniform posterior: E[a0]=2.0, E[a1]=2.0, E[a2]=2.0 → all tied.
1386        // Rust's min_by returns the last equal element, so index 2.
1387        let uniform = Posterior::uniform(3);
1388        let action = m.bayes_action(&uniform);
1389        // Any action is valid since all expected losses are equal.
1390        assert!(action < 3);
1391        // Posterior concentrated on state 1: a1 has loss 1.0 → a1 wins.
1392        let state1 = Posterior::new(vec![0.0, 1.0, 0.0]).unwrap();
1393        assert_eq!(m.bayes_action(&state1), 1);
1394        // Posterior concentrated on state 2: a2 has loss 1.0 → a2 wins.
1395        let state2 = Posterior::new(vec![0.0, 0.0, 1.0]).unwrap();
1396        assert_eq!(m.bayes_action(&state2), 2);
1397    }
1398
1399    // -- Bayesian update hand-computed --
1400
1401    #[test]
1402    fn bayesian_update_hand_computed_three_state() {
1403        // Prior: [0.5, 0.3, 0.2]
1404        // Likelihoods: [0.1, 0.6, 0.3]
1405        // Unnorm: [0.05, 0.18, 0.06]  sum=0.29
1406        // Posterior: [0.05/0.29, 0.18/0.29, 0.06/0.29]
1407        let mut p = Posterior::new(vec![0.5, 0.3, 0.2]).unwrap();
1408        p.bayesian_update(&[0.1, 0.6, 0.3]);
1409        let expected = [0.05 / 0.29, 0.18 / 0.29, 0.06 / 0.29];
1410        for (i, &e) in expected.iter().enumerate() {
1411            assert!(
1412                (p.probs()[i] - e).abs() < 1e-10,
1413                "state {i}: got {}, expected {e}",
1414                p.probs()[i]
1415            );
1416        }
1417    }
1418
1419    #[test]
1420    fn bayesian_update_successive_convergence() {
1421        // Repeated observations of state 0 should drive posterior toward certainty.
1422        let mut p = Posterior::uniform(3);
1423        for _ in 0..20 {
1424            p.bayesian_update(&[0.9, 0.05, 0.05]);
1425        }
1426        assert!(p.probs()[0] > 0.999);
1427        assert!(p.probs()[1] < 0.001);
1428        assert!(p.probs()[2] < 0.001);
1429    }
1430
1431    // -- End-to-end decision pipeline --
1432
1433    #[test]
1434    fn end_to_end_pipeline() {
1435        let contract = TestContract::new();
1436        let mut posterior = Posterior::uniform(2);
1437
1438        // Feed 5 "good" observations: posterior should shift toward state 0.
1439        for _ in 0..5 {
1440            contract.update_posterior(&mut posterior, 0);
1441        }
1442        assert!(posterior.probs()[0] > 0.99);
1443
1444        // Make a decision: should be "continue" (low loss when good).
1445        let ctx = test_ctx(0.95, 200);
1446        let outcome = evaluate(&contract, &posterior, &ctx);
1447        assert!(!outcome.fallback_active);
1448        assert_eq!(outcome.action_name, "continue");
1449        assert!(outcome.expected_loss < 0.01);
1450
1451        // Verify evidence ledger entry.
1452        let evidence = outcome.audit_entry.to_evidence_ledger();
1453        assert_eq!(evidence.component, "test_contract");
1454        assert_eq!(evidence.action, "continue");
1455        assert!(evidence.is_valid());
1456
1457        // Now feed "bad" observations to shift posterior.
1458        for _ in 0..20 {
1459            contract.update_posterior(&mut posterior, 1);
1460        }
1461        assert!(posterior.probs()[1] > 0.99);
1462
1463        // Decision should now be "stop".
1464        let ctx2 = test_ctx(0.95, 201);
1465        let outcome2 = evaluate(&contract, &posterior, &ctx2);
1466        assert_eq!(outcome2.action_name, "stop");
1467    }
1468
1469    // -- Concurrent decision safety --
1470
1471    #[test]
1472    fn concurrent_decision_safety() {
1473        use std::sync::Arc;
1474        use std::thread;
1475
1476        let contract = Arc::new(TestContract::new());
1477        let results: Vec<_> = (0..10)
1478            .map(|i| {
1479                let c = Arc::clone(&contract);
1480                thread::spawn(move || {
1481                    let posterior = Posterior::uniform(2);
1482                    let ctx = EvalContext {
1483                        calibration_score: 0.9,
1484                        e_process: 1.0,
1485                        ci_width: 0.1,
1486                        decision_id: DecisionId::from_parts(1_700_000_000_000, u128::from(i)),
1487                        trace_id: TraceId::from_parts(1_700_000_000_000, u128::from(i)),
1488                        ts_unix_ms: 1_700_000_000_000 + i,
1489                    };
1490                    let outcome = evaluate(c.as_ref(), &posterior, &ctx);
1491                    assert!(!outcome.action_name.is_empty());
1492                    assert_eq!(outcome.expected_losses.len(), 2);
1493                    let evidence = outcome.audit_entry.to_evidence_ledger();
1494                    assert!(evidence.is_valid());
1495                    outcome
1496                })
1497            })
1498            .map(|h| h.join().unwrap())
1499            .collect();
1500        assert_eq!(results.len(), 10);
1501        // All should agree on the same action for uniform posterior.
1502        let actions: std::collections::HashSet<_> =
1503            results.iter().map(|r| r.action_name.clone()).collect();
1504        assert_eq!(
1505            actions.len(),
1506            1,
1507            "all threads should choose the same action"
1508        );
1509    }
1510
1511    // -- Cross-crate type verification --
1512
1513    #[test]
1514    fn cross_crate_franken_kernel_types() {
1515        // Verify DecisionId and TraceId are the franken_kernel versions.
1516        let did = DecisionId::from_parts(1_700_000_000_000, 42);
1517        assert_eq!(did.timestamp_ms(), 1_700_000_000_000);
1518        let tid = TraceId::from_parts(1_700_000_000_000, 1);
1519        assert_eq!(tid.timestamp_ms(), 1_700_000_000_000);
1520
1521        // Verify they work correctly in DecisionAuditEntry.
1522        let contract = TestContract::new();
1523        let posterior = Posterior::uniform(2);
1524        let ctx = EvalContext {
1525            calibration_score: 0.9,
1526            e_process: 1.0,
1527            ci_width: 0.1,
1528            decision_id: did,
1529            trace_id: tid,
1530            ts_unix_ms: 1_700_000_000_000,
1531        };
1532        let outcome = evaluate(&contract, &posterior, &ctx);
1533        assert_eq!(outcome.audit_entry.decision_id, did);
1534        assert_eq!(outcome.audit_entry.trace_id, tid);
1535    }
1536
1537    // -- Posterior serde roundtrips --
1538
1539    #[test]
1540    fn posterior_json_roundtrip() {
1541        let p = Posterior::new(vec![0.25, 0.75]).unwrap();
1542        let json = serde_json::to_string(&p).unwrap();
1543        let parsed: Posterior = serde_json::from_str(&json).unwrap();
1544        assert_eq!(p, parsed);
1545    }
1546
1547    #[test]
1548    fn posterior_json_invalid_value_rejected_at_deserialize() {
1549        let json = r#"{"probs":[-0.1,1.1]}"#;
1550        let err = serde_json::from_str::<Posterior>(json).unwrap_err();
1551        assert!(err.to_string().contains("finite and non-negative"));
1552    }
1553
1554    // -- LossMatrix 3x3 TOML --
1555
1556    #[test]
1557    fn loss_matrix_3x3_toml_roundtrip() {
1558        let m = LossMatrix::new(
1559            vec!["s0".into(), "s1".into(), "s2".into()],
1560            vec!["a0".into(), "a1".into(), "a2".into()],
1561            vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
1562        )
1563        .unwrap();
1564        let toml_str = toml::to_string(&m).unwrap();
1565        let parsed: LossMatrix = toml::from_str(&toml_str).unwrap();
1566        assert_eq!(m, parsed);
1567    }
1568
1569    // -- DecisionOutcome debug --
1570
1571    #[test]
1572    fn decision_outcome_debug() {
1573        let contract = TestContract::new();
1574        let posterior = Posterior::uniform(2);
1575        let ctx = test_ctx(0.9, 300);
1576        let outcome = evaluate(&contract, &posterior, &ctx);
1577        let dbg = format!("{outcome:?}");
1578        assert!(dbg.contains("DecisionOutcome"));
1579        assert!(dbg.contains("action_name"));
1580    }
1581
1582    // -- Fallback all three triggers --
1583
1584    #[test]
1585    fn fallback_multiple_triggers_simultaneously() {
1586        let fp = FallbackPolicy::default();
1587        // All three conditions breached simultaneously.
1588        assert!(fp.should_fallback(0.3, 30.0, 0.9));
1589    }
1590
1591    #[test]
1592    fn fallback_no_trigger_at_exact_thresholds() {
1593        let fp = FallbackPolicy::default();
1594        // Exactly at thresholds: cal=0.7 (not < 0.7), e=20 (not > 20), ci=0.5 (not > 0.5).
1595        assert!(!fp.should_fallback(0.7, 20.0, 0.5));
1596    }
1597
1598    // -- Entropy edge cases --
1599
1600    #[test]
1601    fn posterior_entropy_three_state_uniform() {
1602        let p = Posterior::uniform(3);
1603        // entropy = log2(3) ≈ 1.585
1604        assert!((p.entropy() - 3.0_f64.log2()).abs() < 1e-10);
1605    }
1606
1607    #[test]
1608    fn posterior_entropy_single_state() {
1609        let p = Posterior::new(vec![1.0]).unwrap();
1610        assert!((p.entropy()).abs() < 1e-10);
1611    }
1612
1613    // -- ValidationError is std::error::Error --
1614
1615    #[test]
1616    fn validation_error_is_std_error() {
1617        fn assert_error<E: std::error::Error>() {}
1618        assert_error::<ValidationError>();
1619    }
1620}
1621
1622// ---------------------------------------------------------------------------
1623// Property-based tests (proptest)
1624// ---------------------------------------------------------------------------
1625
1626#[cfg(test)]
1627#[allow(clippy::float_cmp)]
1628mod proptest_tests {
1629    use super::*;
1630    use proptest::prelude::*;
1631
1632    /// Generate a valid probability vector of length `n`.
1633    fn arb_posterior(n: usize) -> impl Strategy<Value = Posterior> {
1634        proptest::collection::vec(0.01_f64..=1.0, n).prop_map(|mut v| {
1635            let sum: f64 = v.iter().sum();
1636            for p in &mut v {
1637                *p /= sum;
1638            }
1639            Posterior::new(v).unwrap()
1640        })
1641    }
1642
1643    /// Generate a valid loss matrix of given dimensions.
1644    fn arb_loss_matrix(n_states: usize, n_actions: usize) -> impl Strategy<Value = LossMatrix> {
1645        let states: Vec<String> = (0..n_states).map(|i| format!("s{i}")).collect();
1646        let actions: Vec<String> = (0..n_actions).map(|i| format!("a{i}")).collect();
1647        proptest::collection::vec(0.0_f64..=10.0, n_states * n_actions).prop_map(move |values| {
1648            LossMatrix::new(states.clone(), actions.clone(), values).unwrap()
1649        })
1650    }
1651
1652    // -- Argmin: chosen action minimizes expected loss for any valid posterior --
1653
1654    proptest! {
1655        #![proptest_config(ProptestConfig::with_cases(10_000))]
1656
1657        #[test]
1658        fn bayes_action_minimizes_expected_loss(
1659            matrix in arb_loss_matrix(3, 3),
1660            posterior in arb_posterior(3),
1661        ) {
1662            let chosen = matrix.bayes_action(&posterior);
1663            let chosen_loss = matrix.expected_loss(&posterior, chosen);
1664            for a in 0..matrix.n_actions() {
1665                let other_loss = matrix.expected_loss(&posterior, a);
1666                prop_assert!(
1667                    chosen_loss <= other_loss + 1e-10,
1668                    "action {chosen} (loss {chosen_loss}) should be <= action {a} (loss {other_loss})"
1669                );
1670            }
1671        }
1672    }
1673
1674    proptest! {
1675        #![proptest_config(ProptestConfig::with_cases(10_000))]
1676
1677        #[test]
1678        fn bayes_action_minimizes_2x2(
1679            matrix in arb_loss_matrix(2, 2),
1680            posterior in arb_posterior(2),
1681        ) {
1682            let chosen = matrix.bayes_action(&posterior);
1683            let chosen_loss = matrix.expected_loss(&posterior, chosen);
1684            for a in 0..matrix.n_actions() {
1685                prop_assert!(chosen_loss <= matrix.expected_loss(&posterior, a) + 1e-10);
1686            }
1687        }
1688    }
1689
1690    // -- Posterior update preserves normalization --
1691
1692    proptest! {
1693        #![proptest_config(ProptestConfig::with_cases(10_000))]
1694
1695        #[test]
1696        fn bayesian_update_preserves_normalization(
1697            prior in arb_posterior(4),
1698            likelihoods in proptest::collection::vec(0.01_f64..=1.0, 4usize),
1699        ) {
1700            let mut p = prior;
1701            p.bayesian_update(&likelihoods);
1702            let sum: f64 = p.probs().iter().sum();
1703            prop_assert!(
1704                (sum - 1.0).abs() < 1e-10,
1705                "posterior sum = {sum}, expected 1.0"
1706            );
1707            for &prob in p.probs() {
1708                prop_assert!(prob >= 0.0, "negative probability: {prob}");
1709            }
1710        }
1711    }
1712
1713    // -- Posterior: all elements non-negative after update --
1714
1715    proptest! {
1716        #![proptest_config(ProptestConfig::with_cases(10_000))]
1717
1718        #[test]
1719        fn posterior_all_non_negative_after_update(
1720            prior in arb_posterior(3),
1721            likelihoods in proptest::collection::vec(0.0_f64..=1.0, 3usize),
1722        ) {
1723            let mut p = prior;
1724            // Only update if likelihoods have positive sum (avoid degenerate case).
1725            let lik_sum: f64 = likelihoods.iter().sum();
1726            if lik_sum > 0.0 {
1727                p.bayesian_update(&likelihoods);
1728                for &prob in p.probs() {
1729                    prop_assert!(prob >= 0.0, "negative probability: {prob}");
1730                }
1731            }
1732        }
1733    }
1734
1735    // -- FallbackPolicy serde roundtrip --
1736
1737    proptest! {
1738        #[test]
1739        fn fallback_policy_serde_roundtrip(
1740            cal in 0.0_f64..=1.0,
1741            e_proc in 0.0_f64..=100.0,
1742            ci in 0.0_f64..=10.0,
1743        ) {
1744            let fp = FallbackPolicy::new(cal, e_proc, ci).unwrap();
1745            let json = serde_json::to_string(&fp).unwrap();
1746            let parsed: FallbackPolicy = serde_json::from_str(&json).unwrap();
1747            // Use approximate comparison due to f64 JSON round-trip precision.
1748            prop_assert!((fp.calibration_drift_threshold - parsed.calibration_drift_threshold).abs() < 1e-12);
1749            prop_assert!((fp.e_process_breach_threshold - parsed.e_process_breach_threshold).abs() < 1e-12);
1750            prop_assert!((fp.confidence_width_threshold - parsed.confidence_width_threshold).abs() < 1e-12);
1751        }
1752    }
1753
1754    // -- LossMatrix serde roundtrip --
1755
1756    proptest! {
1757        #[test]
1758        fn loss_matrix_serde_roundtrip(
1759            matrix in arb_loss_matrix(2, 3),
1760        ) {
1761            let json = serde_json::to_string(&matrix).unwrap();
1762            let parsed: LossMatrix = serde_json::from_str(&json).unwrap();
1763            prop_assert_eq!(matrix.state_names(), parsed.state_names());
1764            prop_assert_eq!(matrix.action_names(), parsed.action_names());
1765            // Use approximate comparison for f64 values.
1766            for s in 0..matrix.n_states() {
1767                for a in 0..matrix.n_actions() {
1768                    prop_assert!((matrix.get(s, a) - parsed.get(s, a)).abs() < 1e-12);
1769                }
1770            }
1771        }
1772    }
1773
1774    // -- Expected loss is a convex combination --
1775
1776    proptest! {
1777        #![proptest_config(ProptestConfig::with_cases(10_000))]
1778
1779        #[test]
1780        fn expected_loss_within_loss_range(
1781            matrix in arb_loss_matrix(3, 3),
1782            posterior in arb_posterior(3),
1783        ) {
1784            for a in 0..matrix.n_actions() {
1785                let el = matrix.expected_loss(&posterior, a);
1786                let min_loss = (0..matrix.n_states())
1787                    .map(|s| matrix.get(s, a))
1788                    .fold(f64::INFINITY, f64::min);
1789                let max_loss = (0..matrix.n_states())
1790                    .map(|s| matrix.get(s, a))
1791                    .fold(f64::NEG_INFINITY, f64::max);
1792                prop_assert!(
1793                    el >= min_loss - 1e-10 && el <= max_loss + 1e-10,
1794                    "expected loss {el} outside [{min_loss}, {max_loss}]"
1795                );
1796            }
1797        }
1798    }
1799}