better_auth_core/adapters/
cache.rs1use async_trait::async_trait;
2use chrono::{DateTime, Duration, Utc};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6use crate::error::AuthResult;
7
8#[async_trait]
10pub trait CacheAdapter: Send + Sync {
11 async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()>;
13
14 async fn get(&self, key: &str) -> AuthResult<Option<String>>;
16
17 async fn delete(&self, key: &str) -> AuthResult<()>;
19
20 async fn exists(&self, key: &str) -> AuthResult<bool>;
22
23 async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()>;
25
26 async fn clear(&self) -> AuthResult<()>;
28}
29
30pub struct MemoryCacheAdapter {
32 data: Arc<Mutex<HashMap<String, CacheEntry>>>,
33}
34
35#[derive(Debug, Clone)]
36struct CacheEntry {
37 value: String,
38 expires_at: DateTime<Utc>,
39}
40
41impl MemoryCacheAdapter {
42 pub fn new() -> Self {
43 Self {
44 data: Arc::new(Mutex::new(HashMap::new())),
45 }
46 }
47
48 fn cleanup_expired(&self) {
50 let mut data = self.data.lock().unwrap();
51 let now = Utc::now();
52 data.retain(|_, entry| entry.expires_at > now);
53 }
54}
55
56impl Default for MemoryCacheAdapter {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62#[async_trait]
63impl CacheAdapter for MemoryCacheAdapter {
64 async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()> {
65 self.cleanup_expired();
66
67 let expires_at = Utc::now() + expires_in;
68 let entry = CacheEntry {
69 value: value.to_string(),
70 expires_at,
71 };
72
73 let mut data = self.data.lock().unwrap();
74 data.insert(key.to_string(), entry);
75
76 Ok(())
77 }
78
79 async fn get(&self, key: &str) -> AuthResult<Option<String>> {
80 self.cleanup_expired();
81
82 let data = self.data.lock().unwrap();
83 let now = Utc::now();
84
85 if let Some(entry) = data.get(key) {
86 if entry.expires_at > now {
87 Ok(Some(entry.value.clone()))
88 } else {
89 Ok(None)
90 }
91 } else {
92 Ok(None)
93 }
94 }
95
96 async fn delete(&self, key: &str) -> AuthResult<()> {
97 let mut data = self.data.lock().unwrap();
98 data.remove(key);
99 Ok(())
100 }
101
102 async fn exists(&self, key: &str) -> AuthResult<bool> {
103 self.cleanup_expired();
104
105 let data = self.data.lock().unwrap();
106 let now = Utc::now();
107
108 if let Some(entry) = data.get(key) {
109 Ok(entry.expires_at > now)
110 } else {
111 Ok(false)
112 }
113 }
114
115 async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()> {
116 let mut data = self.data.lock().unwrap();
117
118 if let Some(entry) = data.get_mut(key) {
119 entry.expires_at = Utc::now() + expires_in;
120 }
121
122 Ok(())
123 }
124
125 async fn clear(&self) -> AuthResult<()> {
126 let mut data = self.data.lock().unwrap();
127 data.clear();
128 Ok(())
129 }
130}
131
132#[cfg(feature = "redis-cache")]
133pub mod redis_adapter {
134 use super::*;
135 use redis::{Client, Commands};
136
137 pub struct RedisAdapter {
138 client: Client,
139 }
140
141 impl RedisAdapter {
142 pub async fn new(redis_url: &str) -> Result<Self, redis::RedisError> {
143 let client = Client::open(redis_url)?;
144 Ok(Self { client })
145 }
146 }
147
148 #[async_trait]
149 impl CacheAdapter for RedisAdapter {
150 async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()> {
151 let mut conn = self
152 .client
153 .get_connection()
154 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
155
156 let seconds = expires_in.num_seconds() as u64;
157 conn.set_ex(key, value, seconds)
158 .map_err(|e| AuthError::internal(format!("Redis set error: {}", e)))?;
159
160 Ok(())
161 }
162
163 async fn get(&self, key: &str) -> AuthResult<Option<String>> {
164 let mut conn = self
165 .client
166 .get_connection()
167 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
168
169 let result: Option<String> = conn
170 .get(key)
171 .map_err(|e| AuthError::internal(format!("Redis get error: {}", e)))?;
172
173 Ok(result)
174 }
175
176 async fn delete(&self, key: &str) -> AuthResult<()> {
177 let mut conn = self
178 .client
179 .get_connection()
180 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
181
182 conn.del(key)
183 .map_err(|e| AuthError::internal(format!("Redis delete error: {}", e)))?;
184
185 Ok(())
186 }
187
188 async fn exists(&self, key: &str) -> AuthResult<bool> {
189 let mut conn = self
190 .client
191 .get_connection()
192 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
193
194 let exists: bool = conn
195 .exists(key)
196 .map_err(|e| AuthError::internal(format!("Redis exists error: {}", e)))?;
197
198 Ok(exists)
199 }
200
201 async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()> {
202 let mut conn = self
203 .client
204 .get_connection()
205 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
206
207 let seconds = expires_in.num_seconds() as u64;
208 conn.expire(key, seconds)
209 .map_err(|e| AuthError::internal(format!("Redis expire error: {}", e)))?;
210
211 Ok(())
212 }
213
214 async fn clear(&self) -> AuthResult<()> {
215 let mut conn = self
216 .client
217 .get_connection()
218 .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
219
220 redis::cmd("FLUSHDB").execute(&mut conn);
221
222 Ok(())
223 }
224 }
225}
226
227#[cfg(feature = "redis-cache")]
228pub use redis_adapter::RedisAdapter;