Skip to main content

aether_core/testing/
utils.rs

1use std::error::Error;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4
5use futures::future::join_all;
6
7use crate::core::{RetryConfig, agent};
8use crate::events::{AgentMessage, UserMessage};
9use crate::mcp::McpSpawnResult;
10use crate::mcp::mcp;
11use crate::testing::FakeMcpServer;
12use crate::testing::fake_mcp::fake_mcp;
13use llm::{Context, LlmError, LlmResponse};
14
15use llm::testing::FakeLlmProvider;
16
17pub fn test_agent() -> TestAgentBuilder {
18    TestAgentBuilder::new()
19}
20
21/// Result of running a test agent, including messages and captured contexts.
22pub struct TestAgentResult {
23    pub messages: Vec<AgentMessage>,
24    pub captured_contexts: Arc<Mutex<Vec<Context>>>,
25}
26
27pub struct TestAgentBuilder {
28    messages: Vec<UserMessage>,
29    responses: Vec<Vec<Result<LlmResponse, LlmError>>>,
30    timeout: Option<Duration>,
31    max_auto_continues: Option<u32>,
32    retry_config: Option<RetryConfig>,
33}
34
35impl Default for TestAgentBuilder {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl TestAgentBuilder {
42    pub fn new() -> Self {
43        Self {
44            messages: Vec::new(),
45            responses: Vec::new(),
46            timeout: None,
47            max_auto_continues: None,
48            retry_config: None,
49        }
50    }
51
52    pub fn user_messages(mut self, user_messages: Vec<UserMessage>) -> Self {
53        self.messages = user_messages;
54        self
55    }
56
57    pub fn llm_responses(mut self, llm_responses: &[Vec<LlmResponse>]) -> Self {
58        self.responses = llm_responses.iter().map(|turn| turn.iter().cloned().map(Ok).collect()).collect();
59        self
60    }
61
62    pub fn llm_result_responses(mut self, llm_responses: &[Vec<Result<LlmResponse, LlmError>>]) -> Self {
63        self.responses = Vec::from(llm_responses);
64        self
65    }
66
67    pub fn tool_timeout(mut self, timeout: Duration) -> Self {
68        self.timeout = Some(timeout);
69        self
70    }
71
72    pub fn max_auto_continues(mut self, max: u32) -> Self {
73        self.max_auto_continues = Some(max);
74        self
75    }
76
77    pub fn retry_config(mut self, config: RetryConfig) -> Self {
78        self.retry_config = Some(config);
79        self
80    }
81
82    pub async fn run(self) -> Result<Vec<AgentMessage>, Box<dyn Error>> {
83        let result = self.run_with_context().await?;
84        Ok(result.messages)
85    }
86
87    /// Runs the test agent and returns both messages and captured contexts.
88    ///
89    /// Use this when you need to verify what context was passed to the LLM,
90    /// for example when testing that file attachments are properly formatted.
91    pub async fn run_with_context(self) -> Result<TestAgentResult, Box<dyn Error>> {
92        let llm = FakeLlmProvider::from_results(self.responses);
93        let captured_contexts = llm.captured_contexts();
94
95        let McpSpawnResult {
96            tool_definitions,
97            instructions: _,
98            server_statuses: _,
99            command_tx: mcp_tx,
100            event_rx: _,
101            handle: _mcp_handle,
102        } = mcp().with_servers(vec![fake_mcp("test", FakeMcpServer::new())]).spawn().await?;
103
104        let mut builder = agent(llm).tools(mcp_tx, tool_definitions);
105        if let Some(timeout) = self.timeout {
106            builder = builder.tool_timeout(timeout);
107        }
108        if let Some(max) = self.max_auto_continues {
109            builder = builder.max_auto_continues(max);
110        }
111        if let Some(retry) = self.retry_config {
112            builder = builder.retry(retry);
113        } else {
114            builder = builder.retry(RetryConfig::disabled());
115        }
116
117        let (tx, mut rx, _handle) = builder.spawn().await?;
118        let futures: Vec<_> = self.messages.into_iter().map(|m| tx.send(m)).collect();
119
120        join_all(futures).await;
121        drop(tx);
122
123        let mut messages = Vec::new();
124        while let Some(message) = rx.recv().await {
125            messages.push(message.clone());
126            if matches!(message, AgentMessage::Done) {
127                break;
128            }
129        }
130
131        Ok(TestAgentResult { messages, captured_contexts })
132    }
133}