agent_stream_kit/
test_utils.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use tokio::{
6    sync::{Mutex as AsyncMutex, mpsc},
7    time::timeout,
8};
9
10use crate::{ASKit, AgentContext, AgentData, AgentError, AgentSpec, AgentValue, AsAgent};
11use askit_macros::askit_agent;
12
13pub type ProbeEvent = (AgentContext, AgentValue);
14
15pub const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_secs(1);
16
17#[derive(Clone)]
18pub struct ProbeReceiver(Arc<AsyncMutex<mpsc::UnboundedReceiver<ProbeEvent>>>);
19
20impl ProbeReceiver {
21    pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
22        let mut rx = self.0.lock().await;
23        timeout(duration, rx.recv())
24            .await
25            .map_err(|_| AgentError::SendMessageFailed("probe receive timed out".into()))?
26            .ok_or_else(|| AgentError::SendMessageFailed("probe channel closed".into()))
27    }
28
29    pub async fn recv(&self) -> Result<ProbeEvent, AgentError> {
30        self.recv_with_timeout(DEFAULT_PROBE_TIMEOUT).await
31    }
32}
33
34#[askit_agent(
35    title = "TestProbeAgent",
36    category = "Test",
37    inputs = ["*"],
38    outputs = []
39)]
40pub struct TestProbeAgent {
41    data: AgentData,
42    tx: mpsc::UnboundedSender<ProbeEvent>,
43    rx: ProbeReceiver,
44}
45
46impl TestProbeAgent {
47    /// Receive next probe event using the instance's own receiver.
48    pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
49        self.rx.recv_with_timeout(duration).await
50    }
51
52    /// Clone the internal receiver so callers can drop agent locks before awaiting.
53    pub fn probe_receiver(&self) -> ProbeReceiver {
54        self.rx.clone()
55    }
56}
57
58/// Helper to fetch the probe receiver for a TestProbeAgent by id.
59pub async fn probe_receiver(askit: &ASKit, agent_id: &str) -> Result<ProbeReceiver, AgentError> {
60    let probe = askit
61        .get_agent(agent_id)
62        .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
63    let probe_guard = probe.lock().await;
64    let probe_agent = probe_guard
65        .as_agent::<TestProbeAgent>()
66        .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
67    Ok(probe_agent.probe_receiver())
68}
69
70/// Await one probe event with timeout on the given receiver.
71pub async fn recv_probe_with_timeout(
72    probe_rec: &ProbeReceiver,
73    duration: Duration,
74) -> Result<ProbeEvent, AgentError> {
75    probe_rec.recv_with_timeout(duration).await
76}
77
78/// Receive one probe event with the default timeout.
79pub async fn recv_probe(probe_rec: &ProbeReceiver) -> Result<ProbeEvent, AgentError> {
80    probe_rec.recv().await
81}
82
83#[async_trait]
84impl AsAgent for TestProbeAgent {
85    fn new(askit: crate::ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
86        let (tx, rx) = mpsc::unbounded_channel();
87        let rx = ProbeReceiver(Arc::new(AsyncMutex::new(rx)));
88
89        Ok(Self {
90            data: AgentData::new(askit, id, spec),
91            tx,
92            rx,
93        })
94    }
95
96    async fn process(
97        &mut self,
98        ctx: AgentContext,
99        _pin: String,
100        value: AgentValue,
101    ) -> Result<(), AgentError> {
102        // Ignore send failures in tests; probe won't fail the pipeline
103        let _ = self.tx.send((ctx, value));
104        Ok(())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    use agent_stream_kit::test_utils::TestProbeAgent;
113    use agent_stream_kit::{ASKit, AgentContext, AgentError, AgentValue};
114    use tokio::time::Duration;
115
116    #[tokio::test]
117    async fn probe_receives_in_order() {
118        let askit = ASKit::new();
119        let def = TestProbeAgent::agent_definition();
120        let spec = AgentSpec::from_def(&def);
121        let mut probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
122
123        probe
124            .process(AgentContext::new(), "in".into(), AgentValue::integer(1))
125            .await
126            .unwrap();
127        let (_ctx, v1) = probe.probe_receiver().recv().await.unwrap();
128        assert_eq!(v1, AgentValue::integer(1));
129
130        probe
131            .process(AgentContext::new(), "in".into(), AgentValue::integer(2))
132            .await
133            .unwrap();
134        let (_ctx, v2) = probe.probe_receiver().recv().await.unwrap();
135        assert_eq!(v2, AgentValue::integer(2));
136    }
137
138    #[tokio::test]
139    async fn probe_times_out() {
140        let askit = ASKit::new();
141        let def = TestProbeAgent::agent_definition();
142        let spec = AgentSpec::from_def(&def);
143        let probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
144        let err = probe
145            .recv_with_timeout(Duration::from_millis(10))
146            .await
147            .unwrap_err();
148        assert!(matches!(err, AgentError::SendMessageFailed(_)));
149    }
150
151    #[tokio::test]
152    async fn probe_receiver_clone_works() {
153        let askit = ASKit::new();
154        let def = TestProbeAgent::agent_definition();
155        let spec = AgentSpec::from_def(&def);
156        let mut probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
157        let rx1 = probe.probe_receiver();
158        let rx2 = probe.probe_receiver();
159
160        probe
161            .process(AgentContext::new(), "in".into(), AgentValue::integer(42))
162            .await
163            .unwrap();
164
165        // Either receiver can consume the message (both clone the same inner receiver)
166        let (_ctx, v) = rx1.recv().await.unwrap();
167        assert_eq!(v, AgentValue::integer(42));
168
169        // Ensure timeout when no further messages exist
170        let err = rx2
171            .recv_with_timeout(Duration::from_millis(10))
172            .await
173            .unwrap_err();
174        assert!(matches!(err, AgentError::SendMessageFailed(_)));
175    }
176}