1use common::Vector;
7use futures_util::StreamExt;
8use redis::aio::ConnectionManager;
9use redis::AsyncCommands;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug)]
14pub struct RedisError(pub String);
15
16impl std::fmt::Display for RedisError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 write!(f, "Redis error: {}", self.0)
19 }
20}
21
22impl From<redis::RedisError> for RedisError {
23 fn from(e: redis::RedisError) -> Self {
24 RedisError(e.to_string())
25 }
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum CacheInvalidation {
31 Vectors { namespace: String, ids: Vec<String> },
33 Namespace(String),
35 All,
37}
38
39#[derive(Debug, Clone, Default)]
41pub struct RedisCacheStats {
42 pub connected: bool,
43 pub used_memory_bytes: u64,
44 pub total_keys: u64,
45 pub hits: u64,
46 pub misses: u64,
47 pub hit_rate: f64,
48}
49
50const REDIS_KEY_PREFIX: &str = "buf";
51const REDIS_PUBSUB_CHANNEL: &str = "buffer:cache:invalidate";
52const DEFAULT_TTL_SECS: u64 = 3600; #[derive(Clone)]
56pub struct RedisCache {
57 conn: ConnectionManager,
58 url: String,
59 default_ttl_secs: u64,
60}
61
62impl RedisCache {
63 pub async fn new(redis_url: &str) -> Result<Self, RedisError> {
65 let client = redis::Client::open(redis_url)
66 .map_err(|e| RedisError(format!("Failed to create Redis client: {}", e)))?;
67 let conn = ConnectionManager::new(client)
68 .await
69 .map_err(|e| RedisError(format!("Failed to connect to Redis: {}", e)))?;
70 Ok(Self {
71 conn,
72 url: redis_url.to_string(),
73 default_ttl_secs: DEFAULT_TTL_SECS,
74 })
75 }
76
77 fn key(namespace: &str, id: &str) -> String {
79 format!("{}:{}:{}", REDIS_KEY_PREFIX, namespace, id)
80 }
81
82 fn namespace_pattern(namespace: &str) -> String {
84 format!("{}:{}:*", REDIS_KEY_PREFIX, namespace)
85 }
86
87 pub async fn get(&self, namespace: &str, id: &str) -> Option<Vector> {
89 let key = Self::key(namespace, id);
90 let mut conn = self.conn.clone();
91 match conn.get::<_, Option<String>>(&key).await {
92 Ok(Some(json)) => {
93 metrics::counter!("buffer_redis_hits_total").increment(1);
94 match serde_json::from_str(&json) {
95 Ok(v) => Some(v),
96 Err(e) => {
97 tracing::warn!(key = %key, error = %e, "Failed to deserialize vector from Redis");
98 None
99 }
100 }
101 }
102 Ok(None) => {
103 metrics::counter!("buffer_redis_misses_total").increment(1);
104 None
105 }
106 Err(e) => {
107 tracing::debug!(key = %key, error = %e, "Redis GET failed");
108 metrics::counter!("buffer_redis_misses_total").increment(1);
109 None
110 }
111 }
112 }
113
114 pub async fn get_multi(&self, namespace: &str, ids: &[String]) -> Vec<Vector> {
116 if ids.is_empty() {
117 return Vec::new();
118 }
119 let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
120 let mut conn = self.conn.clone();
121
122 let results: Result<Vec<Option<String>>, _> =
124 redis::cmd("MGET").arg(&keys).query_async(&mut conn).await;
125
126 match results {
127 Ok(values) => {
128 let mut vectors = Vec::new();
129 for (i, val) in values.into_iter().enumerate() {
130 match val {
131 Some(json) => {
132 metrics::counter!("buffer_redis_hits_total").increment(1);
133 match serde_json::from_str::<Vector>(&json) {
134 Ok(v) => vectors.push(v),
135 Err(e) => {
136 tracing::warn!(key = %keys[i], error = %e, "Failed to deserialize vector from Redis");
137 }
138 }
139 }
140 None => {
141 metrics::counter!("buffer_redis_misses_total").increment(1);
142 }
143 }
144 }
145 vectors
146 }
147 Err(e) => {
148 tracing::debug!(error = %e, "Redis MGET failed");
149 metrics::counter!("buffer_redis_misses_total").increment(ids.len() as u64);
150 Vec::new()
151 }
152 }
153 }
154
155 pub async fn set(&self, namespace: &str, vector: &Vector) {
157 let key = Self::key(namespace, &vector.id);
158 let json = match serde_json::to_string(vector) {
159 Ok(j) => j,
160 Err(e) => {
161 tracing::warn!(key = %key, error = %e, "Failed to serialize vector for Redis");
162 return;
163 }
164 };
165 let mut conn = self.conn.clone();
166 if let Err(e) = conn
167 .set_ex::<_, _, ()>(&key, &json, self.default_ttl_secs)
168 .await
169 {
170 tracing::debug!(key = %key, error = %e, "Redis SET failed");
171 }
172 }
173
174 pub async fn set_batch(&self, namespace: &str, vectors: &[Vector]) {
176 if vectors.is_empty() {
177 return;
178 }
179 let mut conn = self.conn.clone();
180 let mut pipe = redis::pipe();
181 for vector in vectors {
182 let key = Self::key(namespace, &vector.id);
183 let json = match serde_json::to_string(vector) {
184 Ok(j) => j,
185 Err(_) => continue,
186 };
187 pipe.cmd("SET")
188 .arg(&key)
189 .arg(&json)
190 .arg("EX")
191 .arg(self.default_ttl_secs)
192 .ignore();
193 }
194 if let Err(e) = pipe.query_async::<()>(&mut conn).await {
195 tracing::debug!(error = %e, count = vectors.len(), "Redis pipeline SET failed");
196 }
197 }
198
199 pub async fn delete(&self, namespace: &str, ids: &[String]) {
201 if ids.is_empty() {
202 return;
203 }
204 let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
205 let mut conn = self.conn.clone();
206 if let Err(e) = conn.del::<_, ()>(&keys).await {
207 tracing::debug!(error = %e, count = ids.len(), "Redis DEL failed");
208 }
209 }
210
211 pub async fn invalidate_namespace(&self, namespace: &str) {
213 let pattern = Self::namespace_pattern(namespace);
214 let mut conn = self.conn.clone();
215 let mut cursor: u64 = 0;
216 let mut total_deleted = 0u64;
217
218 loop {
219 let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
220 .arg(cursor)
221 .arg("MATCH")
222 .arg(&pattern)
223 .arg("COUNT")
224 .arg(500)
225 .query_async(&mut conn)
226 .await;
227
228 match result {
229 Ok((next_cursor, keys)) => {
230 if !keys.is_empty() {
231 let _ = conn.del::<_, ()>(&keys).await;
232 total_deleted += keys.len() as u64;
233 }
234 cursor = next_cursor;
235 if cursor == 0 {
236 break;
237 }
238 }
239 Err(e) => {
240 tracing::warn!(namespace, error = %e, "Redis SCAN+DEL failed during namespace invalidation");
241 break;
242 }
243 }
244 }
245
246 if total_deleted > 0 {
247 tracing::debug!(
248 namespace,
249 deleted = total_deleted,
250 "Redis namespace invalidated"
251 );
252 }
253 }
254
255 pub async fn clear_all(&self) {
257 let pattern = format!("{}:*", REDIS_KEY_PREFIX);
258 let mut conn = self.conn.clone();
259 let mut cursor: u64 = 0;
260
261 loop {
262 let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
263 .arg(cursor)
264 .arg("MATCH")
265 .arg(&pattern)
266 .arg("COUNT")
267 .arg(500)
268 .query_async(&mut conn)
269 .await;
270
271 match result {
272 Ok((next_cursor, keys)) => {
273 if !keys.is_empty() {
274 let _ = conn.del::<_, ()>(&keys).await;
275 }
276 cursor = next_cursor;
277 if cursor == 0 {
278 break;
279 }
280 }
281 Err(e) => {
282 tracing::warn!(error = %e, "Redis SCAN+DEL failed during full cache clear");
283 break;
284 }
285 }
286 }
287
288 tracing::info!("Redis cache cleared");
289 }
290
291 pub async fn stats(&self) -> RedisCacheStats {
293 let mut conn = self.conn.clone();
294 let info: Result<String, _> = redis::cmd("INFO").query_async(&mut conn).await;
295
296 match info {
297 Ok(info_str) => {
298 let used_memory = Self::parse_info_field(&info_str, "used_memory")
299 .and_then(|s| s.parse::<u64>().ok())
300 .unwrap_or(0);
301 let hits = Self::parse_info_field(&info_str, "keyspace_hits")
302 .and_then(|s| s.parse::<u64>().ok())
303 .unwrap_or(0);
304 let misses = Self::parse_info_field(&info_str, "keyspace_misses")
305 .and_then(|s| s.parse::<u64>().ok())
306 .unwrap_or(0);
307
308 let total_keys: u64 = redis::cmd("DBSIZE")
310 .query_async(&mut conn)
311 .await
312 .unwrap_or(0);
313
314 let hit_rate = if hits + misses > 0 {
315 hits as f64 / (hits + misses) as f64 * 100.0
316 } else {
317 0.0
318 };
319
320 RedisCacheStats {
321 connected: true,
322 used_memory_bytes: used_memory,
323 total_keys,
324 hits,
325 misses,
326 hit_rate,
327 }
328 }
329 Err(e) => {
330 tracing::debug!(error = %e, "Redis INFO command failed");
331 RedisCacheStats {
332 connected: false,
333 ..Default::default()
334 }
335 }
336 }
337 }
338
339 fn parse_info_field<'a>(info: &'a str, field: &str) -> Option<&'a str> {
341 for line in info.lines() {
342 if let Some(value) = line.strip_prefix(&format!("{}:", field)) {
343 return Some(value.trim());
344 }
345 }
346 None
347 }
348
349 pub async fn publish_invalidation(&self, msg: &CacheInvalidation) {
351 let json = match serde_json::to_string(msg) {
352 Ok(j) => j,
353 Err(e) => {
354 tracing::warn!(error = %e, "Failed to serialize cache invalidation message");
355 return;
356 }
357 };
358 let mut conn = self.conn.clone();
359 if let Err(e) = conn.publish::<_, _, ()>(REDIS_PUBSUB_CHANNEL, &json).await {
360 tracing::debug!(error = %e, "Redis PUBLISH failed for cache invalidation");
361 }
362 }
363
364 pub async fn publish_raw(&self, channel: &str, message: &str) {
367 let mut conn = self.conn.clone();
368 if let Err(e) = conn.publish::<_, _, ()>(channel, message).await {
369 tracing::debug!(channel = %channel, error = %e, "Redis PUBLISH failed");
370 }
371 }
372
373 pub async fn subscribe_raw(
376 &self,
377 channel: &str,
378 ) -> Result<tokio::sync::mpsc::Receiver<String>, RedisError> {
379 let client = redis::Client::open(self.url.as_str())
380 .map_err(|e| RedisError(format!("Failed to create Redis client for pub/sub: {}", e)))?;
381 let mut pubsub_conn = client
382 .get_async_pubsub()
383 .await
384 .map_err(|e| RedisError(format!("Failed to get Redis pub/sub connection: {}", e)))?;
385 pubsub_conn
386 .subscribe(channel)
387 .await
388 .map_err(|e| RedisError(format!("Failed to subscribe to {}: {}", channel, e)))?;
389
390 let (tx, rx) = tokio::sync::mpsc::channel(256);
391 let channel_name = channel.to_string();
392
393 tokio::spawn(async move {
394 let mut msg_stream = pubsub_conn.on_message();
395 while let Some(msg) = msg_stream.next().await {
396 let payload: String = match msg.get_payload() {
397 Ok(p) => p,
398 Err(e) => {
399 tracing::debug!(error = %e, "Failed to get pub/sub message payload");
400 continue;
401 }
402 };
403 if tx.send(payload).await.is_err() {
404 tracing::debug!(channel = %channel_name, "Pub/sub receiver dropped, stopping");
405 break;
406 }
407 }
408 tracing::warn!(channel = %channel_name, "Redis pub/sub raw stream ended");
409 });
410
411 tracing::info!(channel = %channel, "Redis raw pub/sub subscription started");
412 Ok(rx)
413 }
414
415 pub async fn subscribe_invalidations<F>(&self, mut handler: F)
419 where
420 F: FnMut(CacheInvalidation) + Send + 'static,
421 {
422 let client = match redis::Client::open(self.url.as_str()) {
424 Ok(c) => c,
425 Err(e) => {
426 tracing::error!(error = %e, "Failed to create Redis client for pub/sub");
427 return;
428 }
429 };
430
431 let mut pubsub_conn = match client.get_async_pubsub().await {
432 Ok(c) => c,
433 Err(e) => {
434 tracing::error!(error = %e, "Failed to get Redis pub/sub connection");
435 return;
436 }
437 };
438
439 if let Err(e) = pubsub_conn.subscribe(REDIS_PUBSUB_CHANNEL).await {
440 tracing::error!(error = %e, "Failed to subscribe to Redis invalidation channel");
441 return;
442 }
443
444 tracing::info!("Redis pub/sub subscribed to {}", REDIS_PUBSUB_CHANNEL);
445
446 let mut msg_stream = pubsub_conn.on_message();
447 while let Some(msg) = msg_stream.next().await {
448 let payload: String = match msg.get_payload() {
449 Ok(p) => p,
450 Err(e) => {
451 tracing::debug!(error = %e, "Failed to get pub/sub message payload");
452 continue;
453 }
454 };
455 match serde_json::from_str::<CacheInvalidation>(&payload) {
456 Ok(invalidation) => handler(invalidation),
457 Err(e) => {
458 tracing::debug!(error = %e, "Failed to deserialize invalidation message");
459 }
460 }
461 }
462
463 tracing::warn!("Redis pub/sub stream ended");
464 }
465}
466
467impl std::fmt::Debug for RedisCache {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 f.debug_struct("RedisCache")
470 .field("url", &self.url)
471 .field("default_ttl_secs", &self.default_ttl_secs)
472 .finish()
473 }
474}