use super::OAuthClient;
use crate::error::{Error, Result};
use crate::token::{ErrorResponse, Token, TokenResponse};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DeviceAuthorization {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub verification_uri_complete: Option<String>,
pub expires_in: u32,
#[serde(default = "default_interval")]
pub interval: u32,
}
const fn default_interval() -> u32 {
5
}
#[derive(Debug)]
pub struct DeviceFlow {
client: OAuthClient,
}
impl DeviceFlow {
#[must_use]
pub const fn new(client: OAuthClient) -> Self {
Self { client }
}
pub async fn request_device_authorization(
&self,
scopes: Option<&[String]>,
) -> Result<DeviceAuthorization> {
let device_auth_url = self
.client
.provider
.device_auth_url
.as_ref()
.ok_or_else(|| {
Error::InvalidConfig(format!(
"Provider {} does not support device flow",
self.client.provider.name
))
})?;
let scope_str = scopes.map_or_else(
|| self.client.provider.default_scopes.join(" "),
|s| s.join(" "),
);
let mut params = HashMap::new();
params.insert("client_id", self.client.client_id.as_str());
if !scope_str.is_empty() {
params.insert("scope", &scope_str);
}
let response = self
.client
.http_client
.post(device_auth_url.clone())
.form(¶ms)
.send()
.await?;
if !response.status().is_success() {
let error: ErrorResponse = response.json().await?;
return Err(error.into_error());
}
response.json().await.map_err(Into::into)
}
pub async fn poll_for_token(&self, device_code: &str, interval: Duration) -> Result<Token> {
tokio::time::sleep(interval).await;
let mut params = HashMap::new();
params.insert("grant_type", "urn:ietf:params:oauth:grant-type:device_code");
params.insert("device_code", device_code);
params.insert("client_id", &self.client.client_id);
let response = self
.client
.http_client
.post(self.client.provider.token_url.clone())
.form(¶ms)
.send()
.await?;
if !response.status().is_success() {
let error: ErrorResponse = response.json().await?;
return match error.error.as_str() {
"authorization_pending" => Err(Error::oauth_error(
"authorization_pending",
"User has not yet authorized",
)),
"slow_down" => Err(Error::oauth_error(
"slow_down",
"Polling too frequently, slow down",
)),
"access_denied" => Err(Error::AccessDenied),
"expired_token" => Err(Error::TokenExpired),
_ => Err(error.into_error()),
};
}
let token_response: TokenResponse = response.json().await?;
Token::from_response(token_response)
}
pub async fn authorize(
&self,
scopes: Option<&[String]>,
max_attempts: usize,
) -> Result<(DeviceAuthorization, Token)> {
let auth = self.request_device_authorization(scopes).await?;
let mut interval = Duration::from_secs(u64::from(auth.interval));
let mut attempts = 0;
loop {
if max_attempts > 0 && attempts >= max_attempts {
return Err(Error::Timeout(auth.expires_in.into()));
}
match self.poll_for_token(&auth.device_code, interval).await {
Ok(token) => return Ok((auth, token)),
Err(Error::OAuth { ref error, .. }) if error == "authorization_pending" => {
attempts += 1;
}
Err(Error::OAuth { ref error, .. }) if error == "slow_down" => {
interval += Duration::from_secs(5);
attempts += 1;
}
Err(e) => return Err(e),
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::redundant_clone, clippy::manual_string_new, clippy::needless_collect, clippy::unreadable_literal, clippy::used_underscore_items, clippy::similar_names)]
mod tests {
use super::*;
use crate::provider::Provider;
#[test]
fn test_device_flow_creation() {
let provider = Provider::google().unwrap();
let client = OAuthClient::new("test_client", provider);
let _flow = DeviceFlow::new(client);
}
#[test]
fn test_default_interval() {
assert_eq!(default_interval(), 5);
}
#[test]
fn test_device_auth_deserialization() {
let json = r#"{
"device_code": "dev123",
"user_code": "USER-CODE",
"verification_uri": "https://example.com/device",
"expires_in": 1800,
"interval": 5
}"#;
let auth: DeviceAuthorization = serde_json::from_str(json).unwrap();
assert_eq!(auth.device_code, "dev123");
assert_eq!(auth.user_code, "USER-CODE");
assert_eq!(auth.interval, 5);
}
}