auth_framework/methods/enhanced_device/
mod.rs1use crate::authentication::credentials::{Credential, CredentialMetadata};
7use crate::errors::{AuthError, Result};
8use crate::methods::{AuthMethod, MethodResult};
9use crate::tokens::AuthToken;
10#[cfg(feature = "enhanced-device-flow")]
11use base64::Engine as _;
12#[cfg(feature = "enhanced-device-flow")]
13use base64::engine::general_purpose::URL_SAFE_NO_PAD;
14use serde::{Deserialize, Serialize};
15
16#[cfg(feature = "enhanced-device-flow")]
21fn extract_sub_from_jwt(jwt: &str) -> Option<String> {
22 let parts: Vec<&str> = jwt.splitn(3, '.').collect();
23 if parts.len() < 2 {
24 return None;
25 }
26 let payload = URL_SAFE_NO_PAD.decode(parts[1]).ok()?;
27 let claims: serde_json::Value = serde_json::from_slice(&payload).ok()?;
28 claims.get("sub").and_then(|v| v.as_str()).map(String::from)
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct DeviceFlowInstructions {
34 pub verification_uri: String,
36 pub verification_uri_complete: Option<String>,
38 pub user_code: String,
40 pub qr_code: Option<String>,
42 pub expires_in: u64,
44 pub interval: u64,
46}
47
48#[cfg(feature = "enhanced-device-flow")]
50#[derive(Debug)]
51pub struct EnhancedDeviceFlowMethod {
52 pub client_id: String,
54 pub client_secret: Option<String>,
56 pub auth_url: String,
58 pub token_url: String,
60 pub device_auth_url: String,
62 pub scopes: Vec<String>,
64 pub _polling_interval: Option<std::time::Duration>,
66 pub enable_qr_code: bool,
68}
69
70#[cfg(feature = "enhanced-device-flow")]
71impl EnhancedDeviceFlowMethod {
72 pub fn new(
74 client_id: String,
75 client_secret: Option<String>,
76 auth_url: String,
77 token_url: String,
78 device_auth_url: String,
79 ) -> Self {
80 Self {
81 client_id,
82 client_secret,
83 auth_url,
84 token_url,
85 device_auth_url,
86 scopes: Vec::new(),
87 _polling_interval: None,
88 enable_qr_code: true,
89 }
90 }
91
92 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
94 self.scopes = scopes;
95 self
96 }
97
98 pub fn with_polling_interval(mut self, interval: std::time::Duration) -> Self {
100 self._polling_interval = Some(interval);
101 self
102 }
103
104 pub fn with_qr_code(mut self, enable: bool) -> Self {
106 self.enable_qr_code = enable;
107 self
108 }
109
110 pub async fn initiate_device_flow(&self) -> Result<DeviceFlowInstructions> {
112 let client = reqwest::Client::new();
113 let mut params = std::collections::HashMap::new();
114 params.insert("client_id", self.client_id.clone());
115 if !self.scopes.is_empty() {
116 params.insert("scope", self.scopes.join(" "));
117 }
118
119 let res = client
120 .post(&self.device_auth_url)
121 .form(¶ms)
122 .send()
123 .await
124 .map_err(AuthError::Network)?;
125
126 if !res.status().is_success() {
127 return Err(AuthError::config(&format!(
128 "Device auth request failed: {}",
129 res.status()
130 )));
131 }
132
133 #[derive(Deserialize)]
134 struct DeviceAuthResponse {
135 #[allow(dead_code)]
136 device_code: String,
137 user_code: String,
138 verification_uri: String,
139 verification_uri_complete: Option<String>,
140 expires_in: u64,
141 interval: Option<u64>,
142 }
143
144 let data: DeviceAuthResponse = res.json().await.map_err(AuthError::Network)?;
145
146 Ok(DeviceFlowInstructions {
147 verification_uri: data.verification_uri,
148 verification_uri_complete: data.verification_uri_complete,
149 user_code: data.user_code,
150 qr_code: None, expires_in: data.expires_in,
152 interval: data.interval.unwrap_or(5),
153 })
154 }
155}
156
157#[cfg(feature = "enhanced-device-flow")]
158impl AuthMethod for EnhancedDeviceFlowMethod {
159 type MethodResult = MethodResult;
160 type AuthToken = AuthToken;
161
162 async fn authenticate(
163 &self,
164 credential: Credential,
165 _metadata: CredentialMetadata,
166 ) -> Result<Self::MethodResult> {
167 let (device_code, _interval) = match credential {
168 Credential::EnhancedDeviceFlow {
169 device_code,
170 interval,
171 ..
172 } => (device_code, interval),
173 _ => {
174 return Ok(MethodResult::Failure {
175 reason: "Invalid credential type for enhanced device flow".to_string(),
176 });
177 }
178 };
179
180 let client = reqwest::Client::new();
181 let mut params = std::collections::HashMap::new();
182 params.insert("client_id", self.client_id.clone());
183 params.insert(
184 "grant_type",
185 "urn:ietf:params:oauth:grant-type:device_code".to_string(),
186 );
187 params.insert("device_code", device_code);
188
189 if let Some(secret) = &self.client_secret {
190 params.insert("client_secret", secret.clone());
191 }
192
193 let res = client
194 .post(&self.token_url)
195 .form(¶ms)
196 .send()
197 .await
198 .map_err(AuthError::Network)?;
199
200 if !res.status().is_success() {
201 let error_text = res.text().await.unwrap_or_default();
202 return Ok(MethodResult::Failure {
204 reason: format!("Token exchange failed: {}", error_text),
205 });
206 }
207
208 #[derive(Deserialize)]
209 struct TokenResponse {
210 access_token: String,
211 refresh_token: Option<String>,
212 expires_in: Option<u64>,
213 }
214
215 let token_data: TokenResponse = res.json().await.map_err(AuthError::Network)?;
216
217 let user_id = extract_sub_from_jwt(&token_data.access_token)
219 .unwrap_or_else(|| "unknown_device_user".to_string());
220
221 let expires_in = std::time::Duration::from_secs(token_data.expires_in.unwrap_or(3600));
222
223 let mut token = AuthToken::new(user_id, &token_data.access_token, expires_in, self.name());
224 token.refresh_token = token_data.refresh_token;
225
226 Ok(MethodResult::Success(Box::new(token)))
227 }
228
229 fn name(&self) -> &str {
230 "enhanced_device_flow"
231 }
232
233 fn validate_config(&self) -> Result<()> {
234 if self.client_id.is_empty() {
235 return Err(AuthError::config("Client ID is required"));
236 }
237 if self.auth_url.is_empty() {
238 return Err(AuthError::config("Authorization URL is required"));
239 }
240 if self.token_url.is_empty() {
241 return Err(AuthError::config("Token URL is required"));
242 }
243 if self.device_auth_url.is_empty() {
244 return Err(AuthError::config("Device authorization URL is required"));
245 }
246 Ok(())
247 }
248}
249
250#[cfg(not(feature = "enhanced-device-flow"))]
252#[derive(Debug)]
253pub struct EnhancedDeviceFlowMethod {
254 client_id: String,
256 client_secret: Option<String>,
257 auth_url: String,
258 token_url: String,
259 device_auth_url: String,
260}
261
262#[cfg(not(feature = "enhanced-device-flow"))]
263impl EnhancedDeviceFlowMethod {
264 pub fn new(
265 client_id: String,
266 client_secret: Option<String>,
267 auth_url: String,
268 token_url: String,
269 device_auth_url: String,
270 ) -> Self {
271 Self {
272 client_id,
273 client_secret,
274 auth_url,
275 token_url,
276 device_auth_url,
277 }
278 }
279}
280
281#[cfg(not(feature = "enhanced-device-flow"))]
282impl AuthMethod for EnhancedDeviceFlowMethod {
283 type MethodResult = MethodResult;
284 type AuthToken = AuthToken;
285
286 async fn authenticate(
287 &self,
288 _credential: Credential,
289 _metadata: CredentialMetadata,
290 ) -> Result<Self::MethodResult> {
291 Err(AuthError::config(format!(
293 "Enhanced device flow requires 'enhanced-device-flow' feature. Configured for client '{}' with auth_url: {}, token_url: {}, device_auth_url: {}",
294 self.client_id, self.auth_url, self.token_url, self.device_auth_url
295 )))
296 }
297
298 fn name(&self) -> &str {
299 "enhanced_device_flow"
300 }
301
302 fn validate_config(&self) -> Result<()> {
303 if self.client_id.is_empty() {
305 return Err(AuthError::config("client_id cannot be empty"));
306 }
307 if self.auth_url.is_empty() {
308 return Err(AuthError::config("auth_url cannot be empty"));
309 }
310 if self.token_url.is_empty() {
311 return Err(AuthError::config("token_url cannot be empty"));
312 }
313 if self.device_auth_url.is_empty() {
314 return Err(AuthError::config("device_auth_url cannot be empty"));
315 }
316
317 if self.client_secret.is_some() {
319 tracing::info!(
320 "Enhanced device flow configured for confidential client: {}",
321 self.client_id
322 );
323 } else {
324 tracing::info!(
325 "Enhanced device flow configured for public client: {}",
326 self.client_id
327 );
328 }
329
330 Err(AuthError::config(
331 "Enhanced device flow requires 'enhanced-device-flow' feature to be enabled at compile time",
332 ))
333 }
334}
335
336pub struct EnhancedDevice {
338 pub device_id: String,
340}
341
342impl EnhancedDevice {
343 pub fn new(device_id: String) -> Self {
345 Self { device_id }
346 }
347
348 pub async fn authenticate(&self, challenge: &str) -> Result<bool> {
350 if challenge.is_empty() {
353 tracing::warn!("Empty challenge provided for device authentication");
354 return Ok(false);
355 }
356
357 tracing::info!(
358 "Starting enhanced device authentication for device: {}",
359 self.device_id
360 );
361
362 if !self.verify_device_binding().await? {
367 tracing::warn!("Device binding verification failed for: {}", self.device_id);
368 return Ok(false);
369 }
370
371 if !self.check_device_trust_signals().await? {
373 tracing::warn!("Device trust signals check failed for: {}", self.device_id);
374 return Ok(false);
375 }
376
377 if !self.validate_device_challenge(challenge).await? {
379 tracing::warn!("Device challenge validation failed for: {}", self.device_id);
380 return Ok(false);
381 }
382
383 tracing::info!(
384 "Enhanced device authentication successful for: {}",
385 self.device_id
386 );
387 Ok(true)
388 }
389
390 async fn verify_device_binding(&self) -> Result<bool> {
392 tracing::debug!("Verifying device binding for: {}", self.device_id);
393
394 if self.device_id.len() < 8 {
402 tracing::warn!("Device ID too short for secure binding");
403 return Ok(false);
404 }
405
406 if !self
408 .device_id
409 .chars()
410 .all(|c| c.is_ascii_alphanumeric() || c == '-')
411 {
412 tracing::warn!("Invalid device ID format");
413 return Ok(false);
414 }
415
416 tracing::debug!("Device binding verified for: {}", self.device_id);
417 Ok(true)
418 }
419
420 async fn check_device_trust_signals(&self) -> Result<bool> {
422 tracing::debug!("Checking device trust signals for: {}", self.device_id);
423
424 let trust_score = self.calculate_trust_score().await;
434
435 if trust_score < 0.7 {
436 tracing::warn!(
437 "Device trust score too low: {} for device: {}",
438 trust_score,
439 self.device_id
440 );
441 return Ok(false);
442 }
443
444 tracing::info!(
445 "Device trust signals validated (score: {}) for: {}",
446 trust_score,
447 self.device_id
448 );
449 Ok(true)
450 }
451
452 async fn calculate_trust_score(&self) -> f64 {
454 let mut score = 1.0_f64;
457
458 if self.device_id.contains("new") {
460 score -= 0.1;
461 }
462
463 if self.device_id.contains("test") {
465 score -= 0.2;
466 }
467
468 score.clamp(0.0, 1.0)
470 }
471
472 async fn validate_device_challenge(&self, challenge: &str) -> Result<bool> {
474 tracing::debug!("Validating device challenge for: {}", self.device_id);
475
476 if challenge.len() < 16 {
484 tracing::warn!(
485 "Device challenge too short ({} chars) for: {}",
486 challenge.len(),
487 self.device_id
488 );
489 return Ok(false);
490 }
491
492 let valid_chars = challenge
495 .chars()
496 .all(|c| c.is_ascii_alphanumeric() || matches!(c, '+' | '/' | '-' | '_' | '='));
497
498 if !valid_chars {
499 tracing::warn!(
500 "Device challenge contains invalid characters for: {}",
501 self.device_id
502 );
503 return Ok(false);
504 }
505
506 tracing::debug!("Device challenge validation successful");
507 Ok(true)
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 fn device(id: &str) -> EnhancedDevice {
516 EnhancedDevice::new(id.to_string())
517 }
518
519 #[test]
522 fn test_new_stores_device_id() {
523 let d = device("my-device-abc123");
524 assert_eq!(d.device_id, "my-device-abc123");
525 }
526
527 #[tokio::test]
530 async fn test_device_binding_valid_uuid_format() {
531 let d = device("550e8400-e29b-41d4-a716-446655440000");
532 let result = d.verify_device_binding().await.unwrap();
534 assert!(result, "UUID-format device ID should pass binding check");
535 }
536
537 #[tokio::test]
538 async fn test_device_binding_too_short() {
539 let d = device("abc123"); assert!(
541 !d.verify_device_binding().await.unwrap(),
542 "Device IDs shorter than 8 chars must fail"
543 );
544 }
545
546 #[tokio::test]
547 async fn test_device_binding_invalid_chars() {
548 let d = device("device@with#special!chars");
549 assert!(
550 !d.verify_device_binding().await.unwrap(),
551 "Device IDs with special chars (not alphanumeric/-) must fail"
552 );
553 }
554
555 #[tokio::test]
558 async fn test_trust_score_clean_device_is_1_0() {
559 let d = device("abcd1234efgh5678"); let score = d.calculate_trust_score().await;
561 assert!(
562 (score - 1.0).abs() < f64::EPSILON,
563 "Clean device should score 1.0, got {score}"
564 );
565 }
566
567 #[tokio::test]
568 async fn test_trust_score_new_device_is_reduced() {
569 let d = device("newdevice-abcd1234");
570 let score = d.calculate_trust_score().await;
571 assert!(
572 score < 1.0,
573 "Device containing 'new' should have score < 1.0, got {score}"
574 );
575 assert!(
576 (score - 0.9).abs() < f64::EPSILON,
577 "Expected 0.9, got {score}"
578 );
579 }
580
581 #[tokio::test]
582 async fn test_trust_score_test_device_is_reduced() {
583 let d = device("testdevice-abcd1234");
584 let score = d.calculate_trust_score().await;
585 assert!(
586 (score - 0.8).abs() < f64::EPSILON,
587 "Expected 0.8 for 'test' device, got {score}"
588 );
589 }
590
591 #[tokio::test]
592 async fn test_trust_score_new_and_test_device() {
593 let d = device("new-testdevice-abcd1234");
594 let score = d.calculate_trust_score().await;
595 assert!(
597 (score - 0.7).abs() < f64::EPSILON,
598 "Expected 0.7 for device containing both 'new' and 'test', got {score}"
599 );
600 }
601
602 #[tokio::test]
603 async fn test_trust_score_always_in_range() {
604 for id in &[
606 "new-test-device-id",
607 "new-new-new-test-test-test-device",
608 "aaaaaaaaaaaaa",
609 ] {
610 let score = device(id).calculate_trust_score().await;
611 assert!(
612 (0.0f64..=1.0).contains(&score),
613 "Trust score {score} out of range [0,1] for '{id}'"
614 );
615 }
616 }
617
618 #[tokio::test]
621 async fn test_challenge_valid_hex_16_chars() {
622 let d = device("abcdefgh-1234");
623 let challenge = "0123456789abcdef"; assert!(d.validate_device_challenge(challenge).await.unwrap());
625 }
626
627 #[tokio::test]
628 async fn test_challenge_valid_base64url() {
629 let d = device("abcdefgh-1234");
630 let challenge = "SGVsbG8gV29ybGQh"; assert!(d.validate_device_challenge(challenge).await.unwrap());
632 }
633
634 #[tokio::test]
635 async fn test_challenge_too_short() {
636 let d = device("abcdefgh-1234");
637 assert!(
638 !d.validate_device_challenge("short123").await.unwrap(),
639 "Challenge < 16 chars must be rejected"
640 );
641 }
642
643 #[tokio::test]
644 async fn test_challenge_empty() {
645 let d = device("abcdefgh-1234");
646 assert!(!d.validate_device_challenge("").await.unwrap());
647 }
648
649 #[tokio::test]
650 async fn test_challenge_invalid_chars() {
651 let d = device("abcdefgh-1234");
652 let challenge = "Hello World!!!!!";
654 assert!(
655 !d.validate_device_challenge(challenge).await.unwrap(),
656 "Challenge with spaces/exclamation marks must be rejected"
657 );
658 }
659
660 #[tokio::test]
663 async fn test_authenticate_empty_challenge_returns_false() {
664 let d = device("abcdefgh-1234");
665 assert!(!d.authenticate("").await.unwrap());
666 }
667
668 #[tokio::test]
669 async fn test_authenticate_valid_device_and_challenge() {
670 let d = device("550e8400-e29b-41d4-a716-446655440000");
672 let challenge = "SGVsbG8gV29ybGQh"; assert!(
675 d.authenticate(challenge).await.unwrap(),
676 "Valid device + valid challenge should authenticate"
677 );
678 }
679
680 #[tokio::test]
681 async fn test_authenticate_short_device_id_fails() {
682 let d = device("tiny"); let challenge = "SGVsbG8gV29ybGQh";
684 assert!(
685 !d.authenticate(challenge).await.unwrap(),
686 "Short device ID must fail authentication"
687 );
688 }
689
690 #[tokio::test]
691 async fn test_authenticate_at_minimum_trust_score_passes() {
692 let d = device("new-test-device-abcde"); let challenge = "SGVsbG8gV29ybGQh";
696 assert!(
697 d.authenticate(challenge).await.unwrap(),
698 "Device at minimum trust score (0.7) should still authenticate"
699 );
700 }
701}