use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use async_trait::async_trait;
use tokio::sync::{broadcast, Notify};
pub use crate::agent::MockAgentConfig;
use crate::connections::{Connection, ConnectionStrategy, StepStream};
use crate::content::Content;
use crate::error::Result;
use crate::types::{Step, StepStatus, ToolCall, ToolResult, UsageMetadata};
const STEP_BROADCAST_CAPACITY: usize = 256;
#[derive(Debug, Clone)]
enum ScriptAction {
Text(String),
ToolCall {
name: String,
args: serde_json::Value,
},
}
#[derive(Debug, Clone, Default)]
pub struct ScriptedTurn {
actions: Vec<ScriptAction>,
usage: Option<UsageMetadata>,
}
impl ScriptedTurn {
pub fn new() -> Self {
Self::default()
}
pub fn text(mut self, text: impl Into<String>) -> Self {
self.actions.push(ScriptAction::Text(text.into()));
self
}
pub fn tool_call(mut self, name: impl Into<String>, args: serde_json::Value) -> Self {
self.actions.push(ScriptAction::ToolCall {
name: name.into(),
args,
});
self
}
pub fn with_usage(mut self, usage: UsageMetadata) -> Self {
self.usage = Some(usage);
self
}
fn content(&self) -> String {
let mut out = String::new();
for a in &self.actions {
if let ScriptAction::Text(t) = a {
out.push_str(t);
}
}
out
}
}
#[derive(Default)]
pub struct MockConnectionBuilder {
turns: Vec<ScriptedTurn>,
conversation_id: Option<String>,
}
impl MockConnectionBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn turn(mut self, f: impl FnOnce(ScriptedTurn) -> ScriptedTurn) -> Self {
self.turns.push(f(ScriptedTurn::new()));
self
}
pub fn push_turn(mut self, turn: ScriptedTurn) -> Self {
self.turns.push(turn);
self
}
pub fn turns(mut self, turns: Vec<ScriptedTurn>) -> Self {
self.turns = turns;
self
}
pub fn conversation_id(mut self, id: impl Into<String>) -> Self {
self.conversation_id = Some(id.into());
self
}
pub fn build(self) -> MockConnectionStrategy {
MockConnectionStrategy {
turns: Arc::new(self.turns),
conversation_id: self
.conversation_id
.unwrap_or_else(|| "mock-conversation".to_string()),
runners: MockRunners::default(),
}
}
}
pub type MockRunners = crate::backends::BackendRunners;
pub struct MockConnectionStrategy {
turns: Arc<Vec<ScriptedTurn>>,
conversation_id: String,
runners: MockRunners,
}
impl MockConnectionStrategy {
pub fn with_runners(mut self, runners: MockRunners) -> Self {
self.runners = runners;
self
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl ConnectionStrategy for MockConnectionStrategy {
async fn connect(&self) -> Result<Arc<dyn Connection>> {
let (steps_tx, _) = broadcast::channel::<Step>(STEP_BROADCAST_CAPACITY);
let inner = Arc::new(MockInner {
turns: self.turns.clone(),
next_turn: AtomicUsize::new(0),
step_index: AtomicUsize::new(0),
steps: steps_tx,
idle: AtomicBool::new(true),
idle_notify: Notify::new(),
conversation_id: self.conversation_id.clone().into(),
runners: self.runners.clone(),
});
Ok(Arc::new(MockConnection { inner }))
}
}
pub struct MockConnection {
inner: Arc<MockInner>,
}
struct MockInner {
turns: Arc<Vec<ScriptedTurn>>,
next_turn: AtomicUsize,
step_index: AtomicUsize,
steps: broadcast::Sender<Step>,
idle: AtomicBool,
idle_notify: Notify,
conversation_id: Arc<str>,
runners: MockRunners,
}
impl MockConnection {
pub fn builder() -> MockConnectionBuilder {
MockConnectionBuilder::new()
}
}
impl MockInner {
fn alloc_step_index(&self) -> u32 {
self.step_index.fetch_add(1, Ordering::Relaxed) as u32
}
fn emit(&self, step: Step) {
let _ = self.steps.send(step);
}
async fn run_turn(&self, prompt: Content) {
let turn_ctx = self
.runners
.session_ctx
.as_ref()
.map(|s| s.child())
.unwrap_or_default();
if let Some(denied) = crate::backends::dispatch::gate_pre_turn(
self.runners.hook_runner.as_ref(),
&turn_ctx,
&prompt,
)
.await
{
self.emit(Step::turn_error(self.alloc_step_index(), denied));
return;
}
let idx = self.next_turn.fetch_add(1, Ordering::Relaxed);
let turn = self.turns.get(idx).cloned().unwrap_or_default();
self.idle.store(false, Ordering::Release);
let traj = uuid::Uuid::new_v4().to_string();
for action in &turn.actions {
match action {
ScriptAction::Text(delta) => {
self.emit(Step::text_delta(&traj, self.alloc_step_index(), delta));
}
ScriptAction::ToolCall { name, args } => {
let tool_call = ToolCall {
name: name.clone(),
args: args.clone(),
id: None,
canonical_path: None,
};
self.emit(Step::tool_call(
self.alloc_step_index(),
tool_call.clone(),
StepStatus::Done,
));
if self.runners.tool_runner.is_some() {
let _result = self.dispatch_tool(&turn_ctx, &tool_call).await;
}
}
}
}
let content = turn.content();
let finished_turn = turn.actions.iter().any(|a| {
matches!(a, ScriptAction::ToolCall { name, .. }
if name == crate::builtins::FINISH_TOOL_NAME)
});
self.emit(Step::turn_complete(
traj,
self.alloc_step_index(),
StepStatus::Done,
content.as_str(),
"",
finished_turn,
None,
turn.usage,
));
crate::backends::dispatch::dispatch_post_turn(
self.runners.hook_runner.as_ref(),
&turn_ctx,
&content,
)
.await;
self.idle.store(true, Ordering::Release);
self.idle_notify.notify_waiters();
}
async fn dispatch_tool(
&self,
turn_ctx: &crate::hooks::TurnContext,
call: &ToolCall,
) -> ToolResult {
crate::backends::dispatch::dispatch_tool_call(
self.runners.tool_runner.as_ref(),
self.runners.hook_runner.as_ref(),
turn_ctx,
call,
)
.await
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl Connection for MockConnection {
fn is_idle(&self) -> bool {
self.inner.idle.load(Ordering::Acquire)
}
fn conversation_id(&self) -> &str {
&self.inner.conversation_id
}
async fn send(&self, content: Content) -> Result<()> {
let inner = self.inner.clone();
crate::runtime::spawn(async move {
inner.run_turn(content).await;
});
Ok(())
}
async fn send_trigger(&self, content: String) -> Result<()> {
self.send(Content::text(content)).await
}
async fn send_tool_results(&self, _results: Vec<ToolResult>) -> Result<()> {
Ok(())
}
fn subscribe_steps(&self) -> StepStream {
crate::backends::subscribe_step_stream(self.inner.steps.subscribe(), "mock")
}
async fn wait_for_idle(&self) -> Result<()> {
loop {
if self.is_idle() {
return Ok(());
}
self.inner.idle_notify.notified().await;
}
}
async fn shutdown(&self) -> Result<()> {
self.inner.idle.store(true, Ordering::Release);
self.inner.idle_notify.notify_waiters();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::Agent;
use crate::policy;
use crate::tools::ClosureTool;
use parking_lot::Mutex;
use serde_json::json;
use std::sync::atomic::{AtomicBool, AtomicUsize};
#[tokio::test]
async fn scripted_tool_call_flow_runs_offline() {
let count = Arc::new(AtomicUsize::new(0));
let recorded: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let count_c = count.clone();
let recorded_c = recorded.clone();
let record_fact = ClosureTool::new(
"record_fact",
"Persist a fact",
json!({"type": "object", "properties": {"fact": {"type": "string"}}}),
move |args, _ctx| {
let count_c = count_c.clone();
let recorded_c = recorded_c.clone();
async move {
count_c.fetch_add(1, Ordering::SeqCst);
let fact = args["fact"].as_str().unwrap_or_default().to_string();
*recorded_c.lock() = Some(fact);
Ok(json!({"ok": true}))
}
},
);
let backend = MockConnection::builder()
.turn(|t| {
t.tool_call("record_fact", json!({"fact": "the sky is blue"}))
.text("logged")
})
.build();
let agent = Agent::start_mock(
MockAgentConfig::new(backend)
.with_tool(record_fact)
.with_policies(vec![policy::allow_all()]),
)
.await
.expect("mock agent starts");
let reply = agent
.chat("remember a fact")
.await
.expect("chat starts")
.text()
.await
.expect("turn completes");
assert_eq!(
count.load(Ordering::SeqCst),
1,
"the scripted tool must run exactly once",
);
assert_eq!(
recorded.lock().as_deref(),
Some("the sky is blue"),
"the tool received the scripted args",
);
assert_eq!(reply, "logged", "the scripted terminal text is returned");
agent.shutdown().await.expect("clean shutdown");
}
#[tokio::test]
async fn denied_tool_call_does_not_execute() {
let ran = Arc::new(AtomicBool::new(false));
let ran_c = ran.clone();
let tool = ClosureTool::new(
"danger",
"A blocked tool",
json!({"type": "object"}),
move |_args, _ctx| {
let ran_c = ran_c.clone();
async move {
ran_c.store(true, Ordering::SeqCst);
Ok(json!({"ok": true}))
}
},
);
let backend = MockConnection::builder()
.turn(|t| t.tool_call("danger", json!({})).text("attempted"))
.build();
let agent = Agent::start_mock(
MockAgentConfig::new(backend)
.with_tool(tool)
.with_policies(vec![policy::deny_all()]),
)
.await
.expect("mock agent starts");
let reply = agent.chat("go").await.unwrap().text().await.unwrap();
assert!(
!ran.load(Ordering::SeqCst),
"a denied tool must NOT execute its body",
);
assert_eq!(reply, "attempted", "the turn still completes");
agent.shutdown().await.unwrap();
}
#[tokio::test]
async fn scripted_tool_call_is_visible_on_the_stream() {
use futures_util::StreamExt;
let tool = ClosureTool::new(
"search",
"Search",
json!({"type": "object", "properties": {"q": {"type": "string"}}}),
|_args, _ctx| async move { Ok(json!({"hits": 0})) },
);
let backend = MockConnection::builder()
.turn(|t| t.tool_call("search", json!({"q": "rust"})).text("none found"))
.build();
let agent = Agent::start_mock(
MockAgentConfig::new(backend)
.with_tool(tool)
.with_policies(vec![policy::allow_all()]),
)
.await
.unwrap();
let resp = agent.chat("find rust").await.unwrap();
let mut calls = resp.tool_calls();
let first = calls
.next()
.await
.expect("a tool call is surfaced")
.expect("ok");
assert_eq!(first.name, "search");
assert_eq!(first.args, json!({"q": "rust"}));
agent.shutdown().await.unwrap();
}
#[tokio::test]
async fn turns_replay_in_order_with_usage() {
let backend = MockConnection::builder()
.turn(|t| {
t.text("first").with_usage(UsageMetadata {
total_token_count: Some(10),
..Default::default()
})
})
.turn(|t| {
t.text("second").with_usage(UsageMetadata {
total_token_count: Some(20),
..Default::default()
})
})
.build();
let agent = Agent::start_mock(MockAgentConfig::new(backend))
.await
.expect("mock agent starts");
let r1 = agent.chat("a").await.unwrap().text().await.unwrap();
assert_eq!(r1, "first");
let r2 = agent.chat("b").await.unwrap().text().await.unwrap();
assert_eq!(r2, "second");
assert_eq!(
agent.cumulative_usage().total_token_count,
Some(30),
"10 + 20, each turn counted once",
);
agent.shutdown().await.unwrap();
}
}