Skip to main content

auth_framework/methods/enhanced_device/
mod.rs

1//! Enhanced Device Flow Implementation
2//!
3//! This module provides advanced device flow authentication using the oauth-device-flows crate
4//! for improved reliability, QR code generation, and better error handling.
5
6use crate::authentication::credentials::{Credential, CredentialMetadata};
7use crate::errors::{AuthError, Result};
8use crate::methods::{AuthMethod, MethodResult};
9use crate::tokens::AuthToken;
10#[cfg(feature = "enhanced-device-flow")]
11use base64::Engine as _;
12#[cfg(feature = "enhanced-device-flow")]
13use base64::engine::general_purpose::URL_SAFE_NO_PAD;
14use serde::{Deserialize, Serialize};
15
16/// Extract the `sub` claim from a JWT access token without verifying the
17/// signature.  This is acceptable here because the token was *just* received
18/// over TLS from the authorisation server; full signature verification is
19/// performed later by the framework's token validation layer.
20#[cfg(feature = "enhanced-device-flow")]
21fn extract_sub_from_jwt(jwt: &str) -> Option<String> {
22    let parts: Vec<&str> = jwt.splitn(3, '.').collect();
23    if parts.len() < 2 {
24        return None;
25    }
26    let payload = URL_SAFE_NO_PAD.decode(parts[1]).ok()?;
27    let claims: serde_json::Value = serde_json::from_slice(&payload).ok()?;
28    claims.get("sub").and_then(|v| v.as_str()).map(String::from)
29}
30
31/// Instructions for device flow authentication
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct DeviceFlowInstructions {
34    /// URL the user should visit
35    pub verification_uri: String,
36    /// Complete URL with embedded code for faster authentication
37    pub verification_uri_complete: Option<String>,
38    /// Device code to display to the user
39    pub user_code: String,
40    /// QR code as base64 encoded PNG (if feature enabled)
41    pub qr_code: Option<String>,
42    /// How long the user has to complete authentication
43    pub expires_in: u64,
44    /// How often to poll for completion
45    pub interval: u64,
46}
47
48/// Enhanced device flow method using oauth-device-flows crate
49#[cfg(feature = "enhanced-device-flow")]
50#[derive(Debug)]
51pub struct EnhancedDeviceFlowMethod {
52    /// OAuth client ID
53    pub client_id: String,
54    /// OAuth client secret (optional for public clients)
55    pub client_secret: Option<String>,
56    /// Authorization URL
57    pub auth_url: String,
58    /// Token URL
59    pub token_url: String,
60    /// Device authorization URL
61    pub device_auth_url: String,
62    /// OAuth scopes to request
63    pub scopes: Vec<String>,
64    /// Custom polling interval (optional)
65    pub _polling_interval: Option<std::time::Duration>,
66    /// Enable QR code generation
67    pub enable_qr_code: bool,
68}
69
70#[cfg(feature = "enhanced-device-flow")]
71impl EnhancedDeviceFlowMethod {
72    /// Create a new enhanced device flow method
73    pub fn new(
74        client_id: String,
75        client_secret: Option<String>,
76        auth_url: String,
77        token_url: String,
78        device_auth_url: String,
79    ) -> Self {
80        Self {
81            client_id,
82            client_secret,
83            auth_url,
84            token_url,
85            device_auth_url,
86            scopes: Vec::new(),
87            _polling_interval: None,
88            enable_qr_code: true,
89        }
90    }
91
92    /// Set the OAuth scopes
93    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
94        self.scopes = scopes;
95        self
96    }
97
98    /// Set custom polling interval
99    pub fn with_polling_interval(mut self, interval: std::time::Duration) -> Self {
100        self._polling_interval = Some(interval);
101        self
102    }
103
104    /// Enable or disable QR code generation
105    pub fn with_qr_code(mut self, enable: bool) -> Self {
106        self.enable_qr_code = enable;
107        self
108    }
109
110    /// Initiate device flow and return instructions
111    pub async fn initiate_device_flow(&self) -> Result<DeviceFlowInstructions> {
112        let client = reqwest::Client::new();
113        let mut params = std::collections::HashMap::new();
114        params.insert("client_id", self.client_id.clone());
115        if !self.scopes.is_empty() {
116            params.insert("scope", self.scopes.join(" "));
117        }
118
119        let res = client
120            .post(&self.device_auth_url)
121            .form(&params)
122            .send()
123            .await
124            .map_err(AuthError::Network)?;
125
126        if !res.status().is_success() {
127            return Err(AuthError::config(&format!(
128                "Device auth request failed: {}",
129                res.status()
130            )));
131        }
132
133        #[derive(Deserialize)]
134        struct DeviceAuthResponse {
135            #[allow(dead_code)]
136            device_code: String,
137            user_code: String,
138            verification_uri: String,
139            verification_uri_complete: Option<String>,
140            expires_in: u64,
141            interval: Option<u64>,
142        }
143
144        let data: DeviceAuthResponse = res.json().await.map_err(AuthError::Network)?;
145
146        Ok(DeviceFlowInstructions {
147            verification_uri: data.verification_uri,
148            verification_uri_complete: data.verification_uri_complete,
149            user_code: data.user_code,
150            qr_code: None, // QR Code generation could be added here
151            expires_in: data.expires_in,
152            interval: data.interval.unwrap_or(5),
153        })
154    }
155}
156
157#[cfg(feature = "enhanced-device-flow")]
158impl AuthMethod for EnhancedDeviceFlowMethod {
159    type MethodResult = MethodResult;
160    type AuthToken = AuthToken;
161
162    async fn authenticate(
163        &self,
164        credential: Credential,
165        _metadata: CredentialMetadata,
166    ) -> Result<Self::MethodResult> {
167        let (device_code, _interval) = match credential {
168            Credential::EnhancedDeviceFlow {
169                device_code,
170                interval,
171                ..
172            } => (device_code, interval),
173            _ => {
174                return Ok(MethodResult::Failure {
175                    reason: "Invalid credential type for enhanced device flow".to_string(),
176                });
177            }
178        };
179
180        let client = reqwest::Client::new();
181        let mut params = std::collections::HashMap::new();
182        params.insert("client_id", self.client_id.clone());
183        params.insert(
184            "grant_type",
185            "urn:ietf:params:oauth:grant-type:device_code".to_string(),
186        );
187        params.insert("device_code", device_code);
188
189        if let Some(secret) = &self.client_secret {
190            params.insert("client_secret", secret.clone());
191        }
192
193        let res = client
194            .post(&self.token_url)
195            .form(&params)
196            .send()
197            .await
198            .map_err(AuthError::Network)?;
199
200        if !res.status().is_success() {
201            let error_text = res.text().await.unwrap_or_default();
202            // Typical errors like authorization_pending, slow_down, expired_token, access_denied
203            return Ok(MethodResult::Failure {
204                reason: format!("Token exchange failed: {}", error_text),
205            });
206        }
207
208        #[derive(Deserialize)]
209        struct TokenResponse {
210            access_token: String,
211            refresh_token: Option<String>,
212            expires_in: Option<u64>,
213        }
214
215        let token_data: TokenResponse = res.json().await.map_err(AuthError::Network)?;
216
217        // Extract user_id from the JWT access token (the `sub` claim).
218        let user_id = extract_sub_from_jwt(&token_data.access_token)
219            .unwrap_or_else(|| "unknown_device_user".to_string());
220
221        let expires_in = std::time::Duration::from_secs(token_data.expires_in.unwrap_or(3600));
222
223        let mut token = AuthToken::new(user_id, &token_data.access_token, expires_in, self.name());
224        token.refresh_token = token_data.refresh_token;
225
226        Ok(MethodResult::Success(Box::new(token)))
227    }
228
229    fn name(&self) -> &str {
230        "enhanced_device_flow"
231    }
232
233    fn validate_config(&self) -> Result<()> {
234        if self.client_id.is_empty() {
235            return Err(AuthError::config("Client ID is required"));
236        }
237        if self.auth_url.is_empty() {
238            return Err(AuthError::config("Authorization URL is required"));
239        }
240        if self.token_url.is_empty() {
241            return Err(AuthError::config("Token URL is required"));
242        }
243        if self.device_auth_url.is_empty() {
244            return Err(AuthError::config("Device authorization URL is required"));
245        }
246        Ok(())
247    }
248}
249
250// Proper implementation when feature is disabled - captures configuration for error reporting
251#[cfg(not(feature = "enhanced-device-flow"))]
252#[derive(Debug)]
253pub struct EnhancedDeviceFlowMethod {
254    /// Client configuration (stored for error reporting)
255    client_id: String,
256    client_secret: Option<String>,
257    auth_url: String,
258    token_url: String,
259    device_auth_url: String,
260}
261
262#[cfg(not(feature = "enhanced-device-flow"))]
263impl EnhancedDeviceFlowMethod {
264    pub fn new(
265        client_id: String,
266        client_secret: Option<String>,
267        auth_url: String,
268        token_url: String,
269        device_auth_url: String,
270    ) -> Self {
271        Self {
272            client_id,
273            client_secret,
274            auth_url,
275            token_url,
276            device_auth_url,
277        }
278    }
279}
280
281#[cfg(not(feature = "enhanced-device-flow"))]
282impl AuthMethod for EnhancedDeviceFlowMethod {
283    type MethodResult = MethodResult;
284    type AuthToken = AuthToken;
285
286    async fn authenticate(
287        &self,
288        _credential: Credential,
289        _metadata: CredentialMetadata,
290    ) -> Result<Self::MethodResult> {
291        // Use configuration fields in error message to avoid unused field warnings
292        Err(AuthError::config(format!(
293            "Enhanced device flow requires 'enhanced-device-flow' feature. Configured for client '{}' with auth_url: {}, token_url: {}, device_auth_url: {}",
294            self.client_id, self.auth_url, self.token_url, self.device_auth_url
295        )))
296    }
297
298    fn name(&self) -> &str {
299        "enhanced_device_flow"
300    }
301
302    fn validate_config(&self) -> Result<()> {
303        // Use configuration fields for validation to avoid unused field warnings
304        if self.client_id.is_empty() {
305            return Err(AuthError::config("client_id cannot be empty"));
306        }
307        if self.auth_url.is_empty() {
308            return Err(AuthError::config("auth_url cannot be empty"));
309        }
310        if self.token_url.is_empty() {
311            return Err(AuthError::config("token_url cannot be empty"));
312        }
313        if self.device_auth_url.is_empty() {
314            return Err(AuthError::config("device_auth_url cannot be empty"));
315        }
316
317        // Log configuration for debugging (uses client_secret field)
318        if self.client_secret.is_some() {
319            tracing::info!(
320                "Enhanced device flow configured for confidential client: {}",
321                self.client_id
322            );
323        } else {
324            tracing::info!(
325                "Enhanced device flow configured for public client: {}",
326                self.client_id
327            );
328        }
329
330        Err(AuthError::config(
331            "Enhanced device flow requires 'enhanced-device-flow' feature to be enabled at compile time",
332        ))
333    }
334}
335
336/// Enhanced device authentication (legacy struct for compatibility)
337pub struct EnhancedDevice {
338    /// Device identifier
339    pub device_id: String,
340}
341
342impl EnhancedDevice {
343    /// Create new enhanced device
344    pub fn new(device_id: String) -> Self {
345        Self { device_id }
346    }
347
348    /// Authenticate using enhanced device
349    pub async fn authenticate(&self, challenge: &str) -> Result<bool> {
350        // Enhanced device authentication with device binding and trust signals
351
352        if challenge.is_empty() {
353            tracing::warn!("Empty challenge provided for device authentication");
354            return Ok(false);
355        }
356
357        tracing::info!(
358            "Starting enhanced device authentication for device: {}",
359            self.device_id
360        );
361
362        // Simulate enhanced device authentication process
363        // In a real implementation, this would:
364
365        // 1. Verify device identity and binding
366        if !self.verify_device_binding().await? {
367            tracing::warn!("Device binding verification failed for: {}", self.device_id);
368            return Ok(false);
369        }
370
371        // 2. Check device trust signals
372        if !self.check_device_trust_signals().await? {
373            tracing::warn!("Device trust signals check failed for: {}", self.device_id);
374            return Ok(false);
375        }
376
377        // 3. Validate challenge-response with device-specific cryptography
378        if !self.validate_device_challenge(challenge).await? {
379            tracing::warn!("Device challenge validation failed for: {}", self.device_id);
380            return Ok(false);
381        }
382
383        tracing::info!(
384            "Enhanced device authentication successful for: {}",
385            self.device_id
386        );
387        Ok(true)
388    }
389
390    /// Verify device binding and identity
391    async fn verify_device_binding(&self) -> Result<bool> {
392        tracing::debug!("Verifying device binding for: {}", self.device_id);
393
394        // In production, this would:
395        // 1. Check device certificate or attestation
396        // 2. Validate device hardware identity
397        // 3. Verify device registration status
398        // 4. Check device compliance status
399
400        // Simulate device binding check
401        if self.device_id.len() < 8 {
402            tracing::warn!("Device ID too short for secure binding");
403            return Ok(false);
404        }
405
406        // Validate device ID format (should be UUID or similar)
407        if !self
408            .device_id
409            .chars()
410            .all(|c| c.is_ascii_alphanumeric() || c == '-')
411        {
412            tracing::warn!("Invalid device ID format");
413            return Ok(false);
414        }
415
416        tracing::debug!("Device binding verified for: {}", self.device_id);
417        Ok(true)
418    }
419
420    /// Check device trust signals
421    async fn check_device_trust_signals(&self) -> Result<bool> {
422        tracing::debug!("Checking device trust signals for: {}", self.device_id);
423
424        // In production, this would check:
425        // 1. Device reputation score
426        // 2. Recent suspicious activity
427        // 3. Device location and behavior patterns
428        // 4. Security posture (OS version, patches, etc.)
429        // 5. Mobile Device Management (MDM) status
430        // 6. Device encryption status
431
432        // Simulate trust signal evaluation
433        let trust_score = self.calculate_trust_score().await;
434
435        if trust_score < 0.7 {
436            tracing::warn!(
437                "Device trust score too low: {} for device: {}",
438                trust_score,
439                self.device_id
440            );
441            return Ok(false);
442        }
443
444        tracing::info!(
445            "Device trust signals validated (score: {}) for: {}",
446            trust_score,
447            self.device_id
448        );
449        Ok(true)
450    }
451
452    /// Calculate device trust score
453    async fn calculate_trust_score(&self) -> f64 {
454        // Simulate trust score calculation based on verifiable device properties.
455        // In production this would query MDM, EDR, and attestation services.
456        let mut score = 1.0_f64;
457
458        // Newly-registered devices start with a lower initial trust score
459        if self.device_id.contains("new") {
460            score -= 0.1;
461        }
462
463        // Test/development devices are considered less trusted
464        if self.device_id.contains("test") {
465            score -= 0.2;
466        }
467
468        // Clamp to [0.0, 1.0] so callers always receive a valid score
469        score.clamp(0.0, 1.0)
470    }
471
472    /// Validate device-specific challenge
473    async fn validate_device_challenge(&self, challenge: &str) -> Result<bool> {
474        tracing::debug!("Validating device challenge for: {}", self.device_id);
475
476        // In production, this would:
477        // 1. Perform cryptographic challenge-response
478        // 2. Validate device attestation
479        // 3. Check challenge freshness and replay protection
480        // 4. Verify device-specific cryptographic proof
481
482        // Minimum length requirement — too-short challenges cannot provide replay protection
483        if challenge.len() < 16 {
484            tracing::warn!(
485                "Device challenge too short ({} chars) for: {}",
486                challenge.len(),
487                self.device_id
488            );
489            return Ok(false);
490        }
491
492        // The challenge must consist only of URL-safe base64 / hex characters
493        // (alphanumeric, '+', '/', '-', '_', '=')
494        let valid_chars = challenge
495            .chars()
496            .all(|c| c.is_ascii_alphanumeric() || matches!(c, '+' | '/' | '-' | '_' | '='));
497
498        if !valid_chars {
499            tracing::warn!(
500                "Device challenge contains invalid characters for: {}",
501                self.device_id
502            );
503            return Ok(false);
504        }
505
506        tracing::debug!("Device challenge validation successful");
507        Ok(true)
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    fn device(id: &str) -> EnhancedDevice {
516        EnhancedDevice::new(id.to_string())
517    }
518
519    // ── EnhancedDevice::new ───────────────────────────────────────────────────
520
521    #[test]
522    fn test_new_stores_device_id() {
523        let d = device("my-device-abc123");
524        assert_eq!(d.device_id, "my-device-abc123");
525    }
526
527    // ── verify_device_binding ────────────────────────────────────────────────
528
529    #[tokio::test]
530    async fn test_device_binding_valid_uuid_format() {
531        let d = device("550e8400-e29b-41d4-a716-446655440000");
532        // UUID-style device_id is ≥ 8 chars and only alphanumeric + dash
533        let result = d.verify_device_binding().await.unwrap();
534        assert!(result, "UUID-format device ID should pass binding check");
535    }
536
537    #[tokio::test]
538    async fn test_device_binding_too_short() {
539        let d = device("abc123"); // 6 chars — below 8
540        assert!(
541            !d.verify_device_binding().await.unwrap(),
542            "Device IDs shorter than 8 chars must fail"
543        );
544    }
545
546    #[tokio::test]
547    async fn test_device_binding_invalid_chars() {
548        let d = device("device@with#special!chars");
549        assert!(
550            !d.verify_device_binding().await.unwrap(),
551            "Device IDs with special chars (not alphanumeric/-) must fail"
552        );
553    }
554
555    // ── calculate_trust_score ────────────────────────────────────────────────
556
557    #[tokio::test]
558    async fn test_trust_score_clean_device_is_1_0() {
559        let d = device("abcd1234efgh5678"); // no "new" or "test" in name
560        let score = d.calculate_trust_score().await;
561        assert!(
562            (score - 1.0).abs() < f64::EPSILON,
563            "Clean device should score 1.0, got {score}"
564        );
565    }
566
567    #[tokio::test]
568    async fn test_trust_score_new_device_is_reduced() {
569        let d = device("newdevice-abcd1234");
570        let score = d.calculate_trust_score().await;
571        assert!(
572            score < 1.0,
573            "Device containing 'new' should have score < 1.0, got {score}"
574        );
575        assert!(
576            (score - 0.9).abs() < f64::EPSILON,
577            "Expected 0.9, got {score}"
578        );
579    }
580
581    #[tokio::test]
582    async fn test_trust_score_test_device_is_reduced() {
583        let d = device("testdevice-abcd1234");
584        let score = d.calculate_trust_score().await;
585        assert!(
586            (score - 0.8).abs() < f64::EPSILON,
587            "Expected 0.8 for 'test' device, got {score}"
588        );
589    }
590
591    #[tokio::test]
592    async fn test_trust_score_new_and_test_device() {
593        let d = device("new-testdevice-abcd1234");
594        let score = d.calculate_trust_score().await;
595        // 1.0 - 0.1 (new) - 0.2 (test) = 0.7
596        assert!(
597            (score - 0.7).abs() < f64::EPSILON,
598            "Expected 0.7 for device containing both 'new' and 'test', got {score}"
599        );
600    }
601
602    #[tokio::test]
603    async fn test_trust_score_always_in_range() {
604        // Even extreme inputs should stay in [0.0, 1.0]
605        for id in &[
606            "new-test-device-id",
607            "new-new-new-test-test-test-device",
608            "aaaaaaaaaaaaa",
609        ] {
610            let score = device(id).calculate_trust_score().await;
611            assert!(
612                (0.0f64..=1.0).contains(&score),
613                "Trust score {score} out of range [0,1] for '{id}'"
614            );
615        }
616    }
617
618    // ── validate_device_challenge ────────────────────────────────────────────
619
620    #[tokio::test]
621    async fn test_challenge_valid_hex_16_chars() {
622        let d = device("abcdefgh-1234");
623        let challenge = "0123456789abcdef"; // 16 hex chars
624        assert!(d.validate_device_challenge(challenge).await.unwrap());
625    }
626
627    #[tokio::test]
628    async fn test_challenge_valid_base64url() {
629        let d = device("abcdefgh-1234");
630        let challenge = "SGVsbG8gV29ybGQh"; // base64url, 16 chars
631        assert!(d.validate_device_challenge(challenge).await.unwrap());
632    }
633
634    #[tokio::test]
635    async fn test_challenge_too_short() {
636        let d = device("abcdefgh-1234");
637        assert!(
638            !d.validate_device_challenge("short123").await.unwrap(),
639            "Challenge < 16 chars must be rejected"
640        );
641    }
642
643    #[tokio::test]
644    async fn test_challenge_empty() {
645        let d = device("abcdefgh-1234");
646        assert!(!d.validate_device_challenge("").await.unwrap());
647    }
648
649    #[tokio::test]
650    async fn test_challenge_invalid_chars() {
651        let d = device("abcdefgh-1234");
652        // Contains space and exclamation mark — invalid
653        let challenge = "Hello World!!!!!";
654        assert!(
655            !d.validate_device_challenge(challenge).await.unwrap(),
656            "Challenge with spaces/exclamation marks must be rejected"
657        );
658    }
659
660    // ── authenticate (integration path) ──────────────────────────────────────
661
662    #[tokio::test]
663    async fn test_authenticate_empty_challenge_returns_false() {
664        let d = device("abcdefgh-1234");
665        assert!(!d.authenticate("").await.unwrap());
666    }
667
668    #[tokio::test]
669    async fn test_authenticate_valid_device_and_challenge() {
670        // Device id: valid format (UUID-like), Challenge: valid base64url ≥ 16 chars
671        let d = device("550e8400-e29b-41d4-a716-446655440000");
672        let challenge = "SGVsbG8gV29ybGQh"; // valid base64url, 16 chars
673        // Trust score is 1.0 (no "new" or "test"), binding passes, challenge passes
674        assert!(
675            d.authenticate(challenge).await.unwrap(),
676            "Valid device + valid challenge should authenticate"
677        );
678    }
679
680    #[tokio::test]
681    async fn test_authenticate_short_device_id_fails() {
682        let d = device("tiny"); // < 8 chars, fails binding
683        let challenge = "SGVsbG8gV29ybGQh";
684        assert!(
685            !d.authenticate(challenge).await.unwrap(),
686            "Short device ID must fail authentication"
687        );
688    }
689
690    #[tokio::test]
691    async fn test_authenticate_at_minimum_trust_score_passes() {
692        // "new" (-0.1) + "test" (-0.2) → score = 0.7, exactly at the threshold.
693        // check_device_trust_signals fails only when score < 0.7, so this passes.
694        let d = device("new-test-device-abcde"); // score exactly 0.7
695        let challenge = "SGVsbG8gV29ybGQh";
696        assert!(
697            d.authenticate(challenge).await.unwrap(),
698            "Device at minimum trust score (0.7) should still authenticate"
699        );
700    }
701}