Skip to main content

better_auth_core/adapters/
cache.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Duration, Utc};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6use crate::error::AuthResult;
7
8/// Cache adapter trait for session caching
9#[async_trait]
10pub trait CacheAdapter: Send + Sync {
11    /// Set a value with expiration
12    async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()>;
13
14    /// Get a value by key
15    async fn get(&self, key: &str) -> AuthResult<Option<String>>;
16
17    /// Delete a value by key
18    async fn delete(&self, key: &str) -> AuthResult<()>;
19
20    /// Check if key exists
21    async fn exists(&self, key: &str) -> AuthResult<bool>;
22
23    /// Set expiration for a key
24    async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()>;
25
26    /// Clear all cached values
27    async fn clear(&self) -> AuthResult<()>;
28}
29
30/// In-memory cache adapter for testing and development
31pub struct MemoryCacheAdapter {
32    data: Arc<Mutex<HashMap<String, CacheEntry>>>,
33}
34
35#[derive(Debug, Clone)]
36struct CacheEntry {
37    value: String,
38    expires_at: DateTime<Utc>,
39}
40
41impl MemoryCacheAdapter {
42    pub fn new() -> Self {
43        Self {
44            data: Arc::new(Mutex::new(HashMap::new())),
45        }
46    }
47
48    /// Clean up expired entries
49    fn cleanup_expired(&self) {
50        let mut data = self.data.lock().unwrap();
51        let now = Utc::now();
52        data.retain(|_, entry| entry.expires_at > now);
53    }
54}
55
56impl Default for MemoryCacheAdapter {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62#[async_trait]
63impl CacheAdapter for MemoryCacheAdapter {
64    async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()> {
65        self.cleanup_expired();
66
67        let expires_at = Utc::now() + expires_in;
68        let entry = CacheEntry {
69            value: value.to_string(),
70            expires_at,
71        };
72
73        let mut data = self.data.lock().unwrap();
74        data.insert(key.to_string(), entry);
75
76        Ok(())
77    }
78
79    async fn get(&self, key: &str) -> AuthResult<Option<String>> {
80        self.cleanup_expired();
81
82        let data = self.data.lock().unwrap();
83        let now = Utc::now();
84
85        if let Some(entry) = data.get(key) {
86            if entry.expires_at > now {
87                Ok(Some(entry.value.clone()))
88            } else {
89                Ok(None)
90            }
91        } else {
92            Ok(None)
93        }
94    }
95
96    async fn delete(&self, key: &str) -> AuthResult<()> {
97        let mut data = self.data.lock().unwrap();
98        data.remove(key);
99        Ok(())
100    }
101
102    async fn exists(&self, key: &str) -> AuthResult<bool> {
103        self.cleanup_expired();
104
105        let data = self.data.lock().unwrap();
106        let now = Utc::now();
107
108        if let Some(entry) = data.get(key) {
109            Ok(entry.expires_at > now)
110        } else {
111            Ok(false)
112        }
113    }
114
115    async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()> {
116        let mut data = self.data.lock().unwrap();
117
118        if let Some(entry) = data.get_mut(key) {
119            entry.expires_at = Utc::now() + expires_in;
120        }
121
122        Ok(())
123    }
124
125    async fn clear(&self) -> AuthResult<()> {
126        let mut data = self.data.lock().unwrap();
127        data.clear();
128        Ok(())
129    }
130}
131
132#[cfg(feature = "redis-cache")]
133pub mod redis_adapter {
134    use super::*;
135    use redis::{Client, Commands};
136
137    pub struct RedisAdapter {
138        client: Client,
139    }
140
141    impl RedisAdapter {
142        pub async fn new(redis_url: &str) -> Result<Self, redis::RedisError> {
143            let client = Client::open(redis_url)?;
144            Ok(Self { client })
145        }
146    }
147
148    #[async_trait]
149    impl CacheAdapter for RedisAdapter {
150        async fn set(&self, key: &str, value: &str, expires_in: Duration) -> AuthResult<()> {
151            let mut conn = self
152                .client
153                .get_connection()
154                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
155
156            let seconds = expires_in.num_seconds() as u64;
157            conn.set_ex(key, value, seconds)
158                .map_err(|e| AuthError::internal(format!("Redis set error: {}", e)))?;
159
160            Ok(())
161        }
162
163        async fn get(&self, key: &str) -> AuthResult<Option<String>> {
164            let mut conn = self
165                .client
166                .get_connection()
167                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
168
169            let result: Option<String> = conn
170                .get(key)
171                .map_err(|e| AuthError::internal(format!("Redis get error: {}", e)))?;
172
173            Ok(result)
174        }
175
176        async fn delete(&self, key: &str) -> AuthResult<()> {
177            let mut conn = self
178                .client
179                .get_connection()
180                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
181
182            conn.del(key)
183                .map_err(|e| AuthError::internal(format!("Redis delete error: {}", e)))?;
184
185            Ok(())
186        }
187
188        async fn exists(&self, key: &str) -> AuthResult<bool> {
189            let mut conn = self
190                .client
191                .get_connection()
192                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
193
194            let exists: bool = conn
195                .exists(key)
196                .map_err(|e| AuthError::internal(format!("Redis exists error: {}", e)))?;
197
198            Ok(exists)
199        }
200
201        async fn expire(&self, key: &str, expires_in: Duration) -> AuthResult<()> {
202            let mut conn = self
203                .client
204                .get_connection()
205                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
206
207            let seconds = expires_in.num_seconds() as u64;
208            conn.expire(key, seconds)
209                .map_err(|e| AuthError::internal(format!("Redis expire error: {}", e)))?;
210
211            Ok(())
212        }
213
214        async fn clear(&self) -> AuthResult<()> {
215            let mut conn = self
216                .client
217                .get_connection()
218                .map_err(|e| AuthError::internal(format!("Redis connection error: {}", e)))?;
219
220            redis::cmd("FLUSHDB").execute(&mut conn);
221
222            Ok(())
223        }
224    }
225}
226
227#[cfg(feature = "redis-cache")]
228pub use redis_adapter::RedisAdapter;