Skip to main content

better_auth_api/plugins/
api_key.rs

1use async_trait::async_trait;
2use base64::Engine;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use rand::RngCore;
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use validator::Validate;
8
9use better_auth_core::adapters::DatabaseAdapter;
10use better_auth_core::entity::{AuthApiKey, AuthSession, AuthUser};
11use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
12use better_auth_core::{AuthError, AuthResult};
13use better_auth_core::{AuthRequest, AuthResponse, CreateApiKey, HttpMethod, UpdateApiKey};
14
15/// API Key management plugin.
16pub struct ApiKeyPlugin {
17    config: ApiKeyConfig,
18}
19
20#[derive(Debug, Clone)]
21pub struct ApiKeyConfig {
22    pub key_length: usize,
23    pub prefix: Option<String>,
24    pub default_remaining: Option<i64>,
25}
26
27impl Default for ApiKeyConfig {
28    fn default() -> Self {
29        Self {
30            key_length: 32,
31            prefix: None,
32            default_remaining: None,
33        }
34    }
35}
36
37// -- Request types --
38
39#[derive(Debug, Deserialize, Validate)]
40struct CreateKeyRequest {
41    name: Option<String>,
42    prefix: Option<String>,
43    #[serde(rename = "expiresIn")]
44    #[validate(range(min = 1, message = "expiresIn must be greater than 0"))]
45    expires_in: Option<i64>,
46    remaining: Option<i64>,
47    #[serde(rename = "rateLimitEnabled")]
48    rate_limit_enabled: Option<bool>,
49    #[serde(rename = "rateLimitTimeWindow")]
50    rate_limit_time_window: Option<i64>,
51    #[serde(rename = "rateLimitMax")]
52    rate_limit_max: Option<i64>,
53    #[serde(rename = "refillInterval")]
54    refill_interval: Option<i64>,
55    #[serde(rename = "refillAmount")]
56    refill_amount: Option<i64>,
57    permissions: Option<serde_json::Value>,
58    metadata: Option<serde_json::Value>,
59}
60
61#[derive(Debug, Deserialize, Validate)]
62struct UpdateKeyRequest {
63    #[validate(length(min = 1, message = "Key ID is required"))]
64    id: String,
65    name: Option<String>,
66    enabled: Option<bool>,
67    remaining: Option<i64>,
68    #[serde(rename = "rateLimitEnabled")]
69    rate_limit_enabled: Option<bool>,
70    #[serde(rename = "rateLimitTimeWindow")]
71    rate_limit_time_window: Option<i64>,
72    #[serde(rename = "rateLimitMax")]
73    rate_limit_max: Option<i64>,
74    #[serde(rename = "refillInterval")]
75    refill_interval: Option<i64>,
76    #[serde(rename = "refillAmount")]
77    refill_amount: Option<i64>,
78    permissions: Option<serde_json::Value>,
79    metadata: Option<serde_json::Value>,
80}
81
82#[derive(Debug, Deserialize, Validate)]
83struct DeleteKeyRequest {
84    #[validate(length(min = 1, message = "Key ID is required"))]
85    id: String,
86}
87
88// -- Response types --
89
90#[derive(Debug, Serialize)]
91struct ApiKeyView {
92    id: String,
93    name: Option<String>,
94    start: Option<String>,
95    prefix: Option<String>,
96    #[serde(rename = "userId")]
97    user_id: String,
98    #[serde(rename = "refillInterval")]
99    refill_interval: Option<i64>,
100    #[serde(rename = "refillAmount")]
101    refill_amount: Option<i64>,
102    #[serde(rename = "lastRefillAt")]
103    last_refill_at: Option<String>,
104    enabled: bool,
105    #[serde(rename = "rateLimitEnabled")]
106    rate_limit_enabled: bool,
107    #[serde(rename = "rateLimitTimeWindow")]
108    rate_limit_time_window: Option<i64>,
109    #[serde(rename = "rateLimitMax")]
110    rate_limit_max: Option<i64>,
111    #[serde(rename = "requestCount")]
112    request_count: Option<i64>,
113    remaining: Option<i64>,
114    #[serde(rename = "lastRequest")]
115    last_request: Option<String>,
116    #[serde(rename = "expiresAt")]
117    expires_at: Option<String>,
118    #[serde(rename = "createdAt")]
119    created_at: String,
120    #[serde(rename = "updatedAt")]
121    updated_at: String,
122    permissions: Option<serde_json::Value>,
123    metadata: Option<serde_json::Value>,
124}
125
126#[derive(Debug, Serialize)]
127struct CreateKeyResponse {
128    key: String,
129    #[serde(flatten)]
130    api_key: ApiKeyView,
131}
132
133impl ApiKeyView {
134    fn from_entity(ak: &impl AuthApiKey) -> Self {
135        Self {
136            id: ak.id().to_string(),
137            name: ak.name().map(|s| s.to_string()),
138            start: ak.start().map(|s| s.to_string()),
139            prefix: ak.prefix().map(|s| s.to_string()),
140            user_id: ak.user_id().to_string(),
141            refill_interval: ak.refill_interval(),
142            refill_amount: ak.refill_amount(),
143            last_refill_at: ak.last_refill_at().map(|s| s.to_string()),
144            enabled: ak.enabled(),
145            rate_limit_enabled: ak.rate_limit_enabled(),
146            rate_limit_time_window: ak.rate_limit_time_window(),
147            rate_limit_max: ak.rate_limit_max(),
148            request_count: ak.request_count(),
149            remaining: ak.remaining(),
150            last_request: ak.last_request().map(|s| s.to_string()),
151            expires_at: ak.expires_at().map(|s| s.to_string()),
152            created_at: ak.created_at().to_string(),
153            updated_at: ak.updated_at().to_string(),
154            permissions: ak.permissions().and_then(|s| serde_json::from_str(s).ok()),
155            metadata: ak.metadata().and_then(|s| serde_json::from_str(s).ok()),
156        }
157    }
158}
159
160impl ApiKeyPlugin {
161    #[allow(clippy::new_without_default)]
162    pub fn new() -> Self {
163        Self {
164            config: ApiKeyConfig::default(),
165        }
166    }
167
168    pub fn with_config(config: ApiKeyConfig) -> Self {
169        Self { config }
170    }
171
172    pub fn key_length(mut self, length: usize) -> Self {
173        self.config.key_length = length;
174        self
175    }
176
177    pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
178        self.config.prefix = Some(prefix.into());
179        self
180    }
181
182    pub fn default_remaining(mut self, remaining: i64) -> Self {
183        self.config.default_remaining = Some(remaining);
184        self
185    }
186
187    fn generate_key(&self, custom_prefix: Option<&str>) -> (String, String, String) {
188        let mut bytes = vec![0u8; self.config.key_length];
189        rand::rngs::OsRng.fill_bytes(&mut bytes);
190        let raw = URL_SAFE_NO_PAD.encode(&bytes);
191
192        let start = raw.chars().take(4).collect::<String>();
193
194        let prefix = custom_prefix
195            .or(self.config.prefix.as_deref())
196            .unwrap_or("");
197        let full_key = format!("{}{}", prefix, raw);
198
199        let mut hasher = Sha256::new();
200        hasher.update(full_key.as_bytes());
201        let hash = format!("{:x}", hasher.finalize());
202
203        (full_key, hash, start)
204    }
205
206    async fn get_authenticated_user<DB: DatabaseAdapter>(
207        req: &AuthRequest,
208        ctx: &AuthContext<DB>,
209    ) -> AuthResult<(DB::User, DB::Session)> {
210        let token = req
211            .headers
212            .get("authorization")
213            .and_then(|v| v.strip_prefix("Bearer "))
214            .ok_or(AuthError::Unauthenticated)?;
215
216        let session = ctx
217            .database
218            .get_session(token)
219            .await?
220            .ok_or(AuthError::Unauthenticated)?;
221
222        if session.expires_at() < chrono::Utc::now() {
223            return Err(AuthError::Unauthenticated);
224        }
225
226        let user = ctx
227            .database
228            .get_user_by_id(session.user_id())
229            .await?
230            .ok_or(AuthError::UserNotFound)?;
231
232        Ok((user, session))
233    }
234
235    async fn handle_create<DB: DatabaseAdapter>(
236        &self,
237        req: &AuthRequest,
238        ctx: &AuthContext<DB>,
239    ) -> AuthResult<AuthResponse> {
240        let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
241
242        let create_req: CreateKeyRequest = match better_auth_core::validate_request_body(req) {
243            Ok(v) => v,
244            Err(resp) => return Ok(resp),
245        };
246
247        let (full_key, hash, start) = self.generate_key(create_req.prefix.as_deref());
248
249        let expires_at = if let Some(ms) = create_req.expires_in {
250            let duration = chrono::Duration::try_milliseconds(ms)
251                .ok_or_else(|| AuthError::bad_request("expiresIn is out of range"))?;
252            let dt = chrono::Utc::now()
253                .checked_add_signed(duration)
254                .ok_or_else(|| AuthError::bad_request("expiresIn is out of range"))?;
255            Some(dt.to_rfc3339())
256        } else {
257            None
258        };
259
260        let remaining = create_req.remaining.or(self.config.default_remaining);
261
262        let input = CreateApiKey {
263            user_id: user.id().to_string(),
264            name: create_req.name,
265            prefix: create_req.prefix.or_else(|| self.config.prefix.clone()),
266            key_hash: hash,
267            start: Some(start),
268            expires_at,
269            remaining,
270            rate_limit_enabled: create_req.rate_limit_enabled.unwrap_or(false),
271            rate_limit_time_window: create_req.rate_limit_time_window,
272            rate_limit_max: create_req.rate_limit_max,
273            refill_interval: create_req.refill_interval,
274            refill_amount: create_req.refill_amount,
275            permissions: create_req
276                .permissions
277                .map(|v| serde_json::to_string(&v).unwrap_or_default()),
278            metadata: create_req
279                .metadata
280                .map(|v| serde_json::to_string(&v).unwrap_or_default()),
281            enabled: true,
282        };
283
284        let api_key = ctx.database.create_api_key(input).await?;
285
286        let response = CreateKeyResponse {
287            key: full_key,
288            api_key: ApiKeyView::from_entity(&api_key),
289        };
290
291        Ok(AuthResponse::json(200, &response)?)
292    }
293
294    async fn handle_get<DB: DatabaseAdapter>(
295        &self,
296        req: &AuthRequest,
297        ctx: &AuthContext<DB>,
298    ) -> AuthResult<AuthResponse> {
299        let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
300
301        let id = req
302            .query
303            .get("id")
304            .ok_or_else(|| AuthError::bad_request("Query parameter 'id' is required"))?;
305
306        let api_key = ctx
307            .database
308            .get_api_key_by_id(id)
309            .await?
310            .ok_or_else(|| AuthError::not_found("API key not found"))?;
311
312        if api_key.user_id() != user.id() {
313            return Err(AuthError::not_found("API key not found"));
314        }
315
316        Ok(AuthResponse::json(200, &ApiKeyView::from_entity(&api_key))?)
317    }
318
319    async fn handle_list<DB: DatabaseAdapter>(
320        &self,
321        req: &AuthRequest,
322        ctx: &AuthContext<DB>,
323    ) -> AuthResult<AuthResponse> {
324        let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
325
326        let keys = ctx.database.list_api_keys_by_user(user.id()).await?;
327
328        let views: Vec<ApiKeyView> = keys.iter().map(ApiKeyView::from_entity).collect();
329
330        Ok(AuthResponse::json(200, &views)?)
331    }
332
333    async fn handle_update<DB: DatabaseAdapter>(
334        &self,
335        req: &AuthRequest,
336        ctx: &AuthContext<DB>,
337    ) -> AuthResult<AuthResponse> {
338        let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
339
340        let update_req: UpdateKeyRequest = match better_auth_core::validate_request_body(req) {
341            Ok(v) => v,
342            Err(resp) => return Ok(resp),
343        };
344
345        // Ownership check
346        let existing = ctx
347            .database
348            .get_api_key_by_id(&update_req.id)
349            .await?
350            .ok_or_else(|| AuthError::not_found("API key not found"))?;
351
352        if existing.user_id() != user.id() {
353            return Err(AuthError::not_found("API key not found"));
354        }
355
356        let update = UpdateApiKey {
357            name: update_req.name,
358            enabled: update_req.enabled,
359            remaining: update_req.remaining,
360            rate_limit_enabled: update_req.rate_limit_enabled,
361            rate_limit_time_window: update_req.rate_limit_time_window,
362            rate_limit_max: update_req.rate_limit_max,
363            refill_interval: update_req.refill_interval,
364            refill_amount: update_req.refill_amount,
365            permissions: update_req
366                .permissions
367                .map(|v| serde_json::to_string(&v).unwrap_or_default()),
368            metadata: update_req
369                .metadata
370                .map(|v| serde_json::to_string(&v).unwrap_or_default()),
371        };
372
373        let updated = ctx.database.update_api_key(&update_req.id, update).await?;
374
375        Ok(AuthResponse::json(200, &ApiKeyView::from_entity(&updated))?)
376    }
377
378    async fn handle_delete<DB: DatabaseAdapter>(
379        &self,
380        req: &AuthRequest,
381        ctx: &AuthContext<DB>,
382    ) -> AuthResult<AuthResponse> {
383        let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
384
385        let delete_req: DeleteKeyRequest = match better_auth_core::validate_request_body(req) {
386            Ok(v) => v,
387            Err(resp) => return Ok(resp),
388        };
389
390        // Ownership check
391        let existing = ctx
392            .database
393            .get_api_key_by_id(&delete_req.id)
394            .await?
395            .ok_or_else(|| AuthError::not_found("API key not found"))?;
396
397        if existing.user_id() != user.id() {
398            return Err(AuthError::not_found("API key not found"));
399        }
400
401        ctx.database.delete_api_key(&delete_req.id).await?;
402
403        Ok(AuthResponse::json(
404            200,
405            &serde_json::json!({ "status": true }),
406        )?)
407    }
408}
409
410#[async_trait]
411impl<DB: DatabaseAdapter> AuthPlugin<DB> for ApiKeyPlugin {
412    fn name(&self) -> &'static str {
413        "api-key"
414    }
415
416    fn routes(&self) -> Vec<AuthRoute> {
417        vec![
418            AuthRoute::post("/api-key/create", "api_key_create"),
419            AuthRoute::get("/api-key/get", "api_key_get"),
420            AuthRoute::post("/api-key/update", "api_key_update"),
421            AuthRoute::post("/api-key/delete", "api_key_delete"),
422            AuthRoute::get("/api-key/list", "api_key_list"),
423        ]
424    }
425
426    async fn on_request(
427        &self,
428        req: &AuthRequest,
429        ctx: &AuthContext<DB>,
430    ) -> AuthResult<Option<AuthResponse>> {
431        match (req.method(), req.path()) {
432            (HttpMethod::Post, "/api-key/create") => Ok(Some(self.handle_create(req, ctx).await?)),
433            (HttpMethod::Get, "/api-key/get") => Ok(Some(self.handle_get(req, ctx).await?)),
434            (HttpMethod::Post, "/api-key/update") => Ok(Some(self.handle_update(req, ctx).await?)),
435            (HttpMethod::Post, "/api-key/delete") => Ok(Some(self.handle_delete(req, ctx).await?)),
436            (HttpMethod::Get, "/api-key/list") => Ok(Some(self.handle_list(req, ctx).await?)),
437            _ => Ok(None),
438        }
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use better_auth_core::adapters::{ApiKeyOps, MemoryDatabaseAdapter, SessionOps, UserOps};
446    use better_auth_core::{CreateSession, CreateUser, Session, User};
447    use chrono::{Duration, Utc};
448    use std::collections::HashMap;
449    use std::sync::Arc;
450
451    async fn create_test_context_with_user() -> (AuthContext<MemoryDatabaseAdapter>, User, Session)
452    {
453        let config = Arc::new(better_auth_core::AuthConfig::new(
454            "test-secret-key-at-least-32-chars-long",
455        ));
456        let database = Arc::new(MemoryDatabaseAdapter::new());
457        let ctx = AuthContext::new(config, database.clone());
458
459        let user = database
460            .create_user(
461                CreateUser::new()
462                    .with_email("test@example.com")
463                    .with_name("Test User"),
464            )
465            .await
466            .unwrap();
467
468        let session = database
469            .create_session(CreateSession {
470                user_id: user.id.clone(),
471                expires_at: Utc::now() + Duration::hours(24),
472                ip_address: Some("127.0.0.1".to_string()),
473                user_agent: Some("test-agent".to_string()),
474                impersonated_by: None,
475                active_organization_id: None,
476            })
477            .await
478            .unwrap();
479
480        (ctx, user, session)
481    }
482
483    async fn create_user_with_session(
484        ctx: &AuthContext<MemoryDatabaseAdapter>,
485        email: &str,
486    ) -> (User, Session) {
487        let user = ctx
488            .database
489            .create_user(
490                CreateUser::new()
491                    .with_email(email.to_string())
492                    .with_name("Another User"),
493            )
494            .await
495            .unwrap();
496
497        let session = ctx
498            .database
499            .create_session(CreateSession {
500                user_id: user.id.clone(),
501                expires_at: Utc::now() + Duration::hours(24),
502                ip_address: None,
503                user_agent: None,
504                impersonated_by: None,
505                active_organization_id: None,
506            })
507            .await
508            .unwrap();
509
510        (user, session)
511    }
512
513    fn create_auth_request(
514        method: HttpMethod,
515        path: &str,
516        token: Option<&str>,
517        body: Option<serde_json::Value>,
518        query: Option<HashMap<String, String>>,
519    ) -> AuthRequest {
520        let mut headers = HashMap::new();
521        if let Some(token) = token {
522            headers.insert("authorization".to_string(), format!("Bearer {}", token));
523        }
524
525        AuthRequest {
526            method,
527            path: path.to_string(),
528            headers,
529            body: body.map(|b| serde_json::to_vec(&b).unwrap()),
530            query: query.unwrap_or_default(),
531        }
532    }
533
534    fn json_body(response: &AuthResponse) -> serde_json::Value {
535        serde_json::from_slice(&response.body).unwrap()
536    }
537
538    async fn create_key_and_get_id(
539        plugin: &ApiKeyPlugin,
540        ctx: &AuthContext<MemoryDatabaseAdapter>,
541        token: &str,
542        name: &str,
543    ) -> String {
544        let req = create_auth_request(
545            HttpMethod::Post,
546            "/api-key/create",
547            Some(token),
548            Some(serde_json::json!({ "name": name })),
549            None,
550        );
551        let response = plugin.handle_create(&req, ctx).await.unwrap();
552        assert_eq!(response.status, 200);
553        json_body(&response)["id"].as_str().unwrap().to_string()
554    }
555
556    #[tokio::test]
557    async fn test_create_and_get_do_not_expose_hash() {
558        let plugin = ApiKeyPlugin::new().prefix("ba_");
559        let (ctx, _user, session) = create_test_context_with_user().await;
560
561        let create_req = create_auth_request(
562            HttpMethod::Post,
563            "/api-key/create",
564            Some(&session.token),
565            Some(serde_json::json!({ "name": "primary" })),
566            None,
567        );
568        let create_response = plugin.handle_create(&create_req, &ctx).await.unwrap();
569        assert_eq!(create_response.status, 200);
570
571        let body = json_body(&create_response);
572        assert!(body.get("key").is_some());
573        assert!(body.get("key_hash").is_none());
574        assert!(body.get("hash").is_none());
575
576        let id = body["id"].as_str().unwrap();
577        let mut query = HashMap::new();
578        query.insert("id".to_string(), id.to_string());
579
580        let get_req = create_auth_request(
581            HttpMethod::Get,
582            "/api-key/get",
583            Some(&session.token),
584            None,
585            Some(query),
586        );
587        let get_response = plugin.handle_get(&get_req, &ctx).await.unwrap();
588        assert_eq!(get_response.status, 200);
589
590        let get_body = json_body(&get_response);
591        assert!(get_body.get("key").is_none());
592        assert!(get_body.get("key_hash").is_none());
593    }
594
595    #[tokio::test]
596    async fn test_create_rejects_invalid_expires_in() {
597        let plugin = ApiKeyPlugin::new();
598        let (ctx, _user, session) = create_test_context_with_user().await;
599
600        let req = create_auth_request(
601            HttpMethod::Post,
602            "/api-key/create",
603            Some(&session.token),
604            Some(serde_json::json!({ "expiresIn": i64::MIN })),
605            None,
606        );
607        let response = plugin.handle_create(&req, &ctx).await.unwrap();
608        assert_eq!(response.status, 422);
609    }
610
611    #[tokio::test]
612    async fn test_get_update_delete_return_404_for_non_owner() {
613        let plugin = ApiKeyPlugin::new();
614        let (ctx, _user1, session1) = create_test_context_with_user().await;
615        let (_user2, session2) = create_user_with_session(&ctx, "other@example.com").await;
616        let key_id = create_key_and_get_id(&plugin, &ctx, &session1.token, "owner-key").await;
617
618        let mut get_query = HashMap::new();
619        get_query.insert("id".to_string(), key_id.clone());
620        let get_req = create_auth_request(
621            HttpMethod::Get,
622            "/api-key/get",
623            Some(&session2.token),
624            None,
625            Some(get_query),
626        );
627        let get_err = plugin.handle_get(&get_req, &ctx).await.unwrap_err();
628        assert_eq!(get_err.status_code(), 404);
629
630        let update_req = create_auth_request(
631            HttpMethod::Post,
632            "/api-key/update",
633            Some(&session2.token),
634            Some(serde_json::json!({ "id": key_id, "name": "new-name" })),
635            None,
636        );
637        let update_err = plugin.handle_update(&update_req, &ctx).await.unwrap_err();
638        assert_eq!(update_err.status_code(), 404);
639
640        let delete_req = create_auth_request(
641            HttpMethod::Post,
642            "/api-key/delete",
643            Some(&session2.token),
644            Some(serde_json::json!({ "id": key_id })),
645            None,
646        );
647        let delete_err = plugin.handle_delete(&delete_req, &ctx).await.unwrap_err();
648        assert_eq!(delete_err.status_code(), 404);
649    }
650
651    #[tokio::test]
652    async fn test_list_returns_only_user_keys() {
653        let plugin = ApiKeyPlugin::new();
654        let (ctx, user1, session1) = create_test_context_with_user().await;
655        let (_user2, session2) = create_user_with_session(&ctx, "other@example.com").await;
656
657        let _ = create_key_and_get_id(&plugin, &ctx, &session1.token, "u1-key").await;
658        let _ = create_key_and_get_id(&plugin, &ctx, &session2.token, "u2-key").await;
659
660        let list_req = create_auth_request(
661            HttpMethod::Get,
662            "/api-key/list",
663            Some(&session1.token),
664            None,
665            None,
666        );
667        let list_response = plugin.handle_list(&list_req, &ctx).await.unwrap();
668        assert_eq!(list_response.status, 200);
669
670        let list_body = json_body(&list_response);
671        let list = list_body.as_array().unwrap();
672        assert_eq!(list.len(), 1);
673        assert_eq!(list[0]["userId"].as_str().unwrap(), user1.id);
674        assert!(list[0].get("key").is_none());
675        assert!(list[0].get("key_hash").is_none());
676    }
677
678    #[tokio::test]
679    async fn test_owner_can_delete_key() {
680        let plugin = ApiKeyPlugin::new();
681        let (ctx, _user, session) = create_test_context_with_user().await;
682        let key_id = create_key_and_get_id(&plugin, &ctx, &session.token, "to-delete").await;
683
684        let delete_req = create_auth_request(
685            HttpMethod::Post,
686            "/api-key/delete",
687            Some(&session.token),
688            Some(serde_json::json!({ "id": key_id })),
689            None,
690        );
691        let delete_response = plugin.handle_delete(&delete_req, &ctx).await.unwrap();
692        assert_eq!(delete_response.status, 200);
693
694        let deleted = ctx.database.get_api_key_by_id(&key_id).await.unwrap();
695        assert!(deleted.is_none());
696    }
697}