use crate::agent::{Agent, ModelSize};
use crate::output::{AgentOutput, ContentBlock, Event, ToolResult, Usage};
use crate::sandbox::SandboxConfig;
use anyhow::Result;
use async_trait::async_trait;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
pub const DEFAULT_MODEL: &str = "mock-default";
pub const AVAILABLE_MODELS: &[&str] = &["mock-default", "mock-small", "mock-medium", "mock-large"];
#[derive(Debug, Clone)]
pub struct MockResponse {
pub result: Option<String>,
pub events: Vec<Event>,
pub session_id: String,
pub is_error: bool,
pub total_cost_usd: Option<f64>,
pub usage: Option<Usage>,
}
impl MockResponse {
pub fn text(text: &str) -> Self {
Self {
result: Some(text.to_string()),
events: vec![Event::Result {
success: true,
message: Some(text.to_string()),
duration_ms: Some(100),
num_turns: Some(1),
}],
session_id: uuid::Uuid::new_v4().to_string(),
is_error: false,
total_cost_usd: None,
usage: None,
}
}
pub fn error(message: &str) -> Self {
Self {
result: Some(message.to_string()),
events: vec![Event::Error {
message: message.to_string(),
details: None,
}],
session_id: uuid::Uuid::new_v4().to_string(),
is_error: true,
total_cost_usd: None,
usage: None,
}
}
pub fn with_events(events: Vec<Event>) -> Self {
let result = events.iter().find_map(|e| {
if let Event::Result { message, .. } = e {
message.clone()
} else {
None
}
});
Self {
result,
events,
session_id: uuid::Uuid::new_v4().to_string(),
is_error: false,
total_cost_usd: None,
usage: None,
}
}
pub fn with_usage(text: &str, usage: Usage) -> Self {
let mut resp = Self::text(text);
resp.usage = Some(usage);
resp
}
pub fn session_id(mut self, id: &str) -> Self {
self.session_id = id.to_string();
self
}
pub fn cost(mut self, cost: f64) -> Self {
self.total_cost_usd = Some(cost);
self
}
pub fn into_output(self) -> AgentOutput {
AgentOutput {
agent: "mock".to_string(),
session_id: self.session_id,
events: self.events,
result: self.result,
is_error: self.is_error,
exit_code: None,
error_message: None,
total_cost_usd: self.total_cost_usd,
usage: self.usage,
}
}
}
pub struct MockAgent {
system_prompt: String,
model: String,
root: Option<String>,
skip_permissions: bool,
output_format: Option<String>,
add_dirs: Vec<String>,
max_turns: Option<u32>,
sandbox: Option<SandboxConfig>,
responses: Mutex<Vec<MockResponse>>,
default_response: Mutex<MockResponse>,
pub run_count: AtomicUsize,
pub interactive_count: AtomicUsize,
pub resume_count: AtomicUsize,
pub last_prompt: Mutex<Option<String>>,
pub all_prompts: Mutex<Vec<String>>,
pub fail_on_run: bool,
pub run_error_message: String,
pub fail_on_interactive: bool,
pub delay: Option<Duration>,
}
impl MockAgent {
pub fn new() -> Self {
Self {
system_prompt: String::new(),
model: DEFAULT_MODEL.to_string(),
root: None,
skip_permissions: false,
output_format: None,
add_dirs: Vec::new(),
max_turns: None,
sandbox: None,
responses: Mutex::new(Vec::new()),
default_response: Mutex::new(MockResponse::text("")),
run_count: AtomicUsize::new(0),
interactive_count: AtomicUsize::new(0),
resume_count: AtomicUsize::new(0),
last_prompt: Mutex::new(None),
all_prompts: Mutex::new(Vec::new()),
fail_on_run: false,
run_error_message: "Mock agent run failed".to_string(),
fail_on_interactive: false,
delay: None,
}
}
pub fn builder() -> MockAgentBuilder {
MockAgentBuilder::new()
}
pub fn run_count(&self) -> usize {
self.run_count.load(Ordering::SeqCst)
}
pub fn interactive_count(&self) -> usize {
self.interactive_count.load(Ordering::SeqCst)
}
pub fn resume_count(&self) -> usize {
self.resume_count.load(Ordering::SeqCst)
}
pub fn last_prompt(&self) -> Option<String> {
self.last_prompt.lock().unwrap().clone()
}
pub fn all_prompts(&self) -> Vec<String> {
self.all_prompts.lock().unwrap().clone()
}
pub fn max_turns(&self) -> Option<u32> {
self.max_turns
}
pub fn root(&self) -> Option<&str> {
self.root.as_deref()
}
pub fn skip_permissions(&self) -> bool {
self.skip_permissions
}
pub fn output_format(&self) -> Option<&str> {
self.output_format.as_deref()
}
pub fn add_dirs(&self) -> &[String] {
&self.add_dirs
}
pub fn sandbox(&self) -> Option<&SandboxConfig> {
self.sandbox.as_ref()
}
}
impl Default for MockAgent {
fn default() -> Self {
Self::new()
}
}
pub struct MockAgentBuilder {
responses: Vec<MockResponse>,
default_response: Option<MockResponse>,
fail_on_run: bool,
run_error_message: String,
fail_on_interactive: bool,
delay: Option<Duration>,
model: Option<String>,
system_prompt: Option<String>,
}
impl MockAgentBuilder {
pub fn new() -> Self {
Self {
responses: Vec::new(),
default_response: None,
fail_on_run: false,
run_error_message: "Mock agent run failed".to_string(),
fail_on_interactive: false,
delay: None,
model: None,
system_prompt: None,
}
}
pub fn respond_with_text(mut self, text: &str) -> Self {
self.responses.push(MockResponse::text(text));
self
}
pub fn respond_with_error(mut self, message: &str) -> Self {
self.responses.push(MockResponse::error(message));
self
}
pub fn respond_with(mut self, response: MockResponse) -> Self {
self.responses.push(response);
self
}
pub fn default_response(mut self, response: MockResponse) -> Self {
self.default_response = Some(response);
self
}
pub fn fail_on_run(mut self, message: &str) -> Self {
self.fail_on_run = true;
self.run_error_message = message.to_string();
self
}
pub fn fail_on_interactive(mut self) -> Self {
self.fail_on_interactive = true;
self
}
pub fn with_delay(mut self, delay: Duration) -> Self {
self.delay = Some(delay);
self
}
pub fn model(mut self, model: &str) -> Self {
self.model = Some(model.to_string());
self
}
pub fn system_prompt(mut self, prompt: &str) -> Self {
self.system_prompt = Some(prompt.to_string());
self
}
pub fn build(self) -> MockAgent {
let mut agent = MockAgent::new();
*agent.responses.lock().unwrap() = self.responses;
if let Some(default) = self.default_response {
*agent.default_response.lock().unwrap() = default;
}
agent.fail_on_run = self.fail_on_run;
agent.run_error_message = self.run_error_message;
agent.fail_on_interactive = self.fail_on_interactive;
agent.delay = self.delay;
if let Some(model) = self.model {
agent.model = model;
}
if let Some(prompt) = self.system_prompt {
agent.system_prompt = prompt;
}
agent
}
}
impl Default for MockAgentBuilder {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Agent for MockAgent {
fn name(&self) -> &str {
"mock"
}
fn default_model() -> &'static str
where
Self: Sized,
{
DEFAULT_MODEL
}
fn model_for_size(size: ModelSize) -> &'static str
where
Self: Sized,
{
match size {
ModelSize::Small => "mock-small",
ModelSize::Medium => "mock-medium",
ModelSize::Large => "mock-large",
}
}
fn available_models() -> &'static [&'static str]
where
Self: Sized,
{
AVAILABLE_MODELS
}
fn system_prompt(&self) -> &str {
&self.system_prompt
}
fn set_system_prompt(&mut self, prompt: String) {
self.system_prompt = prompt;
}
fn get_model(&self) -> &str {
&self.model
}
fn set_model(&mut self, model: String) {
self.model = model;
}
fn set_root(&mut self, root: String) {
self.root = Some(root);
}
fn set_skip_permissions(&mut self, skip: bool) {
self.skip_permissions = skip;
}
fn set_output_format(&mut self, format: Option<String>) {
self.output_format = format;
}
fn set_max_turns(&mut self, turns: u32) {
self.max_turns = Some(turns);
}
fn set_sandbox(&mut self, config: SandboxConfig) {
self.sandbox = Some(config);
}
fn set_add_dirs(&mut self, dirs: Vec<String>) {
self.add_dirs = dirs;
}
fn as_any_ref(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
async fn run(&self, prompt: Option<&str>) -> Result<Option<AgentOutput>> {
self.run_count.fetch_add(1, Ordering::SeqCst);
if let Some(p) = prompt {
*self.last_prompt.lock().unwrap() = Some(p.to_string());
self.all_prompts.lock().unwrap().push(p.to_string());
}
if let Some(delay) = self.delay {
tokio::time::sleep(delay).await;
}
if self.fail_on_run {
anyhow::bail!("{}", self.run_error_message);
}
let response = {
let mut queue = self.responses.lock().unwrap();
if queue.is_empty() {
self.default_response.lock().unwrap().clone()
} else {
queue.remove(0)
}
};
Ok(Some(response.into_output()))
}
async fn run_interactive(&self, prompt: Option<&str>) -> Result<()> {
self.interactive_count.fetch_add(1, Ordering::SeqCst);
if let Some(p) = prompt {
*self.last_prompt.lock().unwrap() = Some(p.to_string());
self.all_prompts.lock().unwrap().push(p.to_string());
}
if self.fail_on_interactive {
anyhow::bail!("Mock agent interactive session failed");
}
Ok(())
}
async fn run_resume(&self, _session_id: Option<&str>, _last: bool) -> Result<()> {
self.resume_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn run_resume_with_prompt(
&self,
_session_id: &str,
prompt: &str,
) -> Result<Option<AgentOutput>> {
self.run_count.fetch_add(1, Ordering::SeqCst);
*self.last_prompt.lock().unwrap() = Some(prompt.to_string());
self.all_prompts.lock().unwrap().push(prompt.to_string());
let response = {
let mut queue = self.responses.lock().unwrap();
if queue.is_empty() {
self.default_response.lock().unwrap().clone()
} else {
queue.remove(0)
}
};
Ok(Some(response.into_output()))
}
async fn cleanup(&self) -> Result<()> {
Ok(())
}
}
pub mod events {
use super::*;
pub fn init(model: &str) -> Event {
Event::Init {
model: model.to_string(),
tools: vec!["Bash".to_string(), "Read".to_string(), "Write".to_string()],
working_directory: Some("/tmp/test".to_string()),
metadata: std::collections::HashMap::new(),
}
}
pub fn assistant_message(text: &str) -> Event {
Event::AssistantMessage {
content: vec![ContentBlock::Text {
text: text.to_string(),
}],
usage: None,
parent_tool_use_id: None,
}
}
pub fn assistant_message_with_usage(
text: &str,
input_tokens: u64,
output_tokens: u64,
) -> Event {
Event::AssistantMessage {
content: vec![ContentBlock::Text {
text: text.to_string(),
}],
usage: Some(Usage {
input_tokens,
output_tokens,
cache_read_tokens: None,
cache_creation_tokens: None,
web_search_requests: None,
web_fetch_requests: None,
}),
parent_tool_use_id: None,
}
}
pub fn tool_execution(tool_name: &str, input: &str, output: &str) -> Event {
Event::ToolExecution {
tool_name: tool_name.to_string(),
tool_id: uuid::Uuid::new_v4().to_string(),
input: serde_json::json!({ "command": input }),
result: ToolResult {
success: true,
output: Some(output.to_string()),
error: None,
data: None,
},
parent_tool_use_id: None,
}
}
pub fn tool_execution_failed(tool_name: &str, error: &str) -> Event {
Event::ToolExecution {
tool_name: tool_name.to_string(),
tool_id: uuid::Uuid::new_v4().to_string(),
input: serde_json::Value::Null,
result: ToolResult {
success: false,
output: None,
error: Some(error.to_string()),
data: None,
},
parent_tool_use_id: None,
}
}
pub fn result_success(message: &str) -> Event {
Event::Result {
success: true,
message: Some(message.to_string()),
duration_ms: Some(100),
num_turns: Some(1),
}
}
pub fn user_message(text: &str) -> Event {
Event::UserMessage {
content: vec![ContentBlock::Text {
text: text.to_string(),
}],
}
}
pub fn permission_granted(tool_name: &str) -> Event {
Event::PermissionRequest {
tool_name: tool_name.to_string(),
description: format!("Allow {} to execute", tool_name),
granted: true,
}
}
}
#[cfg(test)]
#[path = "mock_tests.rs"]
mod tests;