Skip to main content

better_auth_core/adapters/
cache.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Duration, Utc};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6use crate::error::AuthResult;
7
8/// Cache adapter trait for session caching
9#[async_trait]
10pub trait CacheAdapter: Send + Sync {
11    /// Set a value with expiration
12    async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()>;
13
14    /// Get a value by key
15    async fn get(&self, key: &str) -> AuthResult<Option<String>>;
16
17    /// Delete a value by key
18    async fn delete(&self, key: &str) -> AuthResult<()>;
19
20    /// Check if key exists
21    async fn exists(&self, key: &str) -> AuthResult<bool>;
22
23    /// Set expiration for a key
24    async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()>;
25
26    /// Clear all cached values
27    async fn clear(&self) -> AuthResult<()>;
28}
29
30/// In-memory cache adapter for testing and development
31pub struct MemoryCacheAdapter {
32    data: Arc<Mutex<HashMap<String, CacheEntry>>>,
33}
34
35#[derive(Debug, Clone)]
36struct CacheEntry {
37    value: String,
38    expires_at: DateTime<Utc>,
39}
40
41impl MemoryCacheAdapter {
42    pub fn new() -> Self {
43        Self {
44            data: Arc::new(Mutex::new(HashMap::new())),
45        }
46    }
47
48    /// Clean up expired entries
49    fn cleanup_expired(&self) {
50        let mut data = self.data.lock().unwrap();
51        let now = Utc::now();
52        data.retain(|_, entry| entry.expires_at > now);
53    }
54}
55
56impl Default for MemoryCacheAdapter {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62#[async_trait]
63impl CacheAdapter for MemoryCacheAdapter {
64    async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()> {
65        self.cleanup_expired();
66
67        let expires_at = Utc::now() + expires_in;
68        let entry = CacheEntry {
69            value: value.to_string(),
70            expires_at,
71        };
72
73        let mut data = self.data.lock().unwrap();
74        data.insert(key.to_string(), entry);
75
76        Ok(())
77    }
78
79    async fn get(&self, key: &str) -> AuthResult<Option<String>> {
80        self.cleanup_expired();
81
82        let data = self.data.lock().unwrap();
83        let now = Utc::now();
84
85        if let Some(entry) = data.get(key) {
86            if entry.expires_at > now {
87                Ok(Some(entry.value.clone()))
88            } else {
89                Ok(None)
90            }
91        } else {
92            Ok(None)
93        }
94    }
95
96    async fn delete(&self, key: &str) -> AuthResult<()> {
97        let mut data = self.data.lock().unwrap();
98        data.remove(key);
99        Ok(())
100    }
101
102    async fn exists(&self, key: &str) -> AuthResult<bool> {
103        self.cleanup_expired();
104
105        let data = self.data.lock().unwrap();
106        let now = Utc::now();
107
108        if let Some(entry) = data.get(key) {
109            Ok(entry.expires_at > now)
110        } else {
111            Ok(false)
112        }
113    }
114
115    async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()> {
116        let mut data = self.data.lock().unwrap();
117
118        if let Some(entry) = data.get_mut(key) {
119            entry.expires_at = Utc::now() + expires_in;
120        }
121
122        Ok(())
123    }
124
125    async fn clear(&self) -> AuthResult<()> {
126        let mut data = self.data.lock().unwrap();
127        data.clear();
128        Ok(())
129    }
130}
131
132#[cfg(feature = "redis-cache")]
133pub mod redis_adapter {
134    use super::*;
135    use crate::error::AuthError;
136    use redis::{Client, Commands};
137
138    pub struct RedisAdapter {
139        client: Client,
140    }
141
142    impl RedisAdapter {
143        pub async fn new(redis_url: &str) -> Result<Self, redis::RedisError> {
144            let client = Client::open(redis_url)?;
145            Ok(Self { client })
146        }
147    }
148
149    #[async_trait]
150    impl CacheAdapter for RedisAdapter {
151        async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()> {
152            let mut conn = self
153                .client
154                .get_connection()
155                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
156
157            let seconds = u64::try_from(expires_in.num_seconds())
158                .map_err(|_| AuthError::internal("Redis set_ex requires non-negative TTL"))?;
159            let _: () = conn
160                .set_ex(key, value, seconds)
161                .map_err(|e| AuthError::internal(format!("Redis set error: {}", e)))?;
162
163            Ok(())
164        }
165
166        async fn get(&self, key: &str) -> AuthResult<Option<String>> {
167            let mut conn = self
168                .client
169                .get_connection()
170                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
171
172            let result: Option<String> = conn
173                .get(key)
174                .map_err(|e| AuthError::internal(format!("Redis get error: {}", e)))?;
175
176            Ok(result)
177        }
178
179        async fn delete(&self, key: &str) -> AuthResult<()> {
180            let mut conn = self
181                .client
182                .get_connection()
183                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
184
185            let _: usize = conn
186                .del(key)
187                .map_err(|e| AuthError::internal(format!("Redis delete error: {}", e)))?;
188
189            Ok(())
190        }
191
192        async fn exists(&self, key: &str) -> AuthResult<bool> {
193            let mut conn = self
194                .client
195                .get_connection()
196                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
197
198            let exists: bool = conn
199                .exists(key)
200                .map_err(|e| AuthError::internal(format!("Redis exists error: {}", e)))?;
201
202            Ok(exists)
203        }
204
205        async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()> {
206            let mut conn = self
207                .client
208                .get_connection()
209                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
210
211            let seconds = expires_in.num_seconds();
212            let _: bool = conn
213                .expire(key, seconds)
214                .map_err(|e| AuthError::internal(format!("Redis expire error: {}", e)))?;
215
216            Ok(())
217        }
218
219        async fn clear(&self) -> AuthResult<()> {
220            let mut conn = self
221                .client
222                .get_connection()
223                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
224
225            redis::cmd("FLUSHDB")
226                .query::<()>(&mut conn)
227                .map_err(|e| AuthError::internal(format!("Redis flushdb error: {}", e)))?;
228
229            Ok(())
230        }
231    }
232}
233
234#[cfg(feature = "redis-cache")]
235pub use redis_adapter::RedisAdapter;