1use std::time::Duration;
2
3#[cfg(test)]
4use std::sync::Arc;
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use uuid::Uuid;
10
11use crate::manager::FlowManager;
12use crate::types::{Flow, FlowError, FlowStatus};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(tag = "kind", rename_all = "snake_case")]
22pub enum WaitCondition {
23 Timer { at: DateTime<Utc> },
25 ExternalEvent {
27 topic: String,
28 correlation_id: String,
29 },
30 Manual,
32}
33
34impl WaitCondition {
35 pub fn into_value(self) -> Value {
36 serde_json::to_value(self).expect("WaitCondition is always serializable")
37 }
38
39 pub fn from_value(v: &Value) -> Option<Self> {
40 serde_json::from_value(v.clone()).ok()
41 }
42}
43
44#[derive(Clone)]
52pub struct WaitEngine {
53 manager: FlowManager,
54}
55
56#[derive(Debug, Default, Clone)]
57pub struct TickReport {
58 pub scanned: usize,
59 pub resumed: usize,
60 pub cancelled: usize,
61 pub still_waiting: usize,
62 pub errors: usize,
63}
64
65impl WaitEngine {
66 pub fn new(manager: FlowManager) -> Self {
67 Self { manager }
68 }
69
70 pub fn manager(&self) -> &FlowManager {
71 &self.manager
72 }
73
74 pub async fn tick_at(&self, now: DateTime<Utc>) -> TickReport {
78 let mut report = TickReport::default();
79 let waiting = match self.manager.list_by_status(FlowStatus::Waiting).await {
80 Ok(v) => v,
81 Err(e) => {
82 tracing::warn!(error = %e, "wait engine: failed to list waiting flows");
83 report.errors += 1;
84 return report;
85 }
86 };
87 report.scanned = waiting.len();
88
89 for flow in waiting {
90 match self.evaluate(&flow, now).await {
91 Outcome::Resume => match self.manager.resume(flow.id, None).await {
92 Ok(_) => report.resumed += 1,
93 Err(FlowError::CancelPending { .. }) => {
94 match self.manager.cancel(flow.id).await {
96 Ok(_) => report.cancelled += 1,
97 Err(e) => {
98 tracing::warn!(flow_id = %flow.id, error = %e, "wait engine: cancel after CancelPending failed");
99 report.errors += 1;
100 }
101 }
102 }
103 Err(e) => {
104 tracing::warn!(flow_id = %flow.id, error = %e, "wait engine: resume failed");
105 report.errors += 1;
106 }
107 },
108 Outcome::Cancel => match self.manager.cancel(flow.id).await {
109 Ok(_) => report.cancelled += 1,
110 Err(e) => {
111 tracing::warn!(flow_id = %flow.id, error = %e, "wait engine: cancel failed");
112 report.errors += 1;
113 }
114 },
115 Outcome::Wait => {
116 report.still_waiting += 1;
117 }
118 Outcome::Skip(reason) => {
119 tracing::debug!(flow_id = %flow.id, reason, "wait engine: skipping flow");
120 }
121 }
122 }
123 report
124 }
125
126 pub async fn tick(&self) -> TickReport {
127 self.tick_at(Utc::now()).await
128 }
129
130 pub async fn try_resume_external(
137 &self,
138 flow_id: Uuid,
139 topic: &str,
140 correlation_id: &str,
141 payload: Option<Value>,
142 ) -> Result<Option<Flow>, FlowError> {
143 let Some(flow) = self.manager.get(flow_id).await? else {
144 return Ok(None);
145 };
146 if flow.status != FlowStatus::Waiting {
147 return Ok(None);
148 }
149 let cond = match flow.wait_json.as_ref().and_then(WaitCondition::from_value) {
150 Some(c) => c,
151 None => return Ok(None),
152 };
153 let matches = matches!(
154 &cond,
155 WaitCondition::ExternalEvent { topic: t, correlation_id: c }
156 if t == topic && c == correlation_id
157 );
158 if !matches {
159 return Ok(None);
160 }
161 let patch = payload.map(|p| json!({ "resume_event": p }));
162 self.manager.resume(flow.id, patch).await.map(Some)
163 }
164
165 pub async fn run(&self, interval: Duration, shutdown: tokio_util::sync::CancellationToken) {
168 let mut interval_timer = tokio::time::interval(interval);
169 interval_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
170 loop {
171 tokio::select! {
172 _ = shutdown.cancelled() => {
173 tracing::info!("wait engine: shutdown requested");
174 return;
175 }
176 _ = interval_timer.tick() => {
177 let report = self.tick().await;
178 if report.scanned > 0 {
179 tracing::debug!(
180 scanned = report.scanned,
181 resumed = report.resumed,
182 cancelled = report.cancelled,
183 still_waiting = report.still_waiting,
184 errors = report.errors,
185 "wait engine tick"
186 );
187 }
188 }
189 }
190 }
191 }
192
193 async fn evaluate(&self, flow: &Flow, now: DateTime<Utc>) -> Outcome {
194 if flow.cancel_requested {
196 return Outcome::Cancel;
197 }
198 let Some(wait_value) = flow.wait_json.as_ref() else {
199 return Outcome::Skip("missing wait_json");
200 };
201 let Some(cond) = WaitCondition::from_value(wait_value) else {
202 return Outcome::Skip("unparseable wait_json");
203 };
204 match cond {
205 WaitCondition::Timer { at } => {
206 if now >= at {
207 Outcome::Resume
208 } else {
209 Outcome::Wait
210 }
211 }
212 WaitCondition::ExternalEvent { .. } | WaitCondition::Manual => Outcome::Wait,
215 }
216 }
217}
218
219enum Outcome {
220 Resume,
221 Cancel,
222 Wait,
223 Skip(&'static str),
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::manager::CreateManagedInput;
230 use crate::store::SqliteFlowStore;
231 use chrono::Duration as ChronoDuration;
232
233 async fn engine() -> WaitEngine {
234 let store = Arc::new(SqliteFlowStore::open(":memory:").await.unwrap());
235 WaitEngine::new(FlowManager::new(store))
236 }
237
238 fn input() -> CreateManagedInput {
239 CreateManagedInput {
240 controller_id: "test".into(),
241 goal: "test".into(),
242 owner_session_key: "owner".into(),
243 requester_origin: "user".into(),
244 current_step: "init".into(),
245 state_json: json!({}),
246 }
247 }
248
249 async fn put_into_waiting(eng: &WaitEngine, cond: WaitCondition) -> Flow {
250 let m = eng.manager();
251 let f = m.create_managed(input()).await.unwrap();
252 let f = m.start_running(f.id).await.unwrap();
253 m.set_waiting(f.id, cond.into_value()).await.unwrap()
254 }
255
256 #[tokio::test]
257 async fn timer_fires_when_now_past_deadline() {
258 let eng = engine().await;
259 let past = Utc::now() - ChronoDuration::seconds(60);
260 let f = put_into_waiting(&eng, WaitCondition::Timer { at: past }).await;
261
262 let report = eng.tick().await;
263 assert_eq!(report.scanned, 1);
264 assert_eq!(report.resumed, 1);
265
266 let after = eng.manager().get(f.id).await.unwrap().unwrap();
267 assert_eq!(after.status, FlowStatus::Running);
268 assert!(after.wait_json.is_none());
269 }
270
271 #[tokio::test]
272 async fn timer_does_not_fire_before_deadline() {
273 let eng = engine().await;
274 let future = Utc::now() + ChronoDuration::seconds(60);
275 let f = put_into_waiting(&eng, WaitCondition::Timer { at: future }).await;
276
277 let report = eng.tick().await;
278 assert_eq!(report.scanned, 1);
279 assert_eq!(report.resumed, 0);
280 assert_eq!(report.still_waiting, 1);
281
282 let after = eng.manager().get(f.id).await.unwrap().unwrap();
283 assert_eq!(after.status, FlowStatus::Waiting);
284 }
285
286 #[tokio::test]
287 async fn external_event_matches_resumes() {
288 let eng = engine().await;
289 let f = put_into_waiting(
290 &eng,
291 WaitCondition::ExternalEvent {
292 topic: "agent.delegate.reply".into(),
293 correlation_id: "corr-42".into(),
294 },
295 )
296 .await;
297
298 let report = eng.tick().await;
300 assert_eq!(report.resumed, 0);
301 assert_eq!(report.still_waiting, 1);
302
303 let resumed = eng
304 .try_resume_external(
305 f.id,
306 "agent.delegate.reply",
307 "corr-42",
308 Some(json!({"answer": 42})),
309 )
310 .await
311 .unwrap()
312 .expect("resumed");
313 assert_eq!(resumed.status, FlowStatus::Running);
314 assert!(resumed.wait_json.is_none());
315 assert_eq!(resumed.state_json["resume_event"]["answer"], 42);
316 }
317
318 #[tokio::test]
319 async fn external_event_with_wrong_topic_or_id_is_noop() {
320 let eng = engine().await;
321 let f = put_into_waiting(
322 &eng,
323 WaitCondition::ExternalEvent {
324 topic: "topic-A".into(),
325 correlation_id: "id-1".into(),
326 },
327 )
328 .await;
329
330 let r1 = eng
332 .try_resume_external(f.id, "topic-B", "id-1", None)
333 .await
334 .unwrap();
335 assert!(r1.is_none());
336
337 let r2 = eng
339 .try_resume_external(f.id, "topic-A", "id-99", None)
340 .await
341 .unwrap();
342 assert!(r2.is_none());
343
344 let after = eng.manager().get(f.id).await.unwrap().unwrap();
345 assert_eq!(after.status, FlowStatus::Waiting);
346 }
347
348 #[tokio::test]
349 async fn manual_wait_ignored_by_tick() {
350 let eng = engine().await;
351 let f = put_into_waiting(&eng, WaitCondition::Manual).await;
352 let report = eng.tick().await;
353 assert_eq!(report.scanned, 1);
354 assert_eq!(report.resumed, 0);
355 assert_eq!(report.still_waiting, 1);
356 let after = eng.manager().get(f.id).await.unwrap().unwrap();
357 assert_eq!(after.status, FlowStatus::Waiting);
358 }
359
360 #[tokio::test]
361 async fn cancel_requested_waiting_flips_to_cancelled_on_tick() {
362 let eng = engine().await;
363 let future = Utc::now() + ChronoDuration::seconds(60);
364 let f = put_into_waiting(&eng, WaitCondition::Timer { at: future }).await;
365 eng.manager().request_cancel(f.id).await.unwrap();
366
367 let report = eng.tick().await;
368 assert_eq!(report.cancelled, 1);
369 assert_eq!(report.resumed, 0);
370
371 let after = eng.manager().get(f.id).await.unwrap().unwrap();
372 assert_eq!(after.status, FlowStatus::Cancelled);
373 }
374
375 #[tokio::test]
376 async fn run_loop_can_be_shut_down() {
377 let eng = engine().await;
378 let token = tokio_util::sync::CancellationToken::new();
379 let token_clone = token.clone();
380 let eng_clone = eng.clone();
381 let handle = tokio::spawn(async move {
382 eng_clone.run(Duration::from_millis(20), token_clone).await;
383 });
384 tokio::time::sleep(Duration::from_millis(60)).await;
386 token.cancel();
387 let r = tokio::time::timeout(Duration::from_millis(200), handle).await;
389 assert!(r.is_ok(), "engine did not shut down promptly");
390 }
391
392 #[tokio::test]
393 async fn try_resume_external_on_unknown_flow_is_noop() {
394 let eng = engine().await;
395 let r = eng
396 .try_resume_external(Uuid::new_v4(), "t", "c", None)
397 .await
398 .unwrap();
399 assert!(r.is_none());
400 }
401
402 #[tokio::test]
403 async fn try_resume_external_on_running_flow_is_noop() {
404 let eng = engine().await;
405 let m = eng.manager();
406 let f = m.create_managed(input()).await.unwrap();
407 let f = m.start_running(f.id).await.unwrap();
408 let r = eng.try_resume_external(f.id, "t", "c", None).await.unwrap();
409 assert!(r.is_none(), "should ignore non-waiting flows");
410 }
411
412 #[test]
413 fn wait_condition_round_trip() {
414 let original = WaitCondition::Timer { at: Utc::now() };
415 let v = original.clone().into_value();
416 let parsed = WaitCondition::from_value(&v).expect("round trip");
417 match parsed {
418 WaitCondition::Timer { .. } => {}
419 other => panic!("wrong variant: {other:?}"),
420 }
421 }
422}