1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use rand::RngCore;
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use validator::Validate;
8
9use better_auth_core::adapters::DatabaseAdapter;
10use better_auth_core::entity::{AuthApiKey, AuthSession, AuthUser};
11use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
12use better_auth_core::{AuthError, AuthResult};
13use better_auth_core::{AuthRequest, AuthResponse, CreateApiKey, HttpMethod, UpdateApiKey};
14
15pub struct ApiKeyPlugin {
17 config: ApiKeyConfig,
18}
19
20#[derive(Debug, Clone)]
21pub struct ApiKeyConfig {
22 pub key_length: usize,
23 pub prefix: Option<String>,
24 pub default_remaining: Option<i64>,
25}
26
27impl Default for ApiKeyConfig {
28 fn default() -> Self {
29 Self {
30 key_length: 32,
31 prefix: None,
32 default_remaining: None,
33 }
34 }
35}
36
37#[derive(Debug, Deserialize, Validate)]
40struct CreateKeyRequest {
41 name: Option<String>,
42 prefix: Option<String>,
43 #[serde(rename = "expiresIn")]
44 #[validate(range(min = 1, message = "expiresIn must be greater than 0"))]
45 expires_in: Option<i64>,
46 remaining: Option<i64>,
47 #[serde(rename = "rateLimitEnabled")]
48 rate_limit_enabled: Option<bool>,
49 #[serde(rename = "rateLimitTimeWindow")]
50 rate_limit_time_window: Option<i64>,
51 #[serde(rename = "rateLimitMax")]
52 rate_limit_max: Option<i64>,
53 #[serde(rename = "refillInterval")]
54 refill_interval: Option<i64>,
55 #[serde(rename = "refillAmount")]
56 refill_amount: Option<i64>,
57 permissions: Option<serde_json::Value>,
58 metadata: Option<serde_json::Value>,
59}
60
61#[derive(Debug, Deserialize, Validate)]
62struct UpdateKeyRequest {
63 #[validate(length(min = 1, message = "Key ID is required"))]
64 id: String,
65 name: Option<String>,
66 enabled: Option<bool>,
67 remaining: Option<i64>,
68 #[serde(rename = "rateLimitEnabled")]
69 rate_limit_enabled: Option<bool>,
70 #[serde(rename = "rateLimitTimeWindow")]
71 rate_limit_time_window: Option<i64>,
72 #[serde(rename = "rateLimitMax")]
73 rate_limit_max: Option<i64>,
74 #[serde(rename = "refillInterval")]
75 refill_interval: Option<i64>,
76 #[serde(rename = "refillAmount")]
77 refill_amount: Option<i64>,
78 permissions: Option<serde_json::Value>,
79 metadata: Option<serde_json::Value>,
80}
81
82#[derive(Debug, Deserialize, Validate)]
83struct DeleteKeyRequest {
84 #[validate(length(min = 1, message = "Key ID is required"))]
85 id: String,
86}
87
88#[derive(Debug, Serialize)]
91struct ApiKeyView {
92 id: String,
93 name: Option<String>,
94 start: Option<String>,
95 prefix: Option<String>,
96 #[serde(rename = "userId")]
97 user_id: String,
98 #[serde(rename = "refillInterval")]
99 refill_interval: Option<i64>,
100 #[serde(rename = "refillAmount")]
101 refill_amount: Option<i64>,
102 #[serde(rename = "lastRefillAt")]
103 last_refill_at: Option<String>,
104 enabled: bool,
105 #[serde(rename = "rateLimitEnabled")]
106 rate_limit_enabled: bool,
107 #[serde(rename = "rateLimitTimeWindow")]
108 rate_limit_time_window: Option<i64>,
109 #[serde(rename = "rateLimitMax")]
110 rate_limit_max: Option<i64>,
111 #[serde(rename = "requestCount")]
112 request_count: Option<i64>,
113 remaining: Option<i64>,
114 #[serde(rename = "lastRequest")]
115 last_request: Option<String>,
116 #[serde(rename = "expiresAt")]
117 expires_at: Option<String>,
118 #[serde(rename = "createdAt")]
119 created_at: String,
120 #[serde(rename = "updatedAt")]
121 updated_at: String,
122 permissions: Option<serde_json::Value>,
123 metadata: Option<serde_json::Value>,
124}
125
126#[derive(Debug, Serialize)]
127struct CreateKeyResponse {
128 key: String,
129 #[serde(flatten)]
130 api_key: ApiKeyView,
131}
132
133impl ApiKeyView {
134 fn from_entity(ak: &impl AuthApiKey) -> Self {
135 Self {
136 id: ak.id().to_string(),
137 name: ak.name().map(|s| s.to_string()),
138 start: ak.start().map(|s| s.to_string()),
139 prefix: ak.prefix().map(|s| s.to_string()),
140 user_id: ak.user_id().to_string(),
141 refill_interval: ak.refill_interval(),
142 refill_amount: ak.refill_amount(),
143 last_refill_at: ak.last_refill_at().map(|s| s.to_string()),
144 enabled: ak.enabled(),
145 rate_limit_enabled: ak.rate_limit_enabled(),
146 rate_limit_time_window: ak.rate_limit_time_window(),
147 rate_limit_max: ak.rate_limit_max(),
148 request_count: ak.request_count(),
149 remaining: ak.remaining(),
150 last_request: ak.last_request().map(|s| s.to_string()),
151 expires_at: ak.expires_at().map(|s| s.to_string()),
152 created_at: ak.created_at().to_string(),
153 updated_at: ak.updated_at().to_string(),
154 permissions: ak.permissions().and_then(|s| serde_json::from_str(s).ok()),
155 metadata: ak.metadata().and_then(|s| serde_json::from_str(s).ok()),
156 }
157 }
158}
159
160impl ApiKeyPlugin {
161 #[allow(clippy::new_without_default)]
162 pub fn new() -> Self {
163 Self {
164 config: ApiKeyConfig::default(),
165 }
166 }
167
168 pub fn with_config(config: ApiKeyConfig) -> Self {
169 Self { config }
170 }
171
172 pub fn key_length(mut self, length: usize) -> Self {
173 self.config.key_length = length;
174 self
175 }
176
177 pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
178 self.config.prefix = Some(prefix.into());
179 self
180 }
181
182 pub fn default_remaining(mut self, remaining: i64) -> Self {
183 self.config.default_remaining = Some(remaining);
184 self
185 }
186
187 fn generate_key(&self, custom_prefix: Option<&str>) -> (String, String, String) {
188 let mut bytes = vec![0u8; self.config.key_length];
189 rand::rngs::OsRng.fill_bytes(&mut bytes);
190 let raw = URL_SAFE_NO_PAD.encode(&bytes);
191
192 let start = raw.chars().take(4).collect::<String>();
193
194 let prefix = custom_prefix
195 .or(self.config.prefix.as_deref())
196 .unwrap_or("");
197 let full_key = format!("{}{}", prefix, raw);
198
199 let mut hasher = Sha256::new();
200 hasher.update(full_key.as_bytes());
201 let hash = format!("{:x}", hasher.finalize());
202
203 (full_key, hash, start)
204 }
205
206 async fn get_authenticated_user<DB: DatabaseAdapter>(
207 req: &AuthRequest,
208 ctx: &AuthContext<DB>,
209 ) -> AuthResult<(DB::User, DB::Session)> {
210 let token = req
211 .headers
212 .get("authorization")
213 .and_then(|v| v.strip_prefix("Bearer "))
214 .ok_or(AuthError::Unauthenticated)?;
215
216 let session = ctx
217 .database
218 .get_session(token)
219 .await?
220 .ok_or(AuthError::Unauthenticated)?;
221
222 if session.expires_at() < chrono::Utc::now() {
223 return Err(AuthError::Unauthenticated);
224 }
225
226 let user = ctx
227 .database
228 .get_user_by_id(session.user_id())
229 .await?
230 .ok_or(AuthError::UserNotFound)?;
231
232 Ok((user, session))
233 }
234
235 async fn handle_create<DB: DatabaseAdapter>(
236 &self,
237 req: &AuthRequest,
238 ctx: &AuthContext<DB>,
239 ) -> AuthResult<AuthResponse> {
240 let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
241
242 let create_req: CreateKeyRequest = match better_auth_core::validate_request_body(req) {
243 Ok(v) => v,
244 Err(resp) => return Ok(resp),
245 };
246
247 let (full_key, hash, start) = self.generate_key(create_req.prefix.as_deref());
248
249 let expires_at = if let Some(ms) = create_req.expires_in {
250 let duration = chrono::Duration::try_milliseconds(ms)
251 .ok_or_else(|| AuthError::bad_request("expiresIn is out of range"))?;
252 let dt = chrono::Utc::now()
253 .checked_add_signed(duration)
254 .ok_or_else(|| AuthError::bad_request("expiresIn is out of range"))?;
255 Some(dt.to_rfc3339())
256 } else {
257 None
258 };
259
260 let remaining = create_req.remaining.or(self.config.default_remaining);
261
262 let input = CreateApiKey {
263 user_id: user.id().to_string(),
264 name: create_req.name,
265 prefix: create_req.prefix.or_else(|| self.config.prefix.clone()),
266 key_hash: hash,
267 start: Some(start),
268 expires_at,
269 remaining,
270 rate_limit_enabled: create_req.rate_limit_enabled.unwrap_or(false),
271 rate_limit_time_window: create_req.rate_limit_time_window,
272 rate_limit_max: create_req.rate_limit_max,
273 refill_interval: create_req.refill_interval,
274 refill_amount: create_req.refill_amount,
275 permissions: create_req
276 .permissions
277 .map(|v| serde_json::to_string(&v).unwrap_or_default()),
278 metadata: create_req
279 .metadata
280 .map(|v| serde_json::to_string(&v).unwrap_or_default()),
281 enabled: true,
282 };
283
284 let api_key = ctx.database.create_api_key(input).await?;
285
286 let response = CreateKeyResponse {
287 key: full_key,
288 api_key: ApiKeyView::from_entity(&api_key),
289 };
290
291 Ok(AuthResponse::json(200, &response)?)
292 }
293
294 async fn handle_get<DB: DatabaseAdapter>(
295 &self,
296 req: &AuthRequest,
297 ctx: &AuthContext<DB>,
298 ) -> AuthResult<AuthResponse> {
299 let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
300
301 let id = req
302 .query
303 .get("id")
304 .ok_or_else(|| AuthError::bad_request("Query parameter 'id' is required"))?;
305
306 let api_key = ctx
307 .database
308 .get_api_key_by_id(id)
309 .await?
310 .ok_or_else(|| AuthError::not_found("API key not found"))?;
311
312 if api_key.user_id() != user.id() {
313 return Err(AuthError::not_found("API key not found"));
314 }
315
316 Ok(AuthResponse::json(200, &ApiKeyView::from_entity(&api_key))?)
317 }
318
319 async fn handle_list<DB: DatabaseAdapter>(
320 &self,
321 req: &AuthRequest,
322 ctx: &AuthContext<DB>,
323 ) -> AuthResult<AuthResponse> {
324 let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
325
326 let keys = ctx.database.list_api_keys_by_user(user.id()).await?;
327
328 let views: Vec<ApiKeyView> = keys.iter().map(ApiKeyView::from_entity).collect();
329
330 Ok(AuthResponse::json(200, &views)?)
331 }
332
333 async fn handle_update<DB: DatabaseAdapter>(
334 &self,
335 req: &AuthRequest,
336 ctx: &AuthContext<DB>,
337 ) -> AuthResult<AuthResponse> {
338 let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
339
340 let update_req: UpdateKeyRequest = match better_auth_core::validate_request_body(req) {
341 Ok(v) => v,
342 Err(resp) => return Ok(resp),
343 };
344
345 let existing = ctx
347 .database
348 .get_api_key_by_id(&update_req.id)
349 .await?
350 .ok_or_else(|| AuthError::not_found("API key not found"))?;
351
352 if existing.user_id() != user.id() {
353 return Err(AuthError::not_found("API key not found"));
354 }
355
356 let update = UpdateApiKey {
357 name: update_req.name,
358 enabled: update_req.enabled,
359 remaining: update_req.remaining,
360 rate_limit_enabled: update_req.rate_limit_enabled,
361 rate_limit_time_window: update_req.rate_limit_time_window,
362 rate_limit_max: update_req.rate_limit_max,
363 refill_interval: update_req.refill_interval,
364 refill_amount: update_req.refill_amount,
365 permissions: update_req
366 .permissions
367 .map(|v| serde_json::to_string(&v).unwrap_or_default()),
368 metadata: update_req
369 .metadata
370 .map(|v| serde_json::to_string(&v).unwrap_or_default()),
371 };
372
373 let updated = ctx.database.update_api_key(&update_req.id, update).await?;
374
375 Ok(AuthResponse::json(200, &ApiKeyView::from_entity(&updated))?)
376 }
377
378 async fn handle_delete<DB: DatabaseAdapter>(
379 &self,
380 req: &AuthRequest,
381 ctx: &AuthContext<DB>,
382 ) -> AuthResult<AuthResponse> {
383 let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
384
385 let delete_req: DeleteKeyRequest = match better_auth_core::validate_request_body(req) {
386 Ok(v) => v,
387 Err(resp) => return Ok(resp),
388 };
389
390 let existing = ctx
392 .database
393 .get_api_key_by_id(&delete_req.id)
394 .await?
395 .ok_or_else(|| AuthError::not_found("API key not found"))?;
396
397 if existing.user_id() != user.id() {
398 return Err(AuthError::not_found("API key not found"));
399 }
400
401 ctx.database.delete_api_key(&delete_req.id).await?;
402
403 Ok(AuthResponse::json(
404 200,
405 &serde_json::json!({ "status": true }),
406 )?)
407 }
408}
409
410#[async_trait]
411impl<DB: DatabaseAdapter> AuthPlugin<DB> for ApiKeyPlugin {
412 fn name(&self) -> &'static str {
413 "api-key"
414 }
415
416 fn routes(&self) -> Vec<AuthRoute> {
417 vec![
418 AuthRoute::post("/api-key/create", "api_key_create"),
419 AuthRoute::get("/api-key/get", "api_key_get"),
420 AuthRoute::post("/api-key/update", "api_key_update"),
421 AuthRoute::post("/api-key/delete", "api_key_delete"),
422 AuthRoute::get("/api-key/list", "api_key_list"),
423 ]
424 }
425
426 async fn on_request(
427 &self,
428 req: &AuthRequest,
429 ctx: &AuthContext<DB>,
430 ) -> AuthResult<Option<AuthResponse>> {
431 match (req.method(), req.path()) {
432 (HttpMethod::Post, "/api-key/create") => Ok(Some(self.handle_create(req, ctx).await?)),
433 (HttpMethod::Get, "/api-key/get") => Ok(Some(self.handle_get(req, ctx).await?)),
434 (HttpMethod::Post, "/api-key/update") => Ok(Some(self.handle_update(req, ctx).await?)),
435 (HttpMethod::Post, "/api-key/delete") => Ok(Some(self.handle_delete(req, ctx).await?)),
436 (HttpMethod::Get, "/api-key/list") => Ok(Some(self.handle_list(req, ctx).await?)),
437 _ => Ok(None),
438 }
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use better_auth_core::adapters::{ApiKeyOps, MemoryDatabaseAdapter, SessionOps, UserOps};
446 use better_auth_core::{CreateSession, CreateUser, Session, User};
447 use chrono::{Duration, Utc};
448 use std::collections::HashMap;
449 use std::sync::Arc;
450
451 async fn create_test_context_with_user() -> (AuthContext<MemoryDatabaseAdapter>, User, Session)
452 {
453 let config = Arc::new(better_auth_core::AuthConfig::new(
454 "test-secret-key-at-least-32-chars-long",
455 ));
456 let database = Arc::new(MemoryDatabaseAdapter::new());
457 let ctx = AuthContext::new(config, database.clone());
458
459 let user = database
460 .create_user(
461 CreateUser::new()
462 .with_email("test@example.com")
463 .with_name("Test User"),
464 )
465 .await
466 .unwrap();
467
468 let session = database
469 .create_session(CreateSession {
470 user_id: user.id.clone(),
471 expires_at: Utc::now() + Duration::hours(24),
472 ip_address: Some("127.0.0.1".to_string()),
473 user_agent: Some("test-agent".to_string()),
474 impersonated_by: None,
475 active_organization_id: None,
476 })
477 .await
478 .unwrap();
479
480 (ctx, user, session)
481 }
482
483 async fn create_user_with_session(
484 ctx: &AuthContext<MemoryDatabaseAdapter>,
485 email: &str,
486 ) -> (User, Session) {
487 let user = ctx
488 .database
489 .create_user(
490 CreateUser::new()
491 .with_email(email.to_string())
492 .with_name("Another User"),
493 )
494 .await
495 .unwrap();
496
497 let session = ctx
498 .database
499 .create_session(CreateSession {
500 user_id: user.id.clone(),
501 expires_at: Utc::now() + Duration::hours(24),
502 ip_address: None,
503 user_agent: None,
504 impersonated_by: None,
505 active_organization_id: None,
506 })
507 .await
508 .unwrap();
509
510 (user, session)
511 }
512
513 fn create_auth_request(
514 method: HttpMethod,
515 path: &str,
516 token: Option<&str>,
517 body: Option<serde_json::Value>,
518 query: Option<HashMap<String, String>>,
519 ) -> AuthRequest {
520 let mut headers = HashMap::new();
521 if let Some(token) = token {
522 headers.insert("authorization".to_string(), format!("Bearer {}", token));
523 }
524
525 AuthRequest {
526 method,
527 path: path.to_string(),
528 headers,
529 body: body.map(|b| serde_json::to_vec(&b).unwrap()),
530 query: query.unwrap_or_default(),
531 }
532 }
533
534 fn json_body(response: &AuthResponse) -> serde_json::Value {
535 serde_json::from_slice(&response.body).unwrap()
536 }
537
538 async fn create_key_and_get_id(
539 plugin: &ApiKeyPlugin,
540 ctx: &AuthContext<MemoryDatabaseAdapter>,
541 token: &str,
542 name: &str,
543 ) -> String {
544 let req = create_auth_request(
545 HttpMethod::Post,
546 "/api-key/create",
547 Some(token),
548 Some(serde_json::json!({ "name": name })),
549 None,
550 );
551 let response = plugin.handle_create(&req, ctx).await.unwrap();
552 assert_eq!(response.status, 200);
553 json_body(&response)["id"].as_str().unwrap().to_string()
554 }
555
556 #[tokio::test]
557 async fn test_create_and_get_do_not_expose_hash() {
558 let plugin = ApiKeyPlugin::new().prefix("ba_");
559 let (ctx, _user, session) = create_test_context_with_user().await;
560
561 let create_req = create_auth_request(
562 HttpMethod::Post,
563 "/api-key/create",
564 Some(&session.token),
565 Some(serde_json::json!({ "name": "primary" })),
566 None,
567 );
568 let create_response = plugin.handle_create(&create_req, &ctx).await.unwrap();
569 assert_eq!(create_response.status, 200);
570
571 let body = json_body(&create_response);
572 assert!(body.get("key").is_some());
573 assert!(body.get("key_hash").is_none());
574 assert!(body.get("hash").is_none());
575
576 let id = body["id"].as_str().unwrap();
577 let mut query = HashMap::new();
578 query.insert("id".to_string(), id.to_string());
579
580 let get_req = create_auth_request(
581 HttpMethod::Get,
582 "/api-key/get",
583 Some(&session.token),
584 None,
585 Some(query),
586 );
587 let get_response = plugin.handle_get(&get_req, &ctx).await.unwrap();
588 assert_eq!(get_response.status, 200);
589
590 let get_body = json_body(&get_response);
591 assert!(get_body.get("key").is_none());
592 assert!(get_body.get("key_hash").is_none());
593 }
594
595 #[tokio::test]
596 async fn test_create_rejects_invalid_expires_in() {
597 let plugin = ApiKeyPlugin::new();
598 let (ctx, _user, session) = create_test_context_with_user().await;
599
600 let req = create_auth_request(
601 HttpMethod::Post,
602 "/api-key/create",
603 Some(&session.token),
604 Some(serde_json::json!({ "expiresIn": i64::MIN })),
605 None,
606 );
607 let response = plugin.handle_create(&req, &ctx).await.unwrap();
608 assert_eq!(response.status, 422);
609 }
610
611 #[tokio::test]
612 async fn test_get_update_delete_return_404_for_non_owner() {
613 let plugin = ApiKeyPlugin::new();
614 let (ctx, _user1, session1) = create_test_context_with_user().await;
615 let (_user2, session2) = create_user_with_session(&ctx, "other@example.com").await;
616 let key_id = create_key_and_get_id(&plugin, &ctx, &session1.token, "owner-key").await;
617
618 let mut get_query = HashMap::new();
619 get_query.insert("id".to_string(), key_id.clone());
620 let get_req = create_auth_request(
621 HttpMethod::Get,
622 "/api-key/get",
623 Some(&session2.token),
624 None,
625 Some(get_query),
626 );
627 let get_err = plugin.handle_get(&get_req, &ctx).await.unwrap_err();
628 assert_eq!(get_err.status_code(), 404);
629
630 let update_req = create_auth_request(
631 HttpMethod::Post,
632 "/api-key/update",
633 Some(&session2.token),
634 Some(serde_json::json!({ "id": key_id, "name": "new-name" })),
635 None,
636 );
637 let update_err = plugin.handle_update(&update_req, &ctx).await.unwrap_err();
638 assert_eq!(update_err.status_code(), 404);
639
640 let delete_req = create_auth_request(
641 HttpMethod::Post,
642 "/api-key/delete",
643 Some(&session2.token),
644 Some(serde_json::json!({ "id": key_id })),
645 None,
646 );
647 let delete_err = plugin.handle_delete(&delete_req, &ctx).await.unwrap_err();
648 assert_eq!(delete_err.status_code(), 404);
649 }
650
651 #[tokio::test]
652 async fn test_list_returns_only_user_keys() {
653 let plugin = ApiKeyPlugin::new();
654 let (ctx, user1, session1) = create_test_context_with_user().await;
655 let (_user2, session2) = create_user_with_session(&ctx, "other@example.com").await;
656
657 let _ = create_key_and_get_id(&plugin, &ctx, &session1.token, "u1-key").await;
658 let _ = create_key_and_get_id(&plugin, &ctx, &session2.token, "u2-key").await;
659
660 let list_req = create_auth_request(
661 HttpMethod::Get,
662 "/api-key/list",
663 Some(&session1.token),
664 None,
665 None,
666 );
667 let list_response = plugin.handle_list(&list_req, &ctx).await.unwrap();
668 assert_eq!(list_response.status, 200);
669
670 let list_body = json_body(&list_response);
671 let list = list_body.as_array().unwrap();
672 assert_eq!(list.len(), 1);
673 assert_eq!(list[0]["userId"].as_str().unwrap(), user1.id);
674 assert!(list[0].get("key").is_none());
675 assert!(list[0].get("key_hash").is_none());
676 }
677
678 #[tokio::test]
679 async fn test_owner_can_delete_key() {
680 let plugin = ApiKeyPlugin::new();
681 let (ctx, _user, session) = create_test_context_with_user().await;
682 let key_id = create_key_and_get_id(&plugin, &ctx, &session.token, "to-delete").await;
683
684 let delete_req = create_auth_request(
685 HttpMethod::Post,
686 "/api-key/delete",
687 Some(&session.token),
688 Some(serde_json::json!({ "id": key_id })),
689 None,
690 );
691 let delete_response = plugin.handle_delete(&delete_req, &ctx).await.unwrap();
692 assert_eq!(delete_response.status, 200);
693
694 let deleted = ctx.database.get_api_key_by_id(&key_id).await.unwrap();
695 assert!(deleted.is_none());
696 }
697}