1use crate::error::CacheError;
4use async_trait::async_trait;
5use redis::aio::ConnectionManager;
6use redis::AsyncCommands;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9
10#[async_trait]
12pub trait Cache: Send + Sync {
13 async fn get<T>(&self, key: &str) -> Result<Option<T>, CacheError>
15 where
16 T: for<'de> Deserialize<'de>;
17
18 async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<(), CacheError>
20 where
21 T: Serialize + Sync;
22
23 async fn delete(&self, key: &str) -> Result<(), CacheError>;
25
26 async fn exists(&self, key: &str) -> Result<bool, CacheError>;
28
29 async fn invalidate_pattern(&self, pattern: &str) -> Result<(), CacheError>;
31
32 async fn clear(&self) -> Result<(), CacheError>;
34}
35
36pub struct CacheClient {
38 connection: ConnectionManager,
39}
40
41impl CacheClient {
42 pub async fn new(url: &str) -> Result<Self, CacheError> {
44 let client =
45 redis::Client::open(url).map_err(|e| CacheError::ConnectionError(e.to_string()))?;
46 let connection = ConnectionManager::new(client)
47 .await
48 .map_err(|e| CacheError::ConnectionError(e.to_string()))?;
49
50 Ok(Self { connection })
51 }
52
53 pub async fn default() -> Result<Self, CacheError> {
55 Self::new("redis://127.0.0.1:6379").await
56 }
57}
58
59#[async_trait]
60impl Cache for CacheClient {
61 async fn get<T>(&self, key: &str) -> Result<Option<T>, CacheError>
62 where
63 T: for<'de> Deserialize<'de>,
64 {
65 let mut conn = self.connection.clone();
66 let value: Option<String> = conn
67 .get(key)
68 .await
69 .map_err(|e: redis::RedisError| CacheError::from(e))?;
70
71 match value {
72 Some(v) => {
73 let deserialized: T = serde_json::from_str(&v)?;
74 Ok(Some(deserialized))
75 }
76 None => Ok(None),
77 }
78 }
79
80 async fn set<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<(), CacheError>
81 where
82 T: Serialize + Sync,
83 {
84 let mut conn = self.connection.clone();
85 let serialized = serde_json::to_string(value)?;
86
87 if let Some(ttl) = ttl {
88 conn.set_ex::<_, _, ()>(key, serialized, ttl.as_secs())
89 .await?;
90 } else {
91 conn.set::<_, _, ()>(key, serialized).await?;
92 }
93
94 Ok(())
95 }
96
97 async fn delete(&self, key: &str) -> Result<(), CacheError> {
98 let mut conn = self.connection.clone();
99 conn.del::<_, ()>(key).await?;
100 Ok(())
101 }
102
103 async fn exists(&self, key: &str) -> Result<bool, CacheError> {
104 let mut conn = self.connection.clone();
105 let exists: bool = conn.exists(key).await?;
106 Ok(exists)
107 }
108
109 async fn invalidate_pattern(&self, pattern: &str) -> Result<(), CacheError> {
110 let mut conn = self.connection.clone();
111 let keys: Vec<String> = conn.keys(pattern).await?;
112 if !keys.is_empty() {
113 conn.del::<_, ()>(keys).await?;
114 }
115 Ok(())
116 }
117
118 async fn clear(&self) -> Result<(), CacheError> {
119 let mut conn = self.connection.clone();
120 redis::cmd("FLUSHDB").query_async::<()>(&mut conn).await?;
121 Ok(())
122 }
123}