use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use async_trait::async_trait;
use futures_util::stream::StreamExt;
use tokio::sync::{broadcast, Notify};
use tokio_stream::wrappers::BroadcastStream;
pub use crate::agent::MockAgentConfig;
use crate::connections::{Connection, ConnectionStrategy, StepStream};
use crate::content::Content;
use crate::error::{Error, Result};
use crate::hooks::{HookRunner, SessionContext};
use crate::tools::ToolRunner;
use crate::types::{
Step, StepSource, StepStatus, StepTarget, StepType, ToolCall, ToolResult, UsageMetadata,
};
const STEP_BROADCAST_CAPACITY: usize = 256;
#[derive(Debug, Clone)]
enum ScriptAction {
Text(String),
ToolCall {
name: String,
args: serde_json::Value,
},
}
#[derive(Debug, Clone, Default)]
pub struct ScriptedTurn {
actions: Vec<ScriptAction>,
usage: Option<UsageMetadata>,
}
impl ScriptedTurn {
pub fn new() -> Self {
Self::default()
}
pub fn text(mut self, text: impl Into<String>) -> Self {
self.actions.push(ScriptAction::Text(text.into()));
self
}
pub fn tool_call(mut self, name: impl Into<String>, args: serde_json::Value) -> Self {
self.actions.push(ScriptAction::ToolCall {
name: name.into(),
args,
});
self
}
pub fn with_usage(mut self, usage: UsageMetadata) -> Self {
self.usage = Some(usage);
self
}
fn content(&self) -> String {
let mut out = String::new();
for a in &self.actions {
if let ScriptAction::Text(t) = a {
out.push_str(t);
}
}
out
}
}
#[derive(Default)]
pub struct MockConnectionBuilder {
turns: Vec<ScriptedTurn>,
conversation_id: Option<String>,
}
impl MockConnectionBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn turn(mut self, f: impl FnOnce(ScriptedTurn) -> ScriptedTurn) -> Self {
self.turns.push(f(ScriptedTurn::new()));
self
}
pub fn push_turn(mut self, turn: ScriptedTurn) -> Self {
self.turns.push(turn);
self
}
pub fn turns(mut self, turns: Vec<ScriptedTurn>) -> Self {
self.turns = turns;
self
}
pub fn conversation_id(mut self, id: impl Into<String>) -> Self {
self.conversation_id = Some(id.into());
self
}
pub fn build(self) -> MockConnectionStrategy {
MockConnectionStrategy {
turns: Arc::new(self.turns),
conversation_id: self
.conversation_id
.unwrap_or_else(|| "mock-conversation".to_string()),
runners: MockRunners::default(),
}
}
}
#[derive(Default, Clone)]
pub struct MockRunners {
pub tool_runner: Option<Arc<ToolRunner>>,
pub hook_runner: Option<Arc<HookRunner>>,
pub session_ctx: Option<SessionContext>,
}
pub struct MockConnectionStrategy {
turns: Arc<Vec<ScriptedTurn>>,
conversation_id: String,
runners: MockRunners,
}
impl MockConnectionStrategy {
pub fn with_runners(mut self, runners: MockRunners) -> Self {
self.runners = runners;
self
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl ConnectionStrategy for MockConnectionStrategy {
async fn connect(&self) -> Result<Arc<dyn Connection>> {
let (steps_tx, _) = broadcast::channel::<Step>(STEP_BROADCAST_CAPACITY);
let inner = Arc::new(MockInner {
turns: self.turns.clone(),
next_turn: AtomicUsize::new(0),
step_index: AtomicUsize::new(0),
steps: steps_tx,
idle: AtomicBool::new(true),
idle_notify: Notify::new(),
conversation_id: self.conversation_id.clone().into(),
runners: self.runners.clone(),
});
Ok(Arc::new(MockConnection { inner }))
}
}
pub struct MockConnection {
inner: Arc<MockInner>,
}
struct MockInner {
turns: Arc<Vec<ScriptedTurn>>,
next_turn: AtomicUsize,
step_index: AtomicUsize,
steps: broadcast::Sender<Step>,
idle: AtomicBool,
idle_notify: Notify,
conversation_id: Arc<str>,
runners: MockRunners,
}
impl MockConnection {
pub fn builder() -> MockConnectionBuilder {
MockConnectionBuilder::new()
}
}
impl MockInner {
fn alloc_step_index(&self) -> u32 {
self.step_index.fetch_add(1, Ordering::Relaxed) as u32
}
fn emit(&self, step: Step) {
let _ = self.steps.send(step);
}
async fn run_turn(&self, turn: ScriptedTurn) {
self.idle.store(false, Ordering::Release);
let traj = uuid::Uuid::new_v4().to_string();
for action in &turn.actions {
match action {
ScriptAction::Text(delta) => {
self.emit(text_delta_step(&traj, self.alloc_step_index(), delta));
}
ScriptAction::ToolCall { name, args } => {
let tool_call = ToolCall {
name: name.clone(),
args: args.clone(),
id: None,
canonical_path: None,
};
self.emit(tool_call_step(self.alloc_step_index(), tool_call.clone()));
if let Some(runner) = self.runners.tool_runner.as_ref() {
let _result = self.dispatch_tool(&tool_call, runner).await;
}
}
}
}
self.emit(terminal_step(
traj,
self.alloc_step_index(),
&turn.content(),
turn.usage,
));
self.idle.store(true, Ordering::Release);
self.idle_notify.notify_waiters();
}
async fn dispatch_tool(&self, call: &ToolCall, runner: &ToolRunner) -> ToolResult {
let turn_ctx = self
.runners
.session_ctx
.as_ref()
.map(|s| s.child())
.unwrap_or_default();
let (decision, op_ctx) = if let Some(hooks) = self.runners.hook_runner.as_ref() {
hooks.dispatch_pre_tool_call(&turn_ctx, call).await
} else {
(crate::types::HookResult::allow(), turn_ctx.clone())
};
let result = if !decision.allow {
ToolResult::err(call.name.clone(), call.id.clone(), decision.message.clone())
} else {
match runner.execute(&call.name, call.args.clone()).await {
Ok(v) => ToolResult::ok(call.name.clone(), call.id.clone(), v),
Err(e) => ToolResult::err(call.name.clone(), call.id.clone(), e.to_string()),
}
};
if let Some(hooks) = self.runners.hook_runner.as_ref() {
hooks.dispatch_post_tool_call(&op_ctx, &result).await;
}
result
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl Connection for MockConnection {
fn is_idle(&self) -> bool {
self.inner.idle.load(Ordering::Acquire)
}
fn conversation_id(&self) -> &str {
&self.inner.conversation_id
}
async fn send(&self, _content: Content) -> Result<()> {
let idx = self.inner.next_turn.fetch_add(1, Ordering::Relaxed);
let turn = self.inner.turns.get(idx).cloned().unwrap_or_default();
let inner = self.inner.clone();
crate::runtime::spawn(async move {
inner.run_turn(turn).await;
});
Ok(())
}
async fn send_trigger(&self, content: String) -> Result<()> {
self.send(Content::text(content)).await
}
async fn send_tool_results(&self, _results: Vec<ToolResult>) -> Result<()> {
Ok(())
}
fn subscribe_steps(&self) -> StepStream {
let rx = self.inner.steps.subscribe();
let mapped = BroadcastStream::new(rx)
.map(|r| r.map_err(|e| Error::other(format!("mock step lag: {e}"))));
#[cfg(not(target_arch = "wasm32"))]
{
mapped.boxed()
}
#[cfg(target_arch = "wasm32")]
{
mapped.boxed_local()
}
}
async fn wait_for_idle(&self) -> Result<()> {
loop {
if self.is_idle() {
return Ok(());
}
self.inner.idle_notify.notified().await;
}
}
async fn shutdown(&self) -> Result<()> {
self.inner.idle.store(true, Ordering::Release);
self.inner.idle_notify.notify_waiters();
Ok(())
}
}
fn text_delta_step(traj: &str, idx: u32, delta: &str) -> Step {
Step {
id: traj.to_string(),
step_index: idx,
kind: StepType::TextResponse,
source: StepSource::Model,
target: StepTarget::User,
status: StepStatus::Active,
content: String::new(),
content_delta: delta.to_string(),
thinking: String::new(),
thinking_delta: String::new(),
tool_calls: Vec::new(),
error: String::new(),
is_complete_response: Some(false),
structured_output: None,
usage_metadata: None,
}
}
fn tool_call_step(idx: u32, tc: ToolCall) -> Step {
Step {
id: String::new(),
step_index: idx,
kind: StepType::ToolCall,
source: StepSource::Model,
target: StepTarget::Environment,
status: StepStatus::Done,
content: String::new(),
content_delta: String::new(),
thinking: String::new(),
thinking_delta: String::new(),
tool_calls: vec![tc],
error: String::new(),
is_complete_response: Some(false),
structured_output: None,
usage_metadata: None,
}
}
fn terminal_step(traj: String, idx: u32, content: &str, usage: Option<UsageMetadata>) -> Step {
Step {
id: traj,
step_index: idx,
kind: StepType::TextResponse,
source: StepSource::Model,
target: StepTarget::User,
status: StepStatus::Done,
content: content.to_string(),
content_delta: String::new(),
thinking: String::new(),
thinking_delta: String::new(),
tool_calls: Vec::new(),
error: String::new(),
is_complete_response: Some(true),
structured_output: None,
usage_metadata: usage,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::Agent;
use crate::policy;
use crate::tools::ClosureTool;
use parking_lot::Mutex;
use serde_json::json;
use std::sync::atomic::{AtomicBool, AtomicUsize};
#[tokio::test]
async fn scripted_tool_call_flow_runs_offline() {
let count = Arc::new(AtomicUsize::new(0));
let recorded: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let count_c = count.clone();
let recorded_c = recorded.clone();
let record_fact = ClosureTool::new(
"record_fact",
"Persist a fact",
json!({"type": "object", "properties": {"fact": {"type": "string"}}}),
move |args, _ctx| {
let count_c = count_c.clone();
let recorded_c = recorded_c.clone();
async move {
count_c.fetch_add(1, Ordering::SeqCst);
let fact = args["fact"].as_str().unwrap_or_default().to_string();
*recorded_c.lock() = Some(fact);
Ok(json!({"ok": true}))
}
},
);
let backend = MockConnection::builder()
.turn(|t| {
t.tool_call("record_fact", json!({"fact": "the sky is blue"}))
.text("logged")
})
.build();
let agent = Agent::start_mock(
MockAgentConfig::new(backend)
.with_tool(record_fact)
.with_policies(vec![policy::allow_all()]),
)
.await
.expect("mock agent starts");
let reply = agent
.chat("remember a fact")
.await
.expect("chat starts")
.text()
.await
.expect("turn completes");
assert_eq!(
count.load(Ordering::SeqCst),
1,
"the scripted tool must run exactly once",
);
assert_eq!(
recorded.lock().as_deref(),
Some("the sky is blue"),
"the tool received the scripted args",
);
assert_eq!(reply, "logged", "the scripted terminal text is returned");
agent.shutdown().await.expect("clean shutdown");
}
#[tokio::test]
async fn denied_tool_call_does_not_execute() {
let ran = Arc::new(AtomicBool::new(false));
let ran_c = ran.clone();
let tool = ClosureTool::new(
"danger",
"A blocked tool",
json!({"type": "object"}),
move |_args, _ctx| {
let ran_c = ran_c.clone();
async move {
ran_c.store(true, Ordering::SeqCst);
Ok(json!({"ok": true}))
}
},
);
let backend = MockConnection::builder()
.turn(|t| t.tool_call("danger", json!({})).text("attempted"))
.build();
let agent = Agent::start_mock(
MockAgentConfig::new(backend)
.with_tool(tool)
.with_policies(vec![policy::deny_all()]),
)
.await
.expect("mock agent starts");
let reply = agent.chat("go").await.unwrap().text().await.unwrap();
assert!(
!ran.load(Ordering::SeqCst),
"a denied tool must NOT execute its body",
);
assert_eq!(reply, "attempted", "the turn still completes");
agent.shutdown().await.unwrap();
}
#[tokio::test]
async fn scripted_tool_call_is_visible_on_the_stream() {
use futures_util::StreamExt;
let tool = ClosureTool::new(
"search",
"Search",
json!({"type": "object", "properties": {"q": {"type": "string"}}}),
|_args, _ctx| async move { Ok(json!({"hits": 0})) },
);
let backend = MockConnection::builder()
.turn(|t| t.tool_call("search", json!({"q": "rust"})).text("none found"))
.build();
let agent = Agent::start_mock(
MockAgentConfig::new(backend)
.with_tool(tool)
.with_policies(vec![policy::allow_all()]),
)
.await
.unwrap();
let resp = agent.chat("find rust").await.unwrap();
let mut calls = resp.tool_calls();
let first = calls
.next()
.await
.expect("a tool call is surfaced")
.expect("ok");
assert_eq!(first.name, "search");
assert_eq!(first.args, json!({"q": "rust"}));
agent.shutdown().await.unwrap();
}
#[tokio::test]
async fn turns_replay_in_order_with_usage() {
let backend = MockConnection::builder()
.turn(|t| {
t.text("first").with_usage(UsageMetadata {
total_token_count: Some(10),
..Default::default()
})
})
.turn(|t| {
t.text("second").with_usage(UsageMetadata {
total_token_count: Some(20),
..Default::default()
})
})
.build();
let agent = Agent::start_mock(MockAgentConfig::new(backend))
.await
.expect("mock agent starts");
let r1 = agent.chat("a").await.unwrap().text().await.unwrap();
assert_eq!(r1, "first");
let r2 = agent.chat("b").await.unwrap().text().await.unwrap();
assert_eq!(r2, "second");
assert_eq!(
agent.cumulative_usage().total_token_count,
Some(30),
"10 + 20, each turn counted once",
);
agent.shutdown().await.unwrap();
}
}