oxify_authz/
redis_cache.rs

1//! Redis-based L2 cache for distributed authorization
2//!
3//! Provides shared caching across multiple API servers with:
4//! - Distributed cache invalidation
5//! - TTL-based expiration
6//! - Pub/sub for cache updates
7//! - High availability support
8
9use crate::{RelationTuple, Subject};
10use deadpool_redis::{Config, Pool, Runtime};
11use redis::AsyncCommands;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::time::Duration;
15use thiserror::Error;
16use tokio::sync::RwLock;
17
18#[derive(Error, Debug)]
19pub enum RedisCacheError {
20    #[error("Redis connection error: {0}")]
21    ConnectionError(String),
22
23    #[error("Serialization error: {0}")]
24    SerializationError(String),
25
26    #[error("Cache operation failed: {0}")]
27    OperationError(String),
28}
29
30pub type Result<T> = std::result::Result<T, RedisCacheError>;
31
32/// Redis cache configuration
33#[derive(Debug, Clone)]
34pub struct RedisCacheConfig {
35    /// Redis connection URL (e.g., "redis://localhost:6379")
36    pub url: String,
37
38    /// Default TTL for cached entries (default: 300 seconds)
39    pub default_ttl: Duration,
40
41    /// Key prefix for cache entries
42    pub key_prefix: String,
43
44    /// Enable pub/sub for cache invalidation
45    pub enable_pubsub: bool,
46
47    /// Pub/sub channel name for invalidation events
48    pub invalidation_channel: String,
49
50    /// Connection pool size
51    pub pool_size: u32,
52}
53
54impl Default for RedisCacheConfig {
55    fn default() -> Self {
56        Self {
57            url: "redis://localhost:6379".to_string(),
58            default_ttl: Duration::from_secs(300),
59            key_prefix: "authz:".to_string(),
60            enable_pubsub: true,
61            invalidation_channel: "authz:invalidate".to_string(),
62            pool_size: 10,
63        }
64    }
65}
66
67impl RedisCacheConfig {
68    pub fn with_url(mut self, url: String) -> Self {
69        self.url = url;
70        self
71    }
72
73    pub fn with_ttl(mut self, ttl: Duration) -> Self {
74        self.default_ttl = ttl;
75        self
76    }
77
78    pub fn with_key_prefix(mut self, prefix: String) -> Self {
79        self.key_prefix = prefix;
80        self
81    }
82}
83
84/// Cached permission check result
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct CachedPermission {
87    pub allowed: bool,
88    pub cached_at: i64, // Unix timestamp
89}
90
91/// Cache key for permission checks
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct PermissionCacheKey {
94    pub resource_type: String,
95    pub resource_id: String,
96    pub relation: String,
97    pub subject: Subject,
98}
99
100impl PermissionCacheKey {
101    pub fn new(
102        resource_type: String,
103        resource_id: String,
104        relation: String,
105        subject: Subject,
106    ) -> Self {
107        Self {
108            resource_type,
109            resource_id,
110            relation,
111            subject,
112        }
113    }
114
115    /// Generate Redis key string
116    pub fn to_redis_key(&self, prefix: &str) -> String {
117        format!(
118            "{}check:{}:{}:{}:{}",
119            prefix,
120            self.resource_type,
121            self.resource_id,
122            self.relation,
123            subject_to_key(&self.subject)
124        )
125    }
126}
127
128fn subject_to_key(subject: &Subject) -> String {
129    match subject {
130        Subject::User(id) => format!("user:{}", id),
131        Subject::UserSet {
132            namespace,
133            object_id,
134            relation,
135        } => {
136            format!("userset:{}:{}:{}", namespace, object_id, relation)
137        }
138    }
139}
140
141/// Redis cache for authorization (L2 cache)
142pub struct RedisCache {
143    config: RedisCacheConfig,
144    pool: Pool,
145    stats: Arc<RwLock<RedisCacheStatsInternal>>,
146}
147
148/// Internal statistics tracking with atomic operations
149#[derive(Debug, Default)]
150struct RedisCacheStatsInternal {
151    hit_count: u64,
152    miss_count: u64,
153}
154
155impl RedisCache {
156    /// Create a new Redis cache instance
157    pub fn new(config: RedisCacheConfig) -> Result<Self> {
158        let redis_config = Config::from_url(&config.url);
159        let pool = redis_config
160            .create_pool(Some(Runtime::Tokio1))
161            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
162
163        Ok(Self {
164            config,
165            pool,
166            stats: Arc::new(RwLock::new(RedisCacheStatsInternal::default())),
167        })
168    }
169
170    /// Connect to Redis server (validate connection)
171    pub async fn connect(&mut self) -> Result<()> {
172        let mut conn = self
173            .pool
174            .get()
175            .await
176            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
177
178        // Verify connection with PING
179        redis::cmd("PING")
180            .query_async::<String>(&mut conn)
181            .await
182            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
183
184        Ok(())
185    }
186
187    /// Get cached permission result
188    pub async fn get_permission(&self, key: &PermissionCacheKey) -> Result<Option<bool>> {
189        let mut conn = self
190            .pool
191            .get()
192            .await
193            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
194
195        let redis_key = key.to_redis_key(&self.config.key_prefix);
196
197        let result: Option<String> = conn
198            .get(&redis_key)
199            .await
200            .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
201
202        match result {
203            Some(json_str) => {
204                let cached: CachedPermission = serde_json::from_str(&json_str)
205                    .map_err(|e| RedisCacheError::SerializationError(e.to_string()))?;
206
207                // Update stats: cache hit
208                let mut stats = self.stats.write().await;
209                stats.hit_count += 1;
210
211                Ok(Some(cached.allowed))
212            }
213            None => {
214                // Update stats: cache miss
215                let mut stats = self.stats.write().await;
216                stats.miss_count += 1;
217
218                Ok(None)
219            }
220        }
221    }
222
223    /// Set cached permission result
224    pub async fn set_permission(
225        &self,
226        key: &PermissionCacheKey,
227        allowed: bool,
228        ttl: Option<Duration>,
229    ) -> Result<()> {
230        let mut conn = self
231            .pool
232            .get()
233            .await
234            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
235
236        let redis_key = key.to_redis_key(&self.config.key_prefix);
237        let ttl_secs = ttl.unwrap_or(self.config.default_ttl).as_secs() as i64;
238
239        let cached = CachedPermission {
240            allowed,
241            cached_at: chrono::Utc::now().timestamp(),
242        };
243
244        let json_str = serde_json::to_string(&cached)
245            .map_err(|e| RedisCacheError::SerializationError(e.to_string()))?;
246
247        // Set with TTL using SETEX
248        let _: () = conn
249            .set_ex(&redis_key, json_str, ttl_secs as u64)
250            .await
251            .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
252
253        Ok(())
254    }
255
256    /// Invalidate cache entry
257    pub async fn invalidate(&self, key: &PermissionCacheKey) -> Result<()> {
258        let mut conn = self
259            .pool
260            .get()
261            .await
262            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
263
264        let redis_key = key.to_redis_key(&self.config.key_prefix);
265
266        let _: () = conn
267            .del(&redis_key)
268            .await
269            .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
270
271        // Optionally publish invalidation event
272        if self.config.enable_pubsub {
273            let msg = serde_json::to_string(&key)
274                .map_err(|e| RedisCacheError::SerializationError(e.to_string()))?;
275
276            let _: () = conn
277                .publish(&self.config.invalidation_channel, msg)
278                .await
279                .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
280        }
281
282        Ok(())
283    }
284
285    /// Invalidate all cache entries for a subject
286    pub async fn invalidate_subject(&self, subject: &Subject) -> Result<()> {
287        let mut conn = self
288            .pool
289            .get()
290            .await
291            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
292
293        let subject_key = subject_to_key(subject);
294        let pattern = format!("{}check:*:*:*:{}", self.config.key_prefix, subject_key);
295
296        // Use SCAN to find all matching keys
297        let keys: Vec<String> = redis::cmd("SCAN")
298            .arg(0)
299            .arg("MATCH")
300            .arg(&pattern)
301            .arg("COUNT")
302            .arg(100)
303            .query_async(&mut conn)
304            .await
305            .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
306
307        // Delete all matching keys
308        if !keys.is_empty() {
309            let _: () = conn
310                .del(&keys)
311                .await
312                .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
313        }
314
315        Ok(())
316    }
317
318    /// Invalidate all cache entries for a resource
319    pub async fn invalidate_resource(&self, namespace: &str, object_id: &str) -> Result<()> {
320        let mut conn = self
321            .pool
322            .get()
323            .await
324            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
325
326        let pattern = format!(
327            "{}check:{}:{}:*:*",
328            self.config.key_prefix, namespace, object_id
329        );
330
331        // Use SCAN to find all matching keys
332        let keys: Vec<String> = redis::cmd("SCAN")
333            .arg(0)
334            .arg("MATCH")
335            .arg(&pattern)
336            .arg("COUNT")
337            .arg(100)
338            .query_async(&mut conn)
339            .await
340            .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
341
342        // Delete all matching keys
343        if !keys.is_empty() {
344            let _: () = conn
345                .del(&keys)
346                .await
347                .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
348        }
349
350        Ok(())
351    }
352
353    /// Invalidate all cache entries affected by a tuple write
354    pub async fn invalidate_tuple(&self, tuple: &RelationTuple) -> Result<()> {
355        // Invalidate all permissions for the resource
356        self.invalidate_resource(&tuple.namespace, &tuple.object_id)
357            .await?;
358
359        // Invalidate all permissions for the subject
360        self.invalidate_subject(&tuple.subject).await?;
361
362        Ok(())
363    }
364
365    /// Clear all cache entries (use with caution)
366    pub async fn clear_all(&self) -> Result<()> {
367        let mut conn = self
368            .pool
369            .get()
370            .await
371            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
372
373        let pattern = format!("{}*", self.config.key_prefix);
374
375        // Use SCAN to find all matching keys
376        let keys: Vec<String> = redis::cmd("SCAN")
377            .arg(0)
378            .arg("MATCH")
379            .arg(&pattern)
380            .arg("COUNT")
381            .arg(1000)
382            .query_async(&mut conn)
383            .await
384            .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
385
386        // Delete all matching keys
387        if !keys.is_empty() {
388            let _: () = conn
389                .del(&keys)
390                .await
391                .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
392        }
393
394        Ok(())
395    }
396
397    /// Get cache statistics
398    pub async fn stats(&self) -> Result<RedisCacheStats> {
399        let mut conn = self
400            .pool
401            .get()
402            .await
403            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
404
405        // Get number of keys with our prefix
406        let pattern = format!("{}*", self.config.key_prefix);
407        let keys: Vec<String> = redis::cmd("SCAN")
408            .arg(0)
409            .arg("MATCH")
410            .arg(&pattern)
411            .arg("COUNT")
412            .arg(1000)
413            .query_async(&mut conn)
414            .await
415            .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
416
417        let total_keys = keys.len();
418
419        // Get hit/miss stats
420        let stats = self.stats.read().await;
421        let hit_count = stats.hit_count;
422        let miss_count = stats.miss_count;
423        let total_checks = hit_count + miss_count;
424        let hit_rate = if total_checks > 0 {
425            hit_count as f64 / total_checks as f64
426        } else {
427            0.0
428        };
429
430        // Get memory usage from Redis INFO
431        let info: String = redis::cmd("INFO")
432            .arg("memory")
433            .query_async(&mut conn)
434            .await
435            .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
436
437        // Parse memory usage from INFO output
438        let memory_usage_bytes = info
439            .lines()
440            .find(|line| line.starts_with("used_memory:"))
441            .and_then(|line| line.split(':').nth(1))
442            .and_then(|val| val.trim().parse::<usize>().ok())
443            .unwrap_or(0);
444
445        Ok(RedisCacheStats {
446            total_keys,
447            hit_count,
448            miss_count,
449            hit_rate,
450            memory_usage_bytes,
451        })
452    }
453
454    /// Health check
455    pub async fn health_check(&self) -> Result<bool> {
456        let mut conn = self
457            .pool
458            .get()
459            .await
460            .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
461
462        // PING Redis server
463        let response: String = redis::cmd("PING")
464            .query_async(&mut conn)
465            .await
466            .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
467
468        Ok(response == "PONG")
469    }
470}
471
472/// Redis cache statistics
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct RedisCacheStats {
475    pub total_keys: usize,
476    pub hit_count: u64,
477    pub miss_count: u64,
478    pub hit_rate: f64,
479    pub memory_usage_bytes: usize,
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[tokio::test]
487    async fn test_redis_cache_creation() {
488        let config = RedisCacheConfig::default();
489        let cache = RedisCache::new(config);
490        assert!(cache.is_ok());
491    }
492
493    #[tokio::test]
494    async fn test_permission_cache_key() {
495        let resource_type = "document".to_string();
496        let resource_id = "123".to_string();
497        let subject = Subject::User("alice".to_string());
498
499        let key =
500            PermissionCacheKey::new(resource_type, resource_id, "viewer".to_string(), subject);
501
502        let redis_key = key.to_redis_key("authz:");
503        assert!(redis_key.contains("document"));
504        assert!(redis_key.contains("123"));
505        assert!(redis_key.contains("viewer"));
506    }
507
508    #[tokio::test]
509    #[ignore = "Requires Redis server running at localhost:6379"]
510    async fn test_redis_cache_integration() {
511        let config = RedisCacheConfig::default();
512        let mut cache = RedisCache::new(config).unwrap();
513
514        // Try to connect, skip test if Redis is not available
515        if cache.connect().await.is_err() {
516            return;
517        }
518
519        let resource_type = "document".to_string();
520        let resource_id = "123".to_string();
521        let subject = Subject::User("alice".to_string());
522        let key =
523            PermissionCacheKey::new(resource_type, resource_id, "viewer".to_string(), subject);
524
525        // Test set
526        assert!(cache.set_permission(&key, true, None).await.is_ok());
527
528        // Test get (should find the cached value)
529        let result = cache.get_permission(&key).await;
530        assert!(result.is_ok());
531        assert_eq!(result.unwrap(), Some(true));
532
533        // Test invalidate
534        assert!(cache.invalidate(&key).await.is_ok());
535
536        // Test get after invalidate (should be None)
537        let result = cache.get_permission(&key).await;
538        assert!(result.is_ok());
539        assert!(result.unwrap().is_none());
540
541        // Test stats
542        let stats = cache.stats().await;
543        assert!(stats.is_ok());
544    }
545}