floxide_redis/
work_queue.rs

1//! Redis implementation of the WorkQueue trait.
2
3use crate::client::RedisClient;
4use async_trait::async_trait;
5use floxide_core::{
6    context::Context,
7    distributed::{WorkQueue, WorkQueueError},
8    workflow::WorkItem,
9};
10use redis::AsyncCommands;
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tracing::{error, instrument, trace};
14
15/// Redis implementation of the WorkQueue trait.
16#[derive(Clone)]
17pub struct RedisWorkQueue<WI: WorkItem> {
18    client: RedisClient,
19    _phantom: std::marker::PhantomData<WI>,
20}
21
22impl<WI: WorkItem> RedisWorkQueue<WI> {
23    /// Create a new Redis work queue with the given client.
24    pub fn new(client: RedisClient) -> Self {
25        Self {
26            client,
27            _phantom: std::marker::PhantomData,
28        }
29    }
30
31    /// Get the Redis key for the work queue for a specific workflow run.
32    fn queue_key(&self, workflow_id: &str) -> String {
33        self.client.prefixed_key(&format!("queue:{}", workflow_id))
34    }
35
36    /// Get the Redis key for the global work queue.
37    fn global_queue_key(&self) -> String {
38        self.client.prefixed_key("global_queue")
39    }
40}
41
42#[async_trait]
43impl<C: Context, WI: WorkItem + 'static> WorkQueue<C, WI> for RedisWorkQueue<WI>
44where
45    WI: Serialize + DeserializeOwned + Send + Sync,
46{
47    #[instrument(skip(self, work), level = "trace")]
48    async fn enqueue(&self, workflow_id: &str, work: WI) -> Result<(), WorkQueueError> {
49        let queue_key = self.queue_key(workflow_id);
50        let global_queue_key = self.global_queue_key();
51
52        // Serialize the work item
53        let serialized = serde_json::to_string(&work).map_err(|e| {
54            error!("Failed to serialize work item: {}", e);
55            WorkQueueError::Other(format!("Serialization error: {}", e))
56        })?;
57
58        // Use a Redis pipeline to atomically:
59        // 1. Push the work item to the workflow-specific queue
60        // 2. Add the workflow ID to the global queue if not already present
61        let mut conn = self.client.conn.clone();
62        let _result: () = redis::pipe()
63            .rpush(&queue_key, serialized)
64            .sadd(&global_queue_key, workflow_id)
65            .query_async(&mut conn)
66            .await
67            .map_err(|e| {
68                error!("Redis error while enqueueing work: {}", e);
69                WorkQueueError::Io(e.to_string())
70            })?;
71
72        trace!("Enqueued work item for workflow {}", workflow_id);
73        Ok(())
74    }
75
76    #[instrument(skip(self), level = "trace")]
77    async fn dequeue(&self) -> Result<Option<(String, WI)>, WorkQueueError> {
78        let global_queue_key = self.global_queue_key();
79        let mut conn = self.client.conn.clone();
80
81        // Get all workflow IDs from the global queue
82        let workflow_ids: Vec<String> = conn.smembers(&global_queue_key).await.map_err(|e| {
83            error!("Redis error while getting workflow IDs: {}", e);
84            WorkQueueError::Io(e.to_string())
85        })?;
86
87        // Try to dequeue from each workflow queue in turn
88        for workflow_id in workflow_ids {
89            let queue_key = self.queue_key(&workflow_id);
90
91            // Use LPOP to get the next item from the queue
92            let result: Option<String> = conn.lpop(&queue_key, None).await.map_err(|e| {
93                error!("Redis error while dequeueing work: {}", e);
94                WorkQueueError::Io(e.to_string())
95            })?;
96
97            if let Some(serialized) = result {
98                // Deserialize the work item
99                let work_item = serde_json::from_str(&serialized).map_err(|e| {
100                    error!("Failed to deserialize work item: {}", e);
101                    WorkQueueError::Other(format!("Deserialization error: {}", e))
102                })?;
103
104                // Check if the queue is now empty, and if so, remove it from the global queue
105                let queue_len: usize = conn.llen(&queue_key).await.map_err(|e| {
106                    error!("Redis error while checking queue length: {}", e);
107                    WorkQueueError::Io(e.to_string())
108                })?;
109
110                if queue_len == 0 {
111                    let _result: () =
112                        conn.srem(&global_queue_key, &workflow_id)
113                            .await
114                            .map_err(|e| {
115                                error!(
116                                    "Redis error while removing workflow from global queue: {}",
117                                    e
118                                );
119                                WorkQueueError::Io(e.to_string())
120                            })?;
121                }
122
123                trace!("Dequeued work item for workflow {}", workflow_id);
124                return Ok(Some((workflow_id, work_item)));
125            }
126        }
127
128        // No work items found
129        trace!("No work items available");
130        Ok(None)
131    }
132
133    #[instrument(skip(self), level = "trace")]
134    async fn purge_run(&self, run_id: &str) -> Result<(), WorkQueueError> {
135        let queue_key = self.queue_key(run_id);
136        let global_queue_key = self.global_queue_key();
137        let mut conn = self.client.conn.clone();
138
139        // Use a Redis pipeline to atomically:
140        // 1. Delete the workflow-specific queue
141        // 2. Remove the workflow ID from the global queue
142        let _result: () = redis::pipe()
143            .del(&queue_key)
144            .srem(&global_queue_key, run_id)
145            .query_async(&mut conn)
146            .await
147            .map_err(|e| {
148                error!("Redis error while purging run: {}", e);
149                WorkQueueError::Io(e.to_string())
150            })?;
151
152        trace!("Purged work items for workflow {}", run_id);
153        Ok(())
154    }
155
156    #[instrument(skip(self), level = "trace")]
157    async fn pending_work(&self, run_id: &str) -> Result<Vec<WI>, WorkQueueError> {
158        let queue_key = self.queue_key(run_id);
159        let mut conn = self.client.conn.clone();
160
161        // Get all items from the queue
162        let items: Vec<String> = conn.lrange(&queue_key, 0, -1).await.map_err(|e| {
163            error!("Redis error while getting pending work: {}", e);
164            WorkQueueError::Io(e.to_string())
165        })?;
166
167        // Deserialize each item
168        let mut result = Vec::with_capacity(items.len());
169        for item in items {
170            let work_item = serde_json::from_str(&item).map_err(|e| {
171                error!("Failed to deserialize work item: {}", e);
172                WorkQueueError::Other(format!("Deserialization error: {}", e))
173            })?;
174            result.push(work_item);
175        }
176
177        trace!(
178            "Found {} pending work items for workflow {}",
179            result.len(),
180            run_id
181        );
182        Ok(result)
183    }
184}