1use crate::errors::{AuthError, Result};
8use crate::storage::AuthStorage;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DeviceAuthorizationRequest {
18 pub client_id: String,
20
21 pub scope: Option<String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DeviceAuthorizationResponse {
28 pub device_code: String,
30
31 pub user_code: String,
33
34 pub verification_uri: String,
36
37 pub verification_uri_complete: Option<String>,
39
40 pub interval: u64,
42
43 pub expires_in: u64,
45}
46
47#[derive(Debug, Clone, Deserialize)]
49pub struct DeviceTokenRequest {
50 pub grant_type: String,
52
53 pub device_code: String,
55
56 pub client_id: String,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StoredDeviceAuthorization {
63 pub device_code: String,
65
66 pub user_code: String,
68
69 pub client_id: String,
71
72 pub scope: Option<String>,
74
75 pub status: DeviceAuthorizationStatus,
77
78 pub user_id: Option<String>,
80
81 pub created_at: SystemTime,
83
84 pub expires_at: SystemTime,
86
87 pub last_poll: Option<SystemTime>,
89}
90
91#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93pub enum DeviceAuthorizationStatus {
94 Pending,
96 Authorized,
98 Denied,
100 Expired,
102}
103
104use std::fmt;
106
107#[derive(Clone)]
108pub struct DeviceAuthManager {
109 storage: Arc<dyn AuthStorage>,
111
112 authorizations: Arc<tokio::sync::RwLock<HashMap<String, StoredDeviceAuthorization>>>,
114
115 default_expiration: Duration,
117
118 min_interval: Duration,
120
121 verification_uri: String,
123}
124
125impl fmt::Debug for DeviceAuthManager {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("DeviceAuthManager")
128 .field("storage", &"<dyn AuthStorage>")
129 .field("default_expiration", &self.default_expiration)
130 .field("min_interval", &self.min_interval)
131 .field("verification_uri", &self.verification_uri)
132 .finish()
133 }
134}
135
136impl DeviceAuthManager {
137 pub fn new(storage: Arc<dyn AuthStorage>, verification_uri: String) -> Self {
139 Self {
140 storage,
141 authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
142 default_expiration: Duration::from_secs(600), min_interval: Duration::from_secs(5), verification_uri,
145 }
146 }
147
148 pub fn with_settings(
150 storage: Arc<dyn AuthStorage>,
151 verification_uri: String,
152 expiration: Duration,
153 min_interval: Duration,
154 ) -> Self {
155 Self {
156 storage,
157 authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
158 default_expiration: expiration,
159 min_interval,
160 verification_uri,
161 }
162 }
163
164 pub async fn create_authorization(
166 &self,
167 request: DeviceAuthorizationRequest,
168 ) -> Result<DeviceAuthorizationResponse> {
169 self.validate_request(&request)?;
171
172 let device_code = format!("dc_{}", Uuid::new_v4().simple());
174 let user_code = self.generate_user_code();
175
176 let now = SystemTime::now();
178 let expires_at = now + self.default_expiration;
179
180 let stored = StoredDeviceAuthorization {
182 device_code: device_code.clone(),
183 user_code: user_code.clone(),
184 client_id: request.client_id.clone(),
185 scope: request.scope.clone(),
186 status: DeviceAuthorizationStatus::Pending,
187 user_id: None,
188 created_at: now,
189 expires_at,
190 last_poll: None,
191 };
192
193 let device_key = format!("device_code:{}", device_code);
195 let user_key = format!("user_code:{}", user_code);
196
197 let serialized = serde_json::to_string(&stored)
198 .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
199
200 self.storage
201 .store_kv(
202 &device_key,
203 serialized.as_bytes(),
204 Some(self.default_expiration),
205 )
206 .await
207 .map_err(|e| {
208 AuthError::internal(format!("Failed to store device authorization: {}", e))
209 })?;
210
211 self.storage
213 .store_kv(
214 &user_key,
215 serialized.as_bytes(),
216 Some(self.default_expiration),
217 )
218 .await
219 .map_err(|e| {
220 AuthError::internal(format!("Failed to store user code mapping: {}", e))
221 })?;
222
223 let mut authorizations = self.authorizations.write().await;
225 authorizations.insert(device_code.clone(), stored);
226
227 self.cleanup_expired(&mut authorizations, now);
229
230 let verification_uri_complete =
232 format!("{}?user_code={}", self.verification_uri, user_code);
233
234 Ok(DeviceAuthorizationResponse {
235 device_code,
236 user_code,
237 verification_uri: self.verification_uri.clone(),
238 verification_uri_complete: Some(verification_uri_complete),
239 interval: self.min_interval.as_secs(),
240 expires_in: self.default_expiration.as_secs(),
241 })
242 }
243
244 pub async fn poll_authorization(&self, device_code: &str) -> Result<StoredDeviceAuthorization> {
246 let device_key = format!("device_code:{}", device_code);
247
248 let mut stored = if let Some(data) = self.storage.get_kv(&device_key).await? {
250 let serialized = String::from_utf8(data)
251 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
252
253 serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
254 AuthError::internal(format!("Failed to deserialize device auth: {}", e))
255 })?
256 } else {
257 let authorizations = self.authorizations.read().await;
259 authorizations
260 .get(device_code)
261 .cloned()
262 .ok_or_else(|| AuthError::auth_method("device_auth", "Invalid device_code"))?
263 };
264
265 let now = SystemTime::now();
267 if now > stored.expires_at {
268 stored.status = DeviceAuthorizationStatus::Expired;
269 return Err(AuthError::auth_method("device_auth", "Device code expired"));
270 }
271
272 if let Some(last_poll) = stored.last_poll {
274 let elapsed = now.duration_since(last_poll).unwrap_or(Duration::ZERO);
275 if elapsed < self.min_interval {
276 return Err(AuthError::auth_method("device_auth", "slow_down"));
277 }
278 }
279
280 stored.last_poll = Some(now);
282
283 let serialized = serde_json::to_string(&stored)
285 .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
286
287 self.storage
288 .store_kv(
289 &device_key,
290 serialized.as_bytes(),
291 Some(self.default_expiration),
292 )
293 .await
294 .ok(); let mut authorizations = self.authorizations.write().await;
298 authorizations.insert(device_code.to_string(), stored.clone());
299
300 match stored.status {
302 DeviceAuthorizationStatus::Pending => Err(AuthError::auth_method(
303 "device_auth",
304 "authorization_pending",
305 )),
306 DeviceAuthorizationStatus::Authorized => Ok(stored),
307 DeviceAuthorizationStatus::Denied => {
308 Err(AuthError::auth_method("device_auth", "access_denied"))
309 }
310 DeviceAuthorizationStatus::Expired => {
311 Err(AuthError::auth_method("device_auth", "expired_token"))
312 }
313 }
314 }
315
316 pub async fn authorize_device(&self, user_code: &str, user_id: &str) -> Result<()> {
318 let user_key = format!("user_code:{}", user_code);
319
320 let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
322 let serialized = String::from_utf8(data)
323 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
324
325 serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
326 AuthError::internal(format!("Failed to deserialize device auth: {}", e))
327 })?
328 } else {
329 return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
330 };
331
332 let now = SystemTime::now();
334 if now > stored.expires_at {
335 return Err(AuthError::auth_method("device_auth", "Device code expired"));
336 }
337
338 stored.status = DeviceAuthorizationStatus::Authorized;
340 stored.user_id = Some(user_id.to_string());
341
342 let serialized = serde_json::to_string(&stored)
344 .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
345
346 let device_key = format!("device_code:{}", stored.device_code);
347
348 self.storage
349 .store_kv(
350 &device_key,
351 serialized.as_bytes(),
352 Some(self.default_expiration),
353 )
354 .await?;
355
356 self.storage
357 .store_kv(
358 &user_key,
359 serialized.as_bytes(),
360 Some(self.default_expiration),
361 )
362 .await?;
363
364 let mut authorizations = self.authorizations.write().await;
366 authorizations.insert(stored.device_code.clone(), stored);
367
368 Ok(())
369 }
370
371 pub async fn deny_device(&self, user_code: &str) -> Result<()> {
373 let user_key = format!("user_code:{}", user_code);
374
375 let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
377 let serialized = String::from_utf8(data)
378 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
379
380 serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
381 AuthError::internal(format!("Failed to deserialize device auth: {}", e))
382 })?
383 } else {
384 return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
385 };
386
387 stored.status = DeviceAuthorizationStatus::Denied;
389
390 let serialized = serde_json::to_string(&stored)
392 .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
393
394 let device_key = format!("device_code:{}", stored.device_code);
395
396 self.storage
397 .store_kv(
398 &device_key,
399 serialized.as_bytes(),
400 Some(self.default_expiration),
401 )
402 .await?;
403
404 self.storage
405 .store_kv(
406 &user_key,
407 serialized.as_bytes(),
408 Some(self.default_expiration),
409 )
410 .await?;
411
412 let mut authorizations = self.authorizations.write().await;
414 authorizations.insert(stored.device_code.clone(), stored);
415
416 Ok(())
417 }
418
419 pub async fn get_by_user_code(&self, user_code: &str) -> Result<StoredDeviceAuthorization> {
421 let user_key = format!("user_code:{}", user_code);
422
423 if let Some(data) = self.storage.get_kv(&user_key).await? {
424 let serialized = String::from_utf8(data)
425 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
426
427 let stored: StoredDeviceAuthorization =
428 serde_json::from_str(&serialized).map_err(|e| {
429 AuthError::internal(format!("Failed to deserialize device auth: {}", e))
430 })?;
431
432 let now = SystemTime::now();
434 if now > stored.expires_at {
435 return Err(AuthError::auth_method("device_auth", "User code expired"));
436 }
437
438 Ok(stored)
439 } else {
440 Err(AuthError::auth_method("device_auth", "Invalid user_code"))
441 }
442 }
443
444 fn validate_request(&self, request: &DeviceAuthorizationRequest) -> Result<()> {
446 if request.client_id.is_empty() {
447 return Err(AuthError::auth_method("device_auth", "Missing client_id"));
448 }
449
450 Ok(())
453 }
454
455 fn generate_user_code(&self) -> String {
457 use rand::Rng;
458 const CHARS: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; let mut rng = rand::rng();
460
461 let code: String = (0..9)
463 .map(|i| {
464 if i == 4 {
465 '-'
466 } else {
467 let idx = rng.random_range(0..CHARS.len());
468 CHARS[idx] as char
469 }
470 })
471 .collect();
472
473 code
474 }
475
476 fn cleanup_expired(
478 &self,
479 authorizations: &mut HashMap<String, StoredDeviceAuthorization>,
480 now: SystemTime,
481 ) {
482 authorizations.retain(|_, auth| now <= auth.expires_at);
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use crate::storage::MemoryStorage;
490 use tokio::time::sleep;
491
492 fn create_test_manager() -> DeviceAuthManager {
493 let storage = Arc::new(MemoryStorage::new());
494 DeviceAuthManager::new(storage, "https://example.com/device".to_string())
495 }
496
497 #[tokio::test]
498 async fn test_create_authorization() {
499 let manager = create_test_manager();
500
501 let request = DeviceAuthorizationRequest {
502 client_id: "test_client".to_string(),
503 scope: Some("openid profile".to_string()),
504 };
505
506 let response = manager.create_authorization(request).await.unwrap();
507
508 assert!(response.device_code.starts_with("dc_"));
509 assert_eq!(response.user_code.len(), 9); assert!(response.user_code.contains('-'));
511 assert_eq!(response.verification_uri, "https://example.com/device");
512 assert!(response.verification_uri_complete.is_some());
513 assert_eq!(response.interval, 5);
514 assert_eq!(response.expires_in, 600);
515 }
516
517 #[tokio::test]
518 async fn test_poll_pending() {
519 let manager = create_test_manager();
520
521 let request = DeviceAuthorizationRequest {
522 client_id: "test_client".to_string(),
523 scope: None,
524 };
525
526 let response = manager.create_authorization(request).await.unwrap();
527
528 let result = manager.poll_authorization(&response.device_code).await;
530 assert!(result.is_err());
531 let err = result.unwrap_err();
532 assert_eq!(err.to_string().contains("authorization_pending"), true);
533 }
534
535 #[tokio::test]
536 async fn test_authorize_and_poll() {
537 let manager = create_test_manager();
538
539 let request = DeviceAuthorizationRequest {
540 client_id: "test_client".to_string(),
541 scope: Some("openid".to_string()),
542 };
543
544 let response = manager.create_authorization(request).await.unwrap();
545
546 manager
548 .authorize_device(&response.user_code, "user_123")
549 .await
550 .unwrap();
551
552 let stored = manager
554 .poll_authorization(&response.device_code)
555 .await
556 .unwrap();
557 assert_eq!(stored.status, DeviceAuthorizationStatus::Authorized);
558 assert_eq!(stored.user_id, Some("user_123".to_string()));
559 }
560
561 #[tokio::test]
562 async fn test_deny_device() {
563 let manager = create_test_manager();
564
565 let request = DeviceAuthorizationRequest {
566 client_id: "test_client".to_string(),
567 scope: None,
568 };
569
570 let response = manager.create_authorization(request).await.unwrap();
571
572 manager.deny_device(&response.user_code).await.unwrap();
574
575 let result = manager.poll_authorization(&response.device_code).await;
577 assert!(result.is_err());
578 let err = result.unwrap_err();
579 assert_eq!(err.to_string().contains("access_denied"), true);
580 }
581
582 #[tokio::test]
583 async fn test_slow_down() {
584 let manager = create_test_manager();
585
586 let request = DeviceAuthorizationRequest {
587 client_id: "test_client".to_string(),
588 scope: None,
589 };
590
591 let response = manager.create_authorization(request).await.unwrap();
592
593 let _ = manager.poll_authorization(&response.device_code).await;
595
596 let result = manager.poll_authorization(&response.device_code).await;
598 assert!(result.is_err());
599 let err = result.unwrap_err();
600 assert_eq!(err.to_string().contains("slow_down"), true);
601 }
602
603 #[tokio::test]
604 async fn test_expiration() {
605 let storage = Arc::new(MemoryStorage::new());
606 let manager = DeviceAuthManager::with_settings(
608 storage,
609 "https://example.com/device".to_string(),
610 Duration::from_millis(100),
611 Duration::from_secs(1),
612 );
613
614 let request = DeviceAuthorizationRequest {
615 client_id: "test_client".to_string(),
616 scope: None,
617 };
618
619 let response = manager.create_authorization(request).await.unwrap();
620
621 sleep(Duration::from_millis(150)).await;
623
624 let result = manager.poll_authorization(&response.device_code).await;
626 assert!(result.is_err());
627 let err = result.unwrap_err();
628 assert_eq!(err.to_string().contains("expired"), true);
629 }
630}