Skip to main content

rustvello_mem/
trigger.rs

1//! In-memory trigger store implementation.
2
3use std::collections::HashMap;
4use tokio::sync::Mutex;
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8
9use rustvello_core::error::RustvelloResult;
10use rustvello_core::trigger::TriggerStore;
11use rustvello_proto::identifiers::TaskId;
12use rustvello_proto::trigger::{
13    ConditionId, TriggerCondition, TriggerDefinitionDTO, TriggerDefinitionId, TriggerRunId,
14    ValidCondition,
15};
16
17struct TriggerState {
18    conditions: HashMap<String, TriggerCondition>,
19    /// source_task_id -> list of condition IDs
20    source_task_conditions: HashMap<String, Vec<ConditionId>>,
21    /// event_code -> list of condition IDs
22    event_conditions: HashMap<String, Vec<ConditionId>>,
23    /// condition IDs of cron conditions
24    cron_condition_ids: Vec<ConditionId>,
25    triggers: HashMap<String, TriggerDefinitionDTO>,
26    /// condition_id -> list of trigger IDs
27    condition_triggers: HashMap<String, Vec<TriggerDefinitionId>>,
28    valid_conditions: HashMap<String, ValidCondition>,
29    cron_executions: HashMap<String, DateTime<Utc>>,
30    trigger_run_claims: HashMap<String, DateTime<Utc>>,
31}
32
33/// In-memory trigger store for testing and development.
34pub struct MemTriggerStore {
35    state: Mutex<TriggerState>,
36}
37
38impl MemTriggerStore {
39    pub fn new() -> Self {
40        Self {
41            state: Mutex::new(TriggerState {
42                conditions: HashMap::new(),
43                source_task_conditions: HashMap::new(),
44                event_conditions: HashMap::new(),
45                cron_condition_ids: Vec::new(),
46                triggers: HashMap::new(),
47                condition_triggers: HashMap::new(),
48                valid_conditions: HashMap::new(),
49                cron_executions: HashMap::new(),
50                trigger_run_claims: HashMap::new(),
51            }),
52        }
53    }
54}
55
56impl Default for MemTriggerStore {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62#[async_trait]
63impl TriggerStore for MemTriggerStore {
64    async fn register_condition(
65        &self,
66        condition: &TriggerCondition,
67    ) -> RustvelloResult<ConditionId> {
68        let cond_id = condition.condition_id();
69        let mut state = self.state.lock().await;
70
71        state
72            .conditions
73            .insert(cond_id.as_str().to_owned(), condition.clone());
74
75        // Index by source task IDs
76        for task_id in condition.source_task_ids() {
77            let vec = state
78                .source_task_conditions
79                .entry(task_id.to_string())
80                .or_default();
81            if !vec.contains(&cond_id) {
82                vec.push(cond_id.clone());
83            }
84        }
85
86        // Index by event code
87        if let TriggerCondition::Event(evt) = condition {
88            let vec = state
89                .event_conditions
90                .entry(evt.event_code.clone())
91                .or_default();
92            if !vec.contains(&cond_id) {
93                vec.push(cond_id.clone());
94            }
95        }
96
97        // Track cron conditions
98        if matches!(condition, TriggerCondition::Cron(_))
99            && !state.cron_condition_ids.contains(&cond_id)
100        {
101            state.cron_condition_ids.push(cond_id.clone());
102        }
103
104        Ok(cond_id)
105    }
106
107    async fn get_condition(&self, id: &ConditionId) -> RustvelloResult<Option<TriggerCondition>> {
108        let state = self.state.lock().await;
109        Ok(state.conditions.get(id.as_str()).cloned())
110    }
111
112    async fn get_conditions_for_task(
113        &self,
114        task_id: &TaskId,
115    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
116        let state = self.state.lock().await;
117        let key = task_id.to_string();
118        let cond_ids = state.source_task_conditions.get(&key);
119
120        let mut result = Vec::new();
121        if let Some(ids) = cond_ids {
122            for cid in ids {
123                if let Some(cond) = state.conditions.get(cid.as_str()) {
124                    result.push((cid.clone(), cond.clone()));
125                }
126            }
127        }
128        Ok(result)
129    }
130
131    async fn get_cron_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
132        let state = self.state.lock().await;
133        let mut result = Vec::new();
134        for cid in &state.cron_condition_ids {
135            if let Some(cond) = state.conditions.get(cid.as_str()) {
136                result.push((cid.clone(), cond.clone()));
137            }
138        }
139        Ok(result)
140    }
141
142    async fn get_event_conditions(
143        &self,
144        event_code: &str,
145    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
146        let state = self.state.lock().await;
147        let cond_ids = state.event_conditions.get(event_code);
148
149        let mut result = Vec::new();
150        if let Some(ids) = cond_ids {
151            for cid in ids {
152                if let Some(cond) = state.conditions.get(cid.as_str()) {
153                    result.push((cid.clone(), cond.clone()));
154                }
155            }
156        }
157        Ok(result)
158    }
159
160    async fn register_trigger(&self, trigger: &TriggerDefinitionDTO) -> RustvelloResult<()> {
161        let mut state = self.state.lock().await;
162
163        state
164            .triggers
165            .insert(trigger.trigger_id.as_str().to_owned(), trigger.clone());
166
167        // Index condition -> trigger
168        for cid in &trigger.condition_ids {
169            state
170                .condition_triggers
171                .entry(cid.as_str().to_owned())
172                .or_default()
173                .push(trigger.trigger_id.clone());
174        }
175
176        Ok(())
177    }
178
179    async fn get_trigger(
180        &self,
181        id: &TriggerDefinitionId,
182    ) -> RustvelloResult<Option<TriggerDefinitionDTO>> {
183        let state = self.state.lock().await;
184        Ok(state.triggers.get(id.as_str()).cloned())
185    }
186
187    async fn get_triggers_for_condition(
188        &self,
189        cond_id: &ConditionId,
190    ) -> RustvelloResult<Vec<TriggerDefinitionDTO>> {
191        let state = self.state.lock().await;
192        let trigger_ids = state.condition_triggers.get(cond_id.as_str());
193
194        let mut result = Vec::new();
195        if let Some(ids) = trigger_ids {
196            for tid in ids {
197                if let Some(trigger) = state.triggers.get(tid.as_str()) {
198                    result.push(trigger.clone());
199                }
200            }
201        }
202        Ok(result)
203    }
204
205    async fn remove_triggers_for_task(&self, task_id: &TaskId) -> RustvelloResult<u32> {
206        let mut state = self.state.lock().await;
207        let task_str = task_id.to_string();
208
209        let ids_to_remove: Vec<String> = state
210            .triggers
211            .iter()
212            .filter(|(_, t)| t.task_id.to_string() == task_str)
213            .map(|(id, _)| id.clone())
214            .collect();
215
216        let count = u32::try_from(ids_to_remove.len()).unwrap_or(u32::MAX);
217        for id in &ids_to_remove {
218            if let Some(trigger) = state.triggers.remove(id) {
219                // Remove from condition_triggers index
220                for cid in &trigger.condition_ids {
221                    if let Some(tids) = state.condition_triggers.get_mut(cid.as_str()) {
222                        tids.retain(|tid| tid.as_str() != *id);
223                    }
224                }
225            }
226        }
227
228        Ok(count)
229    }
230
231    async fn record_valid_condition(&self, vc: &ValidCondition) -> RustvelloResult<()> {
232        let mut state = self.state.lock().await;
233        state
234            .valid_conditions
235            .insert(vc.valid_condition_id.clone(), vc.clone());
236        Ok(())
237    }
238
239    async fn get_valid_conditions(&self) -> RustvelloResult<Vec<ValidCondition>> {
240        let state = self.state.lock().await;
241        Ok(state.valid_conditions.values().cloned().collect())
242    }
243
244    async fn clear_valid_conditions(&self, ids: &[String]) -> RustvelloResult<()> {
245        let mut state = self.state.lock().await;
246        for id in ids {
247            state.valid_conditions.remove(id);
248        }
249        Ok(())
250    }
251
252    async fn get_last_cron_execution(
253        &self,
254        cond_id: &ConditionId,
255    ) -> RustvelloResult<Option<DateTime<Utc>>> {
256        let state = self.state.lock().await;
257        Ok(state.cron_executions.get(cond_id.as_str()).copied())
258    }
259
260    async fn store_cron_execution(
261        &self,
262        cond_id: &ConditionId,
263        time: DateTime<Utc>,
264        expected_last: Option<DateTime<Utc>>,
265    ) -> RustvelloResult<bool> {
266        let mut state = self.state.lock().await;
267        let current = state.cron_executions.get(cond_id.as_str()).copied();
268
269        // Optimistic locking: only update if current matches expected
270        if current == expected_last {
271            state
272                .cron_executions
273                .insert(cond_id.as_str().to_owned(), time);
274            Ok(true)
275        } else {
276            Ok(false)
277        }
278    }
279
280    async fn claim_trigger_run(&self, run_id: &TriggerRunId) -> RustvelloResult<bool> {
281        let mut state = self.state.lock().await;
282        if state.trigger_run_claims.contains_key(run_id.as_str()) {
283            Ok(false)
284        } else {
285            state
286                .trigger_run_claims
287                .insert(run_id.as_str().to_owned(), Utc::now());
288            Ok(true)
289        }
290    }
291
292    async fn purge(&self) -> RustvelloResult<()> {
293        let mut state = self.state.lock().await;
294        state.conditions.clear();
295        state.source_task_conditions.clear();
296        state.event_conditions.clear();
297        state.cron_condition_ids.clear();
298        state.triggers.clear();
299        state.condition_triggers.clear();
300        state.valid_conditions.clear();
301        state.cron_executions.clear();
302        state.trigger_run_claims.clear();
303        Ok(())
304    }
305
306    async fn get_all_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
307        let state = self.state.lock().await;
308        Ok(state
309            .conditions
310            .iter()
311            .map(|(id, cond)| (ConditionId::from(id.clone()), cond.clone()))
312            .collect())
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use rustvello_proto::trigger::*;
320
321    #[tokio::test]
322    async fn register_and_get_condition() {
323        let store = MemTriggerStore::new();
324        let cond = TriggerCondition::Event(EventCondition {
325            event_code: "payment".to_string(),
326            payload_filter: None,
327        });
328        let id = store.register_condition(&cond).await.unwrap();
329        let got = store.get_condition(&id).await.unwrap();
330        assert!(got.is_some());
331        assert_eq!(got.unwrap().condition_id(), id);
332    }
333
334    #[tokio::test]
335    async fn get_conditions_for_task() {
336        let store = MemTriggerStore::new();
337        let task_id = TaskId::new("mod", "task");
338        let cond = TriggerCondition::Status(StatusCondition {
339            task_id: task_id.clone(),
340            statuses: vec![rustvello_proto::status::InvocationStatus::Success],
341            argument_filter: None,
342        });
343        store.register_condition(&cond).await.unwrap();
344
345        let conds = store.get_conditions_for_task(&task_id).await.unwrap();
346        assert_eq!(conds.len(), 1);
347
348        let other = TaskId::new("mod", "other");
349        let conds = store.get_conditions_for_task(&other).await.unwrap();
350        assert!(conds.is_empty());
351    }
352
353    #[tokio::test]
354    async fn get_event_conditions() {
355        let store = MemTriggerStore::new();
356        let cond = TriggerCondition::Event(EventCondition {
357            event_code: "payment".to_string(),
358            payload_filter: None,
359        });
360        store.register_condition(&cond).await.unwrap();
361
362        let got = store.get_event_conditions("payment").await.unwrap();
363        assert_eq!(got.len(), 1);
364
365        let got = store.get_event_conditions("other").await.unwrap();
366        assert!(got.is_empty());
367    }
368
369    #[tokio::test]
370    async fn get_cron_conditions() {
371        let store = MemTriggerStore::new();
372        let cond = TriggerCondition::Cron(CronCondition {
373            cron_expression: "* * * * *".to_string(),
374            min_interval_seconds: 50,
375        });
376        store.register_condition(&cond).await.unwrap();
377
378        let conds = store.get_cron_conditions().await.unwrap();
379        assert_eq!(conds.len(), 1);
380    }
381
382    #[tokio::test]
383    async fn register_and_get_trigger() {
384        let store = MemTriggerStore::new();
385        let task_id = TaskId::new("mod", "target");
386        let cond_ids = vec![ConditionId::from("c1".to_string())];
387        let trigger_id =
388            TriggerDefinitionDTO::compute_trigger_id(&task_id, &cond_ids, TriggerLogic::Or);
389
390        let trigger = TriggerDefinitionDTO {
391            trigger_id: trigger_id.clone(),
392            task_id,
393            condition_ids: cond_ids,
394            logic: TriggerLogic::Or,
395            argument_template: None,
396        };
397        store.register_trigger(&trigger).await.unwrap();
398
399        let got = store.get_trigger(&trigger_id).await.unwrap();
400        assert!(got.is_some());
401    }
402
403    #[tokio::test]
404    async fn get_triggers_for_condition() {
405        let store = MemTriggerStore::new();
406        let cond_id = ConditionId::from("c1".to_string());
407        let task_id = TaskId::new("mod", "target");
408        let trigger = TriggerDefinitionDTO {
409            trigger_id: TriggerDefinitionDTO::compute_trigger_id(
410                &task_id,
411                &[cond_id.clone()],
412                TriggerLogic::Or,
413            ),
414            task_id,
415            condition_ids: vec![cond_id.clone()],
416            logic: TriggerLogic::Or,
417            argument_template: None,
418        };
419        store.register_trigger(&trigger).await.unwrap();
420
421        let triggers = store.get_triggers_for_condition(&cond_id).await.unwrap();
422        assert_eq!(triggers.len(), 1);
423    }
424
425    #[tokio::test]
426    async fn remove_triggers_for_task() {
427        let store = MemTriggerStore::new();
428        let task_id = TaskId::new("mod", "target");
429        let trigger = TriggerDefinitionDTO {
430            trigger_id: TriggerDefinitionId::from("t1".to_string()),
431            task_id: task_id.clone(),
432            condition_ids: vec![],
433            logic: TriggerLogic::And,
434            argument_template: None,
435        };
436        store.register_trigger(&trigger).await.unwrap();
437
438        let removed = store.remove_triggers_for_task(&task_id).await.unwrap();
439        assert_eq!(removed, 1);
440
441        let got = store
442            .get_trigger(&TriggerDefinitionId::from("t1".to_string()))
443            .await
444            .unwrap();
445        assert!(got.is_none());
446    }
447
448    #[tokio::test]
449    async fn valid_condition_lifecycle() {
450        let store = MemTriggerStore::new();
451        let vc = ValidCondition::new(
452            ConditionId::from("c1".to_string()),
453            ConditionContext::Event(EventContext {
454                event_id: "e1".to_string(),
455                event_code: "test".to_string(),
456                payload: serde_json::json!({}),
457            }),
458        );
459        let vc_id = vc.valid_condition_id.clone();
460
461        store.record_valid_condition(&vc).await.unwrap();
462        let vcs = store.get_valid_conditions().await.unwrap();
463        assert_eq!(vcs.len(), 1);
464
465        store.clear_valid_conditions(&[vc_id]).await.unwrap();
466        let vcs = store.get_valid_conditions().await.unwrap();
467        assert!(vcs.is_empty());
468    }
469
470    #[tokio::test]
471    async fn cron_execution_optimistic_lock() {
472        let store = MemTriggerStore::new();
473        let cond_id = ConditionId::from("cron1".to_string());
474        let now = Utc::now();
475
476        // First store succeeds (no previous execution)
477        assert!(store
478            .store_cron_execution(&cond_id, now, None)
479            .await
480            .unwrap());
481
482        // Second store with wrong expected value fails
483        assert!(!store
484            .store_cron_execution(&cond_id, now, None)
485            .await
486            .unwrap());
487
488        // Store with correct expected value succeeds
489        let later = now + chrono::Duration::seconds(60);
490        assert!(store
491            .store_cron_execution(&cond_id, later, Some(now))
492            .await
493            .unwrap());
494    }
495
496    #[tokio::test]
497    async fn claim_trigger_run_dedup() {
498        let store = MemTriggerStore::new();
499        let run_id = TriggerRunId::from("run-1".to_string());
500
501        assert!(store.claim_trigger_run(&run_id).await.unwrap());
502        assert!(!store.claim_trigger_run(&run_id).await.unwrap());
503    }
504
505    #[tokio::test]
506    async fn purge_clears_all() {
507        let store = MemTriggerStore::new();
508        let cond = TriggerCondition::Event(EventCondition {
509            event_code: "test".to_string(),
510            payload_filter: None,
511        });
512        store.register_condition(&cond).await.unwrap();
513        store.purge().await.unwrap();
514
515        let got = store.get_event_conditions("test").await.unwrap();
516        assert!(got.is_empty());
517    }
518}