1use base64::Engine;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3use governor::clock::DefaultClock;
4use governor::state::{InMemoryState, NotKeyed};
5use governor::{Quota, RateLimiter};
6use rand::RngCore;
7use serde::Serialize;
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::num::NonZeroU32;
11use std::sync::Mutex;
12
13use better_auth_core::adapters::DatabaseAdapter;
14use better_auth_core::entity::{AuthApiKey, AuthUser};
15use better_auth_core::{AuthContext, AuthError, AuthResult, BeforeRequestAction};
16use better_auth_core::{AuthRequest, AuthResponse, UpdateApiKey};
17
18pub(super) mod handlers;
19pub(super) mod types;
20
21#[cfg(test)]
22mod tests;
23
24use handlers::*;
25use types::*;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
33pub enum ApiKeyErrorCode {
34 #[serde(rename = "INVALID_API_KEY")]
35 InvalidApiKey,
36 #[serde(rename = "KEY_DISABLED")]
37 KeyDisabled,
38 #[serde(rename = "KEY_EXPIRED")]
39 KeyExpired,
40 #[serde(rename = "USAGE_EXCEEDED")]
41 UsageExceeded,
42 #[serde(rename = "KEY_NOT_FOUND")]
43 KeyNotFound,
44 #[serde(rename = "RATE_LIMITED")]
45 RateLimited,
46 #[serde(rename = "UNAUTHORIZED_SESSION")]
47 UnauthorizedSession,
48 #[serde(rename = "INVALID_PREFIX_LENGTH")]
49 InvalidPrefixLength,
50 #[serde(rename = "INVALID_NAME_LENGTH")]
51 InvalidNameLength,
52 #[serde(rename = "METADATA_DISABLED")]
53 MetadataDisabled,
54 #[serde(rename = "NO_VALUES_TO_UPDATE")]
55 NoValuesToUpdate,
56 #[serde(rename = "KEY_DISABLED_EXPIRATION")]
57 KeyDisabledExpiration,
58 #[serde(rename = "EXPIRES_IN_IS_TOO_SMALL")]
59 ExpiresInTooSmall,
60 #[serde(rename = "EXPIRES_IN_IS_TOO_LARGE")]
61 ExpiresInTooLarge,
62 #[serde(rename = "INVALID_REMAINING")]
63 InvalidRemaining,
64 #[serde(rename = "REFILL_AMOUNT_AND_INTERVAL_REQUIRED")]
65 RefillAmountAndIntervalRequired,
66 #[serde(rename = "NAME_REQUIRED")]
67 NameRequired,
68 #[serde(rename = "INVALID_USER_ID_FROM_API_KEY")]
69 InvalidUserIdFromApiKey,
70 #[serde(rename = "SERVER_ONLY_PROPERTY")]
71 ServerOnlyProperty,
72 #[serde(rename = "FAILED_TO_UPDATE_API_KEY")]
73 FailedToUpdateApiKey,
74 #[serde(rename = "INVALID_METADATA_TYPE")]
75 InvalidMetadataType,
76}
77
78impl ApiKeyErrorCode {
79 pub fn as_str(self) -> &'static str {
82 match self {
83 Self::InvalidApiKey => "INVALID_API_KEY",
84 Self::KeyDisabled => "KEY_DISABLED",
85 Self::KeyExpired => "KEY_EXPIRED",
86 Self::UsageExceeded => "USAGE_EXCEEDED",
87 Self::KeyNotFound => "KEY_NOT_FOUND",
88 Self::RateLimited => "RATE_LIMITED",
89 Self::UnauthorizedSession => "UNAUTHORIZED_SESSION",
90 Self::InvalidPrefixLength => "INVALID_PREFIX_LENGTH",
91 Self::InvalidNameLength => "INVALID_NAME_LENGTH",
92 Self::MetadataDisabled => "METADATA_DISABLED",
93 Self::NoValuesToUpdate => "NO_VALUES_TO_UPDATE",
94 Self::KeyDisabledExpiration => "KEY_DISABLED_EXPIRATION",
95 Self::ExpiresInTooSmall => "EXPIRES_IN_IS_TOO_SMALL",
96 Self::ExpiresInTooLarge => "EXPIRES_IN_IS_TOO_LARGE",
97 Self::InvalidRemaining => "INVALID_REMAINING",
98 Self::RefillAmountAndIntervalRequired => "REFILL_AMOUNT_AND_INTERVAL_REQUIRED",
99 Self::NameRequired => "NAME_REQUIRED",
100 Self::InvalidUserIdFromApiKey => "INVALID_USER_ID_FROM_API_KEY",
101 Self::ServerOnlyProperty => "SERVER_ONLY_PROPERTY",
102 Self::FailedToUpdateApiKey => "FAILED_TO_UPDATE_API_KEY",
103 Self::InvalidMetadataType => "INVALID_METADATA_TYPE",
104 }
105 }
106
107 pub fn message(self) -> &'static str {
108 match self {
109 Self::InvalidApiKey => "Invalid API key.",
110 Self::KeyDisabled => "API Key is disabled",
111 Self::KeyExpired => "API Key has expired",
112 Self::UsageExceeded => "API Key has reached its usage limit",
113 Self::KeyNotFound => "API Key not found",
114 Self::RateLimited => "Rate limit exceeded.",
115 Self::UnauthorizedSession => "Unauthorized or invalid session",
116 Self::InvalidPrefixLength => "The prefix length is either too large or too small.",
117 Self::InvalidNameLength => "The name length is either too large or too small.",
118 Self::MetadataDisabled => "Metadata is disabled.",
119 Self::NoValuesToUpdate => "No values to update.",
120 Self::KeyDisabledExpiration => "Custom key expiration values are disabled.",
121 Self::ExpiresInTooSmall => {
122 "The expiresIn is smaller than the predefined minimum value."
123 }
124 Self::ExpiresInTooLarge => "The expiresIn is larger than the predefined maximum value.",
125 Self::InvalidRemaining => "The remaining count is either too large or too small.",
126 Self::RefillAmountAndIntervalRequired => {
127 "refillAmount and refillInterval must both be provided together"
128 }
129 Self::NameRequired => "API Key name is required.",
130 Self::InvalidUserIdFromApiKey => "The user id from the API key is invalid.",
131 Self::ServerOnlyProperty => {
132 "The property you're trying to set can only be set from the server auth instance only."
133 }
134 Self::FailedToUpdateApiKey => "Failed to update API key",
135 Self::InvalidMetadataType => "metadata must be an object or undefined",
136 }
137 }
138}
139
140fn api_key_error(code: ApiKeyErrorCode) -> AuthError {
141 AuthError::bad_request(code.message())
142}
143
144struct ApiKeyValidationError {
147 code: ApiKeyErrorCode,
148 message: String,
149}
150
151impl ApiKeyValidationError {
152 fn new(code: ApiKeyErrorCode) -> Self {
153 Self {
154 message: code.message().to_string(),
155 code,
156 }
157 }
158}
159
160pub struct ApiKeyPlugin {
166 pub(super) config: ApiKeyConfig,
167 last_expired_check: Mutex<Option<std::time::Instant>>,
169 pub(super) rate_limiters: Mutex<HashMap<String, std::sync::Arc<GovernorLimiter>>>,
172}
173
174type GovernorLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
176
177#[derive(Debug, Clone)]
179pub struct ApiKeyConfig {
180 pub key_length: usize,
182 pub prefix: Option<String>,
183 pub default_remaining: Option<i64>,
184
185 pub api_key_header: String,
187
188 pub disable_key_hashing: bool,
190
191 pub starting_characters_length: usize,
193 pub store_starting_characters: bool,
194
195 pub max_prefix_length: usize,
197 pub min_prefix_length: usize,
198
199 pub max_name_length: usize,
201 pub min_name_length: usize,
202 pub require_name: bool,
203
204 pub enable_metadata: bool,
206
207 pub key_expiration: KeyExpirationConfig,
209
210 pub rate_limit: RateLimitDefaults,
212
213 pub enable_session_for_api_keys: bool,
215}
216
217#[derive(Debug, Clone)]
219pub struct KeyExpirationConfig {
220 pub default_expires_in: Option<i64>,
222 pub disable_custom_expires_time: bool,
224 pub max_expires_in: i64,
226 pub min_expires_in: i64,
228}
229
230impl Default for KeyExpirationConfig {
231 fn default() -> Self {
232 Self {
233 default_expires_in: None,
234 disable_custom_expires_time: false,
235 max_expires_in: 365,
236 min_expires_in: 0,
237 }
238 }
239}
240
241#[derive(Debug, Clone)]
243pub struct RateLimitDefaults {
244 pub enabled: bool,
245 pub time_window: i64,
247 pub max_requests: i64,
249}
250
251impl Default for RateLimitDefaults {
252 fn default() -> Self {
253 Self {
254 enabled: true,
255 time_window: 86_400_000, max_requests: 10,
257 }
258 }
259}
260
261impl Default for ApiKeyConfig {
262 fn default() -> Self {
263 Self {
264 key_length: 32,
265 prefix: None,
266 default_remaining: None,
267 api_key_header: "x-api-key".to_string(),
268 disable_key_hashing: false,
269 starting_characters_length: 6,
270 store_starting_characters: true,
271 max_prefix_length: 32,
272 min_prefix_length: 1,
273 max_name_length: 32,
274 min_name_length: 1,
275 require_name: false,
276 enable_metadata: false,
277 key_expiration: KeyExpirationConfig::default(),
278 rate_limit: RateLimitDefaults::default(),
279 enable_session_for_api_keys: false,
280 }
281 }
282}
283
284#[bon::bon]
300impl ApiKeyPlugin {
301 #[builder]
302 pub fn new(
303 #[builder(default = 32)] key_length: usize,
304 prefix: Option<String>,
305 default_remaining: Option<i64>,
306 #[builder(default = "x-api-key".to_string())] api_key_header: String,
307 #[builder(default = false)] disable_key_hashing: bool,
308 #[builder(default = 6)] starting_characters_length: usize,
309 #[builder(default = true)] store_starting_characters: bool,
310 #[builder(default = 32)] max_prefix_length: usize,
311 #[builder(default = 1)] min_prefix_length: usize,
312 #[builder(default = 32)] max_name_length: usize,
313 #[builder(default = 1)] min_name_length: usize,
314 #[builder(default = false)] require_name: bool,
315 #[builder(default = false)] enable_metadata: bool,
316 #[builder(default)] key_expiration: KeyExpirationConfig,
317 #[builder(default)] rate_limit: RateLimitDefaults,
318 #[builder(default = false)] enable_session_for_api_keys: bool,
319 ) -> Self {
320 Self {
321 config: ApiKeyConfig {
322 key_length,
323 prefix,
324 default_remaining,
325 api_key_header,
326 disable_key_hashing,
327 starting_characters_length,
328 store_starting_characters,
329 max_prefix_length,
330 min_prefix_length,
331 max_name_length,
332 min_name_length,
333 require_name,
334 enable_metadata,
335 key_expiration,
336 rate_limit,
337 enable_session_for_api_keys,
338 },
339 last_expired_check: Mutex::new(None),
340 rate_limiters: Mutex::new(HashMap::new()),
341 }
342 }
343
344 pub fn with_config(config: ApiKeyConfig) -> Self {
345 Self {
346 config,
347 last_expired_check: Mutex::new(None),
348 rate_limiters: Mutex::new(HashMap::new()),
349 }
350 }
351
352 pub(super) fn generate_key(&self, custom_prefix: Option<&str>) -> (String, String, String) {
355 let mut bytes = vec![0u8; self.config.key_length];
356 rand::rngs::OsRng.fill_bytes(&mut bytes);
357 let raw = URL_SAFE_NO_PAD.encode(&bytes);
358
359 let start_len = self.config.starting_characters_length;
360 let start = raw.chars().take(start_len).collect::<String>();
361
362 let prefix = custom_prefix
363 .or(self.config.prefix.as_deref())
364 .unwrap_or("");
365 let full_key = format!("{}{}", prefix, raw);
366
367 let hash = if self.config.disable_key_hashing {
368 full_key.clone()
369 } else {
370 Self::hash_key(&full_key)
371 };
372
373 (full_key, hash, start)
374 }
375
376 fn hash_key(key: &str) -> String {
377 let mut hasher = Sha256::new();
378 hasher.update(key.as_bytes());
379 let digest = hasher.finalize();
380 URL_SAFE_NO_PAD.encode(digest)
381 }
382
383 pub(super) async fn maybe_delete_expired<DB: DatabaseAdapter>(&self, ctx: &AuthContext<DB>) {
385 let should_run = {
386 let mut last = self.last_expired_check.lock().unwrap();
387 let now = std::time::Instant::now();
388 match *last {
389 Some(prev) if now.duration_since(prev).as_secs() < 10 => false,
390 _ => {
391 *last = Some(now);
392 true
393 }
394 }
395 };
396 if should_run {
397 let _ = ctx.database.delete_expired_api_keys().await;
398 }
399 }
400
401 pub(super) fn validate_prefix(&self, prefix: Option<&str>) -> AuthResult<()> {
404 if let Some(p) = prefix {
405 let len = p.len();
406 if len < self.config.min_prefix_length || len > self.config.max_prefix_length {
407 return Err(api_key_error(ApiKeyErrorCode::InvalidPrefixLength));
408 }
409 }
410 Ok(())
411 }
412
413 pub(super) fn validate_name(&self, name: Option<&str>, is_create: bool) -> AuthResult<()> {
419 if is_create && self.config.require_name && name.is_none() {
420 return Err(api_key_error(ApiKeyErrorCode::NameRequired));
421 }
422 if let Some(n) = name {
423 let len = n.len();
424 if len < self.config.min_name_length || len > self.config.max_name_length {
425 return Err(api_key_error(ApiKeyErrorCode::InvalidNameLength));
426 }
427 }
428 Ok(())
429 }
430
431 pub(super) fn validate_expires_in(&self, expires_in: Option<i64>) -> AuthResult<Option<i64>> {
432 let cfg = &self.config.key_expiration;
433 if let Some(ms) = expires_in {
434 if cfg.disable_custom_expires_time {
435 return Err(api_key_error(ApiKeyErrorCode::KeyDisabledExpiration));
436 }
437 let days = ms as f64 / 86_400_000.0;
438 if days < cfg.min_expires_in as f64 {
439 return Err(api_key_error(ApiKeyErrorCode::ExpiresInTooSmall));
440 }
441 if days > cfg.max_expires_in as f64 {
442 return Err(api_key_error(ApiKeyErrorCode::ExpiresInTooLarge));
443 }
444 Ok(Some(ms))
445 } else {
446 Ok(cfg.default_expires_in)
447 }
448 }
449
450 pub(super) fn validate_metadata(&self, metadata: &Option<serde_json::Value>) -> AuthResult<()> {
451 if metadata.is_some() && !self.config.enable_metadata {
452 return Err(api_key_error(ApiKeyErrorCode::MetadataDisabled));
453 }
454 if let Some(v) = metadata
455 && !v.is_object()
456 && !v.is_null()
457 {
458 return Err(api_key_error(ApiKeyErrorCode::InvalidMetadataType));
459 }
460 Ok(())
461 }
462
463 pub(super) fn validate_refill(
464 refill_interval: Option<i64>,
465 refill_amount: Option<i64>,
466 ) -> AuthResult<()> {
467 match (refill_interval, refill_amount) {
468 (Some(_), None) | (None, Some(_)) => Err(api_key_error(
469 ApiKeyErrorCode::RefillAmountAndIntervalRequired,
470 )),
471 _ => Ok(()),
472 }
473 }
474
475 async fn handle_create<DB: DatabaseAdapter>(
480 &self,
481 req: &AuthRequest,
482 ctx: &AuthContext<DB>,
483 ) -> AuthResult<AuthResponse> {
484 let (user, _session) = ctx.require_session(req).await?;
485 let body: CreateKeyRequest = match better_auth_core::validate_request_body(req) {
486 Ok(v) => v,
487 Err(resp) => return Ok(resp),
488 };
489 let response = create_key_core(&body, user.id(), self, ctx).await?;
490 Ok(AuthResponse::json(200, &response)?)
491 }
492
493 async fn handle_get<DB: DatabaseAdapter>(
494 &self,
495 req: &AuthRequest,
496 ctx: &AuthContext<DB>,
497 ) -> AuthResult<AuthResponse> {
498 let (user, _session) = ctx.require_session(req).await?;
499 let id = req
500 .query
501 .get("id")
502 .ok_or_else(|| AuthError::bad_request("Query parameter 'id' is required"))?;
503 let response = get_key_core(id, user.id(), self, ctx).await?;
504 Ok(AuthResponse::json(200, &response)?)
505 }
506
507 async fn handle_list<DB: DatabaseAdapter>(
508 &self,
509 req: &AuthRequest,
510 ctx: &AuthContext<DB>,
511 ) -> AuthResult<AuthResponse> {
512 let (user, _session) = ctx.require_session(req).await?;
513 let response = list_keys_core(user.id(), self, ctx).await?;
514 Ok(AuthResponse::json(200, &response)?)
515 }
516
517 async fn handle_update<DB: DatabaseAdapter>(
518 &self,
519 req: &AuthRequest,
520 ctx: &AuthContext<DB>,
521 ) -> AuthResult<AuthResponse> {
522 let (user, _session) = ctx.require_session(req).await?;
523 let body: UpdateKeyRequest = match better_auth_core::validate_request_body(req) {
524 Ok(v) => v,
525 Err(resp) => return Ok(resp),
526 };
527 let response = update_key_core(&body, user.id(), self, ctx).await?;
528 Ok(AuthResponse::json(200, &response)?)
529 }
530
531 async fn handle_delete<DB: DatabaseAdapter>(
532 &self,
533 req: &AuthRequest,
534 ctx: &AuthContext<DB>,
535 ) -> AuthResult<AuthResponse> {
536 let (user, _session) = ctx.require_session(req).await?;
537 let body: DeleteKeyRequest = match better_auth_core::validate_request_body(req) {
538 Ok(v) => v,
539 Err(resp) => return Ok(resp),
540 };
541 let response = delete_key_core(&body, user.id(), self, ctx).await?;
542 Ok(AuthResponse::json(200, &response)?)
543 }
544
545 async fn handle_verify<DB: DatabaseAdapter>(
550 &self,
551 req: &AuthRequest,
552 ctx: &AuthContext<DB>,
553 ) -> AuthResult<AuthResponse> {
554 let verify_req: VerifyKeyRequest = match better_auth_core::validate_request_body(req) {
555 Ok(v) => v,
556 Err(resp) => return Ok(resp),
557 };
558 let response = verify_key_core(&verify_req, self, ctx).await?;
559 Ok(AuthResponse::json(200, &response)?)
560 }
561
562 async fn validate_api_key<DB: DatabaseAdapter>(
570 &self,
571 ctx: &AuthContext<DB>,
572 raw_key: &str,
573 required_permissions: Option<&serde_json::Value>,
574 ) -> Result<ApiKeyView, ApiKeyValidationError> {
575 let hashed = if self.config.disable_key_hashing {
577 raw_key.to_string()
578 } else {
579 Self::hash_key(raw_key)
580 };
581
582 let api_key = ctx
584 .database
585 .get_api_key_by_hash(&hashed)
586 .await
587 .map_err(|_| ApiKeyValidationError::new(ApiKeyErrorCode::InvalidApiKey))?
588 .ok_or_else(|| ApiKeyValidationError::new(ApiKeyErrorCode::InvalidApiKey))?;
589
590 if !api_key.enabled() {
592 return Err(ApiKeyValidationError::new(ApiKeyErrorCode::KeyDisabled));
593 }
594
595 if let Some(expires_at_str) = api_key.expires_at()
597 && let Ok(expires_at) = chrono::DateTime::parse_from_rfc3339(expires_at_str)
598 && chrono::Utc::now() > expires_at
599 {
600 let _ = ctx.database.delete_api_key(api_key.id()).await;
602 self.rate_limiters
603 .lock()
604 .expect("rate_limiters mutex poisoned")
605 .remove(api_key.id());
606 return Err(ApiKeyValidationError::new(ApiKeyErrorCode::KeyExpired));
607 }
608
609 if let Some(required) = required_permissions {
611 let key_perms_str = api_key.permissions().unwrap_or("");
612 if key_perms_str.is_empty() {
613 return Err(ApiKeyValidationError::new(ApiKeyErrorCode::KeyNotFound));
614 }
615 if !check_permissions(key_perms_str, required) {
616 return Err(ApiKeyValidationError::new(ApiKeyErrorCode::KeyNotFound));
617 }
618 }
619
620 let mut new_remaining = api_key.remaining();
622 let mut new_last_refill_at: Option<String> =
623 api_key.last_refill_at().map(|s| s.to_string());
624
625 if let Some(0) = api_key.remaining()
626 && api_key.refill_amount().is_none()
627 {
628 let _ = ctx.database.delete_api_key(api_key.id()).await;
630 self.rate_limiters
631 .lock()
632 .expect("rate_limiters mutex poisoned")
633 .remove(api_key.id());
634 return Err(ApiKeyValidationError::new(ApiKeyErrorCode::UsageExceeded));
635 }
636
637 if let Some(remaining) = api_key.remaining() {
638 let refill_interval = api_key.refill_interval();
639 let refill_amount = api_key.refill_amount();
640 let mut current_remaining = remaining;
641
642 if let (Some(interval), Some(amount)) = (refill_interval, refill_amount) {
643 let now = chrono::Utc::now();
644 let last_time_str = api_key
645 .last_refill_at()
646 .or_else(|| Some(api_key.created_at()));
647 if let Some(last_str) = last_time_str
648 && let Ok(last_dt) = chrono::DateTime::parse_from_rfc3339(last_str)
649 {
650 let elapsed_ms = (now - last_dt.with_timezone(&chrono::Utc)).num_milliseconds();
651 if elapsed_ms > interval {
652 current_remaining = amount;
653 new_last_refill_at = Some(now.to_rfc3339());
654 }
655 }
656 }
657
658 if current_remaining <= 0 {
659 return Err(ApiKeyValidationError::new(ApiKeyErrorCode::UsageExceeded));
660 }
661
662 new_remaining = Some(current_remaining - 1);
663 }
664
665 self.check_rate_limit_governor(&api_key)?;
667
668 let mut update = UpdateApiKey {
670 remaining: new_remaining,
671 ..Default::default()
672 };
673 if new_last_refill_at != api_key.last_refill_at().map(|s| s.to_string()) {
674 update.last_refill_at = Some(new_last_refill_at);
675 }
676
677 let updated = ctx
678 .database
679 .update_api_key(api_key.id(), update)
680 .await
681 .map_err(|_| ApiKeyValidationError::new(ApiKeyErrorCode::FailedToUpdateApiKey))?;
682
683 self.maybe_delete_expired(ctx).await;
685
686 Ok(ApiKeyView::from_entity(&updated))
687 }
688
689 fn check_rate_limit_governor(
695 &self,
696 api_key: &impl AuthApiKey,
697 ) -> Result<(), ApiKeyValidationError> {
698 let key_has_explicit_setting =
700 api_key.rate_limit_time_window().is_some() || api_key.rate_limit_max().is_some();
701 let key_enabled = api_key.rate_limit_enabled();
702
703 if !key_enabled {
704 if key_has_explicit_setting {
706 return Ok(());
707 }
708 if !self.config.rate_limit.enabled {
710 return Ok(());
711 }
712 }
713
714 let time_window_ms = api_key
715 .rate_limit_time_window()
716 .unwrap_or(self.config.rate_limit.time_window);
717 let max_requests = api_key
718 .rate_limit_max()
719 .unwrap_or(self.config.rate_limit.max_requests);
720
721 if time_window_ms <= 0 || max_requests <= 0 {
722 return Ok(());
723 }
724
725 let key_id = api_key.id().to_string();
726
727 let limiter = {
729 let mut limiters = self
730 .rate_limiters
731 .lock()
732 .expect("rate_limiters mutex poisoned");
733 limiters
734 .entry(key_id)
735 .or_insert_with(|| {
736 let max = NonZeroU32::new(max_requests as u32).unwrap_or(NonZeroU32::MIN);
737 let period_ms = (time_window_ms as u64)
738 .checked_div(max_requests as u64)
739 .unwrap_or(0);
740 let period = std::time::Duration::from_millis(period_ms.max(1));
742 let quota = Quota::with_period(period)
743 .expect("period >= 1ms is always valid")
744 .allow_burst(max);
745 std::sync::Arc::new(RateLimiter::direct(quota))
746 })
747 .clone()
748 };
749
750 match limiter.check() {
751 Ok(()) => Ok(()),
752 Err(_not_until) => Err(ApiKeyValidationError::new(ApiKeyErrorCode::RateLimited)),
753 }
754 }
755
756 async fn handle_delete_all_expired<DB: DatabaseAdapter>(
761 &self,
762 req: &AuthRequest,
763 ctx: &AuthContext<DB>,
764 ) -> AuthResult<AuthResponse> {
765 let (user, _session) = ctx.require_session(req).await?;
766 let response = delete_all_expired_core(user.id(), self, ctx).await?;
767 Ok(AuthResponse::json(200, &response)?)
768 }
769}
770
771better_auth_core::impl_auth_plugin! {
776 ApiKeyPlugin, "api-key";
777 routes {
778 post "/api-key/create" => handle_create, "api_key_create";
779 get "/api-key/get" => handle_get, "api_key_get";
780 post "/api-key/update" => handle_update, "api_key_update";
781 post "/api-key/delete" => handle_delete, "api_key_delete";
782 get "/api-key/list" => handle_list, "api_key_list";
783 post "/api-key/verify" => handle_verify, "api_key_verify";
784 post "/api-key/delete-all-expired-api-keys" => handle_delete_all_expired, "api_key_delete_all_expired";
785 }
786 extra {
787 async fn before_request(
788 &self,
789 req: &AuthRequest,
790 ctx: &AuthContext<DB>,
791 ) -> AuthResult<Option<BeforeRequestAction>> {
792 if !self.config.enable_session_for_api_keys {
793 return Ok(None);
794 }
795
796 let raw_key = match req.headers.get(&self.config.api_key_header) {
798 Some(k) if !k.is_empty() => k.clone(),
799 _ => return Ok(None),
800 };
801
802 if req.path().starts_with("/api-key/") {
806 return Ok(None);
807 }
808
809 let view = self
811 .validate_api_key(ctx, &raw_key, None)
812 .await
813 .map_err(|e| AuthError::bad_request(e.message))?;
814
815 let user = ctx
817 .database
818 .get_user_by_id(&view.user_id)
819 .await?
820 .ok_or_else(|| api_key_error(ApiKeyErrorCode::InvalidUserIdFromApiKey))?;
821
822 if req.path() == "/get-session" {
824 let session_json = serde_json::json!({
825 "user": {
826 "id": user.id(),
827 "email": user.email(),
828 "name": user.name(),
829 },
830 "session": {
831 "id": view.id,
832 "token": raw_key,
833 "userId": view.user_id,
834 }
835 });
836 return Ok(Some(BeforeRequestAction::Respond(AuthResponse::json(
837 200,
838 &session_json,
839 )?)));
840 }
841
842 Ok(Some(BeforeRequestAction::InjectSession {
844 user_id: view.user_id,
845 session_token: raw_key,
846 }))
847 }
848 }
849}
850
851#[cfg(feature = "axum")]
856mod axum_impl {
857 use super::*;
858 use std::sync::Arc;
859
860 use axum::Json;
861 use axum::extract::Extension;
862 use axum::extract::{Query, State};
863 use better_auth_core::{AuthState, CurrentSession, ValidatedJson};
864 use serde::Deserialize;
865
866 #[derive(Debug, Deserialize)]
868 struct GetKeyQuery {
869 id: String,
870 }
871
872 async fn handle_create<DB: DatabaseAdapter>(
873 State(state): State<AuthState<DB>>,
874 Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
875 CurrentSession { user, .. }: CurrentSession<DB>,
876 ValidatedJson(body): ValidatedJson<CreateKeyRequest>,
877 ) -> Result<Json<CreateKeyResponse>, AuthError> {
878 let ctx = state.to_context();
879 let response = create_key_core(&body, user.id(), &plugin, &ctx).await?;
880 Ok(Json(response))
881 }
882
883 async fn handle_get<DB: DatabaseAdapter>(
884 State(state): State<AuthState<DB>>,
885 Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
886 CurrentSession { user, .. }: CurrentSession<DB>,
887 Query(query): Query<GetKeyQuery>,
888 ) -> Result<Json<ApiKeyView>, AuthError> {
889 let ctx = state.to_context();
890 let response = get_key_core(&query.id, user.id(), &plugin, &ctx).await?;
891 Ok(Json(response))
892 }
893
894 async fn handle_list<DB: DatabaseAdapter>(
895 State(state): State<AuthState<DB>>,
896 Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
897 CurrentSession { user, .. }: CurrentSession<DB>,
898 ) -> Result<Json<Vec<ApiKeyView>>, AuthError> {
899 let ctx = state.to_context();
900 let response = list_keys_core(user.id(), &plugin, &ctx).await?;
901 Ok(Json(response))
902 }
903
904 async fn handle_update<DB: DatabaseAdapter>(
905 State(state): State<AuthState<DB>>,
906 Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
907 CurrentSession { user, .. }: CurrentSession<DB>,
908 ValidatedJson(body): ValidatedJson<UpdateKeyRequest>,
909 ) -> Result<Json<ApiKeyView>, AuthError> {
910 let ctx = state.to_context();
911 let response = update_key_core(&body, user.id(), &plugin, &ctx).await?;
912 Ok(Json(response))
913 }
914
915 async fn handle_delete<DB: DatabaseAdapter>(
916 State(state): State<AuthState<DB>>,
917 Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
918 CurrentSession { user, .. }: CurrentSession<DB>,
919 ValidatedJson(body): ValidatedJson<DeleteKeyRequest>,
920 ) -> Result<Json<serde_json::Value>, AuthError> {
921 let ctx = state.to_context();
922 let response = delete_key_core(&body, user.id(), &plugin, &ctx).await?;
923 Ok(Json(response))
924 }
925
926 async fn handle_verify<DB: DatabaseAdapter>(
927 State(state): State<AuthState<DB>>,
928 Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
929 Json(body): Json<VerifyKeyRequest>,
930 ) -> Result<Json<VerifyKeyResponse>, AuthError> {
931 let ctx = state.to_context();
932 let response = verify_key_core(&body, &plugin, &ctx).await?;
933 Ok(Json(response))
934 }
935
936 async fn handle_delete_all_expired<DB: DatabaseAdapter>(
937 State(state): State<AuthState<DB>>,
938 Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
939 CurrentSession { user, .. }: CurrentSession<DB>,
940 ) -> Result<Json<serde_json::Value>, AuthError> {
941 let ctx = state.to_context();
942 let response = delete_all_expired_core(user.id(), &plugin, &ctx).await?;
943 Ok(Json(response))
944 }
945
946 impl<DB: DatabaseAdapter> better_auth_core::AxumPlugin<DB> for ApiKeyPlugin {
947 fn name(&self) -> &'static str {
948 "api-key"
949 }
950
951 fn router(&self) -> axum::Router<AuthState<DB>> {
952 use axum::routing::{get, post};
953
954 let plugin = Arc::new(ApiKeyPlugin::with_config(self.config.clone()));
955 axum::Router::new()
956 .route("/api-key/create", post(handle_create::<DB>))
957 .route("/api-key/get", get(handle_get::<DB>))
958 .route("/api-key/update", post(handle_update::<DB>))
959 .route("/api-key/delete", post(handle_delete::<DB>))
960 .route("/api-key/list", get(handle_list::<DB>))
961 .route("/api-key/verify", post(handle_verify::<DB>))
962 .route(
963 "/api-key/delete-all-expired-api-keys",
964 post(handle_delete_all_expired::<DB>),
965 )
966 .layer(Extension(plugin))
967 }
968 }
969}