agent_stream_kit/
test_utils.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use tokio::{
6    sync::{Mutex as AsyncMutex, mpsc},
7    time::timeout,
8};
9
10use crate::{ASKit, AgentContext, AgentData, AgentError, AgentSpec, AgentValue, AsAgent};
11use askit_macros::askit_agent;
12
13pub type ProbeEvent = (AgentContext, AgentValue);
14
15pub const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_secs(1);
16
17#[derive(Clone)]
18pub struct ProbeReceiver(Arc<AsyncMutex<mpsc::UnboundedReceiver<ProbeEvent>>>);
19
20impl ProbeReceiver {
21    pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
22        let mut rx = self.0.lock().await;
23        timeout(duration, rx.recv())
24            .await
25            .map_err(|_| AgentError::SendMessageFailed("probe receive timed out".into()))?
26            .ok_or_else(|| AgentError::SendMessageFailed("probe channel closed".into()))
27    }
28
29    pub async fn recv(&self) -> Result<ProbeEvent, AgentError> {
30        self.recv_with_timeout(DEFAULT_PROBE_TIMEOUT).await
31    }
32}
33
34#[askit_agent(
35    title = "TestProbeAgent",
36    category = "Test",
37    inputs = ["*"],
38    outputs = []
39)]
40pub struct TestProbeAgent {
41    data: AgentData,
42    tx: mpsc::UnboundedSender<ProbeEvent>,
43    rx: ProbeReceiver,
44}
45
46impl TestProbeAgent {
47    /// Receive next probe event using the instance's own receiver.
48    pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
49        self.rx.recv_with_timeout(duration).await
50    }
51
52    /// Clone the internal receiver so callers can drop agent locks before awaiting.
53    pub fn probe_receiver(&self) -> ProbeReceiver {
54        self.rx.clone()
55    }
56}
57
58/// Helper to fetch the probe receiver for a TestProbeAgent by id.
59pub async fn probe_receiver(askit: &ASKit, agent_id: &str) -> Result<ProbeReceiver, AgentError> {
60    let probe = askit
61        .get_agent(agent_id)
62        .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
63    let probe_guard = probe.lock().await;
64    let probe_agent = probe_guard
65        .as_agent::<TestProbeAgent>()
66        .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
67    Ok(probe_agent.probe_receiver())
68}
69
70/// Await one probe event with timeout on the given receiver.
71pub async fn recv_probe_with_timeout(
72    probe_rec: &ProbeReceiver,
73    duration: Duration,
74) -> Result<ProbeEvent, AgentError> {
75    probe_rec.recv_with_timeout(duration).await
76}
77
78/// Receive one probe event with the default timeout.
79pub async fn recv_probe(probe_rec: &ProbeReceiver) -> Result<ProbeEvent, AgentError> {
80    probe_rec.recv().await
81}
82
83#[async_trait]
84impl AsAgent for TestProbeAgent {
85    fn new(askit: crate::ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
86        let (tx, rx) = mpsc::unbounded_channel();
87        let rx = ProbeReceiver(Arc::new(AsyncMutex::new(rx)));
88
89        Ok(Self {
90            data: AgentData::new(askit, id, spec),
91            tx,
92            rx,
93        })
94    }
95
96    async fn process(
97        &mut self,
98        ctx: AgentContext,
99        _pin: String,
100        value: AgentValue,
101    ) -> Result<(), AgentError> {
102        // Ignore send failures in tests; probe won't fail the pipeline
103        let _ = self.tx.send((ctx, value));
104        Ok(())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    use agent_stream_kit::test_utils::TestProbeAgent;
113    use agent_stream_kit::{ASKit, AgentContext, AgentError, AgentValue};
114    use tokio::time::Duration;
115
116    #[tokio::test]
117    async fn probe_receives_in_order() {
118        let askit = ASKit::new();
119        let def = TestProbeAgent::agent_definition();
120        let mut probe = TestProbeAgent::new(askit, "p1".into(), def.to_spec()).unwrap();
121
122        probe
123            .process(AgentContext::new(), "in".into(), AgentValue::integer(1))
124            .await
125            .unwrap();
126        let (_ctx, v1) = probe.probe_receiver().recv().await.unwrap();
127        assert_eq!(v1, AgentValue::integer(1));
128
129        probe
130            .process(AgentContext::new(), "in".into(), AgentValue::integer(2))
131            .await
132            .unwrap();
133        let (_ctx, v2) = probe.probe_receiver().recv().await.unwrap();
134        assert_eq!(v2, AgentValue::integer(2));
135    }
136
137    #[tokio::test]
138    async fn probe_times_out() {
139        let askit = ASKit::new();
140        let def = TestProbeAgent::agent_definition();
141        let probe = TestProbeAgent::new(askit, "p1".into(), def.to_spec()).unwrap();
142        let err = probe
143            .recv_with_timeout(Duration::from_millis(10))
144            .await
145            .unwrap_err();
146        assert!(matches!(err, AgentError::SendMessageFailed(_)));
147    }
148
149    #[tokio::test]
150    async fn probe_receiver_clone_works() {
151        let askit = ASKit::new();
152        let def = TestProbeAgent::agent_definition();
153        let mut probe = TestProbeAgent::new(askit, "p1".into(), def.to_spec()).unwrap();
154        let rx1 = probe.probe_receiver();
155        let rx2 = probe.probe_receiver();
156
157        probe
158            .process(AgentContext::new(), "in".into(), AgentValue::integer(42))
159            .await
160            .unwrap();
161
162        // Either receiver can consume the message (both clone the same inner receiver)
163        let (_ctx, v) = rx1.recv().await.unwrap();
164        assert_eq!(v, AgentValue::integer(42));
165
166        // Ensure timeout when no further messages exist
167        let err = rx2
168            .recv_with_timeout(Duration::from_millis(10))
169            .await
170            .unwrap_err();
171        assert!(matches!(err, AgentError::SendMessageFailed(_)));
172    }
173}