Skip to main content

auth_framework/server/oauth/
device.rs

1//! Device Authorization Grant Implementation - RFC 8628
2//!
3//! This module implements RFC 8628 - OAuth 2.0 Device Authorization Grant
4//! which allows devices with limited input capability (smart TVs, printers, etc.)
5//! to obtain user authorization.
6
7use crate::errors::{AuthError, Result};
8use crate::storage::AuthStorage;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13use uuid::Uuid;
14
15/// Device authorization request
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DeviceAuthorizationRequest {
18    /// Client identifier
19    pub client_id: String,
20
21    /// Requested scopes
22    pub scope: Option<String>,
23}
24
25/// Device authorization response
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DeviceAuthorizationResponse {
28    /// Device verification code (for storage)
29    pub device_code: String,
30
31    /// User-friendly code to enter
32    pub user_code: String,
33
34    /// URL where user should authorize
35    pub verification_uri: String,
36
37    /// Complete verification URL with user_code (optional)
38    pub verification_uri_complete: Option<String>,
39
40    /// Polling interval in seconds
41    pub interval: u64,
42
43    /// Device code expires in seconds
44    pub expires_in: u64,
45}
46
47/// Token request for device code grant
48#[derive(Debug, Clone, Deserialize)]
49pub struct DeviceTokenRequest {
50    /// Grant type (must be "urn:ietf:params:oauth:grant-type:device_code")
51    pub grant_type: String,
52
53    /// Device code received from device authorization
54    pub device_code: String,
55
56    /// Client identifier
57    pub client_id: String,
58}
59
60/// Stored device authorization data
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StoredDeviceAuthorization {
63    /// Device code
64    pub device_code: String,
65
66    /// User code
67    pub user_code: String,
68
69    /// Client ID
70    pub client_id: String,
71
72    /// Requested scopes
73    pub scope: Option<String>,
74
75    /// Authorization status
76    pub status: DeviceAuthorizationStatus,
77
78    /// User ID (once authorized)
79    pub user_id: Option<String>,
80
81    /// When the request was created
82    pub created_at: SystemTime,
83
84    /// When the request expires
85    pub expires_at: SystemTime,
86
87    /// Last poll time (for slow_down error)
88    pub last_poll: Option<SystemTime>,
89
90    /// Number of times the client has been told to slow down (RFC 8628 §3.5).
91    /// Each slow_down increases the required interval by 5 seconds.
92    #[serde(default)]
93    pub slow_down_count: u32,
94}
95
96/// Device authorization status
97#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
98pub enum DeviceAuthorizationStatus {
99    /// Awaiting user authorization
100    Pending,
101    /// User has authorized
102    Authorized,
103    /// User has denied
104    Denied,
105    /// Authorization expired
106    Expired,
107}
108
109/// Device authorization manager with persistent storage
110use std::fmt;
111
112#[derive(Clone)]
113pub struct DeviceAuthManager {
114    /// Persistent storage backend
115    storage: Arc<dyn AuthStorage>,
116
117    /// Memory cache for fast access
118    authorizations: Arc<tokio::sync::RwLock<HashMap<String, StoredDeviceAuthorization>>>,
119
120    /// Default expiration time for device codes
121    default_expiration: Duration,
122
123    /// Minimum polling interval
124    min_interval: Duration,
125
126    /// Base verification URI
127    verification_uri: String,
128}
129
130impl fmt::Debug for DeviceAuthManager {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        f.debug_struct("DeviceAuthManager")
133            .field("storage", &"<dyn AuthStorage>")
134            .field("default_expiration", &self.default_expiration)
135            .field("min_interval", &self.min_interval)
136            .field("verification_uri", &self.verification_uri)
137            .finish()
138    }
139}
140
141impl DeviceAuthManager {
142    /// Create a new device authorization manager
143    pub fn new(storage: Arc<dyn AuthStorage>, verification_uri: String) -> Self {
144        Self {
145            storage,
146            authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
147            default_expiration: Duration::from_secs(600), // 10 minutes (RFC 8628 recommendation)
148            min_interval: Duration::from_secs(5),         // 5 seconds minimum
149            verification_uri,
150        }
151    }
152
153    /// Create a new device authorization manager with custom settings
154    pub fn with_settings(
155        storage: Arc<dyn AuthStorage>,
156        verification_uri: String,
157        expiration: Duration,
158        min_interval: Duration,
159    ) -> Self {
160        Self {
161            storage,
162            authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
163            default_expiration: expiration,
164            min_interval,
165            verification_uri,
166        }
167    }
168
169    /// Set the default expiration time for device codes (chainable).
170    ///
171    /// Default: 600 seconds (10 minutes, per RFC 8628 recommendation).
172    pub fn expiration(mut self, expiration: Duration) -> Self {
173        self.default_expiration = expiration;
174        self
175    }
176
177    /// Set the minimum polling interval (chainable).
178    ///
179    /// Default: 5 seconds.
180    pub fn interval(mut self, interval: Duration) -> Self {
181        self.min_interval = interval;
182        self
183    }
184
185    /// Initiate device authorization flow
186    pub async fn create_authorization(
187        &self,
188        request: DeviceAuthorizationRequest,
189    ) -> Result<DeviceAuthorizationResponse> {
190        // Validate the request
191        self.validate_request(&request)?;
192
193        // Generate device code and user code
194        let device_code = format!("dc_{}", Uuid::new_v4().simple());
195        let user_code = self.generate_user_code();
196
197        // Calculate expiration
198        let now = SystemTime::now();
199        let expires_at = now + self.default_expiration;
200
201        // Create stored authorization
202        let stored = StoredDeviceAuthorization {
203            device_code: device_code.clone(),
204            user_code: user_code.clone(),
205            client_id: request.client_id.clone(),
206            scope: request.scope.clone(),
207            status: DeviceAuthorizationStatus::Pending,
208            user_id: None,
209            created_at: now,
210            expires_at,
211            last_poll: None,
212            slow_down_count: 0,
213        };
214
215        // Store in persistent backend with TTL
216        let device_key = format!("device_code:{}", device_code);
217        let user_key = format!("user_code:{}", user_code);
218
219        let serialized = serde_json::to_string(&stored)
220            .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
221
222        self.storage
223            .store_kv(
224                &device_key,
225                serialized.as_bytes(),
226                Some(self.default_expiration),
227            )
228            .await
229            .map_err(|e| {
230                AuthError::internal(format!("Failed to store device authorization: {}", e))
231            })?;
232
233        // Also store under user_code for verification page
234        self.storage
235            .store_kv(
236                &user_key,
237                serialized.as_bytes(),
238                Some(self.default_expiration),
239            )
240            .await
241            .map_err(|e| {
242                AuthError::internal(format!("Failed to store user code mapping: {}", e))
243            })?;
244
245        // Cache in memory
246        let mut authorizations = self.authorizations.write().await;
247        authorizations.insert(device_code.clone(), stored);
248
249        // Cleanup expired entries
250        self.cleanup_expired(&mut authorizations, now);
251
252        // Create response
253        let verification_uri_complete =
254            format!("{}?user_code={}", self.verification_uri, user_code);
255
256        Ok(DeviceAuthorizationResponse {
257            device_code,
258            user_code,
259            verification_uri: self.verification_uri.clone(),
260            verification_uri_complete: Some(verification_uri_complete),
261            interval: self.min_interval.as_secs(),
262            expires_in: self.default_expiration.as_secs(),
263        })
264    }
265
266    /// Poll for authorization status (used during token endpoint polling)
267    pub async fn poll_authorization(&self, device_code: &str) -> Result<StoredDeviceAuthorization> {
268        let device_key = format!("device_code:{}", device_code);
269
270        // Try to load from persistent storage first
271        let mut stored = if let Some(data) = self.storage.get_kv(&device_key).await? {
272            let serialized = String::from_utf8(data)
273                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
274
275            serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
276                AuthError::internal(format!("Failed to deserialize device auth: {}", e))
277            })?
278        } else {
279            // Fallback to memory cache
280            let authorizations = self.authorizations.read().await;
281            authorizations
282                .get(device_code)
283                .cloned()
284                .ok_or_else(|| AuthError::auth_method("device_auth", "Invalid device_code"))?
285        };
286
287        // Check expiration
288        let now = SystemTime::now();
289        if now > stored.expires_at {
290            stored.status = DeviceAuthorizationStatus::Expired;
291            return Err(AuthError::auth_method("device_auth", "Device code expired"));
292        }
293
294        // Check for slow_down (polling too frequently) — RFC 8628 §3.5
295        // The required interval increases by 5 seconds each time the client
296        // polls too fast, implementing exponential backoff.
297        let effective_interval = self.min_interval
298            + Duration::from_secs(5 * u64::from(stored.slow_down_count));
299        if let Some(last_poll) = stored.last_poll {
300            let elapsed = now.duration_since(last_poll).unwrap_or(Duration::ZERO);
301            if elapsed < effective_interval {
302                stored.slow_down_count += 1;
303                // Persist the updated slow_down_count
304                stored.last_poll = Some(now);
305                let serialized = serde_json::to_string(&stored).map_err(|e| {
306                    AuthError::internal(format!("Failed to serialize device auth: {}", e))
307                })?;
308                self.storage
309                    .store_kv(
310                        &device_key,
311                        serialized.as_bytes(),
312                        Some(self.default_expiration),
313                    )
314                    .await
315                    .ok();
316                let mut authorizations = self.authorizations.write().await;
317                authorizations.insert(device_code.to_string(), stored);
318                return Err(AuthError::auth_method("device_auth", "slow_down"));
319            }
320        }
321
322        // Update last poll time
323        stored.last_poll = Some(now);
324
325        // Persist updated state
326        let serialized = serde_json::to_string(&stored)
327            .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
328
329        self.storage
330            .store_kv(
331                &device_key,
332                serialized.as_bytes(),
333                Some(self.default_expiration),
334            )
335            .await
336            .ok(); // Ignore errors for poll time update
337
338        // Update memory cache
339        let mut authorizations = self.authorizations.write().await;
340        authorizations.insert(device_code.to_string(), stored.clone());
341
342        // Return current status
343        match stored.status {
344            DeviceAuthorizationStatus::Pending => Err(AuthError::auth_method(
345                "device_auth",
346                "authorization_pending",
347            )),
348            DeviceAuthorizationStatus::Authorized => Ok(stored),
349            DeviceAuthorizationStatus::Denied => {
350                Err(AuthError::auth_method("device_auth", "access_denied"))
351            }
352            DeviceAuthorizationStatus::Expired => {
353                Err(AuthError::auth_method("device_auth", "expired_token"))
354            }
355        }
356    }
357
358    /// Authorize a device (called when user approves on verification page)
359    pub async fn authorize_device(&self, user_code: &str, user_id: &str) -> Result<()> {
360        let user_key = format!("user_code:{}", user_code);
361
362        // Load from storage
363        let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
364            let serialized = String::from_utf8(data)
365                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
366
367            serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
368                AuthError::internal(format!("Failed to deserialize device auth: {}", e))
369            })?
370        } else {
371            return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
372        };
373
374        // Check if expired
375        let now = SystemTime::now();
376        if now > stored.expires_at {
377            return Err(AuthError::auth_method("device_auth", "Device code expired"));
378        }
379
380        // Update status
381        stored.status = DeviceAuthorizationStatus::Authorized;
382        stored.user_id = Some(user_id.to_string());
383
384        // Persist updated state
385        let serialized = serde_json::to_string(&stored)
386            .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
387
388        let device_key = format!("device_code:{}", stored.device_code);
389
390        self.storage
391            .store_kv(
392                &device_key,
393                serialized.as_bytes(),
394                Some(self.default_expiration),
395            )
396            .await?;
397
398        self.storage
399            .store_kv(
400                &user_key,
401                serialized.as_bytes(),
402                Some(self.default_expiration),
403            )
404            .await?;
405
406        // Update memory cache
407        let mut authorizations = self.authorizations.write().await;
408        authorizations.insert(stored.device_code.clone(), stored);
409
410        Ok(())
411    }
412
413    /// Deny a device authorization
414    pub async fn deny_device(&self, user_code: &str) -> Result<()> {
415        let user_key = format!("user_code:{}", user_code);
416
417        // Load from storage
418        let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
419            let serialized = String::from_utf8(data)
420                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
421
422            serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
423                AuthError::internal(format!("Failed to deserialize device auth: {}", e))
424            })?
425        } else {
426            return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
427        };
428
429        // Update status
430        stored.status = DeviceAuthorizationStatus::Denied;
431
432        // Persist updated state
433        let serialized = serde_json::to_string(&stored)
434            .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
435
436        let device_key = format!("device_code:{}", stored.device_code);
437
438        self.storage
439            .store_kv(
440                &device_key,
441                serialized.as_bytes(),
442                Some(self.default_expiration),
443            )
444            .await?;
445
446        self.storage
447            .store_kv(
448                &user_key,
449                serialized.as_bytes(),
450                Some(self.default_expiration),
451            )
452            .await?;
453
454        // Update memory cache
455        let mut authorizations = self.authorizations.write().await;
456        authorizations.insert(stored.device_code.clone(), stored);
457
458        Ok(())
459    }
460
461    /// Get device authorization by user code (for verification page)
462    pub async fn get_by_user_code(&self, user_code: &str) -> Result<StoredDeviceAuthorization> {
463        let user_key = format!("user_code:{}", user_code);
464
465        if let Some(data) = self.storage.get_kv(&user_key).await? {
466            let serialized = String::from_utf8(data)
467                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
468
469            let stored: StoredDeviceAuthorization =
470                serde_json::from_str(&serialized).map_err(|e| {
471                    AuthError::internal(format!("Failed to deserialize device auth: {}", e))
472                })?;
473
474            // Check expiration
475            let now = SystemTime::now();
476            if now > stored.expires_at {
477                return Err(AuthError::auth_method("device_auth", "User code expired"));
478            }
479
480            Ok(stored)
481        } else {
482            Err(AuthError::auth_method("device_auth", "Invalid user_code"))
483        }
484    }
485
486    /// Validate device authorization request
487    fn validate_request(&self, request: &DeviceAuthorizationRequest) -> Result<()> {
488        if request.client_id.is_empty() {
489            return Err(AuthError::auth_method("device_auth", "Missing client_id"));
490        }
491
492        // In production, validate client_id against registered clients
493
494        Ok(())
495    }
496
497    /// Generate a user-friendly code (uppercase, no ambiguous characters)
498    fn generate_user_code(&self) -> String {
499        use rand::RngExt;
500        const CHARS: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; // No ambiguous: 0,O,I,1
501        let mut rng = rand::rng();
502
503        // Generate 9-character code with dash for readability: XXXX-XXXX
504        let code: String = (0..9)
505            .map(|i| {
506                if i == 4 {
507                    '-'
508                } else {
509                    let idx = rng.random_range(0..CHARS.len());
510                    CHARS[idx] as char
511                }
512            })
513            .collect();
514
515        code
516    }
517
518    /// Clean up expired entries from memory cache
519    fn cleanup_expired(
520        &self,
521        authorizations: &mut HashMap<String, StoredDeviceAuthorization>,
522        now: SystemTime,
523    ) {
524        authorizations.retain(|_, auth| now <= auth.expires_at);
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::storage::MemoryStorage;
532    use tokio::time::sleep;
533
534    fn create_test_manager() -> DeviceAuthManager {
535        let storage = Arc::new(MemoryStorage::new());
536        DeviceAuthManager::new(storage, "https://example.com/device".to_string())
537    }
538
539    #[tokio::test]
540    async fn test_create_authorization() {
541        let manager = create_test_manager();
542
543        let request = DeviceAuthorizationRequest {
544            client_id: "test_client".to_string(),
545            scope: Some("openid profile".to_string()),
546        };
547
548        let response = manager.create_authorization(request).await.unwrap();
549
550        assert!(response.device_code.starts_with("dc_"));
551        assert_eq!(response.user_code.len(), 9); // XXXX-XXXX
552        assert!(response.user_code.contains('-'));
553        assert_eq!(response.verification_uri, "https://example.com/device");
554        assert!(response.verification_uri_complete.is_some());
555        assert_eq!(response.interval, 5);
556        assert_eq!(response.expires_in, 600);
557    }
558
559    #[tokio::test]
560    async fn test_poll_pending() {
561        let manager = create_test_manager();
562
563        let request = DeviceAuthorizationRequest {
564            client_id: "test_client".to_string(),
565            scope: None,
566        };
567
568        let response = manager.create_authorization(request).await.unwrap();
569
570        // Poll should return authorization_pending
571        let result = manager.poll_authorization(&response.device_code).await;
572        assert!(result.is_err());
573        let err = result.unwrap_err();
574        assert!(err.to_string().contains("authorization_pending"));
575    }
576
577    #[tokio::test]
578    async fn test_authorize_and_poll() {
579        let manager = create_test_manager();
580
581        let request = DeviceAuthorizationRequest {
582            client_id: "test_client".to_string(),
583            scope: Some("openid".to_string()),
584        };
585
586        let response = manager.create_authorization(request).await.unwrap();
587
588        // Authorize the device
589        manager
590            .authorize_device(&response.user_code, "user_123")
591            .await
592            .unwrap();
593
594        // Poll should now succeed
595        let stored = manager
596            .poll_authorization(&response.device_code)
597            .await
598            .unwrap();
599        assert_eq!(stored.status, DeviceAuthorizationStatus::Authorized);
600        assert_eq!(stored.user_id, Some("user_123".to_string()));
601    }
602
603    #[tokio::test]
604    async fn test_deny_device() {
605        let manager = create_test_manager();
606
607        let request = DeviceAuthorizationRequest {
608            client_id: "test_client".to_string(),
609            scope: None,
610        };
611
612        let response = manager.create_authorization(request).await.unwrap();
613
614        // Deny the device
615        manager.deny_device(&response.user_code).await.unwrap();
616
617        // Poll should return access_denied
618        let result = manager.poll_authorization(&response.device_code).await;
619        assert!(result.is_err());
620        let err = result.unwrap_err();
621        assert!(err.to_string().contains("access_denied"));
622    }
623
624    #[tokio::test]
625    async fn test_slow_down() {
626        let manager = create_test_manager();
627
628        let request = DeviceAuthorizationRequest {
629            client_id: "test_client".to_string(),
630            scope: None,
631        };
632
633        let response = manager.create_authorization(request).await.unwrap();
634
635        // First poll
636        let _ = manager.poll_authorization(&response.device_code).await;
637
638        // Immediate second poll should return slow_down
639        let result = manager.poll_authorization(&response.device_code).await;
640        assert!(result.is_err());
641        let err = result.unwrap_err();
642        assert!(err.to_string().contains("slow_down"));
643    }
644
645    #[tokio::test]
646    async fn test_expiration() {
647        let storage = Arc::new(MemoryStorage::new());
648        // Create manager with very short expiration
649        let manager = DeviceAuthManager::with_settings(
650            storage,
651            "https://example.com/device".to_string(),
652            Duration::from_millis(100),
653            Duration::from_secs(1),
654        );
655
656        let request = DeviceAuthorizationRequest {
657            client_id: "test_client".to_string(),
658            scope: None,
659        };
660
661        let response = manager.create_authorization(request).await.unwrap();
662
663        // Wait for expiration
664        sleep(Duration::from_millis(150)).await;
665
666        // Poll should return expired
667        let result = manager.poll_authorization(&response.device_code).await;
668        assert!(result.is_err());
669        let err = result.unwrap_err();
670        assert!(err.to_string().contains("expired"));
671    }
672
673    #[tokio::test]
674    async fn test_chainable_expiration_and_interval() {
675        let storage = Arc::new(MemoryStorage::new());
676        let manager = DeviceAuthManager::new(storage, "https://example.com/device".to_string())
677            .expiration(Duration::from_secs(300))
678            .interval(Duration::from_secs(10));
679
680        let request = DeviceAuthorizationRequest {
681            client_id: "test_client".to_string(),
682            scope: None,
683        };
684
685        let response = manager.create_authorization(request).await.unwrap();
686        assert_eq!(response.expires_in, 300);
687        assert_eq!(response.interval, 10);
688    }
689}