1use std::collections::HashMap;
4use tokio::sync::Mutex;
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8
9use rustvello_core::error::RustvelloResult;
10use rustvello_core::trigger::TriggerStore;
11use rustvello_proto::identifiers::TaskId;
12use rustvello_proto::trigger::{
13 ConditionId, TriggerCondition, TriggerDefinitionDTO, TriggerDefinitionId, TriggerRunId,
14 ValidCondition,
15};
16
17struct TriggerState {
18 conditions: HashMap<String, TriggerCondition>,
19 source_task_conditions: HashMap<String, Vec<ConditionId>>,
21 event_conditions: HashMap<String, Vec<ConditionId>>,
23 cron_condition_ids: Vec<ConditionId>,
25 triggers: HashMap<String, TriggerDefinitionDTO>,
26 condition_triggers: HashMap<String, Vec<TriggerDefinitionId>>,
28 valid_conditions: HashMap<String, ValidCondition>,
29 cron_executions: HashMap<String, DateTime<Utc>>,
30 trigger_run_claims: HashMap<String, DateTime<Utc>>,
31}
32
33pub struct MemTriggerStore {
35 state: Mutex<TriggerState>,
36}
37
38impl MemTriggerStore {
39 pub fn new() -> Self {
40 Self {
41 state: Mutex::new(TriggerState {
42 conditions: HashMap::new(),
43 source_task_conditions: HashMap::new(),
44 event_conditions: HashMap::new(),
45 cron_condition_ids: Vec::new(),
46 triggers: HashMap::new(),
47 condition_triggers: HashMap::new(),
48 valid_conditions: HashMap::new(),
49 cron_executions: HashMap::new(),
50 trigger_run_claims: HashMap::new(),
51 }),
52 }
53 }
54}
55
56impl Default for MemTriggerStore {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62#[async_trait]
63impl TriggerStore for MemTriggerStore {
64 async fn register_condition(
65 &self,
66 condition: &TriggerCondition,
67 ) -> RustvelloResult<ConditionId> {
68 let cond_id = condition.condition_id();
69 let mut state = self.state.lock().await;
70
71 state
72 .conditions
73 .insert(cond_id.as_str().to_owned(), condition.clone());
74
75 for task_id in condition.source_task_ids() {
77 let vec = state
78 .source_task_conditions
79 .entry(task_id.to_string())
80 .or_default();
81 if !vec.contains(&cond_id) {
82 vec.push(cond_id.clone());
83 }
84 }
85
86 if let TriggerCondition::Event(evt) = condition {
88 let vec = state
89 .event_conditions
90 .entry(evt.event_code.clone())
91 .or_default();
92 if !vec.contains(&cond_id) {
93 vec.push(cond_id.clone());
94 }
95 }
96
97 if matches!(condition, TriggerCondition::Cron(_))
99 && !state.cron_condition_ids.contains(&cond_id)
100 {
101 state.cron_condition_ids.push(cond_id.clone());
102 }
103
104 Ok(cond_id)
105 }
106
107 async fn get_condition(&self, id: &ConditionId) -> RustvelloResult<Option<TriggerCondition>> {
108 let state = self.state.lock().await;
109 Ok(state.conditions.get(id.as_str()).cloned())
110 }
111
112 async fn get_conditions_for_task(
113 &self,
114 task_id: &TaskId,
115 ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
116 let state = self.state.lock().await;
117 let key = task_id.to_string();
118 let cond_ids = state.source_task_conditions.get(&key);
119
120 let mut result = Vec::new();
121 if let Some(ids) = cond_ids {
122 for cid in ids {
123 if let Some(cond) = state.conditions.get(cid.as_str()) {
124 result.push((cid.clone(), cond.clone()));
125 }
126 }
127 }
128 Ok(result)
129 }
130
131 async fn get_cron_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
132 let state = self.state.lock().await;
133 let mut result = Vec::new();
134 for cid in &state.cron_condition_ids {
135 if let Some(cond) = state.conditions.get(cid.as_str()) {
136 result.push((cid.clone(), cond.clone()));
137 }
138 }
139 Ok(result)
140 }
141
142 async fn get_event_conditions(
143 &self,
144 event_code: &str,
145 ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
146 let state = self.state.lock().await;
147 let cond_ids = state.event_conditions.get(event_code);
148
149 let mut result = Vec::new();
150 if let Some(ids) = cond_ids {
151 for cid in ids {
152 if let Some(cond) = state.conditions.get(cid.as_str()) {
153 result.push((cid.clone(), cond.clone()));
154 }
155 }
156 }
157 Ok(result)
158 }
159
160 async fn register_trigger(&self, trigger: &TriggerDefinitionDTO) -> RustvelloResult<()> {
161 let mut state = self.state.lock().await;
162
163 state
164 .triggers
165 .insert(trigger.trigger_id.as_str().to_owned(), trigger.clone());
166
167 for cid in &trigger.condition_ids {
169 state
170 .condition_triggers
171 .entry(cid.as_str().to_owned())
172 .or_default()
173 .push(trigger.trigger_id.clone());
174 }
175
176 Ok(())
177 }
178
179 async fn get_trigger(
180 &self,
181 id: &TriggerDefinitionId,
182 ) -> RustvelloResult<Option<TriggerDefinitionDTO>> {
183 let state = self.state.lock().await;
184 Ok(state.triggers.get(id.as_str()).cloned())
185 }
186
187 async fn get_triggers_for_condition(
188 &self,
189 cond_id: &ConditionId,
190 ) -> RustvelloResult<Vec<TriggerDefinitionDTO>> {
191 let state = self.state.lock().await;
192 let trigger_ids = state.condition_triggers.get(cond_id.as_str());
193
194 let mut result = Vec::new();
195 if let Some(ids) = trigger_ids {
196 for tid in ids {
197 if let Some(trigger) = state.triggers.get(tid.as_str()) {
198 result.push(trigger.clone());
199 }
200 }
201 }
202 Ok(result)
203 }
204
205 async fn remove_triggers_for_task(&self, task_id: &TaskId) -> RustvelloResult<u32> {
206 let mut state = self.state.lock().await;
207 let task_str = task_id.to_string();
208
209 let ids_to_remove: Vec<String> = state
210 .triggers
211 .iter()
212 .filter(|(_, t)| t.task_id.to_string() == task_str)
213 .map(|(id, _)| id.clone())
214 .collect();
215
216 let count = u32::try_from(ids_to_remove.len()).unwrap_or(u32::MAX);
217 for id in &ids_to_remove {
218 if let Some(trigger) = state.triggers.remove(id) {
219 for cid in &trigger.condition_ids {
221 if let Some(tids) = state.condition_triggers.get_mut(cid.as_str()) {
222 tids.retain(|tid| tid.as_str() != *id);
223 }
224 }
225 }
226 }
227
228 Ok(count)
229 }
230
231 async fn record_valid_condition(&self, vc: &ValidCondition) -> RustvelloResult<()> {
232 let mut state = self.state.lock().await;
233 state
234 .valid_conditions
235 .insert(vc.valid_condition_id.clone(), vc.clone());
236 Ok(())
237 }
238
239 async fn get_valid_conditions(&self) -> RustvelloResult<Vec<ValidCondition>> {
240 let state = self.state.lock().await;
241 Ok(state.valid_conditions.values().cloned().collect())
242 }
243
244 async fn clear_valid_conditions(&self, ids: &[String]) -> RustvelloResult<()> {
245 let mut state = self.state.lock().await;
246 for id in ids {
247 state.valid_conditions.remove(id);
248 }
249 Ok(())
250 }
251
252 async fn get_last_cron_execution(
253 &self,
254 cond_id: &ConditionId,
255 ) -> RustvelloResult<Option<DateTime<Utc>>> {
256 let state = self.state.lock().await;
257 Ok(state.cron_executions.get(cond_id.as_str()).copied())
258 }
259
260 async fn store_cron_execution(
261 &self,
262 cond_id: &ConditionId,
263 time: DateTime<Utc>,
264 expected_last: Option<DateTime<Utc>>,
265 ) -> RustvelloResult<bool> {
266 let mut state = self.state.lock().await;
267 let current = state.cron_executions.get(cond_id.as_str()).copied();
268
269 if current == expected_last {
271 state
272 .cron_executions
273 .insert(cond_id.as_str().to_owned(), time);
274 Ok(true)
275 } else {
276 Ok(false)
277 }
278 }
279
280 async fn claim_trigger_run(&self, run_id: &TriggerRunId) -> RustvelloResult<bool> {
281 let mut state = self.state.lock().await;
282 if state.trigger_run_claims.contains_key(run_id.as_str()) {
283 Ok(false)
284 } else {
285 state
286 .trigger_run_claims
287 .insert(run_id.as_str().to_owned(), Utc::now());
288 Ok(true)
289 }
290 }
291
292 async fn purge(&self) -> RustvelloResult<()> {
293 let mut state = self.state.lock().await;
294 state.conditions.clear();
295 state.source_task_conditions.clear();
296 state.event_conditions.clear();
297 state.cron_condition_ids.clear();
298 state.triggers.clear();
299 state.condition_triggers.clear();
300 state.valid_conditions.clear();
301 state.cron_executions.clear();
302 state.trigger_run_claims.clear();
303 Ok(())
304 }
305
306 async fn get_all_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
307 let state = self.state.lock().await;
308 Ok(state
309 .conditions
310 .iter()
311 .map(|(id, cond)| (ConditionId::from(id.clone()), cond.clone()))
312 .collect())
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use rustvello_proto::trigger::*;
320
321 #[tokio::test]
322 async fn register_and_get_condition() {
323 let store = MemTriggerStore::new();
324 let cond = TriggerCondition::Event(EventCondition {
325 event_code: "payment".to_string(),
326 payload_filter: None,
327 });
328 let id = store.register_condition(&cond).await.unwrap();
329 let got = store.get_condition(&id).await.unwrap();
330 assert!(got.is_some());
331 assert_eq!(got.unwrap().condition_id(), id);
332 }
333
334 #[tokio::test]
335 async fn get_conditions_for_task() {
336 let store = MemTriggerStore::new();
337 let task_id = TaskId::new("mod", "task");
338 let cond = TriggerCondition::Status(StatusCondition {
339 task_id: task_id.clone(),
340 statuses: vec![rustvello_proto::status::InvocationStatus::Success],
341 argument_filter: None,
342 });
343 store.register_condition(&cond).await.unwrap();
344
345 let conds = store.get_conditions_for_task(&task_id).await.unwrap();
346 assert_eq!(conds.len(), 1);
347
348 let other = TaskId::new("mod", "other");
349 let conds = store.get_conditions_for_task(&other).await.unwrap();
350 assert!(conds.is_empty());
351 }
352
353 #[tokio::test]
354 async fn get_event_conditions() {
355 let store = MemTriggerStore::new();
356 let cond = TriggerCondition::Event(EventCondition {
357 event_code: "payment".to_string(),
358 payload_filter: None,
359 });
360 store.register_condition(&cond).await.unwrap();
361
362 let got = store.get_event_conditions("payment").await.unwrap();
363 assert_eq!(got.len(), 1);
364
365 let got = store.get_event_conditions("other").await.unwrap();
366 assert!(got.is_empty());
367 }
368
369 #[tokio::test]
370 async fn get_cron_conditions() {
371 let store = MemTriggerStore::new();
372 let cond = TriggerCondition::Cron(CronCondition {
373 cron_expression: "* * * * *".to_string(),
374 min_interval_seconds: 50,
375 });
376 store.register_condition(&cond).await.unwrap();
377
378 let conds = store.get_cron_conditions().await.unwrap();
379 assert_eq!(conds.len(), 1);
380 }
381
382 #[tokio::test]
383 async fn register_and_get_trigger() {
384 let store = MemTriggerStore::new();
385 let task_id = TaskId::new("mod", "target");
386 let cond_ids = vec![ConditionId::from("c1".to_string())];
387 let trigger_id =
388 TriggerDefinitionDTO::compute_trigger_id(&task_id, &cond_ids, TriggerLogic::Or);
389
390 let trigger = TriggerDefinitionDTO {
391 trigger_id: trigger_id.clone(),
392 task_id,
393 condition_ids: cond_ids,
394 logic: TriggerLogic::Or,
395 argument_template: None,
396 };
397 store.register_trigger(&trigger).await.unwrap();
398
399 let got = store.get_trigger(&trigger_id).await.unwrap();
400 assert!(got.is_some());
401 }
402
403 #[tokio::test]
404 async fn get_triggers_for_condition() {
405 let store = MemTriggerStore::new();
406 let cond_id = ConditionId::from("c1".to_string());
407 let task_id = TaskId::new("mod", "target");
408 let trigger = TriggerDefinitionDTO {
409 trigger_id: TriggerDefinitionDTO::compute_trigger_id(
410 &task_id,
411 &[cond_id.clone()],
412 TriggerLogic::Or,
413 ),
414 task_id,
415 condition_ids: vec![cond_id.clone()],
416 logic: TriggerLogic::Or,
417 argument_template: None,
418 };
419 store.register_trigger(&trigger).await.unwrap();
420
421 let triggers = store.get_triggers_for_condition(&cond_id).await.unwrap();
422 assert_eq!(triggers.len(), 1);
423 }
424
425 #[tokio::test]
426 async fn remove_triggers_for_task() {
427 let store = MemTriggerStore::new();
428 let task_id = TaskId::new("mod", "target");
429 let trigger = TriggerDefinitionDTO {
430 trigger_id: TriggerDefinitionId::from("t1".to_string()),
431 task_id: task_id.clone(),
432 condition_ids: vec![],
433 logic: TriggerLogic::And,
434 argument_template: None,
435 };
436 store.register_trigger(&trigger).await.unwrap();
437
438 let removed = store.remove_triggers_for_task(&task_id).await.unwrap();
439 assert_eq!(removed, 1);
440
441 let got = store
442 .get_trigger(&TriggerDefinitionId::from("t1".to_string()))
443 .await
444 .unwrap();
445 assert!(got.is_none());
446 }
447
448 #[tokio::test]
449 async fn valid_condition_lifecycle() {
450 let store = MemTriggerStore::new();
451 let vc = ValidCondition::new(
452 ConditionId::from("c1".to_string()),
453 ConditionContext::Event(EventContext {
454 event_id: "e1".to_string(),
455 event_code: "test".to_string(),
456 payload: serde_json::json!({}),
457 }),
458 );
459 let vc_id = vc.valid_condition_id.clone();
460
461 store.record_valid_condition(&vc).await.unwrap();
462 let vcs = store.get_valid_conditions().await.unwrap();
463 assert_eq!(vcs.len(), 1);
464
465 store.clear_valid_conditions(&[vc_id]).await.unwrap();
466 let vcs = store.get_valid_conditions().await.unwrap();
467 assert!(vcs.is_empty());
468 }
469
470 #[tokio::test]
471 async fn cron_execution_optimistic_lock() {
472 let store = MemTriggerStore::new();
473 let cond_id = ConditionId::from("cron1".to_string());
474 let now = Utc::now();
475
476 assert!(store
478 .store_cron_execution(&cond_id, now, None)
479 .await
480 .unwrap());
481
482 assert!(!store
484 .store_cron_execution(&cond_id, now, None)
485 .await
486 .unwrap());
487
488 let later = now + chrono::Duration::seconds(60);
490 assert!(store
491 .store_cron_execution(&cond_id, later, Some(now))
492 .await
493 .unwrap());
494 }
495
496 #[tokio::test]
497 async fn claim_trigger_run_dedup() {
498 let store = MemTriggerStore::new();
499 let run_id = TriggerRunId::from("run-1".to_string());
500
501 assert!(store.claim_trigger_run(&run_id).await.unwrap());
502 assert!(!store.claim_trigger_run(&run_id).await.unwrap());
503 }
504
505 #[tokio::test]
506 async fn purge_clears_all() {
507 let store = MemTriggerStore::new();
508 let cond = TriggerCondition::Event(EventCondition {
509 event_code: "test".to_string(),
510 payload_filter: None,
511 });
512 store.register_condition(&cond).await.unwrap();
513 store.purge().await.unwrap();
514
515 let got = store.get_event_conditions("test").await.unwrap();
516 assert!(got.is_empty());
517 }
518}