floxide_redis/
context_store.rs

1//! Redis implementation of the ContextStore trait.
2
3use crate::client::RedisClient;
4use async_trait::async_trait;
5use floxide_core::{
6    context::Context,
7    distributed::context_store::{ContextStore, ContextStoreError},
8    merge::Merge,
9};
10use rand::Rng;
11use redis::{AsyncCommands, Value};
12use serde::{de::DeserializeOwned, Serialize};
13use tokio::time::{sleep, Duration};
14use tracing::{error, instrument, trace, warn};
15
16const LOCK_TIMEOUT_MS: usize = 5000; // 5 seconds lock validity
17const MAX_LOCK_RETRIES: usize = 10; // Max attempts to acquire lock
18const BASE_RETRY_DELAY_MS: u64 = 50; // Min delay between retries
19const MAX_RETRY_DELAY_MS: u64 = 500; // Max delay between retries
20
21/// Redis implementation of the ContextStore trait.
22#[derive(Clone)]
23pub struct RedisContextStore<C: Context + Merge + Default> {
24    client: RedisClient,
25    _phantom: std::marker::PhantomData<C>,
26}
27
28impl<C: Context + Merge + Default> RedisContextStore<C> {
29    /// Create a new Redis context store with the given client.
30    pub fn new(client: RedisClient) -> Self {
31        Self {
32            client,
33            _phantom: std::marker::PhantomData,
34        }
35    }
36
37    /// Get the Redis key for the context for a specific run.
38    fn context_key(&self, run_id: &str) -> String {
39        self.client.prefixed_key(&format!("context:{}", run_id))
40    }
41
42    /// Get the Redis key for the context lock for a specific run.
43    fn lock_key(&self, run_id: &str) -> String {
44        self.client
45            .prefixed_key(&format!("lock:context:{}", run_id))
46    }
47}
48
49#[async_trait]
50impl<C> ContextStore<C> for RedisContextStore<C>
51where
52    C: Context + Merge + Default + Serialize + DeserializeOwned + Send + Sync + 'static,
53{
54    #[instrument(skip(self), level = "trace")]
55    async fn get(&self, run_id: &str) -> Result<Option<C>, ContextStoreError> {
56        let key = self.context_key(run_id);
57        let mut conn = self.client.conn.clone();
58
59        // Get the serialized context from Redis
60        let result: Option<String> = conn.get(&key).await.map_err(|e| {
61            error!("Redis error while getting context: {}", e);
62            ContextStoreError::Io(e.to_string())
63        })?;
64
65        // If the context exists, deserialize it
66        if let Some(serialized) = result {
67            let context = serde_json::from_str(&serialized).map_err(|e| {
68                error!("Failed to deserialize context: {}", e);
69                ContextStoreError::Other(format!("Deserialization error: {}", e))
70            })?;
71            trace!("Got context for run {}", run_id);
72            Ok(Some(context))
73        } else {
74            trace!("No context found for run {}", run_id);
75            Ok(None)
76        }
77    }
78
79    #[instrument(skip(self, ctx), level = "trace")]
80    async fn set(&self, run_id: &str, ctx: C) -> Result<(), ContextStoreError> {
81        let key = self.context_key(run_id);
82        let mut conn = self.client.conn.clone();
83
84        // Serialize the context
85        let serialized = match serde_json::to_string(&ctx) {
86            Ok(s) => s,
87            Err(e) => {
88                error!("Failed to serialize context: {}", e);
89                return Err(ContextStoreError::Other(format!(
90                    "Serialization error: {}",
91                    e
92                )));
93            }
94        };
95
96        // Store the serialized context in Redis
97        if let Err(e) = conn.set(&key, serialized).await as Result<(), _> {
98            error!("Redis error while setting context: {}", e);
99            return Err(ContextStoreError::Other(format!(
100                "Redis error while setting context: {}",
101                e
102            )));
103        } else {
104            trace!("Set context for run {}", run_id);
105            Ok(())
106        }
107    }
108
109    #[instrument(skip(self, ctx), level = "trace")]
110    async fn merge(&self, run_id: &str, ctx: C) -> Result<(), ContextStoreError> {
111        let key = self.context_key(run_id);
112        let lock_key = self.lock_key(run_id);
113        let lock_value = format!("worker_{}", rand::thread_rng().gen::<u32>()); // Unique value for this attempt
114        let mut conn = self.client.conn.clone();
115        let mut acquired_lock = false;
116
117        // --- Attempt to acquire lock ---
118        for attempt in 0..MAX_LOCK_RETRIES {
119            trace!(run_id, attempt, "Attempting to acquire context lock");
120            // Use redis::cmd for SET NX PX
121            let result: Result<Value, redis::RedisError> = redis::cmd("SET")
122                .arg(&lock_key)
123                .arg(&lock_value)
124                .arg("NX")
125                .arg("PX")
126                .arg(LOCK_TIMEOUT_MS)
127                .query_async(&mut conn) // Pass mutable connection
128                .await;
129
130            match result {
131                Ok(Value::Okay) => {
132                    trace!(run_id, "Successfully acquired context lock");
133                    acquired_lock = true;
134                    break; // Lock acquired, exit loop
135                }
136                Ok(Value::Nil) => {
137                    trace!(run_id, "Context lock already held, retrying...");
138                }
139                Ok(other) => {
140                    warn!(
141                        run_id,
142                        ?other,
143                        "Unexpected response from Redis SET NX PX while acquiring lock"
144                    );
145                }
146                Err(e) => {
147                    error!(run_id, error = %e, "Redis error while acquiring context lock");
148                    // Depending on error, might want to break early
149                }
150            }
151
152            // Wait with random backoff before next attempt
153            let delay = rand::thread_rng().gen_range(BASE_RETRY_DELAY_MS..=MAX_RETRY_DELAY_MS);
154            trace!(
155                run_id,
156                attempt,
157                delay_ms = delay,
158                "Waiting before lock retry"
159            );
160            sleep(Duration::from_millis(delay)).await;
161        }
162
163        if !acquired_lock {
164            error!(
165                run_id,
166                "Failed to acquire context lock after {} retries, aborting merge", MAX_LOCK_RETRIES
167            );
168            return Err(ContextStoreError::Other(format!(
169                "Failed to acquire context lock after {} retries, aborting merge",
170                MAX_LOCK_RETRIES
171            )));
172        }
173
174        // --- Lock Acquired: Perform Read-Modify-Write ---
175        // Use a block to manage the RMW logic and ensure lock release
176        let rmw_result = async {
177            // Get the current context from Redis
178            let current: Option<String> = match conn.get(&key).await {
179                Ok(val) => val,
180                Err(e) => {
181                    error!(run_id, error = %e, "Redis error while getting context for merge");
182                    return Err(()); // Indicate error within RMW block
183                }
184            };
185
186            let merged = if let Some(serialized) = current {
187                match serde_json::from_str::<C>(&serialized) {
188                    Ok(mut existing) => {
189                        trace!(run_id, ?existing, ?ctx, "Context before merge");
190                        existing.merge(ctx);
191                        trace!(run_id, ?existing, "Context after merge");
192                        existing
193                    }
194                    Err(e) => {
195                        error!(run_id, error = %e, "Failed to deserialize context for merge");
196                        // Return error instead of using potentially incomplete new context
197                        return Err(()); // Indicate error within RMW block
198                    }
199                }
200            } else {
201                trace!(run_id, ?ctx, "No existing context found, using new context");
202                ctx
203            };
204
205            // Serialize the merged context
206            let serialized = match serde_json::to_string(&merged) {
207                Ok(s) => s,
208                Err(e) => {
209                    error!(run_id, error = %e, "Failed to serialize merged context");
210                    return Err(()); // Indicate error within RMW block
211                }
212            };
213
214            // Store the merged context in Redis
215            trace!(run_id, context_to_write=?merged, "Attempting to write merged context to Redis");
216            if let Err(e) = conn.set(&key, serialized).await as Result<(), _> {
217                error!(run_id, error = %e, "Redis error while setting merged context");
218                Err(()) // Indicate error within RMW block
219            } else {
220                trace!(
221                    run_id,
222                    "Successfully wrote merged context for run {}",
223                    run_id
224                );
225                Ok(()) // Indicate success within RMW block
226            }
227        }
228        .await; // End of RMW async block
229
230        // --- Release Lock ---
231        // Release the lock regardless of RMW outcome.
232        // A more robust implementation might check if the lock value still matches `lock_value` before deleting.
233        trace!(run_id, "Releasing context lock");
234        if let Err(e) = conn.del(&lock_key).await as Result<(), _> {
235            error!(run_id, error = %e, "Failed to release context lock");
236        } else {
237            trace!(run_id, "Successfully released context lock");
238        }
239
240        if rmw_result.is_err() {
241            error!(
242                run_id,
243                "Merge operation failed during read-modify-write phase"
244            );
245            // Potentially signal error further up? For now, just log.
246        }
247        Ok(())
248    }
249}