use std::fmt;
use crate::unified_evidence::{DecisionDomain, EvidenceEntry, EvidenceEntryBuilder, EvidenceTerm};
pub trait State: fmt::Debug + Clone + 'static {}
pub trait Action: fmt::Debug + Clone + 'static {
fn label(&self) -> &'static str;
}
#[derive(Debug, Clone)]
pub struct Posterior<S: State> {
pub point_estimate: S,
pub log_posterior: f64,
pub confidence_interval: (f64, f64),
pub evidence: Vec<EvidenceTerm>,
}
#[derive(Debug, Clone)]
pub struct Decision<A: Action> {
pub action: A,
pub expected_loss: f64,
pub next_best_loss: f64,
pub log_posterior: f64,
pub confidence_interval: (f64, f64),
pub evidence: Vec<EvidenceTerm>,
}
impl<A: Action> Decision<A> {
#[must_use]
pub fn loss_avoided(&self) -> f64 {
(self.next_best_loss - self.expected_loss).max(0.0)
}
#[must_use]
pub fn to_evidence_entry(&self, domain: DecisionDomain, timestamp_ns: u64) -> EvidenceEntry {
let mut builder = EvidenceEntryBuilder::new(domain, 0, timestamp_ns)
.log_posterior(self.log_posterior)
.action(self.action.label())
.loss_avoided(self.loss_avoided())
.confidence_interval(self.confidence_interval.0, self.confidence_interval.1);
for term in &self.evidence {
builder = builder.evidence(term.label, term.bayes_factor);
}
builder.build()
}
}
pub trait Outcome: fmt::Debug + 'static {}
impl Outcome for bool {}
impl Outcome for f64 {}
impl Outcome for u64 {}
impl Outcome for u32 {}
pub trait DecisionCore<S: State, A: Action> {
type Outcome: Outcome;
fn domain(&self) -> DecisionDomain;
fn posterior(&self, evidence: &[EvidenceTerm]) -> Posterior<S>;
fn loss(&self, action: &A, state: &S) -> f64;
fn decide(&mut self, evidence: &[EvidenceTerm]) -> Decision<A>;
fn calibrate(&mut self, outcome: &Self::Outcome);
fn fallback_action(&self) -> A;
fn actions(&self) -> Vec<A>;
fn decide_and_record(
&mut self,
evidence: &[EvidenceTerm],
ledger: &mut crate::unified_evidence::UnifiedEvidenceLedger,
timestamp_ns: u64,
) -> Decision<A> {
let decision = self.decide(evidence);
let entry = decision.to_evidence_entry(self.domain(), timestamp_ns);
ledger.record(entry);
decision
}
}
pub fn argmin_expected_loss<S, A, F>(
actions: &[A],
state_estimate: &S,
loss_fn: F,
) -> Option<(usize, f64)>
where
S: State,
A: Action,
F: Fn(&A, &S) -> f64,
{
if actions.is_empty() {
return None;
}
let mut best_idx = 0;
let mut best_loss = f64::INFINITY;
for (i, action) in actions.iter().enumerate() {
let l = loss_fn(action, state_estimate);
if l < best_loss {
best_loss = l;
best_idx = i;
}
}
Some((best_idx, best_loss))
}
pub fn second_best_loss<S, A, F>(
actions: &[A],
state_estimate: &S,
best_idx: usize,
loss_fn: F,
) -> f64
where
S: State,
A: Action,
F: Fn(&A, &S) -> f64,
{
let mut second = f64::INFINITY;
for (i, action) in actions.iter().enumerate() {
if i == best_idx {
continue;
}
let l = loss_fn(action, state_estimate);
if l < second {
second = l;
}
}
second
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct TestRate(f64);
impl State for TestRate {}
#[derive(Debug, Clone, PartialEq)]
enum TestAction {
Low,
High,
}
impl Action for TestAction {
fn label(&self) -> &'static str {
match self {
Self::Low => "low",
Self::High => "high",
}
}
}
impl Outcome for TestRate {}
struct TestController {
rate: f64,
calibration_count: u32,
}
impl TestController {
fn new(initial_rate: f64) -> Self {
Self {
rate: initial_rate,
calibration_count: 0,
}
}
}
impl DecisionCore<TestRate, TestAction> for TestController {
type Outcome = f64;
fn domain(&self) -> DecisionDomain {
DecisionDomain::DiffStrategy
}
fn posterior(&self, _evidence: &[EvidenceTerm]) -> Posterior<TestRate> {
let log_post = (self.rate / (1.0 - self.rate.clamp(0.001, 0.999))).ln();
Posterior {
point_estimate: TestRate(self.rate),
log_posterior: log_post,
confidence_interval: (self.rate - 0.1, self.rate + 0.1),
evidence: Vec::new(),
}
}
fn loss(&self, action: &TestAction, state: &TestRate) -> f64 {
match action {
TestAction::Low => state.0 * 10.0, TestAction::High => (1.0 - state.0) * 5.0, }
}
fn decide(&mut self, evidence: &[EvidenceTerm]) -> Decision<TestAction> {
let posterior = self.posterior(evidence);
let actions = self.actions();
let state = &posterior.point_estimate;
let (best_idx, best_loss) =
argmin_expected_loss(&actions, state, |a, s| self.loss(a, s)).unwrap();
let next_best = second_best_loss(&actions, state, best_idx, |a, s| self.loss(a, s));
Decision {
action: actions[best_idx].clone(),
expected_loss: best_loss,
next_best_loss: next_best,
log_posterior: posterior.log_posterior,
confidence_interval: posterior.confidence_interval,
evidence: posterior.evidence,
}
}
fn calibrate(&mut self, outcome: &f64) {
self.rate = self.rate * 0.9 + outcome * 0.1;
self.calibration_count += 1;
}
fn fallback_action(&self) -> TestAction {
TestAction::High }
fn actions(&self) -> Vec<TestAction> {
vec![TestAction::Low, TestAction::High]
}
}
#[test]
fn decide_chooses_low_for_low_rate() {
let mut ctrl = TestController::new(0.1);
let decision = ctrl.decide(&[]);
assert_eq!(decision.action, TestAction::Low);
assert!(decision.expected_loss < decision.next_best_loss);
}
#[test]
fn decide_chooses_high_for_high_rate() {
let mut ctrl = TestController::new(0.8);
let decision = ctrl.decide(&[]);
assert_eq!(decision.action, TestAction::High);
}
#[test]
fn loss_avoided_nonnegative() {
let mut ctrl = TestController::new(0.3);
let decision = ctrl.decide(&[]);
assert!(decision.loss_avoided() >= 0.0);
}
#[test]
fn calibrate_updates_rate() {
let mut ctrl = TestController::new(0.5);
ctrl.calibrate(&1.0);
assert!((ctrl.rate - 0.55).abs() < 1e-10);
assert_eq!(ctrl.calibration_count, 1);
}
#[test]
fn fallback_is_conservative() {
let ctrl = TestController::new(0.5);
assert_eq!(ctrl.fallback_action(), TestAction::High);
}
#[test]
fn posterior_reflects_rate() {
let ctrl = TestController::new(0.7);
let post = ctrl.posterior(&[]);
assert!((post.point_estimate.0 - 0.7).abs() < 1e-10);
assert!(post.log_posterior > 0.0); }
#[test]
fn posterior_negative_log_odds_for_low_rate() {
let ctrl = TestController::new(0.2);
let post = ctrl.posterior(&[]);
assert!(post.log_posterior < 0.0); }
#[test]
fn evidence_entry_conversion() {
let mut ctrl = TestController::new(0.3);
let decision = ctrl.decide(&[]);
let entry = decision.to_evidence_entry(DecisionDomain::DiffStrategy, 42_000);
assert_eq!(entry.domain, DecisionDomain::DiffStrategy);
assert_eq!(entry.timestamp_ns, 42_000);
assert_eq!(entry.action, "low");
assert!(entry.loss_avoided >= 0.0);
}
#[test]
fn decide_and_record_adds_to_ledger() {
let mut ctrl = TestController::new(0.3);
let mut ledger = crate::unified_evidence::UnifiedEvidenceLedger::new(100);
assert_eq!(ledger.len(), 0);
let _decision = ctrl.decide_and_record(&[], &mut ledger, 1000);
assert_eq!(ledger.len(), 1);
}
#[test]
fn argmin_empty_returns_none() {
let actions: Vec<TestAction> = vec![];
let state = TestRate(0.5);
let result = argmin_expected_loss(&actions, &state, |_, _| 0.0);
assert!(result.is_none());
}
#[test]
fn argmin_single_action() {
let actions = vec![TestAction::Low];
let state = TestRate(0.5);
let result = argmin_expected_loss(&actions, &state, |a, s| match a {
TestAction::Low => s.0 * 10.0,
TestAction::High => (1.0 - s.0) * 5.0,
});
assert_eq!(result, Some((0, 5.0)));
}
#[test]
fn second_best_with_two_actions() {
let actions = vec![TestAction::Low, TestAction::High];
let state = TestRate(0.3);
let sb = second_best_loss(&actions, &state, 0, |a, s| match a {
TestAction::Low => s.0 * 10.0,
TestAction::High => (1.0 - s.0) * 5.0,
});
assert!((sb - 3.5).abs() < 1e-10);
}
#[test]
fn decision_to_jsonl_roundtrip() {
let mut ctrl = TestController::new(0.3);
let decision = ctrl.decide(&[]);
let entry = decision.to_evidence_entry(DecisionDomain::DiffStrategy, 42_000);
let jsonl = entry.to_jsonl();
assert!(jsonl.contains("\"schema\":\"ftui-evidence-v2\""));
assert!(jsonl.contains("\"domain\":\"diff_strategy\""));
assert!(jsonl.contains("\"action\":\"low\""));
}
#[test]
fn calibrate_multiple_rounds() {
let mut ctrl = TestController::new(0.5);
for _ in 0..10 {
ctrl.calibrate(&1.0);
}
assert!(ctrl.rate > 0.8);
assert_eq!(ctrl.calibration_count, 10);
}
#[test]
fn decision_crossover_point() {
let mut ctrl = TestController::new(1.0 / 3.0);
let decision = ctrl.decide(&[]);
assert!(decision.loss_avoided() < 0.01);
}
#[test]
fn domain_reports_correctly() {
let ctrl = TestController::new(0.5);
assert_eq!(ctrl.domain(), DecisionDomain::DiffStrategy);
}
#[test]
fn deterministic_decide() {
let mut ctrl_a = TestController::new(0.4);
let mut ctrl_b = TestController::new(0.4);
let d_a = ctrl_a.decide(&[]);
let d_b = ctrl_b.decide(&[]);
assert_eq!(d_a.action, d_b.action);
assert!((d_a.expected_loss - d_b.expected_loss).abs() < 1e-10);
}
}