1use chrono::{DateTime, Utc};
2use rand::{Rng, distributions::Alphanumeric};
3use serde::{Deserialize, Serialize};
4use validator::Validate;
5
6use better_auth_core::adapters::DatabaseAdapter;
7use better_auth_core::entity::{AuthUser, AuthVerification};
8use better_auth_core::{
9 AuthContext, AuthError, AuthRequest, AuthResponse, AuthResult, AuthSession, CreateVerification,
10};
11
12use super::StatusResponse;
13
14#[derive(Debug, Clone, better_auth_core::PluginConfig)]
15#[plugin(name = "DeviceAuthorizationPlugin")]
16pub struct DeviceAuthorizationConfig {
17 #[config(default = false)]
18 pub enabled: bool,
19
20 #[config(default = "/device".to_string())]
21 pub verification_uri: String,
22
23 #[config(default = 5)]
24 pub interval: i64,
25
26 #[config(default = 1800)]
27 pub expires_in: i64,
28}
29
30pub struct DeviceAuthorizationPlugin {
31 config: DeviceAuthorizationConfig,
32}
33
34better_auth_core::impl_auth_plugin! {
35 DeviceAuthorizationPlugin, "device-flow";
36 routes {
37 post "/device/code" => handle_code, "device_code";
38 post "/device/token" => handle_token, "device_token";
39 post "/device/approve" => handle_approve, "device_approve";
40 post "/device/deny" => handle_deny, "device_deny";
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45struct DeviceAuthorizationRecord {
46 device_code: String,
47 user_code: String,
48 client_id: String,
49 scope: Option<String>,
50 status: DeviceStatus,
51 user_id: Option<String>,
52 expires_at: DateTime<Utc>,
53 last_polled_at: Option<DateTime<Utc>>,
54 interval: i64,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58#[serde(rename_all = "lowercase")]
59enum DeviceStatus {
60 Pending,
61 Approved,
62 Denied,
63}
64
65#[derive(Debug, Serialize, Deserialize)]
66struct StoredVerification {
67 code: String,
68 data: DeviceAuthorizationRecord,
69}
70
71#[derive(Debug, Deserialize, Validate)]
72struct DeviceCodeRequest {
73 #[validate(length(min = 1))]
74 client_id: String,
75 scope: Option<String>,
76}
77
78#[derive(Debug, Deserialize, Validate)]
79struct DeviceTokenRequest {
80 device_code: String,
81}
82
83#[derive(Debug, Deserialize, Validate)]
84struct ApproveRequest {
85 #[serde(rename = "userCode")]
86 user_code: String,
87}
88
89#[derive(Debug, Serialize)]
90struct DeviceCodeResponse {
91 device_code: String,
92 user_code: String,
93 verification_uri: String,
94 expires_in: i64,
95 interval: i64,
96}
97
98#[derive(Debug, Serialize)]
99struct DeviceTokenResponse {
100 access_token: String,
101}
102
103fn device_code_identifier(device_code: &str) -> String {
104 format!("device_code:{device_code}")
105}
106
107fn user_code_identifier(user_code: &str) -> String {
108 format!("user_code:{user_code}")
109}
110
111async fn store_device_authorization<DB: DatabaseAdapter>(
112 identifier: String,
113 stored: &StoredVerification,
114 ctx: &AuthContext<DB>,
115) -> AuthResult<()> {
116 ctx.database
117 .create_verification(CreateVerification {
118 identifier,
119 value: serde_json::to_string(stored)?,
120 expires_at: stored.data.expires_at,
121 })
122 .await?;
123
124 Ok(())
125}
126
127async fn create_device_code_core<DB: DatabaseAdapter>(
132 client_id: String,
133 scope: Option<String>,
134 ctx: &AuthContext<DB>,
135 config: &DeviceAuthorizationConfig,
136) -> AuthResult<DeviceCodeResponse> {
137 let device_code: String = rand::thread_rng()
138 .sample_iter(&Alphanumeric)
139 .take(40)
140 .map(char::from)
141 .collect();
142
143 let user_code: String = rand::thread_rng()
144 .sample_iter(&Alphanumeric)
145 .take(8)
146 .map(char::from)
147 .collect::<String>()
148 .to_uppercase();
149
150 let expires_at = Utc::now() + chrono::Duration::seconds(config.expires_in);
151 let stored = StoredVerification {
152 code: user_code.clone(),
153 data: DeviceAuthorizationRecord {
154 device_code: device_code.clone(),
155 user_code: user_code.clone(),
156 client_id,
157 scope,
158 status: DeviceStatus::Pending,
159 user_id: None,
160 expires_at,
161 last_polled_at: None,
162 interval: config.interval,
163 },
164 };
165
166 store_device_authorization(device_code_identifier(&device_code), &stored, ctx).await?;
167 store_device_authorization(user_code_identifier(&user_code), &stored, ctx).await?;
168
169 Ok(DeviceCodeResponse {
170 device_code,
171 user_code,
172 verification_uri: config.verification_uri.clone(),
173 expires_in: config.expires_in,
174 interval: config.interval,
175 })
176}
177
178async fn poll_device_token_core<DB: DatabaseAdapter>(
179 device_code: String,
180 ctx: &AuthContext<DB>,
181) -> AuthResult<DeviceTokenResponse> {
182 let identifier = device_code_identifier(&device_code);
183 let verification = ctx
184 .database
185 .get_verification_by_identifier(&identifier)
186 .await?
187 .ok_or_else(|| AuthError::bad_request("invalid_grant"))?;
188
189 let verification_id = verification.id().to_string();
190 let raw_value = verification.value().to_string();
191 let mut stored: StoredVerification = serde_json::from_str(&raw_value)?;
192
193 match stored.data.status {
194 DeviceStatus::Pending => {
195 let now = Utc::now();
196 if let Some(last) = stored.data.last_polled_at
197 && (now - last).num_seconds() < stored.data.interval
198 {
199 return Err(AuthError::bad_request("slow_down"));
200 }
201
202 stored.data.last_polled_at = Some(now);
203 ctx.database.delete_verification(&verification_id).await?;
204 store_device_authorization(identifier, &stored, ctx).await?;
205
206 Err(AuthError::bad_request("authorization_pending"))
207 }
208 DeviceStatus::Denied => Err(AuthError::bad_request("access_denied")),
209 DeviceStatus::Approved => {
210 if ctx
212 .database
213 .consume_verification(&identifier, &raw_value)
214 .await?
215 .is_none()
216 {
217 return Err(AuthError::bad_request("invalid_grant"));
218 }
219
220 let user_id =
221 stored.data.user_id.clone().ok_or_else(|| {
222 AuthError::internal("Device authorization missing approved user")
223 })?;
224
225 let user = ctx
226 .database
227 .get_user_by_id(&user_id)
228 .await?
229 .ok_or_else(|| AuthError::internal("User not found"))?;
230
231 let session = ctx
232 .session_manager()
233 .create_session(&user, None, None)
234 .await?;
235
236 Ok(DeviceTokenResponse {
237 access_token: session.token().to_string(),
238 })
239 }
240 }
241}
242
243async fn approve_device_core<DB: DatabaseAdapter>(
244 user: &DB::User,
245 user_code: &str,
246 ctx: &AuthContext<DB>,
247) -> AuthResult<StatusResponse> {
248 let verification = ctx
249 .database
250 .get_verification_by_identifier(&user_code_identifier(user_code))
251 .await?
252 .ok_or_else(|| AuthError::not_found("Invalid code"))?;
253
254 let verification_id = verification.id().to_string();
255 let mut stored: StoredVerification = serde_json::from_str(verification.value())?;
256
257 stored.data.status = DeviceStatus::Approved;
258 stored.data.user_id = Some(user.id().to_string());
259
260 let device_identifier = device_code_identifier(&stored.data.device_code);
261 if let Some(device_verification) = ctx
262 .database
263 .get_verification_by_identifier(&device_identifier)
264 .await?
265 {
266 ctx.database
267 .delete_verification(device_verification.id())
268 .await?;
269 }
270
271 ctx.database.delete_verification(&verification_id).await?;
272 store_device_authorization(device_identifier, &stored, ctx).await?;
273
274 Ok(StatusResponse { status: true })
275}
276
277async fn deny_device_core<DB: DatabaseAdapter>(
278 user_code: &str,
279 ctx: &AuthContext<DB>,
280) -> AuthResult<StatusResponse> {
281 let verification = ctx
282 .database
283 .get_verification_by_identifier(&user_code_identifier(user_code))
284 .await?
285 .ok_or_else(|| AuthError::not_found("Invalid code"))?;
286
287 let verification_id = verification.id().to_string();
288 let mut stored: StoredVerification = serde_json::from_str(verification.value())?;
289
290 stored.data.status = DeviceStatus::Denied;
291
292 let device_identifier = device_code_identifier(&stored.data.device_code);
293 if let Some(device_verification) = ctx
294 .database
295 .get_verification_by_identifier(&device_identifier)
296 .await?
297 {
298 ctx.database
299 .delete_verification(device_verification.id())
300 .await?;
301 }
302
303 ctx.database.delete_verification(&verification_id).await?;
304 store_device_authorization(device_identifier, &stored, ctx).await?;
305
306 Ok(StatusResponse { status: true })
307}
308
309impl DeviceAuthorizationPlugin {
314 async fn handle_code<DB: DatabaseAdapter>(
315 &self,
316 req: &AuthRequest,
317 ctx: &AuthContext<DB>,
318 ) -> AuthResult<AuthResponse> {
319 if !self.config.enabled {
320 return Err(AuthError::not_found("Not found"));
321 }
322
323 let body: DeviceCodeRequest = match better_auth_core::validate_request_body(req) {
324 Ok(v) => v,
325 Err(resp) => return Ok(resp),
326 };
327
328 let resp = create_device_code_core(body.client_id, body.scope, ctx, &self.config).await?;
329
330 Ok(AuthResponse::json(200, &resp)?)
331 }
332
333 async fn handle_token<DB: DatabaseAdapter>(
334 &self,
335 req: &AuthRequest,
336 ctx: &AuthContext<DB>,
337 ) -> AuthResult<AuthResponse> {
338 if !self.config.enabled {
339 return Err(AuthError::not_found("Not found"));
340 }
341
342 let body: DeviceTokenRequest = match better_auth_core::validate_request_body(req) {
343 Ok(v) => v,
344 Err(resp) => return Ok(resp),
345 };
346
347 let resp = poll_device_token_core(body.device_code, ctx).await?;
348
349 Ok(AuthResponse::json(200, &resp)?)
350 }
351
352 async fn handle_approve<DB: DatabaseAdapter>(
353 &self,
354 req: &AuthRequest,
355 ctx: &AuthContext<DB>,
356 ) -> AuthResult<AuthResponse> {
357 if !self.config.enabled {
358 return Err(AuthError::not_found("Not found"));
359 }
360
361 let (user, _) = ctx.require_session(req).await?;
362
363 let body: ApproveRequest = match better_auth_core::validate_request_body(req) {
364 Ok(v) => v,
365 Err(resp) => return Ok(resp),
366 };
367
368 let resp = approve_device_core(&user, &body.user_code, ctx).await?;
369
370 Ok(AuthResponse::json(200, &resp)?)
371 }
372
373 async fn handle_deny<DB: DatabaseAdapter>(
374 &self,
375 req: &AuthRequest,
376 ctx: &AuthContext<DB>,
377 ) -> AuthResult<AuthResponse> {
378 if !self.config.enabled {
379 return Err(AuthError::not_found("Not found"));
380 }
381
382 let _ = ctx.require_session(req).await?;
383
384 let body: ApproveRequest = match better_auth_core::validate_request_body(req) {
385 Ok(v) => v,
386 Err(resp) => return Ok(resp),
387 };
388
389 let resp = deny_device_core(&body.user_code, ctx).await?;
390
391 Ok(AuthResponse::json(200, &resp)?)
392 }
393}