use crate::errors::{AuthError, Result};
use crate::storage::AuthStorage;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceAuthorizationRequest {
pub client_id: String,
pub scope: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceAuthorizationResponse {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
pub interval: u64,
pub expires_in: u64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DeviceTokenRequest {
pub grant_type: String,
pub device_code: String,
pub client_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredDeviceAuthorization {
pub device_code: String,
pub user_code: String,
pub client_id: String,
pub scope: Option<String>,
pub status: DeviceAuthorizationStatus,
pub user_id: Option<String>,
pub created_at: SystemTime,
pub expires_at: SystemTime,
pub last_poll: Option<SystemTime>,
#[serde(default)]
pub slow_down_count: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DeviceAuthorizationStatus {
Pending,
Authorized,
Denied,
Expired,
}
use std::fmt;
#[derive(Clone)]
pub struct DeviceAuthManager {
storage: Arc<dyn AuthStorage>,
authorizations: Arc<tokio::sync::RwLock<HashMap<String, StoredDeviceAuthorization>>>,
default_expiration: Duration,
min_interval: Duration,
verification_uri: String,
}
impl fmt::Debug for DeviceAuthManager {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DeviceAuthManager")
.field("storage", &"<dyn AuthStorage>")
.field("default_expiration", &self.default_expiration)
.field("min_interval", &self.min_interval)
.field("verification_uri", &self.verification_uri)
.finish()
}
}
impl DeviceAuthManager {
pub fn new(storage: Arc<dyn AuthStorage>, verification_uri: String) -> Self {
Self {
storage,
authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
default_expiration: Duration::from_secs(600), min_interval: Duration::from_secs(5), verification_uri,
}
}
pub fn with_settings(
storage: Arc<dyn AuthStorage>,
verification_uri: String,
expiration: Duration,
min_interval: Duration,
) -> Self {
Self {
storage,
authorizations: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
default_expiration: expiration,
min_interval,
verification_uri,
}
}
pub fn expiration(mut self, expiration: Duration) -> Self {
self.default_expiration = expiration;
self
}
pub fn interval(mut self, interval: Duration) -> Self {
self.min_interval = interval;
self
}
pub async fn create_authorization(
&self,
request: DeviceAuthorizationRequest,
) -> Result<DeviceAuthorizationResponse> {
self.validate_request(&request)?;
let device_code = format!("dc_{}", Uuid::new_v4().simple());
let user_code = self.generate_user_code();
let now = SystemTime::now();
let expires_at = now + self.default_expiration;
let stored = StoredDeviceAuthorization {
device_code: device_code.clone(),
user_code: user_code.clone(),
client_id: request.client_id.clone(),
scope: request.scope.clone(),
status: DeviceAuthorizationStatus::Pending,
user_id: None,
created_at: now,
expires_at,
last_poll: None,
slow_down_count: 0,
};
let device_key = format!("device_code:{}", device_code);
let user_key = format!("user_code:{}", user_code);
let serialized = serde_json::to_string(&stored)
.map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
self.storage
.store_kv(
&device_key,
serialized.as_bytes(),
Some(self.default_expiration),
)
.await
.map_err(|e| {
AuthError::internal(format!("Failed to store device authorization: {}", e))
})?;
self.storage
.store_kv(
&user_key,
serialized.as_bytes(),
Some(self.default_expiration),
)
.await
.map_err(|e| {
AuthError::internal(format!("Failed to store user code mapping: {}", e))
})?;
let mut authorizations = self.authorizations.write().await;
authorizations.insert(device_code.clone(), stored);
self.cleanup_expired(&mut authorizations, now);
let verification_uri_complete =
format!("{}?user_code={}", self.verification_uri, user_code);
Ok(DeviceAuthorizationResponse {
device_code,
user_code,
verification_uri: self.verification_uri.clone(),
verification_uri_complete: Some(verification_uri_complete),
interval: self.min_interval.as_secs(),
expires_in: self.default_expiration.as_secs(),
})
}
pub async fn poll_authorization(&self, device_code: &str) -> Result<StoredDeviceAuthorization> {
let device_key = format!("device_code:{}", device_code);
let mut stored = if let Some(data) = self.storage.get_kv(&device_key).await? {
let serialized = String::from_utf8(data)
.map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
AuthError::internal(format!("Failed to deserialize device auth: {}", e))
})?
} else {
let authorizations = self.authorizations.read().await;
authorizations
.get(device_code)
.cloned()
.ok_or_else(|| AuthError::auth_method("device_auth", "Invalid device_code"))?
};
let now = SystemTime::now();
if now > stored.expires_at {
stored.status = DeviceAuthorizationStatus::Expired;
return Err(AuthError::auth_method("device_auth", "Device code expired"));
}
let effective_interval = self.min_interval
+ Duration::from_secs(5 * u64::from(stored.slow_down_count));
if let Some(last_poll) = stored.last_poll {
let elapsed = now.duration_since(last_poll).unwrap_or(Duration::ZERO);
if elapsed < effective_interval {
stored.slow_down_count += 1;
stored.last_poll = Some(now);
let serialized = serde_json::to_string(&stored).map_err(|e| {
AuthError::internal(format!("Failed to serialize device auth: {}", e))
})?;
self.storage
.store_kv(
&device_key,
serialized.as_bytes(),
Some(self.default_expiration),
)
.await
.ok();
let mut authorizations = self.authorizations.write().await;
authorizations.insert(device_code.to_string(), stored);
return Err(AuthError::auth_method("device_auth", "slow_down"));
}
}
stored.last_poll = Some(now);
let serialized = serde_json::to_string(&stored)
.map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
self.storage
.store_kv(
&device_key,
serialized.as_bytes(),
Some(self.default_expiration),
)
.await
.ok();
let mut authorizations = self.authorizations.write().await;
authorizations.insert(device_code.to_string(), stored.clone());
match stored.status {
DeviceAuthorizationStatus::Pending => Err(AuthError::auth_method(
"device_auth",
"authorization_pending",
)),
DeviceAuthorizationStatus::Authorized => Ok(stored),
DeviceAuthorizationStatus::Denied => {
Err(AuthError::auth_method("device_auth", "access_denied"))
}
DeviceAuthorizationStatus::Expired => {
Err(AuthError::auth_method("device_auth", "expired_token"))
}
}
}
pub async fn authorize_device(&self, user_code: &str, user_id: &str) -> Result<()> {
let user_key = format!("user_code:{}", user_code);
let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
let serialized = String::from_utf8(data)
.map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
AuthError::internal(format!("Failed to deserialize device auth: {}", e))
})?
} else {
return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
};
let now = SystemTime::now();
if now > stored.expires_at {
return Err(AuthError::auth_method("device_auth", "Device code expired"));
}
stored.status = DeviceAuthorizationStatus::Authorized;
stored.user_id = Some(user_id.to_string());
let serialized = serde_json::to_string(&stored)
.map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
let device_key = format!("device_code:{}", stored.device_code);
self.storage
.store_kv(
&device_key,
serialized.as_bytes(),
Some(self.default_expiration),
)
.await?;
self.storage
.store_kv(
&user_key,
serialized.as_bytes(),
Some(self.default_expiration),
)
.await?;
let mut authorizations = self.authorizations.write().await;
authorizations.insert(stored.device_code.clone(), stored);
Ok(())
}
pub async fn deny_device(&self, user_code: &str) -> Result<()> {
let user_key = format!("user_code:{}", user_code);
let mut stored = if let Some(data) = self.storage.get_kv(&user_key).await? {
let serialized = String::from_utf8(data)
.map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
serde_json::from_str::<StoredDeviceAuthorization>(&serialized).map_err(|e| {
AuthError::internal(format!("Failed to deserialize device auth: {}", e))
})?
} else {
return Err(AuthError::auth_method("device_auth", "Invalid user_code"));
};
stored.status = DeviceAuthorizationStatus::Denied;
let serialized = serde_json::to_string(&stored)
.map_err(|e| AuthError::internal(format!("Failed to serialize device auth: {}", e)))?;
let device_key = format!("device_code:{}", stored.device_code);
self.storage
.store_kv(
&device_key,
serialized.as_bytes(),
Some(self.default_expiration),
)
.await?;
self.storage
.store_kv(
&user_key,
serialized.as_bytes(),
Some(self.default_expiration),
)
.await?;
let mut authorizations = self.authorizations.write().await;
authorizations.insert(stored.device_code.clone(), stored);
Ok(())
}
pub async fn get_by_user_code(&self, user_code: &str) -> Result<StoredDeviceAuthorization> {
let user_key = format!("user_code:{}", user_code);
if let Some(data) = self.storage.get_kv(&user_key).await? {
let serialized = String::from_utf8(data)
.map_err(|_| AuthError::internal("Invalid UTF-8 in stored device auth data"))?;
let stored: StoredDeviceAuthorization =
serde_json::from_str(&serialized).map_err(|e| {
AuthError::internal(format!("Failed to deserialize device auth: {}", e))
})?;
let now = SystemTime::now();
if now > stored.expires_at {
return Err(AuthError::auth_method("device_auth", "User code expired"));
}
Ok(stored)
} else {
Err(AuthError::auth_method("device_auth", "Invalid user_code"))
}
}
fn validate_request(&self, request: &DeviceAuthorizationRequest) -> Result<()> {
if request.client_id.is_empty() {
return Err(AuthError::auth_method("device_auth", "Missing client_id"));
}
Ok(())
}
fn generate_user_code(&self) -> String {
use rand::RngExt;
const CHARS: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; let mut rng = rand::rng();
let code: String = (0..9)
.map(|i| {
if i == 4 {
'-'
} else {
let idx = rng.random_range(0..CHARS.len());
CHARS[idx] as char
}
})
.collect();
code
}
fn cleanup_expired(
&self,
authorizations: &mut HashMap<String, StoredDeviceAuthorization>,
now: SystemTime,
) {
authorizations.retain(|_, auth| now <= auth.expires_at);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::MemoryStorage;
use tokio::time::sleep;
fn create_test_manager() -> DeviceAuthManager {
let storage = Arc::new(MemoryStorage::new());
DeviceAuthManager::new(storage, "https://example.com/device".to_string())
}
#[tokio::test]
async fn test_create_authorization() {
let manager = create_test_manager();
let request = DeviceAuthorizationRequest {
client_id: "test_client".to_string(),
scope: Some("openid profile".to_string()),
};
let response = manager.create_authorization(request).await.unwrap();
assert!(response.device_code.starts_with("dc_"));
assert_eq!(response.user_code.len(), 9); assert!(response.user_code.contains('-'));
assert_eq!(response.verification_uri, "https://example.com/device");
assert!(response.verification_uri_complete.is_some());
assert_eq!(response.interval, 5);
assert_eq!(response.expires_in, 600);
}
#[tokio::test]
async fn test_poll_pending() {
let manager = create_test_manager();
let request = DeviceAuthorizationRequest {
client_id: "test_client".to_string(),
scope: None,
};
let response = manager.create_authorization(request).await.unwrap();
let result = manager.poll_authorization(&response.device_code).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("authorization_pending"));
}
#[tokio::test]
async fn test_authorize_and_poll() {
let manager = create_test_manager();
let request = DeviceAuthorizationRequest {
client_id: "test_client".to_string(),
scope: Some("openid".to_string()),
};
let response = manager.create_authorization(request).await.unwrap();
manager
.authorize_device(&response.user_code, "user_123")
.await
.unwrap();
let stored = manager
.poll_authorization(&response.device_code)
.await
.unwrap();
assert_eq!(stored.status, DeviceAuthorizationStatus::Authorized);
assert_eq!(stored.user_id, Some("user_123".to_string()));
}
#[tokio::test]
async fn test_deny_device() {
let manager = create_test_manager();
let request = DeviceAuthorizationRequest {
client_id: "test_client".to_string(),
scope: None,
};
let response = manager.create_authorization(request).await.unwrap();
manager.deny_device(&response.user_code).await.unwrap();
let result = manager.poll_authorization(&response.device_code).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("access_denied"));
}
#[tokio::test]
async fn test_slow_down() {
let manager = create_test_manager();
let request = DeviceAuthorizationRequest {
client_id: "test_client".to_string(),
scope: None,
};
let response = manager.create_authorization(request).await.unwrap();
let _ = manager.poll_authorization(&response.device_code).await;
let result = manager.poll_authorization(&response.device_code).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("slow_down"));
}
#[tokio::test]
async fn test_expiration() {
let storage = Arc::new(MemoryStorage::new());
let manager = DeviceAuthManager::with_settings(
storage,
"https://example.com/device".to_string(),
Duration::from_millis(100),
Duration::from_secs(1),
);
let request = DeviceAuthorizationRequest {
client_id: "test_client".to_string(),
scope: None,
};
let response = manager.create_authorization(request).await.unwrap();
sleep(Duration::from_millis(150)).await;
let result = manager.poll_authorization(&response.device_code).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("expired"));
}
#[tokio::test]
async fn test_chainable_expiration_and_interval() {
let storage = Arc::new(MemoryStorage::new());
let manager = DeviceAuthManager::new(storage, "https://example.com/device".to_string())
.expiration(Duration::from_secs(300))
.interval(Duration::from_secs(10));
let request = DeviceAuthorizationRequest {
client_id: "test_client".to_string(),
scope: None,
};
let response = manager.create_authorization(request).await.unwrap();
assert_eq!(response.expires_in, 300);
assert_eq!(response.interval, 10);
}
}