use common::Vector;
use futures_util::StreamExt;
use redis::aio::ConnectionManager;
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub struct RedisError(pub String);
impl std::fmt::Display for RedisError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Redis error: {}", self.0)
}
}
impl From<redis::RedisError> for RedisError {
fn from(e: redis::RedisError) -> Self {
RedisError(e.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CacheInvalidation {
Vectors { namespace: String, ids: Vec<String> },
Namespace(String),
All,
}
#[derive(Debug, Clone, Default)]
pub struct RedisCacheStats {
pub connected: bool,
pub used_memory_bytes: u64,
pub total_keys: u64,
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
}
const REDIS_KEY_PREFIX: &str = "buf";
const REDIS_PUBSUB_CHANNEL: &str = "buffer:cache:invalidate";
const DEFAULT_TTL_SECS: u64 = 3600;
#[derive(Clone)]
pub struct RedisCache {
conn: ConnectionManager,
url: String,
default_ttl_secs: u64,
}
impl RedisCache {
pub async fn new(redis_url: &str) -> Result<Self, RedisError> {
let client = redis::Client::open(redis_url)
.map_err(|e| RedisError(format!("Failed to create Redis client: {}", e)))?;
let conn = ConnectionManager::new(client)
.await
.map_err(|e| RedisError(format!("Failed to connect to Redis: {}", e)))?;
Ok(Self {
conn,
url: redis_url.to_string(),
default_ttl_secs: DEFAULT_TTL_SECS,
})
}
pub fn connection(&self) -> ConnectionManager {
self.conn.clone()
}
fn key(namespace: &str, id: &str) -> String {
format!("{}:{}:{}", REDIS_KEY_PREFIX, namespace, id)
}
fn namespace_pattern(namespace: &str) -> String {
format!("{}:{}:*", REDIS_KEY_PREFIX, namespace)
}
pub async fn get(&self, namespace: &str, id: &str) -> Option<Vector> {
let key = Self::key(namespace, id);
let mut conn = self.conn.clone();
match conn.get::<_, Option<String>>(&key).await {
Ok(Some(json)) => {
metrics::counter!("buffer_redis_hits_total").increment(1);
match serde_json::from_str(&json) {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!(key = %key, error = %e, "Failed to deserialize vector from Redis");
None
}
}
}
Ok(None) => {
metrics::counter!("buffer_redis_misses_total").increment(1);
None
}
Err(e) => {
tracing::debug!(key = %key, error = %e, "Redis GET failed");
metrics::counter!("buffer_redis_misses_total").increment(1);
None
}
}
}
pub async fn get_multi(&self, namespace: &str, ids: &[String]) -> Vec<Vector> {
if ids.is_empty() {
return Vec::new();
}
let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
let mut conn = self.conn.clone();
let results: Result<Vec<Option<String>>, _> =
redis::cmd("MGET").arg(&keys).query_async(&mut conn).await;
match results {
Ok(values) => {
let mut vectors = Vec::new();
for (i, val) in values.into_iter().enumerate() {
match val {
Some(json) => {
metrics::counter!("buffer_redis_hits_total").increment(1);
match serde_json::from_str::<Vector>(&json) {
Ok(v) => vectors.push(v),
Err(e) => {
tracing::warn!(key = %keys[i], error = %e, "Failed to deserialize vector from Redis");
}
}
}
None => {
metrics::counter!("buffer_redis_misses_total").increment(1);
}
}
}
vectors
}
Err(e) => {
tracing::debug!(error = %e, "Redis MGET failed");
metrics::counter!("buffer_redis_misses_total").increment(ids.len() as u64);
Vec::new()
}
}
}
pub async fn set(&self, namespace: &str, vector: &Vector) {
let key = Self::key(namespace, &vector.id);
let json = match serde_json::to_string(vector) {
Ok(j) => j,
Err(e) => {
tracing::warn!(key = %key, error = %e, "Failed to serialize vector for Redis");
return;
}
};
let mut conn = self.conn.clone();
if let Err(e) = conn
.set_ex::<_, _, ()>(&key, &json, self.default_ttl_secs)
.await
{
tracing::debug!(key = %key, error = %e, "Redis SET failed");
}
}
pub async fn set_batch(&self, namespace: &str, vectors: &[Vector]) {
if vectors.is_empty() {
return;
}
let mut conn = self.conn.clone();
let mut pipe = redis::pipe();
for vector in vectors {
let key = Self::key(namespace, &vector.id);
let json = match serde_json::to_string(vector) {
Ok(j) => j,
Err(_) => continue,
};
pipe.cmd("SET")
.arg(&key)
.arg(&json)
.arg("EX")
.arg(self.default_ttl_secs)
.ignore();
}
if let Err(e) = pipe.query_async::<()>(&mut conn).await {
tracing::debug!(error = %e, count = vectors.len(), "Redis pipeline SET failed");
}
}
pub async fn delete(&self, namespace: &str, ids: &[String]) {
if ids.is_empty() {
return;
}
let keys: Vec<String> = ids.iter().map(|id| Self::key(namespace, id)).collect();
let mut conn = self.conn.clone();
if let Err(e) = conn.del::<_, ()>(&keys).await {
tracing::debug!(error = %e, count = ids.len(), "Redis DEL failed");
}
}
pub async fn invalidate_namespace(&self, namespace: &str) {
let pattern = Self::namespace_pattern(namespace);
let mut conn = self.conn.clone();
let mut cursor: u64 = 0;
let mut total_deleted = 0u64;
loop {
let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(500)
.query_async(&mut conn)
.await;
match result {
Ok((next_cursor, keys)) => {
if !keys.is_empty() {
let _ = conn.del::<_, ()>(&keys).await;
total_deleted += keys.len() as u64;
}
cursor = next_cursor;
if cursor == 0 {
break;
}
}
Err(e) => {
tracing::warn!(namespace, error = %e, "Redis SCAN+DEL failed during namespace invalidation");
break;
}
}
}
if total_deleted > 0 {
tracing::debug!(
namespace,
deleted = total_deleted,
"Redis namespace invalidated"
);
}
}
pub async fn clear_all(&self) {
let pattern = format!("{}:*", REDIS_KEY_PREFIX);
let mut conn = self.conn.clone();
let mut cursor: u64 = 0;
loop {
let result: Result<(u64, Vec<String>), _> = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(500)
.query_async(&mut conn)
.await;
match result {
Ok((next_cursor, keys)) => {
if !keys.is_empty() {
let _ = conn.del::<_, ()>(&keys).await;
}
cursor = next_cursor;
if cursor == 0 {
break;
}
}
Err(e) => {
tracing::warn!(error = %e, "Redis SCAN+DEL failed during full cache clear");
break;
}
}
}
tracing::info!("Redis cache cleared");
}
pub async fn stats(&self) -> RedisCacheStats {
let mut conn = self.conn.clone();
let info: Result<String, _> = redis::cmd("INFO").query_async(&mut conn).await;
match info {
Ok(info_str) => {
let used_memory = Self::parse_info_field(&info_str, "used_memory")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
let hits = Self::parse_info_field(&info_str, "keyspace_hits")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
let misses = Self::parse_info_field(&info_str, "keyspace_misses")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
let total_keys: u64 = redis::cmd("DBSIZE")
.query_async(&mut conn)
.await
.unwrap_or(0);
let hit_rate = if hits + misses > 0 {
hits as f64 / (hits + misses) as f64 * 100.0
} else {
0.0
};
RedisCacheStats {
connected: true,
used_memory_bytes: used_memory,
total_keys,
hits,
misses,
hit_rate,
}
}
Err(e) => {
tracing::debug!(error = %e, "Redis INFO command failed");
RedisCacheStats {
connected: false,
..Default::default()
}
}
}
}
fn parse_info_field<'a>(info: &'a str, field: &str) -> Option<&'a str> {
for line in info.lines() {
if let Some(value) = line.strip_prefix(&format!("{}:", field)) {
return Some(value.trim());
}
}
None
}
pub async fn publish_invalidation(&self, msg: &CacheInvalidation) {
let json = match serde_json::to_string(msg) {
Ok(j) => j,
Err(e) => {
tracing::warn!(error = %e, "Failed to serialize cache invalidation message");
return;
}
};
let mut conn = self.conn.clone();
if let Err(e) = conn.publish::<_, _, ()>(REDIS_PUBSUB_CHANNEL, &json).await {
tracing::debug!(error = %e, "Redis PUBLISH failed for cache invalidation");
}
}
pub async fn publish_raw(&self, channel: &str, message: &str) {
let mut conn = self.conn.clone();
if let Err(e) = conn.publish::<_, _, ()>(channel, message).await {
tracing::debug!(channel = %channel, error = %e, "Redis PUBLISH failed");
}
}
pub async fn subscribe_raw(
&self,
channel: &str,
) -> Result<tokio::sync::mpsc::Receiver<String>, RedisError> {
let client = redis::Client::open(self.url.as_str())
.map_err(|e| RedisError(format!("Failed to create Redis client for pub/sub: {}", e)))?;
let mut pubsub_conn = client
.get_async_pubsub()
.await
.map_err(|e| RedisError(format!("Failed to get Redis pub/sub connection: {}", e)))?;
pubsub_conn
.subscribe(channel)
.await
.map_err(|e| RedisError(format!("Failed to subscribe to {}: {}", channel, e)))?;
let (tx, rx) = tokio::sync::mpsc::channel(256);
let channel_name = channel.to_string();
tokio::spawn(async move {
let mut msg_stream = pubsub_conn.on_message();
while let Some(msg) = msg_stream.next().await {
let payload: String = match msg.get_payload() {
Ok(p) => p,
Err(e) => {
tracing::debug!(error = %e, "Failed to get pub/sub message payload");
continue;
}
};
if tx.send(payload).await.is_err() {
tracing::debug!(channel = %channel_name, "Pub/sub receiver dropped, stopping");
break;
}
}
tracing::warn!(channel = %channel_name, "Redis pub/sub raw stream ended");
});
tracing::info!(channel = %channel, "Redis raw pub/sub subscription started");
Ok(rx)
}
pub async fn subscribe_invalidations<F>(&self, mut handler: F)
where
F: FnMut(CacheInvalidation) + Send + 'static,
{
let client = match redis::Client::open(self.url.as_str()) {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Failed to create Redis client for pub/sub");
return;
}
};
let mut pubsub_conn = match client.get_async_pubsub().await {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "Failed to get Redis pub/sub connection");
return;
}
};
if let Err(e) = pubsub_conn.subscribe(REDIS_PUBSUB_CHANNEL).await {
tracing::error!(error = %e, "Failed to subscribe to Redis invalidation channel");
return;
}
tracing::info!("Redis pub/sub subscribed to {}", REDIS_PUBSUB_CHANNEL);
let mut msg_stream = pubsub_conn.on_message();
while let Some(msg) = msg_stream.next().await {
let payload: String = match msg.get_payload() {
Ok(p) => p,
Err(e) => {
tracing::debug!(error = %e, "Failed to get pub/sub message payload");
continue;
}
};
match serde_json::from_str::<CacheInvalidation>(&payload) {
Ok(invalidation) => handler(invalidation),
Err(e) => {
tracing::debug!(error = %e, "Failed to deserialize invalidation message");
}
}
}
tracing::warn!("Redis pub/sub stream ended");
}
pub async fn try_acquire_lock(&self, key: &str, owner: &str, ttl_secs: u64) -> bool {
let mut conn = self.conn.clone();
let result: Result<Option<String>, _> = redis::cmd("SET")
.arg(key)
.arg(owner)
.arg("EX")
.arg(ttl_secs)
.arg("NX")
.query_async(&mut conn)
.await;
match result {
Ok(Some(_)) => {
tracing::debug!(key = %key, owner = %owner, "Distributed lock acquired");
true
}
Ok(None) => false, Err(e) => {
tracing::warn!(
key = %key,
error = %e,
"Redis lock acquire failed — running as single-node fallback"
);
true }
}
}
pub async fn release_lock(&self, key: &str, owner: &str) {
let mut conn = self.conn.clone();
let script = redis::Script::new(
r#"if redis.call('get', KEYS[1]) == ARGV[1] then
return redis.call('del', KEYS[1])
else
return 0
end"#,
);
if let Err(e) = script
.key(key)
.arg(owner)
.invoke_async::<i64>(&mut conn)
.await
{
tracing::debug!(key = %key, error = %e, "Redis lock release failed (lock may have already expired)");
}
}
}
impl std::fmt::Debug for RedisCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisCache")
.field("url", &self.url)
.field("default_ttl_secs", &self.default_ttl_secs)
.finish()
}
}