agent_stream_kit/
test_utils.rs

1use crate::{ASKit, AgentConfigs, AgentContext, AgentData, AgentError, AgentValue, AsAgent};
2use askit_macros::askit_agent;
3use async_trait::async_trait;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::{
7    sync::{Mutex as AsyncMutex, mpsc},
8    time::timeout,
9};
10
11pub type ProbeEvent = (AgentContext, AgentValue);
12pub const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_secs(1);
13
14#[derive(Clone)]
15pub struct ProbeReceiver(Arc<AsyncMutex<mpsc::UnboundedReceiver<ProbeEvent>>>);
16
17impl ProbeReceiver {
18    pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
19        let mut rx = self.0.lock().await;
20        timeout(duration, rx.recv())
21            .await
22            .map_err(|_| AgentError::SendMessageFailed("probe receive timed out".into()))?
23            .ok_or_else(|| AgentError::SendMessageFailed("probe channel closed".into()))
24    }
25
26    pub async fn recv(&self) -> Result<ProbeEvent, AgentError> {
27        self.recv_with_timeout(DEFAULT_PROBE_TIMEOUT).await
28    }
29}
30
31#[askit_agent(
32    title = "TestProbeAgent",
33    category = "Test",
34    inputs = ["*"],
35    outputs = []
36)]
37pub struct TestProbeAgent {
38    data: AgentData,
39    tx: mpsc::UnboundedSender<ProbeEvent>,
40    rx: ProbeReceiver,
41}
42
43impl TestProbeAgent {
44    /// Receive next probe event using the instance's own receiver.
45    pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
46        self.rx.recv_with_timeout(duration).await
47    }
48
49    /// Clone the internal receiver so callers can drop agent locks before awaiting.
50    pub fn probe_receiver(&self) -> ProbeReceiver {
51        self.rx.clone()
52    }
53}
54
55/// Helper to fetch the probe receiver for a TestProbeAgent by id.
56pub async fn probe_receiver(askit: &ASKit, agent_id: &str) -> Result<ProbeReceiver, AgentError> {
57    let probe = askit
58        .get_agent(agent_id)
59        .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
60    let probe_guard = probe.lock().await;
61    let probe_agent = probe_guard
62        .as_agent::<TestProbeAgent>()
63        .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
64    Ok(probe_agent.probe_receiver())
65}
66
67/// Await one probe event with timeout on the given receiver.
68pub async fn recv_probe_with_timeout(
69    probe_rec: &ProbeReceiver,
70    duration: Duration,
71) -> Result<ProbeEvent, AgentError> {
72    probe_rec.recv_with_timeout(duration).await
73}
74
75/// Receive one probe event with the default timeout.
76pub async fn recv_probe(probe_rec: &ProbeReceiver) -> Result<ProbeEvent, AgentError> {
77    probe_rec.recv().await
78}
79
80#[async_trait]
81impl AsAgent for TestProbeAgent {
82    fn new(
83        askit: crate::ASKit,
84        id: String,
85        def_name: String,
86        config: Option<AgentConfigs>,
87    ) -> Result<Self, AgentError> {
88        let (tx, rx) = mpsc::unbounded_channel();
89        let rx = ProbeReceiver(Arc::new(AsyncMutex::new(rx)));
90
91        Ok(Self {
92            data: AgentData::new(askit, id, def_name, config),
93            tx,
94            rx,
95        })
96    }
97
98    async fn process(
99        &mut self,
100        ctx: AgentContext,
101        _pin: String,
102        value: AgentValue,
103    ) -> Result<(), AgentError> {
104        // Ignore send failures in tests; probe won't fail the pipeline
105        let _ = self.tx.send((ctx, value));
106        Ok(())
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    use agent_stream_kit::test_utils::TestProbeAgent;
115    use agent_stream_kit::{ASKit, AgentContext, AgentError, AgentValue};
116    use tokio::time::Duration;
117
118    #[tokio::test]
119    async fn probe_receives_in_order() {
120        let askit = ASKit::new();
121        let mut probe = TestProbeAgent::new(
122            askit,
123            "p1".into(),
124            TestProbeAgent::DEF_NAME.to_string(),
125            None,
126        )
127        .unwrap();
128
129        probe
130            .process(AgentContext::new(), "in".into(), AgentValue::integer(1))
131            .await
132            .unwrap();
133        let (_ctx, v1) = probe.probe_receiver().recv().await.unwrap();
134        assert_eq!(v1, AgentValue::integer(1));
135
136        probe
137            .process(AgentContext::new(), "in".into(), AgentValue::integer(2))
138            .await
139            .unwrap();
140        let (_ctx, v2) = probe.probe_receiver().recv().await.unwrap();
141        assert_eq!(v2, AgentValue::integer(2));
142    }
143
144    #[tokio::test]
145    async fn probe_times_out() {
146        let askit = ASKit::new();
147        let probe = TestProbeAgent::new(
148            askit,
149            "p1".into(),
150            TestProbeAgent::DEF_NAME.to_string(),
151            None,
152        )
153        .unwrap();
154        let err = probe
155            .recv_with_timeout(Duration::from_millis(10))
156            .await
157            .unwrap_err();
158        assert!(matches!(err, AgentError::SendMessageFailed(_)));
159    }
160
161    #[tokio::test]
162    async fn probe_receiver_clone_works() {
163        let askit = ASKit::new();
164        let mut probe = TestProbeAgent::new(
165            askit,
166            "p1".into(),
167            TestProbeAgent::DEF_NAME.to_string(),
168            None,
169        )
170        .unwrap();
171        let rx1 = probe.probe_receiver();
172        let rx2 = probe.probe_receiver();
173
174        probe
175            .process(AgentContext::new(), "in".into(), AgentValue::integer(42))
176            .await
177            .unwrap();
178
179        // Either receiver can consume the message (both clone the same inner receiver)
180        let (_ctx, v) = rx1.recv().await.unwrap();
181        assert_eq!(v, AgentValue::integer(42));
182
183        // Ensure timeout when no further messages exist
184        let err = rx2
185            .recv_with_timeout(Duration::from_millis(10))
186            .await
187            .unwrap_err();
188        assert!(matches!(err, AgentError::SendMessageFailed(_)));
189    }
190}