use crate::{Error, Result};
use fred::{
interfaces::*,
prelude::*,
types::{RedisConfig as FredRedisConfig, ReconnectPolicy},
};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RedisMode {
#[default]
Standalone,
Cluster,
Sentinel,
}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub pool_size: usize,
pub min_idle: Option<usize>,
pub connection_timeout: Option<u64>,
pub idle_timeout: Option<u64>,
pub max_lifetime: Option<u64>,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
pool_size: 10,
min_idle: None,
connection_timeout: None,
idle_timeout: None,
max_lifetime: None,
}
}
}
impl PoolConfig {
pub fn new(pool_size: usize) -> Self {
Self {
pool_size,
..Default::default()
}
}
#[must_use]
pub fn min_idle(mut self, min_idle: usize) -> Self {
self.min_idle = Some(min_idle);
self
}
#[must_use]
pub fn connection_timeout(mut self, timeout: u64) -> Self {
self.connection_timeout = Some(timeout);
self
}
#[must_use]
pub fn idle_timeout(mut self, timeout: u64) -> Self {
self.idle_timeout = Some(timeout);
self
}
#[must_use]
pub fn max_lifetime(mut self, lifetime: u64) -> Self {
self.max_lifetime = Some(lifetime);
self
}
}
#[derive(Debug, Clone)]
pub struct RedisConfig {
pub url: String,
pub pool_config: PoolConfig,
pub mode: RedisMode,
}
impl Default for RedisConfig {
fn default() -> Self {
Self {
url: "redis://localhost:6379".to_string(),
pool_config: PoolConfig::default(),
mode: RedisMode::Standalone,
}
}
}
#[derive(Clone)]
pub struct RedisClient {
pool: Arc<RedisPool>,
}
impl RedisClient {
pub async fn new(config: RedisConfig) -> Result<Self> {
let redis_config = FredRedisConfig::from_url(&config.url)?;
let pool = RedisPool::new(
redis_config,
None,
None,
Some(ReconnectPolicy::default()),
config.pool_config.pool_size,
)?;
pool.init().await?;
match config.mode {
RedisMode::Standalone => {
tracing::info!("Connected to Redis at {}", config.url);
}
RedisMode::Cluster => {
tracing::info!("Connected to Redis Cluster at {}", config.url);
}
RedisMode::Sentinel => {
tracing::info!("Connected to Redis Sentinel at {}", config.url);
}
}
Ok(Self {
pool: Arc::new(pool),
})
}
pub async fn from_url(url: impl Into<String>) -> Result<Self> {
Self::from_url_with_pool_config(url, PoolConfig::default()).await
}
pub async fn from_url_with_pool(url: impl Into<String>, pool_size: usize) -> Result<Self> {
Self::from_url_with_pool_config(url, PoolConfig::new(pool_size)).await
}
pub async fn from_url_with_pool_config(url: impl Into<String>, pool_config: PoolConfig) -> Result<Self> {
let url = url.into();
let redis_config = FredRedisConfig::from_url(&url)?;
let pool = RedisPool::new(
redis_config,
None,
None,
Some(ReconnectPolicy::default()),
pool_config.pool_size,
)?;
pool.init().await?;
tracing::info!("Connected to Redis at {} (pool size: {})", url, pool_config.pool_size);
Ok(Self {
pool: Arc::new(pool),
})
}
pub async fn from_cluster_url(url: impl Into<String>) -> Result<Self> {
Self::from_cluster_url_with_pool_config(url, PoolConfig::default()).await
}
pub async fn from_cluster_url_with_pool(url: impl Into<String>, pool_size: usize) -> Result<Self> {
Self::from_cluster_url_with_pool_config(url, PoolConfig::new(pool_size)).await
}
pub async fn from_cluster_url_with_pool_config(url: impl Into<String>, pool_config: PoolConfig) -> Result<Self> {
let url = url.into();
let redis_config = FredRedisConfig::from_url(&url)?;
let pool = RedisPool::new(
redis_config,
None,
None,
Some(ReconnectPolicy::default()),
pool_config.pool_size,
)?;
pool.init().await?;
tracing::info!("Connected to Redis Cluster at {} (pool size: {})", url, pool_config.pool_size);
Ok(Self {
pool: Arc::new(pool),
})
}
pub async fn from_sentinel_url(url: impl Into<String>) -> Result<Self> {
Self::from_sentinel_url_with_pool_config(url, PoolConfig::default()).await
}
pub async fn from_sentinel_url_with_pool(url: impl Into<String>, pool_size: usize) -> Result<Self> {
Self::from_sentinel_url_with_pool_config(url, PoolConfig::new(pool_size)).await
}
pub async fn from_sentinel_url_with_pool_config(url: impl Into<String>, pool_config: PoolConfig) -> Result<Self> {
let url = url.into();
let redis_config = FredRedisConfig::from_url(&url)?;
let pool = RedisPool::new(
redis_config,
None,
None,
Some(ReconnectPolicy::default()),
pool_config.pool_size,
)?;
pool.init().await?;
tracing::info!("Connected to Redis Sentinel at {} (pool size: {})", url, pool_config.pool_size);
Ok(Self {
pool: Arc::new(pool),
})
}
pub fn pool(&self) -> &Arc<RedisPool> {
&self.pool
}
pub async fn ping(&self) -> Result<String> {
let result: String = self.pool.ping().await?;
Ok(result)
}
pub async fn set(&self, key: RedisKey, value: RedisValue) -> Result<()> {
let _: () = self.pool.set(key, value, None, None, false).await?;
Ok(())
}
pub async fn get(&self, key: RedisKey) -> Result<Option<RedisValue>> {
let result: Option<RedisValue> = self.pool.get(key).await?;
Ok(result)
}
pub async fn del(&self, keys: Vec<RedisKey>) -> Result<usize> {
let result: usize = self.pool.del(keys).await?;
Ok(result)
}
pub async fn exists(&self, key: RedisKey) -> Result<bool> {
let result: bool = self.pool.exists(key).await?;
Ok(result)
}
pub async fn expire(&self, key: RedisKey, seconds: u64) -> Result<bool> {
let result: bool = self.pool.expire(key, seconds as i64).await?;
Ok(result)
}
pub async fn rpush(&self, key: RedisKey, value: RedisValue) -> Result<u64> {
let result: u64 = self.pool.rpush(key, value).await?;
Ok(result)
}
pub async fn lpush(&self, key: RedisKey, value: RedisValue) -> Result<u64> {
let result: u64 = self.pool.lpush(key, value).await?;
Ok(result)
}
pub async fn blpop(&self, key: RedisKey, timeout: u64) -> Result<Option<(String, String)>> {
let result: Option<(String, String)> = self.pool.blpop(key, timeout as f64).await?;
Ok(result)
}
pub async fn brpop(&self, key: RedisKey, timeout: u64) -> Result<Option<(String, String)>> {
let result: Option<(String, String)> = self.pool.brpop(key, timeout as f64).await?;
Ok(result)
}
pub async fn llen(&self, key: RedisKey) -> Result<u64> {
let result: u64 = self.pool.llen(key).await?;
Ok(result)
}
pub async fn lrange(&self, key: RedisKey, start: i64, stop: i64) -> Result<Vec<String>> {
let result: Vec<RedisValue> = self.pool.lrange(key, start, stop).await?;
Ok(result.into_iter().filter_map(|v| v.as_string().map(|s| s.to_string())).collect())
}
pub async fn zadd(&self, key: RedisKey, member: RedisValue, score: i64) -> Result<()> {
let values: Vec<(f64, RedisValue)> = vec![(score as f64, member)];
let _: () = self.pool.zadd(key, None, None, false, false, values).await?;
Ok(())
}
pub async fn zrange_with_scores(&self, key: RedisKey, start: i64, stop: i64) -> Result<Vec<(String, f64)>> {
let result: Vec<RedisValue> = self
.pool
.zrange(key, start, stop, None, false, None, true)
.await?;
let mut output = Vec::new();
for chunk in result.chunks(2) {
if chunk.len() == 2 {
let member = chunk[0].as_string().map(|s| s.to_string());
let score = chunk[1].as_f64();
if let (Some(m), Some(s)) = (member, score) {
output.push((m, s));
}
}
}
Ok(output)
}
pub async fn zrange(&self, key: RedisKey, start: i64, stop: i64) -> Result<Vec<String>> {
let result: Vec<RedisValue> = self
.pool
.zrange(key, start, stop, None, false, None, false)
.await?;
Ok(result.into_iter().filter_map(|v| v.as_string().map(|s| s.to_string())).collect())
}
pub async fn zrangebyscore(&self, key: RedisKey, min: i64, max: i64) -> Result<Vec<String>> {
let result: Vec<RedisValue> = self
.pool
.zrangebyscore(key, min, max, false, None)
.await?;
Ok(result.into_iter().filter_map(|v| v.as_string().map(|s| s.to_string())).collect())
}
pub async fn zrem(&self, key: RedisKey, member: RedisValue) -> Result<bool> {
let result: u64 = self.pool.zrem(key, member).await?;
Ok(result > 0)
}
pub async fn zcard(&self, key: RedisKey) -> Result<u64> {
let result: u64 = self.pool.zcard(key).await?;
Ok(result)
}
pub async fn sadd(&self, key: RedisKey, member: RedisValue) -> Result<bool> {
let result: u64 = self.pool.sadd(key, member).await?;
Ok(result > 0)
}
pub async fn sismember(&self, key: RedisKey, member: RedisValue) -> Result<bool> {
let result: bool = self.pool.sismember(key, member).await?;
Ok(result)
}
pub async fn srem(&self, key: RedisKey, member: RedisValue) -> Result<bool> {
let result: u64 = self.pool.srem(key, member).await?;
Ok(result > 0)
}
pub async fn smembers(&self, key: RedisKey) -> Result<Vec<String>> {
let result: Vec<RedisValue> = self.pool.smembers(key).await?;
Ok(result.into_iter().filter_map(|v| v.as_string().map(|s| s.to_string())).collect())
}
pub async fn scard(&self, key: RedisKey) -> Result<u64> {
let result: u64 = self.pool.scard(key).await?;
Ok(result)
}
pub async fn hset(&self, key: RedisKey, values: Vec<(RedisKey, RedisValue)>) -> Result<bool> {
let result: u64 = self.pool.hset(key, values).await?;
Ok(result > 0)
}
pub async fn hincrby(&self, key: RedisKey, field: RedisKey, increment: i64) -> Result<i64> {
let result: i64 = self.pool.hincrby(key, field, increment).await?;
Ok(result)
}
pub async fn hget(&self, key: RedisKey, field: RedisKey) -> Result<Option<RedisValue>> {
let result: Option<RedisValue> = self.pool.hget(key, field).await?;
Ok(result)
}
pub async fn hgetall(&self, key: RedisKey) -> Result<Vec<RedisValue>> {
let result: Vec<RedisValue> = self.pool.hgetall(key).await?;
Ok(result)
}
pub async fn hmset(&self, key: RedisKey, values: Vec<(RedisKey, RedisValue)>) -> Result<()> {
let _: () = self.pool.hset(key, values).await?;
Ok(())
}
pub fn pipeline(&self) -> RedisPipeline {
RedisPipeline::new(self.pool.clone())
}
pub async fn lrem(&self, key: RedisKey, value: RedisValue, count: i64) -> Result<u64> {
let result: u64 = self.pool.lrem(key, count, value).await?;
Ok(result)
}
pub async fn eval_script(
&self,
script: &str,
keys: Vec<RedisKey>,
args: Vec<RedisValue>,
) -> Result<Option<String>> {
use fred::interfaces::LuaInterface;
let client = self.pool.next();
let result: fred::types::RedisValue = client.eval(script, keys, args).await?;
match result.as_string() {
Some(s) => Ok(Some(s.to_string())),
None => {
let type_str = format!("{:?}", result);
if type_str.contains("ERR_QUEUE_PAUSED") {
return Err(Error::QueuePaused("Queue paused".to_string()));
}
if type_str.contains("ERR_TIMEOUT") || type_str == "Nil" {
return Ok(None);
}
if type_str.starts_with("Error") {
return Err(Error::Redis(fred::error::RedisError::new(
fred::error::RedisErrorKind::Unknown,
type_str,
)));
}
Ok(None)
}
}
}
pub async fn dedup_add(&self, dedup_key: RedisKey, unique_key: RedisValue) -> Result<bool> {
const DEDUP_SCRIPT: &str = r#"
-- Atomic deduplication script
local dedup_key = KEYS[1]
local unique_key = ARGV[1]
-- Check if key already exists
if redis.call('SISMEMBER', dedup_key, unique_key) == 1 then
return 0 -- Already exists, return false
end
-- Add to set
redis.call('SADD', dedup_key, unique_key)
return 1 -- Successfully added, return true
"#;
let keys = vec![dedup_key];
let args = vec![unique_key];
tracing::debug!("dedup_add: executing atomic dedup check");
match self.eval_script(DEDUP_SCRIPT, keys, args).await {
Ok(Some(result)) => {
let added = result == "1";
tracing::debug!("dedup_add: result = {}", added);
Ok(added)
}
Ok(None) => {
tracing::warn!("dedup_add: unexpected null result");
Ok(false)
}
Err(e) => {
tracing::warn!("dedup_add: script execution failed: {}", e);
Err(e)
}
}
}
pub async fn move_expired_tasks_lua(
&self,
source: RedisKey,
dest: RedisKey,
now: i64,
batch_size: usize,
) -> Result<usize> {
const MOVE_EXPIRED_SCRIPT: &str = r#"
-- Atomic batch move expired tasks script
local source_key = KEYS[1]
local dest_key = KEYS[2]
local now = tonumber(ARGV[1])
local batch_size = tonumber(ARGV[2])
-- Get expired tasks (score <= now)
local tasks = redis.call('ZRANGEBYSCORE', source_key, '-inf', now, 'LIMIT', 0, batch_size)
if not tasks or #tasks == 0 then
return 0
end
-- Remove from source and add to destination
for _, task_id in ipairs(tasks) do
redis.call('ZREM', source_key, task_id)
redis.call('RPUSH', dest_key, task_id)
end
return #tasks
"#;
let keys = vec![source, dest];
let args = vec![
RedisValue::from(now.to_string()),
RedisValue::from(batch_size.to_string()),
];
tracing::debug!("move_expired_tasks_lua: executing batch move (now={}, batch={})", now, batch_size);
match self.eval_script(MOVE_EXPIRED_SCRIPT, keys, args).await {
Ok(Some(count)) => {
let moved = count.parse::<usize>().unwrap_or(0);
tracing::debug!("move_expired_tasks_lua: moved {} tasks", moved);
Ok(moved)
}
Ok(None) => {
tracing::debug!("move_expired_tasks_lua: no tasks moved");
Ok(0)
}
Err(e) => {
tracing::warn!("move_expired_tasks_lua: script execution failed: {}", e);
Err(e)
}
}
}
pub async fn pdequeue_lua(
&self,
pqueue: RedisKey,
active: RedisKey,
pause: RedisKey,
_task_prefix: RedisKey,
ttl: usize,
) -> Result<String> {
const PDEQUEUE_SCRIPT: &str = r#"
-- Atomic priority dequeue script
local pqueue_key = KEYS[1]
local active_key = KEYS[2]
local pause_key = KEYS[3]
local timeout = tonumber(ARGV[1])
local current_timestamp = tonumber(ARGV[2])
local task_ttl = tonumber(ARGV[3]) or 86400
-- Check if queue is paused
if redis.call('EXISTS', pause_key) == 1 then
return {err = 'ERR_QUEUE_PAUSED'}
end
-- Get task with highest priority (lowest score)
local results = redis.call('ZRANGE', pqueue_key, 0, 0)
if not results or #results == 0 then
return {err = 'ERR_TIMEOUT'}
end
local task_id = results[1]
-- Remove from priority queue (atomic with ZRANGE since in same script)
redis.call('ZREM', pqueue_key, task_id)
-- Move to active queue
redis.call('LPUSH', active_key, task_id)
-- Update task status
local task_key = 'rediq:task:' .. task_id
redis.call('HSET', task_key, 'status', 'active')
redis.call('HSET', task_key, 'processed_at', current_timestamp)
redis.call('EXPIRE', task_key, task_ttl)
return {ok = task_id}
"#;
let current_timestamp = chrono::Utc::now().timestamp();
let keys = vec![pqueue, active, pause];
let args = vec![
RedisValue::from("0"), RedisValue::from(current_timestamp.to_string()),
RedisValue::from(ttl.to_string()),
];
tracing::debug!("pdequeue_lua: executing atomic dequeue script");
match self.eval_script(PDEQUEUE_SCRIPT, keys, args).await {
Ok(Some(task_id)) => {
tracing::debug!("pdequeue_lua: successfully dequeued task {}", task_id);
Ok(task_id)
}
Ok(None) => {
tracing::debug!("pdequeue_lua: no tasks in priority queue");
Ok(String::new())
}
Err(Error::QueuePaused(_)) => {
tracing::debug!("pdequeue_lua: queue is paused");
Err(Error::QueuePaused("Queue paused".to_string()))
}
Err(e) => {
tracing::warn!("pdequeue_lua: script execution failed: {}", e);
Err(e)
}
}
}
pub async fn ttl(&self, key: &str) -> Result<Option<i64>> {
let key: RedisKey = key.into();
let result: Option<i64> = self.pool.ttl(key).await?;
match result {
Some(-2) => Ok(None), Some(-1) => Ok(None), Some(ttl) => Ok(Some(ttl)),
None => Ok(None),
}
}
pub async fn scan_match(&self, _cursor: u64, pattern: &str, count: u64) -> Result<(u64, Vec<String>)> {
use fred::types::Scanner;
use futures::StreamExt;
let client = self.pool.next();
let mut stream = client.scan(pattern, Some(count as u32), None);
match stream.next().await {
Some(Ok(scan_result)) => {
let has_more = scan_result.has_more();
let keys: Vec<String> = scan_result
.results()
.as_ref()
.map(|v| v.iter().filter_map(|k| k.as_str().map(|s| s.to_string())).collect())
.unwrap_or_default();
if has_more {
Ok((1, keys))
} else {
Ok((0, keys))
}
}
Some(Err(e)) => Err(Error::Redis(e)),
None => Ok((0, Vec::new())),
}
}
}
pub struct RedisPipeline {
pool: Arc<RedisPool>,
sets: Vec<(RedisKey, RedisValue)>,
rpushes: Vec<(RedisKey, Vec<RedisValue>)>,
sadds: Vec<(RedisKey, RedisValue)>,
expires: Vec<(RedisKey, u64)>,
}
impl RedisPipeline {
pub fn new(pool: Arc<RedisPool>) -> Self {
Self {
pool,
sets: Vec::new(),
rpushes: Vec::new(),
sadds: Vec::new(),
expires: Vec::new(),
}
}
pub fn set(mut self, key: RedisKey, value: RedisValue) -> Self {
self.sets.push((key, value));
self
}
pub fn rpush(mut self, key: RedisKey, value: RedisValue) -> Self {
self.rpushes.push((key, vec![value]));
self
}
pub fn sadd(mut self, key: RedisKey, member: RedisValue) -> Self {
self.sadds.push((key, member));
self
}
pub fn expire(mut self, key: RedisKey, seconds: u64) -> Self {
self.expires.push((key, seconds));
self
}
pub async fn execute(self) -> Result<Vec<RedisValue>> {
for (key, value) in self.sets {
let _: () = self.pool.set(key, value, None, None, false).await?;
}
for (key, values) in self.rpushes {
let _: u64 = self.pool.rpush(key, values).await?;
}
for (key, member) in self.sadds {
let _: u64 = self.pool.sadd(key, member).await?;
}
for (key, seconds) in self.expires {
let _: bool = self.pool.expire(key, seconds as i64).await?;
}
Ok(Vec::new())
}
}
pub type RedisPipelineBuilder = RedisPipeline;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore = "Requires Redis server"]
async fn test_redis_ping() {
let redis_url = std::env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://localhost:6379".to_string());
let client = RedisClient::from_url(&redis_url)
.await
.unwrap();
let result = client.ping().await.unwrap();
assert_eq!(result, "PONG");
}
}