floxide_redis/
work_queue.rs1use crate::client::RedisClient;
4use async_trait::async_trait;
5use floxide_core::{
6 context::Context,
7 distributed::{WorkQueue, WorkQueueError},
8 workflow::WorkItem,
9};
10use redis::AsyncCommands;
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tracing::{error, instrument, trace};
14
15#[derive(Clone)]
17pub struct RedisWorkQueue<WI: WorkItem> {
18 client: RedisClient,
19 _phantom: std::marker::PhantomData<WI>,
20}
21
22impl<WI: WorkItem> RedisWorkQueue<WI> {
23 pub fn new(client: RedisClient) -> Self {
25 Self {
26 client,
27 _phantom: std::marker::PhantomData,
28 }
29 }
30
31 fn queue_key(&self, workflow_id: &str) -> String {
33 self.client.prefixed_key(&format!("queue:{}", workflow_id))
34 }
35
36 fn global_queue_key(&self) -> String {
38 self.client.prefixed_key("global_queue")
39 }
40}
41
42#[async_trait]
43impl<C: Context, WI: WorkItem + 'static> WorkQueue<C, WI> for RedisWorkQueue<WI>
44where
45 WI: Serialize + DeserializeOwned + Send + Sync,
46{
47 #[instrument(skip(self, work), level = "trace")]
48 async fn enqueue(&self, workflow_id: &str, work: WI) -> Result<(), WorkQueueError> {
49 let queue_key = self.queue_key(workflow_id);
50 let global_queue_key = self.global_queue_key();
51
52 let serialized = serde_json::to_string(&work).map_err(|e| {
54 error!("Failed to serialize work item: {}", e);
55 WorkQueueError::Other(format!("Serialization error: {}", e))
56 })?;
57
58 let mut conn = self.client.conn.clone();
62 let _result: () = redis::pipe()
63 .rpush(&queue_key, serialized)
64 .sadd(&global_queue_key, workflow_id)
65 .query_async(&mut conn)
66 .await
67 .map_err(|e| {
68 error!("Redis error while enqueueing work: {}", e);
69 WorkQueueError::Io(e.to_string())
70 })?;
71
72 trace!("Enqueued work item for workflow {}", workflow_id);
73 Ok(())
74 }
75
76 #[instrument(skip(self), level = "trace")]
77 async fn dequeue(&self) -> Result<Option<(String, WI)>, WorkQueueError> {
78 let global_queue_key = self.global_queue_key();
79 let mut conn = self.client.conn.clone();
80
81 let workflow_ids: Vec<String> = conn.smembers(&global_queue_key).await.map_err(|e| {
83 error!("Redis error while getting workflow IDs: {}", e);
84 WorkQueueError::Io(e.to_string())
85 })?;
86
87 for workflow_id in workflow_ids {
89 let queue_key = self.queue_key(&workflow_id);
90
91 let result: Option<String> = conn.lpop(&queue_key, None).await.map_err(|e| {
93 error!("Redis error while dequeueing work: {}", e);
94 WorkQueueError::Io(e.to_string())
95 })?;
96
97 if let Some(serialized) = result {
98 let work_item = serde_json::from_str(&serialized).map_err(|e| {
100 error!("Failed to deserialize work item: {}", e);
101 WorkQueueError::Other(format!("Deserialization error: {}", e))
102 })?;
103
104 let queue_len: usize = conn.llen(&queue_key).await.map_err(|e| {
106 error!("Redis error while checking queue length: {}", e);
107 WorkQueueError::Io(e.to_string())
108 })?;
109
110 if queue_len == 0 {
111 let _result: () =
112 conn.srem(&global_queue_key, &workflow_id)
113 .await
114 .map_err(|e| {
115 error!(
116 "Redis error while removing workflow from global queue: {}",
117 e
118 );
119 WorkQueueError::Io(e.to_string())
120 })?;
121 }
122
123 trace!("Dequeued work item for workflow {}", workflow_id);
124 return Ok(Some((workflow_id, work_item)));
125 }
126 }
127
128 trace!("No work items available");
130 Ok(None)
131 }
132
133 #[instrument(skip(self), level = "trace")]
134 async fn purge_run(&self, run_id: &str) -> Result<(), WorkQueueError> {
135 let queue_key = self.queue_key(run_id);
136 let global_queue_key = self.global_queue_key();
137 let mut conn = self.client.conn.clone();
138
139 let _result: () = redis::pipe()
143 .del(&queue_key)
144 .srem(&global_queue_key, run_id)
145 .query_async(&mut conn)
146 .await
147 .map_err(|e| {
148 error!("Redis error while purging run: {}", e);
149 WorkQueueError::Io(e.to_string())
150 })?;
151
152 trace!("Purged work items for workflow {}", run_id);
153 Ok(())
154 }
155
156 #[instrument(skip(self), level = "trace")]
157 async fn pending_work(&self, run_id: &str) -> Result<Vec<WI>, WorkQueueError> {
158 let queue_key = self.queue_key(run_id);
159 let mut conn = self.client.conn.clone();
160
161 let items: Vec<String> = conn.lrange(&queue_key, 0, -1).await.map_err(|e| {
163 error!("Redis error while getting pending work: {}", e);
164 WorkQueueError::Io(e.to_string())
165 })?;
166
167 let mut result = Vec::with_capacity(items.len());
169 for item in items {
170 let work_item = serde_json::from_str(&item).map_err(|e| {
171 error!("Failed to deserialize work item: {}", e);
172 WorkQueueError::Other(format!("Deserialization error: {}", e))
173 })?;
174 result.push(work_item);
175 }
176
177 trace!(
178 "Found {} pending work items for workflow {}",
179 result.len(),
180 run_id
181 );
182 Ok(result)
183 }
184}