use crate::extensions::flow_state::FlowStore;
use anyhow::{Context, Result};
use redis::{Commands, Connection};
use serde_json::Value;
use std::sync::Mutex;
struct RedisConnectionManager {
client: redis::Client,
}
impl RedisConnectionManager {
fn new(client: redis::Client) -> Self {
Self { client }
}
}
impl r2d2::ManageConnection for RedisConnectionManager {
type Connection = Mutex<Connection>;
type Error = redis::RedisError;
fn connect(&self) -> Result<Self::Connection, Self::Error> {
let conn = self.client.get_connection()?;
Ok(Mutex::new(conn))
}
fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
redis::cmd("PING").query(conn.get_mut().unwrap())
}
fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
false
}
}
pub struct RedisFlowStore {
pool: r2d2::Pool<RedisConnectionManager>,
key_prefix: String,
default_ttl_seconds: i64,
}
impl RedisFlowStore {
pub fn new(
url: &str,
pool_size: usize,
key_prefix: String,
default_ttl_seconds: i64,
) -> Result<Self> {
let client = redis::Client::open(url).context("Failed to parse Redis URL")?;
let manager = RedisConnectionManager::new(client);
let pool = r2d2::Pool::builder()
.max_size(pool_size as u32)
.connection_timeout(std::time::Duration::from_secs(5))
.build(manager)
.context("Failed to create Redis connection pool")?;
{
let conn = pool.get().context("Failed to get connection from pool")?;
let _: String = redis::cmd("PING")
.query(&mut *conn.lock().unwrap())
.context("Failed to PING Redis")?;
}
tracing::info!(
"Connected to Redis with prefix={}, ttl={}s, pool_size={}",
key_prefix,
default_ttl_seconds,
pool_size
);
Ok(Self {
pool,
key_prefix,
default_ttl_seconds,
})
}
fn make_key(&self, flow_id: &str, key: &str) -> String {
format!("{}flow:{}:{}", self.key_prefix, flow_id, key)
}
}
impl FlowStore for RedisFlowStore {
fn get(&self, flow_id: &str, key: &str) -> Result<Option<Value>> {
let key_str = self.make_key(flow_id, key);
let conn = self
.pool
.get()
.context("Failed to get Redis connection from pool")?;
let value: Option<String> = conn
.lock()
.unwrap()
.get(&key_str)
.context("Redis GET failed")?;
if let Some(json_str) = value {
let val = serde_json::from_str(&json_str).context("Failed to parse JSON from Redis")?;
Ok(Some(val))
} else {
Ok(None)
}
}
fn set(&self, flow_id: &str, key: &str, value: Value) -> Result<()> {
let key_str = self.make_key(flow_id, key);
let json_str =
serde_json::to_string(&value).context("Failed to serialize value to JSON")?;
let conn = self
.pool
.get()
.context("Failed to get Redis connection from pool")?;
let _: () = redis::cmd("SETEX")
.arg(&key_str)
.arg(self.default_ttl_seconds)
.arg(json_str)
.query(&mut *conn.lock().unwrap())
.context("Redis SETEX failed")?;
Ok(())
}
fn exists(&self, flow_id: &str, key: &str) -> Result<bool> {
let key_str = self.make_key(flow_id, key);
let conn = self
.pool
.get()
.context("Failed to get Redis connection from pool")?;
let count: i64 = conn
.lock()
.unwrap()
.exists(&key_str)
.context("Redis EXISTS failed")?;
Ok(count > 0)
}
fn delete(&self, flow_id: &str, key: &str) -> Result<()> {
let key_str = self.make_key(flow_id, key);
let conn = self
.pool
.get()
.context("Failed to get Redis connection from pool")?;
let _: () = conn
.lock()
.unwrap()
.del(&key_str)
.context("Redis DEL failed")?;
Ok(())
}
fn increment(&self, flow_id: &str, key: &str) -> Result<i64> {
let key_str = self.make_key(flow_id, key);
let conn = self
.pool
.get()
.context("Failed to get Redis connection from pool")?;
let mut conn_guard = conn.lock().unwrap();
let new_value: i64 = conn_guard.incr(&key_str, 1).context("Redis INCR failed")?;
let _: () = redis::cmd("EXPIRE")
.arg(&key_str)
.arg(self.default_ttl_seconds)
.query(&mut *conn_guard)
.context("Redis EXPIRE failed")?;
Ok(new_value)
}
fn set_ttl(&self, flow_id: &str, ttl_seconds: i64) -> Result<()> {
tracing::debug!(
"set_ttl called for flow_id={} with ttl={}s - individual operations already set TTL",
flow_id,
ttl_seconds
);
Ok(())
}
}
#[allow(dead_code, private_interfaces)]
pub(crate) fn health_check(pool: &r2d2::Pool<RedisConnectionManager>) -> Result<bool> {
let conn = pool.get().context("Failed to get connection from pool")?;
let mut guard = conn.lock().unwrap();
match redis::cmd("PING").query::<String>(&mut *guard) {
Ok(_) => Ok(true),
Err(e) => {
tracing::warn!("Redis health check failed: {}", e);
Ok(false)
}
}
}