1use std::collections::HashMap;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15use tokio::sync::RwLock;
16use uuid::Uuid;
17
18use crate::error::Result;
19use crate::result::FlowResult;
20
21#[async_trait]
36pub trait ExecutionStore: Send + Sync {
37 async fn save(&self, result: &FlowResult) -> Result<()>;
39
40 async fn load(&self, id: Uuid) -> Result<Option<FlowResult>>;
44
45 async fn list(&self) -> Result<Vec<Uuid>>;
47
48 async fn delete(&self, id: Uuid) -> Result<()>;
50}
51
52pub struct MemoryExecutionStore {
56 inner: Arc<RwLock<HashMap<Uuid, FlowResult>>>,
57}
58
59impl MemoryExecutionStore {
60 pub fn new() -> Self {
62 Self {
63 inner: Arc::new(RwLock::new(HashMap::new())),
64 }
65 }
66}
67
68impl Default for MemoryExecutionStore {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74#[async_trait]
75impl ExecutionStore for MemoryExecutionStore {
76 async fn save(&self, result: &FlowResult) -> Result<()> {
77 self.inner
78 .write()
79 .await
80 .insert(result.execution_id, result.clone());
81 Ok(())
82 }
83
84 async fn load(&self, id: Uuid) -> Result<Option<FlowResult>> {
85 Ok(self.inner.read().await.get(&id).cloned())
86 }
87
88 async fn list(&self) -> Result<Vec<Uuid>> {
89 Ok(self.inner.read().await.keys().cloned().collect())
90 }
91
92 async fn delete(&self, id: Uuid) -> Result<()> {
93 self.inner.write().await.remove(&id);
94 Ok(())
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use std::collections::{HashMap, HashSet};
102
103 fn make_result() -> FlowResult {
104 FlowResult {
105 execution_id: Uuid::new_v4(),
106 outputs: HashMap::new(),
107 completed_nodes: HashSet::new(),
108 skipped_nodes: HashSet::new(),
109 }
110 }
111
112 #[tokio::test]
113 async fn save_and_load_round_trip() {
114 let store = MemoryExecutionStore::new();
115 let r = make_result();
116 let id = r.execution_id;
117
118 store.save(&r).await.unwrap();
119 let loaded = store.load(id).await.unwrap().unwrap();
120 assert_eq!(loaded.execution_id, id);
121 }
122
123 #[tokio::test]
124 async fn load_unknown_id_returns_none() {
125 let store = MemoryExecutionStore::new();
126 let result = store.load(Uuid::new_v4()).await.unwrap();
127 assert!(result.is_none());
128 }
129
130 #[tokio::test]
131 async fn list_returns_all_saved_ids() {
132 let store = MemoryExecutionStore::new();
133 let r1 = make_result();
134 let r2 = make_result();
135 let id1 = r1.execution_id;
136 let id2 = r2.execution_id;
137
138 store.save(&r1).await.unwrap();
139 store.save(&r2).await.unwrap();
140
141 let ids = store.list().await.unwrap();
142 assert_eq!(ids.len(), 2);
143 assert!(ids.contains(&id1));
144 assert!(ids.contains(&id2));
145 }
146
147 #[tokio::test]
148 async fn delete_removes_entry() {
149 let store = MemoryExecutionStore::new();
150 let r = make_result();
151 let id = r.execution_id;
152
153 store.save(&r).await.unwrap();
154 store.delete(id).await.unwrap();
155
156 assert!(store.load(id).await.unwrap().is_none());
157 assert!(store.list().await.unwrap().is_empty());
158 }
159
160 #[tokio::test]
161 async fn save_overwrites_existing_entry() {
162 let store = MemoryExecutionStore::new();
163 let mut r = make_result();
164 let id = r.execution_id;
165
166 store.save(&r).await.unwrap();
167
168 r.outputs.insert("x".into(), serde_json::json!(42));
170 store.save(&r).await.unwrap();
171
172 let loaded = store.load(id).await.unwrap().unwrap();
173 assert_eq!(loaded.outputs["x"], serde_json::json!(42));
174
175 assert_eq!(store.list().await.unwrap().len(), 1);
177 }
178}