pyra-redis 0.3.0

Shared Redis client, key builders, and common operations for Pyra services
Documentation
use std::collections::HashMap;

use deadpool_redis::Pool;
use redis::AsyncCommands;

use crate::error::RedisResult;

/// Shared Redis client wrapping a `deadpool_redis::Pool`.
///
/// Provides typed helpers for common operations used across Pyra services:
/// get/set with optional TTL, JSON serialization, MGET, SCAN, sets, hashes,
/// streams, and distributed locks (SET NX).
#[derive(Clone)]
pub struct RedisClient {
    pool: Pool,
}

impl RedisClient {
    pub fn new(pool: Pool) -> Self {
        Self { pool }
    }

    /// Get a reference to the underlying pool.
    pub fn pool(&self) -> &Pool {
        &self.pool
    }

    // ── String operations ─────────────────────────────────────────────

    pub async fn get(&self, key: &str) -> RedisResult<Option<String>> {
        let mut conn = self.pool.get().await?;
        let val: Option<String> = conn.get(key).await?;
        Ok(val)
    }

    pub async fn set(&self, key: &str, value: &str, ttl_seconds: Option<u64>) -> RedisResult<()> {
        let mut conn = self.pool.get().await?;
        match ttl_seconds {
            Some(ttl) => {
                let _: () = conn.set_ex(key, value, ttl).await?;
            }
            None => {
                let _: () = conn.set(key, value).await?;
            }
        }
        Ok(())
    }

    /// SET key value NX EX ttl — returns `true` if the key was set (lock acquired).
    pub async fn set_nx(&self, key: &str, value: &str, ttl_seconds: u64) -> RedisResult<bool> {
        let mut conn = self.pool.get().await?;
        let result: Option<()> = redis::cmd("SET")
            .arg(key)
            .arg(value)
            .arg("NX")
            .arg("EX")
            .arg(ttl_seconds)
            .query_async(&mut *conn)
            .await?;
        Ok(result.is_some())
    }

    pub async fn delete(&self, key: &str) -> RedisResult<bool> {
        let mut conn = self.pool.get().await?;
        let deleted: i64 = conn.del(key).await?;
        Ok(deleted > 0)
    }

    pub async fn exists(&self, key: &str) -> RedisResult<bool> {
        let mut conn = self.pool.get().await?;
        let exists: bool = conn.exists(key).await?;
        Ok(exists)
    }

    pub async fn expire(&self, key: &str, ttl_seconds: i64) -> RedisResult<bool> {
        let mut conn = self.pool.get().await?;
        let set: bool = conn.expire(key, ttl_seconds).await?;
        Ok(set)
    }

    pub async fn increment(&self, key: &str) -> RedisResult<i64> {
        let mut conn = self.pool.get().await?;
        let val: i64 = conn.incr(key, 1i64).await?;
        Ok(val)
    }

    /// INCRBY — increment key by a specific amount.
    pub async fn increment_by(&self, key: &str, amount: i64) -> RedisResult<i64> {
        let mut conn = self.pool.get().await?;
        let val: i64 = conn.incr(key, amount).await?;
        Ok(val)
    }

    /// DECRBY — decrement key by a specific amount.
    pub async fn decrement_by(&self, key: &str, amount: i64) -> RedisResult<i64> {
        let mut conn = self.pool.get().await?;
        let val: i64 = conn.decr(key, amount).await?;
        Ok(val)
    }

    /// TTL — get remaining time-to-live in seconds. Returns -1 if no expiry, -2 if key doesn't exist.
    pub async fn ttl(&self, key: &str) -> RedisResult<i64> {
        let mut conn = self.pool.get().await?;
        let val: i64 = conn.ttl(key).await?;
        Ok(val)
    }

    // ── JSON helpers ──────────────────────────────────────────────────

    pub async fn set_json<T: serde::Serialize>(
        &self,
        key: &str,
        value: &T,
        ttl_seconds: Option<u64>,
    ) -> RedisResult<()> {
        let json = serde_json::to_string(value)?;
        self.set(key, &json, ttl_seconds).await
    }

    pub async fn get_json<T: serde::de::DeserializeOwned>(
        &self,
        key: &str,
    ) -> RedisResult<Option<T>> {
        match self.get(key).await? {
            Some(raw) => Ok(Some(serde_json::from_str(&raw)?)),
            None => Ok(None),
        }
    }

    // ── Bulk operations ───────────────────────────────────────────────

    /// MSET — set multiple key-value pairs in a single round-trip.
    pub async fn set_multiple(&self, pairs: &[(String, String)]) -> RedisResult<()> {
        if pairs.is_empty() {
            return Ok(());
        }
        let mut conn = self.pool.get().await?;
        let _: () = conn.mset(pairs).await?;
        Ok(())
    }

    /// MGET — fetch multiple keys in a single round-trip.
    pub async fn mget(&self, keys: &[String]) -> RedisResult<Vec<Option<String>>> {
        if keys.is_empty() {
            return Ok(Vec::new());
        }
        let mut conn = self.pool.get().await?;
        let values: Vec<Option<String>> = conn.mget(keys).await?;
        Ok(values)
    }

