1use crate::{RelationTuple, Subject};
10use deadpool_redis::{Config, Pool, Runtime};
11use redis::AsyncCommands;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::time::Duration;
15use thiserror::Error;
16use tokio::sync::RwLock;
17
18#[derive(Error, Debug)]
19pub enum RedisCacheError {
20 #[error("Redis connection error: {0}")]
21 ConnectionError(String),
22
23 #[error("Serialization error: {0}")]
24 SerializationError(String),
25
26 #[error("Cache operation failed: {0}")]
27 OperationError(String),
28}
29
30pub type Result<T> = std::result::Result<T, RedisCacheError>;
31
32#[derive(Debug, Clone)]
34pub struct RedisCacheConfig {
35 pub url: String,
37
38 pub default_ttl: Duration,
40
41 pub key_prefix: String,
43
44 pub enable_pubsub: bool,
46
47 pub invalidation_channel: String,
49
50 pub pool_size: u32,
52}
53
54impl Default for RedisCacheConfig {
55 fn default() -> Self {
56 Self {
57 url: "redis://localhost:6379".to_string(),
58 default_ttl: Duration::from_secs(300),
59 key_prefix: "authz:".to_string(),
60 enable_pubsub: true,
61 invalidation_channel: "authz:invalidate".to_string(),
62 pool_size: 10,
63 }
64 }
65}
66
67impl RedisCacheConfig {
68 pub fn with_url(mut self, url: String) -> Self {
69 self.url = url;
70 self
71 }
72
73 pub fn with_ttl(mut self, ttl: Duration) -> Self {
74 self.default_ttl = ttl;
75 self
76 }
77
78 pub fn with_key_prefix(mut self, prefix: String) -> Self {
79 self.key_prefix = prefix;
80 self
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct CachedPermission {
87 pub allowed: bool,
88 pub cached_at: i64, }
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct PermissionCacheKey {
94 pub resource_type: String,
95 pub resource_id: String,
96 pub relation: String,
97 pub subject: Subject,
98}
99
100impl PermissionCacheKey {
101 pub fn new(
102 resource_type: String,
103 resource_id: String,
104 relation: String,
105 subject: Subject,
106 ) -> Self {
107 Self {
108 resource_type,
109 resource_id,
110 relation,
111 subject,
112 }
113 }
114
115 pub fn to_redis_key(&self, prefix: &str) -> String {
117 format!(
118 "{}check:{}:{}:{}:{}",
119 prefix,
120 self.resource_type,
121 self.resource_id,
122 self.relation,
123 subject_to_key(&self.subject)
124 )
125 }
126}
127
128fn subject_to_key(subject: &Subject) -> String {
129 match subject {
130 Subject::User(id) => format!("user:{}", id),
131 Subject::UserSet {
132 namespace,
133 object_id,
134 relation,
135 } => {
136 format!("userset:{}:{}:{}", namespace, object_id, relation)
137 }
138 }
139}
140
141pub struct RedisCache {
143 config: RedisCacheConfig,
144 pool: Pool,
145 stats: Arc<RwLock<RedisCacheStatsInternal>>,
146}
147
148#[derive(Debug, Default)]
150struct RedisCacheStatsInternal {
151 hit_count: u64,
152 miss_count: u64,
153}
154
155impl RedisCache {
156 pub fn new(config: RedisCacheConfig) -> Result<Self> {
158 let redis_config = Config::from_url(&config.url);
159 let pool = redis_config
160 .create_pool(Some(Runtime::Tokio1))
161 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
162
163 Ok(Self {
164 config,
165 pool,
166 stats: Arc::new(RwLock::new(RedisCacheStatsInternal::default())),
167 })
168 }
169
170 pub async fn connect(&mut self) -> Result<()> {
172 let mut conn = self
173 .pool
174 .get()
175 .await
176 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
177
178 redis::cmd("PING")
180 .query_async::<String>(&mut conn)
181 .await
182 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
183
184 Ok(())
185 }
186
187 pub async fn get_permission(&self, key: &PermissionCacheKey) -> Result<Option<bool>> {
189 let mut conn = self
190 .pool
191 .get()
192 .await
193 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
194
195 let redis_key = key.to_redis_key(&self.config.key_prefix);
196
197 let result: Option<String> = conn
198 .get(&redis_key)
199 .await
200 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
201
202 match result {
203 Some(json_str) => {
204 let cached: CachedPermission = serde_json::from_str(&json_str)
205 .map_err(|e| RedisCacheError::SerializationError(e.to_string()))?;
206
207 let mut stats = self.stats.write().await;
209 stats.hit_count += 1;
210
211 Ok(Some(cached.allowed))
212 }
213 None => {
214 let mut stats = self.stats.write().await;
216 stats.miss_count += 1;
217
218 Ok(None)
219 }
220 }
221 }
222
223 pub async fn set_permission(
225 &self,
226 key: &PermissionCacheKey,
227 allowed: bool,
228 ttl: Option<Duration>,
229 ) -> Result<()> {
230 let mut conn = self
231 .pool
232 .get()
233 .await
234 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
235
236 let redis_key = key.to_redis_key(&self.config.key_prefix);
237 let ttl_secs = ttl.unwrap_or(self.config.default_ttl).as_secs() as i64;
238
239 let cached = CachedPermission {
240 allowed,
241 cached_at: chrono::Utc::now().timestamp(),
242 };
243
244 let json_str = serde_json::to_string(&cached)
245 .map_err(|e| RedisCacheError::SerializationError(e.to_string()))?;
246
247 let _: () = conn
249 .set_ex(&redis_key, json_str, ttl_secs as u64)
250 .await
251 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
252
253 Ok(())
254 }
255
256 pub async fn invalidate(&self, key: &PermissionCacheKey) -> Result<()> {
258 let mut conn = self
259 .pool
260 .get()
261 .await
262 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
263
264 let redis_key = key.to_redis_key(&self.config.key_prefix);
265
266 let _: () = conn
267 .del(&redis_key)
268 .await
269 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
270
271 if self.config.enable_pubsub {
273 let msg = serde_json::to_string(&key)
274 .map_err(|e| RedisCacheError::SerializationError(e.to_string()))?;
275
276 let _: () = conn
277 .publish(&self.config.invalidation_channel, msg)
278 .await
279 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
280 }
281
282 Ok(())
283 }
284
285 pub async fn invalidate_subject(&self, subject: &Subject) -> Result<()> {
287 let mut conn = self
288 .pool
289 .get()
290 .await
291 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
292
293 let subject_key = subject_to_key(subject);
294 let pattern = format!("{}check:*:*:*:{}", self.config.key_prefix, subject_key);
295
296 let keys: Vec<String> = redis::cmd("SCAN")
298 .arg(0)
299 .arg("MATCH")
300 .arg(&pattern)
301 .arg("COUNT")
302 .arg(100)
303 .query_async(&mut conn)
304 .await
305 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
306
307 if !keys.is_empty() {
309 let _: () = conn
310 .del(&keys)
311 .await
312 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
313 }
314
315 Ok(())
316 }
317
318 pub async fn invalidate_resource(&self, namespace: &str, object_id: &str) -> Result<()> {
320 let mut conn = self
321 .pool
322 .get()
323 .await
324 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
325
326 let pattern = format!(
327 "{}check:{}:{}:*:*",
328 self.config.key_prefix, namespace, object_id
329 );
330
331 let keys: Vec<String> = redis::cmd("SCAN")
333 .arg(0)
334 .arg("MATCH")
335 .arg(&pattern)
336 .arg("COUNT")
337 .arg(100)
338 .query_async(&mut conn)
339 .await
340 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
341
342 if !keys.is_empty() {
344 let _: () = conn
345 .del(&keys)
346 .await
347 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
348 }
349
350 Ok(())
351 }
352
353 pub async fn invalidate_tuple(&self, tuple: &RelationTuple) -> Result<()> {
355 self.invalidate_resource(&tuple.namespace, &tuple.object_id)
357 .await?;
358
359 self.invalidate_subject(&tuple.subject).await?;
361
362 Ok(())
363 }
364
365 pub async fn clear_all(&self) -> Result<()> {
367 let mut conn = self
368 .pool
369 .get()
370 .await
371 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
372
373 let pattern = format!("{}*", self.config.key_prefix);
374
375 let keys: Vec<String> = redis::cmd("SCAN")
377 .arg(0)
378 .arg("MATCH")
379 .arg(&pattern)
380 .arg("COUNT")
381 .arg(1000)
382 .query_async(&mut conn)
383 .await
384 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
385
386 if !keys.is_empty() {
388 let _: () = conn
389 .del(&keys)
390 .await
391 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
392 }
393
394 Ok(())
395 }
396
397 pub async fn stats(&self) -> Result<RedisCacheStats> {
399 let mut conn = self
400 .pool
401 .get()
402 .await
403 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
404
405 let pattern = format!("{}*", self.config.key_prefix);
407 let keys: Vec<String> = redis::cmd("SCAN")
408 .arg(0)
409 .arg("MATCH")
410 .arg(&pattern)
411 .arg("COUNT")
412 .arg(1000)
413 .query_async(&mut conn)
414 .await
415 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
416
417 let total_keys = keys.len();
418
419 let stats = self.stats.read().await;
421 let hit_count = stats.hit_count;
422 let miss_count = stats.miss_count;
423 let total_checks = hit_count + miss_count;
424 let hit_rate = if total_checks > 0 {
425 hit_count as f64 / total_checks as f64
426 } else {
427 0.0
428 };
429
430 let info: String = redis::cmd("INFO")
432 .arg("memory")
433 .query_async(&mut conn)
434 .await
435 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
436
437 let memory_usage_bytes = info
439 .lines()
440 .find(|line| line.starts_with("used_memory:"))
441 .and_then(|line| line.split(':').nth(1))
442 .and_then(|val| val.trim().parse::<usize>().ok())
443 .unwrap_or(0);
444
445 Ok(RedisCacheStats {
446 total_keys,
447 hit_count,
448 miss_count,
449 hit_rate,
450 memory_usage_bytes,
451 })
452 }
453
454 pub async fn health_check(&self) -> Result<bool> {
456 let mut conn = self
457 .pool
458 .get()
459 .await
460 .map_err(|e| RedisCacheError::ConnectionError(e.to_string()))?;
461
462 let response: String = redis::cmd("PING")
464 .query_async(&mut conn)
465 .await
466 .map_err(|e| RedisCacheError::OperationError(e.to_string()))?;
467
468 Ok(response == "PONG")
469 }
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct RedisCacheStats {
475 pub total_keys: usize,
476 pub hit_count: u64,
477 pub miss_count: u64,
478 pub hit_rate: f64,
479 pub memory_usage_bytes: usize,
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[tokio::test]
487 async fn test_redis_cache_creation() {
488 let config = RedisCacheConfig::default();
489 let cache = RedisCache::new(config);
490 assert!(cache.is_ok());
491 }
492
493 #[tokio::test]
494 async fn test_permission_cache_key() {
495 let resource_type = "document".to_string();
496 let resource_id = "123".to_string();
497 let subject = Subject::User("alice".to_string());
498
499 let key =
500 PermissionCacheKey::new(resource_type, resource_id, "viewer".to_string(), subject);
501
502 let redis_key = key.to_redis_key("authz:");
503 assert!(redis_key.contains("document"));
504 assert!(redis_key.contains("123"));
505 assert!(redis_key.contains("viewer"));
506 }
507
508 #[tokio::test]
509 #[ignore = "Requires Redis server running at localhost:6379"]
510 async fn test_redis_cache_integration() {
511 let config = RedisCacheConfig::default();
512 let mut cache = RedisCache::new(config).unwrap();
513
514 if cache.connect().await.is_err() {
516 return;
517 }
518
519 let resource_type = "document".to_string();
520 let resource_id = "123".to_string();
521 let subject = Subject::User("alice".to_string());
522 let key =
523 PermissionCacheKey::new(resource_type, resource_id, "viewer".to_string(), subject);
524
525 assert!(cache.set_permission(&key, true, None).await.is_ok());
527
528 let result = cache.get_permission(&key).await;
530 assert!(result.is_ok());
531 assert_eq!(result.unwrap(), Some(true));
532
533 assert!(cache.invalidate(&key).await.is_ok());
535
536 let result = cache.get_permission(&key).await;
538 assert!(result.is_ok());
539 assert!(result.unwrap().is_none());
540
541 let stats = cache.stats().await;
543 assert!(stats.is_ok());
544 }
545}