use std::collections::VecDeque;
use tracing::info_span;
use zeph_config::TrajectoryRiskAccumulatorConfig;
pub use zeph_common::audit::{AuditSignalType, Severity};
fn signal_type_label(t: AuditSignalType) -> &'static str {
match t {
AuditSignalType::PolicyViolation => "policy_violation",
AuditSignalType::PromptInjectionPattern => "prompt_injection",
AuditSignalType::ToolChainAnomaly => "tool_chain_anomaly",
AuditSignalType::ConfidenceDrop => "confidence_drop",
_ => "unknown",
}
}
fn severity_label(s: Severity) -> &'static str {
match s {
Severity::Low => "low",
Severity::Medium => "medium",
Severity::High => "high",
_ => "unknown",
}
}
#[derive(Debug, Clone)]
pub struct SignalEvent {
pub turn_id: u32,
pub signal_type: AuditSignalType,
pub severity: Severity,
pub raw_score: f64,
}
pub struct TrajectoryRiskAccumulator {
config: Option<TrajectoryRiskAccumulatorConfig>,
trajectory_risk: f64,
turn_count: u32,
signal_history: VecDeque<SignalEvent>,
}
impl TrajectoryRiskAccumulator {
#[must_use]
pub fn new_noop() -> Self {
Self {
config: None,
trajectory_risk: 0.0,
turn_count: 0,
signal_history: VecDeque::new(),
}
}
#[must_use]
pub fn new(config: TrajectoryRiskAccumulatorConfig) -> Self {
if !config.enabled {
return Self::new_noop();
}
let cap = config.signal_history_cap;
Self {
config: Some(config),
trajectory_risk: 0.0,
turn_count: 0,
signal_history: VecDeque::with_capacity(cap.min(1024)),
}
}
pub fn advance_turn(&mut self) {
let _span = info_span!("memory.shadow.advance_turn").entered();
let Some(config) = &self.config else { return };
self.turn_count = self.turn_count.saturating_add(1);
let halflife = if config.risk_halflife_turns == 0 {
tracing::warn!("risk_halflife_turns = 0 is invalid, clamping to 1");
1u32
} else {
config.risk_halflife_turns
};
let decay = (-std::f64::consts::LN_2 / f64::from(halflife)).exp();
self.trajectory_risk *= decay;
}
pub fn ingest(&mut self, signal_type: AuditSignalType, severity: Severity) {
let _span = info_span!("memory.shadow.ingest").entered();
let Some(config) = &self.config else { return };
let base_weight = match signal_type {
AuditSignalType::PolicyViolation => config.signal_weights.policy_violation,
AuditSignalType::PromptInjectionPattern => config.signal_weights.prompt_injection,
AuditSignalType::ToolChainAnomaly => config.signal_weights.tool_chain_anomaly,
AuditSignalType::ConfidenceDrop => config.signal_weights.confidence_drop,
_ => 0.0,
};
let severity_mult = match severity {
Severity::Low => config.severity_multipliers.low,
Severity::Medium => config.severity_multipliers.medium,
Severity::High => config.severity_multipliers.high,
_ => 1.0,
};
let raw_score = base_weight * severity_mult;
self.trajectory_risk = (self.trajectory_risk + raw_score).min(1.0);
let cap = config.signal_history_cap;
if self.signal_history.len() >= cap {
self.signal_history.pop_front();
}
self.signal_history.push_back(SignalEvent {
turn_id: self.turn_count,
signal_type,
severity,
raw_score,
});
metrics::counter!(
"shadow_memory_signals_total",
"type" => signal_type_label(signal_type),
"severity" => severity_label(severity),
)
.increment(1);
}
pub fn record_block(&self) {
metrics::counter!("shadow_memory_blocks_total").increment(1);
}
pub fn record_escalation(&self) {
metrics::counter!("shadow_memory_escalations_total").increment(1);
}
#[must_use]
pub fn current_risk(&self) -> f64 {
if self.config.is_none() {
return 0.0;
}
self.trajectory_risk
}
#[must_use]
pub fn is_blocked(&self) -> bool {
let Some(config) = &self.config else {
return false;
};
self.trajectory_risk >= config.risk_threshold
}
#[must_use]
pub fn should_escalate(&self) -> bool {
let Some(config) = &self.config else {
return false;
};
self.trajectory_risk >= config.escalation_threshold
&& self.trajectory_risk < config.risk_threshold
}
#[must_use]
pub fn top_signals(&self, n: usize) -> Vec<&SignalEvent> {
let mut signals: Vec<&SignalEvent> = self.signal_history.iter().collect();
signals.sort_by(|a, b| {
b.raw_score
.partial_cmp(&a.raw_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
signals.truncate(n);
signals
}
pub fn reset(&mut self) {
if self.config.is_none() {
return;
}
self.trajectory_risk = 0.0;
self.signal_history.clear();
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.config.is_some()
}
#[must_use]
pub fn turn_count(&self) -> u32 {
self.turn_count
}
}
#[cfg(test)]
mod tests {
use super::*;
use zeph_config::{
TrajectoryRiskAccumulatorConfig, TrajectorySeverityMultipliers, TrajectorySignalWeights,
};
fn enabled_config() -> TrajectoryRiskAccumulatorConfig {
TrajectoryRiskAccumulatorConfig {
enabled: true,
risk_threshold: 0.75,
escalation_threshold: 0.50,
risk_halflife_turns: 10,
signal_history_cap: 200,
tui_show_risk_gauge: true,
reset_on_compaction: false,
signal_weights: TrajectorySignalWeights::default(),
severity_multipliers: TrajectorySeverityMultipliers::default(),
}
}
#[test]
fn new_noop_returns_zero_risk() {
let acc = TrajectoryRiskAccumulator::new_noop();
assert!(acc.current_risk() < f64::EPSILON);
assert!(!acc.is_blocked());
assert!(!acc.is_enabled());
}
#[test]
fn single_signal_below_threshold_not_blocked() {
let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
acc.advance_turn();
acc.ingest(AuditSignalType::PolicyViolation, Severity::Medium);
assert!(acc.current_risk() > 0.0);
assert!(acc.current_risk() < 0.75);
assert!(!acc.is_blocked());
}
#[test]
fn multi_turn_chain_accumulates_and_blocks() {
let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
for _ in 0..5 {
acc.advance_turn();
acc.ingest(AuditSignalType::PromptInjectionPattern, Severity::High);
}
assert!(acc.is_blocked(), "risk={}", acc.current_risk());
}
#[test]
fn temporal_decay_reduces_score() {
let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
acc.advance_turn();
acc.ingest(AuditSignalType::PromptInjectionPattern, Severity::High);
let risk_after_signal = acc.current_risk();
assert!(risk_after_signal > 0.0);
for _ in 0..100 {
acc.advance_turn();
}
assert!(
acc.current_risk() < risk_after_signal / 2.0,
"expected significant decay, got {}",
acc.current_risk()
);
}
#[test]
fn risk_clamped_at_one() {
let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
for _ in 0..20 {
acc.advance_turn();
acc.ingest(AuditSignalType::PromptInjectionPattern, Severity::High);
}
assert!(
acc.current_risk() <= 1.0,
"trajectory_risk exceeded 1.0: {}",
acc.current_risk()
);
}
#[test]
fn advance_turn_before_ingest_applies_decay() {
let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
acc.advance_turn();
acc.ingest(AuditSignalType::PolicyViolation, Severity::High);
let risk_t1 = acc.current_risk();
acc.advance_turn();
let risk_after_decay = acc.current_risk();
assert!(
risk_after_decay < risk_t1,
"decay should reduce risk before new ingest: {risk_after_decay} vs {risk_t1}"
);
acc.ingest(AuditSignalType::PolicyViolation, Severity::High);
assert!(
acc.current_risk() > risk_after_decay,
"ingest should increase risk: {} vs {}",
acc.current_risk(),
risk_after_decay
);
}
#[test]
fn decay_formula_matches_spec() {
let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
for _ in 0..5 {
acc.advance_turn();
acc.ingest(AuditSignalType::ConfidenceDrop, Severity::Medium);
}
let decay = (-std::f64::consts::LN_2 / 10.0_f64).exp();
let expected: f64 = (0..5).map(|k| 0.15_f64 * decay.powi(k)).sum();
assert!(
expected < 1.0,
"test precondition: expected sum {expected} must be < 1.0 (no clamping)"
);
assert!(
(acc.current_risk() - expected).abs() < 1e-9,
"expected {expected:.12}, got {:.12}",
acc.current_risk()
);
}
#[test]
fn fifty_clean_turns_zero_risk() {
let mut acc = TrajectoryRiskAccumulator::new(enabled_config());
for _ in 0..50 {
acc.advance_turn();
}
assert!(
acc.current_risk() < f64::EPSILON,
"no signals → risk must stay 0.0"
);
assert!(!acc.is_blocked());
}
}