Skip to main content

rustvello_redis/
trigger.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use redis::AsyncCommands;
6use tracing;
7
8use rustvello_core::error::{RustvelloError, RustvelloResult};
9use rustvello_core::trigger::TriggerStore;
10use rustvello_proto::identifiers::TaskId;
11use rustvello_proto::trigger::{
12    ConditionId, TriggerCondition, TriggerDefinitionDTO, TriggerDefinitionId, TriggerRunId,
13    ValidCondition,
14};
15
16use crate::connection::{redis_err, scan_keys, RedisPool};
17
18/// Batch-fetch conditions by ID using MGET instead of N individual GETs.
19async fn batch_get_conditions(
20    conn: &mut redis::aio::MultiplexedConnection,
21    member_ids: &[String],
22    cond_prefix: &str,
23) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
24    if member_ids.is_empty() {
25        return Ok(Vec::new());
26    }
27    let keys: Vec<String> = member_ids
28        .iter()
29        .map(|id| {
30            let mut k = String::with_capacity(cond_prefix.len() + id.len());
31            k.push_str(cond_prefix);
32            k.push_str(id);
33            k
34        })
35        .collect();
36    let values: Vec<Option<String>> = redis::cmd("MGET")
37        .arg(&keys)
38        .query_async(conn)
39        .await
40        .map_err(redis_err)?;
41    let mut result = Vec::with_capacity(member_ids.len());
42    for (cid_str, val) in member_ids.iter().zip(values) {
43        if let Some(json) = val {
44            match serde_json::from_str::<TriggerCondition>(&json) {
45                Ok(cond) => result.push((ConditionId::from(cid_str.clone()), cond)),
46                Err(e) => {
47                    tracing::warn!("Failed to deserialize condition {}: {}", cid_str, e);
48                }
49            }
50        }
51    }
52    Ok(result)
53}
54
55/// Redis-backed trigger store.
56#[non_exhaustive]
57pub struct RedisTriggerStore {
58    pool: Arc<RedisPool>,
59    cond_prefix: String,
60    cond_task_prefix: String,
61    trigger_prefix: String,
62    cond_trigger_prefix: String,
63    valid_cond_prefix: String,
64    cron_exec_prefix: String,
65    run_prefix: String,
66    trigger_task_prefix: String,
67    cron_index: String,
68    event_index_prefix: String,
69}
70
71impl RedisTriggerStore {
72    pub fn new(pool: Arc<RedisPool>) -> Self {
73        let p = pool.prefix();
74        Self {
75            cond_prefix: format!("{p}trg:cond:"),
76            cond_task_prefix: format!("{p}trg:cond_task:"),
77            trigger_prefix: format!("{p}trg:def:"),
78            cond_trigger_prefix: format!("{p}trg:cond_trg:"),
79            valid_cond_prefix: format!("{p}trg:valid:"),
80            cron_exec_prefix: format!("{p}trg:cron_exec:"),
81            run_prefix: format!("{p}trg:run:"),
82            trigger_task_prefix: format!("{p}trg:trg_task:"),
83            cron_index: format!("{p}trg:cron_ids"),
84            event_index_prefix: format!("{p}trg:event:"),
85            pool,
86        }
87    }
88}
89
90#[async_trait]
91impl TriggerStore for RedisTriggerStore {
92    async fn register_condition(
93        &self,
94        condition: &TriggerCondition,
95    ) -> RustvelloResult<ConditionId> {
96        let cond_id = condition.condition_id();
97        let mut conn = self.pool.conn().await?;
98        let json = serde_json::to_string(condition).map_err(|e| RustvelloError::Serialization {
99            message: e.to_string(),
100        })?;
101        conn.set::<_, _, ()>(format!("{}{}", &self.cond_prefix, cond_id.as_str()), &json)
102            .await
103            .map_err(redis_err)?;
104
105        // Index by task if the condition references one
106        for task_id in condition.source_task_ids() {
107            conn.sadd::<_, _, ()>(
108                format!("{}{}", &self.cond_task_prefix, task_id),
109                cond_id.as_str().to_owned(),
110            )
111            .await
112            .map_err(redis_err)?;
113        }
114
115        // Secondary index by condition type for efficient lookup
116        if matches!(condition, TriggerCondition::Cron(_)) {
117            conn.sadd::<_, _, ()>(&self.cron_index, cond_id.as_str().to_owned())
118                .await
119                .map_err(redis_err)?;
120        }
121        if let TriggerCondition::Event(ev) = condition {
122            conn.sadd::<_, _, ()>(
123                format!("{}{}", &self.event_index_prefix, ev.event_code),
124                cond_id.as_str().to_owned(),
125            )
126            .await
127            .map_err(redis_err)?;
128        }
129
130        Ok(cond_id)
131    }
132
133    async fn get_condition(&self, id: &ConditionId) -> RustvelloResult<Option<TriggerCondition>> {
134        let mut conn = self.pool.conn().await?;
135        let val: Option<String> = conn
136            .get(format!("{}{}", &self.cond_prefix, id.as_str()))
137            .await
138            .map_err(redis_err)?;
139        match val {
140            Some(s) => {
141                let c: TriggerCondition =
142                    serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
143                        message: e.to_string(),
144                    })?;
145                Ok(Some(c))
146            }
147            None => Ok(None),
148        }
149    }
150
151    async fn get_conditions_for_task(
152        &self,
153        task_id: &TaskId,
154    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
155        let mut conn = self.pool.conn().await?;
156        let members: Vec<String> = conn
157            .smembers(format!("{}{}", &self.cond_task_prefix, task_id))
158            .await
159            .map_err(redis_err)?;
160
161        batch_get_conditions(&mut conn, &members, &self.cond_prefix).await
162    }
163
164    async fn get_cron_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
165        let mut conn = self.pool.conn().await?;
166        // Use secondary index instead of full keyspace SCAN
167        let members: Vec<String> = conn.smembers(&self.cron_index).await.map_err(redis_err)?;
168
169        batch_get_conditions(&mut conn, &members, &self.cond_prefix).await
170    }
171
172    async fn get_event_conditions(
173        &self,
174        event_code: &str,
175    ) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
176        let mut conn = self.pool.conn().await?;
177        // Use per-event-code secondary index instead of full keyspace SCAN
178        let members: Vec<String> = conn
179            .smembers(format!("{}{}", &self.event_index_prefix, event_code))
180            .await
181            .map_err(redis_err)?;
182
183        batch_get_conditions(&mut conn, &members, &self.cond_prefix).await
184    }
185
186    async fn register_trigger(&self, trigger: &TriggerDefinitionDTO) -> RustvelloResult<()> {
187        let mut conn = self.pool.conn().await?;
188        let json = serde_json::to_string(trigger).map_err(|e| RustvelloError::Serialization {
189            message: e.to_string(),
190        })?;
191        conn.set::<_, _, ()>(
192            format!("{}{}", &self.trigger_prefix, trigger.trigger_id.as_str()),
193            &json,
194        )
195        .await
196        .map_err(redis_err)?;
197
198        // Index by condition
199        for cid in &trigger.condition_ids {
200            conn.sadd::<_, _, ()>(
201                format!("{}{}", &self.cond_trigger_prefix, cid.as_str()),
202                trigger.trigger_id.as_str().to_owned(),
203            )
204            .await
205            .map_err(redis_err)?;
206        }
207
208        // Index by task
209        conn.sadd::<_, _, ()>(
210            format!("{}{}", &self.trigger_task_prefix, trigger.task_id),
211            trigger.trigger_id.as_str().to_owned(),
212        )
213        .await
214        .map_err(redis_err)?;
215
216        Ok(())
217    }
218
219    async fn get_trigger(
220        &self,
221        id: &TriggerDefinitionId,
222    ) -> RustvelloResult<Option<TriggerDefinitionDTO>> {
223        let mut conn = self.pool.conn().await?;
224        let val: Option<String> = conn
225            .get(format!("{}{}", &self.trigger_prefix, id.as_str()))
226            .await
227            .map_err(redis_err)?;
228        match val {
229            Some(s) => {
230                let t: TriggerDefinitionDTO =
231                    serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
232                        message: e.to_string(),
233                    })?;
234                Ok(Some(t))
235            }
236            None => Ok(None),
237        }
238    }
239
240    async fn get_triggers_for_condition(
241        &self,
242        cond_id: &ConditionId,
243    ) -> RustvelloResult<Vec<TriggerDefinitionDTO>> {
244        let mut conn = self.pool.conn().await?;
245        let members: Vec<String> = conn
246            .smembers(format!("{}{}", &self.cond_trigger_prefix, cond_id.as_str()))
247            .await
248            .map_err(redis_err)?;
249
250        if members.is_empty() {
251            return Ok(Vec::new());
252        }
253
254        let keys: Vec<String> = members
255            .iter()
256            .map(|tid| format!("{}{}", &self.trigger_prefix, tid))
257            .collect();
258        let values: Vec<Option<String>> = redis::cmd("MGET")
259            .arg(&keys)
260            .query_async(&mut conn)
261            .await
262            .map_err(redis_err)?;
263
264        let mut result = Vec::new();
265        for val in values.into_iter().flatten() {
266            let t: TriggerDefinitionDTO =
267                serde_json::from_str(&val).map_err(|e| RustvelloError::Serialization {
268                    message: e.to_string(),
269                })?;
270            result.push(t);
271        }
272        Ok(result)
273    }
274
275    async fn remove_triggers_for_task(&self, task_id: &TaskId) -> RustvelloResult<u32> {
276        let mut conn = self.pool.conn().await?;
277        let members: Vec<String> = conn
278            .smembers(format!("{}{}", &self.trigger_task_prefix, task_id))
279            .await
280            .map_err(redis_err)?;
281
282        let count = u32::try_from(members.len()).unwrap_or(u32::MAX);
283        for tid_str in &members {
284            // Read trigger data to find its condition_ids so we can clean up
285            // the reverse index (&self.cond_trigger_prefix:{cond_id} sets)
286            let val: Option<String> = conn
287                .get(format!("{}{}", &self.trigger_prefix, tid_str))
288                .await
289                .map_err(redis_err)?;
290            if let Some(json) = val {
291                if let Ok(trigger) = serde_json::from_str::<TriggerDefinitionDTO>(&json) {
292                    for cid in &trigger.condition_ids {
293                        conn.srem::<_, _, ()>(
294                            format!("{}{}", &self.cond_trigger_prefix, cid.as_str()),
295                            tid_str.as_str(),
296                        )
297                        .await
298                        .map_err(redis_err)?;
299                    }
300                }
301            }
302            conn.del::<_, ()>(format!("{}{}", &self.trigger_prefix, tid_str))
303                .await
304                .map_err(redis_err)?;
305        }
306        conn.del::<_, ()>(format!("{}{}", &self.trigger_task_prefix, task_id))
307            .await
308            .map_err(redis_err)?;
309        Ok(count)
310    }
311
312    async fn record_valid_condition(&self, vc: &ValidCondition) -> RustvelloResult<()> {
313        let mut conn = self.pool.conn().await?;
314        let json = serde_json::to_string(vc).map_err(|e| RustvelloError::Serialization {
315            message: e.to_string(),
316        })?;
317        let key = format!("{}{}", &self.valid_cond_prefix, vc.valid_condition_id);
318        conn.set::<_, _, ()>(&key, &json).await.map_err(redis_err)
319    }
320
321    async fn get_valid_conditions(&self) -> RustvelloResult<Vec<ValidCondition>> {
322        let mut conn = self.pool.conn().await?;
323        let keys = scan_keys(&mut conn, &format!("{}*", &self.valid_cond_prefix)).await?;
324
325        if keys.is_empty() {
326            return Ok(Vec::new());
327        }
328
329        let values: Vec<Option<String>> = redis::cmd("MGET")
330            .arg(&keys)
331            .query_async(&mut conn)
332            .await
333            .map_err(redis_err)?;
334
335        let mut result = Vec::new();
336        for val in values.into_iter().flatten() {
337            let vc: ValidCondition =
338                serde_json::from_str(&val).map_err(|e| RustvelloError::Serialization {
339                    message: e.to_string(),
340                })?;
341            result.push(vc);
342        }
343        Ok(result)
344    }
345
346    async fn clear_valid_conditions(&self, ids: &[String]) -> RustvelloResult<()> {
347        if ids.is_empty() {
348            return Ok(());
349        }
350        let mut conn = self.pool.conn().await?;
351        let keys: Vec<String> = ids
352            .iter()
353            .map(|id| format!("{}{}", &self.valid_cond_prefix, id))
354            .collect();
355        conn.del::<_, ()>(keys).await.map_err(redis_err)
356    }
357
358    async fn get_last_cron_execution(
359        &self,
360        cond_id: &ConditionId,
361    ) -> RustvelloResult<Option<DateTime<Utc>>> {
362        let mut conn = self.pool.conn().await?;
363        let val: Option<String> = conn
364            .get(format!("{}{}", &self.cron_exec_prefix, cond_id.as_str()))
365            .await
366            .map_err(redis_err)?;
367        match val {
368            Some(s) => {
369                let dt = DateTime::parse_from_rfc3339(&s)
370                    .map(|d| d.with_timezone(&Utc))
371                    .map_err(|e| RustvelloError::Serialization {
372                        message: format!("cron timestamp: {}", e),
373                    })?;
374                Ok(Some(dt))
375            }
376            None => Ok(None),
377        }
378    }
379
380    async fn store_cron_execution(
381        &self,
382        cond_id: &ConditionId,
383        time: DateTime<Utc>,
384        expected_last: Option<DateTime<Utc>>,
385    ) -> RustvelloResult<bool> {
386        let key = format!("{}{}", &self.cron_exec_prefix, cond_id.as_str());
387        let mut conn = self.pool.conn().await?;
388
389        // Atomic compare-and-swap via Lua script
390        let expected_val = match expected_last {
391            Some(dt) => dt.to_rfc3339(),
392            None => String::new(), // empty string sentinel for "no previous value"
393        };
394        let new_val = time.to_rfc3339();
395
396        // Lua CAS: if expected is "" (no previous), check key doesn't exist;
397        // otherwise check current value matches expected. Set new value on match.
398        let script = redis::Script::new(
399            r"
400            local current = redis.call('GET', KEYS[1])
401            local expected = ARGV[1]
402            if expected == '' then
403                if current == false then
404                    redis.call('SET', KEYS[1], ARGV[2])
405                    return 1
406                else
407                    return 0
408                end
409            else
410                if current == expected then
411                    redis.call('SET', KEYS[1], ARGV[2])
412                    return 1
413                else
414                    return 0
415                end
416            end
417            ",
418        );
419        let result: i32 = script
420            .key(&key)
421            .arg(&expected_val)
422            .arg(&new_val)
423            .invoke_async(&mut conn)
424            .await
425            .map_err(redis_err)?;
426        Ok(result == 1)
427    }
428
429    async fn claim_trigger_run(&self, run_id: &TriggerRunId) -> RustvelloResult<bool> {
430        let key = format!("{}{}", &self.run_prefix, run_id.as_str());
431        let mut conn = self.pool.conn().await?;
432        // SETNX — returns true if the key was set (i.e. we claimed it)
433        let set: bool = conn.set_nx(&key, "1").await.map_err(redis_err)?;
434        if set {
435            // Auto-expire after 1 hour to prevent leaks
436            conn.expire::<_, ()>(&key, 3600).await.map_err(redis_err)?;
437        }
438        Ok(set)
439    }
440
441    async fn purge(&self) -> RustvelloResult<()> {
442        let prefixes = [
443            &self.cond_prefix,
444            &self.cond_task_prefix,
445            &self.trigger_prefix,
446            &self.cond_trigger_prefix,
447            &self.valid_cond_prefix,
448            &self.cron_exec_prefix,
449            &self.run_prefix,
450            &self.trigger_task_prefix,
451            &self.event_index_prefix,
452        ];
453        let mut conn = self.pool.conn().await?;
454        for prefix in prefixes {
455            let keys = scan_keys(&mut conn, &format!("{}*", prefix)).await?;
456            if !keys.is_empty() {
457                conn.del::<_, ()>(keys).await.map_err(redis_err)?;
458            }
459        }
460        // Also delete the singleton cron index key
461        conn.del::<_, ()>(&self.cron_index)
462            .await
463            .map_err(redis_err)?;
464        Ok(())
465    }
466
467    async fn get_all_conditions(&self) -> RustvelloResult<Vec<(ConditionId, TriggerCondition)>> {
468        let mut conn = self.pool.conn().await?;
469        let keys = scan_keys(&mut conn, &format!("{}*", &self.cond_prefix)).await?;
470        let ids: Vec<String> = keys
471            .iter()
472            .filter_map(|k| k.strip_prefix(&self.cond_prefix).map(String::from))
473            .collect();
474        batch_get_conditions(&mut conn, &ids, &self.cond_prefix).await
475    }
476}