1use common::Vector;
7use futures_util::StreamExt;
8use redis::aio::ConnectionManager;
9use redis::AsyncCommands;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug)]
14pub struct RedisError(pub String);
15
16impl std::fmt::Display for RedisError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 write!(f, "Redis error: {}", self.0)
19 }
20}
21
22impl From<redis::RedisError> for RedisError {
23 fn from(e: redis::RedisError) -> Self {
24 RedisError(e.to_string())
25 }
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum CacheInvalidation {
31 Vectors { namespace: String, ids: Vec<String> },
33 Namespace(String),
35 All,
37}
38
39#[derive(Debug, Clone, Default)]
41pub struct RedisCacheStats {
42 pub connected: bool,
43 pub used_memory_bytes: u64,
44 pub total_keys: u64,
45 pub hits: u64,
46 pub misses: u64,
47 pub hit_rate: f64,
48}
49
50const REDIS_KEY_PREFIX: &str = "buf";
51const REDIS_PUBSUB_CHANNEL: &str = "buffer:cache:invalidate";
52const DEFAULT_TTL_SECS: u64 = 3600; #[derive(Clone)]
56pub struct RedisCache {
57 conn: ConnectionManager,
58 url: String,
59 default_ttl_secs: u64,
60}
61
62impl RedisCache {
63 pub async fn new(redis_url: &str) -> Result<Self, RedisError> {
65 let client = redis::Client::open(redis_url)
66 .map_err(|e| RedisError(format!("Failed to create Redis client: {}", e)))?;
67 let conn = ConnectionManager::new(client)
68 .await
69 .map_err(|e| RedisError(format!("Failed to connect to Redis: {}", e)))?;
70 Ok(Self {
71 conn,
72 url: redis_url.to_string(),
73 default_ttl_secs: DEFAULT_TTL_SECS,
74 })
75 }
76
77 pub fn connection(&self) -> ConnectionManager {
82 self.conn.clone()
83 }
84
85 fn key(namespace: &str, id: &str) -> String {
87 format!("{}:{}:{}", REDIS_KEY_PREFIX, namespace, id)
88 }
89
90 fn namespace_pattern(namespace: &str) -> String {
92 format!("{}:{}:*", REDIS_KEY_PREFIX, namespace)
93 }
94
95 pub async fn get(&self, namespace: &str, id: &str) -> Option<Vector> {
97 let key = Self::key(namespace, id);
98 let mut conn = self.conn.clone();
99 match conn.get::<_, Option<String>>(&key).await {
100 Ok(Some(json)) => {
101 metrics::counter!("buffer_redis_hits_total").increment(1);
102 match serde_json::from_str(&json) {
103 Ok(v) => Some(v),
104 Err(e) => {
105 tracing::warn!(key = %key, error = %e, "Failed to deserialize vector from Redis");
106 None
107 }
108 }
109 }
110 Ok(None) => {
111 metrics::counter!("buffer_redis_misses_total").increment(1);
112 None
113 }
114 Err(e) => {
115 tracing::debug!(key = %key, error = %e, "Redis GET failed");
116 metrics::counter!("buffer_redis_misses_total").increment(1);
117 None
118 }
119 }
120 }
121
122 pub async fn get_multi(&self, namespace: &str, ids: &[String]) -> Vec<Vector> {
124 if ids.is_empty() {
125 return Vec::new();
126 }
127 let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
128 let mut conn = self.conn.clone();
129
130 let results: Result<Vec<Option<String>>, _> =
132 redis::cmd("MGET").arg(&keys).query_async(&mut conn).await;
133
134 match results {
135 Ok(values) => {
136 let mut vectors = Vec::new();
137 for (i, val) in values.into_iter().enumerate() {
138 match val {
139 Some(json) => {
140 metrics::counter!("buffer_redis_hits_total").increment(1);
141 match serde_json::from_str::<Vector>(&json) {
142 Ok(v) => vectors.push(v),
143 Err(e) => {
144 tracing::warn!(key = %keys[i], error = %e, "Failed to deserialize vector from Redis");
145 }
146 }
147 }
148 None => {
149 metrics::counter!("buffer_redis_misses_total").increment(1);
150 }
151 }
152 }
153 vectors
154 }
155 Err(e) => {
156 tracing::debug!(error = %e, "Redis MGET failed");
157 metrics::counter!("buffer_redis_misses_total").increment(ids.len() as u64);
158 Vec::new()
159 }
160 }
161 }
162
163 pub async fn set(&self, namespace: &str, vector: &Vector) {
165 let key = Self::key(namespace, &vector.id);
166 let json = match serde_json::to_string(vector) {
167 Ok(j) => j,
168 Err(e) => {
169 tracing::warn!(key = %key, error = %e, "Failed to serialize vector for Redis");
170 return;
171 }
172 };
173 let mut conn = self.conn.clone();
174 if let Err(e) = conn
175 .set_ex::<_, _, ()>(&key, &json, self.default_ttl_secs)
176 .await
177 {
178 tracing::debug!(key = %key, error = %e, "Redis SET failed");
179 }
180 }
181
182 pub async fn set_batch(&self, namespace: &str, vectors: &[Vector]) {
184 if vectors.is_empty() {
185 return;
186 }
187 let mut conn = self.conn.clone();
188 let mut pipe = redis::pipe();
189 for vector in vectors {
190 let key = Self::key(namespace, &vector.id);
191 let json = match serde_json::to_string(vector) {
192 Ok(j) => j,
193 Err(_) => continue,
194 };
195 pipe.cmd("SET")
196 .arg(&key)
197 .arg(&json)
198 .arg("EX")
199 .arg(self.default_ttl_secs)
200 .ignore();
201 }
202 if let Err(e) = pipe.query_async::<()>(&mut conn).await {
203 tracing::debug!(error = %e, count = vectors.len(), "Redis pipeline SET failed");
204 }
205 }
206
207 pub async fn delete(&self, namespace: &str, ids: &[String]) {
209 if ids.is_empty() {
210 return;
211 }
212 let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
213 let mut conn = self.conn.clone();
214 if let Err(e) = conn.del::<_, ()>(&keys).await {
215 tracing::debug!(error = %e, count = ids.len(), "Redis DEL failed");
216 }
217 }
218
219 pub async fn invalidate_namespace(&self, namespace: &str) {
221 let pattern = Self::namespace_pattern(namespace);
222 let mut conn = self.conn.clone();
223 let mut cursor: u64 = 0;
224 let mut total_deleted = 0u64;
225
226 loop {
227 let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
228 .arg(cursor)
229 .arg("MATCH")
230 .arg(&pattern)
231 .arg("COUNT")
232 .arg(500)
233 .query_async(&mut conn)
234 .await;
235
236 match result {
237 Ok((next_cursor, keys)) => {
238 if !keys.is_empty() {
239 let _ = conn.del::<_, ()>(&keys).await;
240 total_deleted += keys.len() as u64;
241 }
242 cursor = next_cursor;
243 if cursor == 0 {
244 break;
245 }
246 }
247 Err(e) => {
248 tracing::warn!(namespace, error = %e, "Redis SCAN+DEL failed during namespace invalidation");
249 break;
250 }
251 }
252 }
253
254 if total_deleted > 0 {
255 tracing::debug!(
256 namespace,
257 deleted = total_deleted,
258 "Redis namespace invalidated"
259 );
260 }
261 }
262
263 pub async fn clear_all(&self) {
265 let pattern = format!("{}:*", REDIS_KEY_PREFIX);
266 let mut conn = self.conn.clone();
267 let mut cursor: u64 = 0;
268
269 loop {
270 let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
271 .arg(cursor)
272 .arg("MATCH")
273 .arg(&pattern)
274 .arg("COUNT")
275 .arg(500)
276 .query_async(&mut conn)
277 .await;
278
279 match result {
280 Ok((next_cursor, keys)) => {
281 if !keys.is_empty() {
282 let _ = conn.del::<_, ()>(&keys).await;
283 }
284 cursor = next_cursor;
285 if cursor == 0 {
286 break;
287 }
288 }
289 Err(e) => {
290 tracing::warn!(error = %e, "Redis SCAN+DEL failed during full cache clear");
291 break;
292 }
293 }
294 }
295
296 tracing::info!("Redis cache cleared");
297 }
298
299 pub async fn stats(&self) -> RedisCacheStats {
301 let mut conn = self.conn.clone();
302 let info: Result<String, _> = redis::cmd("INFO").query_async(&mut conn).await;
303
304 match info {
305 Ok(info_str) => {
306 let used_memory = Self::parse_info_field(&info_str, "used_memory")
307 .and_then(|s| s.parse::<u64>().ok())
308 .unwrap_or(0);
309 let hits = Self::parse_info_field(&info_str, "keyspace_hits")
310 .and_then(|s| s.parse::<u64>().ok())
311 .unwrap_or(0);
312 let misses = Self::parse_info_field(&info_str, "keyspace_misses")
313 .and_then(|s| s.parse::<u64>().ok())
314 .unwrap_or(0);
315
316 let total_keys: u64 = redis::cmd("DBSIZE")
318 .query_async(&mut conn)
319 .await
320 .unwrap_or(0);
321
322 let hit_rate = if hits + misses > 0 {
323 hits as f64 / (hits + misses) as f64 * 100.0
324 } else {
325 0.0
326 };
327
328 RedisCacheStats {
329 connected: true,
330 used_memory_bytes: used_memory,
331 total_keys,
332 hits,
333 misses,
334 hit_rate,
335 }
336 }
337 Err(e) => {
338 tracing::debug!(error = %e, "Redis INFO command failed");
339 RedisCacheStats {
340 connected: false,
341 ..Default::default()
342 }
343 }
344 }
345 }
346
347 fn parse_info_field<'a>(info: &'a str, field: &str) -> Option<&'a str> {
349 for line in info.lines() {
350 if let Some(value) = line.strip_prefix(&format!("{}:", field)) {
351 return Some(value.trim());
352 }
353 }
354 None
355 }
356
357 pub async fn publish_invalidation(&self, msg: &CacheInvalidation) {
359 let json = match serde_json::to_string(msg) {
360 Ok(j) => j,
361 Err(e) => {
362 tracing::warn!(error = %e, "Failed to serialize cache invalidation message");
363 return;
364 }
365 };
366 let mut conn = self.conn.clone();
367 if let Err(e) = conn.publish::<_, _, ()>(REDIS_PUBSUB_CHANNEL, &json).await {
368 tracing::debug!(error = %e, "Redis PUBLISH failed for cache invalidation");
369 }
370 }
371
372 pub async fn publish_raw(&self, channel: &str, message: &str) {
375 let mut conn = self.conn.clone();
376 if let Err(e) = conn.publish::<_, _, ()>(channel, message).await {
377 tracing::debug!(channel = %channel, error = %e, "Redis PUBLISH failed");
378 }
379 }
380
381 pub async fn subscribe_raw(
384 &self,
385 channel: &str,
386 ) -> Result<tokio::sync::mpsc::Receiver<String>, RedisError> {
387 let client = redis::Client::open(self.url.as_str())
388 .map_err(|e| RedisError(format!("Failed to create Redis client for pub/sub: {}", e)))?;
389 let mut pubsub_conn = client
390 .get_async_pubsub()
391 .await
392 .map_err(|e| RedisError(format!("Failed to get Redis pub/sub connection: {}", e)))?;
393 pubsub_conn
394 .subscribe(channel)
395 .await
396 .map_err(|e| RedisError(format!("Failed to subscribe to {}: {}", channel, e)))?;
397
398 let (tx, rx) = tokio::sync::mpsc::channel(256);
399 let channel_name = channel.to_string();
400
401 tokio::spawn(async move {
402 let mut msg_stream = pubsub_conn.on_message();
403 while let Some(msg) = msg_stream.next().await {
404 let payload: String = match msg.get_payload() {
405 Ok(p) => p,
406 Err(e) => {
407 tracing::debug!(error = %e, "Failed to get pub/sub message payload");
408 continue;
409 }
410 };
411 if tx.send(payload).await.is_err() {
412 tracing::debug!(channel = %channel_name, "Pub/sub receiver dropped, stopping");
413 break;
414 }
415 }
416 tracing::warn!(channel = %channel_name, "Redis pub/sub raw stream ended");
417 });
418
419 tracing::info!(channel = %channel, "Redis raw pub/sub subscription started");
420 Ok(rx)
421 }
422
423 pub async fn subscribe_invalidations<F>(&self, mut handler: F)
427 where
428 F: FnMut(CacheInvalidation) + Send + 'static,
429 {
430 let client = match redis::Client::open(self.url.as_str()) {
432 Ok(c) => c,
433 Err(e) => {
434 tracing::error!(error = %e, "Failed to create Redis client for pub/sub");
435 return;
436 }
437 };
438
439 let mut pubsub_conn = match client.get_async_pubsub().await {
440 Ok(c) => c,
441 Err(e) => {
442 tracing::error!(error = %e, "Failed to get Redis pub/sub connection");
443 return;
444 }
445 };
446
447 if let Err(e) = pubsub_conn.subscribe(REDIS_PUBSUB_CHANNEL).await {
448 tracing::error!(error = %e, "Failed to subscribe to Redis invalidation channel");
449 return;
450 }
451
452 tracing::info!("Redis pub/sub subscribed to {}", REDIS_PUBSUB_CHANNEL);
453
454 let mut msg_stream = pubsub_conn.on_message();
455 while let Some(msg) = msg_stream.next().await {
456 let payload: String = match msg.get_payload() {
457 Ok(p) => p,
458 Err(e) => {
459 tracing::debug!(error = %e, "Failed to get pub/sub message payload");
460 continue;
461 }
462 };
463 match serde_json::from_str::<CacheInvalidation>(&payload) {
464 Ok(invalidation) => handler(invalidation),
465 Err(e) => {
466 tracing::debug!(error = %e, "Failed to deserialize invalidation message");
467 }
468 }
469 }
470
471 tracing::warn!("Redis pub/sub stream ended");
472 }
473
474 pub async fn try_acquire_lock(&self, key: &str, owner: &str, ttl_secs: u64) -> bool {
484 let mut conn = self.conn.clone();
485 let result: Result<Option<String>, _> = redis::cmd("SET")
486 .arg(key)
487 .arg(owner)
488 .arg("EX")
489 .arg(ttl_secs)
490 .arg("NX")
491 .query_async(&mut conn)
492 .await;
493 match result {
494 Ok(Some(_)) => {
495 tracing::debug!(key = %key, owner = %owner, "Distributed lock acquired");
496 true
497 }
498 Ok(None) => false, Err(e) => {
500 tracing::warn!(
501 key = %key,
502 error = %e,
503 "Redis lock acquire failed — running as single-node fallback"
504 );
505 true }
507 }
508 }
509
510 pub async fn release_lock(&self, key: &str, owner: &str) {
514 let mut conn = self.conn.clone();
515 let script = redis::Script::new(
516 r#"if redis.call('get', KEYS[1]) == ARGV[1] then
517 return redis.call('del', KEYS[1])
518 else
519 return 0
520 end"#,
521 );
522 if let Err(e) = script
523 .key(key)
524 .arg(owner)
525 .invoke_async::<i64>(&mut conn)
526 .await
527 {
528 tracing::debug!(key = %key, error = %e, "Redis lock release failed (lock may have already expired)");
529 }
530 }
531}
532
533impl std::fmt::Debug for RedisCache {
534 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535 f.debug_struct("RedisCache")
536 .field("url", &self.url)
537 .field("default_ttl_secs", &self.default_ttl_secs)
538 .finish()
539 }
540}