use crate::agent::{Agent, AgentEvent};
use crate::error::{AgentError, ReactError, Result};
use futures::future::BoxFuture;
use futures::stream;
use futures::stream::BoxStream;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
pub struct MockAgent {
name: String,
model_name: String,
system_prompt: String,
responses: Arc<Mutex<VecDeque<String>>>,
calls: Arc<Mutex<Vec<String>>>,
}
impl MockAgent {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
model_name: "mock-model".to_string(),
system_prompt: "You are a mock agent".to_string(),
responses: Arc::new(Mutex::new(VecDeque::new())),
calls: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model_name = model.into();
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_response(self, text: impl Into<String>) -> Self {
self.responses.lock().unwrap().push_back(text.into());
self
}
pub fn with_responses(self, texts: impl IntoIterator<Item = impl Into<String>>) -> Self {
{
let mut q = self.responses.lock().unwrap();
for t in texts {
q.push_back(t.into());
}
}
self
}
pub fn call_count(&self) -> usize {
self.calls.lock().unwrap().len()
}
pub fn calls(&self) -> Vec<String> {
self.calls.lock().unwrap().clone()
}
pub fn last_task(&self) -> Option<String> {
self.calls.lock().unwrap().last().cloned()
}
pub fn reset_calls(&self) {
self.calls.lock().unwrap().clear();
}
fn next_response(&self) -> String {
self.responses
.lock()
.unwrap()
.pop_front()
.unwrap_or_else(|| "mock agent response".to_string())
}
}
impl Agent for MockAgent {
fn name(&self) -> &str {
&self.name
}
fn model_name(&self) -> &str {
&self.model_name
}
fn system_prompt(&self) -> &str {
&self.system_prompt
}
fn execute<'a>(&'a self, task: &'a str) -> BoxFuture<'a, Result<String>> {
Box::pin(async move {
self.calls.lock().unwrap().push(task.to_string());
Ok(self.next_response())
})
}
fn execute_stream<'a>(
&'a self,
task: &'a str,
) -> BoxFuture<'a, Result<BoxStream<'a, Result<AgentEvent>>>> {
Box::pin(async move {
let answer = self.execute(task).await?;
let event_stream = stream::once(async move { Ok(AgentEvent::FinalAnswer(answer)) });
Ok(Box::pin(event_stream) as BoxStream<'a, Result<AgentEvent>>)
})
}
fn chat<'a>(&'a self, message: &'a str) -> BoxFuture<'a, Result<String>> {
Box::pin(async move {
self.calls.lock().unwrap().push(message.to_string());
Ok(self.next_response())
})
}
fn chat_stream<'a>(
&'a self,
message: &'a str,
) -> BoxFuture<'a, Result<BoxStream<'a, Result<AgentEvent>>>> {
Box::pin(async move {
let answer = self.chat(message).await?;
let event_stream = stream::once(async move { Ok(AgentEvent::FinalAnswer(answer)) });
Ok(Box::pin(event_stream) as BoxStream<'a, Result<AgentEvent>>)
})
}
fn reset(&self) {
self.calls.lock().unwrap().clear();
}
}
pub struct FailingMockAgent {
name: String,
error_message: String,
calls: Arc<Mutex<Vec<String>>>,
}
impl FailingMockAgent {
pub fn new(name: impl Into<String>, error_message: impl Into<String>) -> Self {
Self {
name: name.into(),
error_message: error_message.into(),
calls: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn call_count(&self) -> usize {
self.calls.lock().unwrap().len()
}
}
impl Agent for FailingMockAgent {
fn name(&self) -> &str {
&self.name
}
fn model_name(&self) -> &str {
"mock-model"
}
fn system_prompt(&self) -> &str {
"failing mock agent"
}
fn execute<'a>(&'a self, task: &'a str) -> BoxFuture<'a, Result<String>> {
Box::pin(async move {
self.calls.lock().unwrap().push(task.to_string());
Err(ReactError::Agent(AgentError::InitializationFailed(
self.error_message.clone(),
)))
})
}
fn execute_stream<'a>(
&'a self,
task: &'a str,
) -> BoxFuture<'a, Result<BoxStream<'a, Result<AgentEvent>>>> {
Box::pin(async move {
let err = self.execute(task).await.unwrap_err();
let event_stream = stream::once(async move { Err(err) });
Ok(Box::pin(event_stream) as BoxStream<'a, Result<AgentEvent>>)
})
}
fn chat<'a>(&'a self, message: &'a str) -> BoxFuture<'a, Result<String>> {
Box::pin(async move { self.execute(message).await })
}
fn reset(&self) {
self.calls.lock().unwrap().clear();
}
}