use crate::{
config::DeviceFlowConfig,
error::{DeviceFlowError, Result},
provider::Provider,
token::TokenManager,
types::{
AuthorizationResponse, DeviceAuthorizationRequest, DeviceTokenRequest, ErrorResponse,
TokenResponse,
},
};
use reqwest::Client;
use secrecy::ExposeSecret;
use std::time::Duration;
use time::OffsetDateTime;
use tokio::time::sleep;
use url::Url;
#[derive(Debug, Clone)]
pub struct DeviceFlow {
provider: Provider,
config: DeviceFlowConfig,
client: Client,
auth_response: Option<AuthorizationResponse>,
}
impl DeviceFlow {
pub fn new(provider: Provider, config: DeviceFlowConfig) -> Result<Self> {
config.validate(provider)?;
let client = Self::build_client(&config)?;
Ok(Self {
provider,
config,
client,
auth_response: None,
})
}
pub async fn initialize(&mut self) -> Result<&AuthorizationResponse> {
let auth_endpoint = if let Some(ref config) = self.config.generic_provider_config {
config.device_authorization_endpoint.clone()
} else {
Url::parse(self.provider.device_authorization_endpoint()).map_err(|e| {
DeviceFlowError::other(format!("Invalid authorization endpoint: {e}"))
})?
};
let scopes = self.config.effective_scopes(self.provider);
let scope_string = if scopes.is_empty() {
None
} else {
Some(scopes.join(" "))
};
let request = DeviceAuthorizationRequest {
client_id: self.config.client_id.clone(),
scope: scope_string,
};
let mut req_builder = self.client.post(auth_endpoint).form(&request);
if let Some(ref client_secret) = self.config.client_secret {
req_builder = req_builder.form(&[("client_secret", client_secret.expose_secret())]);
}
for (key, value) in self.provider.headers() {
req_builder = req_builder.header(key, value);
}
for (key, value) in &self.config.additional_headers {
req_builder = req_builder.header(key, value);
}
let response = req_builder.send().await?;
if !response.status().is_success() {
let error_response: ErrorResponse = response.json().await?;
return Err(DeviceFlowError::oauth_error(
error_response.error,
error_response.error_description.unwrap_or_default(),
));
}
let auth_response: AuthorizationResponse = response.json().await?;
self.auth_response = Some(auth_response);
Ok(self.auth_response.as_ref().unwrap())
}
pub async fn poll_for_token(&self) -> Result<TokenResponse> {
let auth_response = self
.auth_response
.as_ref()
.ok_or_else(|| DeviceFlowError::other("Must call initialize() first"))?;
let token_endpoint = if let Some(ref config) = self.config.generic_provider_config {
config.token_endpoint.clone()
} else {
Url::parse(self.provider.token_endpoint())
.map_err(|e| DeviceFlowError::other(format!("Invalid token endpoint: {e}")))?
};
let request = DeviceTokenRequest {
grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(),
device_code: auth_response.device_code().to_string(),
client_id: self.config.client_id.clone(),
};
let mut poll_interval = self.config.effective_poll_interval(self.provider);
let mut attempts = 0;
loop {
if attempts >= self.config.max_attempts {
return Err(DeviceFlowError::MaxAttemptsExceeded(
self.config.max_attempts,
));
}
attempts += 1;
if attempts > 1 {
sleep(poll_interval).await;
}
let mut req_builder = self.client.post(token_endpoint.clone()).form(&request);
if let Some(ref client_secret) = self.config.client_secret {
req_builder = req_builder.form(&[("client_secret", client_secret.expose_secret())]);
}
for (key, value) in self.provider.headers() {
req_builder = req_builder.header(key, value);
}
for (key, value) in &self.config.additional_headers {
req_builder = req_builder.header(key, value);
}
let response = req_builder.send().await?;
if response.status().is_success() {
let mut token_response: TokenResponse = response.json().await?;
token_response.issued_at = OffsetDateTime::now_utc();
return Ok(token_response);
}
let error_response: ErrorResponse = response.json().await?;
match error_response.error.as_str() {
"authorization_pending" => {
continue;
}
"slow_down" => {
poll_interval = Duration::from_secs(
(poll_interval.as_secs() as f64 * self.config.backoff_multiplier) as u64,
)
.min(self.config.max_poll_interval);
continue;
}
"access_denied" => {
return Err(DeviceFlowError::AuthorizationDenied);
}
"expired_token" => {
return Err(DeviceFlowError::ExpiredToken);
}
_ => {
return Err(DeviceFlowError::oauth_error(
error_response.error,
error_response.error_description.unwrap_or_default(),
));
}
}
}
}
pub async fn run(&mut self) -> Result<TokenManager> {
let _auth_response = self.initialize().await?;
let token_response = self.poll_for_token().await?;
TokenManager::new(token_response, self.provider, self.config.clone())
}
pub fn authorization_response(&self) -> Option<&AuthorizationResponse> {
self.auth_response.as_ref()
}
pub fn provider(&self) -> Provider {
self.provider
}
pub fn config(&self) -> &DeviceFlowConfig {
&self.config
}
pub fn is_initialized(&self) -> bool {
self.auth_response.is_some()
}
pub fn reset(&mut self) {
self.auth_response = None;
}
pub fn with_provider(self, provider: Provider) -> Result<Self> {
Self::new(provider, self.config)
}
pub fn with_config(mut self, config: DeviceFlowConfig) -> Result<Self> {
config.validate(self.provider)?;
self.client = Self::build_client(&config)?;
self.config = config;
Ok(self)
}
fn build_client(config: &DeviceFlowConfig) -> Result<Client> {
let mut client_builder = Client::builder().timeout(config.request_timeout);
if let Some(ref user_agent) = config.user_agent {
client_builder = client_builder.user_agent(user_agent);
}
client_builder.build().map_err(DeviceFlowError::from)
}
}
pub async fn run_device_flow(provider: Provider, config: DeviceFlowConfig) -> Result<TokenManager> {
let mut device_flow = DeviceFlow::new(provider, config)?;
device_flow.run().await
}
pub async fn run_device_flow_with_callback<F>(
provider: Provider,
config: DeviceFlowConfig,
callback: F,
) -> Result<TokenManager>
where
F: FnOnce(&AuthorizationResponse) -> Result<()>,
{
let mut device_flow = DeviceFlow::new(provider, config)?;
let auth_response = device_flow.initialize().await?;
callback(auth_response)?;
let token_response = device_flow.poll_for_token().await?;
TokenManager::new(token_response, provider, device_flow.config)
}