agent_stream_kit/
test_utils.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use tokio::{
6 sync::{Mutex as AsyncMutex, mpsc},
7 time::timeout,
8};
9
10use crate::{ASKit, AgentContext, AgentData, AgentError, AgentSpec, AgentValue, AsAgent};
11use askit_macros::askit_agent;
12
13pub type ProbeEvent = (AgentContext, AgentValue);
14
15pub const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_secs(1);
16
17#[derive(Clone)]
18pub struct ProbeReceiver(Arc<AsyncMutex<mpsc::UnboundedReceiver<ProbeEvent>>>);
19
20impl ProbeReceiver {
21 pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
22 let mut rx = self.0.lock().await;
23 timeout(duration, rx.recv())
24 .await
25 .map_err(|_| AgentError::SendMessageFailed("probe receive timed out".into()))?
26 .ok_or_else(|| AgentError::SendMessageFailed("probe channel closed".into()))
27 }
28
29 pub async fn recv(&self) -> Result<ProbeEvent, AgentError> {
30 self.recv_with_timeout(DEFAULT_PROBE_TIMEOUT).await
31 }
32}
33
34#[askit_agent(
35 title = "TestProbeAgent",
36 category = "Test",
37 inputs = ["*"],
38 outputs = []
39)]
40pub struct TestProbeAgent {
41 data: AgentData,
42 tx: mpsc::UnboundedSender<ProbeEvent>,
43 rx: ProbeReceiver,
44}
45
46impl TestProbeAgent {
47 pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
49 self.rx.recv_with_timeout(duration).await
50 }
51
52 pub fn probe_receiver(&self) -> ProbeReceiver {
54 self.rx.clone()
55 }
56}
57
58pub async fn probe_receiver(askit: &ASKit, agent_id: &str) -> Result<ProbeReceiver, AgentError> {
60 let probe = askit
61 .get_agent(agent_id)
62 .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
63 let probe_guard = probe.lock().await;
64 let probe_agent = probe_guard
65 .as_agent::<TestProbeAgent>()
66 .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
67 Ok(probe_agent.probe_receiver())
68}
69
70pub async fn recv_probe_with_timeout(
72 probe_rec: &ProbeReceiver,
73 duration: Duration,
74) -> Result<ProbeEvent, AgentError> {
75 probe_rec.recv_with_timeout(duration).await
76}
77
78pub async fn recv_probe(probe_rec: &ProbeReceiver) -> Result<ProbeEvent, AgentError> {
80 probe_rec.recv().await
81}
82
83#[async_trait]
84impl AsAgent for TestProbeAgent {
85 fn new(askit: crate::ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
86 let (tx, rx) = mpsc::unbounded_channel();
87 let rx = ProbeReceiver(Arc::new(AsyncMutex::new(rx)));
88
89 Ok(Self {
90 data: AgentData::new(askit, id, spec),
91 tx,
92 rx,
93 })
94 }
95
96 async fn process(
97 &mut self,
98 ctx: AgentContext,
99 _pin: String,
100 value: AgentValue,
101 ) -> Result<(), AgentError> {
102 let _ = self.tx.send((ctx, value));
104 Ok(())
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 use agent_stream_kit::test_utils::TestProbeAgent;
113 use agent_stream_kit::{ASKit, AgentContext, AgentError, AgentValue};
114 use tokio::time::Duration;
115
116 #[tokio::test]
117 async fn probe_receives_in_order() {
118 let askit = ASKit::new();
119 let def = TestProbeAgent::agent_definition();
120 let spec = AgentSpec::from_def(&def);
121 let mut probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
122
123 probe
124 .process(AgentContext::new(), "in".into(), AgentValue::integer(1))
125 .await
126 .unwrap();
127 let (_ctx, v1) = probe.probe_receiver().recv().await.unwrap();
128 assert_eq!(v1, AgentValue::integer(1));
129
130 probe
131 .process(AgentContext::new(), "in".into(), AgentValue::integer(2))
132 .await
133 .unwrap();
134 let (_ctx, v2) = probe.probe_receiver().recv().await.unwrap();
135 assert_eq!(v2, AgentValue::integer(2));
136 }
137
138 #[tokio::test]
139 async fn probe_times_out() {
140 let askit = ASKit::new();
141 let def = TestProbeAgent::agent_definition();
142 let spec = AgentSpec::from_def(&def);
143 let probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
144 let err = probe
145 .recv_with_timeout(Duration::from_millis(10))
146 .await
147 .unwrap_err();
148 assert!(matches!(err, AgentError::SendMessageFailed(_)));
149 }
150
151 #[tokio::test]
152 async fn probe_receiver_clone_works() {
153 let askit = ASKit::new();
154 let def = TestProbeAgent::agent_definition();
155 let spec = AgentSpec::from_def(&def);
156 let mut probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
157 let rx1 = probe.probe_receiver();
158 let rx2 = probe.probe_receiver();
159
160 probe
161 .process(AgentContext::new(), "in".into(), AgentValue::integer(42))
162 .await
163 .unwrap();
164
165 let (_ctx, v) = rx1.recv().await.unwrap();
167 assert_eq!(v, AgentValue::integer(42));
168
169 let err = rx2
171 .recv_with_timeout(Duration::from_millis(10))
172 .await
173 .unwrap_err();
174 assert!(matches!(err, AgentError::SendMessageFailed(_)));
175 }
176}