Skip to main content

pyra_redis/
client.rs

1use std::collections::HashMap;
2
3use deadpool_redis::Pool;
4use redis::AsyncCommands;
5
6use crate::error::RedisResult;
7
8/// Shared Redis client wrapping a `deadpool_redis::Pool`.
9///
10/// Provides typed helpers for common operations used across Pyra services:
11/// get/set with optional TTL, JSON serialization, MGET, SCAN, sets, hashes,
12/// streams, and distributed locks (SET NX).
13#[derive(Clone)]
14pub struct RedisClient {
15    pool: Pool,
16}
17
18impl RedisClient {
19    pub fn new(pool: Pool) -> Self {
20        Self { pool }
21    }
22
23    /// Get a reference to the underlying pool.
24    pub fn pool(&self) -> &Pool {
25        &self.pool
26    }
27
28    // ── String operations ─────────────────────────────────────────────
29
30    pub async fn get(&self, key: &str) -> RedisResult<Option<String>> {
31        let mut conn = self.pool.get().await?;
32        let val: Option<String> = conn.get(key).await?;
33        Ok(val)
34    }
35
36    pub async fn set(&self, key: &str, value: &str, ttl_seconds: Option<u64>) -> RedisResult<()> {
37        let mut conn = self.pool.get().await?;
38        match ttl_seconds {
39            Some(ttl) => {
40                let _: () = conn.set_ex(key, value, ttl).await?;
41            }
42            None => {
43                let _: () = conn.set(key, value).await?;
44            }
45        }
46        Ok(())
47    }
48
49    /// SET key value NX EX ttl — returns `true` if the key was set (lock acquired).
50    pub async fn set_nx(&self, key: &str, value: &str, ttl_seconds: u64) -> RedisResult<bool> {
51        let mut conn = self.pool.get().await?;
52        let result: Option<()> = redis::cmd("SET")
53            .arg(key)
54            .arg(value)
55            .arg("NX")
56            .arg("EX")
57            .arg(ttl_seconds)
58            .query_async(&mut *conn)
59            .await?;
60        Ok(result.is_some())
61    }
62
63    pub async fn delete(&self, key: &str) -> RedisResult<bool> {
64        let mut conn = self.pool.get().await?;
65        let deleted: i64 = conn.del(key).await?;
66        Ok(deleted > 0)
67    }
68
69    pub async fn exists(&self, key: &str) -> RedisResult<bool> {
70        let mut conn = self.pool.get().await?;
71        let exists: bool = conn.exists(key).await?;
72        Ok(exists)
73    }
74
75    pub async fn expire(&self, key: &str, ttl_seconds: i64) -> RedisResult<bool> {
76        let mut conn = self.pool.get().await?;
77        let set: bool = conn.expire(key, ttl_seconds).await?;
78        Ok(set)
79    }
80
81    pub async fn increment(&self, key: &str) -> RedisResult<i64> {
82        let mut conn = self.pool.get().await?;
83        let val: i64 = conn.incr(key, 1i64).await?;
84        Ok(val)
85    }
86
87    /// INCRBY — increment key by a specific amount.
88    pub async fn increment_by(&self, key: &str, amount: i64) -> RedisResult<i64> {
89        let mut conn = self.pool.get().await?;
90        let val: i64 = conn.incr(key, amount).await?;
91        Ok(val)
92    }
93
94    /// DECRBY — decrement key by a specific amount.
95    pub async fn decrement_by(&self, key: &str, amount: i64) -> RedisResult<i64> {
96        let mut conn = self.pool.get().await?;
97        let val: i64 = conn.decr(key, amount).await?;
98        Ok(val)
99    }
100
101    /// TTL — get remaining time-to-live in seconds. Returns -1 if no expiry, -2 if key doesn't exist.
102    pub async fn ttl(&self, key: &str) -> RedisResult<i64> {
103        let mut conn = self.pool.get().await?;
104        let val: i64 = conn.ttl(key).await?;
105        Ok(val)
106    }
107
108    // ── JSON helpers ──────────────────────────────────────────────────
109
110    pub async fn set_json<T: serde::Serialize>(
111        &self,
112        key: &str,
113        value: &T,
114        ttl_seconds: Option<u64>,
115    ) -> RedisResult<()> {
116        let json = serde_json::to_string(value)?;
117        self.set(key, &json, ttl_seconds).await
118    }
119
120    pub async fn get_json<T: serde::de::DeserializeOwned>(
121        &self,
122        key: &str,
123    ) -> RedisResult<Option<T>> {
124        match self.get(key).await? {
125            Some(raw) => Ok(Some(serde_json::from_str(&raw)?)),
126            None => Ok(None),
127        }
128    }
129
130    // ── Bulk operations ───────────────────────────────────────────────
131
132    /// MSET — set multiple key-value pairs in a single round-trip.
133    pub async fn set_multiple(&self, pairs: &[(String, String)]) -> RedisResult<()> {
134        if pairs.is_empty() {
135            return Ok(());
136        }
137        let mut conn = self.pool.get().await?;
138        let _: () = conn.mset(pairs).await?;
139        Ok(())
140    }
141
142    /// MGET — fetch multiple keys in a single round-trip.
143    pub async fn mget(&self, keys: &[String]) -> RedisResult<Vec<Option<String>>> {
144        if keys.is_empty() {
145            return Ok(Vec::new());
146        }
147        let mut conn = self.pool.get().await?;
148        let values: Vec<Option<String>> = conn.mget(keys).await?;
149        Ok(values)
150    }
151
152    /// SCAN with a glob pattern. Returns deduplicated keys.
153    pub async fn scan_keys(&self, pattern: &str) -> RedisResult<Vec<String>> {
154        let mut conn = self.pool.get().await?;
155        let mut keys = Vec::new();
156        let mut cursor: u64 = 0;
157
158        loop {
159            let (next_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
160                .arg(cursor)
161                .arg("MATCH")
162                .arg(pattern)
163                .arg("COUNT")
164                .arg(1000)
165                .query_async(&mut *conn)
166                .await?;
167
168            keys.extend(batch);
169            cursor = next_cursor;
170            if cursor == 0 {
171                break;
172            }
173        }
174
175        // SCAN can return duplicates during hash table resizes.
176        keys.sort_unstable();
177        keys.dedup();
178
179        Ok(keys)
180    }
181
182    // ── Set operations ────────────────────────────────────────────────
183
184    pub async fn set_add(&self, key: &str, members: &[String]) -> RedisResult<usize> {
185        if members.is_empty() {
186            return Ok(0);
187        }
188        let mut conn = self.pool.get().await?;
189        let count: usize = conn.sadd(key, members).await?;
190        Ok(count)
191    }
192
193    pub async fn set_members(&self, key: &str) -> RedisResult<Vec<String>> {
194        let mut conn = self.pool.get().await?;
195        let members: Vec<String> = conn.smembers(key).await?;
196        Ok(members)
197    }
198
199    pub async fn set_is_member(&self, key: &str, member: &str) -> RedisResult<bool> {
200        let mut conn = self.pool.get().await?;
201        let is_member: bool = conn.sismember(key, member).await?;
202        Ok(is_member)
203    }
204
205    // ── Hash operations ───────────────────────────────────────────────
206
207    pub async fn hash_set(&self, key: &str, field: &str, value: &str) -> RedisResult<()> {
208        let mut conn = self.pool.get().await?;
209        let _: () = conn.hset(key, field, value).await?;
210        Ok(())
211    }
212
213    pub async fn hash_get(&self, key: &str, field: &str) -> RedisResult<Option<String>> {
214        let mut conn = self.pool.get().await?;
215        let val: Option<String> = conn.hget(key, field).await?;
216        Ok(val)
217    }
218
219    pub async fn hash_get_all(&self, key: &str) -> RedisResult<HashMap<String, String>> {
220        let mut conn = self.pool.get().await?;
221        let map: HashMap<String, String> = conn.hgetall(key).await?;
222        Ok(map)
223    }
224
225    // ── Stream operations ─────────────────────────────────────────────
226
227    /// XADD with approximate MAXLEN trimming.
228    pub async fn xadd(
229        &self,
230        key: &str,
231        max_len: usize,
232        fields: &[(&str, &str)],
233    ) -> RedisResult<String> {
234        let mut conn = self.pool.get().await?;
235        let mut cmd = redis::cmd("XADD");
236        cmd.arg(key).arg("MAXLEN").arg("~").arg(max_len).arg("*");
237        for &(field, value) in fields {
238            cmd.arg(field).arg(value);
239        }
240        let id: String = cmd.query_async(&mut *conn).await?;
241        Ok(id)
242    }
243
244    // ── List operations ───────────────────────────────────────────────
245
246    /// RPUSH — append one or more values to a list.
247    pub async fn list_push(&self, key: &str, values: &[String]) -> RedisResult<i64> {
248        if values.is_empty() {
249            return Ok(0);
250        }
251        let mut conn = self.pool.get().await?;
252        let len: i64 = conn.rpush(key, values).await?;
253        Ok(len)
254    }
255
256    pub async fn list_pop(&self, key: &str) -> RedisResult<Option<String>> {
257        let mut conn = self.pool.get().await?;
258        let val: Option<String> = conn.lpop(key, None).await?;
259        Ok(val)
260    }
261
262    pub async fn list_length(&self, key: &str) -> RedisResult<i64> {
263        let mut conn = self.pool.get().await?;
264        let len: i64 = conn.llen(key).await?;
265        Ok(len)
266    }
267
268    // ── Lua scripting ─────────────────────────────────────────────────
269
270    /// Execute a Lua script via EVAL.
271    pub async fn eval<T: redis::FromRedisValue>(
272        &self,
273        script: &str,
274        keys: &[&str],
275        args: &[&str],
276    ) -> RedisResult<T> {
277        let mut conn = self.pool.get().await?;
278        let result: T = redis::cmd("EVAL")
279            .arg(script)
280            .arg(keys.len())
281            .arg(keys)
282            .arg(args)
283            .query_async(&mut *conn)
284            .await?;
285        Ok(result)
286    }
287
288    // ── Health ─────────────────────────────────────────────────────────
289
290    pub async fn ping(&self) -> RedisResult<bool> {
291        let mut conn = self.pool.get().await?;
292        let response: String = redis::cmd("PING").query_async(&mut *conn).await?;
293        Ok(response == "PONG")
294    }
295
296    /// Health check — attempts a GET and returns false on error.
297    pub async fn health_check(&self) -> bool {
298        self.get("health_check").await.is_ok()
299    }
300
301    // ── Pool monitoring ──────────────────────────────────────────────
302
303    /// Total number of connections in the pool.
304    pub fn pool_size(&self) -> usize {
305        self.pool.status().size
306    }
307
308    /// Number of idle connections available for use.
309    pub fn available_connections(&self) -> usize {
310        self.pool.status().available
311    }
312
313    /// Number of tasks waiting for a connection.
314    pub fn waiting_connections(&self) -> usize {
315        self.pool.status().waiting
316    }
317}