Skip to main content

storage/
redis_cache.rs

1//! Redis L1.5 distributed cache layer
2//!
3//! Sits between the Moka L1 in-process cache and the VectorStorage backend.
4//! Provides cross-node cache sharing and pub/sub invalidation for HA deployments.
5
6use common::Vector;
7use futures_util::StreamExt;
8use redis::aio::ConnectionManager;
9use redis::AsyncCommands;
10use serde::{Deserialize, Serialize};
11
12/// Error type for Redis cache operations
13#[derive(Debug)]
14pub struct RedisError(pub String);
15
16impl std::fmt::Display for RedisError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        write!(f, "Redis error: {}", self.0)
19    }
20}
21
22impl From<redis::RedisError> for RedisError {
23    fn from(e: redis::RedisError) -> Self {
24        RedisError(e.to_string())
25    }
26}
27
28/// Cache invalidation messages for pub/sub
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum CacheInvalidation {
31    /// Specific vectors in a namespace were changed
32    Vectors { namespace: String, ids: Vec<String> },
33    /// An entire namespace was invalidated
34    Namespace(String),
35    /// Full cache invalidation
36    All,
37}
38
39/// Redis cache statistics
40#[derive(Debug, Clone, Default)]
41pub struct RedisCacheStats {
42    pub connected: bool,
43    pub used_memory_bytes: u64,
44    pub total_keys: u64,
45    pub hits: u64,
46    pub misses: u64,
47    pub hit_rate: f64,
48}
49
50const REDIS_KEY_PREFIX: &str = "buf";
51const REDIS_PUBSUB_CHANNEL: &str = "buffer:cache:invalidate";
52const DEFAULT_TTL_SECS: u64 = 3600; // 1 hour
53
54/// Redis L1.5 distributed cache
55#[derive(Clone)]
56pub struct RedisCache {
57    conn: ConnectionManager,
58    url: String,
59    default_ttl_secs: u64,
60}
61
62impl RedisCache {
63    /// Connect to Redis and create a new cache instance
64    pub async fn new(redis_url: &str) -> Result<Self, RedisError> {
65        let client = redis::Client::open(redis_url)
66            .map_err(|e| RedisError(format!("Failed to create Redis client: {}", e)))?;
67        let conn = ConnectionManager::new(client)
68            .await
69            .map_err(|e| RedisError(format!("Failed to connect to Redis: {}", e)))?;
70        Ok(Self {
71            conn,
72            url: redis_url.to_string(),
73            default_ttl_secs: DEFAULT_TTL_SECS,
74        })
75    }
76
77    /// Build a Redis key from namespace and vector ID
78    fn key(namespace: &str, id: &str) -> String {
79        format!("{}:{}:{}", REDIS_KEY_PREFIX, namespace, id)
80    }
81
82    /// Build a namespace scan pattern
83    fn namespace_pattern(namespace: &str) -> String {
84        format!("{}:{}:*", REDIS_KEY_PREFIX, namespace)
85    }
86
87    /// Get a single vector from Redis
88    pub async fn get(&self, namespace: &str, id: &str) -> Option<Vector> {
89        let key = Self::key(namespace, id);
90        let mut conn = self.conn.clone();
91        match conn.get::<_, Option<String>>(&key).await {
92            Ok(Some(json)) => {
93                metrics::counter!("buffer_redis_hits_total").increment(1);
94                match serde_json::from_str(&json) {
95                    Ok(v) => Some(v),
96                    Err(e) => {
97                        tracing::warn!(key = %key, error = %e, "Failed to deserialize vector from Redis");
98                        None
99                    }
100                }
101            }
102            Ok(None) => {
103                metrics::counter!("buffer_redis_misses_total").increment(1);
104                None
105            }
106            Err(e) => {
107                tracing::debug!(key = %key, error = %e, "Redis GET failed");
108                metrics::counter!("buffer_redis_misses_total").increment(1);
109                None
110            }
111        }
112    }
113
114    /// Get multiple vectors from Redis
115    pub async fn get_multi(&self, namespace: &str, ids: &[String]) -> Vec<Vector> {
116        if ids.is_empty() {
117            return Vec::new();
118        }
119        let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
120        let mut conn = self.conn.clone();
121
122        // Use MGET for batch retrieval
123        let results: Result<Vec<Option<String>>, _> =
124            redis::cmd("MGET").arg(&keys).query_async(&mut conn).await;
125
126        match results {
127            Ok(values) => {
128                let mut vectors = Vec::new();
129                for (i, val) in values.into_iter().enumerate() {
130                    match val {
131                        Some(json) => {
132                            metrics::counter!("buffer_redis_hits_total").increment(1);
133                            match serde_json::from_str::<Vector>(&json) {
134                                Ok(v) => vectors.push(v),
135                                Err(e) => {
136                                    tracing::warn!(key = %keys[i], error = %e, "Failed to deserialize vector from Redis");
137                                }
138                            }
139                        }
140                        None => {
141                            metrics::counter!("buffer_redis_misses_total").increment(1);
142                        }
143                    }
144                }
145                vectors
146            }
147            Err(e) => {
148                tracing::debug!(error = %e, "Redis MGET failed");
149                metrics::counter!("buffer_redis_misses_total").increment(ids.len() as u64);
150                Vec::new()
151            }
152        }
153    }
154
155    /// Store a single vector in Redis with TTL
156    pub async fn set(&self, namespace: &str, vector: &Vector) {
157        let key = Self::key(namespace, &vector.id);
158        let json = match serde_json::to_string(vector) {
159            Ok(j) => j,
160            Err(e) => {
161                tracing::warn!(key = %key, error = %e, "Failed to serialize vector for Redis");
162                return;
163            }
164        };
165        let mut conn = self.conn.clone();
166        if let Err(e) = conn
167            .set_ex::<_, _, ()>(&key, &json, self.default_ttl_secs)
168            .await
169        {
170            tracing::debug!(key = %key, error = %e, "Redis SET failed");
171        }
172    }
173
174    /// Store multiple vectors in Redis using pipeline
175    pub async fn set_batch(&self, namespace: &str, vectors: &[Vector]) {
176        if vectors.is_empty() {
177            return;
178        }
179        let mut conn = self.conn.clone();
180        let mut pipe = redis::pipe();
181        for vector in vectors {
182            let key = Self::key(namespace, &vector.id);
183            let json = match serde_json::to_string(vector) {
184                Ok(j) => j,
185                Err(_) => continue,
186            };
187            pipe.cmd("SET")
188                .arg(&key)
189                .arg(&json)
190                .arg("EX")
191                .arg(self.default_ttl_secs)
192                .ignore();
193        }
194        if let Err(e) = pipe.query_async::<()>(&mut conn).await {
195            tracing::debug!(error = %e, count = vectors.len(), "Redis pipeline SET failed");
196        }
197    }
198
199    /// Delete specific vectors from Redis
200    pub async fn delete(&self, namespace: &str, ids: &[String]) {
201        if ids.is_empty() {
202            return;
203        }
204        let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
205        let mut conn = self.conn.clone();
206        if let Err(e) = conn.del::<_, ()>(&keys).await {
207            tracing::debug!(error = %e, count = ids.len(), "Redis DEL failed");
208        }
209    }
210
211    /// Invalidate all entries for a namespace using SCAN + DEL
212    pub async fn invalidate_namespace(&self, namespace: &str) {
213        let pattern = Self::namespace_pattern(namespace);
214        let mut conn = self.conn.clone();
215        let mut cursor: u64 = 0;
216        let mut total_deleted = 0u64;
217
218        loop {
219            let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
220                .arg(cursor)
221                .arg("MATCH")
222                .arg(&pattern)
223                .arg("COUNT")
224                .arg(500)
225                .query_async(&mut conn)
226                .await;
227
228            match result {
229                Ok((next_cursor, keys)) => {
230                    if !keys.is_empty() {
231                        let _ = conn.del::<_, ()>(&keys).await;
232                        total_deleted += keys.len() as u64;
233                    }
234                    cursor = next_cursor;
235                    if cursor == 0 {
236                        break;
237                    }
238                }
239                Err(e) => {
240                    tracing::warn!(namespace, error = %e, "Redis SCAN+DEL failed during namespace invalidation");
241                    break;
242                }
243            }
244        }
245
246        if total_deleted > 0 {
247            tracing::debug!(
248                namespace,
249                deleted = total_deleted,
250                "Redis namespace invalidated"
251            );
252        }
253    }
254
255    /// Clear all buffer keys from Redis using SCAN + DEL
256    pub async fn clear_all(&self) {
257        let pattern = format!("{}:*", REDIS_KEY_PREFIX);
258        let mut conn = self.conn.clone();
259        let mut cursor: u64 = 0;
260
261        loop {
262            let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
263                .arg(cursor)
264                .arg("MATCH")
265                .arg(&pattern)
266                .arg("COUNT")
267                .arg(500)
268                .query_async(&mut conn)
269                .await;
270
271            match result {
272                Ok((next_cursor, keys)) => {
273                    if !keys.is_empty() {
274                        let _ = conn.del::<_, ()>(&keys).await;
275                    }
276                    cursor = next_cursor;
277                    if cursor == 0 {
278                        break;
279                    }
280                }
281                Err(e) => {
282                    tracing::warn!(error = %e, "Redis SCAN+DEL failed during full cache clear");
283                    break;
284                }
285            }
286        }
287
288        tracing::info!("Redis cache cleared");
289    }
290
291    /// Get Redis cache statistics from INFO command
292    pub async fn stats(&self) -> RedisCacheStats {
293        let mut conn = self.conn.clone();
294        let info: Result<String, _> = redis::cmd("INFO").query_async(&mut conn).await;
295
296        match info {
297            Ok(info_str) => {
298                let used_memory = Self::parse_info_field(&info_str, "used_memory")
299                    .and_then(|s| s.parse::<u64>().ok())
300                    .unwrap_or(0);
301                let hits = Self::parse_info_field(&info_str, "keyspace_hits")
302                    .and_then(|s| s.parse::<u64>().ok())
303                    .unwrap_or(0);
304                let misses = Self::parse_info_field(&info_str, "keyspace_misses")
305                    .and_then(|s| s.parse::<u64>().ok())
306                    .unwrap_or(0);
307
308                // Get key count via DBSIZE
309                let total_keys: u64 = redis::cmd("DBSIZE")
310                    .query_async(&mut conn)
311                    .await
312                    .unwrap_or(0);
313
314                let hit_rate = if hits + misses > 0 {
315                    hits as f64 / (hits + misses) as f64 * 100.0
316                } else {
317                    0.0
318                };
319
320                RedisCacheStats {
321                    connected: true,
322                    used_memory_bytes: used_memory,
323                    total_keys,
324                    hits,
325                    misses,
326                    hit_rate,
327                }
328            }
329            Err(e) => {
330                tracing::debug!(error = %e, "Redis INFO command failed");
331                RedisCacheStats {
332                    connected: false,
333                    ..Default::default()
334                }
335            }
336        }
337    }
338
339    /// Parse a field from Redis INFO output
340    fn parse_info_field<'a>(info: &'a str, field: &str) -> Option<&'a str> {
341        for line in info.lines() {
342            if let Some(value) = line.strip_prefix(&format!("{}:", field)) {
343                return Some(value.trim());
344            }
345        }
346        None
347    }
348
349    /// Publish a cache invalidation message via Redis pub/sub
350    pub async fn publish_invalidation(&self, msg: &CacheInvalidation) {
351        let json = match serde_json::to_string(msg) {
352            Ok(j) => j,
353            Err(e) => {
354                tracing::warn!(error = %e, "Failed to serialize cache invalidation message");
355                return;
356            }
357        };
358        let mut conn = self.conn.clone();
359        if let Err(e) = conn.publish::<_, _, ()>(REDIS_PUBSUB_CHANNEL, &json).await {
360            tracing::debug!(error = %e, "Redis PUBLISH failed for cache invalidation");
361        }
362    }
363
364    /// Publish a raw string message to any Redis channel.
365    /// Used for backup cache invalidation and other cross-node signaling.
366    pub async fn publish_raw(&self, channel: &str, message: &str) {
367        let mut conn = self.conn.clone();
368        if let Err(e) = conn.publish::<_, _, ()>(channel, message).await {
369            tracing::debug!(channel = %channel, error = %e, "Redis PUBLISH failed");
370        }
371    }
372
373    /// Subscribe to a Redis channel and return a receiver for raw string messages.
374    /// Spawns a background task that listens for messages on the channel.
375    pub async fn subscribe_raw(
376        &self,
377        channel: &str,
378    ) -> Result<tokio::sync::mpsc::Receiver<String>, RedisError> {
379        let client = redis::Client::open(self.url.as_str())
380            .map_err(|e| RedisError(format!("Failed to create Redis client for pub/sub: {}", e)))?;
381        let mut pubsub_conn = client
382            .get_async_pubsub()
383            .await
384            .map_err(|e| RedisError(format!("Failed to get Redis pub/sub connection: {}", e)))?;
385        pubsub_conn
386            .subscribe(channel)
387            .await
388            .map_err(|e| RedisError(format!("Failed to subscribe to {}: {}", channel, e)))?;
389
390        let (tx, rx) = tokio::sync::mpsc::channel(256);
391        let channel_name = channel.to_string();
392
393        tokio::spawn(async move {
394            let mut msg_stream = pubsub_conn.on_message();
395            while let Some(msg) = msg_stream.next().await {
396                let payload: String = match msg.get_payload() {
397                    Ok(p) => p,
398                    Err(e) => {
399                        tracing::debug!(error = %e, "Failed to get pub/sub message payload");
400                        continue;
401                    }
402                };
403                if tx.send(payload).await.is_err() {
404                    tracing::debug!(channel = %channel_name, "Pub/sub receiver dropped, stopping");
405                    break;
406                }
407            }
408            tracing::warn!(channel = %channel_name, "Redis pub/sub raw stream ended");
409        });
410
411        tracing::info!(channel = %channel, "Redis raw pub/sub subscription started");
412        Ok(rx)
413    }
414
415    /// Subscribe to cache invalidation messages.
416    /// This is a long-running async function that calls the handler for each message.
417    /// Should be spawned as a background task.
418    pub async fn subscribe_invalidations<F>(&self, mut handler: F)
419    where
420        F: FnMut(CacheInvalidation) + Send + 'static,
421    {
422        // Create a separate connection for pub/sub (can't reuse ConnectionManager)
423        let client = match redis::Client::open(self.url.as_str()) {
424            Ok(c) => c,
425            Err(e) => {
426                tracing::error!(error = %e, "Failed to create Redis client for pub/sub");
427                return;
428            }
429        };
430
431        let mut pubsub_conn = match client.get_async_pubsub().await {
432            Ok(c) => c,
433            Err(e) => {
434                tracing::error!(error = %e, "Failed to get Redis pub/sub connection");
435                return;
436            }
437        };
438
439        if let Err(e) = pubsub_conn.subscribe(REDIS_PUBSUB_CHANNEL).await {
440            tracing::error!(error = %e, "Failed to subscribe to Redis invalidation channel");
441            return;
442        }
443
444        tracing::info!("Redis pub/sub subscribed to {}", REDIS_PUBSUB_CHANNEL);
445
446        let mut msg_stream = pubsub_conn.on_message();
447        while let Some(msg) = msg_stream.next().await {
448            let payload: String = match msg.get_payload() {
449                Ok(p) => p,
450                Err(e) => {
451                    tracing::debug!(error = %e, "Failed to get pub/sub message payload");
452                    continue;
453                }
454            };
455            match serde_json::from_str::<CacheInvalidation>(&payload) {
456                Ok(invalidation) => handler(invalidation),
457                Err(e) => {
458                    tracing::debug!(error = %e, "Failed to deserialize invalidation message");
459                }
460            }
461        }
462
463        tracing::warn!("Redis pub/sub stream ended");
464    }
465}
466
467impl std::fmt::Debug for RedisCache {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        f.debug_struct("RedisCache")
470            .field("url", &self.url)
471            .field("default_ttl_secs", &self.default_ttl_secs)
472            .finish()
473    }
474}