use async_trait::async_trait;
#[cfg(feature = "redis-cache")]
use redis::aio::ConnectionManager;
#[cfg(feature = "redis-cache")]
use super::error::{StorageError, StorageResult};
use super::policy::PolicyDocument;
#[async_trait]
pub trait PolicyCacheLike: Send + Sync {
async fn get(&self, name: &str) -> Option<PolicyDocument>;
async fn set(&self, doc: &PolicyDocument);
async fn invalidate(&self, name: &str);
fn is_enabled(&self) -> bool;
}
#[non_exhaustive]
pub enum PolicyCache {
Disabled,
#[cfg(feature = "redis-cache")]
Redis(RedisPolicyCache),
}
#[async_trait]
impl PolicyCacheLike for PolicyCache {
async fn get(&self, _name: &str) -> Option<PolicyDocument> {
match self {
Self::Disabled => None,
#[cfg(feature = "redis-cache")]
Self::Redis(cache) => cache.get(_name).await,
}
}
async fn set(&self, _doc: &PolicyDocument) {
match self {
Self::Disabled => {}
#[cfg(feature = "redis-cache")]
Self::Redis(cache) => cache.set(_doc).await,
}
}
async fn invalidate(&self, _name: &str) {
match self {
Self::Disabled => {}
#[cfg(feature = "redis-cache")]
Self::Redis(cache) => cache.invalidate(_name).await,
}
}
fn is_enabled(&self) -> bool {
match self {
Self::Disabled => false,
#[cfg(feature = "redis-cache")]
Self::Redis(_) => true,
}
}
}
impl Default for PolicyCache {
fn default() -> Self {
Self::Disabled
}
}
impl PolicyCache {
pub fn from_config(config: &RedisConfig) -> Self {
if !config.enabled {
return Self::Disabled;
}
Self::Disabled
}
pub async fn from_config_async(config: &RedisConfig) -> Self {
if !config.enabled {
return Self::Disabled;
}
#[cfg(feature = "redis-cache")]
{
match RedisPolicyCache::connect(config).await {
Ok(cache) => Self::Redis(cache),
Err(err) => {
tracing::warn!(
error = %err,
"redis policy cache connect failed — falling back to disabled cache"
);
Self::Disabled
}
}
}
#[cfg(not(feature = "redis-cache"))]
{
tracing::warn!(
"storage.redis.enabled = true but the `redis-cache` feature is not compiled in; falling back to disabled cache"
);
Self::Disabled
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RedisConfig {
pub enabled: bool,
pub url: Option<String>,
pub policy_cache_ttl_secs: u64,
pub max_connections: u32,
}
impl Default for RedisConfig {
fn default() -> Self {
Self {
enabled: false,
url: None,
policy_cache_ttl_secs: 30,
max_connections: 10,
}
}
}
const POLICY_CACHE_HASH_HEX_LEN: usize = 16;
pub fn policy_cache_key(name: &str, bytes: &[u8]) -> String {
use sha2::{Digest, Sha256};
let digest = Sha256::digest(bytes);
let mut hex = String::with_capacity(POLICY_CACHE_HASH_HEX_LEN);
for byte in digest.iter().take(POLICY_CACHE_HASH_HEX_LEN / 2) {
use std::fmt::Write as _;
let _ = write!(hex, "{byte:02x}");
}
format!("policy:{name}:{hex}")
}
pub fn policy_invalidation_pattern(name: &str) -> String {
format!("policy:{name}:*")
}
pub fn active_policy_key(name: &str) -> String {
format!("policy:{name}")
}
#[cfg(feature = "redis-cache")]
pub struct RedisPolicyCache {
conn: ConnectionManager,
ttl_secs: u64,
}
#[cfg(feature = "redis-cache")]
impl RedisPolicyCache {
pub async fn connect(config: &RedisConfig) -> StorageResult<Self> {
let url = config.url.as_deref().ok_or_else(|| {
StorageError::ConnectionFailed("storage.redis.url is required when redis.enabled = true".into())
})?;
let client = redis::Client::open(url).map_err(|e| StorageError::ConnectionFailed(e.to_string()))?;
let conn = client
.get_connection_manager()
.await
.map_err(|e| StorageError::ConnectionFailed(e.to_string()))?;
Ok(Self {
conn,
ttl_secs: config.policy_cache_ttl_secs,
})
}
#[cfg(test)]
pub fn ttl_secs(&self) -> u64 {
self.ttl_secs
}
}
#[cfg(feature = "redis-cache")]
#[async_trait]
impl PolicyCacheLike for RedisPolicyCache {
async fn get(&self, name: &str) -> Option<PolicyDocument> {
use redis::AsyncCommands;
let mut conn = self.conn.clone();
let key = active_policy_key(name);
let bytes: redis::RedisResult<Option<Vec<u8>>> = conn.get(&key).await;
bytes.ok().flatten().map(|b| PolicyDocument {
name: name.into(),
bytes: b,
})
}
async fn set(&self, doc: &PolicyDocument) {
use redis::AsyncCommands;
let mut conn = self.conn.clone();
let key = active_policy_key(&doc.name);
let result: redis::RedisResult<()> = conn.set_ex(&key, &doc.bytes, self.ttl_secs).await;
if let Err(err) = result {
tracing::debug!(error = %err, key = %key, "redis policy cache set failed");
}
}
async fn invalidate(&self, name: &str) {
use redis::AsyncCommands;
let mut conn = self.conn.clone();
let key = active_policy_key(name);
let result: redis::RedisResult<()> = conn.del(&key).await;
if let Err(err) = result {
tracing::debug!(error = %err, key = %key, "redis policy cache invalidate failed");
}
}
fn is_enabled(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
mod config {
use super::*;
#[test]
fn default_is_off_posture() {
let cfg = RedisConfig::default();
assert!(!cfg.enabled, "cache must default to OFF");
assert!(cfg.url.is_none(), "no URL by default");
assert_eq!(cfg.policy_cache_ttl_secs, 30);
assert_eq!(cfg.max_connections, 10);
}
#[test]
fn explicit_url_is_preserved() {
let cfg = RedisConfig {
enabled: true,
url: Some("redis://10.0.0.5:6379".into()),
..RedisConfig::default()
};
assert!(cfg.enabled);
assert_eq!(cfg.url.as_deref(), Some("redis://10.0.0.5:6379"));
assert_eq!(cfg.policy_cache_ttl_secs, 30);
assert_eq!(cfg.max_connections, 10);
}
}
mod key {
use super::*;
#[test]
fn same_inputs_yield_identical_key() {
let a = policy_cache_key("default", b"version-1-body");
let b = policy_cache_key("default", b"version-1-body");
assert_eq!(a, b);
}
#[test]
fn changing_bytes_changes_key() {
let v1 = policy_cache_key("default", b"version-1-body");
let v2 = policy_cache_key("default", b"version-2-body");
assert_ne!(v1, v2, "content-addressing must shift the key");
}
#[test]
fn name_namespaces_the_key() {
let same_bytes: &[u8] = b"shared-bytes";
let a = policy_cache_key("default", same_bytes);
let b = policy_cache_key("legacy", same_bytes);
assert_ne!(a, b, "different names must produce different keys");
}
#[test]
fn hash_slice_is_sixteen_hex_chars() {
let key = policy_cache_key("default", b"any-body");
let hex = key.rsplit(':').next().expect("key has a hex segment");
assert_eq!(hex.len(), 16, "expected 16 hex chars, got {hex:?}");
assert!(
hex.bytes().all(|b| b.is_ascii_hexdigit()),
"hex segment must be ascii hex: {hex:?}"
);
}
#[test]
fn invalidation_pattern_matches_every_version() {
assert_eq!(policy_invalidation_pattern("default"), "policy:default:*");
assert_eq!(policy_invalidation_pattern("legacy"), "policy:legacy:*");
}
}
mod disabled {
use super::*;
#[test]
fn default_is_disabled() {
let cache = PolicyCache::default();
assert!(matches!(cache, PolicyCache::Disabled));
assert!(!cache.is_enabled());
}
#[tokio::test]
async fn get_always_returns_none() {
let cache = PolicyCache::Disabled;
assert!(cache.get("default").await.is_none());
assert!(cache.get("any-other-name").await.is_none());
}
#[tokio::test]
async fn set_and_invalidate_do_not_panic() {
let cache = PolicyCache::Disabled;
let doc = PolicyDocument {
name: "default".into(),
bytes: b"any-body".to_vec(),
};
cache.set(&doc).await;
cache.invalidate("default").await;
}
#[test]
fn from_config_default_redis_is_disabled() {
let cache = PolicyCache::from_config(&RedisConfig::default());
assert!(matches!(cache, PolicyCache::Disabled));
assert!(!cache.is_enabled());
}
}
mod stub {
use super::*;
use std::collections::HashMap;
use std::sync::Mutex;
use tokio::time::{Duration, Instant};
pub struct StubPolicyCache {
ttl: Duration,
store: Mutex<HashMap<String, (Vec<u8>, Instant)>>,
}
impl StubPolicyCache {
pub fn new(ttl_secs: u64) -> Self {
Self {
ttl: Duration::from_secs(ttl_secs),
store: Mutex::new(HashMap::new()),
}
}
}
#[async_trait]
impl PolicyCacheLike for StubPolicyCache {
async fn get(&self, name: &str) -> Option<PolicyDocument> {
let guard = self.store.lock().expect("stub lock");
let (bytes, expires_at) = guard.get(name)?;
if *expires_at <= Instant::now() {
return None;
}
Some(PolicyDocument {
name: name.into(),
bytes: bytes.clone(),
})
}
async fn set(&self, doc: &PolicyDocument) {
let mut guard = self.store.lock().expect("stub lock");
guard.insert(doc.name.clone(), (doc.bytes.clone(), Instant::now() + self.ttl));
}
async fn invalidate(&self, name: &str) {
let mut guard = self.store.lock().expect("stub lock");
guard.remove(name);
}
fn is_enabled(&self) -> bool {
true
}
}
}
mod contract {
use super::stub::StubPolicyCache;
use super::*;
fn doc(name: &str, body: &[u8]) -> PolicyDocument {
PolicyDocument {
name: name.into(),
bytes: body.to_vec(),
}
}
#[tokio::test]
async fn round_trip_set_then_get() {
let cache = StubPolicyCache::new(30);
cache.set(&doc("default", b"v1-body")).await;
let fetched = cache.get("default").await.expect("cached entry present");
assert_eq!(fetched.name, "default");
assert_eq!(fetched.bytes, b"v1-body".to_vec());
assert!(cache.is_enabled());
}
#[tokio::test]
async fn invalidate_evicts_entry() {
let cache = StubPolicyCache::new(30);
cache.set(&doc("default", b"v1-body")).await;
cache.invalidate("default").await;
assert!(cache.get("default").await.is_none());
}
#[tokio::test(start_paused = true)]
async fn entry_expires_after_ttl() {
use tokio::time::Duration;
let cache = StubPolicyCache::new(30);
cache.set(&doc("default", b"v1-body")).await;
assert!(cache.get("default").await.is_some());
tokio::time::advance(Duration::from_secs(31)).await;
assert!(
cache.get("default").await.is_none(),
"entry must expire once ttl_secs elapses"
);
}
}
#[cfg(feature = "redis-cache")]
mod redis_backend {
use super::*;
#[tokio::test]
async fn connect_with_none_url_returns_connection_failed() {
let config = RedisConfig {
enabled: true,
url: None,
..RedisConfig::default()
};
match RedisPolicyCache::connect(&config).await {
Ok(_) => panic!("None URL must surface as ConnectionFailed"),
Err(err) => assert!(matches!(err, StorageError::ConnectionFailed(_))),
}
}
#[tokio::test]
async fn connect_with_malformed_url_returns_connection_failed() {
let config = RedisConfig {
enabled: true,
url: Some("not-a-redis-url".into()),
..RedisConfig::default()
};
match RedisPolicyCache::connect(&config).await {
Ok(_) => panic!("malformed URL must surface as ConnectionFailed"),
Err(err) => assert!(matches!(err, StorageError::ConnectionFailed(_))),
}
}
#[tokio::test]
async fn from_config_async_falls_back_to_disabled_on_bad_url() {
let config = RedisConfig {
enabled: true,
url: Some("redis://127.0.0.1:1".into()),
..RedisConfig::default()
};
let cache = PolicyCache::from_config_async(&config).await;
assert!(
matches!(cache, PolicyCache::Disabled),
"connect failure must fall back to Disabled, not panic"
);
assert!(!cache.is_enabled());
}
}
}