    /// SCAN with a glob pattern. Returns deduplicated keys.
    pub async fn scan_keys(&self, pattern: &str) -> RedisResult<Vec<String>> {
        let mut conn = self.pool.get().await?;
        let mut keys = Vec::new();
        let mut cursor: u64 = 0;

        loop {
            let (next_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
                .arg(cursor)
                .arg("MATCH")
                .arg(pattern)
                .arg("COUNT")
                .arg(1000)
                .query_async(&mut *conn)
                .await?;

            keys.extend(batch);
            cursor = next_cursor;
            if cursor == 0 {
                break;
            }
        }

        // SCAN can return duplicates during hash table resizes.
        keys.sort_unstable();
        keys.dedup();

        Ok(keys)
    }

    // ── Set operations ────────────────────────────────────────────────

    pub async fn set_add(&self, key: &str, members: &[String]) -> RedisResult<usize> {
        if members.is_empty() {
            return Ok(0);
        }
        let mut conn = self.pool.get().await?;
        let count: usize = conn.sadd(key, members).await?;
        Ok(count)
    }

    pub async fn set_members(&self, key: &str) -> RedisResult<Vec<String>> {
        let mut conn = self.pool.get().await?;
        let members: Vec<String> = conn.smembers(key).await?;
        Ok(members)
    }

    pub async fn set_is_member(&self, key: &str, member: &str) -> RedisResult<bool> {
        let mut conn = self.pool.get().await?;
        let is_member: bool = conn.sismember(key, member).await?;
        Ok(is_member)
    }

    // ── Hash operations ───────────────────────────────────────────────

    pub async fn hash_set(&self, key: &str, field: &str, value: &str) -> RedisResult<()> {
        let mut conn = self.pool.get().await?;
        let _: () = conn.hset(key, field, value).await?;
        Ok(())
    }

    pub async fn hash_get(&self, key: &str, field: &str) -> RedisResult<Option<String>> {
        let mut conn = self.pool.get().await?;
        let val: Option<String> = conn.hget(key, field).await?;
        Ok(val)
    }

    pub async fn hash_get_all(&self, key: &str) -> RedisResult<HashMap<String, String>> {
        let mut conn = self.pool.get().await?;
        let map: HashMap<String, String> = conn.hgetall(key).await?;
        Ok(map)
    }

    // ── Stream operations ─────────────────────────────────────────────

    /// XADD with approximate MAXLEN trimming.
    pub async fn xadd(
        &self,
        key: &str,
        max_len: usize,
        fields: &[(&str, &str)],
    ) -> RedisResult<String> {
        let mut conn = self.pool.get().await?;
        let mut cmd = redis::cmd("XADD");
        cmd.arg(key).arg("MAXLEN").arg("~").arg(max_len).arg("*");
        for &(field, value) in fields {
            cmd.arg(field).arg(value);
        }
        let id: String = cmd.query_async(&mut *conn).await?;
        Ok(id)
    }

    // ── List operations ───────────────────────────────────────────────

    /// RPUSH — append one or more values to a list.
    pub async fn list_push(&self, key: &str, values: &[String]) -> RedisResult<i64> {
        if values.is_empty() {
            return Ok(0);
        }
        let mut conn = self.pool.get().await?;
        let len: i64 = conn.rpush(key, values).await?;
        Ok(len)
    }

    pub async fn list_pop(&self, key: &str) -> RedisResult<Option<String>> {
        let mut conn = self.pool.get().await?;
        let val: Option<String> = conn.lpop(key, None).await?;
        Ok(val)
    }

    pub async fn list_length(&self, key: &str) -> RedisResult<i64> {
        let mut conn = self.pool.get().await?;
        let len: i64 = conn.llen(key).await?;
        Ok(len)
    }

    // ── Lua scripting ─────────────────────────────────────────────────

    /// Execute a Lua script via EVAL.
    pub async fn eval<T: redis::FromRedisValue>(
        &self,
        script: &str,
        keys: &[&str],
        args: &[&str],
    ) -> RedisResult<T> {
        let mut conn = self.pool.get().await?;
        let result: T = redis::cmd("EVAL")
            .arg(script)
            .arg(keys.len())
            .arg(keys)
            .arg(args)
            .query_async(&mut *conn)
            .await?;
        Ok(result)
    }

    // ── Health ─────────────────────────────────────────────────────────

    pub async fn ping(&self) -> RedisResult<bool> {
        let mut conn = self.pool.get().await?;
        let response: String = redis::cmd("PING").query_async(&mut *conn).await?;
        Ok(response == "PONG")
    }

    /// Health check — attempts a GET and returns false on error.
    pub async fn health_check(&self) -> bool {
        self.get("health_check").await.is_ok()
    }

    // ── Pool monitoring ──────────────────────────────────────────────

    /// Total number of connections in the pool.
    pub fn pool_size(&self) -> usize {
        self.pool.status().size
    }

    /// Number of idle connections available for use.
    pub fn available_connections(&self) -> usize {
        self.pool.status().available
    }

    /// Number of tasks waiting for a connection.
    pub fn waiting_connections(&self) -> usize {
        self.pool.status().waiting
    }
}