Skip to main content

rustvello_mem/
state_backend.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::Mutex;
4
5use async_trait::async_trait;
6use tracing::instrument;
7
8use rustvello_core::error::{RustvelloError, RustvelloResult, TaskError};
9use rustvello_core::state_backend::{
10    StateBackendCore, StateBackendQuery, StateBackendRunner, StoredRunnerContext,
11};
12use rustvello_proto::call::CallDTO;
13use rustvello_proto::identifiers::{CallId, InvocationId, TaskId};
14use rustvello_proto::invocation::{InvocationDTO, InvocationHistory, WorkflowIdentity};
15
16struct BackendState {
17    invocations: HashMap<Arc<str>, InvocationDTO>,
18    calls: HashMap<String, CallDTO>,
19    results: HashMap<Arc<str>, String>,
20    errors: HashMap<Arc<str>, TaskError>,
21    histories: HashMap<Arc<str>, Vec<InvocationHistory>>,
22    /// workflow_id → member invocation IDs
23    workflow_members: HashMap<Arc<str>, Vec<InvocationId>>,
24    /// parent_invocation_id → child invocation IDs
25    children: HashMap<Arc<str>, Vec<InvocationId>>,
26    /// runner_id → StoredRunnerContext
27    runner_contexts: HashMap<String, StoredRunnerContext>,
28    /// runner_id → invocation IDs processed by that runner
29    runner_invocations: HashMap<String, Vec<InvocationId>>,
30    /// workflow_type (TaskId key) → set of WorkflowIdentity runs
31    workflow_types: Vec<TaskId>,
32    workflow_runs: HashMap<String, Vec<WorkflowIdentity>>,
33    /// workflow_id → { key → value }
34    workflow_data: HashMap<Arc<str>, HashMap<String, String>>,
35    /// app_id → info_json
36    app_infos: HashMap<String, String>,
37    /// workflow_id → sub-invocation IDs
38    workflow_sub_invocations: HashMap<Arc<str>, Vec<InvocationId>>,
39}
40
41/// In-memory state backend.
42///
43/// Stores invocations, calls, results, and history in process memory.
44/// Suitable for testing and development only.
45pub struct MemStateBackend {
46    state: Mutex<BackendState>,
47}
48
49impl MemStateBackend {
50    pub fn new() -> Self {
51        Self {
52            state: Mutex::new(BackendState {
53                invocations: HashMap::with_capacity(64),
54                calls: HashMap::with_capacity(64),
55                results: HashMap::with_capacity(32),
56                errors: HashMap::new(),
57                histories: HashMap::with_capacity(64),
58                workflow_members: HashMap::new(),
59                children: HashMap::new(),
60                runner_contexts: HashMap::new(),
61                runner_invocations: HashMap::new(),
62                workflow_types: Vec::new(),
63                workflow_runs: HashMap::new(),
64                workflow_data: HashMap::new(),
65                app_infos: HashMap::new(),
66                workflow_sub_invocations: HashMap::new(),
67            }),
68        }
69    }
70}
71
72impl Default for MemStateBackend {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78#[async_trait]
79impl StateBackendCore for MemStateBackend {
80    #[instrument(skip(self, invocation, call), fields(%invocation.invocation_id))]
81    async fn upsert_invocation(
82        &self,
83        invocation: &InvocationDTO,
84        call: &CallDTO,
85    ) -> RustvelloResult<()> {
86        let mut state = self.state.lock().await;
87
88        let is_new = !state
89            .invocations
90            .contains_key(invocation.invocation_id.as_str());
91
92        // Only index on first insert to avoid duplicate entries on upsert
93        if is_new {
94            // Index workflow membership
95            if let Some(ref wf) = invocation.workflow {
96                state
97                    .workflow_members
98                    .entry(Arc::from(wf.workflow_id.as_str()))
99                    .or_default()
100                    .push(invocation.invocation_id.clone());
101            }
102
103            // Index parent-child relationship
104            if let Some(ref parent_id) = invocation.parent_invocation_id {
105                state
106                    .children
107                    .entry(Arc::from(parent_id.as_str()))
108                    .or_default()
109                    .push(invocation.invocation_id.clone());
110            }
111        }
112
113        state.invocations.insert(
114            Arc::from(invocation.invocation_id.as_str()),
115            invocation.clone(),
116        );
117        state.calls.insert(call.call_id.to_string(), call.clone());
118        Ok(())
119    }
120
121    #[instrument(skip(self), fields(%invocation_id))]
122    async fn get_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<InvocationDTO> {
123        let state = self.state.lock().await;
124        state
125            .invocations
126            .get(invocation_id.as_str())
127            .cloned()
128            .ok_or_else(|| RustvelloError::InvocationNotFound {
129                invocation_id: invocation_id.clone(),
130            })
131    }
132
133    async fn get_call(&self, call_id: &CallId) -> RustvelloResult<CallDTO> {
134        let state = self.state.lock().await;
135        state
136            .calls
137            .get(&call_id.to_string())
138            .cloned()
139            .ok_or_else(|| RustvelloError::state_backend(format!("call not found: {}", call_id)))
140    }
141
142    #[instrument(skip(self, result), fields(%invocation_id))]
143    async fn store_result(
144        &self,
145        invocation_id: &InvocationId,
146        result: &str,
147    ) -> RustvelloResult<()> {
148        let mut state = self.state.lock().await;
149        state
150            .results
151            .insert(Arc::from(invocation_id.as_str()), result.to_string());
152        Ok(())
153    }
154
155    #[instrument(skip(self), fields(%invocation_id))]
156    async fn get_result(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<String>> {
157        let state = self.state.lock().await;
158        Ok(state.results.get(invocation_id.as_str()).cloned())
159    }
160
161    async fn store_error(
162        &self,
163        invocation_id: &InvocationId,
164        error: &TaskError,
165    ) -> RustvelloResult<()> {
166        let mut state = self.state.lock().await;
167        state
168            .errors
169            .insert(Arc::from(invocation_id.as_str()), error.clone());
170        Ok(())
171    }
172
173    async fn get_error(&self, invocation_id: &InvocationId) -> RustvelloResult<Option<TaskError>> {
174        let state = self.state.lock().await;
175        Ok(state.errors.get(invocation_id.as_str()).cloned())
176    }
177
178    async fn add_history(&self, history: &InvocationHistory) -> RustvelloResult<()> {
179        let mut state = self.state.lock().await;
180        // Update runner → invocation reverse index
181        let rid = history
182            .runner_id
183            .as_ref()
184            .or(history.status_record.runner_id.as_ref());
185        if let Some(r) = rid {
186            let inv_id = history.invocation_id.clone();
187            let entries = state.runner_invocations.entry(r.to_string()).or_default();
188            if !entries.iter().any(|e| e == &inv_id) {
189                entries.push(inv_id);
190            }
191        }
192        state
193            .histories
194            .entry(Arc::from(history.invocation_id.as_str()))
195            .or_default()
196            .push(history.clone());
197        Ok(())
198    }
199
200    async fn get_history(
201        &self,
202        invocation_id: &InvocationId,
203    ) -> RustvelloResult<Vec<InvocationHistory>> {
204        let state = self.state.lock().await;
205        Ok(state
206            .histories
207            .get(invocation_id.as_str())
208            .cloned()
209            .unwrap_or_default())
210    }
211
212    async fn purge(&self) -> RustvelloResult<()> {
213        let mut state = self.state.lock().await;
214        state.invocations.clear();
215        state.calls.clear();
216        state.results.clear();
217        state.errors.clear();
218        state.histories.clear();
219        state.workflow_members.clear();
220        state.children.clear();
221        state.runner_contexts.clear();
222        state.runner_invocations.clear();
223        state.workflow_types.clear();
224        state.workflow_runs.clear();
225        state.workflow_data.clear();
226        state.app_infos.clear();
227        state.workflow_sub_invocations.clear();
228        Ok(())
229    }
230
231    fn backend_name(&self) -> &'static str {
232        "In-Memory"
233    }
234
235    async fn usage_stats(&self) -> Vec<(&'static str, String)> {
236        let state = self.state.lock().await;
237        let history_entries: usize = state.histories.values().map(std::vec::Vec::len).sum();
238        let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
239        let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
240        for entries in state.histories.values() {
241            for h in entries {
242                let ts = h.status_record.timestamp;
243                oldest = Some(oldest.map_or(ts, |o| o.min(ts)));
244                newest = Some(newest.map_or(ts, |n| n.max(ts)));
245            }
246        }
247        let mut stats = vec![
248            ("Invocations", state.invocations.len().to_string()),
249            ("Calls", state.calls.len().to_string()),
250            ("Results", state.results.len().to_string()),
251            ("Errors", state.errors.len().to_string()),
252            ("History Entries", history_entries.to_string()),
253            ("Workflows", state.workflow_members.len().to_string()),
254            ("Runner Contexts", state.runner_contexts.len().to_string()),
255        ];
256        if let Some(dt) = oldest {
257            stats.push(("Oldest Record", dt.format("%Y-%m-%d %H:%M:%S").to_string()));
258        }
259        if let Some(dt) = newest {
260            stats.push(("Newest Record", dt.format("%Y-%m-%d %H:%M:%S").to_string()));
261        }
262        stats
263    }
264}
265
266#[async_trait]
267impl StateBackendQuery for MemStateBackend {
268    async fn get_workflow_invocations(
269        &self,
270        workflow_id: &InvocationId,
271    ) -> RustvelloResult<Vec<InvocationId>> {
272        let state = self.state.lock().await;
273        Ok(state
274            .workflow_members
275            .get(workflow_id.as_str())
276            .cloned()
277            .unwrap_or_default())
278    }
279
280    async fn get_child_invocations(
281        &self,
282        parent_invocation_id: &InvocationId,
283    ) -> RustvelloResult<Vec<InvocationId>> {
284        let state = self.state.lock().await;
285        Ok(state
286            .children
287            .get(parent_invocation_id.as_str())
288            .cloned()
289            .unwrap_or_default())
290    }
291
292    async fn store_workflow_run(&self, workflow: &WorkflowIdentity) -> RustvelloResult<()> {
293        let mut state = self.state.lock().await;
294        let type_key = workflow.workflow_type.to_string();
295        if !state
296            .workflow_types
297            .iter()
298            .any(|t| t.to_string() == type_key)
299        {
300            state.workflow_types.push(workflow.workflow_type.clone());
301        }
302        let runs = state.workflow_runs.entry(type_key).or_default();
303        if !runs.iter().any(|r| r.workflow_id == workflow.workflow_id) {
304            runs.push(workflow.clone());
305        }
306        Ok(())
307    }
308
309    async fn get_all_workflow_types(&self) -> RustvelloResult<Vec<TaskId>> {
310        let state = self.state.lock().await;
311        Ok(state.workflow_types.clone())
312    }
313
314    async fn get_workflow_runs(
315        &self,
316        workflow_type: &TaskId,
317    ) -> RustvelloResult<Vec<WorkflowIdentity>> {
318        let state = self.state.lock().await;
319        Ok(state
320            .workflow_runs
321            .get(&workflow_type.to_string())
322            .cloned()
323            .unwrap_or_default())
324    }
325
326    async fn set_workflow_data(
327        &self,
328        workflow_id: &InvocationId,
329        key: &str,
330        value: &str,
331    ) -> RustvelloResult<()> {
332        let mut state = self.state.lock().await;
333        state
334            .workflow_data
335            .entry(Arc::from(workflow_id.as_str()))
336            .or_default()
337            .insert(key.to_string(), value.to_string());
338        Ok(())
339    }
340
341    async fn get_workflow_data(
342        &self,
343        workflow_id: &InvocationId,
344        key: &str,
345    ) -> RustvelloResult<Option<String>> {
346        let state = self.state.lock().await;
347        Ok(state
348            .workflow_data
349            .get(workflow_id.as_str())
350            .and_then(|m| m.get(key).cloned()))
351    }
352
353    async fn store_app_info(&self, app_id: &str, info_json: &str) -> RustvelloResult<()> {
354        let mut state = self.state.lock().await;
355        state
356            .app_infos
357            .insert(app_id.to_string(), info_json.to_string());
358        Ok(())
359    }
360
361    async fn get_app_info(&self, app_id: &str) -> RustvelloResult<Option<String>> {
362        let state = self.state.lock().await;
363        Ok(state.app_infos.get(app_id).cloned())
364    }
365
366    async fn get_all_app_infos(&self) -> RustvelloResult<Vec<(String, String)>> {
367        let state = self.state.lock().await;
368        Ok(state
369            .app_infos
370            .iter()
371            .map(|(k, v)| (k.clone(), v.clone()))
372            .collect())
373    }
374
375    async fn store_workflow_sub_invocation(
376        &self,
377        workflow_id: &InvocationId,
378        sub_inv_id: &InvocationId,
379    ) -> RustvelloResult<()> {
380        let mut state = self.state.lock().await;
381        state
382            .workflow_sub_invocations
383            .entry(Arc::from(workflow_id.as_str()))
384            .or_default()
385            .push(sub_inv_id.clone());
386        Ok(())
387    }
388
389    async fn get_workflow_sub_invocations(
390        &self,
391        workflow_id: &InvocationId,
392    ) -> RustvelloResult<Vec<InvocationId>> {
393        let state = self.state.lock().await;
394        Ok(state
395            .workflow_sub_invocations
396            .get(workflow_id.as_str())
397            .cloned()
398            .unwrap_or_default())
399    }
400
401    async fn get_all_workflow_runs(&self) -> RustvelloResult<Vec<WorkflowIdentity>> {
402        let state = self.state.lock().await;
403        Ok(state.workflow_runs.values().flatten().cloned().collect())
404    }
405}
406
407#[async_trait]
408impl StateBackendRunner for MemStateBackend {
409    async fn store_runner_context(&self, context: &StoredRunnerContext) -> RustvelloResult<()> {
410        let mut state = self.state.lock().await;
411        state
412            .runner_contexts
413            .insert(context.runner_id.clone(), context.clone());
414        Ok(())
415    }
416
417    async fn get_runner_context(
418        &self,
419        runner_id: &str,
420    ) -> RustvelloResult<Option<StoredRunnerContext>> {
421        let state = self.state.lock().await;
422        Ok(state.runner_contexts.get(runner_id).cloned())
423    }
424
425    async fn get_runner_contexts_by_parent(
426        &self,
427        parent_runner_id: &str,
428    ) -> RustvelloResult<Vec<StoredRunnerContext>> {
429        let state = self.state.lock().await;
430        Ok(state
431            .runner_contexts
432            .values()
433            .filter(|ctx| ctx.parent_runner_id.as_deref() == Some(parent_runner_id))
434            .cloned()
435            .collect())
436    }
437
438    async fn get_invocation_ids_by_runner(
439        &self,
440        runner_id: &str,
441        limit: usize,
442        offset: usize,
443    ) -> RustvelloResult<Vec<InvocationId>> {
444        let state = self.state.lock().await;
445        let ids = state
446            .runner_invocations
447            .get(runner_id)
448            .map(|v| {
449                let iter = v.iter().skip(offset);
450                if limit > 0 {
451                    iter.take(limit).cloned().collect()
452                } else {
453                    iter.cloned().collect()
454                }
455            })
456            .unwrap_or_default();
457        Ok(ids)
458    }
459
460    async fn count_invocations_by_runner(&self, runner_id: &str) -> RustvelloResult<usize> {
461        let state = self.state.lock().await;
462        Ok(state
463            .runner_invocations
464            .get(runner_id)
465            .map_or(0, std::vec::Vec::len))
466    }
467
468    async fn get_history_in_timerange(
469        &self,
470        start: chrono::DateTime<chrono::Utc>,
471        end: chrono::DateTime<chrono::Utc>,
472        limit: usize,
473        offset: usize,
474    ) -> RustvelloResult<Vec<InvocationHistory>> {
475        let state = self.state.lock().await;
476        let mut all: Vec<&InvocationHistory> = state
477            .histories
478            .values()
479            .flat_map(|v| v.iter())
480            .filter(|h| {
481                let ts = h.history_timestamp.unwrap_or(h.status_record.timestamp);
482                ts >= start && ts <= end
483            })
484            .collect();
485        // Sort by timestamp ascending
486        all.sort_by_key(|h| h.history_timestamp.unwrap_or(h.status_record.timestamp));
487        let result = all
488            .into_iter()
489            .skip(offset)
490            .take(if limit > 0 { limit } else { usize::MAX })
491            .cloned()
492            .collect();
493        Ok(result)
494    }
495
496    async fn get_matching_runner_contexts(
497        &self,
498        partial_id: &str,
499    ) -> RustvelloResult<Vec<StoredRunnerContext>> {
500        let state = self.state.lock().await;
501        Ok(state
502            .runner_contexts
503            .values()
504            .filter(|ctx| ctx.runner_id.contains(partial_id))
505            .cloned()
506            .collect())
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513    use rustvello_proto::call::SerializedArguments;
514    use rustvello_proto::identifiers::TaskId;
515    use rustvello_proto::invocation::{InvocationDTO, WorkflowIdentity};
516    use rustvello_proto::status::{InvocationStatus, InvocationStatusRecord};
517
518    fn make_fixtures() -> (InvocationDTO, CallDTO) {
519        let task_id = TaskId::new("test.module", "my_task");
520        let mut args = SerializedArguments::new();
521        args.insert("x", "42");
522        let call = CallDTO::new(task_id.clone(), args);
523        let inv_id = InvocationId::new();
524        let inv = InvocationDTO::new(inv_id, task_id, call.call_id.clone());
525        (inv, call)
526    }
527
528    #[tokio::test]
529    async fn test_upsert_and_get() {
530        let backend = MemStateBackend::new();
531        let (inv, call) = make_fixtures();
532
533        backend.upsert_invocation(&inv, &call).await.unwrap();
534
535        let retrieved_inv = backend.get_invocation(&inv.invocation_id).await.unwrap();
536        assert_eq!(retrieved_inv.invocation_id, inv.invocation_id);
537
538        let retrieved_call = backend.get_call(&call.call_id).await.unwrap();
539        assert_eq!(retrieved_call.call_id, call.call_id);
540    }
541
542    #[tokio::test]
543    async fn test_results() {
544        let backend = MemStateBackend::new();
545        let inv_id = InvocationId::new();
546
547        assert!(backend.get_result(&inv_id).await.unwrap().is_none());
548
549        backend.store_result(&inv_id, "42").await.unwrap();
550        let result = backend.get_result(&inv_id).await.unwrap();
551        assert_eq!(result, Some("42".to_string()));
552    }
553
554    #[tokio::test]
555    async fn test_errors() {
556        let backend = MemStateBackend::new();
557        let inv_id = InvocationId::new();
558
559        let error = TaskError {
560            error_type: "ValueError".to_string(),
561            message: "something went wrong".to_string(),
562            traceback: None,
563        };
564
565        backend.store_error(&inv_id, &error).await.unwrap();
566        let retrieved = backend.get_error(&inv_id).await.unwrap().unwrap();
567        assert_eq!(retrieved.error_type, "ValueError");
568    }
569
570    #[tokio::test]
571    async fn test_history() {
572        let backend = MemStateBackend::new();
573        let inv_id = InvocationId::new();
574
575        let history = InvocationHistory::new(
576            inv_id.clone(),
577            InvocationStatusRecord::new(InvocationStatus::Registered, None),
578            None,
579        );
580        backend.add_history(&history).await.unwrap();
581
582        let histories = backend.get_history(&inv_id).await.unwrap();
583        assert_eq!(histories.len(), 1);
584    }
585
586    #[tokio::test]
587    async fn test_purge() {
588        let backend = MemStateBackend::new();
589        let (inv, call) = make_fixtures();
590        backend.upsert_invocation(&inv, &call).await.unwrap();
591
592        backend.purge().await.unwrap();
593
594        assert!(backend.get_invocation(&inv.invocation_id).await.is_err());
595    }
596
597    #[tokio::test]
598    async fn test_workflow_invocations() {
599        let backend = MemStateBackend::new();
600        let task_id = TaskId::new("mod", "task");
601        let mut args = SerializedArguments::new();
602        args.insert("x", "1");
603
604        // Create a root workflow invocation
605        let root_inv_id = InvocationId::from_string("root-1");
606        let wf = WorkflowIdentity::root(root_inv_id.clone(), task_id.clone());
607        let call = CallDTO::new(task_id.clone(), args.clone());
608        let inv = InvocationDTO::with_workflow(
609            root_inv_id.clone(),
610            task_id.clone(),
611            call.call_id.clone(),
612            None,
613            wf.clone(),
614        );
615        backend.upsert_invocation(&inv, &call).await.unwrap();
616
617        // Create a child in the same workflow
618        let child_inv_id = InvocationId::from_string("child-1");
619        let call2 = CallDTO::new(task_id.clone(), args);
620        let inv2 = InvocationDTO::with_workflow(
621            child_inv_id.clone(),
622            task_id.clone(),
623            call2.call_id.clone(),
624            Some(root_inv_id.clone()),
625            wf.clone(),
626        );
627        backend.upsert_invocation(&inv2, &call2).await.unwrap();
628
629        // Query workflow members
630        let members = backend
631            .get_workflow_invocations(&root_inv_id)
632            .await
633            .unwrap();
634        assert_eq!(members.len(), 2);
635
636        // Query children of root
637        let children = backend.get_child_invocations(&root_inv_id).await.unwrap();
638        assert_eq!(children.len(), 1);
639        assert_eq!(children[0], child_inv_id);
640    }
641
642    #[tokio::test]
643    async fn test_no_workflow_returns_empty() {
644        let backend = MemStateBackend::new();
645        let inv_id = InvocationId::from_string("nonexistent");
646        let members = backend.get_workflow_invocations(&inv_id).await.unwrap();
647        assert!(members.is_empty());
648    }
649
650    // --- Workflow discovery tests ---
651
652    #[tokio::test]
653    async fn test_store_and_get_workflow_runs() {
654        let backend = MemStateBackend::new();
655        let task_id = TaskId::new("mod", "my_workflow");
656        let wf_id = InvocationId::from_string("wf-run-1");
657        let wf = WorkflowIdentity::root(wf_id.clone(), task_id.clone());
658
659        backend.store_workflow_run(&wf).await.unwrap();
660
661        let types = backend.get_all_workflow_types().await.unwrap();
662        assert_eq!(types.len(), 1);
663        assert_eq!(types[0].to_string(), task_id.to_string());
664
665        let runs = backend.get_workflow_runs(&task_id).await.unwrap();
666        assert_eq!(runs.len(), 1);
667        assert_eq!(runs[0].workflow_id, wf_id);
668    }
669
670    #[tokio::test]
671    async fn test_multiple_workflow_types() {
672        let backend = MemStateBackend::new();
673        let task_a = TaskId::new("mod", "workflow_a");
674        let task_b = TaskId::new("mod", "workflow_b");
675
676        let wf_a = WorkflowIdentity::root(InvocationId::from_string("wf-a"), task_a.clone());
677        let wf_b = WorkflowIdentity::root(InvocationId::from_string("wf-b"), task_b.clone());
678
679        backend.store_workflow_run(&wf_a).await.unwrap();
680        backend.store_workflow_run(&wf_b).await.unwrap();
681
682        let types = backend.get_all_workflow_types().await.unwrap();
683        assert_eq!(types.len(), 2);
684
685        let runs_a = backend.get_workflow_runs(&task_a).await.unwrap();
686        assert_eq!(runs_a.len(), 1);
687        let runs_b = backend.get_workflow_runs(&task_b).await.unwrap();
688        assert_eq!(runs_b.len(), 1);
689    }
690
691    #[tokio::test]
692    async fn test_multiple_runs_same_type() {
693        let backend = MemStateBackend::new();
694        let task_id = TaskId::new("mod", "my_workflow");
695
696        for i in 0..3 {
697            let wf = WorkflowIdentity::root(
698                InvocationId::from_string(format!("wf-{i}")),
699                task_id.clone(),
700            );
701            backend.store_workflow_run(&wf).await.unwrap();
702        }
703
704        let types = backend.get_all_workflow_types().await.unwrap();
705        assert_eq!(types.len(), 1);
706
707        let runs = backend.get_workflow_runs(&task_id).await.unwrap();
708        assert_eq!(runs.len(), 3);
709    }
710
711    // --- Workflow data tests ---
712
713    #[tokio::test]
714    async fn test_workflow_data_set_get() {
715        let backend = MemStateBackend::new();
716        let wf_id = InvocationId::from_string("wf-data-1");
717
718        // Set and get
719        backend
720            .set_workflow_data(&wf_id, "key1", "value1")
721            .await
722            .unwrap();
723        let val = backend.get_workflow_data(&wf_id, "key1").await.unwrap();
724        assert_eq!(val, Some("value1".to_string()));
725
726        // Non-existent key returns None
727        let val = backend.get_workflow_data(&wf_id, "missing").await.unwrap();
728        assert!(val.is_none());
729    }
730
731    #[tokio::test]
732    async fn test_workflow_data_update() {
733        let backend = MemStateBackend::new();
734        let wf_id = InvocationId::from_string("wf-data-2");
735
736        backend
737            .set_workflow_data(&wf_id, "counter", "1")
738            .await
739            .unwrap();
740        backend
741            .set_workflow_data(&wf_id, "counter", "2")
742            .await
743            .unwrap();
744        let val = backend.get_workflow_data(&wf_id, "counter").await.unwrap();
745        assert_eq!(val, Some("2".to_string()));
746    }
747
748    #[tokio::test]
749    async fn test_workflow_data_isolation() {
750        let backend = MemStateBackend::new();
751        let wf1 = InvocationId::from_string("wf-iso-1");
752        let wf2 = InvocationId::from_string("wf-iso-2");
753
754        backend
755            .set_workflow_data(&wf1, "key", "val1")
756            .await
757            .unwrap();
758        backend
759            .set_workflow_data(&wf2, "key", "val2")
760            .await
761            .unwrap();
762
763        assert_eq!(
764            backend.get_workflow_data(&wf1, "key").await.unwrap(),
765            Some("val1".to_string())
766        );
767        assert_eq!(
768            backend.get_workflow_data(&wf2, "key").await.unwrap(),
769            Some("val2".to_string())
770        );
771    }
772
773    #[tokio::test]
774    async fn test_workflow_data_purge() {
775        let backend = MemStateBackend::new();
776        let wf_id = InvocationId::from_string("wf-purge");
777
778        backend
779            .set_workflow_data(&wf_id, "key", "val")
780            .await
781            .unwrap();
782        backend.purge().await.unwrap();
783
784        let val = backend.get_workflow_data(&wf_id, "key").await.unwrap();
785        assert!(val.is_none());
786    }
787}