Skip to main content

rustvello_mongo/
trigger.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use mongodb::bson::doc;
6use mongodb::error::{ErrorKind, WriteFailure};
7
8use rustvello_core::error::{RustvelloError, RustvelloResult};
9use rustvello_core::trigger::TriggerStore;
10use rustvello_proto::identifiers::TaskId;
11use rustvello_proto::trigger::{
12    ConditionId, TriggerCondition, TriggerDefinitionDTO, TriggerDefinitionId, TriggerRunId,
13    ValidCondition,
14};
15
16use crate::connection::{mongo_err, MongoPool};
17
18const COND_COL: &str = "trg_conditions";
19const TRIGGER_COL: &str = "trg_definitions";
20const VALID_COL: &str = "trg_valid_conditions";
21const CRON_EXEC_COL: &str = "trg_cron_executions";
22const RUN_COL: &str = "trg_runs";
23
24/// MongoDB-backed trigger store.
25#[non_exhaustive]
26pub struct MongoTriggerStore {
27    pool: Arc<MongoPool>,
28}
29
30impl MongoTriggerStore {
31    pub fn new(pool: Arc<MongoPool>) -> Self {
32        Self { pool }
33    }
34}
35
36#[async_trait]
37impl TriggerStore for MongoTriggerStore {
38    async fn register_condition(
39        &self,
40        condition: &TriggerCondition,
41    ) -> RustvelloResult<ConditionId> {
42        let cond_id = condition.condition_id();
43        let db = self.pool.db().await?;
44        let col = db.collection::<mongodb::bson::Document>(COND_COL);
45        let json = serde_json::to_string(condition).map_err(|e| RustvelloError::Serialization {
46            message: e.to_string(),
47        })?;
48
49        let task_ids: Vec<String> = condition
50            .source_task_ids()
51            .iter()
52            .map(ToString::to_string)
53            .collect();
54
55        // Classify condition type for efficient server-side queries
56        let condition_type = match condition {
57            TriggerCondition::Cron(_) => "Cron",
58            TriggerCondition::Event(_) => "Event",
59            TriggerCondition::Status(_) => "Status",
60            TriggerCondition::Result(_) => "Result",
61            TriggerCondition::Exception(_) => "Exception",
62            TriggerCondition::Composite(_) => "Composite",
63            _ => "Other",
64        };
65
66        let mut set_fields = doc! {
67            "data": &json,
68            "task_ids": &task_ids,
69            "condition_type": condition_type,
70        };
71
72        // Store event_code for direct querying
73        if let TriggerCondition::Event(ev) = condition {
74            set_fields.insert("event_code", ev.event_code.clone());
75        }
76
77        let update_doc = doc! { "$set": set_fields };
78
79        let filter = doc! { "_id": cond_id.as_str().to_owned() };
80        col.update_one(filter, update_doc)
81            .upsert(true)
82            .await
83            .map_err(mongo_err)?;
84
85        Ok(cond_id)
86    }
87
88    async fn get_condition(&self, id: &ConditionId) -> RustvelloResult<Option<TriggerCondition>> {
89        let db = self.pool.db().await?;
90        let col = db.collection::<mongodb::bson::Document>(COND_COL);
91        let filter = doc! { "_id": &id.as_str() };
92        let result = col.find_one(filter).await.map_err(mongo_err)?;
93        match result {
94            Some(d) => {
95                let s = d
96                    .get_str("data")
97                    .map_err(|e| RustvelloError::state_backend(e.to_string()))?;
98                let c: TriggerCondition =
99                    serde_json::from_str(s).map_err(|e| RustvelloError::Serialization {
100                        message: e.to_string(),
101                    })?;
102                Ok(Some(c))
103            }
104            None => Ok(None),
105        }
106    }
107
108    async fn get_conditions_for_task(
109        &self,
110        task_id: &TaskId,
111    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
112        let db = self.pool.db().await?;
113        let col = db.collection::<mongodb::bson::Document>(COND_COL);
114        let filter = doc! { "task_ids": task_id.to_string() };
115        let mut cursor = col.find(filter).await.map_err(mongo_err)?;
116
117        let mut result = Vec::new();
118        use futures_util::StreamExt;
119        while let Some(doc_result) = StreamExt::next(&mut cursor).await {
120            let d = doc_result.map_err(mongo_err)?;
121            if let (Ok(id), Ok(data)) = (d.get_str("_id"), d.get_str("data")) {
122                let cond: TriggerCondition =
123                    serde_json::from_str(data).map_err(|e| RustvelloError::Serialization {
124                        message: e.to_string(),
125                    })?;
126                result.push((ConditionId::from(id.to_string()), cond));
127            }
128        }
129        Ok(result)
130    }
131
132    async fn get_cron_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
133        let db = self.pool.db().await?;
134        let col = db.collection::<mongodb::bson::Document>(COND_COL);
135        // Server-side filter by condition_type instead of fetching all and filtering client-side
136        let mut cursor = col
137            .find(doc! { "condition_type": "Cron" })
138            .await
139            .map_err(mongo_err)?;
140
141        let mut result = Vec::new();
142        use futures_util::StreamExt;
143        while let Some(doc_result) = StreamExt::next(&mut cursor).await {
144            let d = doc_result.map_err(mongo_err)?;
145            if let (Ok(id), Ok(data)) = (d.get_str("_id"), d.get_str("data")) {
146                let cond: TriggerCondition =
147                    serde_json::from_str(data).map_err(|e| RustvelloError::Serialization {
148                        message: e.to_string(),
149                    })?;
150                result.push((ConditionId::from(id.to_string()), cond));
151            }
152        }
153        Ok(result)
154    }
155
156    async fn get_event_conditions(
157        &self,
158        event_code: &str,
159    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
160        let db = self.pool.db().await?;
161        let col = db.collection::<mongodb::bson::Document>(COND_COL);
162        // Server-side filter by condition_type and event_code
163        let mut cursor = col
164            .find(doc! { "condition_type": "Event", "event_code": event_code })
165            .await
166            .map_err(mongo_err)?;
167
168        let mut result = Vec::new();
169        use futures_util::StreamExt;
170        while let Some(doc_result) = StreamExt::next(&mut cursor).await {
171            let d = doc_result.map_err(mongo_err)?;
172            if let (Ok(id), Ok(data)) = (d.get_str("_id"), d.get_str("data")) {
173                let cond: TriggerCondition =
174                    serde_json::from_str(data).map_err(|e| RustvelloError::Serialization {
175                        message: e.to_string(),
176                    })?;
177                result.push((ConditionId::from(id.to_string()), cond));
178            }
179        }
180        Ok(result)
181    }
182
183    async fn register_trigger(&self, trigger: &TriggerDefinitionDTO) -> RustvelloResult<()> {
184        let db = self.pool.db().await?;
185        let col = db.collection::<mongodb::bson::Document>(TRIGGER_COL);
186        let json = serde_json::to_string(trigger).map_err(|e| RustvelloError::Serialization {
187            message: e.to_string(),
188        })?;
189
190        let cond_ids: Vec<String> = trigger
191            .condition_ids
192            .iter()
193            .map(|c| c.as_str().to_owned())
194            .collect();
195
196        let filter = doc! { "_id": &trigger.trigger_id.as_str() };
197        let update = doc! {
198            "$set": {
199                "data": &json,
200                "task_id": trigger.task_id.to_string(),
201                "condition_ids": &cond_ids,
202            }
203        };
204        col.update_one(filter, update)
205            .upsert(true)
206            .await
207            .map_err(mongo_err)?;
208        Ok(())
209    }
210
211    async fn get_trigger(
212        &self,
213        id: &TriggerDefinitionId,
214    ) -> RustvelloResult<Option<TriggerDefinitionDTO>> {
215        let db = self.pool.db().await?;
216        let col = db.collection::<mongodb::bson::Document>(TRIGGER_COL);
217        let filter = doc! { "_id": &id.as_str() };
218        let result = col.find_one(filter).await.map_err(mongo_err)?;
219        match result {
220            Some(d) => {
221                let s = d
222                    .get_str("data")
223                    .map_err(|e| RustvelloError::state_backend(e.to_string()))?;
224                let t: TriggerDefinitionDTO =
225                    serde_json::from_str(s).map_err(|e| RustvelloError::Serialization {
226                        message: e.to_string(),
227                    })?;
228                Ok(Some(t))
229            }
230            None => Ok(None),
231        }
232    }
233
234    async fn get_triggers_for_condition(
235        &self,
236        cond_id: &ConditionId,
237    ) -> RustvelloResult<Vec<TriggerDefinitionDTO>> {
238        let db = self.pool.db().await?;
239        let col = db.collection::<mongodb::bson::Document>(TRIGGER_COL);
240        let filter = doc! { "condition_ids": cond_id.as_str() };
241        let mut cursor = col.find(filter).await.map_err(mongo_err)?;
242
243        let mut result = Vec::new();
244        use futures_util::StreamExt;
245        while let Some(doc_result) = StreamExt::next(&mut cursor).await {
246            let d = doc_result.map_err(mongo_err)?;
247            if let Ok(s) = d.get_str("data") {
248                let t: TriggerDefinitionDTO =
249                    serde_json::from_str(s).map_err(|e| RustvelloError::Serialization {
250                        message: e.to_string(),
251                    })?;
252                result.push(t);
253            }
254        }
255        Ok(result)
256    }
257
258    async fn remove_triggers_for_task(&self, task_id: &TaskId) -> RustvelloResult<u32> {
259        let db = self.pool.db().await?;
260        let col = db.collection::<mongodb::bson::Document>(TRIGGER_COL);
261        let filter = doc! { "task_id": task_id.to_string() };
262        let result = col.delete_many(filter).await.map_err(mongo_err)?;
263        Ok(u32::try_from(result.deleted_count).unwrap_or(u32::MAX))
264    }
265
266    async fn record_valid_condition(&self, vc: &ValidCondition) -> RustvelloResult<()> {
267        let db = self.pool.db().await?;
268        let col = db.collection::<mongodb::bson::Document>(VALID_COL);
269        let json = serde_json::to_string(vc).map_err(|e| RustvelloError::Serialization {
270            message: e.to_string(),
271        })?;
272        let filter = doc! { "_id": &vc.valid_condition_id };
273        let update = doc! { "$set": { "data": &json } };
274        col.update_one(filter, update)
275            .upsert(true)
276            .await
277            .map_err(mongo_err)?;
278        Ok(())
279    }
280
281    async fn get_valid_conditions(&self) -> RustvelloResult<Vec<ValidCondition>> {
282        let db = self.pool.db().await?;
283        let col = db.collection::<mongodb::bson::Document>(VALID_COL);
284        let mut cursor = col.find(doc! {}).await.map_err(mongo_err)?;
285
286        let mut result = Vec::new();
287        use futures_util::StreamExt;
288        while let Some(doc_result) = StreamExt::next(&mut cursor).await {
289            let d = doc_result.map_err(mongo_err)?;
290            if let Ok(s) = d.get_str("data") {
291                let vc: ValidCondition =
292                    serde_json::from_str(s).map_err(|e| RustvelloError::Serialization {
293                        message: e.to_string(),
294                    })?;
295                result.push(vc);
296            }
297        }
298        Ok(result)
299    }
300
301    async fn clear_valid_conditions(&self, ids: &[String]) -> RustvelloResult<()> {
302        if ids.is_empty() {
303            return Ok(());
304        }
305        let db = self.pool.db().await?;
306        let col = db.collection::<mongodb::bson::Document>(VALID_COL);
307        let bson_ids: Vec<mongodb::bson::Bson> = ids
308            .iter()
309            .map(|id| mongodb::bson::Bson::String(id.clone()))
310            .collect();
311        let filter = doc! { "_id": { "$in": bson_ids } };
312        col.delete_many(filter).await.map_err(mongo_err)?;
313        Ok(())
314    }
315
316    async fn get_last_cron_execution(
317        &self,
318        cond_id: &ConditionId,
319    ) -> RustvelloResult<Option<DateTime<Utc>>> {
320        let db = self.pool.db().await?;
321        let col = db.collection::<mongodb::bson::Document>(CRON_EXEC_COL);
322        let filter = doc! { "_id": cond_id.as_str() };
323        let result = col.find_one(filter).await.map_err(mongo_err)?;
324        match result {
325            Some(d) => {
326                let ts = d
327                    .get_str("timestamp")
328                    .map_err(|e| RustvelloError::state_backend(e.to_string()))?;
329                let dt = DateTime::parse_from_rfc3339(ts)
330                    .map(|d| d.with_timezone(&Utc))
331                    .map_err(|e| RustvelloError::Serialization {
332                        message: format!("cron timestamp: {}", e),
333                    })?;
334                Ok(Some(dt))
335            }
336            None => Ok(None),
337        }
338    }
339
340    async fn store_cron_execution(
341        &self,
342        cond_id: &ConditionId,
343        time: DateTime<Utc>,
344        expected_last: Option<DateTime<Utc>>,
345    ) -> RustvelloResult<bool> {
346        let db = self.pool.db().await?;
347        let col = db.collection::<mongodb::bson::Document>(CRON_EXEC_COL);
348
349        // Atomic compare-and-swap: filter includes the expected timestamp value
350        let filter = match expected_last {
351            Some(ts) => doc! { "_id": cond_id.as_str(), "timestamp": ts.to_rfc3339() },
352            None => doc! { "_id": cond_id.as_str(), "timestamp": { "$exists": false } },
353        };
354        let update = doc! { "$set": { "timestamp": time.to_rfc3339() } };
355
356        let result = col
357            .update_one(filter, update)
358            .upsert(expected_last.is_none())
359            .await;
360
361        match result {
362            Ok(r) => Ok(r.matched_count > 0 || r.upserted_id.is_some()),
363            Err(e) => {
364                // Duplicate key error (code 11000) means another runner raced us
365                // on the first execution — treat as "not claimed"
366                if matches!(*e.kind, ErrorKind::Write(WriteFailure::WriteError(ref we)) if we.code == 11000)
367                {
368                    Ok(false)
369                } else {
370                    Err(mongo_err(e))
371                }
372            }
373        }
374    }
375
376    async fn claim_trigger_run(&self, run_id: &TriggerRunId) -> RustvelloResult<bool> {
377        let db = self.pool.db().await?;
378        let col = db.collection::<mongodb::bson::Document>(RUN_COL);
379        let doc = doc! { "_id": run_id.as_str().to_owned(), "claimed": true };
380        match col.insert_one(doc).await {
381            Ok(_) => Ok(true),
382            Err(e) => {
383                // Duplicate key error (code 11000) means already claimed
384                if matches!(*e.kind, ErrorKind::Write(WriteFailure::WriteError(ref we)) if we.code == 11000)
385                {
386                    Ok(false)
387                } else {
388                    Err(mongo_err(e))
389                }
390            }
391        }
392    }
393
394    async fn purge(&self) -> RustvelloResult<()> {
395        let db = self.pool.db().await?;
396        for col_name in [COND_COL, TRIGGER_COL, VALID_COL, CRON_EXEC_COL, RUN_COL] {
397            let col = db.collection::<mongodb::bson::Document>(col_name);
398            col.delete_many(doc! {}).await.map_err(mongo_err)?;
399        }
400        Ok(())
401    }
402
403    async fn get_all_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
404        let db = self.pool.db().await?;
405        let col = db.collection::<mongodb::bson::Document>(COND_COL);
406        let mut cursor = col.find(doc! {}).await.map_err(mongo_err)?;
407
408        let mut result = Vec::new();
409        use futures_util::StreamExt;
410        while let Some(doc_result) = StreamExt::next(&mut cursor).await {
411            let d = doc_result.map_err(mongo_err)?;
412            if let (Ok(id), Ok(data)) = (d.get_str("_id"), d.get_str("data")) {
413                let cond: TriggerCondition =
414                    serde_json::from_str(data).map_err(|e| RustvelloError::Serialization {
415                        message: e.to_string(),
416                    })?;
417                result.push((ConditionId::from(id.to_string()), cond));
418            }
419        }
420        Ok(result)
421    }
422}