agent_stream_kit/
test_utils.rs

1use std::path::Path;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use async_trait::async_trait;
6use tokio::{
7    sync::{Mutex as AsyncMutex, mpsc},
8    time::timeout,
9};
10
11use crate::{
12    ASKit, ASKitEvent, ASKitObserver, AgentContext, AgentData, AgentError, AgentSpec,
13    AgentStreamSpec, AgentValue, AsAgent, askit_agent,
14};
15
16/// Setting up ASKit
17pub async fn setup_askit() -> ASKit {
18    let askit = ASKit::init().unwrap();
19    askit.ready().await.unwrap();
20
21    // set an observer to receive board events
22    subscribe_board_observer(&askit).unwrap();
23
24    askit
25}
26
27/// Load and start an agent stream from a file.
28pub async fn load_and_start_stream(askit: &ASKit, path: &str) -> Result<String, AgentError> {
29    let stream_json = std::fs::read_to_string(path)
30        .map_err(|e| AgentError::IoError(format!("Failed to read stream file: {}", e)))?;
31    let mut spec = AgentStreamSpec::from_json(&stream_json)?;
32    spec.run_on_start = true;
33    let name = Path::new(path)
34        .file_stem()
35        .and_then(|s| s.to_str())
36        .unwrap_or("stream")
37        .to_string();
38    let id = askit.add_agent_stream(name, spec)?;
39    askit.start_agent_stream(&id).await?;
40    Ok(id)
41}
42
43// BoardObserver
44
45static BOARD_RX: OnceLock<AsyncMutex<mpsc::UnboundedReceiver<(String, AgentValue)>>> =
46    OnceLock::new();
47
48#[derive(Clone)]
49pub struct BoardObserver {
50    sender: mpsc::UnboundedSender<(String, AgentValue)>,
51}
52
53#[allow(dead_code)]
54impl BoardObserver {
55    pub fn new(sender: mpsc::UnboundedSender<(String, AgentValue)>) -> Self {
56        Self { sender }
57    }
58}
59
60impl ASKitObserver for BoardObserver {
61    fn notify(&self, event: &ASKitEvent) {
62        if let ASKitEvent::Board(name, value) = event {
63            self.sender
64                .send((name.to_string(), value.clone()))
65                .unwrap_or_else(|e| {
66                    eprintln!("BoardObserver failed to send board event: {}", e);
67                });
68        }
69    }
70}
71
72pub fn subscribe_board_observer(askit: &ASKit) -> Result<(), AgentError> {
73    // set an observer to receive board events
74    let (tx, rx) = mpsc::unbounded_channel();
75    let observer = BoardObserver::new(tx);
76    askit.subscribe(Box::new(observer));
77    BOARD_RX
78        .set(AsyncMutex::new(rx))
79        .map_err(|_| AgentError::SendMessageFailed("board receiver already initialized".into()))
80}
81
82pub const DEFAULT_BOARD_TIMEOUT: Duration = Duration::from_secs(1);
83
84fn board_rx()
85-> Result<&'static AsyncMutex<mpsc::UnboundedReceiver<(String, AgentValue)>>, AgentError> {
86    BOARD_RX
87        .get()
88        .ok_or_else(|| AgentError::SendMessageFailed("board receiver not initialized".into()))
89}
90
91pub async fn recv_board_with_timeout(
92    duration: Duration,
93) -> Result<(String, AgentValue), AgentError> {
94    let rx = board_rx()?;
95    let mut rx = rx.lock().await;
96    timeout(duration, rx.recv())
97        .await
98        .map_err(|_| AgentError::SendMessageFailed("board receive timed out".into()))?
99        .ok_or_else(|| AgentError::SendMessageFailed("board channel closed".into()))
100}
101
102pub async fn expect_board_value(
103    expected_name: &str,
104    expected_value: &AgentValue,
105) -> Result<(), AgentError> {
106    let (name, value) = recv_board_with_timeout(DEFAULT_BOARD_TIMEOUT).await?;
107    if name == expected_name && &value == expected_value {
108        Ok(())
109    } else {
110        Err(AgentError::SendMessageFailed(format!(
111            "expected board '{}' with value {:?}, got '{}' with value {:?}",
112            expected_name, expected_value, name, value
113        )))
114    }
115}
116
117pub async fn expect_var_value(
118    flow_id: &str,
119    var_name: &str,
120    expected_value: &AgentValue,
121) -> Result<(), AgentError> {
122    let expected_name = format!("%{}/{}", flow_id, var_name);
123    expect_board_value(&expected_name, expected_value).await
124}
125
126// TestProbeAgent
127
128pub type ProbeEvent = (AgentContext, AgentValue);
129
130pub const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_secs(1);
131
132#[derive(Clone)]
133pub struct ProbeReceiver(Arc<AsyncMutex<mpsc::UnboundedReceiver<ProbeEvent>>>);
134
135impl ProbeReceiver {
136    pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
137        let mut rx = self.0.lock().await;
138        timeout(duration, rx.recv())
139            .await
140            .map_err(|_| AgentError::SendMessageFailed("probe receive timed out".into()))?
141            .ok_or_else(|| AgentError::SendMessageFailed("probe channel closed".into()))
142    }
143
144    pub async fn recv(&self) -> Result<ProbeEvent, AgentError> {
145        self.recv_with_timeout(DEFAULT_PROBE_TIMEOUT).await
146    }
147}
148
149#[askit_agent(
150    title = "TestProbeAgent",
151    category = "Test",
152    inputs = ["*"],
153    outputs = []
154)]
155pub struct TestProbeAgent {
156    data: AgentData,
157    tx: mpsc::UnboundedSender<ProbeEvent>,
158    rx: ProbeReceiver,
159}
160
161impl TestProbeAgent {
162    /// Receive next probe event using the instance's own receiver.
163    pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
164        self.rx.recv_with_timeout(duration).await
165    }
166
167    /// Clone the internal receiver so callers can drop agent locks before awaiting.
168    pub fn probe_receiver(&self) -> ProbeReceiver {
169        self.rx.clone()
170    }
171}
172
173/// Helper to fetch the probe receiver for a TestProbeAgent by id.
174pub async fn probe_receiver(askit: &ASKit, agent_id: &str) -> Result<ProbeReceiver, AgentError> {
175    let probe = askit
176        .get_agent(agent_id)
177        .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
178    let probe_guard = probe.lock().await;
179    let probe_agent = probe_guard
180        .as_agent::<TestProbeAgent>()
181        .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
182    Ok(probe_agent.probe_receiver())
183}
184
185/// Await one probe event with timeout on the given receiver.
186pub async fn recv_probe_with_timeout(
187    probe_rec: &ProbeReceiver,
188    duration: Duration,
189) -> Result<ProbeEvent, AgentError> {
190    probe_rec.recv_with_timeout(duration).await
191}
192
193/// Receive one probe event with the default timeout.
194pub async fn recv_probe(probe_rec: &ProbeReceiver) -> Result<ProbeEvent, AgentError> {
195    probe_rec.recv().await
196}
197
198#[async_trait]
199impl AsAgent for TestProbeAgent {
200    fn new(askit: crate::ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
201        let (tx, rx) = mpsc::unbounded_channel();
202        let rx = ProbeReceiver(Arc::new(AsyncMutex::new(rx)));
203
204        Ok(Self {
205            data: AgentData::new(askit, id, spec),
206            tx,
207            rx,
208        })
209    }
210
211    async fn process(
212        &mut self,
213        ctx: AgentContext,
214        _pin: String,
215        value: AgentValue,
216    ) -> Result<(), AgentError> {
217        // Ignore send failures in tests; probe won't fail the pipeline
218        let _ = self.tx.send((ctx, value));
219        Ok(())
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    use agent_stream_kit::test_utils::TestProbeAgent;
228    use agent_stream_kit::{ASKit, AgentContext, AgentError, AgentValue};
229    use tokio::time::Duration;
230
231    #[tokio::test]
232    async fn probe_receives_in_order() {
233        let askit = ASKit::new();
234        let def = TestProbeAgent::agent_definition();
235        let spec = AgentSpec::from_def(&def);
236        let mut probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
237
238        probe
239            .process(AgentContext::new(), "in".into(), AgentValue::integer(1))
240            .await
241            .unwrap();
242        let (_ctx, v1) = probe.probe_receiver().recv().await.unwrap();
243        assert_eq!(v1, AgentValue::integer(1));
244
245        probe
246            .process(AgentContext::new(), "in".into(), AgentValue::integer(2))
247            .await
248            .unwrap();
249        let (_ctx, v2) = probe.probe_receiver().recv().await.unwrap();
250        assert_eq!(v2, AgentValue::integer(2));
251    }
252
253    #[tokio::test]
254    async fn probe_times_out() {
255        let askit = ASKit::new();
256        let def = TestProbeAgent::agent_definition();
257        let spec = AgentSpec::from_def(&def);
258        let probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
259        let err = probe
260            .recv_with_timeout(Duration::from_millis(10))
261            .await
262            .unwrap_err();
263        assert!(matches!(err, AgentError::SendMessageFailed(_)));
264    }
265
266    #[tokio::test]
267    async fn probe_receiver_clone_works() {
268        let askit = ASKit::new();
269        let def = TestProbeAgent::agent_definition();
270        let spec = AgentSpec::from_def(&def);
271        let mut probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
272        let rx1 = probe.probe_receiver();
273        let rx2 = probe.probe_receiver();
274
275        probe
276            .process(AgentContext::new(), "in".into(), AgentValue::integer(42))
277            .await
278            .unwrap();
279
280        // Either receiver can consume the message (both clone the same inner receiver)
281        let (_ctx, v) = rx1.recv().await.unwrap();
282        assert_eq!(v, AgentValue::integer(42));
283
284        // Ensure timeout when no further messages exist
285        let err = rx2
286            .recv_with_timeout(Duration::from_millis(10))
287            .await
288            .unwrap_err();
289        assert!(matches!(err, AgentError::SendMessageFailed(_)));
290    }
291}