auth_framework/server/oauth/
device.rs

1//! Device Authorization Grant Implementation - RFC 8628
2//!
3//! This module implements RFC 8628 - OAuth 2.0 Device Authorization Grant
4//! which allows devices with limited input capability (smart TVs, printers, etc.)
5//! to obtain user authorization.
6
7use crate::errors::{AuthError, Result};
8use crate::storage::AuthStorage;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13use uuid::Uuid;
14
15/// Device authorization request
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DeviceAuthorizationRequest {
18    /// Client identifier
19    pub client_id: String,
20
21    /// Requested scopes
22    pub scope: Option<String>,
23}
24
25/// Device authorization response
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DeviceAuthorizationResponse {
28    /// Device verification code (for storage)
29    pub device_code: String,
30
31    /// User-friendly code to enter
32    pub user_code: String,
33
34    /// URL where user should authorize
35    pub verification_uri: String,
36
37    /// Complete verification URL with user_code (optional)
38    pub verification_uri_complete: Option<String>,
39
40    /// Polling interval in seconds
41    pub interval: u64,
42
43    /// Device code expires in seconds
44    pub expires_in: u64,
45}
46
47/// Token request for device code grant
48#[derive(Debug, Clone, Deserialize)]
49pub struct DeviceTokenRequest {
50    /// Grant type (must be "urn:ietf:params:oauth:grant-type:device_code")
51    pub grant_type: String,
52
53    /// Device code received from device authorization
54    pub device_code: String,
55
56    /// Client identifier
57    pub client_id: String,
58}
59
60/// Stored device authorization data
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct StoredDeviceAuthorization {
63    /// Device code
64    pub device_code: String,
65
66    /// User code
67    pub user_code: String,
68
69    /// Client ID
70    pub client_id: String,
71
72    /// Requested scopes
73    pub scope: Option<String>,
74
75    /// Authorization status
76    pub status: DeviceAuthorizationStatus,
77
78    /// User ID (once authorized)
79    pub user_id: Option<String>,
80
81    /// When the request was created
82    pub created_at: SystemTime,
83
84    /// When the request expires
85    pub expires_at: SystemTime,
86
87    /// Last poll time (for slow_down error)
88    pub last_poll: Option<SystemTime>,
89}
90
91/// Device authorization status
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93pub enum DeviceAuthorizationStatus {
94    /// Awaiting user authorization
95    Pending,
96    /// User has authorized
97    Authorized,
98    /// User has denied
99    Denied,
100    /// Authorization expired
101    Expired,
102}
103
104/// Device authorization manager with persistent storage
105use std::fmt;
106
107#[derive(Clone)]
108pub struct DeviceAuthManager {
109    /// Persistent storage backend
110    storage: Arc<dyn AuthStorage>,
111
112    /// Memory cache for fast access
113    authorizations: Arc<tokio::sync::RwLock<HashMap<String, StoredDeviceAuthorization>>>,
114
115    /// Default expiration time for device codes
116    default_expiration: Duration,
117
118    /// Minimum polling interval
119    min_interval: Duration,
120
121    /// Base verification URI
122    verification_uri: String,
123}
124
125impl fmt::Debug for DeviceAuthManager {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.debug_struct("DeviceAuthManager")
128            .field("storage", &"<dyn AuthStorage>")
129            .field("default_expiration", &self.default_expiration)
130            .field("min_interval", &self.min_interval)
131            .field("verification_uri", &self.verification_uri)
132            .finish()
133    }
134}
135
136impl DeviceAuthManager {
137    /// Create a new device authorization manager
138    pub fn new(storage: Arc<dyn AuthStorage>, verification_uri: String) -> Self {
139        Self {
140            storage,
141            authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
142            default_expiration: Duration::from_secs(600), // 10 minutes (RFC 8628 recommendation)
143            min_interval: Duration::from_secs(5),         // 5 seconds minimum
144            verification_uri,
145        }
146    }
147
148    /// Create a new device authorization manager with custom settings
149    pub fn with_settings(
150        storage: Arc<dyn AuthStorage>,
151        verification_uri: String,
152        expiration: Duration,
153        min_interval: Duration,
154    ) -> Self {
155        Self {
156            storage,
157            authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
158            default_expiration: expiration,
159            min_interval,
160            verification_uri,
161        }
162    }
163
164    /// Initiate device authorization flow
165    pub async fn create_authorization(
166        &self,
167        request: DeviceAuthorizationRequest,
168    ) -> Result<DeviceAuthorizationResponse> {
169        // Validate the request
170        self.validate_request(&request)?;
171
172        // Generate device code and user code
173        let device_code = format!("dc_{}", Uuid::new_v4().simple());
174        let user_code = self.generate_user_code();
175
176        // Calculate expiration
177        let now = SystemTime::now();
178        let expires_at = now + self.default_expiration;
179
180        // Create stored authorization
181        let stored = StoredDeviceAuthorization {
182            device_code: device_code.clone(),
183            user_code: user_code.clone(),
184            client_id: request.client_id.clone(),
185            scope: request.scope.clone(),
186            status: DeviceAuthorizationStatus::Pending,
187            user_id: None,
188            created_at: now,
189            expires_at,
190            last_poll: None,
191        };
192
193        // Store in persistent backend with TTL
194        let device_key = format!("device_code:{}", device_code);
195        let user_key = format!("user_code:{}", user_code);
196
197        let serialized = serde_json::to_string(&stored)
198            .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
199
200        self.storage
201            .store_kv(
202                &device_key,
203                serialized.as_bytes(),
204                Some(self.default_expiration),
205            )
206            .await
207            .map_err(|e| {
208                AuthError::internal(format!("Failed to store device authorization: {}", e))
209            })?;
210
211        // Also store under user_code for verification page
212        self.storage
213            .store_kv(
214                &user_key,
215                serialized.as_bytes(),
216                Some(self.default_expiration),
217            )
218            .await
219            .map_err(|e| {
220                AuthError::internal(format!("Failed to store user code mapping: {}", e))
221            })?;
222
223        // Cache in memory
224        let mut authorizations = self.authorizations.write().await;
225        authorizations.insert(device_code.clone(), stored);
226
227        // Cleanup expired entries
228        self.cleanup_expired(&mut authorizations, now);
229
230        // Create response
231        let verification_uri_complete =
232            format!("{}?user_code={}", self.verification_uri, user_code);
233
234        Ok(DeviceAuthorizationResponse {
235            device_code,
236            user_code,
237            verification_uri: self.verification_uri.clone(),
238            verification_uri_complete: Some(verification_uri_complete),
239            interval: self.min_interval.as_secs(),
240            expires_in: self.default_expiration.as_secs(),
241        })
242    }
243
244    /// Poll for authorization status (used during token endpoint polling)
245    pub async fn poll_authorization(&self, device_code: &str) -> Result<StoredDeviceAuthorization> {
246        let device_key = format!("device_code:{}", device_code);
247
248        // Try to load from persistent storage first
249        let mut stored = if let Some(data) = self.storage.get_kv(&device_key).await? {
250            let serialized = String::from_utf8(data)
251                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
252
253            serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
254                AuthError::internal(format!("Failed to deserialize device auth: {}", e))
255            })?
256        } else {
257            // Fallback to memory cache
258            let authorizations = self.authorizations.read().await;
259            authorizations
260                .get(device_code)
261                .cloned()
262                .ok_or_else(|| AuthError::auth_method("device_auth", "Invalid device_code"))?
263        };
264
265        // Check expiration
266        let now = SystemTime::now();
267        if now > stored.expires_at {
268            stored.status = DeviceAuthorizationStatus::Expired;
269            return Err(AuthError::auth_method("device_auth", "Device code expired"));
270        }
271
272        // Check for slow_down (polling too frequently)
273        if let Some(last_poll) = stored.last_poll {
274            let elapsed = now.duration_since(last_poll).unwrap_or(Duration::ZERO);
275            if elapsed < self.min_interval {
276                return Err(AuthError::auth_method("device_auth", "slow_down"));
277            }
278        }
279
280        // Update last poll time
281        stored.last_poll = Some(now);
282
283        // Persist updated state
284        let serialized = serde_json::to_string(&stored)
285            .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
286
287        self.storage
288            .store_kv(
289                &device_key,
290                serialized.as_bytes(),
291                Some(self.default_expiration),
292            )
293            .await
294            .ok(); // Ignore errors for poll time update
295
296        // Update memory cache
297        let mut authorizations = self.authorizations.write().await;
298        authorizations.insert(device_code.to_string(), stored.clone());
299
300        // Return current status
301        match stored.status {
302            DeviceAuthorizationStatus::Pending => Err(AuthError::auth_method(
303                "device_auth",
304                "authorization_pending",
305            )),
306            DeviceAuthorizationStatus::Authorized => Ok(stored),
307            DeviceAuthorizationStatus::Denied => {
308                Err(AuthError::auth_method("device_auth", "access_denied"))
309            }
310            DeviceAuthorizationStatus::Expired => {
311                Err(AuthError::auth_method("device_auth", "expired_token"))
312            }
313        }
314    }
315
316    /// Authorize a device (called when user approves on verification page)
317    pub async fn authorize_device(&self, user_code: &str, user_id: &str) -> Result<()> {
318        let user_key = format!("user_code:{}", user_code);
319
320        // Load from storage
321        let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
322            let serialized = String::from_utf8(data)
323                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
324
325            serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
326                AuthError::internal(format!("Failed to deserialize device auth: {}", e))
327            })?
328        } else {
329            return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
330        };
331
332        // Check if expired
333        let now = SystemTime::now();
334        if now > stored.expires_at {
335            return Err(AuthError::auth_method("device_auth", "Device code expired"));
336        }
337
338        // Update status
339        stored.status = DeviceAuthorizationStatus::Authorized;
340        stored.user_id = Some(user_id.to_string());
341
342        // Persist updated state
343        let serialized = serde_json::to_string(&stored)
344            .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
345
346        let device_key = format!("device_code:{}", stored.device_code);
347
348        self.storage
349            .store_kv(
350                &device_key,
351                serialized.as_bytes(),
352                Some(self.default_expiration),
353            )
354            .await?;
355
356        self.storage
357            .store_kv(
358                &user_key,
359                serialized.as_bytes(),
360                Some(self.default_expiration),
361            )
362            .await?;
363
364        // Update memory cache
365        let mut authorizations = self.authorizations.write().await;
366        authorizations.insert(stored.device_code.clone(), stored);
367
368        Ok(())
369    }
370
371    /// Deny a device authorization
372    pub async fn deny_device(&self, user_code: &str) -> Result<()> {
373        let user_key = format!("user_code:{}", user_code);
374
375        // Load from storage
376        let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
377            let serialized = String::from_utf8(data)
378                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
379
380            serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
381                AuthError::internal(format!("Failed to deserialize device auth: {}", e))
382            })?
383        } else {
384            return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
385        };
386
387        // Update status
388        stored.status = DeviceAuthorizationStatus::Denied;
389
390        // Persist updated state
391        let serialized = serde_json::to_string(&stored)
392            .map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
393
394        let device_key = format!("device_code:{}", stored.device_code);
395
396        self.storage
397            .store_kv(
398                &device_key,
399                serialized.as_bytes(),
400                Some(self.default_expiration),
401            )
402            .await?;
403
404        self.storage
405            .store_kv(
406                &user_key,
407                serialized.as_bytes(),
408                Some(self.default_expiration),
409            )
410            .await?;
411
412        // Update memory cache
413        let mut authorizations = self.authorizations.write().await;
414        authorizations.insert(stored.device_code.clone(), stored);
415
416        Ok(())
417    }
418
419    /// Get device authorization by user code (for verification page)
420    pub async fn get_by_user_code(&self, user_code: &str) -> Result<StoredDeviceAuthorization> {
421        let user_key = format!("user_code:{}", user_code);
422
423        if let Some(data) = self.storage.get_kv(&user_key).await? {
424            let serialized = String::from_utf8(data)
425                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
426
427            let stored: StoredDeviceAuthorization =
428                serde_json::from_str(&serialized).map_err(|e| {
429                    AuthError::internal(format!("Failed to deserialize device auth: {}", e))
430                })?;
431
432            // Check expiration
433            let now = SystemTime::now();
434            if now > stored.expires_at {
435                return Err(AuthError::auth_method("device_auth", "User code expired"));
436            }
437
438            Ok(stored)
439        } else {
440            Err(AuthError::auth_method("device_auth", "Invalid user_code"))
441        }
442    }
443
444    /// Validate device authorization request
445    fn validate_request(&self, request: &DeviceAuthorizationRequest) -> Result<()> {
446        if request.client_id.is_empty() {
447            return Err(AuthError::auth_method("device_auth", "Missing client_id"));
448        }
449
450        // In production, validate client_id against registered clients
451
452        Ok(())
453    }
454
455    /// Generate a user-friendly code (uppercase, no ambiguous characters)
456    fn generate_user_code(&self) -> String {
457        use rand::Rng;
458        const CHARS: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; // No ambiguous: 0,O,I,1
459        let mut rng = rand::rng();
460
461        // Generate 9-character code with dash for readability: XXXX-XXXX
462        let code: String = (0..9)
463            .map(|i| {
464                if i == 4 {
465                    '-'
466                } else {
467                    let idx = rng.random_range(0..CHARS.len());
468                    CHARS[idx] as char
469                }
470            })
471            .collect();
472
473        code
474    }
475
476    /// Clean up expired entries from memory cache
477    fn cleanup_expired(
478        &self,
479        authorizations: &mut HashMap<String, StoredDeviceAuthorization>,
480        now: SystemTime,
481    ) {
482        authorizations.retain(|_, auth| now <= auth.expires_at);
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use crate::storage::MemoryStorage;
490    use tokio::time::sleep;
491
492    fn create_test_manager() -> DeviceAuthManager {
493        let storage = Arc::new(MemoryStorage::new());
494        DeviceAuthManager::new(storage, "https://example.com/device".to_string())
495    }
496
497    #[tokio::test]
498    async fn test_create_authorization() {
499        let manager = create_test_manager();
500
501        let request = DeviceAuthorizationRequest {
502            client_id: "test_client".to_string(),
503            scope: Some("openid profile".to_string()),
504        };
505
506        let response = manager.create_authorization(request).await.unwrap();
507
508        assert!(response.device_code.starts_with("dc_"));
509        assert_eq!(response.user_code.len(), 9); // XXXX-XXXX
510        assert!(response.user_code.contains('-'));
511        assert_eq!(response.verification_uri, "https://example.com/device");
512        assert!(response.verification_uri_complete.is_some());
513        assert_eq!(response.interval, 5);
514        assert_eq!(response.expires_in, 600);
515    }
516
517    #[tokio::test]
518    async fn test_poll_pending() {
519        let manager = create_test_manager();
520
521        let request = DeviceAuthorizationRequest {
522            client_id: "test_client".to_string(),
523            scope: None,
524        };
525
526        let response = manager.create_authorization(request).await.unwrap();
527
528        // Poll should return authorization_pending
529        let result = manager.poll_authorization(&response.device_code).await;
530        assert!(result.is_err());
531        let err = result.unwrap_err();
532        assert_eq!(err.to_string().contains("authorization_pending"), true);
533    }
534
535    #[tokio::test]
536    async fn test_authorize_and_poll() {
537        let manager = create_test_manager();
538
539        let request = DeviceAuthorizationRequest {
540            client_id: "test_client".to_string(),
541            scope: Some("openid".to_string()),
542        };
543
544        let response = manager.create_authorization(request).await.unwrap();
545
546        // Authorize the device
547        manager
548            .authorize_device(&response.user_code, "user_123")
549            .await
550            .unwrap();
551
552        // Poll should now succeed
553        let stored = manager
554            .poll_authorization(&response.device_code)
555            .await
556            .unwrap();
557        assert_eq!(stored.status, DeviceAuthorizationStatus::Authorized);
558        assert_eq!(stored.user_id, Some("user_123".to_string()));
559    }
560
561    #[tokio::test]
562    async fn test_deny_device() {
563        let manager = create_test_manager();
564
565        let request = DeviceAuthorizationRequest {
566            client_id: "test_client".to_string(),
567            scope: None,
568        };
569
570        let response = manager.create_authorization(request).await.unwrap();
571
572        // Deny the device
573        manager.deny_device(&response.user_code).await.unwrap();
574
575        // Poll should return access_denied
576        let result = manager.poll_authorization(&response.device_code).await;
577        assert!(result.is_err());
578        let err = result.unwrap_err();
579        assert_eq!(err.to_string().contains("access_denied"), true);
580    }
581
582    #[tokio::test]
583    async fn test_slow_down() {
584        let manager = create_test_manager();
585
586        let request = DeviceAuthorizationRequest {
587            client_id: "test_client".to_string(),
588            scope: None,
589        };
590
591        let response = manager.create_authorization(request).await.unwrap();
592
593        // First poll
594        let _ = manager.poll_authorization(&response.device_code).await;
595
596        // Immediate second poll should return slow_down
597        let result = manager.poll_authorization(&response.device_code).await;
598        assert!(result.is_err());
599        let err = result.unwrap_err();
600        assert_eq!(err.to_string().contains("slow_down"), true);
601    }
602
603    #[tokio::test]
604    async fn test_expiration() {
605        let storage = Arc::new(MemoryStorage::new());
606        // Create manager with very short expiration
607        let manager = DeviceAuthManager::with_settings(
608            storage,
609            "https://example.com/device".to_string(),
610            Duration::from_millis(100),
611            Duration::from_secs(1),
612        );
613
614        let request = DeviceAuthorizationRequest {
615            client_id: "test_client".to_string(),
616            scope: None,
617        };
618
619        let response = manager.create_authorization(request).await.unwrap();
620
621        // Wait for expiration
622        sleep(Duration::from_millis(150)).await;
623
624        // Poll should return expired
625        let result = manager.poll_authorization(&response.device_code).await;
626        assert!(result.is_err());
627        let err = result.unwrap_err();
628        assert_eq!(err.to_string().contains("expired"), true);
629    }
630}