mod types;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use serde_json::json;
use crate::error::{Result, RustAgentsError};
use crate::harness::events::{AgentEvent, EventSink, RecordingListener};
use crate::harness::model::{ChatModel, ModelRequest, ModelResponse};
use crate::harness::tool::{Tool, ToolCall, ToolResult, ToolSchema};
pub use types::*;
impl ScriptedModel {
pub fn new(responses: Vec<ModelResponse>) -> Self {
Self {
queue: Mutex::new(VecDeque::from(responses)),
received: Mutex::new(Vec::new()),
}
}
pub fn replies<S: AsRef<str>>(texts: Vec<S>) -> Self {
let responses = texts
.into_iter()
.map(|t| ModelResponse::assistant(t.as_ref()))
.collect();
Self::new(responses)
}
pub fn requests(&self) -> Vec<ModelRequest> {
self.received
.lock()
.expect("ScriptedModel received lock poisoned")
.clone()
}
}
#[async_trait]
impl<State: Send + Sync> ChatModel<State> for ScriptedModel {
async fn invoke(&self, _state: &State, request: ModelRequest) -> Result<ModelResponse> {
self.received
.lock()
.expect("ScriptedModel received lock poisoned")
.push(request);
self.queue
.lock()
.expect("ScriptedModel queue lock poisoned")
.pop_front()
.ok_or_else(|| {
RustAgentsError::Model(
"ScriptedModel: response queue is exhausted; no more scripted responses"
.to_string(),
)
})
}
}
impl FakeTool {
pub fn new(name: impl Into<String>) -> Self {
let name = name.into();
Self {
tool_description: format!("Fake tool: {name}"),
tool_name: name,
behavior: FakeToolBehavior::Return(String::new()),
received: Mutex::new(Vec::new()),
}
}
pub fn returning(name: impl Into<String>, content: impl Into<String>) -> Self {
let name = name.into();
Self {
tool_description: format!("Fake tool: {name}"),
tool_name: name,
behavior: FakeToolBehavior::Return(content.into()),
received: Mutex::new(Vec::new()),
}
}
pub fn failing(name: impl Into<String>, message: impl Into<String>) -> Self {
let name = name.into();
Self {
tool_description: format!("Fake tool: {name}"),
tool_name: name,
behavior: FakeToolBehavior::Fail(message.into()),
received: Mutex::new(Vec::new()),
}
}
pub fn calls(&self) -> Vec<ToolCall> {
self.received
.lock()
.expect("FakeTool received lock poisoned")
.clone()
}
}
#[async_trait]
impl<State: Send + Sync> Tool<State> for FakeTool {
fn name(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &str {
&self.tool_description
}
fn schema(&self) -> ToolSchema {
ToolSchema::new(
self.tool_name.clone(),
self.tool_description.clone(),
json!({ "type": "object", "properties": {}, "required": [] }),
)
}
async fn call(&self, _state: &State, call: ToolCall) -> Result<ToolResult> {
self.received
.lock()
.expect("FakeTool received lock poisoned")
.push(call.clone());
match &self.behavior {
FakeToolBehavior::Return(content) => {
Ok(ToolResult::text(call.id, call.name, content.clone()))
}
FakeToolBehavior::Fail(message) => Err(RustAgentsError::Tool(message.clone())),
}
}
}
impl DeterministicClock {
pub fn new(start_millis: u64) -> Self {
Self {
millis: Mutex::new(start_millis),
}
}
pub fn now_millis(&self) -> u64 {
*self
.millis
.lock()
.expect("DeterministicClock lock poisoned")
}
pub fn advance(&self, ms: u64) {
*self
.millis
.lock()
.expect("DeterministicClock lock poisoned") += ms;
}
}
impl Default for DeterministicClock {
fn default() -> Self {
Self::new(0)
}
}
impl DeterministicIds {
pub fn new(prefix: impl Into<String>) -> Self {
Self {
prefix: prefix.into(),
counter: Mutex::new(0),
}
}
pub fn next(&self) -> String {
let mut counter = self.counter.lock().expect("DeterministicIds lock poisoned");
let id = format!("{}-{}", self.prefix, *counter);
*counter += 1;
id
}
}
impl EventRecorder {
pub fn new() -> Self {
let listener = Arc::new(RecordingListener::new());
let sink = EventSink::new();
sink.subscribe(listener.clone());
Self { listener, sink }
}
pub fn sink(&self) -> EventSink {
self.sink.clone()
}
pub fn events(&self) -> Vec<AgentEvent> {
self.listener
.events()
.into_iter()
.map(|r| r.event)
.collect()
}
pub fn kinds(&self) -> Vec<String> {
self.listener
.events()
.into_iter()
.map(|r| r.event.kind().to_string())
.collect()
}
}
impl Default for EventRecorder {
fn default() -> Self {
Self::new()
}
}
impl Trajectory {
pub fn from_events(events: Vec<AgentEvent>) -> Self {
Self { events }
}
pub fn tool_was_called(&self, name: &str) -> bool {
self.tool_call_count(name) > 0
}
pub fn assert_tool_called(&self, name: &str) {
assert!(
self.tool_was_called(name),
"Trajectory: expected tool '{name}' to have been called, but it was not found in the \
event sequence"
);
}
pub fn tool_call_count(&self, name: &str) -> usize {
self.events
.iter()
.filter(|e| matches!(e, AgentEvent::ToolStarted { tool_name, .. } if tool_name == name))
.count()
}
pub fn model_call_count(&self) -> usize {
self.events
.iter()
.filter(|e| matches!(e, AgentEvent::ModelStarted { .. }))
.count()
}
pub fn assert_model_called_times(&self, n: usize) {
let actual = self.model_call_count();
assert_eq!(
actual, n,
"Trajectory: expected {n} model call(s) but found {actual}"
);
}
pub fn assert_order(&self, labels: &[&str]) -> Result<()> {
let mut event_iter = self.events.iter();
for &label in labels {
let found = event_iter.any(|e| Self::event_matches_label(e, label));
if !found {
return Err(RustAgentsError::Validation(format!(
"Trajectory: expected label '{label}' in order but it was not found after the \
previous matched label"
)));
}
}
Ok(())
}
pub fn completed(&self) -> bool {
self.events
.iter()
.any(|e| matches!(e, AgentEvent::RunCompleted { .. }))
}
pub fn assert_completed(&self) {
assert!(
self.completed(),
"Trajectory: expected RunCompleted event but none was found"
);
}
pub fn failed(&self) -> bool {
self.events
.iter()
.any(|e| matches!(e, AgentEvent::RunFailed { .. }))
}
fn event_matches_label(event: &AgentEvent, label: &str) -> bool {
if event.kind() == label {
return true;
}
match event {
AgentEvent::ToolStarted { tool_name, .. }
| AgentEvent::ToolCompleted { tool_name, .. } => tool_name == label,
AgentEvent::RouteSelected { route } => route == label,
_ => false,
}
}
}
#[cfg(test)]
mod test;