use std::time::Duration;
use anyhow::Result;
use futures::future::join_all;
use serde::Serialize;
use serde_json::{json, Value};
use tokio_util::sync::CancellationToken;
use crate::{
client::CallOutcome,
corpus::Corpus,
differential::response_value,
finding::{Finding, FindingKind, ReproInfo},
};
use super::{
exec::McpExec,
reporter::{Reporter, RunInfo},
};
#[derive(Debug, Default, Serialize)]
pub struct TortureReport {
pub findings_count: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TortureMode {
Parallel,
StateLeak,
}
pub const GLOBAL_DEADLINE_FACTOR: u32 = 4;
pub struct TortureRun {
pub mode: TortureMode,
pub target_tool: String,
pub concurrency: usize,
pub timeout: Duration,
pub global_deadline: Duration,
pub transport_name: String,
}
impl TortureRun {
#[must_use]
pub fn new(
mode: TortureMode,
target_tool: String,
concurrency: usize,
timeout: Duration,
transport_name: String,
) -> Self {
Self {
mode,
target_tool,
concurrency,
timeout,
global_deadline: timeout
.checked_mul(GLOBAL_DEADLINE_FACTOR)
.unwrap_or(timeout),
transport_name,
}
}
pub async fn execute<C: McpExec + ?Sized>(
self,
client: &C,
corpus: &Corpus,
reporter: &mut dyn Reporter,
) -> Result<TortureReport> {
reporter.on_run_start(&RunInfo {
kind: "torture",
total_iterations: self.concurrency as u64,
tools: vec![self.target_tool.clone()],
blocked: Vec::new(),
master_seed: None,
});
let token = CancellationToken::new();
let watchdog = {
let token = token.clone();
let deadline = self.global_deadline;
tokio::spawn(async move {
tokio::select! {
() = tokio::time::sleep(deadline) => {
token.cancel();
}
() = token.cancelled() => {
}
}
})
};
let findings = match self.mode {
TortureMode::Parallel => self.run_parallel(client, &token).await,
TortureMode::StateLeak => self.run_state_leak(client).await,
};
token.cancel();
let _ = watchdog.await;
let mut report = TortureReport::default();
for finding in findings {
corpus.write_finding(&finding)?;
reporter.on_finding(&finding);
report.findings_count += 1;
}
reporter.on_run_end();
Ok(report)
}
async fn run_parallel<C: McpExec + ?Sized>(
&self,
client: &C,
token: &CancellationToken,
) -> Vec<Finding> {
let payload = json!({});
let calls = (0..self.concurrency)
.map(|_| {
guarded_call(
client,
&self.target_tool,
payload.clone(),
self.timeout,
token,
)
})
.collect::<Vec<_>>();
let outcomes = join_all(calls).await;
let success_count = outcomes
.iter()
.filter(|outcome| matches!(outcome, CallOutcome::Ok(_)))
.count();
let mut findings = Vec::new();
if success_count < self.concurrency {
findings.push(Finding::new(
FindingKind::ProtocolError,
self.target_tool.clone(),
"parallel calls did not all complete successfully",
format!("{success_count}/{} calls completed", self.concurrency),
ReproInfo {
seed: 0,
tool_call: payload.clone(),
transport: self.transport_name.clone(),
composition_trail: Vec::new(),
},
));
}
if self.target_tool == "counter_inc" {
let counter =
match guarded_call(client, "counter_get", json!({}), self.timeout, token).await {
CallOutcome::Ok(result) => response_value(&result)
.get("counter")
.and_then(Value::as_u64)
.unwrap_or(0) as usize,
_ => 0,
};
if counter != self.concurrency {
findings.push(Finding::new(
FindingKind::PropertyFailure {
invariant: "counter_inc must be atomic".to_string(),
},
self.target_tool.clone(),
"counter lost updates under parallel calls",
format!("expected counter {}, observed {counter}", self.concurrency),
ReproInfo {
seed: 0,
tool_call: payload,
transport: self.transport_name.clone(),
composition_trail: Vec::new(),
},
));
}
}
findings
}
async fn run_state_leak<C: McpExec + ?Sized>(&self, client: &C) -> Vec<Finding> {
let set_payload = json!({"key": "secret", "value": "alice-data"});
let get_payload = json!({"key": "secret"});
let _ = client
.call_tool("session_set", set_payload, self.timeout)
.await;
let observed = match client
.call_tool("session_get", get_payload.clone(), self.timeout)
.await
{
CallOutcome::Ok(result) => response_value(&result),
other => json!({"unexpected": format!("{other:?}")}),
};
let leaked = observed.get("value").is_some_and(|value| !value.is_null());
if !leaked {
return Vec::new();
}
vec![Finding::new(
FindingKind::StateLeak,
"session_get",
"session data is visible outside its expected boundary",
format!(
"observed response: {}",
serde_json::to_string_pretty(&observed).unwrap_or_default()
),
ReproInfo {
seed: 0,
tool_call: get_payload,
transport: self.transport_name.clone(),
composition_trail: Vec::new(),
},
)]
}
}
async fn guarded_call<C: McpExec + ?Sized>(
client: &C,
tool: &str,
args: Value,
timeout: Duration,
token: &CancellationToken,
) -> CallOutcome {
if token.is_cancelled() {
return CallOutcome::Hang(timeout);
}
tokio::select! {
outcome = client.call_tool(tool, args, timeout) => outcome,
_ = token.cancelled() => CallOutcome::Hang(timeout),
}
}
pub fn parse_duration(value: &str) -> Option<Duration> {
if let Some(milliseconds) = value.strip_suffix("ms") {
return milliseconds.parse::<u64>().ok().map(Duration::from_millis);
}
if let Some(seconds) = value.strip_suffix('s') {
return seconds.parse::<u64>().ok().map(Duration::from_secs);
}
value.parse::<u64>().ok().map(Duration::from_secs)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn parse_duration_handles_units() {
assert_eq!(parse_duration("30s"), Some(Duration::from_secs(30)));
assert_eq!(parse_duration("500ms"), Some(Duration::from_millis(500)));
assert_eq!(parse_duration("5"), Some(Duration::from_secs(5)));
assert_eq!(parse_duration("nope"), None);
}
}