1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8
9use rustvello_core::error::{RustvelloError, RustvelloResult};
10use rustvello_core::trigger::TriggerStore;
11use rustvello_proto::identifiers::TaskId;
12use rustvello_proto::trigger::{
13 ConditionContext, ConditionId, TriggerCondition, TriggerDefinitionDTO, TriggerDefinitionId,
14 TriggerLogic, TriggerRunId, ValidCondition,
15};
16
17use crate::db::{pg_err, Database};
18
19pub struct PostgresTriggerStore {
21 db: Arc<Database>,
22}
23
24impl PostgresTriggerStore {
25 pub fn new(db: Arc<Database>) -> Self {
26 Self { db }
27 }
28}
29
30fn parse_logic(s: &str) -> RustvelloResult<TriggerLogic> {
31 match s {
32 "AND" => Ok(TriggerLogic::And),
33 "OR" => Ok(TriggerLogic::Or),
34 other => Err(RustvelloError::state_backend(format!(
35 "unknown trigger logic value: {other:?}"
36 ))),
37 }
38}
39
40fn condition_type_tag(condition: &TriggerCondition) -> &'static str {
44 match condition {
45 TriggerCondition::Cron(_) => "Cron",
46 TriggerCondition::Status(_) => "Status",
47 TriggerCondition::Event(_) => "Event",
48 TriggerCondition::Result(_) => "Result",
49 TriggerCondition::Exception(_) => "Exception",
50 TriggerCondition::Composite(_) => "Composite",
51 _ => "Unknown",
52 }
53}
54
55#[async_trait]
56impl TriggerStore for PostgresTriggerStore {
57 async fn register_condition(
58 &self,
59 condition: &TriggerCondition,
60 ) -> RustvelloResult<ConditionId> {
61 let cond_id = condition.condition_id();
62 let json = serde_json::to_string(condition).map_err(|e| RustvelloError::Serialization {
63 message: e.to_string(),
64 })?;
65 let cond_type = condition_type_tag(condition);
66 let event_code: Option<&str> = match condition {
67 TriggerCondition::Event(evt) => Some(&evt.event_code),
68 _ => None,
69 };
70
71 let mut client = self.db.conn().await?;
72 let tx = client.transaction().await.map_err(pg_err)?;
73
74 tx
75 .execute(
76 "INSERT INTO trg_conditions (condition_id, condition_type, condition_json, event_code) VALUES ($1, $2, $3, $4)
77 ON CONFLICT (condition_id) DO UPDATE SET condition_type = $2, condition_json = $3, event_code = $4",
78 &[&cond_id.as_str(), &cond_type, &json, &event_code],
79 )
80 .await
81 .map_err(pg_err)?;
82
83 for task_id in condition.source_task_ids() {
85 tx.execute(
86 "INSERT INTO trg_source_task_conditions (task_id, condition_id) VALUES ($1, $2)
87 ON CONFLICT DO NOTHING",
88 &[&task_id.to_string(), &cond_id.as_str()],
89 )
90 .await
91 .map_err(pg_err)?;
92 }
93
94 tx.commit().await.map_err(pg_err)?;
95
96 Ok(cond_id)
97 }
98
99 async fn get_condition(&self, id: &ConditionId) -> RustvelloResult<Option<TriggerCondition>> {
100 let client = self.db.conn().await?;
101
102 let row = client
103 .query_opt(
104 "SELECT condition_json FROM trg_conditions WHERE condition_id = $1",
105 &[&id.as_str()],
106 )
107 .await
108 .map_err(pg_err)?;
109
110 match row {
111 Some(r) => {
112 let json: String = r.get(0);
113 let cond: TriggerCondition =
114 serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
115 message: e.to_string(),
116 })?;
117 Ok(Some(cond))
118 }
119 None => Ok(None),
120 }
121 }
122
123 async fn get_conditions_for_task(
124 &self,
125 task_id: &TaskId,
126 ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
127 let client = self.db.conn().await?;
128
129 let rows = client
130 .query(
131 "SELECT c.condition_id, c.condition_json
132 FROM trg_conditions c
133 INNER JOIN trg_source_task_conditions stc ON c.condition_id = stc.condition_id
134 WHERE stc.task_id = $1",
135 &[&task_id.to_string()],
136 )
137 .await
138 .map_err(pg_err)?;
139
140 let mut result = Vec::new();
141 for row in &rows {
142 let id: String = row.get(0);
143 let json: String = row.get(1);
144 let cond: TriggerCondition =
145 serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
146 message: e.to_string(),
147 })?;
148 result.push((ConditionId::from(id), cond));
149 }
150 Ok(result)
151 }
152
153 async fn get_cron_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
154 let client = self.db.conn().await?;
155
156 let rows = client
157 .query(
158 "SELECT condition_id, condition_json FROM trg_conditions WHERE condition_type = 'Cron'",
159 &[],
160 )
161 .await
162 .map_err(pg_err)?;
163
164 let mut result = Vec::new();
165 for row in &rows {
166 let id: String = row.get(0);
167 let json: String = row.get(1);
168 let cond: TriggerCondition =
169 serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
170 message: e.to_string(),
171 })?;
172 result.push((ConditionId::from(id), cond));
173 }
174 Ok(result)
175 }
176
177 async fn get_event_conditions(
178 &self,
179 event_code: &str,
180 ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
181 let client = self.db.conn().await?;
182
183 let rows = client
184 .query(
185 "SELECT condition_id, condition_json FROM trg_conditions \
186 WHERE condition_type = 'Event' AND event_code = $1",
187 &[&event_code],
188 )
189 .await
190 .map_err(pg_err)?;
191
192 let mut result = Vec::new();
193 for row in &rows {
194 let id: String = row.get(0);
195 let json: String = row.get(1);
196 let cond: TriggerCondition =
197 serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
198 message: e.to_string(),
199 })?;
200 result.push((ConditionId::from(id), cond));
201 }
202 Ok(result)
203 }
204
205 async fn register_trigger(&self, trigger: &TriggerDefinitionDTO) -> RustvelloResult<()> {
206 let mut client = self.db.conn().await?;
207 let tx = client.transaction().await.map_err(pg_err)?;
208 let arg_tmpl = trigger.argument_template.as_ref().map(ToString::to_string);
209
210 tx
211 .execute(
212 "INSERT INTO trg_triggers (trigger_id, task_id, logic, argument_template) VALUES ($1, $2, $3, $4)
213 ON CONFLICT (trigger_id) DO UPDATE SET task_id = $2, logic = $3, argument_template = $4",
214 &[
215 &trigger.trigger_id.as_str(),
216 &trigger.task_id.to_string(),
217 &trigger.logic.to_string(),
218 &arg_tmpl as &(dyn tokio_postgres::types::ToSql + Sync),
219 ],
220 )
221 .await
222 .map_err(pg_err)?;
223
224 for cid in &trigger.condition_ids {
226 tx.execute(
227 "INSERT INTO trg_condition_triggers (condition_id, trigger_id) VALUES ($1, $2)
228 ON CONFLICT DO NOTHING",
229 &[&cid.as_str(), &trigger.trigger_id.as_str()],
230 )
231 .await
232 .map_err(pg_err)?;
233 }
234
235 tx.commit().await.map_err(pg_err)?;
236 Ok(())
237 }
238
239 async fn get_trigger(
240 &self,
241 id: &TriggerDefinitionId,
242 ) -> RustvelloResult<Option<TriggerDefinitionDTO>> {
243 let client = self.db.conn().await?;
244
245 let row = client
246 .query_opt(
247 "SELECT task_id, logic, argument_template FROM trg_triggers WHERE trigger_id = $1",
248 &[&id.as_str()],
249 )
250 .await
251 .map_err(pg_err)?;
252
253 match row {
254 Some(r) => {
255 let task_id_str: String = r.get(0);
256 let logic_str: String = r.get(1);
257 let arg_tmpl: Option<String> = r.get(2);
258
259 let task_id: TaskId = task_id_str.parse().map_err(|e| {
260 RustvelloError::state_backend(format!("invalid task_id in database: {e}"))
261 })?;
262 let logic = parse_logic(&logic_str)?;
263 let argument_template = arg_tmpl
264 .map(|s| serde_json::from_str(&s))
265 .transpose()
266 .map_err(|e| RustvelloError::Serialization {
267 message: e.to_string(),
268 })?;
269
270 let condition_ids = self
271 .get_condition_ids_for_trigger(&client, id.as_str())
272 .await?;
273
274 Ok(Some(TriggerDefinitionDTO {
275 trigger_id: id.clone(),
276 task_id,
277 condition_ids,
278 logic,
279 argument_template,
280 }))
281 }
282 None => Ok(None),
283 }
284 }
285
286 async fn get_triggers_for_condition(
287 &self,
288 cond_id: &ConditionId,
289 ) -> RustvelloResult<Vec<TriggerDefinitionDTO>> {
290 let client = self.db.conn().await?;
291
292 let rows = client
293 .query(
294 "SELECT t.trigger_id, t.task_id, t.logic, t.argument_template
295 FROM trg_triggers t
296 INNER JOIN trg_condition_triggers ct ON t.trigger_id = ct.trigger_id
297 WHERE ct.condition_id = $1",
298 &[&cond_id.as_str()],
299 )
300 .await
301 .map_err(pg_err)?;
302
303 let trigger_ids: Vec<String> = rows.iter().map(|r| r.get::<_, String>(0)).collect();
305 let mut cond_map = self
306 .get_condition_ids_for_triggers(&client, &trigger_ids)
307 .await?;
308
309 let mut result = Vec::new();
310 for row in &rows {
311 let trigger_id: String = row.get(0);
312 let task_id_str: String = row.get(1);
313 let logic_str: String = row.get(2);
314 let arg_tmpl: Option<String> = row.get(3);
315
316 let task_id: TaskId = task_id_str.parse().map_err(|e| {
317 RustvelloError::state_backend(format!("invalid task_id in database: {e}"))
318 })?;
319 let logic = parse_logic(&logic_str)?;
320 let argument_template = arg_tmpl
321 .map(|s| serde_json::from_str(&s))
322 .transpose()
323 .map_err(|e| RustvelloError::Serialization {
324 message: e.to_string(),
325 })?;
326 let condition_ids = cond_map.remove(&trigger_id).unwrap_or_default();
327
328 result.push(TriggerDefinitionDTO {
329 trigger_id: TriggerDefinitionId::from(trigger_id),
330 task_id,
331 condition_ids,
332 logic,
333 argument_template,
334 });
335 }
336 Ok(result)
337 }
338
339 async fn remove_triggers_for_task(&self, task_id: &TaskId) -> RustvelloResult<u32> {
340 let client = self.db.conn().await?;
341 let task_str = task_id.to_string();
342
343 client
345 .execute(
346 "DELETE FROM trg_condition_triggers WHERE trigger_id IN \
347 (SELECT trigger_id FROM trg_triggers WHERE task_id = $1)",
348 &[&task_str],
349 )
350 .await
351 .map_err(pg_err)?;
352
353 let deleted = client
355 .execute("DELETE FROM trg_triggers WHERE task_id = $1", &[&task_str])
356 .await
357 .map_err(pg_err)?;
358
359 Ok(u32::try_from(deleted).unwrap_or(u32::MAX))
360 }
361
362 async fn record_valid_condition(&self, vc: &ValidCondition) -> RustvelloResult<()> {
363 let context_json =
364 serde_json::to_string(&vc.context).map_err(|e| RustvelloError::Serialization {
365 message: e.to_string(),
366 })?;
367
368 let client = self.db.conn().await?;
369 client
370 .execute(
371 "INSERT INTO trg_valid_conditions (valid_condition_id, condition_id, context_json) VALUES ($1, $2, $3)
372 ON CONFLICT (valid_condition_id) DO UPDATE SET condition_id = $2, context_json = $3",
373 &[&vc.valid_condition_id, &vc.condition_id.as_str(), &context_json],
374 )
375 .await
376 .map_err(pg_err)?;
377 Ok(())
378 }
379
380 async fn get_valid_conditions(&self) -> RustvelloResult<Vec<ValidCondition>> {
381 let client = self.db.conn().await?;
382
383 let rows = client
384 .query(
385 "SELECT valid_condition_id, condition_id, context_json FROM trg_valid_conditions",
386 &[],
387 )
388 .await
389 .map_err(pg_err)?;
390
391 let mut result = Vec::new();
392 for row in &rows {
393 let vc_id: String = row.get(0);
394 let cond_id: String = row.get(1);
395 let ctx_json: String = row.get(2);
396 let context: ConditionContext =
397 serde_json::from_str(&ctx_json).map_err(|e| RustvelloError::Serialization {
398 message: e.to_string(),
399 })?;
400 result.push(ValidCondition {
401 valid_condition_id: vc_id,
402 condition_id: ConditionId::from(cond_id),
403 context,
404 });
405 }
406 Ok(result)
407 }
408
409 async fn clear_valid_conditions(&self, ids: &[String]) -> RustvelloResult<()> {
410 if ids.is_empty() {
411 return Ok(());
412 }
413 let client = self.db.conn().await?;
414 client
415 .execute(
416 "DELETE FROM trg_valid_conditions WHERE valid_condition_id = ANY($1)",
417 &[&ids],
418 )
419 .await
420 .map_err(pg_err)?;
421 Ok(())
422 }
423
424 async fn get_last_cron_execution(
425 &self,
426 cond_id: &ConditionId,
427 ) -> RustvelloResult<Option<DateTime<Utc>>> {
428 let client = self.db.conn().await?;
429
430 let row = client
431 .query_opt(
432 "SELECT last_execution FROM trg_cron_executions WHERE condition_id = $1",
433 &[&cond_id.as_str()],
434 )
435 .await
436 .map_err(pg_err)?;
437
438 Ok(row.map(|r| r.get::<_, DateTime<Utc>>(0)))
439 }
440
441 async fn store_cron_execution(
442 &self,
443 cond_id: &ConditionId,
444 time: DateTime<Utc>,
445 expected_last: Option<DateTime<Utc>>,
446 ) -> RustvelloResult<bool> {
447 let client = self.db.conn().await?;
448
449 let changed = match expected_last {
450 None => {
451 client
453 .execute(
454 "INSERT INTO trg_cron_executions (condition_id, last_execution) VALUES ($1, $2)
455 ON CONFLICT DO NOTHING",
456 &[&cond_id.as_str(), &time],
457 )
458 .await
459 .map_err(pg_err)?
460 }
461 Some(expected) => client
462 .execute(
463 "UPDATE trg_cron_executions SET last_execution = $1
464 WHERE condition_id = $2 AND last_execution = $3",
465 &[&time, &cond_id.as_str(), &expected],
466 )
467 .await
468 .map_err(pg_err)?,
469 };
470
471 Ok(changed > 0)
472 }
473
474 async fn claim_trigger_run(&self, run_id: &TriggerRunId) -> RustvelloResult<bool> {
475 let client = self.db.conn().await?;
476 let now = Utc::now();
477
478 let changed = client
479 .execute(
480 "INSERT INTO trg_trigger_run_claims (trigger_run_id, claimed_at) VALUES ($1, $2)
481 ON CONFLICT DO NOTHING",
482 &[&run_id.as_str(), &now],
483 )
484 .await
485 .map_err(pg_err)?;
486
487 Ok(changed > 0)
488 }
489
490 async fn purge(&self) -> RustvelloResult<()> {
491 let client = self.db.conn().await?;
492 client
493 .batch_execute(
494 "DELETE FROM trg_trigger_run_claims;
495 DELETE FROM trg_cron_executions;
496 DELETE FROM trg_valid_conditions;
497 DELETE FROM trg_source_task_conditions;
498 DELETE FROM trg_condition_triggers;
499 DELETE FROM trg_triggers;
500 DELETE FROM trg_conditions;",
501 )
502 .await
503 .map_err(pg_err)?;
504 Ok(())
505 }
506
507 async fn get_all_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
508 let client = self.db.conn().await?;
509
510 let rows = client
511 .query(
512 "SELECT condition_id, condition_json FROM trg_conditions",
513 &[],
514 )
515 .await
516 .map_err(pg_err)?;
517
518 let mut result = Vec::new();
519 for row in &rows {
520 let id: String = row.get(0);
521 let json: String = row.get(1);
522 let cond: TriggerCondition =
523 serde_json::from_str(&json).map_err(|e| RustvelloError::Serialization {
524 message: e.to_string(),
525 })?;
526 result.push((ConditionId::from(id), cond));
527 }
528 Ok(result)
529 }
530}
531
532impl PostgresTriggerStore {
534 async fn get_condition_ids_for_trigger(
535 &self,
536 client: &deadpool_postgres::Client,
537 trigger_id: &str,
538 ) -> RustvelloResult<Vec<ConditionId>> {
539 let rows = client
540 .query(
541 "SELECT condition_id FROM trg_condition_triggers WHERE trigger_id = $1",
542 &[&trigger_id],
543 )
544 .await
545 .map_err(pg_err)?;
546
547 Ok(rows
548 .iter()
549 .map(|r| ConditionId::from(r.get::<_, String>(0)))
550 .collect())
551 }
552
553 async fn get_condition_ids_for_triggers(
555 &self,
556 client: &deadpool_postgres::Client,
557 trigger_ids: &[String],
558 ) -> RustvelloResult<HashMap<String, Vec<ConditionId>>> {
559 if trigger_ids.is_empty() {
560 return Ok(HashMap::new());
561 }
562 let rows = client
563 .query(
564 "SELECT trigger_id, condition_id FROM trg_condition_triggers \
565 WHERE trigger_id = ANY($1)",
566 &[&trigger_ids],
567 )
568 .await
569 .map_err(pg_err)?;
570
571 let mut map: HashMap<String, Vec<ConditionId>> = HashMap::new();
572 for row in &rows {
573 let tid: String = row.get(0);
574 let cid: String = row.get(1);
575 map.entry(tid).or_default().push(ConditionId::from(cid));
576 }
577 Ok(map)
578 }
579}