use crate::client::RedisClient;
use async_trait::async_trait;
use floxide_core::{
context::Context,
distributed::context_store::{ContextStore, ContextStoreError},
merge::Merge,
};
use rand::Rng;
use redis::{AsyncCommands, Value};
use serde::{de::DeserializeOwned, Serialize};
use tokio::time::{sleep, Duration};
use tracing::{error, instrument, trace, warn};
const LOCK_TIMEOUT_MS: usize = 5000; const MAX_LOCK_RETRIES: usize = 10; const BASE_RETRY_DELAY_MS: u64 = 50; const MAX_RETRY_DELAY_MS: u64 = 500;
#[derive(Clone)]
pub struct RedisContextStore<C: Context + Merge + Default> {
client: RedisClient,
_phantom: std::marker::PhantomData<C>,
}
impl<C: Context + Merge + Default> RedisContextStore<C> {
pub fn new(client: RedisClient) -> Self {
Self {
client,
_phantom: std::marker::PhantomData,
}
}
fn context_key(&self, run_id: &str) -> String {
self.client.prefixed_key(&format!("context:{}", run_id))
}
fn lock_key(&self, run_id: &str) -> String {
self.client
.prefixed_key(&format!("lock:context:{}", run_id))
}
}
#[async_trait]
impl<C> ContextStore<C> for RedisContextStore<C>
where
C: Context + Merge + Default + Serialize + DeserializeOwned + Send + Sync + 'static,
{
#[instrument(skip(self), level = "trace")]
async fn get(&self, run_id: &str) -> Result<Option<C>, ContextStoreError> {
let key = self.context_key(run_id);
let mut conn = self.client.conn.clone();
let result: Option<String> = conn.get(&key).await.map_err(|e| {
error!("Redis error while getting context: {}", e);
ContextStoreError::Io(e.to_string())
})?;
if let Some(serialized) = result {
let context = serde_json::from_str(&serialized).map_err(|e| {
error!("Failed to deserialize context: {}", e);
ContextStoreError::Other(format!("Deserialization error: {}", e))
})?;
trace!("Got context for run {}", run_id);
Ok(Some(context))
} else {
trace!("No context found for run {}", run_id);
Ok(None)
}
}
#[instrument(skip(self, ctx), level = "trace")]
async fn set(&self, run_id: &str, ctx: C) -> Result<(), ContextStoreError> {
let key = self.context_key(run_id);
let mut conn = self.client.conn.clone();
let serialized = match serde_json::to_string(&ctx) {
Ok(s) => s,
Err(e) => {
error!("Failed to serialize context: {}", e);
return Err(ContextStoreError::Other(format!(
"Serialization error: {}",
e
)));
}
};
if let Err(e) = conn.set(&key, serialized).await as Result<(), _> {
error!("Redis error while setting context: {}", e);
return Err(ContextStoreError::Other(format!(
"Redis error while setting context: {}",
e
)));
} else {
trace!("Set context for run {}", run_id);
Ok(())
}
}
#[instrument(skip(self, ctx), level = "trace")]
async fn merge(&self, run_id: &str, ctx: C) -> Result<(), ContextStoreError> {
let key = self.context_key(run_id);
let lock_key = self.lock_key(run_id);
let lock_value = format!("worker_{}", rand::thread_rng().gen::<u32>()); let mut conn = self.client.conn.clone();
let mut acquired_lock = false;
for attempt in 0..MAX_LOCK_RETRIES {
trace!(run_id, attempt, "Attempting to acquire context lock");
let result: Result<Value, redis::RedisError> = redis::cmd("SET")
.arg(&lock_key)
.arg(&lock_value)
.arg("NX")
.arg("PX")
.arg(LOCK_TIMEOUT_MS)
.query_async(&mut conn) .await;
match result {
Ok(Value::Okay) => {
trace!(run_id, "Successfully acquired context lock");
acquired_lock = true;
break; }
Ok(Value::Nil) => {
trace!(run_id, "Context lock already held, retrying...");
}
Ok(other) => {
warn!(
run_id,
?other,
"Unexpected response from Redis SET NX PX while acquiring lock"
);
}
Err(e) => {
error!(run_id, error = %e, "Redis error while acquiring context lock");
}
}
let delay = rand::thread_rng().gen_range(BASE_RETRY_DELAY_MS..=MAX_RETRY_DELAY_MS);
trace!(
run_id,
attempt,
delay_ms = delay,
"Waiting before lock retry"
);
sleep(Duration::from_millis(delay)).await;
}
if !acquired_lock {
error!(
run_id,
"Failed to acquire context lock after {} retries, aborting merge", MAX_LOCK_RETRIES
);
return Err(ContextStoreError::Other(format!(
"Failed to acquire context lock after {} retries, aborting merge",
MAX_LOCK_RETRIES
)));
}
let rmw_result = async {
let current: Option<String> = match conn.get(&key).await {
Ok(val) => val,
Err(e) => {
error!(run_id, error = %e, "Redis error while getting context for merge");
return Err(()); }
};
let merged = if let Some(serialized) = current {
match serde_json::from_str::<C>(&serialized) {
Ok(mut existing) => {
trace!(run_id, ?existing, ?ctx, "Context before merge");
existing.merge(ctx);
trace!(run_id, ?existing, "Context after merge");
existing
}
Err(e) => {
error!(run_id, error = %e, "Failed to deserialize context for merge");
return Err(()); }
}
} else {
trace!(run_id, ?ctx, "No existing context found, using new context");
ctx
};
let serialized = match serde_json::to_string(&merged) {
Ok(s) => s,
Err(e) => {
error!(run_id, error = %e, "Failed to serialize merged context");
return Err(()); }
};
trace!(run_id, context_to_write=?merged, "Attempting to write merged context to Redis");
if let Err(e) = conn.set(&key, serialized).await as Result<(), _> {
error!(run_id, error = %e, "Redis error while setting merged context");
Err(()) } else {
trace!(
run_id,
"Successfully wrote merged context for run {}",
run_id
);
Ok(()) }
}
.await;
trace!(run_id, "Releasing context lock");
if let Err(e) = conn.del(&lock_key).await as Result<(), _> {
error!(run_id, error = %e, "Failed to release context lock");
} else {
trace!(run_id, "Successfully released context lock");
}
if rmw_result.is_err() {
error!(
run_id,
"Merge operation failed during read-modify-write phase"
);
}
Ok(())
}
}