aether_core/testing/
utils.rs1use std::error::Error;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4
5use futures::future::join_all;
6
7use crate::core::agent;
8use crate::events::{AgentMessage, UserMessage};
9use crate::mcp::McpSpawnResult;
10use crate::mcp::mcp;
11use crate::testing::FakeMcpServer;
12use crate::testing::fake_mcp::fake_mcp;
13use llm::{Context, LlmError, LlmResponse};
14
15use llm::testing::FakeLlmProvider;
16
17pub fn test_agent() -> TestAgentBuilder {
18 TestAgentBuilder::new()
19}
20
21pub struct TestAgentResult {
23 pub messages: Vec<AgentMessage>,
24 pub captured_contexts: Arc<Mutex<Vec<Context>>>,
25}
26
27pub struct TestAgentBuilder {
28 messages: Vec<UserMessage>,
29 responses: Vec<Vec<Result<LlmResponse, LlmError>>>,
30 timeout: Option<Duration>,
31 max_auto_continues: Option<u32>,
32}
33
34impl Default for TestAgentBuilder {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl TestAgentBuilder {
41 pub fn new() -> Self {
42 Self { messages: Vec::new(), responses: Vec::new(), timeout: None, max_auto_continues: None }
43 }
44
45 pub fn user_messages(mut self, user_messages: Vec<UserMessage>) -> Self {
46 self.messages = user_messages;
47 self
48 }
49
50 pub fn llm_responses(mut self, llm_responses: &[Vec<LlmResponse>]) -> Self {
51 self.responses = llm_responses.iter().map(|turn| turn.iter().cloned().map(Ok).collect()).collect();
52 self
53 }
54
55 pub fn llm_result_responses(mut self, llm_responses: &[Vec<Result<LlmResponse, LlmError>>]) -> Self {
56 self.responses = Vec::from(llm_responses);
57 self
58 }
59
60 pub fn tool_timeout(mut self, timeout: Duration) -> Self {
61 self.timeout = Some(timeout);
62 self
63 }
64
65 pub fn max_auto_continues(mut self, max: u32) -> Self {
66 self.max_auto_continues = Some(max);
67 self
68 }
69
70 pub async fn run(self) -> Result<Vec<AgentMessage>, Box<dyn Error>> {
71 let result = self.run_with_context().await?;
72 Ok(result.messages)
73 }
74
75 pub async fn run_with_context(self) -> Result<TestAgentResult, Box<dyn Error>> {
80 let llm = FakeLlmProvider::from_results(self.responses);
81 let captured_contexts = llm.captured_contexts();
82
83 let McpSpawnResult {
84 tool_definitions,
85 instructions: _,
86 server_statuses: _,
87 command_tx: mcp_tx,
88 event_rx: _,
89 handle: _mcp_handle,
90 } = mcp().with_servers(vec![fake_mcp("test", FakeMcpServer::new()).into()]).spawn().await?;
91
92 let mut builder = agent(llm).tools(mcp_tx, tool_definitions);
93 if let Some(timeout) = self.timeout {
94 builder = builder.tool_timeout(timeout);
95 }
96 if let Some(max) = self.max_auto_continues {
97 builder = builder.max_auto_continues(max);
98 }
99
100 let (tx, mut rx, _handle) = builder.spawn().await?;
101 let futures: Vec<_> = self.messages.into_iter().map(|m| tx.send(m)).collect();
102
103 join_all(futures).await;
104 drop(tx);
105
106 let mut messages = Vec::new();
107 while let Some(message) = rx.recv().await {
108 messages.push(message.clone());
109 if matches!(message, AgentMessage::Done) {
110 break;
111 }
112 }
113
114 Ok(TestAgentResult { messages, captured_contexts })
115 }
116}