Skip to main content

rustvello_redis/
state_backend.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use redis::AsyncCommands;
5
6use rustvello_core::error::{RustvelloError, RustvelloResult};
7use rustvello_core::state_backend::{
8    StateBackendCore, StateBackendQuery, StateBackendRunner, StoredRunnerContext,
9};
10use rustvello_proto::call::CallDTO;
11use rustvello_proto::identifiers::{CallId, InvocationId, TaskId};
12use rustvello_proto::invocation::{InvocationDTO, InvocationHistory, WorkflowIdentity};
13
14use crate::connection::{redis_err, scan_keys, RedisPool};
15use rustvello_core::error::TaskError;
16
17fn prefixed_key(prefix: &str, suffix: &str) -> String {
18    let mut s = String::with_capacity(prefix.len() + suffix.len());
19    s.push_str(prefix);
20    s.push_str(suffix);
21    s
22}
23
24/// Redis-backed state backend.
25///
26/// Stores invocations, calls, results, errors, and history as JSON strings
27/// in Redis keys with structured prefixes.
28#[non_exhaustive]
29pub struct RedisStateBackend {
30    pool: Arc<RedisPool>,
31    inv_prefix: String,
32    call_prefix: String,
33    result_prefix: String,
34    error_prefix: String,
35    history_prefix: String,
36    wf_prefix: String,
37    child_prefix: String,
38    wf_types_key: String,
39    wf_runs_prefix: String,
40    wf_data_prefix: String,
41    app_infos_key: String,
42    wf_sub_prefix: String,
43    runner_prefix: String,
44    runner_inv_prefix: String,
45    history_ts_key: String,
46}
47
48impl RedisStateBackend {
49    pub fn new(pool: Arc<RedisPool>) -> Self {
50        let p = pool.prefix();
51        Self {
52            inv_prefix: format!("{p}state:inv:"),
53            call_prefix: format!("{p}state:call:"),
54            result_prefix: format!("{p}state:result:"),
55            error_prefix: format!("{p}state:error:"),
56            history_prefix: format!("{p}state:history:"),
57            wf_prefix: format!("{p}state:wf:"),
58            child_prefix: format!("{p}state:child:"),
59            wf_types_key: format!("{p}state:wf_types"),
60            wf_runs_prefix: format!("{p}state:wf_runs:"),
61            wf_data_prefix: format!("{p}state:wf_data:"),
62            app_infos_key: format!("{p}state:app_infos"),
63            wf_sub_prefix: format!("{p}state:wf_sub:"),
64            runner_prefix: format!("{p}state:runner:"),
65            runner_inv_prefix: format!("{p}state:runner_inv:"),
66            history_ts_key: format!("{p}state:history_ts"),
67            pool,
68        }
69    }
70}
71
72#[async_trait]
73impl StateBackendCore for RedisStateBackend {
74    async fn upsert_invocation(
75        &self,
76        invocation: &InvocationDTO,
77        call: &CallDTO,
78    ) -> RustvelloResult<()> {
79        let mut conn = self.pool.conn().await?;
80        let inv_json =
81            serde_json::to_string(invocation).map_err(|e| RustvelloError::Serialization {
82                message: e.to_string(),
83            })?;
84        let call_json = serde_json::to_string(call).map_err(|e| RustvelloError::Serialization {
85            message: e.to_string(),
86        })?;
87
88        conn.set::<_, _, ()>(
89            prefixed_key(&self.inv_prefix, invocation.invocation_id.as_ref()),
90            &inv_json,
91        )
92        .await
93        .map_err(redis_err)?;
94
95        conn.set::<_, _, ()>(
96            prefixed_key(&self.call_prefix, &call.call_id.to_string()),
97            &call_json,
98        )
99        .await
100        .map_err(redis_err)?;
101
102        // Index by workflow if present
103        if let Some(wf) = &invocation.workflow {
104            conn.sadd::<_, _, ()>(
105                prefixed_key(&self.wf_prefix, wf.workflow_id.as_ref()),
106                invocation.invocation_id.as_str(),
107            )
108            .await
109            .map_err(redis_err)?;
110        }
111
112        // Index by parent
113        if let Some(parent_id) = &invocation.parent_invocation_id {
114            conn.sadd::<_, _, ()>(
115                prefixed_key(&self.child_prefix, parent_id.as_ref()),
116                invocation.invocation_id.as_str(),
117            )
118            .await
119            .map_err(redis_err)?;
120        }
121
122        Ok(())
123    }
124
125    async fn get_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<InvocationDTO> {
126        let mut conn = self.pool.conn().await?;
127        let val: Option<String> = conn
128            .get(prefixed_key(&self.inv_prefix, invocation_id.as_ref()))
129            .await
130            .map_err(redis_err)?;
131        match val {
132            Some(s) => serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
133                message: e.to_string(),
134            }),
135            None => Err(RustvelloError::InvocationNotFound {
136                invocation_id: invocation_id.clone(),
137            }),
138        }
139    }
140
141    async fn get_call(&self, call_id: &CallId) -> RustvelloResult<CallDTO> {
142        let mut conn = self.pool.conn().await?;
143        let val: Option<String> = conn
144            .get(prefixed_key(&self.call_prefix, &call_id.to_string()))
145            .await
146            .map_err(redis_err)?;
147        match val {
148            Some(s) => serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
149                message: e.to_string(),
150            }),
151            None => Err(RustvelloError::state_backend(format!(
152                "call not found: {}",
153                call_id
154            ))),
155        }
156    }
157
158    async fn store_result(
159        &self,
160        invocation_id: &InvocationId,
161        result: &str,
162    ) -> RustvelloResult<()> {
163        let mut conn = self.pool.conn().await?;
164        conn.set::<_, _, ()>(
165            prefixed_key(&self.result_prefix, invocation_id.as_ref()),
166            result,
167        )
168        .await
169        .map_err(redis_err)
170    }
171
172    async fn get_result(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<String>> {
173        let mut conn = self.pool.conn().await?;
174        conn.get(prefixed_key(&self.result_prefix, invocation_id.as_ref()))
175            .await
176            .map_err(redis_err)
177    }
178
179    async fn store_error(
180        &self,
181        invocation_id: &InvocationId,
182        error: &TaskError,
183    ) -> RustvelloResult<()> {
184        let mut conn = self.pool.conn().await?;
185        let json = serde_json::to_string(error).map_err(|e| RustvelloError::Serialization {
186            message: e.to_string(),
187        })?;
188        conn.set::<_, _, ()>(
189            prefixed_key(&self.error_prefix, invocation_id.as_ref()),
190            &json,
191        )
192        .await
193        .map_err(redis_err)
194    }
195
196    async fn get_error(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<TaskError>> {
197        let mut conn = self.pool.conn().await?;
198        let val: Option<String> = conn
199            .get(prefixed_key(&self.error_prefix, invocation_id.as_ref()))
200            .await
201            .map_err(redis_err)?;
202        match val {
203            Some(s) => {
204                let err: TaskError =
205                    serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
206                        message: e.to_string(),
207                    })?;
208                Ok(Some(err))
209            }
210            None => Ok(None),
211        }
212    }
213
214    async fn add_history(&self, history: &InvocationHistory) -> RustvelloResult<()> {
215        let mut conn = self.pool.conn().await?;
216        let json = serde_json::to_string(history).map_err(|e| RustvelloError::Serialization {
217            message: e.to_string(),
218        })?;
219        // Append to invocation-scoped list
220        conn.rpush::<_, _, ()>(
221            prefixed_key(&self.history_prefix, history.invocation_id.as_ref()),
222            &json,
223        )
224        .await
225        .map_err(redis_err)?;
226        // Index by timestamp for time-range queries
227        let ts = history
228            .history_timestamp
229            .unwrap_or(history.status_record.timestamp);
230        conn.zadd::<_, _, _, ()>(&self.history_ts_key, &json, ts.timestamp_millis() as f64)
231            .await
232            .map_err(redis_err)?;
233        // Maintain runner → invocation reverse index
234        let rid = history
235            .runner_id
236            .as_ref()
237            .or(history.status_record.runner_id.as_ref());
238        if let Some(r) = rid {
239            conn.sadd::<_, _, ()>(
240                prefixed_key(&self.runner_inv_prefix, r.as_str()),
241                history.invocation_id.as_str(),
242            )
243            .await
244            .map_err(redis_err)?;
245        }
246        Ok(())
247    }
248
249    async fn get_history(
250        &self,
251        invocation_id: &InvocationId,
252    ) -> RustvelloResult<Vec<InvocationHistory>> {
253        let mut conn = self.pool.conn().await?;
254        let vals: Vec<String> = conn
255            .lrange(
256                prefixed_key(&self.history_prefix, invocation_id.as_ref()),
257                0,
258                -1,
259            )
260            .await
261            .map_err(redis_err)?;
262        vals.into_iter()
263            .map(|s| {
264                serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
265                    message: e.to_string(),
266                })
267            })
268            .collect()
269    }
270
271    async fn purge(&self) -> RustvelloResult<()> {
272        let prefixes = [
273            &self.inv_prefix,
274            &self.call_prefix,
275            &self.result_prefix,
276            &self.error_prefix,
277            &self.history_prefix,
278            &self.wf_prefix,
279            &self.child_prefix,
280            &self.wf_runs_prefix,
281            &self.wf_data_prefix,
282            &self.wf_sub_prefix,
283            &self.runner_prefix,
284            &self.runner_inv_prefix,
285        ];
286        let mut conn = self.pool.conn().await?;
287        for prefix in prefixes {
288            let keys = scan_keys(&mut conn, &format!("{}*", prefix)).await?;
289            if !keys.is_empty() {
290                conn.del::<_, ()>(keys).await.map_err(redis_err)?;
291            }
292        }
293        for key in [
294            &self.wf_types_key,
295            &self.app_infos_key,
296            &self.history_ts_key,
297        ] {
298            conn.del::<_, ()>(key).await.map_err(redis_err)?;
299        }
300        Ok(())
301    }
302}
303
304#[async_trait]
305impl StateBackendQuery for RedisStateBackend {
306    async fn get_workflow_invocations(
307        &self,
308        workflow_id: &InvocationId,
309    ) -> RustvelloResult<Vec<InvocationId>> {
310        let mut conn = self.pool.conn().await?;
311        let members: Vec<String> = conn
312            .smembers(prefixed_key(&self.wf_prefix, workflow_id.as_ref()))
313            .await
314            .map_err(redis_err)?;
315        Ok(members.into_iter().map(InvocationId::from_string).collect())
316    }
317
318    async fn get_child_invocations(
319        &self,
320        parent_invocation_id: &InvocationId,
321    ) -> RustvelloResult<Vec<InvocationId>> {
322        let mut conn = self.pool.conn().await?;
323        let members: Vec<String> = conn
324            .smembers(format!("{}{}", &self.child_prefix, parent_invocation_id))
325            .await
326            .map_err(redis_err)?;
327        Ok(members.into_iter().map(InvocationId::from_string).collect())
328    }
329
330    async fn store_workflow_run(&self, workflow: &WorkflowIdentity) -> RustvelloResult<()> {
331        let mut conn = self.pool.conn().await?;
332        let type_key = workflow.workflow_type.to_string();
333        // Track the workflow type
334        conn.sadd::<_, _, ()>(&self.wf_types_key, &type_key)
335            .await
336            .map_err(redis_err)?;
337        // Store the run under its type (keyed by workflow_id to avoid dupes)
338        let json = serde_json::to_string(workflow).map_err(|e| RustvelloError::Serialization {
339            message: e.to_string(),
340        })?;
341        conn.hset::<_, _, _, ()>(
342            prefixed_key(&self.wf_runs_prefix, &type_key),
343            workflow.workflow_id.as_str(),
344            &json,
345        )
346        .await
347        .map_err(redis_err)?;
348        Ok(())
349    }
350
351    async fn get_all_workflow_types(&self) -> RustvelloResult<Vec<TaskId>> {
352        let mut conn = self.pool.conn().await?;
353        let members: Vec<String> = conn.smembers(&self.wf_types_key).await.map_err(redis_err)?;
354        members
355            .into_iter()
356            .map(|s| {
357                s.parse::<TaskId>()
358                    .map_err(|e| RustvelloError::state_backend(format!("invalid task_id: {e}")))
359            })
360            .collect()
361    }
362
363    async fn get_workflow_runs(
364        &self,
365        workflow_type: &TaskId,
366    ) -> RustvelloResult<Vec<WorkflowIdentity>> {
367        let mut conn = self.pool.conn().await?;
368        let vals: Vec<String> = conn
369            .hvals(prefixed_key(
370                &self.wf_runs_prefix,
371                &workflow_type.to_string(),
372            ))
373            .await
374            .map_err(redis_err)?;
375        vals.into_iter()
376            .map(|s| {
377                serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
378                    message: e.to_string(),
379                })
380            })
381            .collect()
382    }
383
384    async fn set_workflow_data(
385        &self,
386        workflow_id: &InvocationId,
387        key: &str,
388        value: &str,
389    ) -> RustvelloResult<()> {
390        let mut conn = self.pool.conn().await?;
391        conn.hset::<_, _, _, ()>(
392            prefixed_key(&self.wf_data_prefix, workflow_id.as_ref()),
393            key,
394            value,
395        )
396        .await
397        .map_err(redis_err)?;
398        Ok(())
399    }
400
401    async fn get_workflow_data(
402        &self,
403        workflow_id: &InvocationId,
404        key: &str,
405    ) -> RustvelloResult<Option<String>> {
406        let mut conn = self.pool.conn().await?;
407        conn.hget(
408            prefixed_key(&self.wf_data_prefix, workflow_id.as_ref()),
409            key,
410        )
411        .await
412        .map_err(redis_err)
413    }
414
415    async fn store_app_info(&self, app_id: &str, info_json: &str) -> RustvelloResult<()> {
416        let mut conn = self.pool.conn().await?;
417        conn.hset::<_, _, _, ()>(&self.app_infos_key, app_id, info_json)
418            .await
419            .map_err(redis_err)?;
420        Ok(())
421    }
422
423    async fn get_app_info(&self, app_id: &str) -> RustvelloResult<Option<String>> {
424        let mut conn = self.pool.conn().await?;
425        conn.hget(&self.app_infos_key, app_id)
426            .await
427            .map_err(redis_err)
428    }
429
430    async fn get_all_app_infos(&self) -> RustvelloResult<Vec<(String, String)>> {
431        let mut conn = self.pool.conn().await?;
432        let map: Vec<(String, String)> =
433            conn.hgetall(&self.app_infos_key).await.map_err(redis_err)?;
434        Ok(map)
435    }
436
437    async fn store_workflow_sub_invocation(
438        &self,
439        workflow_id: &InvocationId,
440        sub_inv_id: &InvocationId,
441    ) -> RustvelloResult<()> {
442        let mut conn = self.pool.conn().await?;
443        conn.sadd::<_, _, ()>(
444            prefixed_key(&self.wf_sub_prefix, workflow_id.as_ref()),
445            sub_inv_id.as_str(),
446        )
447        .await
448        .map_err(redis_err)?;
449        Ok(())
450    }
451
452    async fn get_workflow_sub_invocations(
453        &self,
454        workflow_id: &InvocationId,
455    ) -> RustvelloResult<Vec<InvocationId>> {
456        let mut conn = self.pool.conn().await?;
457        let members: Vec<String> = conn
458            .smembers(prefixed_key(&self.wf_sub_prefix, workflow_id.as_ref()))
459            .await
460            .map_err(redis_err)?;
461        Ok(members.into_iter().map(InvocationId::from_string).collect())
462    }
463
464    async fn get_all_workflow_runs(&self) -> RustvelloResult<Vec<WorkflowIdentity>> {
465        let mut conn = self.pool.conn().await?;
466        let types: Vec<String> = conn.smembers(&self.wf_types_key).await.map_err(redis_err)?;
467        let mut all = Vec::new();
468        for t in &types {
469            let vals: Vec<String> = conn
470                .hvals(prefixed_key(&self.wf_runs_prefix, t))
471                .await
472                .map_err(redis_err)?;
473            for s in vals {
474                let wf: WorkflowIdentity =
475                    serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
476                        message: e.to_string(),
477                    })?;
478                all.push(wf);
479            }
480        }
481        Ok(all)
482    }
483}
484
485#[async_trait]
486impl StateBackendRunner for RedisStateBackend {
487    async fn store_runner_context(&self, context: &StoredRunnerContext) -> RustvelloResult<()> {
488        let mut conn = self.pool.conn().await?;
489        let json = serde_json::to_string(context).map_err(|e| RustvelloError::Serialization {
490            message: e.to_string(),
491        })?;
492        conn.set::<_, _, ()>(prefixed_key(&self.runner_prefix, &context.runner_id), &json)
493            .await
494            .map_err(redis_err)?;
495        Ok(())
496    }
497
498    async fn get_runner_context(
499        &self,
500        runner_id: &str,
501    ) -> RustvelloResult<Option<StoredRunnerContext>> {
502        let mut conn = self.pool.conn().await?;
503        let val: Option<String> = conn
504            .get(prefixed_key(&self.runner_prefix, runner_id))
505            .await
506            .map_err(redis_err)?;
507        match val {
508            Some(s) => {
509                let ctx: StoredRunnerContext =
510                    serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
511                        message: e.to_string(),
512                    })?;
513                Ok(Some(ctx))
514            }
515            None => Ok(None),
516        }
517    }
518
519    async fn get_runner_contexts_by_parent(
520        &self,
521        parent_runner_id: &str,
522    ) -> RustvelloResult<Vec<StoredRunnerContext>> {
523        let mut conn = self.pool.conn().await?;
524        let keys = scan_keys(&mut conn, &format!("{}*", &self.runner_prefix)).await?;
525        let mut result = Vec::new();
526        for key in keys {
527            let val: Option<String> = conn.get(&key).await.map_err(redis_err)?;
528            if let Some(s) = val {
529                let ctx: StoredRunnerContext =
530                    serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
531                        message: e.to_string(),
532                    })?;
533                if ctx.parent_runner_id.as_deref() == Some(parent_runner_id) {
534                    result.push(ctx);
535                }
536            }
537        }
538        Ok(result)
539    }
540
541    async fn get_invocation_ids_by_runner(
542        &self,
543        runner_id: &str,
544        limit: usize,
545        offset: usize,
546    ) -> RustvelloResult<Vec<InvocationId>> {
547        let mut conn = self.pool.conn().await?;
548        let members: Vec<String> = conn
549            .smembers(prefixed_key(&self.runner_inv_prefix, runner_id))
550            .await
551            .map_err(redis_err)?;
552        let iter = members.into_iter().skip(offset);
553        let ids: Vec<InvocationId> = if limit > 0 {
554            iter.take(limit).map(InvocationId::from_string).collect()
555        } else {
556            iter.map(InvocationId::from_string).collect()
557        };
558        Ok(ids)
559    }
560
561    async fn count_invocations_by_runner(&self, runner_id: &str) -> RustvelloResult<usize> {
562        let mut conn = self.pool.conn().await?;
563        let count: usize = conn
564            .scard(prefixed_key(&self.runner_inv_prefix, runner_id))
565            .await
566            .map_err(redis_err)?;
567        Ok(count)
568    }
569
570    async fn get_history_in_timerange(
571        &self,
572        start: chrono::DateTime<chrono::Utc>,
573        end: chrono::DateTime<chrono::Utc>,
574        limit: usize,
575        offset: usize,
576    ) -> RustvelloResult<Vec<InvocationHistory>> {
577        let mut conn = self.pool.conn().await?;
578        let min = start.timestamp_millis() as f64;
579        let max = end.timestamp_millis() as f64;
580        // Fetch all in range, then paginate
581        let vals: Vec<String> = redis::cmd("ZRANGEBYSCORE")
582            .arg(&self.history_ts_key)
583            .arg(min)
584            .arg(max)
585            .query_async(&mut conn)
586            .await
587            .map_err(redis_err)?;
588        let iter = vals.into_iter().skip(offset);
589        let selected: Vec<String> = if limit > 0 {
590            iter.take(limit).collect()
591        } else {
592            iter.collect()
593        };
594        selected
595            .into_iter()
596            .map(|s| {
597                serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
598                    message: e.to_string(),
599                })
600            })
601            .collect()
602    }
603
604    async fn get_matching_runner_contexts(
605        &self,
606        partial_id: &str,
607    ) -> RustvelloResult<Vec<StoredRunnerContext>> {
608        let mut conn = self.pool.conn().await?;
609        let pattern = format!("{}*{}*", &self.runner_prefix, partial_id);
610        let keys = scan_keys(&mut conn, &pattern).await?;
611        let mut result = Vec::new();
612        for key in keys {
613            let val: Option<String> = conn.get(&key).await.map_err(redis_err)?;
614            if let Some(s) = val {
615                let ctx: StoredRunnerContext =
616                    serde_json::from_str(&s).map_err(|e| RustvelloError::Serialization {
617                        message: e.to_string(),
618                    })?;
619                result.push(ctx);
620            }
621        }
622        Ok(result)
623    }
624}