floxide_redis/
work_item_store.rs

1//! Redis implementation of the WorkItemStateStore trait.
2
3use crate::client::RedisClient;
4use async_trait::async_trait;
5use floxide_core::distributed::{
6    WorkItemState, WorkItemStateStore, WorkItemStateStoreError, WorkItemStatus,
7};
8use floxide_core::workflow::WorkItem;
9use redis::AsyncCommands;
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use tracing::{error, instrument, trace};
13
14/// Redis implementation of the WorkItemStateStore trait.
15#[derive(Clone)]
16pub struct RedisWorkItemStateStore<W: WorkItem> {
17    client: RedisClient,
18    _phantom: std::marker::PhantomData<W>,
19}
20
21impl<W: WorkItem> RedisWorkItemStateStore<W> {
22    /// Create a new Redis work item state store with the given client.
23    pub fn new(client: RedisClient) -> Self {
24        Self {
25            client,
26            _phantom: std::marker::PhantomData,
27        }
28    }
29
30    /// Get the Redis key for work item states for a specific run.
31    fn work_item_states_key(&self, run_id: &str) -> String {
32        self.client
33            .prefixed_key(&format!("work_item_states:{}", run_id))
34    }
35
36    /// Get the Redis key for a specific work item state.
37    fn work_item_state_key(&self, run_id: &str, item_id: &str) -> String {
38        self.client
39            .prefixed_key(&format!("work_item_state:{}:{}", run_id, item_id))
40    }
41}
42
43#[async_trait]
44impl<W: WorkItem + Serialize + DeserializeOwned> WorkItemStateStore<W>
45    for RedisWorkItemStateStore<W>
46{
47    #[instrument(skip(self, item), level = "trace")]
48    async fn get_status(
49        &self,
50        run_id: &str,
51        item: &W,
52    ) -> Result<WorkItemStatus, WorkItemStateStoreError> {
53        let item_id = item.instance_id();
54        let key = self.work_item_state_key(run_id, &item_id);
55        let mut conn = self.client.conn.clone();
56
57        // Get the serialized work item state from Redis
58        let result: Option<String> = conn.get(&key).await.map_err(|e| {
59            error!("Redis error while getting work item state: {}", e);
60            WorkItemStateStoreError::Io(e.to_string())
61        })?;
62
63        // If the work item state exists, deserialize it and return the status
64        if let Some(serialized) = result {
65            let state = serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
66                error!("Failed to deserialize work item state: {}", e);
67                WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
68            })?;
69
70            trace!(
71                "Got status for work item {} in run {}: {:?}",
72                item_id,
73                run_id,
74                state.status
75            );
76            Ok(state.status)
77        } else {
78            // If the work item state doesn't exist, return the default status
79            trace!(
80                "No status found for work item {} in run {}, returning default",
81                item_id,
82                run_id
83            );
84            Ok(WorkItemStatus::default())
85        }
86    }
87
88    #[instrument(skip(self, item, status), level = "trace")]
89    async fn set_status(
90        &self,
91        run_id: &str,
92        item: &W,
93        status: WorkItemStatus,
94    ) -> Result<(), WorkItemStateStoreError> {
95        let item_id = item.instance_id();
96        let key = self.work_item_state_key(run_id, &item_id);
97        let states_key = self.work_item_states_key(run_id);
98        let mut conn = self.client.conn.clone();
99
100        // Get the current work item state or create a new one
101        let state = match conn.get::<_, Option<String>>(&key).await {
102            Ok(Some(serialized)) => {
103                let mut state =
104                    serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
105                        error!("Failed to deserialize work item state: {}", e);
106                        WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
107                    })?;
108                state.status = status;
109                state
110            }
111            _ => WorkItemState {
112                status,
113                attempts: 0,
114                work_item: item.clone(),
115            },
116        };
117        // Clone status for debug log before it is moved
118        let status_for_log = state.status.clone();
119        // Serialize the updated work item state
120        let serialized = serde_json::to_string(&state).map_err(|e| {
121            error!("Failed to serialize work item state: {}", e);
122            WorkItemStateStoreError::Other(format!("Serialization error: {}", e))
123        })?;
124        // Use a Redis pipeline to atomically:
125        // 1. Store the updated work item state
126        // 2. Add the work item ID to the set of work items for this run
127        let _result: () = redis::pipe()
128            .set(&key, serialized)
129            .sadd(&states_key, &item_id)
130            .query_async(&mut conn)
131            .await
132            .map_err(|e| {
133                error!("Redis error while updating work item status: {}", e);
134                WorkItemStateStoreError::Io(e.to_string())
135            })?;
136        trace!(
137            "Updated status for work item {} in run {} to {:?}",
138            item_id,
139            run_id,
140            status_for_log
141        );
142        Ok(())
143    }
144
145    #[instrument(skip(self, item), level = "trace")]
146    async fn increment_attempts(
147        &self,
148        run_id: &str,
149        item: &W,
150    ) -> Result<u32, WorkItemStateStoreError> {
151        let item_id = item.instance_id();
152        let key = self.work_item_state_key(run_id, &item_id);
153        let states_key = self.work_item_states_key(run_id);
154        let mut conn = self.client.conn.clone();
155
156        // Get the current work item state or create a new one
157        let mut state = match conn.get::<_, Option<String>>(&key).await {
158            Ok(Some(serialized)) => {
159                serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
160                    error!("Failed to deserialize work item state: {}", e);
161                    WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
162                })?
163            }
164            _ => WorkItemState {
165                status: WorkItemStatus::default(),
166                attempts: 0,
167                work_item: item.clone(),
168            },
169        };
170
171        // Increment the attempts counter
172        state.attempts += 1;
173
174        // Serialize the updated work item state
175        let serialized = serde_json::to_string(&state).map_err(|e| {
176            error!("Failed to serialize work item state: {}", e);
177            WorkItemStateStoreError::Other(format!("Serialization error: {}", e))
178        })?;
179
180        // Use a Redis pipeline to atomically:
181        // 1. Store the updated work item state
182        // 2. Add the work item ID to the set of work items for this run
183        let _result: () = redis::pipe()
184            .set(&key, serialized)
185            .sadd(&states_key, &item_id)
186            .query_async(&mut conn)
187            .await
188            .map_err(|e| {
189                error!("Redis error while incrementing attempts: {}", e);
190                WorkItemStateStoreError::Io(e.to_string())
191            })?;
192
193        trace!(
194            "Incremented attempts for work item {} in run {} to {}",
195            item_id,
196            run_id,
197            state.attempts
198        );
199        Ok(state.attempts)
200    }
201
202    #[instrument(skip(self, item), level = "trace")]
203    async fn get_attempts(&self, run_id: &str, item: &W) -> Result<u32, WorkItemStateStoreError> {
204        let item_id = item.instance_id();
205        let key = self.work_item_state_key(run_id, &item_id);
206        let mut conn = self.client.conn.clone();
207
208        // Get the serialized work item state from Redis
209        let result: Option<String> = conn.get(&key).await.map_err(|e| {
210            error!("Redis error while getting work item state: {}", e);
211            WorkItemStateStoreError::Io(e.to_string())
212        })?;
213
214        // If the work item state exists, deserialize it and return the attempts
215        if let Some(serialized) = result {
216            let state = serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
217                error!("Failed to deserialize work item state: {}", e);
218                WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
219            })?;
220
221            trace!(
222                "Got attempts for work item {} in run {}: {}",
223                item_id,
224                run_id,
225                state.attempts
226            );
227            Ok(state.attempts)
228        } else {
229            // If the work item state doesn't exist, return 0 attempts
230            trace!(
231                "No attempts found for work item {} in run {}, returning 0",
232                item_id,
233                run_id
234            );
235            Ok(0)
236        }
237    }
238
239    #[instrument(skip(self, item), level = "trace")]
240    async fn reset_attempts(&self, run_id: &str, item: &W) -> Result<(), WorkItemStateStoreError> {
241        let item_id = item.instance_id();
242        let key = self.work_item_state_key(run_id, &item_id);
243        let states_key = self.work_item_states_key(run_id);
244        let mut conn = self.client.conn.clone();
245
246        // Get the current work item state or create a new one
247        let mut state = match conn.get::<_, Option<String>>(&key).await {
248            Ok(Some(serialized)) => {
249                serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
250                    error!("Failed to deserialize work item state: {}", e);
251                    WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
252                })?
253            }
254            _ => WorkItemState {
255                status: WorkItemStatus::default(),
256                attempts: 0,
257                work_item: item.clone(),
258            },
259        };
260
261        // Reset the attempts counter
262        state.attempts = 0;
263
264        // Serialize the updated work item state
265        let serialized = serde_json::to_string(&state).map_err(|e| {
266            error!("Failed to serialize work item state: {}", e);
267            WorkItemStateStoreError::Other(format!("Serialization error: {}", e))
268        })?;
269
270        // Use a Redis pipeline to atomically:
271        // 1. Store the updated work item state
272        // 2. Add the work item ID to the set of work items for this run
273        let _result: () = redis::pipe()
274            .set(&key, serialized)
275            .sadd(&states_key, &item_id)
276            .query_async(&mut conn)
277            .await
278            .map_err(|e| {
279                error!("Redis error while resetting attempts: {}", e);
280                WorkItemStateStoreError::Io(e.to_string())
281            })?;
282
283        trace!(
284            "Reset attempts for work item {} in run {} to 0",
285            item_id,
286            run_id
287        );
288        Ok(())
289    }
290
291    #[instrument(skip(self), level = "trace")]
292    async fn get_all(
293        &self,
294        run_id: &str,
295    ) -> Result<Vec<WorkItemState<W>>, WorkItemStateStoreError> {
296        let states_key = self.work_item_states_key(run_id);
297        let mut conn = self.client.conn.clone();
298
299        // Get all work item IDs for this run
300        let item_ids: Vec<String> = conn.smembers(&states_key).await.map_err(|e| {
301            error!("Redis error while getting all work items: {}", e);
302            WorkItemStateStoreError::Io(e.to_string())
303        })?;
304
305        // Get the work item state for each ID
306        let mut items = Vec::with_capacity(item_ids.len());
307        for item_id in item_ids {
308            let key = self.work_item_state_key(run_id, &item_id);
309            let result: Option<String> = conn.get(&key).await.map_err(|e| {
310                error!("Redis error while getting work item state: {}", e);
311                WorkItemStateStoreError::Io(e.to_string())
312            })?;
313
314            if let Some(serialized) = result {
315                let state = serde_json::from_str::<WorkItemState<W>>(&serialized).map_err(|e| {
316                    error!("Failed to deserialize work item state: {}", e);
317                    WorkItemStateStoreError::Other(format!("Deserialization error: {}", e))
318                })?;
319
320                items.push(state);
321            }
322        }
323
324        trace!("Got {} work item states for run {}", items.len(), run_id);
325        Ok(items)
326    }
327}