1use crate::errors::{AuthError, Result};
8use crate::storage::AuthStorage;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DeviceAuthorizationRequest {
18 pub client_id: String,
20
21 pub scope: Option<String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DeviceAuthorizationResponse {
28 pub device_code: String,
30
31 pub user_code: String,
33
34 pub verification_uri: String,
36
37 pub verification_uri_complete: Option<String>,
39
40 pub interval: u64,
42
43 pub expires_in: u64,
45}
46
47#[derive(Debug, Clone, Deserialize)]
49pub struct DeviceTokenRequest {
50 pub grant_type: String,
52
53 pub device_code: String,
55
56 pub client_id: String,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StoredDeviceAuthorization {
63 pub device_code: String,
65
66 pub user_code: String,
68
69 pub client_id: String,
71
72 pub scope: Option<String>,
74
75 pub status: DeviceAuthorizationStatus,
77
78 pub user_id: Option<String>,
80
81 pub created_at: SystemTime,
83
84 pub expires_at: SystemTime,
86
87 pub last_poll: Option<SystemTime>,
89
90 #[serde(default)]
93 pub slow_down_count: u32,
94}
95
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
98pub enum DeviceAuthorizationStatus {
99 Pending,
101 Authorized,
103 Denied,
105 Expired,
107}
108
109use std::fmt;
111
112#[derive(Clone)]
113pub struct DeviceAuthManager {
114 storage: Arc<dyn AuthStorage>,
116
117 authorizations: Arc<tokio::sync::RwLock<HashMap<String, StoredDeviceAuthorization>>>,
119
120 default_expiration: Duration,
122
123 min_interval: Duration,
125
126 verification_uri: String,
128}
129
130impl fmt::Debug for DeviceAuthManager {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_struct("DeviceAuthManager")
133 .field("storage", &"<dyn AuthStorage>")
134 .field("default_expiration", &self.default_expiration)
135 .field("min_interval", &self.min_interval)
136 .field("verification_uri", &self.verification_uri)
137 .finish()
138 }
139}
140
141impl DeviceAuthManager {
142 pub fn new(storage: Arc<dyn AuthStorage>, verification_uri: String) -> Self {
144 Self {
145 storage,
146 authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
147 default_expiration: Duration::from_secs(600), min_interval: Duration::from_secs(5), verification_uri,
150 }
151 }
152
153 pub fn with_settings(
155 storage: Arc<dyn AuthStorage>,
156 verification_uri: String,
157 expiration: Duration,
158 min_interval: Duration,
159 ) -> Self {
160 Self {
161 storage,
162 authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
163 default_expiration: expiration,
164 min_interval,
165 verification_uri,
166 }
167 }
168
169 pub fn expiration(mut self, expiration: Duration) -> Self {
173 self.default_expiration = expiration;
174 self
175 }
176
177 pub fn interval(mut self, interval: Duration) -> Self {
181 self.min_interval = interval;
182 self
183 }
184
185 pub async fn create_authorization(
187 &self,
188 request: DeviceAuthorizationRequest,
189 ) -> Result<DeviceAuthorizationResponse> {
190 self.validate_request(&request)?;
192
193 let device_code = format!("dc_{}", Uuid::new_v4().simple());
195 let user_code = self.generate_user_code();
196
197 let now = SystemTime::now();
199 let expires_at = now + self.default_expiration;
200
201 let stored = StoredDeviceAuthorization {
203 device_code: device_code.clone(),
204 user_code: user_code.clone(),
205 client_id: request.client_id.clone(),
206 scope: request.scope.clone(),
207 status: DeviceAuthorizationStatus::Pending,
208 user_id: None,
209 created_at: now,
210 expires_at,
211 last_poll: None,
212 slow_down_count: 0,
213 };
214
215 let device_key = format!("device_code:{}", device_code);
217 let user_key = format!("user_code:{}", user_code);
218
219 let serialized = serde_json::to_string(&stored)
220 .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
221
222 self.storage
223 .store_kv(
224 &device_key,
225 serialized.as_bytes(),
226 Some(self.default_expiration),
227 )
228 .await
229 .map_err(|e| {
230 AuthError::internal(format!("Failed to store device authorization: {}", e))
231 })?;
232
233 self.storage
235 .store_kv(
236 &user_key,
237 serialized.as_bytes(),
238 Some(self.default_expiration),
239 )
240 .await
241 .map_err(|e| {
242 AuthError::internal(format!("Failed to store user code mapping: {}", e))
243 })?;
244
245 let mut authorizations = self.authorizations.write().await;
247 authorizations.insert(device_code.clone(), stored);
248
249 self.cleanup_expired(&mut authorizations, now);
251
252 let verification_uri_complete =
254 format!("{}?user_code={}", self.verification_uri, user_code);
255
256 Ok(DeviceAuthorizationResponse {
257 device_code,
258 user_code,
259 verification_uri: self.verification_uri.clone(),
260 verification_uri_complete: Some(verification_uri_complete),
261 interval: self.min_interval.as_secs(),
262 expires_in: self.default_expiration.as_secs(),
263 })
264 }
265
266 pub async fn poll_authorization(&self, device_code: &str) -> Result<StoredDeviceAuthorization> {
268 let device_key = format!("device_code:{}", device_code);
269
270 let mut stored = if let Some(data) = self.storage.get_kv(&device_key).await? {
272 let serialized = String::from_utf8(data)
273 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
274
275 serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
276 AuthError::internal(format!("Failed to deserialize device auth: {}", e))
277 })?
278 } else {
279 let authorizations = self.authorizations.read().await;
281 authorizations
282 .get(device_code)
283 .cloned()
284 .ok_or_else(|| AuthError::auth_method("device_auth", "Invalid device_code"))?
285 };
286
287 let now = SystemTime::now();
289 if now > stored.expires_at {
290 stored.status = DeviceAuthorizationStatus::Expired;
291 return Err(AuthError::auth_method("device_auth", "Device code expired"));
292 }
293
294 let effective_interval = self.min_interval
298 + Duration::from_secs(5 * u64::from(stored.slow_down_count));
299 if let Some(last_poll) = stored.last_poll {
300 let elapsed = now.duration_since(last_poll).unwrap_or(Duration::ZERO);
301 if elapsed < effective_interval {
302 stored.slow_down_count += 1;
303 stored.last_poll = Some(now);
305 let serialized = serde_json::to_string(&stored).map_err(|e| {
306 AuthError::internal(format!("Failed to serialize device auth: {}", e))
307 })?;
308 self.storage
309 .store_kv(
310 &device_key,
311 serialized.as_bytes(),
312 Some(self.default_expiration),
313 )
314 .await
315 .ok();
316 let mut authorizations = self.authorizations.write().await;
317 authorizations.insert(device_code.to_string(), stored);
318 return Err(AuthError::auth_method("device_auth", "slow_down"));
319 }
320 }
321
322 stored.last_poll = Some(now);
324
325 let serialized = serde_json::to_string(&stored)
327 .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
328
329 self.storage
330 .store_kv(
331 &device_key,
332 serialized.as_bytes(),
333 Some(self.default_expiration),
334 )
335 .await
336 .ok(); let mut authorizations = self.authorizations.write().await;
340 authorizations.insert(device_code.to_string(), stored.clone());
341
342 match stored.status {
344 DeviceAuthorizationStatus::Pending => Err(AuthError::auth_method(
345 "device_auth",
346 "authorization_pending",
347 )),
348 DeviceAuthorizationStatus::Authorized => Ok(stored),
349 DeviceAuthorizationStatus::Denied => {
350 Err(AuthError::auth_method("device_auth", "access_denied"))
351 }
352 DeviceAuthorizationStatus::Expired => {
353 Err(AuthError::auth_method("device_auth", "expired_token"))
354 }
355 }
356 }
357
358 pub async fn authorize_device(&self, user_code: &str, user_id: &str) -> Result<()> {
360 let user_key = format!("user_code:{}", user_code);
361
362 let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
364 let serialized = String::from_utf8(data)
365 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
366
367 serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
368 AuthError::internal(format!("Failed to deserialize device auth: {}", e))
369 })?
370 } else {
371 return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
372 };
373
374 let now = SystemTime::now();
376 if now > stored.expires_at {
377 return Err(AuthError::auth_method("device_auth", "Device code expired"));
378 }
379
380 stored.status = DeviceAuthorizationStatus::Authorized;
382 stored.user_id = Some(user_id.to_string());
383
384 let serialized = serde_json::to_string(&stored)
386 .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
387
388 let device_key = format!("device_code:{}", stored.device_code);
389
390 self.storage
391 .store_kv(
392 &device_key,
393 serialized.as_bytes(),
394 Some(self.default_expiration),
395 )
396 .await?;
397
398 self.storage
399 .store_kv(
400 &user_key,
401 serialized.as_bytes(),
402 Some(self.default_expiration),
403 )
404 .await?;
405
406 let mut authorizations = self.authorizations.write().await;
408 authorizations.insert(stored.device_code.clone(), stored);
409
410 Ok(())
411 }
412
413 pub async fn deny_device(&self, user_code: &str) -> Result<()> {
415 let user_key = format!("user_code:{}", user_code);
416
417 let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
419 let serialized = String::from_utf8(data)
420 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
421
422 serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
423 AuthError::internal(format!("Failed to deserialize device auth: {}", e))
424 })?
425 } else {
426 return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
427 };
428
429 stored.status = DeviceAuthorizationStatus::Denied;
431
432 let serialized = serde_json::to_string(&stored)
434 .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
435
436 let device_key = format!("device_code:{}", stored.device_code);
437
438 self.storage
439 .store_kv(
440 &device_key,
441 serialized.as_bytes(),
442 Some(self.default_expiration),
443 )
444 .await?;
445
446 self.storage
447 .store_kv(
448 &user_key,
449 serialized.as_bytes(),
450 Some(self.default_expiration),
451 )
452 .await?;
453
454 let mut authorizations = self.authorizations.write().await;
456 authorizations.insert(stored.device_code.clone(), stored);
457
458 Ok(())
459 }
460
461 pub async fn get_by_user_code(&self, user_code: &str) -> Result<StoredDeviceAuthorization> {
463 let user_key = format!("user_code:{}", user_code);
464
465 if let Some(data) = self.storage.get_kv(&user_key).await? {
466 let serialized = String::from_utf8(data)
467 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
468
469 let stored: StoredDeviceAuthorization =
470 serde_json::from_str(&serialized).map_err(|e| {
471 AuthError::internal(format!("Failed to deserialize device auth: {}", e))
472 })?;
473
474 let now = SystemTime::now();
476 if now > stored.expires_at {
477 return Err(AuthError::auth_method("device_auth", "User code expired"));
478 }
479
480 Ok(stored)
481 } else {
482 Err(AuthError::auth_method("device_auth", "Invalid user_code"))
483 }
484 }
485
486 fn validate_request(&self, request: &DeviceAuthorizationRequest) -> Result<()> {
488 if request.client_id.is_empty() {
489 return Err(AuthError::auth_method("device_auth", "Missing client_id"));
490 }
491
492 Ok(())
495 }
496
497 fn generate_user_code(&self) -> String {
499 use rand::RngExt;
500 const CHARS: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; let mut rng = rand::rng();
502
503 let code: String = (0..9)
505 .map(|i| {
506 if i == 4 {
507 '-'
508 } else {
509 let idx = rng.random_range(0..CHARS.len());
510 CHARS[idx] as char
511 }
512 })
513 .collect();
514
515 code
516 }
517
518 fn cleanup_expired(
520 &self,
521 authorizations: &mut HashMap<String, StoredDeviceAuthorization>,
522 now: SystemTime,
523 ) {
524 authorizations.retain(|_, auth| now <= auth.expires_at);
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use crate::storage::MemoryStorage;
532 use tokio::time::sleep;
533
534 fn create_test_manager() -> DeviceAuthManager {
535 let storage = Arc::new(MemoryStorage::new());
536 DeviceAuthManager::new(storage, "https://example.com/device".to_string())
537 }
538
539 #[tokio::test]
540 async fn test_create_authorization() {
541 let manager = create_test_manager();
542
543 let request = DeviceAuthorizationRequest {
544 client_id: "test_client".to_string(),
545 scope: Some("openid profile".to_string()),
546 };
547
548 let response = manager.create_authorization(request).await.unwrap();
549
550 assert!(response.device_code.starts_with("dc_"));
551 assert_eq!(response.user_code.len(), 9); assert!(response.user_code.contains('-'));
553 assert_eq!(response.verification_uri, "https://example.com/device");
554 assert!(response.verification_uri_complete.is_some());
555 assert_eq!(response.interval, 5);
556 assert_eq!(response.expires_in, 600);
557 }
558
559 #[tokio::test]
560 async fn test_poll_pending() {
561 let manager = create_test_manager();
562
563 let request = DeviceAuthorizationRequest {
564 client_id: "test_client".to_string(),
565 scope: None,
566 };
567
568 let response = manager.create_authorization(request).await.unwrap();
569
570 let result = manager.poll_authorization(&response.device_code).await;
572 assert!(result.is_err());
573 let err = result.unwrap_err();
574 assert!(err.to_string().contains("authorization_pending"));
575 }
576
577 #[tokio::test]
578 async fn test_authorize_and_poll() {
579 let manager = create_test_manager();
580
581 let request = DeviceAuthorizationRequest {
582 client_id: "test_client".to_string(),
583 scope: Some("openid".to_string()),
584 };
585
586 let response = manager.create_authorization(request).await.unwrap();
587
588 manager
590 .authorize_device(&response.user_code, "user_123")
591 .await
592 .unwrap();
593
594 let stored = manager
596 .poll_authorization(&response.device_code)
597 .await
598 .unwrap();
599 assert_eq!(stored.status, DeviceAuthorizationStatus::Authorized);
600 assert_eq!(stored.user_id, Some("user_123".to_string()));
601 }
602
603 #[tokio::test]
604 async fn test_deny_device() {
605 let manager = create_test_manager();
606
607 let request = DeviceAuthorizationRequest {
608 client_id: "test_client".to_string(),
609 scope: None,
610 };
611
612 let response = manager.create_authorization(request).await.unwrap();
613
614 manager.deny_device(&response.user_code).await.unwrap();
616
617 let result = manager.poll_authorization(&response.device_code).await;
619 assert!(result.is_err());
620 let err = result.unwrap_err();
621 assert!(err.to_string().contains("access_denied"));
622 }
623
624 #[tokio::test]
625 async fn test_slow_down() {
626 let manager = create_test_manager();
627
628 let request = DeviceAuthorizationRequest {
629 client_id: "test_client".to_string(),
630 scope: None,
631 };
632
633 let response = manager.create_authorization(request).await.unwrap();
634
635 let _ = manager.poll_authorization(&response.device_code).await;
637
638 let result = manager.poll_authorization(&response.device_code).await;
640 assert!(result.is_err());
641 let err = result.unwrap_err();
642 assert!(err.to_string().contains("slow_down"));
643 }
644
645 #[tokio::test]
646 async fn test_expiration() {
647 let storage = Arc::new(MemoryStorage::new());
648 let manager = DeviceAuthManager::with_settings(
650 storage,
651 "https://example.com/device".to_string(),
652 Duration::from_millis(100),
653 Duration::from_secs(1),
654 );
655
656 let request = DeviceAuthorizationRequest {
657 client_id: "test_client".to_string(),
658 scope: None,
659 };
660
661 let response = manager.create_authorization(request).await.unwrap();
662
663 sleep(Duration::from_millis(150)).await;
665
666 let result = manager.poll_authorization(&response.device_code).await;
668 assert!(result.is_err());
669 let err = result.unwrap_err();
670 assert!(err.to_string().contains("expired"));
671 }
672
673 #[tokio::test]
674 async fn test_chainable_expiration_and_interval() {
675 let storage = Arc::new(MemoryStorage::new());
676 let manager = DeviceAuthManager::new(storage, "https://example.com/device".to_string())
677 .expiration(Duration::from_secs(300))
678 .interval(Duration::from_secs(10));
679
680 let request = DeviceAuthorizationRequest {
681 client_id: "test_client".to_string(),
682 scope: None,
683 };
684
685 let response = manager.create_authorization(request).await.unwrap();
686 assert_eq!(response.expires_in, 300);
687 assert_eq!(response.interval, 10);
688 }
689}