1use std::path::Path;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use async_trait::async_trait;
6use tokio::{
7 sync::{Mutex as AsyncMutex, mpsc},
8 time::timeout,
9};
10
11use crate::{
12 ASKit, ASKitEvent, ASKitObserver, AgentContext, AgentData, AgentError, AgentSpec,
13 AgentStreamSpec, AgentValue, AsAgent, askit_agent,
14};
15
16pub async fn setup_askit() -> ASKit {
18 let askit = ASKit::init().unwrap();
19 askit.ready().await.unwrap();
20
21 subscribe_board_observer(&askit).unwrap();
23
24 askit
25}
26
27pub async fn load_and_start_stream(askit: &ASKit, path: &str) -> Result<String, AgentError> {
29 let stream_json = std::fs::read_to_string(path)
30 .map_err(|e| AgentError::IoError(format!("Failed to read stream file: {}", e)))?;
31 let mut spec = AgentStreamSpec::from_json(&stream_json)?;
32 spec.run_on_start = true;
33 let name = Path::new(path)
34 .file_stem()
35 .and_then(|s| s.to_str())
36 .unwrap_or("stream")
37 .to_string();
38 let id = askit.add_agent_stream(name, spec)?;
39 askit.start_agent_stream(&id).await?;
40 Ok(id)
41}
42
43static BOARD_RX: OnceLock<AsyncMutex<mpsc::UnboundedReceiver<(String, AgentValue)>>> =
46 OnceLock::new();
47
48#[derive(Clone)]
49pub struct BoardObserver {
50 sender: mpsc::UnboundedSender<(String, AgentValue)>,
51}
52
53#[allow(dead_code)]
54impl BoardObserver {
55 pub fn new(sender: mpsc::UnboundedSender<(String, AgentValue)>) -> Self {
56 Self { sender }
57 }
58}
59
60impl ASKitObserver for BoardObserver {
61 fn notify(&self, event: &ASKitEvent) {
62 if let ASKitEvent::Board(name, value) = event {
63 self.sender
64 .send((name.to_string(), value.clone()))
65 .unwrap_or_else(|e| {
66 eprintln!("BoardObserver failed to send board event: {}", e);
67 });
68 }
69 }
70}
71
72pub fn subscribe_board_observer(askit: &ASKit) -> Result<(), AgentError> {
73 let (tx, rx) = mpsc::unbounded_channel();
75 let observer = BoardObserver::new(tx);
76 askit.subscribe(Box::new(observer));
77 BOARD_RX
78 .set(AsyncMutex::new(rx))
79 .map_err(|_| AgentError::SendMessageFailed("board receiver already initialized".into()))
80}
81
82pub const DEFAULT_BOARD_TIMEOUT: Duration = Duration::from_secs(1);
83
84fn board_rx()
85-> Result<&'static AsyncMutex<mpsc::UnboundedReceiver<(String, AgentValue)>>, AgentError> {
86 BOARD_RX
87 .get()
88 .ok_or_else(|| AgentError::SendMessageFailed("board receiver not initialized".into()))
89}
90
91pub async fn recv_board_with_timeout(
92 duration: Duration,
93) -> Result<(String, AgentValue), AgentError> {
94 let rx = board_rx()?;
95 let mut rx = rx.lock().await;
96 timeout(duration, rx.recv())
97 .await
98 .map_err(|_| AgentError::SendMessageFailed("board receive timed out".into()))?
99 .ok_or_else(|| AgentError::SendMessageFailed("board channel closed".into()))
100}
101
102pub async fn expect_board_value(
103 expected_name: &str,
104 expected_value: &AgentValue,
105) -> Result<(), AgentError> {
106 let (name, value) = recv_board_with_timeout(DEFAULT_BOARD_TIMEOUT).await?;
107 if name == expected_name && &value == expected_value {
108 Ok(())
109 } else {
110 Err(AgentError::SendMessageFailed(format!(
111 "expected board '{}' with value {:?}, got '{}' with value {:?}",
112 expected_name, expected_value, name, value
113 )))
114 }
115}
116
117pub async fn expect_var_value(
118 flow_id: &str,
119 var_name: &str,
120 expected_value: &AgentValue,
121) -> Result<(), AgentError> {
122 let expected_name = format!("%{}/{}", flow_id, var_name);
123 expect_board_value(&expected_name, expected_value).await
124}
125
126pub type ProbeEvent = (AgentContext, AgentValue);
129
130pub const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_secs(1);
131
132#[derive(Clone)]
133pub struct ProbeReceiver(Arc<AsyncMutex<mpsc::UnboundedReceiver<ProbeEvent>>>);
134
135impl ProbeReceiver {
136 pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
137 let mut rx = self.0.lock().await;
138 timeout(duration, rx.recv())
139 .await
140 .map_err(|_| AgentError::SendMessageFailed("probe receive timed out".into()))?
141 .ok_or_else(|| AgentError::SendMessageFailed("probe channel closed".into()))
142 }
143
144 pub async fn recv(&self) -> Result<ProbeEvent, AgentError> {
145 self.recv_with_timeout(DEFAULT_PROBE_TIMEOUT).await
146 }
147}
148
149#[askit_agent(
150 title = "TestProbeAgent",
151 category = "Test",
152 inputs = ["*"],
153 outputs = []
154)]
155pub struct TestProbeAgent {
156 data: AgentData,
157 tx: mpsc::UnboundedSender<ProbeEvent>,
158 rx: ProbeReceiver,
159}
160
161impl TestProbeAgent {
162 pub async fn recv_with_timeout(&self, duration: Duration) -> Result<ProbeEvent, AgentError> {
164 self.rx.recv_with_timeout(duration).await
165 }
166
167 pub fn probe_receiver(&self) -> ProbeReceiver {
169 self.rx.clone()
170 }
171}
172
173pub async fn probe_receiver(askit: &ASKit, agent_id: &str) -> Result<ProbeReceiver, AgentError> {
175 let probe = askit
176 .get_agent(agent_id)
177 .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
178 let probe_guard = probe.lock().await;
179 let probe_agent = probe_guard
180 .as_agent::<TestProbeAgent>()
181 .ok_or_else(|| AgentError::AgentNotFound(agent_id.to_string()))?;
182 Ok(probe_agent.probe_receiver())
183}
184
185pub async fn recv_probe_with_timeout(
187 probe_rec: &ProbeReceiver,
188 duration: Duration,
189) -> Result<ProbeEvent, AgentError> {
190 probe_rec.recv_with_timeout(duration).await
191}
192
193pub async fn recv_probe(probe_rec: &ProbeReceiver) -> Result<ProbeEvent, AgentError> {
195 probe_rec.recv().await
196}
197
198#[async_trait]
199impl AsAgent for TestProbeAgent {
200 fn new(askit: crate::ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
201 let (tx, rx) = mpsc::unbounded_channel();
202 let rx = ProbeReceiver(Arc::new(AsyncMutex::new(rx)));
203
204 Ok(Self {
205 data: AgentData::new(askit, id, spec),
206 tx,
207 rx,
208 })
209 }
210
211 async fn process(
212 &mut self,
213 ctx: AgentContext,
214 _pin: String,
215 value: AgentValue,
216 ) -> Result<(), AgentError> {
217 let _ = self.tx.send((ctx, value));
219 Ok(())
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 use agent_stream_kit::test_utils::TestProbeAgent;
228 use agent_stream_kit::{ASKit, AgentContext, AgentError, AgentValue};
229 use tokio::time::Duration;
230
231 #[tokio::test]
232 async fn probe_receives_in_order() {
233 let askit = ASKit::new();
234 let def = TestProbeAgent::agent_definition();
235 let spec = AgentSpec::from_def(&def);
236 let mut probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
237
238 probe
239 .process(AgentContext::new(), "in".into(), AgentValue::integer(1))
240 .await
241 .unwrap();
242 let (_ctx, v1) = probe.probe_receiver().recv().await.unwrap();
243 assert_eq!(v1, AgentValue::integer(1));
244
245 probe
246 .process(AgentContext::new(), "in".into(), AgentValue::integer(2))
247 .await
248 .unwrap();
249 let (_ctx, v2) = probe.probe_receiver().recv().await.unwrap();
250 assert_eq!(v2, AgentValue::integer(2));
251 }
252
253 #[tokio::test]
254 async fn probe_times_out() {
255 let askit = ASKit::new();
256 let def = TestProbeAgent::agent_definition();
257 let spec = AgentSpec::from_def(&def);
258 let probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
259 let err = probe
260 .recv_with_timeout(Duration::from_millis(10))
261 .await
262 .unwrap_err();
263 assert!(matches!(err, AgentError::SendMessageFailed(_)));
264 }
265
266 #[tokio::test]
267 async fn probe_receiver_clone_works() {
268 let askit = ASKit::new();
269 let def = TestProbeAgent::agent_definition();
270 let spec = AgentSpec::from_def(&def);
271 let mut probe = TestProbeAgent::new(askit, "p1".into(), spec).unwrap();
272 let rx1 = probe.probe_receiver();
273 let rx2 = probe.probe_receiver();
274
275 probe
276 .process(AgentContext::new(), "in".into(), AgentValue::integer(42))
277 .await
278 .unwrap();
279
280 let (_ctx, v) = rx1.recv().await.unwrap();
282 assert_eq!(v, AgentValue::integer(42));
283
284 let err = rx2
286 .recv_with_timeout(Duration::from_millis(10))
287 .await
288 .unwrap_err();
289 assert!(matches!(err, AgentError::SendMessageFailed(_)));
290 }
291}