1use crate::client::RedisClient;
4use async_trait::async_trait;
5use floxide_core::{
6 context::Context,
7 distributed::context_store::{ContextStore, ContextStoreError},
8 merge::Merge,
9};
10use rand::Rng;
11use redis::{AsyncCommands, Value};
12use serde::{de::DeserializeOwned, Serialize};
13use tokio::time::{sleep, Duration};
14use tracing::{error, instrument, trace, warn};
15
16const LOCK_TIMEOUT_MS: usize = 5000; const MAX_LOCK_RETRIES: usize = 10; const BASE_RETRY_DELAY_MS: u64 = 50; const MAX_RETRY_DELAY_MS: u64 = 500; #[derive(Clone)]
23pub struct RedisContextStore<C: Context + Merge + Default> {
24 client: RedisClient,
25 _phantom: std::marker::PhantomData<C>,
26}
27
28impl<C: Context + Merge + Default> RedisContextStore<C> {
29 pub fn new(client: RedisClient) -> Self {
31 Self {
32 client,
33 _phantom: std::marker::PhantomData,
34 }
35 }
36
37 fn context_key(&self, run_id: &str) -> String {
39 self.client.prefixed_key(&format!("context:{}", run_id))
40 }
41
42 fn lock_key(&self, run_id: &str) -> String {
44 self.client
45 .prefixed_key(&format!("lock:context:{}", run_id))
46 }
47}
48
49#[async_trait]
50impl<C> ContextStore<C> for RedisContextStore<C>
51where
52 C: Context + Merge + Default + Serialize + DeserializeOwned + Send + Sync + 'static,
53{
54 #[instrument(skip(self), level = "trace")]
55 async fn get(&self, run_id: &str) -> Result<Option<C>, ContextStoreError> {
56 let key = self.context_key(run_id);
57 let mut conn = self.client.conn.clone();
58
59 let result: Option<String> = conn.get(&key).await.map_err(|e| {
61 error!("Redis error while getting context: {}", e);
62 ContextStoreError::Io(e.to_string())
63 })?;
64
65 if let Some(serialized) = result {
67 let context = serde_json::from_str(&serialized).map_err(|e| {
68 error!("Failed to deserialize context: {}", e);
69 ContextStoreError::Other(format!("Deserialization error: {}", e))
70 })?;
71 trace!("Got context for run {}", run_id);
72 Ok(Some(context))
73 } else {
74 trace!("No context found for run {}", run_id);
75 Ok(None)
76 }
77 }
78
79 #[instrument(skip(self, ctx), level = "trace")]
80 async fn set(&self, run_id: &str, ctx: C) -> Result<(), ContextStoreError> {
81 let key = self.context_key(run_id);
82 let mut conn = self.client.conn.clone();
83
84 let serialized = match serde_json::to_string(&ctx) {
86 Ok(s) => s,
87 Err(e) => {
88 error!("Failed to serialize context: {}", e);
89 return Err(ContextStoreError::Other(format!(
90 "Serialization error: {}",
91 e
92 )));
93 }
94 };
95
96 if let Err(e) = conn.set(&key, serialized).await as Result<(), _> {
98 error!("Redis error while setting context: {}", e);
99 return Err(ContextStoreError::Other(format!(
100 "Redis error while setting context: {}",
101 e
102 )));
103 } else {
104 trace!("Set context for run {}", run_id);
105 Ok(())
106 }
107 }
108
109 #[instrument(skip(self, ctx), level = "trace")]
110 async fn merge(&self, run_id: &str, ctx: C) -> Result<(), ContextStoreError> {
111 let key = self.context_key(run_id);
112 let lock_key = self.lock_key(run_id);
113 let lock_value = format!("worker_{}", rand::thread_rng().gen::<u32>()); let mut conn = self.client.conn.clone();
115 let mut acquired_lock = false;
116
117 for attempt in 0..MAX_LOCK_RETRIES {
119 trace!(run_id, attempt, "Attempting to acquire context lock");
120 let result: Result<Value, redis::RedisError> = redis::cmd("SET")
122 .arg(&lock_key)
123 .arg(&lock_value)
124 .arg("NX")
125 .arg("PX")
126 .arg(LOCK_TIMEOUT_MS)
127 .query_async(&mut conn) .await;
129
130 match result {
131 Ok(Value::Okay) => {
132 trace!(run_id, "Successfully acquired context lock");
133 acquired_lock = true;
134 break; }
136 Ok(Value::Nil) => {
137 trace!(run_id, "Context lock already held, retrying...");
138 }
139 Ok(other) => {
140 warn!(
141 run_id,
142 ?other,
143 "Unexpected response from Redis SET NX PX while acquiring lock"
144 );
145 }
146 Err(e) => {
147 error!(run_id, error = %e, "Redis error while acquiring context lock");
148 }
150 }
151
152 let delay = rand::thread_rng().gen_range(BASE_RETRY_DELAY_MS..=MAX_RETRY_DELAY_MS);
154 trace!(
155 run_id,
156 attempt,
157 delay_ms = delay,
158 "Waiting before lock retry"
159 );
160 sleep(Duration::from_millis(delay)).await;
161 }
162
163 if !acquired_lock {
164 error!(
165 run_id,
166 "Failed to acquire context lock after {} retries, aborting merge", MAX_LOCK_RETRIES
167 );
168 return Err(ContextStoreError::Other(format!(
169 "Failed to acquire context lock after {} retries, aborting merge",
170 MAX_LOCK_RETRIES
171 )));
172 }
173
174 let rmw_result = async {
177 let current: Option<String> = match conn.get(&key).await {
179 Ok(val) => val,
180 Err(e) => {
181 error!(run_id, error = %e, "Redis error while getting context for merge");
182 return Err(()); }
184 };
185
186 let merged = if let Some(serialized) = current {
187 match serde_json::from_str::<C>(&serialized) {
188 Ok(mut existing) => {
189 trace!(run_id, ?existing, ?ctx, "Context before merge");
190 existing.merge(ctx);
191 trace!(run_id, ?existing, "Context after merge");
192 existing
193 }
194 Err(e) => {
195 error!(run_id, error = %e, "Failed to deserialize context for merge");
196 return Err(()); }
199 }
200 } else {
201 trace!(run_id, ?ctx, "No existing context found, using new context");
202 ctx
203 };
204
205 let serialized = match serde_json::to_string(&merged) {
207 Ok(s) => s,
208 Err(e) => {
209 error!(run_id, error = %e, "Failed to serialize merged context");
210 return Err(()); }
212 };
213
214 trace!(run_id, context_to_write=?merged, "Attempting to write merged context to Redis");
216 if let Err(e) = conn.set(&key, serialized).await as Result<(), _> {
217 error!(run_id, error = %e, "Redis error while setting merged context");
218 Err(()) } else {
220 trace!(
221 run_id,
222 "Successfully wrote merged context for run {}",
223 run_id
224 );
225 Ok(()) }
227 }
228 .await; trace!(run_id, "Releasing context lock");
234 if let Err(e) = conn.del(&lock_key).await as Result<(), _> {
235 error!(run_id, error = %e, "Failed to release context lock");
236 } else {
237 trace!(run_id, "Successfully released context lock");
238 }
239
240 if rmw_result.is_err() {
241 error!(
242 run_id,
243 "Merge operation failed during read-modify-write phase"
244 );
245 }
247 Ok(())
248 }
249}