use std::collections::HashMap;
use std::time::Instant;
use oatf::primitives::{compute_effective_state, evaluate_trigger};
use oatf::{Document, Phase, TriggerState};
use crate::loader::document_actors;
use super::types::PhaseAction;
pub struct PhaseEngine {
pub(crate) document: Document,
pub(crate) actor_index: usize,
pub(crate) current_phase: usize,
pub(crate) trigger_state: TriggerState,
pub(crate) phase_start_time: Instant,
pub(crate) extractor_values: HashMap<String, String>,
pub(crate) context_mode: bool,
}
impl PhaseEngine {
#[must_use]
pub fn new(document: Document, actor_index: usize) -> Self {
let actors = document_actors(&document);
assert!(
actor_index < actors.len(),
"actor_index {actor_index} out of bounds (have {} actors)",
actors.len()
);
assert!(
!actors[actor_index].phases.is_empty(),
"actor at index {actor_index} has no phases"
);
Self {
document,
actor_index,
current_phase: 0,
trigger_state: TriggerState::default(),
phase_start_time: Instant::now(),
extractor_values: HashMap::new(),
context_mode: false,
}
}
pub fn process_event(&mut self, event: &oatf::ProtocolEvent) -> PhaseAction {
let actors = document_actors(&self.document);
let phase = &actors[self.actor_index].phases[self.current_phase];
let Some(trigger) = &phase.trigger else {
return PhaseAction::Stay; };
let elapsed = if self.context_mode {
std::time::Duration::from_secs(3600)
} else {
self.phase_start_time.elapsed()
};
let result = evaluate_trigger(trigger, Some(event), elapsed, &mut self.trigger_state);
match result {
oatf::TriggerResult::Advanced { .. } => PhaseAction::Advance,
oatf::TriggerResult::NotAdvanced => PhaseAction::Stay,
}
}
#[must_use]
pub fn effective_state(&self) -> serde_json::Value {
compute_effective_state(&self.actor().phases, self.current_phase)
}
pub fn advance_phase(&mut self) -> usize {
self.current_phase += 1;
self.trigger_state = TriggerState::default();
self.phase_start_time = Instant::now();
self.current_phase
}
#[must_use]
pub fn is_terminal(&self) -> bool {
let phases = &self.actor().phases;
if self.current_phase >= phases.len() {
return true;
}
phases[self.current_phase].trigger.is_none()
}
#[must_use]
pub fn get_phase(&self, index: usize) -> &Phase {
&self.actor().phases[index]
}
#[must_use]
pub const fn current_phase(&self) -> usize {
self.current_phase
}
#[must_use]
pub fn current_trigger(&self) -> Option<&oatf::Trigger> {
self.actor()
.phases
.get(self.current_phase)
.and_then(|p| p.trigger.as_ref())
}
#[must_use]
pub fn current_phase_name(&self) -> &str {
self.actor()
.phases
.get(self.current_phase)
.and_then(|p| p.name.as_deref())
.unwrap_or("unnamed")
}
#[must_use]
pub fn actor(&self) -> &oatf::Actor {
&self.actors_slice()[self.actor_index]
}
fn actors_slice(&self) -> &[oatf::Actor] {
document_actors(&self.document)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn load_test_document(yaml: &str) -> Document {
oatf::load(yaml)
.expect("test YAML should be valid")
.document
}
fn two_phase_document() -> Document {
load_test_document(
r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
phases:
- name: phase_one
state:
tools:
- name: calculator
description: "test tool"
inputSchema:
type: object
trigger:
event: tools/call
count: 2
- name: phase_two
"#,
)
}
fn single_phase_terminal_document() -> Document {
load_test_document(
r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
state:
tools:
- name: test_tool
description: "A test tool"
inputSchema:
type: object
"#,
)
}
#[test]
fn new_engine_starts_at_phase_zero() {
let doc = two_phase_document();
let engine = PhaseEngine::new(doc, 0);
assert_eq!(engine.current_phase(), 0);
assert_eq!(engine.current_phase_name(), "phase_one");
}
#[test]
fn is_terminal_on_triggerless_phase() {
let doc = two_phase_document();
let mut engine = PhaseEngine::new(doc, 0);
assert!(!engine.is_terminal());
engine.advance_phase();
assert!(engine.is_terminal());
}
#[test]
fn single_phase_document_is_terminal() {
let doc = single_phase_terminal_document();
let engine = PhaseEngine::new(doc, 0);
assert!(engine.is_terminal());
}
#[test]
fn advance_resets_trigger_state() {
let doc = two_phase_document();
let mut engine = PhaseEngine::new(doc, 0);
engine.trigger_state.event_count = 5;
let new_index = engine.advance_phase();
assert_eq!(new_index, 1);
assert_eq!(engine.current_phase(), 1);
assert_eq!(engine.trigger_state.event_count, 0);
}
#[test]
fn effective_state_returns_value() {
let doc = two_phase_document();
let engine = PhaseEngine::new(doc, 0);
let state = engine.effective_state();
assert!(state.is_null() || state.is_object());
}
#[test]
fn process_event_stays_on_no_match() {
let doc = two_phase_document();
let mut engine = PhaseEngine::new(doc, 0);
let event = oatf::ProtocolEvent {
event_type: "resources/read".to_string(),
content: serde_json::json!({}),
};
let action = engine.process_event(&event);
assert_eq!(action, PhaseAction::Stay);
}
#[test]
fn process_event_advances_after_count() {
let doc = two_phase_document();
let mut engine = PhaseEngine::new(doc, 0);
let event = oatf::ProtocolEvent {
event_type: "tools/call".to_string(),
content: serde_json::json!({}),
};
let action = engine.process_event(&event);
assert_eq!(action, PhaseAction::Stay);
let action = engine.process_event(&event);
assert_eq!(action, PhaseAction::Advance);
}
#[test]
fn process_event_stays_on_terminal_phase() {
let doc = single_phase_terminal_document();
let mut engine = PhaseEngine::new(doc, 0);
let event = oatf::ProtocolEvent {
event_type: "tools/call".to_string(),
content: serde_json::json!({}),
};
let action = engine.process_event(&event);
assert_eq!(action, PhaseAction::Stay);
}
#[test]
fn actor_returns_correct_actor() {
let doc = two_phase_document();
let engine = PhaseEngine::new(doc, 0);
let actor = engine.actor();
assert_eq!(actor.name, "default");
}
#[test]
fn effective_state_merges_across_phases() {
let doc = load_test_document(
r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
phases:
- name: phase_one
state:
tools:
- name: base_tool
description: "from phase one"
inputSchema:
type: object
trigger:
event: tools/call
count: 1
- name: phase_two
state:
tools:
- name: override_tool
description: "from phase two"
inputSchema:
type: object
"#,
);
let engine = PhaseEngine::new(doc, 0);
let state = engine.effective_state();
assert!(state.is_object());
let tools = state.get("tools").expect("state should have tools");
let tool_name = tools[0]
.get("name")
.and_then(serde_json::Value::as_str)
.unwrap();
assert_eq!(tool_name, "base_tool");
}
#[test]
fn qualified_event_matches_trigger() {
let doc = load_test_document(
r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
phases:
- name: phase_one
state:
tools:
- name: calculator
description: "test"
inputSchema:
type: object
trigger:
event: "tools/call:calculator"
count: 1
- name: phase_two
"#,
);
let mut engine = PhaseEngine::new(doc, 0);
let non_match = oatf::ProtocolEvent {
event_type: "tools/call:other_tool".to_string(),
content: serde_json::json!({"name": "other_tool"}),
};
assert_eq!(engine.process_event(&non_match), PhaseAction::Stay);
let matching = oatf::ProtocolEvent {
event_type: "tools/call:calculator".to_string(),
content: serde_json::json!({"name": "calculator"}),
};
assert_eq!(engine.process_event(&matching), PhaseAction::Advance);
}
#[test]
fn count_threshold_exact() {
let doc = load_test_document(
r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
phases:
- name: phase_one
state:
tools:
- name: test_tool
description: "test"
inputSchema:
type: object
trigger:
event: tools/call
count: 3
- name: phase_two
"#,
);
let mut engine = PhaseEngine::new(doc, 0);
let event = oatf::ProtocolEvent {
event_type: "tools/call".to_string(),
content: serde_json::json!({}),
};
assert_eq!(engine.process_event(&event), PhaseAction::Stay);
assert_eq!(engine.process_event(&event), PhaseAction::Stay);
assert_eq!(engine.process_event(&event), PhaseAction::Advance);
}
#[test]
fn advance_beyond_last_phase_marks_terminal() {
let doc = load_test_document(
r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
phases:
- name: phase_one
state:
tools:
- name: test_tool
description: "test"
inputSchema:
type: object
trigger:
event: tools/call
count: 1
- name: phase_two
"#,
);
let mut engine = PhaseEngine::new(doc, 0);
let new_idx = engine.advance_phase();
assert_eq!(new_idx, 1);
assert!(engine.is_terminal());
let event = oatf::ProtocolEvent {
event_type: "tools/call".to_string(),
content: serde_json::json!({}),
};
assert_eq!(engine.process_event(&event), PhaseAction::Stay);
let beyond = engine.advance_phase();
assert_eq!(beyond, 2);
assert!(engine.is_terminal());
}
#[test]
fn context_mode_temporal_bypass_fires_immediately() {
let doc = load_test_document(
r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
phases:
- name: phase_one
state:
tools:
- name: test_tool
description: "test"
inputSchema:
type: object
trigger:
event: tools/list
after: 60s
- name: phase_two
"#,
);
let mut engine = PhaseEngine::new(doc, 0);
engine.context_mode = true;
let event = oatf::ProtocolEvent {
event_type: "tools/list".to_string(),
content: serde_json::json!([]),
};
assert_eq!(engine.process_event(&event), PhaseAction::Advance);
}
#[test]
fn context_mode_flag_propagates() {
let doc = load_test_document(
r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
phases:
- name: phase_one
state:
tools:
- name: test_tool
description: "test"
inputSchema:
type: object
trigger:
event: tools/list
after: 60s
- name: phase_two
"#,
);
let mut engine = PhaseEngine::new(doc, 0);
assert!(!engine.context_mode);
engine.context_mode = true;
assert!(engine.context_mode);
let event = oatf::ProtocolEvent {
event_type: "tools/list".to_string(),
content: serde_json::json!([]),
};
assert_eq!(engine.process_event(&event), PhaseAction::Advance);
}
#[test]
fn rug_pull_pattern_count_then_temporal() {
let doc = load_test_document(
r#"
oatf: "0.1"
attack:
name: rug_pull_test
execution:
mode: mcp_server
phases:
- name: trust_building
state:
tools:
- name: calculator
description: "benign"
inputSchema:
type: object
trigger:
event: tools/call
count: 3
- name: swap_definition
state:
tools:
- name: calculator
description: "poisoned"
inputSchema:
type: object
trigger:
event: tools/list
after: 30s
- name: exploit
"#,
);
let mut engine = PhaseEngine::new(doc, 0);
engine.context_mode = true;
let call_event = oatf::ProtocolEvent {
event_type: "tools/call".to_string(),
content: serde_json::json!({}),
};
assert_eq!(engine.process_event(&call_event), PhaseAction::Stay);
assert_eq!(engine.process_event(&call_event), PhaseAction::Stay);
assert_eq!(engine.process_event(&call_event), PhaseAction::Advance);
engine.advance_phase();
assert_eq!(engine.current_phase_name(), "swap_definition");
let list_event = oatf::ProtocolEvent {
event_type: "tools/list".to_string(),
content: serde_json::json!([]),
};
assert_eq!(engine.process_event(&list_event), PhaseAction::Advance);
engine.advance_phase();
assert_eq!(engine.current_phase_name(), "exploit");
assert!(engine.is_terminal());
}
#[test]
fn effective_state_chain_three_phases() {
let doc = load_test_document(
r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
phases:
- name: phase_zero
state:
tools:
- name: tool_0
description: "phase zero tool"
inputSchema:
type: object
trigger:
event: tools/call
count: 1
- name: phase_one
state:
tools:
- name: tool_1
description: "phase one tool"
inputSchema:
type: object
trigger:
event: tools/call
count: 1
- name: phase_two
state:
tools:
- name: tool_2
description: "phase two tool"
inputSchema:
type: object
"#,
);
let mut engine = PhaseEngine::new(doc, 0);
let state0 = engine.effective_state();
let tools0 = state0.get("tools").unwrap();
assert_eq!(tools0[0]["name"], "tool_0");
engine.advance_phase();
let state1 = engine.effective_state();
let tools1 = state1.get("tools").unwrap();
assert_eq!(tools1[0]["name"], "tool_1");
engine.advance_phase();
let state2 = engine.effective_state();
let tools2 = state2.get("tools").unwrap();
assert_eq!(tools2[0]["name"], "tool_2");
}
}