use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::label::Label;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RiskDirection {
Increases,
Reduces,
Neutral,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FrictionLevel {
L0,
L1,
L2,
L3,
L4,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
#[serde(default)]
pub struct RiskContext {
pub guardrail_proximity_pct: Option<f64>,
pub halted: bool,
}
impl RiskContext {
pub const PROXIMITY_PCT: f64 = 1.0;
#[must_use]
pub fn from_engine(
drawdown_pct: Option<f64>,
last_drawdown_alert_pct: Option<f64>,
halted: bool,
) -> Self {
let proximity = match (drawdown_pct, last_drawdown_alert_pct) {
(Some(dd), Some(alert)) => Some((alert - dd).abs()),
_ => None,
};
Self {
guardrail_proximity_pct: proximity,
halted,
}
}
#[must_use]
pub fn near_guardrail(&self) -> bool {
self.guardrail_proximity_pct
.is_some_and(|pp| pp <= Self::PROXIMITY_PCT)
}
}
impl FrictionLevel {
#[must_use]
pub const fn from_label(label: Label) -> Self {
match label {
Label::Fresh | Label::Steady | Label::Recovery => Self::L0,
Label::Elevated | Label::Fatigued => Self::L1,
Label::Tilt => Self::L2,
}
}
#[must_use]
pub fn from_label_and_risk(label: Label, risk: RiskContext) -> Self {
let base = Self::from_label(label);
if matches!(label, Label::Tilt) {
if risk.halted {
return Self::L4;
}
if risk.near_guardrail() {
return Self::L3;
}
}
base
}
#[must_use]
pub const fn pause(self) -> Duration {
match self {
Self::L0 => Duration::ZERO,
Self::L1 => Duration::from_secs(3),
Self::L2 => Duration::from_secs(10),
Self::L3 => Duration::from_secs(30),
Self::L4 => Duration::from_secs(15 * 60),
}
}
#[must_use]
pub const fn requires_typed_confirm(self) -> bool {
matches!(self, Self::L2 | Self::L3 | Self::L4)
}
#[must_use]
pub const fn is_refusal(self) -> bool {
matches!(self, Self::L4)
}
}
#[derive(Debug, Clone, Copy)]
pub struct FrictionGate<D: GateableDirection> {
level: FrictionLevel,
_direction: std::marker::PhantomData<D>,
}
pub trait GateableDirection: sealed::Sealed {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Increases;
impl sealed::Sealed for Increases {}
impl GateableDirection for Increases {}
mod sealed {
pub trait Sealed {}
}
impl FrictionGate<Increases> {
#[must_use]
pub const fn new(level: FrictionLevel) -> Self {
Self {
level,
_direction: std::marker::PhantomData,
}
}
#[must_use]
pub const fn for_label(label: Label) -> Self {
Self::new(FrictionLevel::from_label(label))
}
#[must_use]
pub const fn level(&self) -> FrictionLevel {
self.level
}
#[must_use]
pub const fn pause(&self) -> Duration {
self.level.pause()
}
#[must_use]
pub const fn requires_typed_confirm(&self) -> bool {
self.level.requires_typed_confirm()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fresh_steady_recovery_have_no_pause() {
for label in [Label::Fresh, Label::Steady, Label::Recovery] {
let gate = FrictionGate::<Increases>::for_label(label);
assert_eq!(gate.pause(), Duration::ZERO);
assert!(!gate.requires_typed_confirm());
}
}
#[test]
fn elevated_and_fatigued_get_three_seconds() {
for label in [Label::Elevated, Label::Fatigued] {
let gate = FrictionGate::<Increases>::for_label(label);
assert_eq!(gate.pause(), Duration::from_secs(3));
assert!(!gate.requires_typed_confirm());
}
}
#[test]
fn tilt_requires_typed_confirm_and_ten_seconds() {
let gate = FrictionGate::<Increases>::for_label(Label::Tilt);
assert_eq!(gate.pause(), Duration::from_secs(10));
assert!(gate.requires_typed_confirm());
}
#[test]
fn from_label_caps_at_l2_without_risk_context() {
for label in [
Label::Fresh,
Label::Steady,
Label::Elevated,
Label::Fatigued,
Label::Tilt,
Label::Recovery,
] {
assert!(
FrictionLevel::from_label(label) <= FrictionLevel::L2,
"from_label({label:?}) escaped the L2 cap — L3/L4 must flow through from_label_and_risk only"
);
}
}
#[test]
fn non_tilt_labels_never_escalate_regardless_of_risk() {
let haz = RiskContext {
guardrail_proximity_pct: Some(0.1),
halted: true,
};
for label in [
Label::Fresh,
Label::Steady,
Label::Elevated,
Label::Fatigued,
Label::Recovery,
] {
let level = FrictionLevel::from_label_and_risk(label, haz);
assert_eq!(
level,
FrictionLevel::from_label(label),
"label={label:?} must not escalate — TILT-gated invariant"
);
}
}
#[test]
fn tilt_plus_halt_escalates_to_l4() {
let ctx = RiskContext {
guardrail_proximity_pct: None,
halted: true,
};
assert_eq!(
FrictionLevel::from_label_and_risk(Label::Tilt, ctx),
FrictionLevel::L4
);
}
#[test]
fn tilt_plus_proximity_escalates_to_l3() {
let ctx = RiskContext {
guardrail_proximity_pct: Some(0.5),
halted: false,
};
assert_eq!(
FrictionLevel::from_label_and_risk(Label::Tilt, ctx),
FrictionLevel::L3
);
}
#[test]
fn tilt_plus_halt_beats_tilt_plus_proximity() {
let ctx = RiskContext {
guardrail_proximity_pct: Some(0.1),
halted: true,
};
assert_eq!(
FrictionLevel::from_label_and_risk(Label::Tilt, ctx),
FrictionLevel::L4
);
}
#[test]
fn tilt_with_distant_proximity_stays_at_l2() {
let ctx = RiskContext {
guardrail_proximity_pct: Some(RiskContext::PROXIMITY_PCT + 0.01),
halted: false,
};
assert_eq!(
FrictionLevel::from_label_and_risk(Label::Tilt, ctx),
FrictionLevel::L2
);
}
#[test]
fn tilt_without_any_risk_signal_stays_at_l2() {
assert_eq!(
FrictionLevel::from_label_and_risk(Label::Tilt, RiskContext::default()),
FrictionLevel::L2
);
}
#[test]
fn proximity_is_inclusive_at_the_threshold() {
let ctx = RiskContext {
guardrail_proximity_pct: Some(RiskContext::PROXIMITY_PCT),
halted: false,
};
assert_eq!(
FrictionLevel::from_label_and_risk(Label::Tilt, ctx),
FrictionLevel::L3,
"proximity at exactly the threshold must escalate"
);
}
#[test]
fn risk_context_from_engine_computes_absolute_distance() {
let ctx = RiskContext::from_engine(Some(4.0), Some(5.0), false);
assert_eq!(ctx.guardrail_proximity_pct, Some(1.0));
assert!(ctx.near_guardrail());
let ctx_reversed = RiskContext::from_engine(Some(5.0), Some(4.0), false);
assert_eq!(
ctx_reversed.guardrail_proximity_pct,
Some(1.0),
"from_engine must be sign-symmetric — absolute distance only"
);
}
#[test]
fn risk_context_from_engine_drops_proximity_when_either_field_missing() {
let ctx = RiskContext::from_engine(None, Some(5.0), false);
assert_eq!(ctx.guardrail_proximity_pct, None);
assert!(!ctx.near_guardrail());
let ctx = RiskContext::from_engine(Some(5.0), None, false);
assert_eq!(ctx.guardrail_proximity_pct, None);
assert!(!ctx.near_guardrail());
}
#[test]
fn l3_pauses_thirty_seconds_and_l4_refuses() {
assert_eq!(FrictionLevel::L3.pause(), Duration::from_secs(30));
assert!(FrictionLevel::L3.requires_typed_confirm());
assert!(!FrictionLevel::L3.is_refusal());
assert!(FrictionLevel::L4.is_refusal());
assert!(FrictionLevel::L4.requires_typed_confirm());
}
fn _compile_fail_doctest_marker() {}
}