Skip to main content

better_auth_api/plugins/api_key/
mod.rs

1use base64::Engine;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3use governor::clock::DefaultClock;
4use governor::state::{InMemoryState, NotKeyed};
5use governor::{Quota, RateLimiter};
6use rand::RngCore;
7use serde::Serialize;
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::num::NonZeroU32;
11use std::sync::Mutex;
12
13use better_auth_core::adapters::DatabaseAdapter;
14use better_auth_core::entity::{AuthApiKey, AuthUser};
15use better_auth_core::{AuthContext, AuthError, AuthResult, BeforeRequestAction};
16use better_auth_core::{AuthRequest, AuthResponse, UpdateApiKey};
17
18pub(super) mod handlers;
19pub(super) mod types;
20
21#[cfg(test)]
22mod tests;
23
24use handlers::*;
25use types::*;
26
27// ---------------------------------------------------------------------------
28// Error codes -- mirrors the TypeScript `API_KEY_ERROR_CODES`
29// ---------------------------------------------------------------------------
30
31/// Dedicated API Key error codes aligned with the TypeScript implementation.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
33pub enum ApiKeyErrorCode {
34    #[serde(rename = "INVALID_API_KEY")]
35    InvalidApiKey,
36    #[serde(rename = "KEY_DISABLED")]
37    KeyDisabled,
38    #[serde(rename = "KEY_EXPIRED")]
39    KeyExpired,
40    #[serde(rename = "USAGE_EXCEEDED")]
41    UsageExceeded,
42    #[serde(rename = "KEY_NOT_FOUND")]
43    KeyNotFound,
44    #[serde(rename = "RATE_LIMITED")]
45    RateLimited,
46    #[serde(rename = "UNAUTHORIZED_SESSION")]
47    UnauthorizedSession,
48    #[serde(rename = "INVALID_PREFIX_LENGTH")]
49    InvalidPrefixLength,
50    #[serde(rename = "INVALID_NAME_LENGTH")]
51    InvalidNameLength,
52    #[serde(rename = "METADATA_DISABLED")]
53    MetadataDisabled,
54    #[serde(rename = "NO_VALUES_TO_UPDATE")]
55    NoValuesToUpdate,
56    #[serde(rename = "KEY_DISABLED_EXPIRATION")]
57    KeyDisabledExpiration,
58    #[serde(rename = "EXPIRES_IN_IS_TOO_SMALL")]
59    ExpiresInTooSmall,
60    #[serde(rename = "EXPIRES_IN_IS_TOO_LARGE")]
61    ExpiresInTooLarge,
62    #[serde(rename = "INVALID_REMAINING")]
63    InvalidRemaining,
64    #[serde(rename = "REFILL_AMOUNT_AND_INTERVAL_REQUIRED")]
65    RefillAmountAndIntervalRequired,
66    #[serde(rename = "NAME_REQUIRED")]
67    NameRequired,
68    #[serde(rename = "INVALID_USER_ID_FROM_API_KEY")]
69    InvalidUserIdFromApiKey,
70    #[serde(rename = "SERVER_ONLY_PROPERTY")]
71    ServerOnlyProperty,
72    #[serde(rename = "FAILED_TO_UPDATE_API_KEY")]
73    FailedToUpdateApiKey,
74    #[serde(rename = "INVALID_METADATA_TYPE")]
75    InvalidMetadataType,
76}
77
78impl ApiKeyErrorCode {
79    /// Return the SCREAMING_SNAKE_CASE string for this error code.
80    /// Used by `handle_verify` to produce the structured JSON error response.
81    pub fn as_str(self) -> &'static str {
82        match self {
83            Self::InvalidApiKey => "INVALID_API_KEY",
84            Self::KeyDisabled => "KEY_DISABLED",
85            Self::KeyExpired => "KEY_EXPIRED",
86            Self::UsageExceeded => "USAGE_EXCEEDED",
87            Self::KeyNotFound => "KEY_NOT_FOUND",
88            Self::RateLimited => "RATE_LIMITED",
89            Self::UnauthorizedSession => "UNAUTHORIZED_SESSION",
90            Self::InvalidPrefixLength => "INVALID_PREFIX_LENGTH",
91            Self::InvalidNameLength => "INVALID_NAME_LENGTH",
92            Self::MetadataDisabled => "METADATA_DISABLED",
93            Self::NoValuesToUpdate => "NO_VALUES_TO_UPDATE",
94            Self::KeyDisabledExpiration => "KEY_DISABLED_EXPIRATION",
95            Self::ExpiresInTooSmall => "EXPIRES_IN_IS_TOO_SMALL",
96            Self::ExpiresInTooLarge => "EXPIRES_IN_IS_TOO_LARGE",
97            Self::InvalidRemaining => "INVALID_REMAINING",
98            Self::RefillAmountAndIntervalRequired => "REFILL_AMOUNT_AND_INTERVAL_REQUIRED",
99            Self::NameRequired => "NAME_REQUIRED",
100            Self::InvalidUserIdFromApiKey => "INVALID_USER_ID_FROM_API_KEY",
101            Self::ServerOnlyProperty => "SERVER_ONLY_PROPERTY",
102            Self::FailedToUpdateApiKey => "FAILED_TO_UPDATE_API_KEY",
103            Self::InvalidMetadataType => "INVALID_METADATA_TYPE",
104        }
105    }
106
107    pub fn message(self) -> &'static str {
108        match self {
109            Self::InvalidApiKey => "Invalid API key.",
110            Self::KeyDisabled => "API Key is disabled",
111            Self::KeyExpired => "API Key has expired",
112            Self::UsageExceeded => "API Key has reached its usage limit",
113            Self::KeyNotFound => "API Key not found",
114            Self::RateLimited => "Rate limit exceeded.",
115            Self::UnauthorizedSession => "Unauthorized or invalid session",
116            Self::InvalidPrefixLength => "The prefix length is either too large or too small.",
117            Self::InvalidNameLength => "The name length is either too large or too small.",
118            Self::MetadataDisabled => "Metadata is disabled.",
119            Self::NoValuesToUpdate => "No values to update.",
120            Self::KeyDisabledExpiration => "Custom key expiration values are disabled.",
121            Self::ExpiresInTooSmall => {
122                "The expiresIn is smaller than the predefined minimum value."
123            }
124            Self::ExpiresInTooLarge => "The expiresIn is larger than the predefined maximum value.",
125            Self::InvalidRemaining => "The remaining count is either too large or too small.",
126            Self::RefillAmountAndIntervalRequired => {
127                "refillAmount and refillInterval must both be provided together"
128            }
129            Self::NameRequired => "API Key name is required.",
130            Self::InvalidUserIdFromApiKey => "The user id from the API key is invalid.",
131            Self::ServerOnlyProperty => {
132                "The property you're trying to set can only be set from the server auth instance only."
133            }
134            Self::FailedToUpdateApiKey => "Failed to update API key",
135            Self::InvalidMetadataType => "metadata must be an object or undefined",
136        }
137    }
138}
139
140fn api_key_error(code: ApiKeyErrorCode) -> AuthError {
141    AuthError::bad_request(code.message())
142}
143
144/// Structured error returned by `validate_api_key` so that `handle_verify`
145/// can extract the error code without fragile string matching.
146struct ApiKeyValidationError {
147    code: ApiKeyErrorCode,
148    message: String,
149}
150
151impl ApiKeyValidationError {
152    fn new(code: ApiKeyErrorCode) -> Self {
153        Self {
154            message: code.message().to_string(),
155            code,
156        }
157    }
158}
159
160// ---------------------------------------------------------------------------
161// Configuration
162// ---------------------------------------------------------------------------
163
164/// API Key management plugin.
165pub struct ApiKeyPlugin {
166    pub(super) config: ApiKeyConfig,
167    /// Throttle for `delete_expired_api_keys` -- stores the last check instant.
168    last_expired_check: Mutex<Option<std::time::Instant>>,
169    /// Per-key in-memory rate limiters backed by the `governor` crate.
170    /// Key: API key ID -> governor rate limiter.
171    pub(super) rate_limiters: Mutex<HashMap<String, std::sync::Arc<GovernorLimiter>>>,
172}
173
174/// Type alias for the governor rate limiter we use (not keyed, in-memory, default clock).
175type GovernorLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
176
177/// Configuration for the API Key plugin, aligned with the TypeScript `ApiKeyOptions`.
178#[derive(Debug, Clone)]
179pub struct ApiKeyConfig {
180    // -- key generation --
181    pub key_length: usize,
182    pub prefix: Option<String>,
183    pub default_remaining: Option<i64>,
184
185    // -- header --
186    pub api_key_header: String,
187
188    // -- hashing --
189    pub disable_key_hashing: bool,
190
191    // -- starting characters --
192    pub starting_characters_length: usize,
193    pub store_starting_characters: bool,
194
195    // -- prefix length validation --
196    pub max_prefix_length: usize,
197    pub min_prefix_length: usize,
198
199    // -- name validation --
200    pub max_name_length: usize,
201    pub min_name_length: usize,
202    pub require_name: bool,
203
204    // -- metadata --
205    pub enable_metadata: bool,
206
207    // -- key expiration --
208    pub key_expiration: KeyExpirationConfig,
209
210    // -- rate limit defaults --
211    pub rate_limit: RateLimitDefaults,
212
213    // -- session emulation --
214    pub enable_session_for_api_keys: bool,
215}
216
217/// Key expiration constraints.
218#[derive(Debug, Clone)]
219pub struct KeyExpirationConfig {
220    /// Default `expiresIn` (in milliseconds) when none is provided. `None` = no default.
221    pub default_expires_in: Option<i64>,
222    /// If true, clients cannot set a custom `expiresIn`.
223    pub disable_custom_expires_time: bool,
224    /// Maximum `expiresIn` in **days**.
225    pub max_expires_in: i64,
226    /// Minimum `expiresIn` in **days**.
227    pub min_expires_in: i64,
228}
229
230impl Default for KeyExpirationConfig {
231    fn default() -> Self {
232        Self {
233            default_expires_in: None,
234            disable_custom_expires_time: false,
235            max_expires_in: 365,
236            min_expires_in: 0,
237        }
238    }
239}
240
241/// Global rate-limit defaults applied to newly-created keys.
242#[derive(Debug, Clone)]
243pub struct RateLimitDefaults {
244    pub enabled: bool,
245    /// Default time window in milliseconds.
246    pub time_window: i64,
247    /// Default max requests per window.
248    pub max_requests: i64,
249}
250
251impl Default for RateLimitDefaults {
252    fn default() -> Self {
253        Self {
254            enabled: true,
255            time_window: 86_400_000, // 24 hours
256            max_requests: 10,
257        }
258    }
259}
260
261impl Default for ApiKeyConfig {
262    fn default() -> Self {
263        Self {
264            key_length: 32,
265            prefix: None,
266            default_remaining: None,
267            api_key_header: "x-api-key".to_string(),
268            disable_key_hashing: false,
269            starting_characters_length: 6,
270            store_starting_characters: true,
271            max_prefix_length: 32,
272            min_prefix_length: 1,
273            max_name_length: 32,
274            min_name_length: 1,
275            require_name: false,
276            enable_metadata: false,
277            key_expiration: KeyExpirationConfig::default(),
278            rate_limit: RateLimitDefaults::default(),
279            enable_session_for_api_keys: false,
280        }
281    }
282}
283
284// ---------------------------------------------------------------------------
285// Plugin implementation
286// ---------------------------------------------------------------------------
287
288/// Builder for [`ApiKeyPlugin`] powered by the `bon` crate.
289///
290/// Usage:
291/// ```ignore
292/// let plugin = ApiKeyPlugin::builder()
293///     .key_length(48)
294///     .prefix("ba_".to_string())
295///     .enable_metadata(true)
296///     .rate_limit(RateLimitDefaults { enabled: true, time_window: 60_000, max_requests: 5 })
297///     .build();
298/// ```
299#[bon::bon]
300impl ApiKeyPlugin {
301    #[builder]
302    pub fn new(
303        #[builder(default = 32)] key_length: usize,
304        prefix: Option<String>,
305        default_remaining: Option<i64>,
306        #[builder(default = "x-api-key".to_string())] api_key_header: String,
307        #[builder(default = false)] disable_key_hashing: bool,
308        #[builder(default = 6)] starting_characters_length: usize,
309        #[builder(default = true)] store_starting_characters: bool,
310        #[builder(default = 32)] max_prefix_length: usize,
311        #[builder(default = 1)] min_prefix_length: usize,
312        #[builder(default = 32)] max_name_length: usize,
313        #[builder(default = 1)] min_name_length: usize,
314        #[builder(default = false)] require_name: bool,
315        #[builder(default = false)] enable_metadata: bool,
316        #[builder(default)] key_expiration: KeyExpirationConfig,
317        #[builder(default)] rate_limit: RateLimitDefaults,
318        #[builder(default = false)] enable_session_for_api_keys: bool,
319    ) -> Self {
320        Self {
321            config: ApiKeyConfig {
322                key_length,
323                prefix,
324                default_remaining,
325                api_key_header,
326                disable_key_hashing,
327                starting_characters_length,
328                store_starting_characters,
329                max_prefix_length,
330                min_prefix_length,
331                max_name_length,
332                min_name_length,
333                require_name,
334                enable_metadata,
335                key_expiration,
336                rate_limit,
337                enable_session_for_api_keys,
338            },
339            last_expired_check: Mutex::new(None),
340            rate_limiters: Mutex::new(HashMap::new()),
341        }
342    }
343
344    pub fn with_config(config: ApiKeyConfig) -> Self {
345        Self {
346            config,
347            last_expired_check: Mutex::new(None),
348            rate_limiters: Mutex::new(HashMap::new()),
349        }
350    }
351
352    // -- internal helpers --
353
354    pub(super) fn generate_key(&self, custom_prefix: Option<&str>) -> (String, String, String) {
355        let mut bytes = vec![0u8; self.config.key_length];
356        rand::rngs::OsRng.fill_bytes(&mut bytes);
357        let raw = URL_SAFE_NO_PAD.encode(&bytes);
358
359        let start_len = self.config.starting_characters_length;
360        let start = raw.chars().take(start_len).collect::<String>();
361
362        let prefix = custom_prefix
363            .or(self.config.prefix.as_deref())
364            .unwrap_or("");
365        let full_key = format!("{}{}", prefix, raw);
366
367        let hash = if self.config.disable_key_hashing {
368            full_key.clone()
369        } else {
370            Self::hash_key(&full_key)
371        };
372
373        (full_key, hash, start)
374    }
375
376    fn hash_key(key: &str) -> String {
377        let mut hasher = Sha256::new();
378        hasher.update(key.as_bytes());
379        let digest = hasher.finalize();
380        URL_SAFE_NO_PAD.encode(digest)
381    }
382
383    /// Throttled cleanup -- at most once per 10 seconds.
384    pub(super) async fn maybe_delete_expired<DB: DatabaseAdapter>(&self, ctx: &AuthContext<DB>) {
385        let should_run = {
386            let mut last = self.last_expired_check.lock().unwrap();
387            let now = std::time::Instant::now();
388            match *last {
389                Some(prev) if now.duration_since(prev).as_secs() < 10 => false,
390                _ => {
391                    *last = Some(now);
392                    true
393                }
394            }
395        };
396        if should_run {
397            let _ = ctx.database.delete_expired_api_keys().await;
398        }
399    }
400
401    // -- Validation helpers --
402
403    pub(super) fn validate_prefix(&self, prefix: Option<&str>) -> AuthResult<()> {
404        if let Some(p) = prefix {
405            let len = p.len();
406            if len < self.config.min_prefix_length || len > self.config.max_prefix_length {
407                return Err(api_key_error(ApiKeyErrorCode::InvalidPrefixLength));
408            }
409        }
410        Ok(())
411    }
412
413    /// Validate the `name` field.
414    ///
415    /// When `is_create` is true, `require_name` is enforced (name must be
416    /// present).  On updates `require_name` is **not** enforced -- the
417    /// caller may be updating unrelated fields without resending the name.
418    pub(super) fn validate_name(&self, name: Option<&str>, is_create: bool) -> AuthResult<()> {
419        if is_create && self.config.require_name && name.is_none() {
420            return Err(api_key_error(ApiKeyErrorCode::NameRequired));
421        }
422        if let Some(n) = name {
423            let len = n.len();
424            if len < self.config.min_name_length || len > self.config.max_name_length {
425                return Err(api_key_error(ApiKeyErrorCode::InvalidNameLength));
426            }
427        }
428        Ok(())
429    }
430
431    pub(super) fn validate_expires_in(&self, expires_in: Option<i64>) -> AuthResult<Option<i64>> {
432        let cfg = &self.config.key_expiration;
433        if let Some(ms) = expires_in {
434            if cfg.disable_custom_expires_time {
435                return Err(api_key_error(ApiKeyErrorCode::KeyDisabledExpiration));
436            }
437            let days = ms as f64 / 86_400_000.0;
438            if days < cfg.min_expires_in as f64 {
439                return Err(api_key_error(ApiKeyErrorCode::ExpiresInTooSmall));
440            }
441            if days > cfg.max_expires_in as f64 {
442                return Err(api_key_error(ApiKeyErrorCode::ExpiresInTooLarge));
443            }
444            Ok(Some(ms))
445        } else {
446            Ok(cfg.default_expires_in)
447        }
448    }
449
450    pub(super) fn validate_metadata(&self, metadata: &Option<serde_json::Value>) -> AuthResult<()> {
451        if metadata.is_some() && !self.config.enable_metadata {
452            return Err(api_key_error(ApiKeyErrorCode::MetadataDisabled));
453        }
454        if let Some(v) = metadata
455            && !v.is_object()
456            && !v.is_null()
457        {
458            return Err(api_key_error(ApiKeyErrorCode::InvalidMetadataType));
459        }
460        Ok(())
461    }
462
463    pub(super) fn validate_refill(
464        refill_interval: Option<i64>,
465        refill_amount: Option<i64>,
466    ) -> AuthResult<()> {
467        match (refill_interval, refill_amount) {
468            (Some(_), None) | (None, Some(_)) => Err(api_key_error(
469                ApiKeyErrorCode::RefillAmountAndIntervalRequired,
470            )),
471            _ => Ok(()),
472        }
473    }
474
475    // -----------------------------------------------------------------------
476    // Route handlers (old -- delegate to core functions)
477    // -----------------------------------------------------------------------
478
479    async fn handle_create<DB: DatabaseAdapter>(
480        &self,
481        req: &AuthRequest,
482        ctx: &AuthContext<DB>,
483    ) -> AuthResult<AuthResponse> {
484        let (user, _session) = ctx.require_session(req).await?;
485        let body: CreateKeyRequest = match better_auth_core::validate_request_body(req) {
486            Ok(v) => v,
487            Err(resp) => return Ok(resp),
488        };
489        let response = create_key_core(&body, user.id(), self, ctx).await?;
490        Ok(AuthResponse::json(200, &response)?)
491    }
492
493    async fn handle_get<DB: DatabaseAdapter>(
494        &self,
495        req: &AuthRequest,
496        ctx: &AuthContext<DB>,
497    ) -> AuthResult<AuthResponse> {
498        let (user, _session) = ctx.require_session(req).await?;
499        let id = req
500            .query
501            .get("id")
502            .ok_or_else(|| AuthError::bad_request("Query parameter 'id' is required"))?;
503        let response = get_key_core(id, user.id(), self, ctx).await?;
504        Ok(AuthResponse::json(200, &response)?)
505    }
506
507    async fn handle_list<DB: DatabaseAdapter>(
508        &self,
509        req: &AuthRequest,
510        ctx: &AuthContext<DB>,
511    ) -> AuthResult<AuthResponse> {
512        let (user, _session) = ctx.require_session(req).await?;
513        let response = list_keys_core(user.id(), self, ctx).await?;
514        Ok(AuthResponse::json(200, &response)?)
515    }
516
517    async fn handle_update<DB: DatabaseAdapter>(
518        &self,
519        req: &AuthRequest,
520        ctx: &AuthContext<DB>,
521    ) -> AuthResult<AuthResponse> {
522        let (user, _session) = ctx.require_session(req).await?;
523        let body: UpdateKeyRequest = match better_auth_core::validate_request_body(req) {
524            Ok(v) => v,
525            Err(resp) => return Ok(resp),
526        };
527        let response = update_key_core(&body, user.id(), self, ctx).await?;
528        Ok(AuthResponse::json(200, &response)?)
529    }
530
531    async fn handle_delete<DB: DatabaseAdapter>(
532        &self,
533        req: &AuthRequest,
534        ctx: &AuthContext<DB>,
535    ) -> AuthResult<AuthResponse> {
536        let (user, _session) = ctx.require_session(req).await?;
537        let body: DeleteKeyRequest = match better_auth_core::validate_request_body(req) {
538            Ok(v) => v,
539            Err(resp) => return Ok(resp),
540        };
541        let response = delete_key_core(&body, user.id(), self, ctx).await?;
542        Ok(AuthResponse::json(200, &response)?)
543    }
544
545    // -----------------------------------------------------------------------
546    // POST /api-key/verify -- core verification endpoint
547    // -----------------------------------------------------------------------
548
549    async fn handle_verify<DB: DatabaseAdapter>(
550        &self,
551        req: &AuthRequest,
552        ctx: &AuthContext<DB>,
553    ) -> AuthResult<AuthResponse> {
554        let verify_req: VerifyKeyRequest = match better_auth_core::validate_request_body(req) {
555            Ok(v) => v,
556            Err(resp) => return Ok(resp),
557        };
558        let response = verify_key_core(&verify_req, self, ctx).await?;
559        Ok(AuthResponse::json(200, &response)?)
560    }
561
562    /// Core validation logic shared by `handle_verify` and `before_request`.
563    ///
564    /// Validation chain: exists -> disabled -> expired -> permissions ->
565    /// remaining/refill -> rate limit.
566    ///
567    /// Returns `Ok(ApiKeyView)` on success, or `Err(ApiKeyValidationError)` with
568    /// a structured error code (no fragile string matching needed).
569    async fn validate_api_key<DB: DatabaseAdapter>(
570        &self,
571        ctx: &AuthContext<DB>,
572        raw_key: &str,
573        required_permissions: Option<&serde_json::Value>,
574    ) -> Result<ApiKeyView, ApiKeyValidationError> {
575        // Hash the key (or use as-is if hashing is disabled)
576        let hashed = if self.config.disable_key_hashing {
577            raw_key.to_string()
578        } else {
579            Self::hash_key(raw_key)
580        };
581
582        // Look up by hash
583        let api_key = ctx
584            .database
585            .get_api_key_by_hash(&hashed)
586            .await
587            .map_err(|_| ApiKeyValidationError::new(ApiKeyErrorCode::InvalidApiKey))?
588            .ok_or_else(|| ApiKeyValidationError::new(ApiKeyErrorCode::InvalidApiKey))?;
589
590        // 1. Disabled?
591        if !api_key.enabled() {
592            return Err(ApiKeyValidationError::new(ApiKeyErrorCode::KeyDisabled));
593        }
594
595        // 2. Expired?
596        if let Some(expires_at_str) = api_key.expires_at()
597            && let Ok(expires_at) = chrono::DateTime::parse_from_rfc3339(expires_at_str)
598            && chrono::Utc::now() > expires_at
599        {
600            // Delete expired key and evict its cached rate limiter
601            let _ = ctx.database.delete_api_key(api_key.id()).await;
602            self.rate_limiters
603                .lock()
604                .expect("rate_limiters mutex poisoned")
605                .remove(api_key.id());
606            return Err(ApiKeyValidationError::new(ApiKeyErrorCode::KeyExpired));
607        }
608
609        // 3. Permissions check
610        if let Some(required) = required_permissions {
611            let key_perms_str = api_key.permissions().unwrap_or("");
612            if key_perms_str.is_empty() {
613                return Err(ApiKeyValidationError::new(ApiKeyErrorCode::KeyNotFound));
614            }
615            if !check_permissions(key_perms_str, required) {
616                return Err(ApiKeyValidationError::new(ApiKeyErrorCode::KeyNotFound));
617            }
618        }
619
620        // 4. Remaining / refill
621        let mut new_remaining = api_key.remaining();
622        let mut new_last_refill_at: Option<String> =
623            api_key.last_refill_at().map(|s| s.to_string());
624
625        if let Some(0) = api_key.remaining()
626            && api_key.refill_amount().is_none()
627        {
628            // Usage exhausted, no refill configured -- delete key and evict cache
629            let _ = ctx.database.delete_api_key(api_key.id()).await;
630            self.rate_limiters
631                .lock()
632                .expect("rate_limiters mutex poisoned")
633                .remove(api_key.id());
634            return Err(ApiKeyValidationError::new(ApiKeyErrorCode::UsageExceeded));
635        }
636
637        if let Some(remaining) = api_key.remaining() {
638            let refill_interval = api_key.refill_interval();
639            let refill_amount = api_key.refill_amount();
640            let mut current_remaining = remaining;
641
642            if let (Some(interval), Some(amount)) = (refill_interval, refill_amount) {
643                let now = chrono::Utc::now();
644                let last_time_str = api_key
645                    .last_refill_at()
646                    .or_else(|| Some(api_key.created_at()));
647                if let Some(last_str) = last_time_str
648                    && let Ok(last_dt) = chrono::DateTime::parse_from_rfc3339(last_str)
649                {
650                    let elapsed_ms = (now - last_dt.with_timezone(&chrono::Utc)).num_milliseconds();
651                    if elapsed_ms > interval {
652                        current_remaining = amount;
653                        new_last_refill_at = Some(now.to_rfc3339());
654                    }
655                }
656            }
657
658            if current_remaining <= 0 {
659                return Err(ApiKeyValidationError::new(ApiKeyErrorCode::UsageExceeded));
660            }
661
662            new_remaining = Some(current_remaining - 1);
663        }
664
665        // 5. Rate limiting via `governor` crate
666        self.check_rate_limit_governor(&api_key)?;
667
668        // 6. Build update
669        let mut update = UpdateApiKey {
670            remaining: new_remaining,
671            ..Default::default()
672        };
673        if new_last_refill_at != api_key.last_refill_at().map(|s| s.to_string()) {
674            update.last_refill_at = Some(new_last_refill_at);
675        }
676
677        let updated = ctx
678            .database
679            .update_api_key(api_key.id(), update)
680            .await
681            .map_err(|_| ApiKeyValidationError::new(ApiKeyErrorCode::FailedToUpdateApiKey))?;
682
683        // Throttled cleanup
684        self.maybe_delete_expired(ctx).await;
685
686        Ok(ApiKeyView::from_entity(&updated))
687    }
688
689    /// Check rate limiting for an API key using the `governor` crate.
690    ///
691    /// Creates or retrieves a per-key in-memory rate limiter backed by GCRA
692    /// (Generic Cell Rate Algorithm), which is thread-safe and lock-free on
693    /// the hot path.
694    fn check_rate_limit_governor(
695        &self,
696        api_key: &impl AuthApiKey,
697    ) -> Result<(), ApiKeyValidationError> {
698        // Determine if rate limiting is enabled for this key.
699        let key_has_explicit_setting =
700            api_key.rate_limit_time_window().is_some() || api_key.rate_limit_max().is_some();
701        let key_enabled = api_key.rate_limit_enabled();
702
703        if !key_enabled {
704            // Key explicitly disabled rate limiting -- skip.
705            if key_has_explicit_setting {
706                return Ok(());
707            }
708            // Key has no explicit setting and global is also off -- skip.
709            if !self.config.rate_limit.enabled {
710                return Ok(());
711            }
712        }
713
714        let time_window_ms = api_key
715            .rate_limit_time_window()
716            .unwrap_or(self.config.rate_limit.time_window);
717        let max_requests = api_key
718            .rate_limit_max()
719            .unwrap_or(self.config.rate_limit.max_requests);
720
721        if time_window_ms <= 0 || max_requests <= 0 {
722            return Ok(());
723        }
724
725        let key_id = api_key.id().to_string();
726
727        // Get or create the rate limiter for this key
728        let limiter = {
729            let mut limiters = self
730                .rate_limiters
731                .lock()
732                .expect("rate_limiters mutex poisoned");
733            limiters
734                .entry(key_id)
735                .or_insert_with(|| {
736                    let max = NonZeroU32::new(max_requests as u32).unwrap_or(NonZeroU32::MIN);
737                    let period_ms = (time_window_ms as u64)
738                        .checked_div(max_requests as u64)
739                        .unwrap_or(0);
740                    // Guard against zero-period panic (e.g. time_window_ms < max_requests)
741                    let period = std::time::Duration::from_millis(period_ms.max(1));
742                    let quota = Quota::with_period(period)
743                        .expect("period >= 1ms is always valid")
744                        .allow_burst(max);
745                    std::sync::Arc::new(RateLimiter::direct(quota))
746                })
747                .clone()
748        };
749
750        match limiter.check() {
751            Ok(()) => Ok(()),
752            Err(_not_until) => Err(ApiKeyValidationError::new(ApiKeyErrorCode::RateLimited)),
753        }
754    }
755
756    // -----------------------------------------------------------------------
757    // POST /api-key/delete-all-expired-api-keys
758    // -----------------------------------------------------------------------
759
760    async fn handle_delete_all_expired<DB: DatabaseAdapter>(
761        &self,
762        req: &AuthRequest,
763        ctx: &AuthContext<DB>,
764    ) -> AuthResult<AuthResponse> {
765        let (user, _session) = ctx.require_session(req).await?;
766        let response = delete_all_expired_core(user.id(), self, ctx).await?;
767        Ok(AuthResponse::json(200, &response)?)
768    }
769}
770
771// ---------------------------------------------------------------------------
772// AuthPlugin trait implementation
773// ---------------------------------------------------------------------------
774
775better_auth_core::impl_auth_plugin! {
776    ApiKeyPlugin, "api-key";
777    routes {
778        post "/api-key/create"                    => handle_create,             "api_key_create";
779        get  "/api-key/get"                       => handle_get,                "api_key_get";
780        post "/api-key/update"                    => handle_update,             "api_key_update";
781        post "/api-key/delete"                    => handle_delete,             "api_key_delete";
782        get  "/api-key/list"                      => handle_list,               "api_key_list";
783        post "/api-key/verify"                    => handle_verify,             "api_key_verify";
784        post "/api-key/delete-all-expired-api-keys" => handle_delete_all_expired, "api_key_delete_all_expired";
785    }
786    extra {
787        async fn before_request(
788            &self,
789            req: &AuthRequest,
790            ctx: &AuthContext<DB>,
791        ) -> AuthResult<Option<BeforeRequestAction>> {
792            if !self.config.enable_session_for_api_keys {
793                return Ok(None);
794            }
795
796            // Check for API key in the configured header
797            let raw_key = match req.headers.get(&self.config.api_key_header) {
798                Some(k) if !k.is_empty() => k.clone(),
799                _ => return Ok(None),
800            };
801
802            // Skip session emulation for API-key management routes to avoid
803            // double-validating the key (before_request + handle_verify both
804            // call validate_api_key, consuming usage/rate-limit budget twice).
805            if req.path().starts_with("/api-key/") {
806                return Ok(None);
807            }
808
809            // Validate the key (reuses the full verify logic)
810            let view = self
811                .validate_api_key(ctx, &raw_key, None)
812                .await
813                .map_err(|e| AuthError::bad_request(e.message))?;
814
815            // Look up the user
816            let user = ctx
817                .database
818                .get_user_by_id(&view.user_id)
819                .await?
820                .ok_or_else(|| api_key_error(ApiKeyErrorCode::InvalidUserIdFromApiKey))?;
821
822            // Build a virtual session response for `/get-session`
823            if req.path() == "/get-session" {
824                let session_json = serde_json::json!({
825                    "user": {
826                        "id": user.id(),
827                        "email": user.email(),
828                        "name": user.name(),
829                    },
830                    "session": {
831                        "id": view.id,
832                        "token": raw_key,
833                        "userId": view.user_id,
834                    }
835                });
836                return Ok(Some(BeforeRequestAction::Respond(AuthResponse::json(
837                    200,
838                    &session_json,
839                )?)));
840            }
841
842            // For all other routes, inject the session
843            Ok(Some(BeforeRequestAction::InjectSession {
844                user_id: view.user_id,
845                session_token: raw_key,
846            }))
847        }
848    }
849}
850
851// ---------------------------------------------------------------------------
852// Axum integration
853// ---------------------------------------------------------------------------
854
855#[cfg(feature = "axum")]
856mod axum_impl {
857    use super::*;
858    use std::sync::Arc;
859
860    use axum::Json;
861    use axum::extract::Extension;
862    use axum::extract::{Query, State};
863    use better_auth_core::{AuthState, CurrentSession, ValidatedJson};
864    use serde::Deserialize;
865
866    /// Query parameters for GET /api-key/get
867    #[derive(Debug, Deserialize)]
868    struct GetKeyQuery {
869        id: String,
870    }
871
872    async fn handle_create<DB: DatabaseAdapter>(
873        State(state): State<AuthState<DB>>,
874        Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
875        CurrentSession { user, .. }: CurrentSession<DB>,
876        ValidatedJson(body): ValidatedJson<CreateKeyRequest>,
877    ) -> Result<Json<CreateKeyResponse>, AuthError> {
878        let ctx = state.to_context();
879        let response = create_key_core(&body, user.id(), &plugin, &ctx).await?;
880        Ok(Json(response))
881    }
882
883    async fn handle_get<DB: DatabaseAdapter>(
884        State(state): State<AuthState<DB>>,
885        Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
886        CurrentSession { user, .. }: CurrentSession<DB>,
887        Query(query): Query<GetKeyQuery>,
888    ) -> Result<Json<ApiKeyView>, AuthError> {
889        let ctx = state.to_context();
890        let response = get_key_core(&query.id, user.id(), &plugin, &ctx).await?;
891        Ok(Json(response))
892    }
893
894    async fn handle_list<DB: DatabaseAdapter>(
895        State(state): State<AuthState<DB>>,
896        Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
897        CurrentSession { user, .. }: CurrentSession<DB>,
898    ) -> Result<Json<Vec<ApiKeyView>>, AuthError> {
899        let ctx = state.to_context();
900        let response = list_keys_core(user.id(), &plugin, &ctx).await?;
901        Ok(Json(response))
902    }
903
904    async fn handle_update<DB: DatabaseAdapter>(
905        State(state): State<AuthState<DB>>,
906        Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
907        CurrentSession { user, .. }: CurrentSession<DB>,
908        ValidatedJson(body): ValidatedJson<UpdateKeyRequest>,
909    ) -> Result<Json<ApiKeyView>, AuthError> {
910        let ctx = state.to_context();
911        let response = update_key_core(&body, user.id(), &plugin, &ctx).await?;
912        Ok(Json(response))
913    }
914
915    async fn handle_delete<DB: DatabaseAdapter>(
916        State(state): State<AuthState<DB>>,
917        Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
918        CurrentSession { user, .. }: CurrentSession<DB>,
919        ValidatedJson(body): ValidatedJson<DeleteKeyRequest>,
920    ) -> Result<Json<serde_json::Value>, AuthError> {
921        let ctx = state.to_context();
922        let response = delete_key_core(&body, user.id(), &plugin, &ctx).await?;
923        Ok(Json(response))
924    }
925
926    async fn handle_verify<DB: DatabaseAdapter>(
927        State(state): State<AuthState<DB>>,
928        Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
929        Json(body): Json<VerifyKeyRequest>,
930    ) -> Result<Json<VerifyKeyResponse>, AuthError> {
931        let ctx = state.to_context();
932        let response = verify_key_core(&body, &plugin, &ctx).await?;
933        Ok(Json(response))
934    }
935
936    async fn handle_delete_all_expired<DB: DatabaseAdapter>(
937        State(state): State<AuthState<DB>>,
938        Extension(plugin): Extension<Arc<ApiKeyPlugin>>,
939        CurrentSession { user, .. }: CurrentSession<DB>,
940    ) -> Result<Json<serde_json::Value>, AuthError> {
941        let ctx = state.to_context();
942        let response = delete_all_expired_core(user.id(), &plugin, &ctx).await?;
943        Ok(Json(response))
944    }
945
946    impl<DB: DatabaseAdapter> better_auth_core::AxumPlugin<DB> for ApiKeyPlugin {
947        fn name(&self) -> &'static str {
948            "api-key"
949        }
950
951        fn router(&self) -> axum::Router<AuthState<DB>> {
952            use axum::routing::{get, post};
953
954            let plugin = Arc::new(ApiKeyPlugin::with_config(self.config.clone()));
955            axum::Router::new()
956                .route("/api-key/create", post(handle_create::<DB>))
957                .route("/api-key/get", get(handle_get::<DB>))
958                .route("/api-key/update", post(handle_update::<DB>))
959                .route("/api-key/delete", post(handle_delete::<DB>))
960                .route("/api-key/list", get(handle_list::<DB>))
961                .route("/api-key/verify", post(handle_verify::<DB>))
962                .route(
963                    "/api-key/delete-all-expired-api-keys",
964                    post(handle_delete_all_expired::<DB>),
965                )
966                .layer(Extension(plugin))
967        }
968    }
969}