use super::error::{OAuthError, OAuthResult};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Deserialize)]
pub struct DeviceCodeResponse {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub expires_in: u64,
pub interval: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceTokenResponse {
pub access_token: String,
pub token_type: String,
#[serde(default)]
pub scope: String,
}
#[derive(Debug, Clone)]
pub enum DeviceFlowState {
Pending {
user_code: String,
verification_uri: String,
},
Completed(DeviceTokenResponse),
}
#[derive(Debug, Deserialize)]
struct PollRaw {
access_token: Option<String>,
token_type: Option<String>,
#[serde(default)]
scope: String,
error: Option<String>,
error_description: Option<String>,
}
pub struct DeviceFlow {
client_id: String,
scopes: Vec<String>,
device_code_url: String,
token_url: String,
client: Client,
}
impl DeviceFlow {
pub fn new(
client_id: impl Into<String>,
scopes: Vec<String>,
device_code_url: impl Into<String>,
token_url: impl Into<String>,
) -> OAuthResult<Self> {
let client =
crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
.map_err(OAuthError::token_exchange_failed)?;
Ok(Self {
client_id: client_id.into(),
scopes,
device_code_url: device_code_url.into(),
token_url: token_url.into(),
client,
})
}
pub async fn request_device_code(&self) -> OAuthResult<DeviceCodeResponse> {
let scope = self.scopes.join(" ");
let response = self
.client
.post(&self.device_code_url)
.header("Accept", "application/json")
.form(&[("client_id", self.client_id.as_str()), ("scope", &scope)])
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(OAuthError::token_exchange_failed(format!(
"Device code request failed: HTTP {} — {}",
status, body
)));
}
response.json::<DeviceCodeResponse>().await.map_err(|e| {
OAuthError::token_exchange_failed(format!(
"Failed to parse device code response: {}",
e
))
})
}
pub async fn poll_for_token(
&self,
device_code: &DeviceCodeResponse,
) -> OAuthResult<DeviceTokenResponse> {
let mut interval_secs = device_code.interval;
let expires_at = std::time::Instant::now() + Duration::from_secs(device_code.expires_in);
loop {
if std::time::Instant::now() >= expires_at {
return Err(OAuthError::token_exchange_failed(
"Device code expired before the user completed authorisation",
));
}
let response = self
.client
.post(&self.token_url)
.header("Accept", "application/json")
.form(&[
("client_id", self.client_id.as_str()),
("device_code", device_code.device_code.as_str()),
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
])
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(OAuthError::token_exchange_failed(format!(
"Token polling failed: HTTP {} — {}",
status, body
)));
}
let poll_raw: PollRaw = response.json().await.map_err(|e| {
OAuthError::token_exchange_failed(format!(
"Failed to parse token poll response: {}",
e
))
})?;
if let Some(ref err) = poll_raw.error {
match err.as_str() {
"authorization_pending" => {}
"slow_down" => {
interval_secs += 5;
}
"access_denied" => {
return Err(OAuthError::token_exchange_failed(
"User denied the authorisation request",
));
}
"expired_token" | "token_expired" => {
return Err(OAuthError::token_exchange_failed("Device code expired"));
}
"unsupported_grant_type" => {
return Err(OAuthError::token_exchange_failed(
"Unsupported grant type — grant_type must be \
urn:ietf:params:oauth:grant-type:device_code",
));
}
"incorrect_client_credentials" => {
return Err(OAuthError::token_exchange_failed(
"Incorrect client credentials — check the client_id",
));
}
"incorrect_device_code" => {
return Err(OAuthError::token_exchange_failed(
"The device_code provided is not valid",
));
}
"device_flow_disabled" => {
return Err(OAuthError::token_exchange_failed(
"Device flow is not enabled for this OAuth app",
));
}
other => {
return Err(OAuthError::token_exchange_failed(format!(
"Unexpected error from provider: {} — {}",
other,
poll_raw.error_description.as_deref().unwrap_or("")
)));
}
}
} else if let Some(access_token) = poll_raw.access_token {
let token_type = poll_raw.token_type.unwrap_or_default();
return Ok(DeviceTokenResponse {
access_token,
token_type,
scope: poll_raw.scope,
});
} else {
return Err(OAuthError::token_exchange_failed(
"Token poll response contained neither an error nor an access token",
));
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
}
}
}