use std::fmt;
use std::time::Duration;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Direction {
Incoming,
Outgoing,
}
impl fmt::Display for Direction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Incoming => write!(f, "incoming"),
Self::Outgoing => write!(f, "outgoing"),
}
}
}
#[derive(Debug, Clone)]
pub struct ProtocolEvent {
pub direction: Direction,
pub method: String,
pub content: serde_json::Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PhaseAction {
Stay,
Advance,
}
#[derive(Debug, Clone)]
pub enum DriveResult {
Complete,
TransportClosed,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum TerminationReason {
TerminalPhaseReached,
Cancelled,
MaxSessionExpired,
TransportClosed,
}
impl fmt::Display for TerminationReason {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TerminalPhaseReached => write!(f, "terminal phase reached"),
Self::Cancelled => write!(f, "cancelled"),
Self::MaxSessionExpired => write!(f, "max session expired"),
Self::TransportClosed => write!(f, "transport closed"),
}
}
}
#[derive(Debug, Clone)]
pub struct ActorResult {
pub actor_name: String,
pub termination: TerminationReason,
pub phases_completed: usize,
pub total_phases: usize,
pub final_phase: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AwaitExtractor {
pub actor: String,
pub extractors: Vec<String>,
pub timeout: Duration,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn direction_display() {
assert_eq!(Direction::Incoming.to_string(), "incoming");
assert_eq!(Direction::Outgoing.to_string(), "outgoing");
}
#[test]
fn direction_equality() {
assert_eq!(Direction::Incoming, Direction::Incoming);
assert_ne!(Direction::Incoming, Direction::Outgoing);
}
#[test]
fn phase_action_equality() {
assert_eq!(PhaseAction::Stay, PhaseAction::Stay);
assert_eq!(PhaseAction::Advance, PhaseAction::Advance);
assert_ne!(PhaseAction::Stay, PhaseAction::Advance);
}
#[test]
fn termination_reason_display() {
assert_eq!(
TerminationReason::TerminalPhaseReached.to_string(),
"terminal phase reached"
);
assert_eq!(TerminationReason::Cancelled.to_string(), "cancelled");
assert_eq!(
TerminationReason::MaxSessionExpired.to_string(),
"max session expired"
);
assert_eq!(
TerminationReason::TransportClosed.to_string(),
"transport closed"
);
}
#[test]
fn protocol_event_construction() {
let event = ProtocolEvent {
direction: Direction::Incoming,
method: "tools/call".to_string(),
content: serde_json::json!({"name": "calculator"}),
};
assert_eq!(event.direction, Direction::Incoming);
assert_eq!(event.method, "tools/call");
}
#[test]
fn actor_result_construction() {
let result = ActorResult {
actor_name: "mcp_poison".to_string(),
termination: TerminationReason::TerminalPhaseReached,
phases_completed: 2,
total_phases: 3,
final_phase: Some("exploit".to_string()),
};
assert_eq!(result.actor_name, "mcp_poison");
assert_eq!(result.phases_completed, 2);
assert_eq!(result.total_phases, 3);
assert_eq!(result.final_phase.as_deref(), Some("exploit"));
}
#[test]
fn await_extractor_construction() {
let spec = AwaitExtractor {
actor: "other_actor".to_string(),
extractors: vec!["token".to_string(), "session_id".to_string()],
timeout: Duration::from_secs(30),
};
assert_eq!(spec.actor, "other_actor");
assert_eq!(spec.extractors.len(), 2);
assert_eq!(spec.timeout, Duration::from_secs(30));
}
#[test]
fn direction_serialization() {
let json = serde_json::to_string(&Direction::Incoming).unwrap();
assert_eq!(json, "\"incoming\"");
let deserialized: Direction = serde_json::from_str("\"outgoing\"").unwrap();
assert_eq!(deserialized, Direction::Outgoing);
}
#[test]
fn termination_reason_serialization() {
let json = serde_json::to_string(&TerminationReason::Cancelled).unwrap();
assert_eq!(json, "\"cancelled\"");
let json = serde_json::to_string(&TerminationReason::TransportClosed).unwrap();
assert_eq!(json, "\"transport_closed\"");
}
}