Skip to main content

aether_core/testing/
utils.rs

1use std::collections::BTreeMap;
2use std::error::Error;
3use std::sync::{Arc, Mutex};
4use std::time::Duration;
5
6use futures::future::join_all;
7
8use crate::core::{RetryConfig, agent};
9use crate::events::{AgentMessage, Command};
10use crate::mcp::mcp;
11use crate::testing::FakeMcpServer;
12use crate::testing::fake_mcp::fake_mcp;
13use llm::{Context, LlmError, LlmResponse};
14
15use llm::testing::FakeLlmProvider;
16
17pub fn mcp_instructions(entries: &[(&str, &str)]) -> BTreeMap<String, String> {
18    entries.iter().map(|(k, v)| ((*k).to_string(), (*v).to_string())).collect()
19}
20
21pub fn test_agent() -> TestAgentBuilder {
22    TestAgentBuilder::new()
23}
24
25/// Result of running a test agent, including messages and captured contexts.
26pub struct TestAgentResult {
27    pub messages: Vec<AgentMessage>,
28    pub captured_contexts: Arc<Mutex<Vec<Context>>>,
29}
30
31pub struct TestAgentBuilder {
32    messages: Vec<Command>,
33    responses: Vec<Vec<Result<LlmResponse, LlmError>>>,
34    timeout: Option<Duration>,
35    max_auto_continues: Option<u32>,
36    retry_config: Option<RetryConfig>,
37}
38
39impl Default for TestAgentBuilder {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl TestAgentBuilder {
46    pub fn new() -> Self {
47        Self {
48            messages: Vec::new(),
49            responses: Vec::new(),
50            timeout: None,
51            max_auto_continues: None,
52            retry_config: None,
53        }
54    }
55
56    pub fn user_messages(mut self, user_messages: Vec<Command>) -> Self {
57        self.messages = user_messages;
58        self
59    }
60
61    pub fn llm_responses(mut self, llm_responses: &[Vec<LlmResponse>]) -> Self {
62        self.responses = llm_responses.iter().map(|turn| turn.iter().cloned().map(Ok).collect()).collect();
63        self
64    }
65
66    pub fn llm_result_responses(mut self, llm_responses: &[Vec<Result<LlmResponse, LlmError>>]) -> Self {
67        self.responses = Vec::from(llm_responses);
68        self
69    }
70
71    pub fn tool_timeout(mut self, timeout: Duration) -> Self {
72        self.timeout = Some(timeout);
73        self
74    }
75
76    pub fn max_auto_continues(mut self, max: u32) -> Self {
77        self.max_auto_continues = Some(max);
78        self
79    }
80
81    pub fn retry_config(mut self, config: RetryConfig) -> Self {
82        self.retry_config = Some(config);
83        self
84    }
85
86    pub async fn run(self) -> Result<Vec<AgentMessage>, Box<dyn Error>> {
87        let result = self.run_with_context().await?;
88        Ok(result.messages)
89    }
90
91    /// Runs the test agent and returns both messages and captured contexts.
92    ///
93    /// Use this when you need to verify what context was passed to the LLM,
94    /// for example when testing that file attachments are properly formatted.
95    pub async fn run_with_context(self) -> Result<TestAgentResult, Box<dyn Error>> {
96        let llm = FakeLlmProvider::from_results(self.responses);
97        let captured_contexts = llm.captured_contexts();
98
99        let mut spawn = mcp("/workspace").with_servers(vec![fake_mcp("test", FakeMcpServer::new())]).spawn().await?;
100        let snapshot = spawn.block_until_ready().await.expect("bootstrap completes");
101
102        let mut builder = agent(llm).tools(spawn.command_tx, snapshot.tool_definitions);
103        if let Some(timeout) = self.timeout {
104            builder = builder.tool_timeout(timeout);
105        }
106        if let Some(max) = self.max_auto_continues {
107            builder = builder.max_auto_continues(max);
108        }
109        if let Some(retry) = self.retry_config {
110            builder = builder.retry(retry);
111        } else {
112            builder = builder.retry(RetryConfig::disabled());
113        }
114
115        let (tx, mut rx, _handle) = builder.spawn().await?;
116        let futures: Vec<_> = self.messages.into_iter().map(|m| tx.send(m)).collect();
117
118        join_all(futures).await;
119        drop(tx);
120
121        let mut messages = Vec::new();
122        while let Some(message) = rx.recv().await {
123            messages.push(message.clone());
124            if matches!(message, AgentMessage::Done) {
125                break;
126            }
127        }
128
129        Ok(TestAgentResult { messages, captured_contexts })
130    }
131}