1use crate::client::RedisClient;
4use async_trait::async_trait;
5use floxide_core::distributed::{
6 WorkItemState, WorkItemStateStore, WorkItemStateStoreError, WorkItemStatus,
7};
8use floxide_core::workflow::WorkItem;
9use redis::AsyncCommands;
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use tracing::{error, instrument, trace};
13
14#[derive(Clone)]
16pub struct RedisWorkItemStateStore<W: WorkItem> {
17 client: RedisClient,
18 _phantom: std::marker::PhantomData<W>,
19}
20
21impl<W: WorkItem> RedisWorkItemStateStore<W> {
22 pub fn new(client: RedisClient) -> Self {
24 Self {
25 client,
26 _phantom: std::marker::PhantomData,
27 }
28 }
29
30 fn work_item_states_key(&self, run_id: &str) -> String {
32 self.client
33 .prefixed_key(&format!("work_item_states:{}", run_id))
34 }
35
36 fn work_item_state_key(&self, run_id: &str, item_id: &str) -> String {
38 self.client
39 .prefixed_key(&format!("work_item_state:{}:{}", run_id, item_id))
40 }
41}
42
43#[async_trait]
44impl<W: WorkItem + Serialize + DeserializeOwned> WorkItemStateStore<W>
45 for RedisWorkItemStateStore<W>
46{
47 #[instrument(skip(self, item), level = "trace")]
48 async fn get_status(
49 &self,
50 run_id: &str,
51 item: &W,
52 ) -> Result<WorkItemStatus, WorkItemStateStoreError> {
53 let item_id = item.instance_id();
54 let key = self.work_item_state_key(run_id, &item_id);
55 let mut conn = self.client.conn.clone();
56
57 let result: Option<String> = conn.get(&key).await.map_err(|e| {
59 error!("Redis error while getting work item state: {}", e);
60 WorkItemStateStoreError::Io(e.to_string())
61 })?;
62
63 if let Some(serialized) = result {
65 let state = serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
66 error!("Failed to deserialize work item state: {}", e);
67 WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
68 })?;
69
70 trace!(
71 "Got status for work item {} in run {}: {:?}",
72 item_id,
73 run_id,
74 state.status
75 );
76 Ok(state.status)
77 } else {
78 trace!(
80 "No status found for work item {} in run {}, returning default",
81 item_id,
82 run_id
83 );
84 Ok(WorkItemStatus::default())
85 }
86 }
87
88 #[instrument(skip(self, item, status), level = "trace")]
89 async fn set_status(
90 &self,
91 run_id: &str,
92 item: &W,
93 status: WorkItemStatus,
94 ) -> Result<(), WorkItemStateStoreError> {
95 let item_id = item.instance_id();
96 let key = self.work_item_state_key(run_id, &item_id);
97 let states_key = self.work_item_states_key(run_id);
98 let mut conn = self.client.conn.clone();
99
100 let state = match conn.get::<_, Option<String>>(&key).await {
102 Ok(Some(serialized)) => {
103 let mut state =
104 serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
105 error!("Failed to deserialize work item state: {}", e);
106 WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
107 })?;
108 state.status = status;
109 state
110 }
111 _ => WorkItemState {
112 status,
113 attempts: 0,
114 work_item: item.clone(),
115 },
116 };
117 let status_for_log = state.status.clone();
119 let serialized = serde_json::to_string(&state).map_err(|e| {
121 error!("Failed to serialize work item state: {}", e);
122 WorkItemStateStoreError::Other(format!("Serialization error: {}", e))
123 })?;
124 let _result: () = redis::pipe()
128 .set(&key, serialized)
129 .sadd(&states_key, &item_id)
130 .query_async(&mut conn)
131 .await
132 .map_err(|e| {
133 error!("Redis error while updating work item status: {}", e);
134 WorkItemStateStoreError::Io(e.to_string())
135 })?;
136 trace!(
137 "Updated status for work item {} in run {} to {:?}",
138 item_id,
139 run_id,
140 status_for_log
141 );
142 Ok(())
143 }
144
145 #[instrument(skip(self, item), level = "trace")]
146 async fn increment_attempts(
147 &self,
148 run_id: &str,
149 item: &W,
150 ) -> Result<u32, WorkItemStateStoreError> {
151 let item_id = item.instance_id();
152 let key = self.work_item_state_key(run_id, &item_id);
153 let states_key = self.work_item_states_key(run_id);
154 let mut conn = self.client.conn.clone();
155
156 let mut state = match conn.get::<_, Option<String>>(&key).await {
158 Ok(Some(serialized)) => {
159 serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
160 error!("Failed to deserialize work item state: {}", e);
161 WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
162 })?
163 }
164 _ => WorkItemState {
165 status: WorkItemStatus::default(),
166 attempts: 0,
167 work_item: item.clone(),
168 },
169 };
170
171 state.attempts += 1;
173
174 let serialized = serde_json::to_string(&state).map_err(|e| {
176 error!("Failed to serialize work item state: {}", e);
177 WorkItemStateStoreError::Other(format!("Serialization error: {}", e))
178 })?;
179
180 let _result: () = redis::pipe()
184 .set(&key, serialized)
185 .sadd(&states_key, &item_id)
186 .query_async(&mut conn)
187 .await
188 .map_err(|e| {
189 error!("Redis error while incrementing attempts: {}", e);
190 WorkItemStateStoreError::Io(e.to_string())
191 })?;
192
193 trace!(
194 "Incremented attempts for work item {} in run {} to {}",
195 item_id,
196 run_id,
197 state.attempts
198 );
199 Ok(state.attempts)
200 }
201
202 #[instrument(skip(self, item), level = "trace")]
203 async fn get_attempts(&self, run_id: &str, item: &W) -> Result<u32, WorkItemStateStoreError> {
204 let item_id = item.instance_id();
205 let key = self.work_item_state_key(run_id, &item_id);
206 let mut conn = self.client.conn.clone();
207
208 let result: Option<String> = conn.get(&key).await.map_err(|e| {
210 error!("Redis error while getting work item state: {}", e);
211 WorkItemStateStoreError::Io(e.to_string())
212 })?;
213
214 if let Some(serialized) = result {
216 let state = serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
217 error!("Failed to deserialize work item state: {}", e);
218 WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
219 })?;
220
221 trace!(
222 "Got attempts for work item {} in run {}: {}",
223 item_id,
224 run_id,
225 state.attempts
226 );
227 Ok(state.attempts)
228 } else {
229 trace!(
231 "No attempts found for work item {} in run {}, returning 0",
232 item_id,
233 run_id
234 );
235 Ok(0)
236 }
237 }
238
239 #[instrument(skip(self, item), level = "trace")]
240 async fn reset_attempts(&self, run_id: &str, item: &W) -> Result<(), WorkItemStateStoreError> {
241 let item_id = item.instance_id();
242 let key = self.work_item_state_key(run_id, &item_id);
243 let states_key = self.work_item_states_key(run_id);
244 let mut conn = self.client.conn.clone();
245
246 let mut state = match conn.get::<_, Option<String>>(&key).await {
248 Ok(Some(serialized)) => {
249 serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
250 error!("Failed to deserialize work item state: {}", e);
251 WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
252 })?
253 }
254 _ => WorkItemState {
255 status: WorkItemStatus::default(),
256 attempts: 0,
257 work_item: item.clone(),
258 },
259 };
260
261 state.attempts = 0;
263
264 let serialized = serde_json::to_string(&state).map_err(|e| {
266 error!("Failed to serialize work item state: {}", e);
267 WorkItemStateStoreError::Other(format!("Serialization error: {}", e))
268 })?;
269
270 let _result: () = redis::pipe()
274 .set(&key, serialized)
275 .sadd(&states_key, &item_id)
276 .query_async(&mut conn)
277 .await
278 .map_err(|e| {
279 error!("Redis error while resetting attempts: {}", e);
280 WorkItemStateStoreError::Io(e.to_string())
281 })?;
282
283 trace!(
284 "Reset attempts for work item {} in run {} to 0",
285 item_id,
286 run_id
287 );
288 Ok(())
289 }
290
291 #[instrument(skip(self), level = "trace")]
292 async fn get_all(
293 &self,
294 run_id: &str,
295 ) -> Result<Vec<WorkItemState<W>>, WorkItemStateStoreError> {
296 let states_key = self.work_item_states_key(run_id);
297 let mut conn = self.client.conn.clone();
298
299 let item_ids: Vec<String> = conn.smembers(&states_key).await.map_err(|e| {
301 error!("Redis error while getting all work items: {}", e);
302 WorkItemStateStoreError::Io(e.to_string())
303 })?;
304
305 let mut items = Vec::with_capacity(item_ids.len());
307 for item_id in item_ids {
308 let key = self.work_item_state_key(run_id, &item_id);
309 let result: Option<String> = conn.get(&key).await.map_err(|e| {
310 error!("Redis error while getting work item state: {}", e);
311 WorkItemStateStoreError::Io(e.to_string())
312 })?;
313
314 if let Some(serialized) = result {
315 let state = serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
316 error!("Failed to deserialize work item state: {}", e);
317 WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
318 })?;
319
320 items.push(state);
321 }
322 }
323
324 trace!("Got {} work item states for run {}", items.len(), run_id);
325 Ok(items)
326 }
327}