Skip to main content

storage/
redis_cache.rs

1//! Redis L1.5 distributed cache layer
2//!
3//! Sits between the Moka L1 in-process cache and the VectorStorage backend.
4//! Provides cross-node cache sharing and pub/sub invalidation for HA deployments.
5
6use common::Vector;
7use futures_util::StreamExt;
8use redis::aio::ConnectionManager;
9use redis::AsyncCommands;
10use serde::{Deserialize, Serialize};
11
12/// Error type for Redis cache operations
13#[derive(Debug)]
14pub struct RedisError(pub String);
15
16impl std::fmt::Display for RedisError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        write!(f, "Redis error: {}", self.0)
19    }
20}
21
22impl From<redis::RedisError> for RedisError {
23    fn from(e: redis::RedisError) -> Self {
24        RedisError(e.to_string())
25    }
26}
27
28/// Cache invalidation messages for pub/sub
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum CacheInvalidation {
31    /// Specific vectors in a namespace were changed
32    Vectors { namespace: String, ids: Vec<String> },
33    /// An entire namespace was invalidated
34    Namespace(String),
35    /// Full cache invalidation
36    All,
37}
38
39/// Redis cache statistics
40#[derive(Debug, Clone, Default)]
41pub struct RedisCacheStats {
42    pub connected: bool,
43    pub used_memory_bytes: u64,
44    pub total_keys: u64,
45    pub hits: u64,
46    pub misses: u64,
47    pub hit_rate: f64,
48}
49
50const REDIS_KEY_PREFIX: &str = "buf";
51const REDIS_PUBSUB_CHANNEL: &str = "buffer:cache:invalidate";
52const DEFAULT_TTL_SECS: u64 = 3600; // 1 hour
53
54/// Redis L1.5 distributed cache
55#[derive(Clone)]
56pub struct RedisCache {
57    conn: ConnectionManager,
58    url: String,
59    default_ttl_secs: u64,
60}
61
62impl RedisCache {
63    /// Connect to Redis and create a new cache instance
64    pub async fn new(redis_url: &str) -> Result<Self, RedisError> {
65        let client = redis::Client::open(redis_url)
66            .map_err(|e| RedisError(format!("Failed to create Redis client: {}", e)))?;
67        let conn = ConnectionManager::new(client)
68            .await
69            .map_err(|e| RedisError(format!("Failed to connect to Redis: {}", e)))?;
70        Ok(Self {
71            conn,
72            url: redis_url.to_string(),
73            default_ttl_secs: DEFAULT_TTL_SECS,
74        })
75    }
76
77    /// Returns a clone of the underlying Redis connection manager.
78    ///
79    /// Cloning is cheap — `ConnectionManager` is backed by an `Arc`.
80    /// Used by the rate-limit middleware to share the connection pool.
81    pub fn connection(&self) -> ConnectionManager {
82        self.conn.clone()
83    }
84
85    /// Build a Redis key from namespace and vector ID
86    fn key(namespace: &str, id: &str) -> String {
87        format!("{}:{}:{}", REDIS_KEY_PREFIX, namespace, id)
88    }
89
90    /// Build a namespace scan pattern
91    fn namespace_pattern(namespace: &str) -> String {
92        format!("{}:{}:*", REDIS_KEY_PREFIX, namespace)
93    }
94
95    /// Get a single vector from Redis
96    pub async fn get(&self, namespace: &str, id: &str) -> Option<Vector> {
97        let key = Self::key(namespace, id);
98        let mut conn = self.conn.clone();
99        match conn.get::<_, Option<String>>(&key).await {
100            Ok(Some(json)) => {
101                metrics::counter!("buffer_redis_hits_total").increment(1);
102                match serde_json::from_str(&json) {
103                    Ok(v) => Some(v),
104                    Err(e) => {
105                        tracing::warn!(key = %key, error = %e, "Failed to deserialize vector from Redis");
106                        None
107                    }
108                }
109            }
110            Ok(None) => {
111                metrics::counter!("buffer_redis_misses_total").increment(1);
112                None
113            }
114            Err(e) => {
115                tracing::debug!(key = %key, error = %e, "Redis GET failed");
116                metrics::counter!("buffer_redis_misses_total").increment(1);
117                None
118            }
119        }
120    }
121
122    /// Get multiple vectors from Redis
123    pub async fn get_multi(&self, namespace: &str, ids: &[String]) -> Vec<Vector> {
124        if ids.is_empty() {
125            return Vec::new();
126        }
127        let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
128        let mut conn = self.conn.clone();
129
130        // Use MGET for batch retrieval
131        let results: Result<Vec<Option<String>>, _> =
132            redis::cmd("MGET").arg(&keys).query_async(&mut conn).await;
133
134        match results {
135            Ok(values) => {
136                let mut vectors = Vec::new();
137                for (i, val) in values.into_iter().enumerate() {
138                    match val {
139                        Some(json) => {
140                            metrics::counter!("buffer_redis_hits_total").increment(1);
141                            match serde_json::from_str::<Vector>(&json) {
142                                Ok(v) => vectors.push(v),
143                                Err(e) => {
144                                    tracing::warn!(key = %keys[i], error = %e, "Failed to deserialize vector from Redis");
145                                }
146                            }
147                        }
148                        None => {
149                            metrics::counter!("buffer_redis_misses_total").increment(1);
150                        }
151                    }
152                }
153                vectors
154            }
155            Err(e) => {
156                tracing::debug!(error = %e, "Redis MGET failed");
157                metrics::counter!("buffer_redis_misses_total").increment(ids.len() as u64);
158                Vec::new()
159            }
160        }
161    }
162
163    /// Store a single vector in Redis with TTL
164    pub async fn set(&self, namespace: &str, vector: &Vector) {
165        let key = Self::key(namespace, &vector.id);
166        let json = match serde_json::to_string(vector) {
167            Ok(j) => j,
168            Err(e) => {
169                tracing::warn!(key = %key, error = %e, "Failed to serialize vector for Redis");
170                return;
171            }
172        };
173        let mut conn = self.conn.clone();
174        if let Err(e) = conn
175            .set_ex::<_, _, ()>(&key, &json, self.default_ttl_secs)
176            .await
177        {
178            tracing::debug!(key = %key, error = %e, "Redis SET failed");
179        }
180    }
181
182    /// Store multiple vectors in Redis using pipeline
183    pub async fn set_batch(&self, namespace: &str, vectors: &[Vector]) {
184        if vectors.is_empty() {
185            return;
186        }
187        let mut conn = self.conn.clone();
188        let mut pipe = redis::pipe();
189        for vector in vectors {
190            let key = Self::key(namespace, &vector.id);
191            let json = match serde_json::to_string(vector) {
192                Ok(j) => j,
193                Err(_) => continue,
194            };
195            pipe.cmd("SET")
196                .arg(&key)
197                .arg(&json)
198                .arg("EX")
199                .arg(self.default_ttl_secs)
200                .ignore();
201        }
202        if let Err(e) = pipe.query_async::<()>(&mut conn).await {
203            tracing::debug!(error = %e, count = vectors.len(), "Redis pipeline SET failed");
204        }
205    }
206
207    /// Delete specific vectors from Redis
208    pub async fn delete(&self, namespace: &str, ids: &[String]) {
209        if ids.is_empty() {
210            return;
211        }
212        let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
213        let mut conn = self.conn.clone();
214        if let Err(e) = conn.del::<_, ()>(&keys).await {
215            tracing::debug!(error = %e, count = ids.len(), "Redis DEL failed");
216        }
217    }
218
219    /// Invalidate all entries for a namespace using SCAN + DEL
220    pub async fn invalidate_namespace(&self, namespace: &str) {
221        let pattern = Self::namespace_pattern(namespace);
222        let mut conn = self.conn.clone();
223        let mut cursor: u64 = 0;
224        let mut total_deleted = 0u64;
225
226        loop {
227            let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
228                .arg(cursor)
229                .arg("MATCH")
230                .arg(&pattern)
231                .arg("COUNT")
232                .arg(500)
233                .query_async(&mut conn)
234                .await;
235
236            match result {
237                Ok((next_cursor, keys)) => {
238                    if !keys.is_empty() {
239                        let _ = conn.del::<_, ()>(&keys).await;
240                        total_deleted += keys.len() as u64;
241                    }
242                    cursor = next_cursor;
243                    if cursor == 0 {
244                        break;
245                    }
246                }
247                Err(e) => {
248                    tracing::warn!(namespace, error = %e, "Redis SCAN+DEL failed during namespace invalidation");
249                    break;
250                }
251            }
252        }
253
254        if total_deleted > 0 {
255            tracing::debug!(
256                namespace,
257                deleted = total_deleted,
258                "Redis namespace invalidated"
259            );
260        }
261    }
262
263    /// Clear all buffer keys from Redis using SCAN + DEL
264    pub async fn clear_all(&self) {
265        let pattern = format!("{}:*", REDIS_KEY_PREFIX);
266        let mut conn = self.conn.clone();
267        let mut cursor: u64 = 0;
268
269        loop {
270            let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
271                .arg(cursor)
272                .arg("MATCH")
273                .arg(&pattern)
274                .arg("COUNT")
275                .arg(500)
276                .query_async(&mut conn)
277                .await;
278
279            match result {
280                Ok((next_cursor, keys)) => {
281                    if !keys.is_empty() {
282                        let _ = conn.del::<_, ()>(&keys).await;
283                    }
284                    cursor = next_cursor;
285                    if cursor == 0 {
286                        break;
287                    }
288                }
289                Err(e) => {
290                    tracing::warn!(error = %e, "Redis SCAN+DEL failed during full cache clear");
291                    break;
292                }
293            }
294        }
295
296        tracing::info!("Redis cache cleared");
297    }
298
299    /// Get Redis cache statistics from INFO command
300    pub async fn stats(&self) -> RedisCacheStats {
301        let mut conn = self.conn.clone();
302        let info: Result<String, _> = redis::cmd("INFO").query_async(&mut conn).await;
303
304        match info {
305            Ok(info_str) => {
306                let used_memory = Self::parse_info_field(&info_str, "used_memory")
307                    .and_then(|s| s.parse::<u64>().ok())
308                    .unwrap_or(0);
309                let hits = Self::parse_info_field(&info_str, "keyspace_hits")
310                    .and_then(|s| s.parse::<u64>().ok())
311                    .unwrap_or(0);
312                let misses = Self::parse_info_field(&info_str, "keyspace_misses")
313                    .and_then(|s| s.parse::<u64>().ok())
314                    .unwrap_or(0);
315
316                // Get key count via DBSIZE
317                let total_keys: u64 = redis::cmd("DBSIZE")
318                    .query_async(&mut conn)
319                    .await
320                    .unwrap_or(0);
321
322                let hit_rate = if hits + misses > 0 {
323                    hits as f64 / (hits + misses) as f64 * 100.0
324                } else {
325                    0.0
326                };
327
328                RedisCacheStats {
329                    connected: true,
330                    used_memory_bytes: used_memory,
331                    total_keys,
332                    hits,
333                    misses,
334                    hit_rate,
335                }
336            }
337            Err(e) => {
338                tracing::debug!(error = %e, "Redis INFO command failed");
339                RedisCacheStats {
340                    connected: false,
341                    ..Default::default()
342                }
343            }
344        }
345    }
346
347    /// Parse a field from Redis INFO output
348    fn parse_info_field<'a>(info: &'a str, field: &str) -> Option<&'a str> {
349        for line in info.lines() {
350            if let Some(value) = line.strip_prefix(&format!("{}:", field)) {
351                return Some(value.trim());
352            }
353        }
354        None
355    }
356
357    /// Publish a cache invalidation message via Redis pub/sub
358    pub async fn publish_invalidation(&self, msg: &CacheInvalidation) {
359        let json = match serde_json::to_string(msg) {
360            Ok(j) => j,
361            Err(e) => {
362                tracing::warn!(error = %e, "Failed to serialize cache invalidation message");
363                return;
364            }
365        };
366        let mut conn = self.conn.clone();
367        if let Err(e) = conn.publish::<_, _, ()>(REDIS_PUBSUB_CHANNEL, &json).await {
368            tracing::debug!(error = %e, "Redis PUBLISH failed for cache invalidation");
369        }
370    }
371
372    /// Publish a raw string message to any Redis channel.
373    /// Used for backup cache invalidation and other cross-node signaling.
374    pub async fn publish_raw(&self, channel: &str, message: &str) {
375        let mut conn = self.conn.clone();
376        if let Err(e) = conn.publish::<_, _, ()>(channel, message).await {
377            tracing::debug!(channel = %channel, error = %e, "Redis PUBLISH failed");
378        }
379    }
380
381    /// Subscribe to a Redis channel and return a receiver for raw string messages.
382    /// Spawns a background task that listens for messages on the channel.
383    pub async fn subscribe_raw(
384        &self,
385        channel: &str,
386    ) -> Result<tokio::sync::mpsc::Receiver<String>, RedisError> {
387        let client = redis::Client::open(self.url.as_str())
388            .map_err(|e| RedisError(format!("Failed to create Redis client for pub/sub: {}", e)))?;
389        let mut pubsub_conn = client
390            .get_async_pubsub()
391            .await
392            .map_err(|e| RedisError(format!("Failed to get Redis pub/sub connection: {}", e)))?;
393        pubsub_conn
394            .subscribe(channel)
395            .await
396            .map_err(|e| RedisError(format!("Failed to subscribe to {}: {}", channel, e)))?;
397
398        let (tx, rx) = tokio::sync::mpsc::channel(256);
399        let channel_name = channel.to_string();
400
401        tokio::spawn(async move {
402            let mut msg_stream = pubsub_conn.on_message();
403            while let Some(msg) = msg_stream.next().await {
404                let payload: String = match msg.get_payload() {
405                    Ok(p) => p,
406                    Err(e) => {
407                        tracing::debug!(error = %e, "Failed to get pub/sub message payload");
408                        continue;
409                    }
410                };
411                if tx.send(payload).await.is_err() {
412                    tracing::debug!(channel = %channel_name, "Pub/sub receiver dropped, stopping");
413                    break;
414                }
415            }
416            tracing::warn!(channel = %channel_name, "Redis pub/sub raw stream ended");
417        });
418
419        tracing::info!(channel = %channel, "Redis raw pub/sub subscription started");
420        Ok(rx)
421    }
422
423    /// Subscribe to cache invalidation messages.
424    /// This is a long-running async function that calls the handler for each message.
425    /// Should be spawned as a background task.
426    pub async fn subscribe_invalidations<F>(&self, mut handler: F)
427    where
428        F: FnMut(CacheInvalidation) + Send + 'static,
429    {
430        // Create a separate connection for pub/sub (can't reuse ConnectionManager)
431        let client = match redis::Client::open(self.url.as_str()) {
432            Ok(c) => c,
433            Err(e) => {
434                tracing::error!(error = %e, "Failed to create Redis client for pub/sub");
435                return;
436            }
437        };
438
439        let mut pubsub_conn = match client.get_async_pubsub().await {
440            Ok(c) => c,
441            Err(e) => {
442                tracing::error!(error = %e, "Failed to get Redis pub/sub connection");
443                return;
444            }
445        };
446
447        if let Err(e) = pubsub_conn.subscribe(REDIS_PUBSUB_CHANNEL).await {
448            tracing::error!(error = %e, "Failed to subscribe to Redis invalidation channel");
449            return;
450        }
451
452        tracing::info!("Redis pub/sub subscribed to {}", REDIS_PUBSUB_CHANNEL);
453
454        let mut msg_stream = pubsub_conn.on_message();
455        while let Some(msg) = msg_stream.next().await {
456            let payload: String = match msg.get_payload() {
457                Ok(p) => p,
458                Err(e) => {
459                    tracing::debug!(error = %e, "Failed to get pub/sub message payload");
460                    continue;
461                }
462            };
463            match serde_json::from_str::<CacheInvalidation>(&payload) {
464                Ok(invalidation) => handler(invalidation),
465                Err(e) => {
466                    tracing::debug!(error = %e, "Failed to deserialize invalidation message");
467                }
468            }
469        }
470
471        tracing::warn!("Redis pub/sub stream ended");
472    }
473
474    /// Try to acquire a distributed lock via SET NX EX.
475    ///
476    /// Returns `true` if the lock was acquired (this replica is the leader),
477    /// `false` if another replica already holds it.
478    ///
479    /// On Redis error (connection failure, timeout) returns `true` so callers
480    /// gracefully degrade to in-process mode — the operation runs on every
481    /// replica independently rather than not running at all.
482    /// Callers MUST call `release_lock` after their critical section completes.
483    pub async fn try_acquire_lock(&self, key: &str, owner: &str, ttl_secs: u64) -> bool {
484        let mut conn = self.conn.clone();
485        let result: Result<Option<String>, _> = redis::cmd("SET")
486            .arg(key)
487            .arg(owner)
488            .arg("EX")
489            .arg(ttl_secs)
490            .arg("NX")
491            .query_async(&mut conn)
492            .await;
493        match result {
494            Ok(Some(_)) => {
495                tracing::debug!(key = %key, owner = %owner, "Distributed lock acquired");
496                true
497            }
498            Ok(None) => false, // Another replica holds the lock
499            Err(e) => {
500                tracing::warn!(
501                    key = %key,
502                    error = %e,
503                    "Redis lock acquire failed — running as single-node fallback"
504                );
505                true // Graceful degradation: run anyway rather than skip the operation
506            }
507        }
508    }
509
510    /// Release a distributed lock, but only if this replica still owns it.
511    ///
512    /// Uses a Lua script for atomic check-and-delete.
513    pub async fn release_lock(&self, key: &str, owner: &str) {
514        let mut conn = self.conn.clone();
515        let script = redis::Script::new(
516            r#"if redis.call('get', KEYS[1]) == ARGV[1] then
517                 return redis.call('del', KEYS[1])
518               else
519                 return 0
520               end"#,
521        );
522        if let Err(e) = script
523            .key(key)
524            .arg(owner)
525            .invoke_async::<i64>(&mut conn)
526            .await
527        {
528            tracing::debug!(key = %key, error = %e, "Redis lock release failed (lock may have already expired)");
529        }
530    }
531}
532
533impl std::fmt::Debug for RedisCache {
534    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535        f.debug_struct("RedisCache")
536            .field("url", &self.url)
537            .field("default_ttl_secs", &self.default_ttl_secs)
538            .finish()
539    }
540}