use std::sync::{Arc, Mutex};
use crate::events::AgentEvent;
use crate::model::ModelResponse;
use crate::provider::{ModelProvider, ProviderError};
use crate::types::{ContentBlock, Message, Role, StopReason, ToolDefinition, ToolUseBlock};
#[derive(Clone)]
pub struct MockProvider {
responses: Arc<Mutex<Vec<ModelResponse>>>,
call_count: Arc<Mutex<usize>>,
}
impl MockProvider {
pub fn new() -> Self {
Self {
responses: Arc::new(Mutex::new(Vec::new())),
call_count: Arc::new(Mutex::new(0)),
}
}
pub fn with_text(self, text: impl Into<String>) -> Self {
let message = Message::assistant(text);
let response = ModelResponse {
message,
stop_reason: StopReason::EndTurn,
usage: None,
};
self.responses.lock().unwrap().push(response);
self
}
pub fn with_tool_use(
self,
tool_name: impl Into<String>,
tool_input: serde_json::Value,
) -> Self {
let tool_use = ToolUseBlock {
id: format!("tool_{}", uuid::Uuid::new_v4()),
name: tool_name.into(),
input: tool_input,
};
let message = Message {
role: Role::Assistant,
content: vec![ContentBlock::ToolUse(tool_use)],
};
let response = ModelResponse {
message,
stop_reason: StopReason::ToolUse,
usage: None,
};
self.responses.lock().unwrap().push(response);
self
}
pub fn call_count(&self) -> usize {
*self.call_count.lock().unwrap()
}
}
impl Default for MockProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl ModelProvider for MockProvider {
fn name(&self) -> &str {
"MockProvider"
}
fn max_context_tokens(&self) -> usize {
200_000
}
fn max_output_tokens(&self) -> usize {
8_192
}
async fn generate(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_system_prompt: Option<String>,
) -> Result<ModelResponse, ProviderError> {
let mut count = self.call_count.lock().unwrap();
*count += 1;
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
return Err(ProviderError::Other(
"MockProvider: No more responses configured".to_string(),
));
}
Ok(responses.remove(0))
}
}
#[derive(Clone)]
pub struct EventCollector {
events: Arc<Mutex<Vec<AgentEvent>>>,
}
impl EventCollector {
pub fn new() -> Self {
Self {
events: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn events(&self) -> Vec<AgentEvent> {
self.events.lock().unwrap().clone()
}
pub fn event_types(&self) -> Vec<String> {
self.events
.lock()
.unwrap()
.iter()
.map(|e| Self::event_type_name(e).to_string())
.collect()
}
pub fn clear(&self) {
self.events.lock().unwrap().clear();
}
pub fn has_event(&self, event_type: &str) -> bool {
self.events
.lock()
.unwrap()
.iter()
.any(|e| Self::event_type_name(e) == event_type)
}
pub fn count_event(&self, event_type: &str) -> usize {
self.events
.lock()
.unwrap()
.iter()
.filter(|e| Self::event_type_name(e) == event_type)
.count()
}
pub fn len(&self) -> usize {
self.events.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.events.lock().unwrap().is_empty()
}
fn event_type_name(event: &AgentEvent) -> &'static str {
match event {
AgentEvent::RunStarted { .. } => "run_started",
AgentEvent::RunCompleted { .. } => "run_completed",
AgentEvent::RunFailed { .. } => "run_failed",
AgentEvent::ModelCallStarted { .. } => "model_call_started",
AgentEvent::ModelCallStreaming { .. } => "model_streaming",
AgentEvent::ModelCallCompleted { .. } => "model_call_completed",
AgentEvent::ToolRequested { .. } => "tool_requested",
AgentEvent::ToolExecuting { .. } => "tool_executing",
AgentEvent::ToolCompleted { .. } => "tool_completed",
AgentEvent::ToolFailed { .. } => "tool_failed",
AgentEvent::PermissionRequired { .. } => "permission_required",
AgentEvent::PermissionGranted { .. } => "permission_granted",
AgentEvent::PermissionDenied { .. } => "permission_denied",
#[cfg(feature = "session")]
AgentEvent::SessionResumed { .. } => "session_resumed",
#[cfg(feature = "session")]
AgentEvent::SessionSaved { .. } => "session_saved",
}
}
}
impl Default for EventCollector {
fn default() -> Self {
Self::new()
}
}
impl crate::events::AgentHook for EventCollector {
fn on_event(&self, event: &AgentEvent) {
self.events.lock().unwrap().push(event.clone());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_provider_text_response() {
let provider = MockProvider::new().with_text("Hello!");
assert_eq!(provider.call_count(), 0);
}
#[test]
fn test_mock_provider_chained_responses() {
let provider = MockProvider::new()
.with_tool_use("calculator", serde_json::json!({"expr": "2+2"}))
.with_text("The answer is 4");
assert_eq!(provider.call_count(), 0);
}
#[tokio::test]
async fn test_mock_provider_generate() {
let provider = MockProvider::new()
.with_text("Response 1")
.with_text("Response 2");
let response1 = provider.generate(vec![], vec![], None).await.unwrap();
assert_eq!(provider.call_count(), 1);
assert!(response1.message.text().contains("Response 1"));
let response2 = provider.generate(vec![], vec![], None).await.unwrap();
assert_eq!(provider.call_count(), 2);
assert!(response2.message.text().contains("Response 2"));
let result = provider.generate(vec![], vec![], None).await;
assert!(result.is_err());
}
#[test]
fn test_event_collector() {
let collector = EventCollector::new();
assert!(collector.is_empty());
collector
.events
.lock()
.unwrap()
.push(AgentEvent::RunStarted {
input: "test".to_string(),
timestamp: std::time::Instant::now(),
});
collector
.events
.lock()
.unwrap()
.push(AgentEvent::RunCompleted {
output: "done".to_string(),
duration: std::time::Duration::from_secs(1),
});
assert_eq!(collector.len(), 2);
assert!(collector.has_event("run_started"));
assert!(collector.has_event("run_completed"));
assert!(!collector.has_event("run_failed"));
assert_eq!(collector.count_event("run_started"), 1);
let types = collector.event_types();
assert_eq!(types, vec!["run_started", "run_completed"]);
collector.clear();
assert!(collector.is_empty());
}
}