Skip to main content

better_auth_api/plugins/
device_authorization.rs

1use chrono::{DateTime, Utc};
2use rand::{Rng, distributions::Alphanumeric};
3use serde::{Deserialize, Serialize};
4use validator::Validate;
5
6use better_auth_core::adapters::DatabaseAdapter;
7use better_auth_core::entity::{AuthUser, AuthVerification};
8use better_auth_core::{
9    AuthContext, AuthError, AuthRequest, AuthResponse, AuthResult, AuthSession, CreateVerification,
10};
11
12use super::StatusResponse;
13
14#[derive(Debug, Clone, better_auth_core::PluginConfig)]
15#[plugin(name = "DeviceAuthorizationPlugin")]
16pub struct DeviceAuthorizationConfig {
17    #[config(default = false)]
18    pub enabled: bool,
19
20    #[config(default = "/device".to_string())]
21    pub verification_uri: String,
22
23    #[config(default = 5)]
24    pub interval: i64,
25
26    #[config(default = 1800)]
27    pub expires_in: i64,
28}
29
30pub struct DeviceAuthorizationPlugin {
31    config: DeviceAuthorizationConfig,
32}
33
34better_auth_core::impl_auth_plugin! {
35    DeviceAuthorizationPlugin, "device-flow";
36    routes {
37        post "/device/code" => handle_code, "device_code";
38        post "/device/token" => handle_token, "device_token";
39        post "/device/approve" => handle_approve, "device_approve";
40        post "/device/deny" => handle_deny, "device_deny";
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45struct DeviceAuthorizationRecord {
46    device_code: String,
47    user_code: String,
48    client_id: String,
49    scope: Option<String>,
50    status: DeviceStatus,
51    user_id: Option<String>,
52    expires_at: DateTime<Utc>,
53    last_polled_at: Option<DateTime<Utc>>,
54    interval: i64,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58#[serde(rename_all = "lowercase")]
59enum DeviceStatus {
60    Pending,
61    Approved,
62    Denied,
63}
64
65#[derive(Debug, Serialize, Deserialize)]
66struct StoredVerification {
67    code: String,
68    data: DeviceAuthorizationRecord,
69}
70
71#[derive(Debug, Deserialize, Validate)]
72struct DeviceCodeRequest {
73    #[validate(length(min = 1))]
74    client_id: String,
75    scope: Option<String>,
76}
77
78#[derive(Debug, Deserialize, Validate)]
79struct DeviceTokenRequest {
80    device_code: String,
81}
82
83#[derive(Debug, Deserialize, Validate)]
84struct ApproveRequest {
85    #[serde(rename = "userCode")]
86    user_code: String,
87}
88
89#[derive(Debug, Serialize)]
90struct DeviceCodeResponse {
91    device_code: String,
92    user_code: String,
93    verification_uri: String,
94    expires_in: i64,
95    interval: i64,
96}
97
98#[derive(Debug, Serialize)]
99struct DeviceTokenResponse {
100    access_token: String,
101}
102
103fn device_code_identifier(device_code: &str) -> String {
104    format!("device_code:{device_code}")
105}
106
107fn user_code_identifier(user_code: &str) -> String {
108    format!("user_code:{user_code}")
109}
110
111async fn store_device_authorization<DB: DatabaseAdapter>(
112    identifier: String,
113    stored: &StoredVerification,
114    ctx: &AuthContext<DB>,
115) -> AuthResult<()> {
116    ctx.database
117        .create_verification(CreateVerification {
118            identifier,
119            value: serde_json::to_string(stored)?,
120            expires_at: stored.data.expires_at,
121        })
122        .await?;
123
124    Ok(())
125}
126
127// ---------------------------------------------------------------------------
128// Core functions — framework-agnostic business logic
129// ---------------------------------------------------------------------------
130
131async fn create_device_code_core<DB: DatabaseAdapter>(
132    client_id: String,
133    scope: Option<String>,
134    ctx: &AuthContext<DB>,
135    config: &DeviceAuthorizationConfig,
136) -> AuthResult<DeviceCodeResponse> {
137    let device_code: String = rand::thread_rng()
138        .sample_iter(&Alphanumeric)
139        .take(40)
140        .map(char::from)
141        .collect();
142
143    let user_code: String = rand::thread_rng()
144        .sample_iter(&Alphanumeric)
145        .take(8)
146        .map(char::from)
147        .collect::<String>()
148        .to_uppercase();
149
150    let expires_at = Utc::now() + chrono::Duration::seconds(config.expires_in);
151    let stored = StoredVerification {
152        code: user_code.clone(),
153        data: DeviceAuthorizationRecord {
154            device_code: device_code.clone(),
155            user_code: user_code.clone(),
156            client_id,
157            scope,
158            status: DeviceStatus::Pending,
159            user_id: None,
160            expires_at,
161            last_polled_at: None,
162            interval: config.interval,
163        },
164    };
165
166    store_device_authorization(device_code_identifier(&device_code), &stored, ctx).await?;
167    store_device_authorization(user_code_identifier(&user_code), &stored, ctx).await?;
168
169    Ok(DeviceCodeResponse {
170        device_code,
171        user_code,
172        verification_uri: config.verification_uri.clone(),
173        expires_in: config.expires_in,
174        interval: config.interval,
175    })
176}
177
178async fn poll_device_token_core<DB: DatabaseAdapter>(
179    device_code: String,
180    ctx: &AuthContext<DB>,
181) -> AuthResult<DeviceTokenResponse> {
182    let identifier = device_code_identifier(&device_code);
183    let verification = ctx
184        .database
185        .get_verification_by_identifier(&identifier)
186        .await?
187        .ok_or_else(|| AuthError::bad_request("invalid_grant"))?;
188
189    let verification_id = verification.id().to_string();
190    let raw_value = verification.value().to_string();
191    let mut stored: StoredVerification = serde_json::from_str(&raw_value)?;
192
193    match stored.data.status {
194        DeviceStatus::Pending => {
195            let now = Utc::now();
196            if let Some(last) = stored.data.last_polled_at
197                && (now - last).num_seconds() < stored.data.interval
198            {
199                return Err(AuthError::bad_request("slow_down"));
200            }
201
202            stored.data.last_polled_at = Some(now);
203            ctx.database.delete_verification(&verification_id).await?;
204            store_device_authorization(identifier, &stored, ctx).await?;
205
206            Err(AuthError::bad_request("authorization_pending"))
207        }
208        DeviceStatus::Denied => Err(AuthError::bad_request("access_denied")),
209        DeviceStatus::Approved => {
210            // Atomically consume to prevent double-redemption from concurrent polls
211            if ctx
212                .database
213                .consume_verification(&identifier, &raw_value)
214                .await?
215                .is_none()
216            {
217                return Err(AuthError::bad_request("invalid_grant"));
218            }
219
220            let user_id =
221                stored.data.user_id.clone().ok_or_else(|| {
222                    AuthError::internal("Device authorization missing approved user")
223                })?;
224
225            let user = ctx
226                .database
227                .get_user_by_id(&user_id)
228                .await?
229                .ok_or_else(|| AuthError::internal("User not found"))?;
230
231            let session = ctx
232                .session_manager()
233                .create_session(&user, None, None)
234                .await?;
235
236            Ok(DeviceTokenResponse {
237                access_token: session.token().to_string(),
238            })
239        }
240    }
241}
242
243async fn approve_device_core<DB: DatabaseAdapter>(
244    user: &DB::User,
245    user_code: &str,
246    ctx: &AuthContext<DB>,
247) -> AuthResult<StatusResponse> {
248    let verification = ctx
249        .database
250        .get_verification_by_identifier(&user_code_identifier(user_code))
251        .await?
252        .ok_or_else(|| AuthError::not_found("Invalid code"))?;
253
254    let verification_id = verification.id().to_string();
255    let mut stored: StoredVerification = serde_json::from_str(verification.value())?;
256
257    stored.data.status = DeviceStatus::Approved;
258    stored.data.user_id = Some(user.id().to_string());
259
260    let device_identifier = device_code_identifier(&stored.data.device_code);
261    if let Some(device_verification) = ctx
262        .database
263        .get_verification_by_identifier(&device_identifier)
264        .await?
265    {
266        ctx.database
267            .delete_verification(device_verification.id())
268            .await?;
269    }
270
271    ctx.database.delete_verification(&verification_id).await?;
272    store_device_authorization(device_identifier, &stored, ctx).await?;
273
274    Ok(StatusResponse { status: true })
275}
276
277async fn deny_device_core<DB: DatabaseAdapter>(
278    user_code: &str,
279    ctx: &AuthContext<DB>,
280) -> AuthResult<StatusResponse> {
281    let verification = ctx
282        .database
283        .get_verification_by_identifier(&user_code_identifier(user_code))
284        .await?
285        .ok_or_else(|| AuthError::not_found("Invalid code"))?;
286
287    let verification_id = verification.id().to_string();
288    let mut stored: StoredVerification = serde_json::from_str(verification.value())?;
289
290    stored.data.status = DeviceStatus::Denied;
291
292    let device_identifier = device_code_identifier(&stored.data.device_code);
293    if let Some(device_verification) = ctx
294        .database
295        .get_verification_by_identifier(&device_identifier)
296        .await?
297    {
298        ctx.database
299            .delete_verification(device_verification.id())
300            .await?;
301    }
302
303    ctx.database.delete_verification(&verification_id).await?;
304    store_device_authorization(device_identifier, &stored, ctx).await?;
305
306    Ok(StatusResponse { status: true })
307}
308
309// ---------------------------------------------------------------------------
310// Old handler methods — delegate to core functions
311// ---------------------------------------------------------------------------
312
313impl DeviceAuthorizationPlugin {
314    async fn handle_code<DB: DatabaseAdapter>(
315        &self,
316        req: &AuthRequest,
317        ctx: &AuthContext<DB>,
318    ) -> AuthResult<AuthResponse> {
319        if !self.config.enabled {
320            return Err(AuthError::not_found("Not found"));
321        }
322
323        let body: DeviceCodeRequest = match better_auth_core::validate_request_body(req) {
324            Ok(v) => v,
325            Err(resp) => return Ok(resp),
326        };
327
328        let resp = create_device_code_core(body.client_id, body.scope, ctx, &self.config).await?;
329
330        Ok(AuthResponse::json(200, &resp)?)
331    }
332
333    async fn handle_token<DB: DatabaseAdapter>(
334        &self,
335        req: &AuthRequest,
336        ctx: &AuthContext<DB>,
337    ) -> AuthResult<AuthResponse> {
338        if !self.config.enabled {
339            return Err(AuthError::not_found("Not found"));
340        }
341
342        let body: DeviceTokenRequest = match better_auth_core::validate_request_body(req) {
343            Ok(v) => v,
344            Err(resp) => return Ok(resp),
345        };
346
347        let resp = poll_device_token_core(body.device_code, ctx).await?;
348
349        Ok(AuthResponse::json(200, &resp)?)
350    }
351
352    async fn handle_approve<DB: DatabaseAdapter>(
353        &self,
354        req: &AuthRequest,
355        ctx: &AuthContext<DB>,
356    ) -> AuthResult<AuthResponse> {
357        if !self.config.enabled {
358            return Err(AuthError::not_found("Not found"));
359        }
360
361        let (user, _) = ctx.require_session(req).await?;
362
363        let body: ApproveRequest = match better_auth_core::validate_request_body(req) {
364            Ok(v) => v,
365            Err(resp) => return Ok(resp),
366        };
367
368        let resp = approve_device_core(&user, &body.user_code, ctx).await?;
369
370        Ok(AuthResponse::json(200, &resp)?)
371    }
372
373    async fn handle_deny<DB: DatabaseAdapter>(
374        &self,
375        req: &AuthRequest,
376        ctx: &AuthContext<DB>,
377    ) -> AuthResult<AuthResponse> {
378        if !self.config.enabled {
379            return Err(AuthError::not_found("Not found"));
380        }
381
382        let _ = ctx.require_session(req).await?;
383
384        let body: ApproveRequest = match better_auth_core::validate_request_body(req) {
385            Ok(v) => v,
386            Err(resp) => return Ok(resp),
387        };
388
389        let resp = deny_device_core(&body.user_code, ctx).await?;
390
391        Ok(AuthResponse::json(200, &resp)?)
392    }
393}