1use crate::error::{AuthError, Result};
4use crate::models::ApiKey;
5use argon2::{
6 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
7 Argon2,
8};
9use chrono::{DateTime, Duration, Utc};
10use rand::{distributions::Alphanumeric, Rng};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use uuid::Uuid;
15
16pub struct ApiKeyManager {
17 keys: Arc<RwLock<HashMap<String, ApiKey>>>,
18 rate_limiter: Arc<RateLimiter>,
19 argon2: Argon2<'static>,
20}
21
22impl ApiKeyManager {
23 pub fn new() -> Self {
24 Self {
25 keys: Arc::new(RwLock::new(HashMap::new())),
26 rate_limiter: Arc::new(RateLimiter::new()),
27 argon2: Argon2::default(),
28 }
29 }
30
31 pub async fn generate_api_key(
32 &self,
33 user_id: Uuid,
34 name: String,
35 description: Option<String>,
36 scopes: Vec<String>,
37 rate_limit: Option<u32>,
38 expires_in: Option<Duration>,
39 ) -> Result<(String, ApiKey)> {
40 let key = self.generate_key_string();
42 let prefix = &key[..8]; let salt = SaltString::generate(&mut rand::thread_rng());
46 let key_hash = self.argon2
47 .hash_password(key.as_bytes(), &salt)?
48 .to_string();
49
50 let now = Utc::now();
51 let api_key = ApiKey {
52 id: Uuid::new_v4(),
53 key_hash,
54 prefix: prefix.to_string(),
55 user_id,
56 name,
57 description,
58 scopes,
59 rate_limit,
60 expires_at: expires_in.map(|d| now + d),
61 last_used_at: None,
62 created_at: now,
63 revoked: false,
64 };
65
66 let mut keys = self.keys.write().await;
67 keys.insert(api_key.id.to_string(), api_key.clone());
68
69 Ok((key, api_key))
70 }
71
72 fn generate_key_string(&self) -> String {
73 let random_part: String = rand::thread_rng()
74 .sample_iter(&Alphanumeric)
75 .take(48)
76 .map(char::from)
77 .collect();
78
79 format!("avl_{}", random_part)
80 }
81
82 pub async fn verify_api_key(&self, key: &str) -> Result<ApiKey> {
83 if !key.starts_with("avl_") {
84 return Err(AuthError::InvalidApiKey);
85 }
86
87 let prefix = &key[..8.min(key.len())];
88
89 let keys = self.keys.read().await;
90
91 for api_key in keys.values() {
93 if !api_key.prefix.starts_with(prefix) {
94 continue;
95 }
96
97 if api_key.revoked {
98 return Err(AuthError::InvalidApiKey);
99 }
100
101 if let Some(expires_at) = api_key.expires_at {
102 if Utc::now() > expires_at {
103 return Err(AuthError::ApiKeyExpired);
104 }
105 }
106
107 let parsed_hash = PasswordHash::new(&api_key.key_hash)
109 .map_err(|e| AuthError::CryptoError(e.to_string()))?;
110
111 if self.argon2.verify_password(key.as_bytes(), &parsed_hash).is_ok() {
112 if let Some(limit) = api_key.rate_limit {
114 if !self.rate_limiter.check_limit(&api_key.id.to_string(), limit).await {
115 return Err(AuthError::RateLimitExceeded);
116 }
117 }
118
119 return Ok(api_key.clone());
120 }
121 }
122
123 Err(AuthError::InvalidApiKey)
124 }
125
126 pub async fn revoke_api_key(&self, key_id: &Uuid) -> Result<()> {
127 let mut keys = self.keys.write().await;
128
129 if let Some(api_key) = keys.get_mut(&key_id.to_string()) {
130 api_key.revoked = true;
131 Ok(())
132 } else {
133 Err(AuthError::InvalidApiKey)
134 }
135 }
136
137 pub async fn update_last_used(&self, key_id: &Uuid) -> Result<()> {
138 let mut keys = self.keys.write().await;
139
140 if let Some(api_key) = keys.get_mut(&key_id.to_string()) {
141 api_key.last_used_at = Some(Utc::now());
142 Ok(())
143 } else {
144 Err(AuthError::InvalidApiKey)
145 }
146 }
147
148 pub async fn list_user_keys(&self, user_id: &Uuid) -> Vec<ApiKey> {
149 let keys = self.keys.read().await;
150 keys.values()
151 .filter(|k| k.user_id == *user_id && !k.revoked)
152 .cloned()
153 .collect()
154 }
155
156 pub async fn rotate_api_key(&self, old_key_id: &Uuid) -> Result<(String, ApiKey)> {
157 let (user_id, name, description, scopes, rate_limit, expires_duration) = {
159 let keys = self.keys.read().await;
160
161 let old_key = keys
162 .get(&old_key_id.to_string())
163 .ok_or(AuthError::InvalidApiKey)?;
164
165 (
166 old_key.user_id,
167 old_key.name.clone(),
168 old_key.description.clone(),
169 old_key.scopes.clone(),
170 old_key.rate_limit,
171 old_key.expires_at.map(|exp| exp - Utc::now()),
172 )
173 };
174
175 let (new_key, new_api_key) = self.generate_api_key(
177 user_id,
178 name,
179 description,
180 scopes,
181 rate_limit,
182 expires_duration,
183 ).await?;
184
185 self.revoke_api_key(old_key_id).await?;
187
188 Ok((new_key, new_api_key))
189 }
190}
191
192struct RateLimiter {
195 buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
196}
197
198struct TokenBucket {
199 tokens: f64,
200 capacity: f64,
201 refill_rate: f64,
202 last_refill: DateTime<Utc>,
203}
204
205impl RateLimiter {
206 fn new() -> Self {
207 Self {
208 buckets: Arc::new(RwLock::new(HashMap::new())),
209 }
210 }
211
212 async fn check_limit(&self, key: &str, limit: u32) -> bool {
213 let mut buckets = self.buckets.write().await;
214
215 let bucket = buckets.entry(key.to_string()).or_insert_with(|| {
216 TokenBucket {
217 tokens: limit as f64,
218 capacity: limit as f64,
219 refill_rate: limit as f64 / 60.0, last_refill: Utc::now(),
221 }
222 });
223
224 let now = Utc::now();
226 let elapsed = (now - bucket.last_refill).num_seconds() as f64;
227 bucket.tokens = (bucket.tokens + elapsed * bucket.refill_rate).min(bucket.capacity);
228 bucket.last_refill = now;
229
230 if bucket.tokens >= 1.0 {
232 bucket.tokens -= 1.0;
233 true
234 } else {
235 false
236 }
237 }
238
239 async fn reset(&self, key: &str) {
240 let mut buckets = self.buckets.write().await;
241 buckets.remove(key);
242 }
243
244 async fn get_remaining(&self, key: &str) -> Option<u32> {
245 let buckets = self.buckets.read().await;
246 buckets.get(key).map(|b| b.tokens as u32)
247 }
248}
249
250impl Default for ApiKeyManager {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[tokio::test]
261 async fn test_generate_and_verify_api_key() {
262 let manager = ApiKeyManager::new();
263
264 let user_id = Uuid::new_v4();
265 let (key, api_key) = manager
266 .generate_api_key(
267 user_id,
268 "Test Key".to_string(),
269 None,
270 vec!["read".to_string()],
271 Some(100),
272 None,
273 )
274 .await
275 .unwrap();
276
277 assert!(key.starts_with("avl_"));
278
279 let verified = manager.verify_api_key(&key).await.unwrap();
280 assert_eq!(verified.user_id, user_id);
281 }
282
283 #[tokio::test]
284 async fn test_revoke_api_key() {
285 let manager = ApiKeyManager::new();
286
287 let user_id = Uuid::new_v4();
288 let (key, api_key) = manager
289 .generate_api_key(
290 user_id,
291 "Test Key".to_string(),
292 None,
293 vec![],
294 None,
295 None,
296 )
297 .await
298 .unwrap();
299
300 manager.revoke_api_key(&api_key.id).await.unwrap();
301
302 let result = manager.verify_api_key(&key).await;
303 assert!(result.is_err());
304 }
305}