use authkestra_core::{AuthError, OAuthToken};
use serde::{Deserialize, Serialize};
use std::thread::sleep;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceAuthorizationResponse {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
pub expires_in: u64,
pub interval: Option<u64>,
}
pub struct DeviceFlow {
client_id: String,
device_authorization_url: String,
token_url: String,
http_client: reqwest::Client,
}
impl DeviceFlow {
pub fn new(client_id: String, device_authorization_url: String, token_url: String) -> Self {
Self {
client_id,
device_authorization_url,
token_url,
http_client: reqwest::Client::new(),
}
}
pub async fn initiate_device_authorization(
&self,
scopes: &[&str],
) -> Result<DeviceAuthorizationResponse, AuthError> {
let scope_param = scopes.join(" ");
let response = self
.http_client
.post(&self.device_authorization_url)
.header("Accept", "application/json")
.form(&[("client_id", &self.client_id), ("scope", &scope_param)])
.send()
.await
.map_err(|_| AuthError::Network)?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AuthError::Provider(format!(
"Device authorization request failed: {}",
error_text
)));
}
response
.json::<DeviceAuthorizationResponse>()
.await
.map_err(|e| {
AuthError::Provider(format!(
"Failed to parse device authorization response: {}",
e
))
})
}
pub async fn poll_for_token(
&self,
device_code: &str,
interval: Option<u64>,
) -> Result<OAuthToken, AuthError> {
let mut current_interval = interval.unwrap_or(5);
loop {
let response = self
.http_client
.post(&self.token_url)
.header("Accept", "application/json")
.form(&[
("client_id", &self.client_id),
("device_code", &device_code.to_string()),
(
"grant_type",
&"urn:ietf:params:oauth:grant-type:device_code".to_string(),
),
])
.send()
.await
.map_err(|_| AuthError::Network)?;
let status = response.status();
if status.is_success() {
return response.json::<OAuthToken>().await.map_err(|e| {
AuthError::Provider(format!("Failed to parse token response: {}", e))
});
} else {
let error_resp: serde_json::Value = response
.json()
.await
.map_err(|_| AuthError::Provider("Failed to parse error response".into()))?;
let error = error_resp["error"].as_str().unwrap_or("unknown_error");
match error {
"authorization_pending" => {
}
"slow_down" => {
current_interval += 5;
}
"access_denied" => {
return Err(AuthError::Provider("Access denied by user".into()));
}
"expired_token" => {
return Err(AuthError::Provider("Device code expired".into()));
}
_ => {
return Err(AuthError::Provider(format!(
"Token polling failed: {}",
error
)));
}
}
}
sleep(Duration::from_secs(current_interval));
}
}
}