use std::{collections::BTreeMap, time::Duration};
use async_trait::async_trait;
use crate::{
cache::{CacheError, CacheKey, CacheResult, CacheStore},
cache_redis::{
RedisCacheError, RedisCacheResult, RedisCacheStore, RedisShardChoice, RedisShardPicker,
RedisShardedCacheConfig,
},
};
#[derive(Debug, Clone)]
struct RedisShard {
name: String,
store: RedisCacheStore,
}
#[derive(Debug, Clone)]
pub struct RedisShardedCacheStore {
config: RedisShardedCacheConfig,
picker: RedisShardPicker,
shards: Vec<RedisShard>,
}
impl RedisShardedCacheStore {
pub fn new(config: RedisShardedCacheConfig) -> RedisCacheResult<Self> {
config.validate()?;
let picker = RedisShardPicker::new(
config
.nodes
.iter()
.map(RedisShardChoice::from)
.collect::<Vec<_>>(),
)?;
let shards = config
.nodes
.iter()
.map(|node| {
#[cfg(feature = "observability")]
let store = RedisCacheStore::new(config.node_cache_config(node))?
.with_shard_name(&node.name);
#[cfg(not(feature = "observability"))]
let store = RedisCacheStore::new(config.node_cache_config(node))?;
Ok(RedisShard {
name: node.name.clone(),
store,
})
})
.collect::<RedisCacheResult<Vec<_>>>()?;
Ok(Self {
config,
picker,
shards,
})
}
pub fn namespace(&self) -> &str {
&self.config.namespace
}
#[cfg(feature = "observability")]
pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
for shard in &mut self.shards {
shard.store = shard
.store
.clone()
.with_metrics(metrics.clone())
.with_shard_name(&shard.name);
}
self
}
pub fn shard_name_for_key(&self, key: &CacheKey) -> RedisCacheResult<&str> {
self.picker.select_name(&key.render())
}
pub async fn health_check(&self) -> RedisCacheResult<()> {
let mut failures = Vec::new();
for shard in &self.shards {
if let Err(error) = shard.store.health_check().await {
failures.push(format!("{}: {error}", shard.name));
}
}
if failures.is_empty() {
Ok(())
} else {
Err(RedisCacheError::Backend(format!(
"redis shard health check failed: {}",
failures.join("; ")
)))
}
}
fn shard_for_key(&self, key: &CacheKey) -> CacheResult<&RedisCacheStore> {
let name = self.shard_name_for_key(key).map_err(to_cache_error)?;
self.shards
.iter()
.find(|shard| shard.name == name)
.map(|shard| &shard.store)
.ok_or_else(|| CacheError::Backend(format!("redis shard `{name}` not found")))
}
}
#[async_trait]
impl CacheStore for RedisShardedCacheStore {
async fn get_raw(&self, key: &CacheKey) -> CacheResult<Option<Vec<u8>>> {
self.shard_for_key(key)?.get_raw(key).await
}
async fn set_raw(
&self,
key: &CacheKey,
value: Vec<u8>,
ttl: Option<Duration>,
) -> CacheResult<()> {
self.shard_for_key(key)?.set_raw(key, value, ttl).await
}
async fn delete(&self, key: &CacheKey) -> CacheResult<()> {
self.shard_for_key(key)?.delete(key).await
}
async fn delete_many(&self, keys: &[CacheKey]) -> CacheResult<()> {
let mut groups: BTreeMap<String, Vec<CacheKey>> = BTreeMap::new();
for key in keys {
let shard = self.shard_name_for_key(key).map_err(to_cache_error)?;
groups
.entry(shard.to_string())
.or_default()
.push(key.clone());
}
let mut failures = Vec::new();
for (name, keys) in groups {
let Some(shard) = self.shards.iter().find(|shard| shard.name == name) else {
failures.push(format!("{name}: shard not found"));
continue;
};
if let Err(error) = shard.store.delete_many(&keys).await {
failures.push(format!("{name}: {error}"));
}
}
if failures.is_empty() {
Ok(())
} else {
Err(CacheError::Backend(format!(
"redis sharded delete failed: {}",
failures.join("; ")
)))
}
}
}
fn to_cache_error(error: RedisCacheError) -> CacheError {
CacheError::Backend(error.to_string())
}
#[cfg(test)]
mod tests {
use crate::{
cache::CacheKey,
cache_redis::{RedisNodeConfig, RedisShardedCacheConfig, RedisShardedCacheStore},
};
#[test]
fn sharded_store_rejects_empty_nodes() {
let error = RedisShardedCacheStore::new(RedisShardedCacheConfig::default())
.expect_err("empty nodes");
assert!(error.to_string().contains("at least one redis shard"));
}
#[test]
fn sharded_store_selects_stable_node_without_connecting() {
let store = RedisShardedCacheStore::new(RedisShardedCacheConfig {
nodes: vec![
RedisNodeConfig::new("a", "redis://127.0.0.1:6379"),
RedisNodeConfig::new("b", "redis://127.0.0.1:6380"),
],
..RedisShardedCacheConfig::default()
})
.expect("store");
let key = CacheKey::new(store.namespace(), ["user", "1"]);
let first = store.shard_name_for_key(&key).expect("first shard");
let second = store.shard_name_for_key(&key).expect("second shard");
assert_eq!(first, second);
}
#[test]
fn sharded_store_groups_keys_by_selected_node() {
let store = RedisShardedCacheStore::new(RedisShardedCacheConfig {
nodes: vec![
RedisNodeConfig::new("a", "redis://127.0.0.1:6379"),
RedisNodeConfig::new("b", "redis://127.0.0.1:6380"),
],
..RedisShardedCacheConfig::default()
})
.expect("store");
let keys = (0..16)
.map(|index| {
CacheKey::new(
store.namespace(),
vec!["user".to_string(), index.to_string()],
)
})
.collect::<Vec<_>>();
let selected = keys
.iter()
.map(|key| store.shard_name_for_key(key).expect("shard"))
.collect::<std::collections::BTreeSet<_>>();
assert!(selected.len() <= 2);
assert!(!selected.is_empty());
}
}