better_auth_core/adapters/
cache.rs1use async_trait::async_trait;
2use chrono::{DateTime, Duration, Utc};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6use crate::error::AuthResult;
7
8#[async_trait]
10pub trait CacheAdapter: Send + Sync {
11 async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()>;
13
14 async fn get(&self, key: &str) -> AuthResult<Option<String>>;
16
17 async fn delete(&self, key: &str) -> AuthResult<()>;
19
20 async fn exists(&self, key: &str) -> AuthResult<bool>;
22
23 async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()>;
25
26 async fn clear(&self) -> AuthResult<()>;
28}
29
30pub struct MemoryCacheAdapter {
32 data: Arc<Mutex<HashMap<String, CacheEntry>>>,
33}
34
35#[derive(Debug, Clone)]
36struct CacheEntry {
37 value: String,
38 expires_at: DateTime<Utc>,
39}
40
41impl MemoryCacheAdapter {
42 pub fn new() -> Self {
43 Self {
44 data: Arc::new(Mutex::new(HashMap::new())),
45 }
46 }
47
48 fn cleanup_expired(&self) {
50 let mut data = self.data.lock().unwrap();
51 let now = Utc::now();
52 data.retain(|_, entry| entry.expires_at > now);
53 }
54}
55
56impl Default for MemoryCacheAdapter {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62#[async_trait]
63impl CacheAdapter for MemoryCacheAdapter {
64 async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()> {
65 self.cleanup_expired();
66
67 let expires_at = Utc::now() + expires_in;
68 let entry = CacheEntry {
69 value: value.to_string(),
70 expires_at,
71 };
72
73 let mut data = self.data.lock().unwrap();
74 data.insert(key.to_string(), entry);
75
76 Ok(())
77 }
78
79 async fn get(&self, key: &str) -> AuthResult<Option<String>> {
80 self.cleanup_expired();
81
82 let data = self.data.lock().unwrap();
83 let now = Utc::now();
84
85 if let Some(entry) = data.get(key) {
86 if entry.expires_at > now {
87 Ok(Some(entry.value.clone()))
88 } else {
89 Ok(None)
90 }
91 } else {
92 Ok(None)
93 }
94 }
95
96 async fn delete(&self, key: &str) -> AuthResult<()> {
97 let mut data = self.data.lock().unwrap();
98 data.remove(key);
99 Ok(())
100 }
101
102 async fn exists(&self, key: &str) -> AuthResult<bool> {
103 self.cleanup_expired();
104
105 let data = self.data.lock().unwrap();
106 let now = Utc::now();
107
108 if let Some(entry) = data.get(key) {
109 Ok(entry.expires_at > now)
110 } else {
111 Ok(false)
112 }
113 }
114
115 async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()> {
116 let mut data = self.data.lock().unwrap();
117
118 if let Some(entry) = data.get_mut(key) {
119 entry.expires_at = Utc::now() + expires_in;
120 }
121
122 Ok(())
123 }
124
125 async fn clear(&self) -> AuthResult<()> {
126 let mut data = self.data.lock().unwrap();
127 data.clear();
128 Ok(())
129 }
130}
131
132#[cfg(feature = "redis-cache")]
133pub mod redis_adapter {
134 use super::*;
135 use crate::error::AuthError;
136 use redis::{Client, Commands};
137
138 pub struct RedisAdapter {
139 client: Client,
140 }
141
142 impl RedisAdapter {
143 pub async fn new(redis_url: &str) -> Result<Self, redis::RedisError> {
144 let client = Client::open(redis_url)?;
145 Ok(Self { client })
146 }
147 }
148
149 #[async_trait]
150 impl CacheAdapter for RedisAdapter {
151 async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()> {
152 let mut conn = self
153 .client
154 .get_connection()
155 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
156
157 let seconds = u64::try_from(expires_in.num_seconds())
158 .map_err(|_| AuthError::internal("Redis set_ex requires non-negative TTL"))?;
159 let _: () = conn
160 .set_ex(key, value, seconds)
161 .map_err(|e| AuthError::internal(format!("Redis set error: {}", e)))?;
162
163 Ok(())
164 }
165
166 async fn get(&self, key: &str) -> AuthResult<Option<String>> {
167 let mut conn = self
168 .client
169 .get_connection()
170 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
171
172 let result: Option<String> = conn
173 .get(key)
174 .map_err(|e| AuthError::internal(format!("Redis get error: {}", e)))?;
175
176 Ok(result)
177 }
178
179 async fn delete(&self, key: &str) -> AuthResult<()> {
180 let mut conn = self
181 .client
182 .get_connection()
183 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
184
185 let _: usize = conn
186 .del(key)
187 .map_err(|e| AuthError::internal(format!("Redis delete error: {}", e)))?;
188
189 Ok(())
190 }
191
192 async fn exists(&self, key: &str) -> AuthResult<bool> {
193 let mut conn = self
194 .client
195 .get_connection()
196 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
197
198 let exists: bool = conn
199 .exists(key)
200 .map_err(|e| AuthError::internal(format!("Redis exists error: {}", e)))?;
201
202 Ok(exists)
203 }
204
205 async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()> {
206 let mut conn = self
207 .client
208 .get_connection()
209 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
210
211 let seconds = expires_in.num_seconds();
212 let _: bool = conn
213 .expire(key, seconds)
214 .map_err(|e| AuthError::internal(format!("Redis expire error: {}", e)))?;
215
216 Ok(())
217 }
218
219 async fn clear(&self) -> AuthResult<()> {
220 let mut conn = self
221 .client
222 .get_connection()
223 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
224
225 redis::cmd("FLUSHDB")
226 .query::<()>(&mut conn)
227 .map_err(|e| AuthError::internal(format!("Redis flushdb error: {}", e)))?;
228
229 Ok(())
230 }
231 }
232}
233
234#[cfg(feature = "redis-cache")]
235pub use redis_adapter::RedisAdapter;