use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use reqwest::header::{HeaderName, HeaderValue};
use tokio::sync::{broadcast, oneshot};
use tokio_util::sync::CancellationToken;
use crate::cli::args::ExecutionArgs;
use crate::engine::mcp_server::McpServerDriver;
use crate::engine::phase::PhaseEngine;
use crate::engine::phase_loop::{PhaseLoop, PhaseLoopConfig};
use crate::engine::trace::SharedTrace;
use crate::engine::types::{ActorResult, AwaitExtractor};
use crate::error::EngineError;
use crate::loader::document_actors;
use crate::observability::events::{EventEmitter, ThoughtJackEvent};
use crate::orchestration::store::ExtractorStore;
use crate::protocol::{a2a_client, a2a_server, agui, mcp_client};
use crate::transport::http::HttpConfig;
use crate::transport::{HttpTransport, StdioTransport};
#[cfg(test)]
#[derive(Clone)]
pub struct TransportFactory(
pub Arc<dyn Fn() -> Arc<dyn crate::transport::Transport> + Send + Sync>,
);
#[cfg(test)]
impl std::fmt::Debug for TransportFactory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("TransportFactory(<fn>)")
}
}
#[cfg(test)]
#[must_use]
pub fn null_transport_factory() -> TransportFactory {
TransportFactory(Arc::new(|| Arc::new(crate::transport::NullTransport)))
}
#[derive(Debug, Clone)]
pub struct ActorConfig {
pub mcp_server_bind: Option<String>,
pub agui_client_endpoint: Option<String>,
pub a2a_server_bind: Option<String>,
pub a2a_client_endpoint: Option<String>,
pub mcp_client_command: Option<String>,
pub mcp_client_args: Option<String>,
pub mcp_client_endpoint: Option<String>,
pub headers: Vec<(String, String)>,
pub raw_synthesize: bool,
pub grace_period: Option<Duration>,
pub max_session: Duration,
pub readiness_timeout: Duration,
pub context_mode: bool,
pub context_provider_config: Option<crate::transport::provider::ProviderConfig>,
pub max_turns: Option<u32>,
pub context_system_prompt: Option<String>,
#[cfg(test)]
pub transport_factory: Option<TransportFactory>,
}
pub fn build_actor_config(args: &ExecutionArgs) -> Result<ActorConfig, EngineError> {
let mut headers: Vec<(String, String)> = Vec::with_capacity(args.header.len());
for (idx, raw) in args.header.iter().enumerate() {
headers.push(parse_cli_header(raw, idx + 1)?);
}
let context_provider_config = if args.context {
let model = args.context_model.clone().ok_or_else(|| {
EngineError::Driver("--context-model is required when --context is enabled".into())
})?;
let api_key = resolve_context_api_key(
args.context_api_key.as_deref(),
args.context_base_url.as_deref(),
)?;
Some(crate::transport::provider::ProviderConfig {
provider_type: args.context_provider.clone(),
api_key,
model,
base_url: args.context_base_url.clone(),
temperature: args.context_temperature.unwrap_or(0.0),
max_tokens: args.context_max_tokens.or(Some(4096)),
timeout_secs: args.context_timeout.unwrap_or(120),
})
} else {
None
};
Ok(ActorConfig {
mcp_server_bind: args.mcp_server.clone(),
agui_client_endpoint: args.agui_client_endpoint.clone(),
a2a_server_bind: args.a2a_server.clone(),
a2a_client_endpoint: args.a2a_client_endpoint.clone(),
mcp_client_command: args.mcp_client_command.clone(),
mcp_client_args: args.mcp_client_args.clone(),
mcp_client_endpoint: args.mcp_client_endpoint.clone(),
headers,
raw_synthesize: args.raw_synthesize,
grace_period: args.grace_period.map(Into::into),
max_session: args.max_session.into(),
readiness_timeout: args.readiness_timeout.into(),
context_mode: args.context,
context_provider_config,
max_turns: args.max_turns,
context_system_prompt: args.context_system_prompt.clone(),
#[cfg(test)]
transport_factory: None,
})
}
fn resolve_context_api_key(
api_key: Option<&str>,
base_url: Option<&str>,
) -> Result<String, EngineError> {
if let Some(key) = api_key {
return Ok(if key.is_empty() {
"no-key".to_string()
} else {
key.to_string()
});
}
if base_url.is_some_and(is_local_endpoint) {
return Ok("no-key".to_string());
}
Err(EngineError::Driver(
"--context-api-key is required (omit only for local endpoints like localhost, \
127.0.0.1, or private IP ranges)"
.into(),
))
}
fn is_local_endpoint(url: &str) -> bool {
let host = url
.strip_prefix("http://")
.or_else(|| url.strip_prefix("https://"))
.unwrap_or(url)
.split('/')
.next()
.unwrap_or("")
.split(':')
.next()
.unwrap_or("");
matches!(
host,
"localhost" | "127.0.0.1" | "0.0.0.0" | "host.docker.internal"
) || host.starts_with("10.")
|| host.starts_with("192.168.")
|| (host.starts_with("172.")
&& host
.split('.')
.nth(1)
.and_then(|s| s.parse::<u8>().ok())
.is_some_and(|n| (16..=31).contains(&n)))
}
fn parse_cli_header(raw: &str, position: usize) -> Result<(String, String), EngineError> {
let Some((raw_key, raw_value)) = raw.split_once(':') else {
return Err(EngineError::Driver(format!(
"invalid --header value at position {position}: '{raw}' (expected KEY:VALUE)"
)));
};
let key = raw_key.trim();
let value = raw_value.trim();
if key.is_empty() {
return Err(EngineError::Driver(format!(
"invalid --header value at position {position}: empty header name"
)));
}
HeaderName::from_bytes(key.as_bytes()).map_err(|_| {
EngineError::Driver(format!(
"invalid --header value at position {position}: '{key}' is not a valid HTTP header name"
))
})?;
HeaderValue::from_str(value).map_err(|_| {
EngineError::Driver(format!(
"invalid --header value at position {position}: value for '{key}' is not a valid HTTP header value"
))
})?;
Ok((key.to_string(), value.to_string()))
}
fn mode_env_prefix(mode: &str) -> Option<&'static str> {
match mode {
"mcp_client" => Some("MCP_CLIENT"),
"a2a_client" => Some("A2A_CLIENT"),
"ag_ui_client" => Some("AGUI"),
_ => None,
}
}
fn merge_headers(
base: &[(String, String)],
overrides: &[(String, String)],
) -> Vec<(String, String)> {
let mut merged: Vec<(String, String)> = base.to_vec();
for (name, value) in overrides {
merged.retain(|(k, _)| !k.eq_ignore_ascii_case(name));
merged.push((name.clone(), value.clone()));
}
merged
}
fn collect_env_headers(prefix: &str) -> Vec<(String, String)> {
let mut headers = Vec::new();
let auth_var = format!("THOUGHTJACK_{prefix}_AUTHORIZATION");
if let Ok(value) = std::env::var(&auth_var) {
headers.push(("Authorization".to_string(), value));
}
let header_prefix = format!("THOUGHTJACK_{prefix}_HEADER_");
for (key, value) in std::env::vars() {
if let Some(suffix) = key.strip_prefix(&header_prefix) {
let header_name = suffix.replace('_', "-");
headers.push((header_name, value));
}
}
headers
}
fn resolve_headers_for_mode(base: &[(String, String)], mode: &str) -> Vec<(String, String)> {
let Some(prefix) = mode_env_prefix(mode) else {
return base.to_vec();
};
let env_headers = collect_env_headers(prefix);
merge_headers(base, &env_headers)
}
fn parse_mcp_client_args(raw: &str) -> Result<Vec<String>, EngineError> {
shlex::split(raw)
.ok_or_else(|| EngineError::Driver("invalid --mcp-client-args: unbalanced quotes".into()))
}
fn parse_mcp_client_command(raw: &str) -> Result<(String, Vec<String>), EngineError> {
let parts = shlex::split(raw).ok_or_else(|| {
EngineError::Driver("invalid --mcp-client-command: unbalanced quotes".into())
})?;
let mut iter = parts.into_iter();
let command = iter
.next()
.ok_or_else(|| EngineError::Driver("invalid --mcp-client-command: empty command".into()))?;
Ok((command, iter.collect()))
}
async fn wait_for_readiness_gate(
actor_name: &str,
gate_rx: Option<broadcast::Receiver<()>>,
) -> Result<(), EngineError> {
if let Some(mut rx) = gate_rx {
tracing::debug!(actor = %actor_name, "waiting for server readiness gate");
rx.recv().await.map_err(|err| {
EngineError::Phase(format!(
"readiness gate closed before actor '{actor_name}' started: {err}"
))
})?;
tracing::debug!(actor = %actor_name, "readiness gate opened");
}
Ok(())
}
pub(crate) struct ActorRunContext<'a> {
pub actor_index: usize,
pub document: oatf::Document,
pub config: &'a ActorConfig,
pub trace: SharedTrace,
pub extractor_store: ExtractorStore,
pub await_config: HashMap<usize, Vec<AwaitExtractor>>,
pub cancel: CancellationToken,
pub ready_tx: Option<oneshot::Sender<()>>,
pub gate_rx: Option<broadcast::Receiver<()>>,
pub events: &'a Arc<EventEmitter>,
}
#[allow(clippy::too_many_arguments, clippy::implicit_hasher)]
pub async fn run_actor(
actor_index: usize,
document: oatf::Document,
config: &ActorConfig,
trace: SharedTrace,
extractor_store: ExtractorStore,
await_config: HashMap<usize, Vec<AwaitExtractor>>,
cancel: CancellationToken,
ready_tx: Option<oneshot::Sender<()>>,
gate_rx: Option<broadcast::Receiver<()>>,
events: &Arc<EventEmitter>,
) -> Result<ActorResult, EngineError> {
let actors = document_actors(&document);
let actor = actors.get(actor_index).ok_or_else(|| {
EngineError::Driver(format!(
"actor index {actor_index} out of bounds (have {} actors)",
actors.len()
))
})?;
let actor_name = actor.name.clone();
let mode = actor.mode.clone();
events.emit(ThoughtJackEvent::ActorInit {
actor_name: actor_name.clone(),
mode: mode.clone(),
});
let ctx = ActorRunContext {
actor_index,
document,
config,
trace,
extractor_store,
await_config,
cancel,
ready_tx,
gate_rx,
events,
};
match mode.as_str() {
"mcp_server" => run_mcp_server_actor(&actor_name, ctx).await,
"ag_ui_client" => run_agui_client_actor(&actor_name, ctx).await,
"a2a_server" => run_a2a_server_actor(&actor_name, ctx).await,
"a2a_client" => run_a2a_client_actor(&actor_name, ctx).await,
"mcp_client" => run_mcp_client_actor(&actor_name, ctx).await,
other => Err(EngineError::Driver(format!(
"driver for mode '{other}' not yet implemented"
))),
}
}
async fn run_mcp_server_actor(
actor_name: &str,
ctx: ActorRunContext<'_>,
) -> Result<ActorResult, EngineError> {
let ActorRunContext {
actor_index,
document,
config,
trace,
extractor_store,
await_config,
cancel,
ready_tx,
gate_rx: _gate_rx,
events,
} = ctx;
let transport: Arc<dyn crate::transport::Transport> =
if let Some(ref bind_addr) = config.mcp_server_bind {
let http_config = HttpConfig {
bind_addr: bind_addr.clone(),
max_message_size: crate::transport::DEFAULT_MAX_MESSAGE_SIZE,
};
let (transport, addr) = HttpTransport::bind(http_config, cancel.clone())
.await
.map_err(|e| EngineError::Driver(format!("HTTP bind failed: {e}")))?;
events.emit(ThoughtJackEvent::ActorReady {
actor_name: actor_name.to_string(),
bind_address: addr.to_string(),
});
if let Some(tx) = ready_tx {
let _ = tx.send(());
}
Arc::new(transport)
} else {
#[cfg(test)]
let transport: Arc<dyn crate::transport::Transport> = config
.transport_factory
.as_ref()
.map_or_else(|| Arc::new(StdioTransport::new()) as _, |f| (f.0)());
#[cfg(not(test))]
let transport: Arc<dyn crate::transport::Transport> = Arc::new(StdioTransport::new());
events.emit(ThoughtJackEvent::ActorReady {
actor_name: actor_name.to_string(),
bind_address: "stdio".to_string(),
});
if let Some(tx) = ready_tx {
let _ = tx.send(());
}
transport
};
let driver = McpServerDriver::new(Arc::clone(&transport), config.raw_synthesize);
let entry_action_sender = driver.entry_action_sender();
let engine = PhaseEngine::new(document, actor_index);
let phase_count = engine.actor().phases.len();
events.emit(ThoughtJackEvent::ActorStarted {
actor_name: actor_name.to_string(),
phase_count,
});
let loop_config = PhaseLoopConfig {
trace,
extractor_store,
actor_name: actor_name.to_string(),
await_extractors_config: await_config,
cancel,
entry_action_sender: Some(Box::new(entry_action_sender)),
events: Arc::clone(events),
tool_watch_tx: None,
a2a_skill_tx: None,
context_mode: false,
};
let mut phase_loop = PhaseLoop::new(driver, engine, loop_config);
let result = phase_loop.run().await?;
events.emit(ThoughtJackEvent::ActorCompleted {
actor_name: actor_name.to_string(),
reason: result.termination.to_string(),
phases_completed: result.phases_completed,
});
Ok(result)
}
async fn run_agui_client_actor(
actor_name: &str,
ctx: ActorRunContext<'_>,
) -> Result<ActorResult, EngineError> {
let ActorRunContext {
actor_index,
document,
config,
trace,
extractor_store,
await_config,
cancel,
ready_tx,
gate_rx,
events,
} = ctx;
let endpoint = config.agui_client_endpoint.as_deref().ok_or_else(|| {
EngineError::Driver("ag_ui_client mode requires --agui-client-endpoint".to_string())
})?;
if let Some(tx) = ready_tx {
let _ = tx.send(());
}
wait_for_readiness_gate(actor_name, gate_rx).await?;
events.emit(ThoughtJackEvent::ActorReady {
actor_name: actor_name.to_string(),
bind_address: endpoint.to_string(),
});
let headers = resolve_headers_for_mode(&config.headers, "ag_ui_client");
let driver = agui::create_agui_driver(endpoint, headers, config.raw_synthesize);
let engine = PhaseEngine::new(document, actor_index);
let phase_count = engine.actor().phases.len();
events.emit(ThoughtJackEvent::ActorStarted {
actor_name: actor_name.to_string(),
phase_count,
});
let loop_config = PhaseLoopConfig {
trace,
extractor_store,
actor_name: actor_name.to_string(),
await_extractors_config: await_config,
cancel,
entry_action_sender: None,
events: Arc::clone(events),
tool_watch_tx: None,
a2a_skill_tx: None,
context_mode: false,
};
let mut phase_loop = PhaseLoop::new(driver, engine, loop_config);
let result = phase_loop.run().await?;
events.emit(ThoughtJackEvent::ActorCompleted {
actor_name: actor_name.to_string(),
reason: result.termination.to_string(),
phases_completed: result.phases_completed,
});
Ok(result)
}
async fn run_a2a_server_actor(
actor_name: &str,
ctx: ActorRunContext<'_>,
) -> Result<ActorResult, EngineError> {
let ActorRunContext {
actor_index,
document,
config,
trace,
extractor_store,
await_config,
cancel,
ready_tx,
gate_rx: _gate_rx,
events,
} = ctx;
let bind_addr = config.a2a_server_bind.as_deref().ok_or_else(|| {
EngineError::Driver("A2A server actor requires --a2a-server <ADDR:PORT>".to_string())
})?;
let (bound_addr_tx, mut bound_addr_rx) = oneshot::channel();
let mut driver = a2a_server::create_a2a_server_driver(bind_addr, config.raw_synthesize);
if let Some(tx) = ready_tx {
driver.set_ready_sender(tx);
}
driver.set_bound_addr_sender(bound_addr_tx);
let engine = PhaseEngine::new(document, actor_index);
let phase_count = engine.actor().phases.len();
events.emit(ThoughtJackEvent::ActorStarted {
actor_name: actor_name.to_string(),
phase_count,
});
let loop_config = PhaseLoopConfig {
trace,
extractor_store,
actor_name: actor_name.to_string(),
await_extractors_config: await_config,
cancel,
entry_action_sender: None,
events: Arc::clone(events),
tool_watch_tx: None,
a2a_skill_tx: None,
context_mode: false,
};
let mut phase_loop = PhaseLoop::new(driver, engine, loop_config);
let run_fut = phase_loop.run();
tokio::pin!(run_fut);
let mut ready_emitted = false;
let result = loop {
tokio::select! {
ready = &mut bound_addr_rx, if !ready_emitted => {
if let Ok(addr) = ready {
events.emit(ThoughtJackEvent::ActorReady {
actor_name: actor_name.to_string(),
bind_address: addr.to_string(),
});
}
ready_emitted = true;
}
run_result = &mut run_fut => {
break run_result?;
}
}
};
events.emit(ThoughtJackEvent::ActorCompleted {
actor_name: actor_name.to_string(),
reason: result.termination.to_string(),
phases_completed: result.phases_completed,
});
Ok(result)
}
async fn run_a2a_client_actor(
actor_name: &str,
ctx: ActorRunContext<'_>,
) -> Result<ActorResult, EngineError> {
let ActorRunContext {
actor_index,
document,
config,
trace,
extractor_store,
await_config,
cancel,
ready_tx,
gate_rx,
events,
} = ctx;
let endpoint = config.a2a_client_endpoint.as_deref().ok_or_else(|| {
EngineError::Driver("a2a_client mode requires --a2a-client-endpoint".to_string())
})?;
if let Some(tx) = ready_tx {
let _ = tx.send(());
}
wait_for_readiness_gate(actor_name, gate_rx).await?;
events.emit(ThoughtJackEvent::ActorReady {
actor_name: actor_name.to_string(),
bind_address: endpoint.to_string(),
});
let headers = resolve_headers_for_mode(&config.headers, "a2a_client");
let driver = a2a_client::create_a2a_client_driver(endpoint, headers, config.raw_synthesize);
let engine = PhaseEngine::new(document, actor_index);
let phase_count = engine.actor().phases.len();
events.emit(ThoughtJackEvent::ActorStarted {
actor_name: actor_name.to_string(),
phase_count,
});
let loop_config = PhaseLoopConfig {
trace,
extractor_store,
actor_name: actor_name.to_string(),
await_extractors_config: await_config,
cancel,
entry_action_sender: None,
events: Arc::clone(events),
tool_watch_tx: None,
a2a_skill_tx: None,
context_mode: false,
};
let mut phase_loop = PhaseLoop::new(driver, engine, loop_config);
let result = phase_loop.run().await?;
events.emit(ThoughtJackEvent::ActorCompleted {
actor_name: actor_name.to_string(),
reason: result.termination.to_string(),
phases_completed: result.phases_completed,
});
Ok(result)
}
async fn run_mcp_client_actor(
actor_name: &str,
ctx: ActorRunContext<'_>,
) -> Result<ActorResult, EngineError> {
let ActorRunContext {
actor_index,
document,
config,
trace,
extractor_store,
await_config,
cancel,
ready_tx,
gate_rx,
events,
} = ctx;
let bind_address = config
.mcp_client_command
.as_deref()
.map(|command| format!("stdio:{command}"))
.or_else(|| config.mcp_client_endpoint.clone())
.ok_or_else(|| {
EngineError::Driver(
"mcp_client mode requires --mcp-client-command (stdio) \
or --mcp-client-endpoint (HTTP)"
.to_string(),
)
})?;
if let Some(tx) = ready_tx {
let _ = tx.send(());
}
wait_for_readiness_gate(actor_name, gate_rx).await?;
let driver = match (
config.mcp_client_command.as_deref(),
config.mcp_client_endpoint.as_deref(),
) {
(Some(command_raw), _) => {
let (command, mut args) = parse_mcp_client_command(command_raw)?;
let extra_args: Vec<String> = config
.mcp_client_args
.as_deref()
.map(parse_mcp_client_args)
.transpose()?
.unwrap_or_default();
args.extend(extra_args);
mcp_client::create_mcp_client_driver(
Some(command.as_str()),
&args,
None,
&[],
config.raw_synthesize,
)?
}
(None, Some(endpoint)) => {
let headers = resolve_headers_for_mode(&config.headers, "mcp_client");
mcp_client::create_mcp_client_driver(
None,
&[],
Some(endpoint),
&headers,
config.raw_synthesize,
)?
}
(None, None) => unreachable!("validated above"),
};
events.emit(ThoughtJackEvent::ActorReady {
actor_name: actor_name.to_string(),
bind_address,
});
let engine = PhaseEngine::new(document, actor_index);
let phase_count = engine.actor().phases.len();
events.emit(ThoughtJackEvent::ActorStarted {
actor_name: actor_name.to_string(),
phase_count,
});
let loop_config = PhaseLoopConfig {
trace,
extractor_store,
actor_name: actor_name.to_string(),
await_extractors_config: await_config,
cancel,
entry_action_sender: None,
events: Arc::clone(events),
tool_watch_tx: None,
a2a_skill_tx: None,
context_mode: false,
};
let mut phase_loop = PhaseLoop::new(driver, engine, loop_config);
let result = phase_loop.run().await?;
events.emit(ThoughtJackEvent::ActorCompleted {
actor_name: actor_name.to_string(),
reason: result.termination.to_string(),
phases_completed: result.phases_completed,
});
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_actor_config_maps_flags() {
let args = ExecutionArgs {
mcp_server: Some("0.0.0.0:8080".to_string()),
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
agui_client_endpoint: Some("http://localhost:3000".to_string()),
a2a_server: None,
a2a_client_endpoint: None,
grace_period: None,
max_session: humantime::Duration::from(Duration::from_secs(300)),
readiness_timeout: humantime::Duration::from(Duration::from_secs(30)),
output: None,
header: vec!["Authorization: Bearer token123".to_string()],
no_semantic: false,
raw_synthesize: true,
metrics_port: None,
events_file: None,
export_trace: None,
progress: crate::cli::args::ProgressLevel::Off,
context: false,
context_model: None,
context_api_key: None,
context_base_url: None,
context_provider: "openai".to_string(),
context_temperature: None,
context_max_tokens: None,
context_system_prompt: None,
context_timeout: None,
max_turns: None,
};
let config = build_actor_config(&args).expect("valid headers should parse");
assert_eq!(config.mcp_server_bind, Some("0.0.0.0:8080".to_string()));
assert_eq!(
config.agui_client_endpoint,
Some("http://localhost:3000".to_string())
);
assert!(config.raw_synthesize);
assert_eq!(config.headers.len(), 1);
assert_eq!(config.headers[0].0, "Authorization");
assert_eq!(config.headers[0].1, "Bearer token123");
assert_eq!(config.max_session, Duration::from_secs(300));
}
#[tokio::test]
async fn unsupported_mode_errors() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
actors:
- name: unknown_actor
mode: future_protocol_client
phases:
- name: setup
state:
tools:
- name: test_tool
description: "test"
inputSchema:
type: object
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: None,
a2a_server_bind: None,
a2a_client_endpoint: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(300),
readiness_timeout: Duration::from_secs(30),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: None,
};
let result = run_actor(
0,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
CancellationToken::new(),
None,
None,
&Arc::new(EventEmitter::noop()),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("not yet implemented"),
"Expected 'not yet implemented', got: {err}"
);
}
#[tokio::test]
async fn actor_index_out_of_bounds_returns_error() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
actors:
- name: server_actor
mode: mcp_server
phases:
- name: setup
state:
tools: []
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: None,
a2a_server_bind: None,
a2a_client_endpoint: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(300),
readiness_timeout: Duration::from_secs(30),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: Some(null_transport_factory()),
};
let result = run_actor(
1,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
CancellationToken::new(),
None,
None,
&Arc::new(EventEmitter::noop()),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("actor index 1 out of bounds"),
"unexpected error: {err}"
);
assert!(err.contains("have 1 actors"), "unexpected error: {err}");
}
#[tokio::test]
async fn agui_client_requires_endpoint() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
actors:
- name: agui_actor
mode: ag_ui_client
phases:
- name: setup
state:
run_agent_input:
messages:
- role: user
content: "Hello"
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: None,
a2a_server_bind: None,
a2a_client_endpoint: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(300),
readiness_timeout: Duration::from_secs(30),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: None,
};
let result = run_actor(
0,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
CancellationToken::new(),
None,
None,
&Arc::new(EventEmitter::noop()),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("--agui-client-endpoint"),
"Expected endpoint error, got: {err}"
);
}
#[tokio::test]
async fn agui_client_fails_when_readiness_gate_closes() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
actors:
- name: agui_actor
mode: ag_ui_client
phases:
- name: setup
state:
run_agent_input:
messages:
- role: user
content: "Hello"
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: Some("http://localhost:3000".to_string()),
a2a_server_bind: None,
a2a_client_endpoint: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(300),
readiness_timeout: Duration::from_secs(30),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: None,
};
let (tx, rx) = tokio::sync::broadcast::channel(1);
drop(tx);
let result = run_actor(
0,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
CancellationToken::new(),
None,
Some(rx),
&Arc::new(EventEmitter::noop()),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("readiness gate closed"),
"Expected gate closed error, got: {err}"
);
}
#[tokio::test]
async fn a2a_client_fails_when_readiness_gate_closes() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
actors:
- name: a2a_actor
mode: a2a_client
phases:
- name: send
state:
task_message:
role: user
parts:
- kind: text
text: "Hello"
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: None,
a2a_server_bind: None,
a2a_client_endpoint: Some("http://localhost:9090".to_string()),
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(300),
readiness_timeout: Duration::from_secs(30),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: None,
};
let (tx, rx) = tokio::sync::broadcast::channel(1);
drop(tx);
let result = run_actor(
0,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
CancellationToken::new(),
None,
Some(rx),
&Arc::new(EventEmitter::noop()),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("readiness gate closed"),
"Expected gate closed error, got: {err}"
);
}
#[tokio::test]
async fn mcp_client_fails_when_readiness_gate_closes() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
actors:
- name: mcp_actor
mode: mcp_client
phases:
- name: probe
state:
actions:
- list_tools
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: None,
a2a_server_bind: None,
a2a_client_endpoint: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: Some("http://localhost:8080/mcp".to_string()),
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(300),
readiness_timeout: Duration::from_secs(30),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: None,
};
let (tx, rx) = tokio::sync::broadcast::channel(1);
drop(tx);
let result = run_actor(
0,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
CancellationToken::new(),
None,
Some(rx),
&Arc::new(EventEmitter::noop()),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("readiness gate closed"),
"Expected gate closed error, got: {err}"
);
}
#[tokio::test]
async fn a2a_server_mode_recognized() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
actors:
- name: a2a_actor
mode: a2a_server
phases:
- name: serve
state:
agent_card:
name: "Test Agent"
skills: []
defaultInputModes: ["text/plain"]
defaultOutputModes: ["text/plain"]
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: None,
a2a_server_bind: Some("127.0.0.1:0".to_string()),
a2a_client_endpoint: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(300),
readiness_timeout: Duration::from_secs(30),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: None,
};
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
cancel_clone.cancel();
});
let result = run_actor(
0,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
cancel,
None,
None,
&Arc::new(EventEmitter::noop()),
)
.await;
let actor_result = result.expect("a2a_server actor should run to a graceful termination");
assert_eq!(actor_result.actor_name, "a2a_actor");
assert!(
matches!(
actor_result.termination,
crate::engine::types::TerminationReason::Cancelled
| crate::engine::types::TerminationReason::TerminalPhaseReached
),
"unexpected termination: {:?}",
actor_result.termination
);
}
#[tokio::test]
async fn a2a_client_requires_endpoint() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
actors:
- name: a2a_actor
mode: a2a_client
phases:
- name: send
state:
task_message:
role: user
parts:
- kind: text
text: "Hello"
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: None,
a2a_server_bind: None,
a2a_client_endpoint: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(300),
readiness_timeout: Duration::from_secs(30),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: None,
};
let result = run_actor(
0,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
CancellationToken::new(),
None,
None,
&Arc::new(EventEmitter::noop()),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("--a2a-client-endpoint"),
"Expected endpoint error, got: {err}"
);
}
#[tokio::test]
async fn mcp_server_stdio_runs_to_completion() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
mode: mcp_server
state:
tools:
- name: test_tool
description: "test"
inputSchema:
type: object
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: None,
a2a_server_bind: None,
a2a_client_endpoint: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(5),
readiness_timeout: Duration::from_secs(5),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: Some(null_transport_factory()),
};
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(200)).await;
cancel_clone.cancel();
});
let result = tokio::time::timeout(
Duration::from_secs(10),
run_actor(
0,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
cancel,
None,
None,
&Arc::new(EventEmitter::noop()),
),
)
.await
.expect("test timed out — stdio actor did not respond to cancellation within 10s");
let actor_result = result.expect("mcp_server actor should run to a graceful termination");
assert_eq!(actor_result.actor_name, "default");
assert!(
matches!(
actor_result.termination,
crate::engine::types::TerminationReason::Cancelled
| crate::engine::types::TerminationReason::TerminalPhaseReached
),
"unexpected termination: {:?}",
actor_result.termination
);
}
#[test]
fn build_actor_config_parses_multiple_headers() {
let args = ExecutionArgs {
mcp_server: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
agui_client_endpoint: None,
a2a_server: None,
a2a_client_endpoint: None,
grace_period: None,
max_session: humantime::Duration::from(Duration::from_secs(300)),
readiness_timeout: humantime::Duration::from(Duration::from_secs(30)),
output: None,
header: vec![
"Authorization: Bearer token123".to_string(),
"X-Custom: value with : colons".to_string(),
"Accept : application/json".to_string(),
],
no_semantic: false,
raw_synthesize: false,
metrics_port: None,
events_file: None,
export_trace: None,
progress: crate::cli::args::ProgressLevel::Off,
context: false,
context_model: None,
context_api_key: None,
context_base_url: None,
context_provider: "openai".to_string(),
context_temperature: None,
context_max_tokens: None,
context_system_prompt: None,
context_timeout: None,
max_turns: None,
};
let config = build_actor_config(&args).expect("valid headers should parse");
assert_eq!(config.headers.len(), 3);
assert_eq!(config.headers[0].0, "Authorization");
assert_eq!(config.headers[0].1, "Bearer token123");
assert_eq!(config.headers[1].0, "X-Custom");
assert_eq!(config.headers[1].1, "value with : colons");
assert_eq!(config.headers[2].0, "Accept");
assert_eq!(config.headers[2].1, "application/json");
}
#[test]
fn build_actor_config_header_without_colon_rejected() {
let args = ExecutionArgs {
mcp_server: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
agui_client_endpoint: None,
a2a_server: None,
a2a_client_endpoint: None,
grace_period: None,
max_session: humantime::Duration::from(Duration::from_secs(300)),
readiness_timeout: humantime::Duration::from(Duration::from_secs(30)),
output: None,
header: vec!["NoColonHere".to_string(), "Valid: header".to_string()],
no_semantic: false,
raw_synthesize: false,
metrics_port: None,
events_file: None,
export_trace: None,
progress: crate::cli::args::ProgressLevel::Off,
context: false,
context_model: None,
context_api_key: None,
context_base_url: None,
context_provider: "openai".to_string(),
context_temperature: None,
context_max_tokens: None,
context_system_prompt: None,
context_timeout: None,
max_turns: None,
};
let err = build_actor_config(&args).expect_err("missing colon should be rejected");
assert!(
err.to_string().contains("expected KEY:VALUE"),
"unexpected error: {err}"
);
}
#[test]
fn build_actor_config_invalid_header_name_rejected() {
let args = ExecutionArgs {
mcp_server: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
agui_client_endpoint: None,
a2a_server: None,
a2a_client_endpoint: None,
grace_period: None,
max_session: humantime::Duration::from(Duration::from_secs(300)),
readiness_timeout: humantime::Duration::from(Duration::from_secs(30)),
output: None,
header: vec!["Bad Name: value".to_string()],
no_semantic: false,
raw_synthesize: false,
metrics_port: None,
events_file: None,
export_trace: None,
progress: crate::cli::args::ProgressLevel::Off,
context: false,
context_model: None,
context_api_key: None,
context_base_url: None,
context_provider: "openai".to_string(),
context_temperature: None,
context_max_tokens: None,
context_system_prompt: None,
context_timeout: None,
max_turns: None,
};
let err = build_actor_config(&args).expect_err("invalid header name should be rejected");
assert!(
err.to_string().contains("valid HTTP header name"),
"unexpected error: {err}"
);
}
#[test]
fn build_actor_config_invalid_header_value_rejected() {
let args = ExecutionArgs {
mcp_server: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
agui_client_endpoint: None,
a2a_server: None,
a2a_client_endpoint: None,
grace_period: None,
max_session: humantime::Duration::from(Duration::from_secs(300)),
readiness_timeout: humantime::Duration::from(Duration::from_secs(30)),
output: None,
header: vec!["X-Test: value\r\ninjected".to_string()],
no_semantic: false,
raw_synthesize: false,
metrics_port: None,
events_file: None,
export_trace: None,
progress: crate::cli::args::ProgressLevel::Off,
context: false,
context_model: None,
context_api_key: None,
context_base_url: None,
context_provider: "openai".to_string(),
context_temperature: None,
context_max_tokens: None,
context_system_prompt: None,
context_timeout: None,
max_turns: None,
};
let err = build_actor_config(&args).expect_err("invalid header value should be rejected");
assert!(
err.to_string().contains("valid HTTP header value"),
"unexpected error: {err}"
);
}
#[test]
fn build_actor_config_defaults() {
let args = ExecutionArgs {
mcp_server: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
agui_client_endpoint: None,
a2a_server: None,
a2a_client_endpoint: None,
grace_period: None,
max_session: humantime::Duration::from(Duration::from_secs(300)),
readiness_timeout: humantime::Duration::from(Duration::from_secs(30)),
output: None,
header: vec![],
no_semantic: false,
raw_synthesize: false,
metrics_port: None,
events_file: None,
export_trace: None,
progress: crate::cli::args::ProgressLevel::Off,
context: false,
context_model: None,
context_api_key: None,
context_base_url: None,
context_provider: "openai".to_string(),
context_temperature: None,
context_max_tokens: None,
context_system_prompt: None,
context_timeout: None,
max_turns: None,
};
let config = build_actor_config(&args).expect("empty header list should be valid");
assert!(config.mcp_server_bind.is_none());
assert!(config.agui_client_endpoint.is_none());
assert!(config.a2a_server_bind.is_none());
assert!(config.a2a_client_endpoint.is_none());
assert!(config.mcp_client_command.is_none());
assert!(config.mcp_client_args.is_none());
assert!(config.mcp_client_endpoint.is_none());
assert!(!config.raw_synthesize);
assert!(config.headers.is_empty());
assert!(config.grace_period.is_none());
}
#[tokio::test]
async fn mcp_client_requires_command_or_endpoint() {
let yaml = r#"
oatf: "0.1"
attack:
name: test
execution:
actors:
- name: mcp_actor
mode: mcp_client
phases:
- name: probe
state:
actions:
- list_tools
"#;
let doc = oatf::load(yaml).unwrap().document;
let config = ActorConfig {
mcp_server_bind: None,
agui_client_endpoint: None,
a2a_server_bind: None,
a2a_client_endpoint: None,
mcp_client_command: None,
mcp_client_args: None,
mcp_client_endpoint: None,
headers: vec![],
raw_synthesize: false,
grace_period: None,
max_session: Duration::from_secs(300),
readiness_timeout: Duration::from_secs(30),
context_mode: false,
context_provider_config: None,
max_turns: None,
context_system_prompt: None,
transport_factory: None,
};
let result = run_actor(
0,
doc,
&config,
SharedTrace::new(),
ExtractorStore::new(),
HashMap::new(),
CancellationToken::new(),
None,
None,
&Arc::new(EventEmitter::noop()),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("mcp_client mode requires"),
"Expected transport error, got: {err}"
);
}
#[test]
fn resolve_headers_passthrough_for_server_mode() {
let base = vec![("X-Custom".to_string(), "val".to_string())];
let resolved = resolve_headers_for_mode(&base, "mcp_server");
assert_eq!(resolved, base, "server modes should pass through unchanged");
}
#[test]
fn mode_env_prefix_maps_client_modes() {
assert_eq!(mode_env_prefix("mcp_client"), Some("MCP_CLIENT"));
assert_eq!(mode_env_prefix("a2a_client"), Some("A2A_CLIENT"));
assert_eq!(mode_env_prefix("ag_ui_client"), Some("AGUI"));
assert_eq!(mode_env_prefix("mcp_server"), None);
assert_eq!(mode_env_prefix("a2a_server"), None);
}
#[test]
fn merge_headers_override_replaces_base() {
let base = vec![("Authorization".to_string(), "Bearer cli-token".to_string())];
let overrides = vec![("Authorization".to_string(), "Bearer env-token".to_string())];
let merged = merge_headers(&base, &overrides);
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].1, "Bearer env-token");
}
#[test]
fn merge_headers_case_insensitive() {
let base = vec![("authorization".to_string(), "Bearer cli".to_string())];
let overrides = vec![("Authorization".to_string(), "Bearer env".to_string())];
let merged = merge_headers(&base, &overrides);
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].0, "Authorization");
assert_eq!(merged[0].1, "Bearer env");
}
#[test]
fn merge_headers_appends_new() {
let base = vec![("Accept".to_string(), "application/json".to_string())];
let overrides = vec![("X-Api-Key".to_string(), "key-123".to_string())];
let merged = merge_headers(&base, &overrides);
assert_eq!(merged.len(), 2);
assert_eq!(merged[0].0, "Accept");
assert_eq!(merged[1].0, "X-Api-Key");
}
#[test]
fn merge_headers_empty_override_preserves_base() {
let base = vec![
("Accept".to_string(), "application/json".to_string()),
("X-Custom".to_string(), "value".to_string()),
];
let merged = merge_headers(&base, &[]);
assert_eq!(merged, base);
}
#[test]
fn merge_headers_empty_base_uses_overrides() {
let overrides = vec![("Authorization".to_string(), "Bearer token".to_string())];
let merged = merge_headers(&[], &overrides);
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].1, "Bearer token");
}
#[test]
fn parse_mcp_client_args_respects_quotes() {
let parsed = parse_mcp_client_args(r#"--flag "two words" 'three words'"#).unwrap();
assert_eq!(
parsed,
vec![
"--flag".to_string(),
"two words".to_string(),
"three words".to_string(),
]
);
}
#[test]
fn parse_mcp_client_args_rejects_unbalanced_quotes() {
let err = parse_mcp_client_args("\"oops").unwrap_err();
assert!(err.to_string().contains("invalid --mcp-client-args"));
}
#[test]
fn parse_mcp_client_command_supports_inline_args() {
let (command, args) =
parse_mcp_client_command("npx -y @modelcontextprotocol/server-everything").unwrap();
assert_eq!(command, "npx");
assert_eq!(
args,
vec![
"-y".to_string(),
"@modelcontextprotocol/server-everything".to_string()
]
);
}
#[test]
fn parse_mcp_client_command_rejects_empty() {
let err = parse_mcp_client_command("").unwrap_err();
assert!(err.to_string().contains("invalid --mcp-client-command"));
}
#[test]
fn parse_mcp_client_command_rejects_unbalanced_quotes() {
let err = parse_mcp_client_command("\"oops").unwrap_err();
assert!(err.to_string().contains("invalid --mcp-client-command"));
}
}