agent_stream_kit/
test_utils.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use tokio::{
6 sync::{Mutex as AsyncMutex, mpsc},
7 time::timeout,
8};
9
10use crate::{ASKit, AgentContext, AgentData, AgentError, AgentSpec, AgentValue, AsAgent};
11use askit_macros::askit_agent;
12
13pub type ProbeEvent = (AgentContext, AgentValue);
14
15pub const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_secs(1);
16
17#[derive(Clone)]
18pub struct ProbeReceiver(Arc<AsyncMutex<mpsc::UnboundedReceiver<ProbeEvent>>>);
19
20impl ProbeReceiver {
21 pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
22 let mut rx = self.0.lock().await;
23 timeout(duration, rx.recv())
24 .await
25 .map_err(|_| AgentError::SendMessageFailed("probe receive timed out".into()))?
26 .ok_or_else(|| AgentError::SendMessageFailed("probe channel closed".into()))
27 }
28
29 pub async fn recv(&self) -> Result<ProbeEvent, AgentError> {
30 self.recv_with_timeout(DEFAULT_PROBE_TIMEOUT).await
31 }
32}
33
34#[askit_agent(
35 title = "TestProbeAgent",
36 category = "Test",
37 inputs = ["*"],
38 outputs = []
39)]
40pub struct TestProbeAgent {
41 data: AgentData,
42 tx: mpsc::UnboundedSender<ProbeEvent>,
43 rx: ProbeReceiver,
44}
45
46impl TestProbeAgent {
47 pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
49 self.rx.recv_with_timeout(duration).await
50 }
51
52 pub fn probe_receiver(&self) -> ProbeReceiver {
54 self.rx.clone()
55 }
56}
57
58pub async fn probe_receiver(askit: &ASKit, agent_id: &str) -> Result<ProbeReceiver, AgentError> {
60 let probe = askit
61 .get_agent(agent_id)
62 .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
63 let probe_guard = probe.lock().await;
64 let probe_agent = probe_guard
65 .as_agent::<TestProbeAgent>()
66 .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
67 Ok(probe_agent.probe_receiver())
68}
69
70pub async fn recv_probe_with_timeout(
72 probe_rec: &ProbeReceiver,
73 duration: Duration,
74) -> Result<ProbeEvent, AgentError> {
75 probe_rec.recv_with_timeout(duration).await
76}
77
78pub async fn recv_probe(probe_rec: &ProbeReceiver) -> Result<ProbeEvent, AgentError> {
80 probe_rec.recv().await
81}
82
83#[async_trait]
84impl AsAgent for TestProbeAgent {
85 fn new(askit: crate::ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
86 let (tx, rx) = mpsc::unbounded_channel();
87 let rx = ProbeReceiver(Arc::new(AsyncMutex::new(rx)));
88
89 Ok(Self {
90 data: AgentData::new(askit, id, spec),
91 tx,
92 rx,
93 })
94 }
95
96 async fn process(
97 &mut self,
98 ctx: AgentContext,
99 _pin: String,
100 value: AgentValue,
101 ) -> Result<(), AgentError> {
102 let _ = self.tx.send((ctx, value));
104 Ok(())
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 use agent_stream_kit::test_utils::TestProbeAgent;
113 use agent_stream_kit::{ASKit, AgentContext, AgentError, AgentValue};
114 use tokio::time::Duration;
115
116 #[tokio::test]
117 async fn probe_receives_in_order() {
118 let askit = ASKit::new();
119 let def = TestProbeAgent::agent_definition();
120 let mut probe = TestProbeAgent::new(askit, "p1".into(), def.to_spec()).unwrap();
121
122 probe
123 .process(AgentContext::new(), "in".into(), AgentValue::integer(1))
124 .await
125 .unwrap();
126 let (_ctx, v1) = probe.probe_receiver().recv().await.unwrap();
127 assert_eq!(v1, AgentValue::integer(1));
128
129 probe
130 .process(AgentContext::new(), "in".into(), AgentValue::integer(2))
131 .await
132 .unwrap();
133 let (_ctx, v2) = probe.probe_receiver().recv().await.unwrap();
134 assert_eq!(v2, AgentValue::integer(2));
135 }
136
137 #[tokio::test]
138 async fn probe_times_out() {
139 let askit = ASKit::new();
140 let def = TestProbeAgent::agent_definition();
141 let probe = TestProbeAgent::new(askit, "p1".into(), def.to_spec()).unwrap();
142 let err = probe
143 .recv_with_timeout(Duration::from_millis(10))
144 .await
145 .unwrap_err();
146 assert!(matches!(err, AgentError::SendMessageFailed(_)));
147 }
148
149 #[tokio::test]
150 async fn probe_receiver_clone_works() {
151 let askit = ASKit::new();
152 let def = TestProbeAgent::agent_definition();
153 let mut probe = TestProbeAgent::new(askit, "p1".into(), def.to_spec()).unwrap();
154 let rx1 = probe.probe_receiver();
155 let rx2 = probe.probe_receiver();
156
157 probe
158 .process(AgentContext::new(), "in".into(), AgentValue::integer(42))
159 .await
160 .unwrap();
161
162 let (_ctx, v) = rx1.recv().await.unwrap();
164 assert_eq!(v, AgentValue::integer(42));
165
166 let err = rx2
168 .recv_with_timeout(Duration::from_millis(10))
169 .await
170 .unwrap_err();
171 assert!(matches!(err, AgentError::SendMessageFailed(_)));
172 }
173}