oauth-device-flows 0.1.0

A specialized Rust library implementing OAuth 2.0 Device Authorization Grant (RFC 8628)
Documentation
//! Main device flow implementation

use crate::{
    config::DeviceFlowConfig,
    error::{DeviceFlowError, Result},
    provider::Provider,
    token::TokenManager,
    types::{
        AuthorizationResponse, DeviceAuthorizationRequest, DeviceTokenRequest, ErrorResponse,
        TokenResponse,
    },
};
use reqwest::Client;
use secrecy::ExposeSecret;
use std::time::Duration;
use time::OffsetDateTime;
use tokio::time::sleep;
use url::Url;

/// Main device flow implementation
#[derive(Debug, Clone)]
pub struct DeviceFlow {
    /// OAuth provider
    provider: Provider,

    /// Configuration
    config: DeviceFlowConfig,

    /// HTTP client
    client: Client,

    /// Current authorization response (if any)
    auth_response: Option<AuthorizationResponse>,
}

impl DeviceFlow {
    /// Create a new device flow instance
    pub fn new(provider: Provider, config: DeviceFlowConfig) -> Result<Self> {
        // Validate configuration
        config.validate(provider)?;

        let client = Self::build_client(&config)?;

        Ok(Self {
            provider,
            config,
            client,
            auth_response: None,
        })
    }

    /// Initialize the device authorization flow
    pub async fn initialize(&mut self) -> Result<&AuthorizationResponse> {
        let auth_endpoint = if let Some(ref config) = self.config.generic_provider_config {
            config.device_authorization_endpoint.clone()
        } else {
            Url::parse(self.provider.device_authorization_endpoint()).map_err(|e| {
                DeviceFlowError::other(format!("Invalid authorization endpoint: {e}"))
            })?
        };

        let scopes = self.config.effective_scopes(self.provider);
        let scope_string = if scopes.is_empty() {
            None
        } else {
            Some(scopes.join(" "))
        };

        let request = DeviceAuthorizationRequest {
            client_id: self.config.client_id.clone(),
            scope: scope_string,
        };

        let mut req_builder = self.client.post(auth_endpoint).form(&request);

        // Add client secret if required
        if let Some(ref client_secret) = self.config.client_secret {
            req_builder = req_builder.form(&[("client_secret", client_secret.expose_secret())]);
        }

        // Add provider-specific headers
        for (key, value) in self.provider.headers() {
            req_builder = req_builder.header(key, value);
        }

        // Add additional headers
        for (key, value) in &self.config.additional_headers {
            req_builder = req_builder.header(key, value);
        }

        let response = req_builder.send().await?;

        if !response.status().is_success() {
            let error_response: ErrorResponse = response.json().await?;
            return Err(DeviceFlowError::oauth_error(
                error_response.error,
                error_response.error_description.unwrap_or_default(),
            ));
        }

        let auth_response: AuthorizationResponse = response.json().await?;
        self.auth_response = Some(auth_response);

        Ok(self.auth_response.as_ref().unwrap())
    }

    /// Poll for the token
    pub async fn poll_for_token(&self) -> Result<TokenResponse> {
        let auth_response = self
            .auth_response
            .as_ref()
            .ok_or_else(|| DeviceFlowError::other("Must call initialize() first"))?;

        let token_endpoint = if let Some(ref config) = self.config.generic_provider_config {
            config.token_endpoint.clone()
        } else {
            Url::parse(self.provider.token_endpoint())
                .map_err(|e| DeviceFlowError::other(format!("Invalid token endpoint: {e}")))?
        };

        let request = DeviceTokenRequest {
            grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(),
            device_code: auth_response.device_code().to_string(),
            client_id: self.config.client_id.clone(),
        };

        let mut poll_interval = self.config.effective_poll_interval(self.provider);
        let mut attempts = 0;

        loop {
            if attempts >= self.config.max_attempts {
                return Err(DeviceFlowError::MaxAttemptsExceeded(
                    self.config.max_attempts,
                ));
            }

            attempts += 1;

            // Wait before making the request (except for the first attempt)
            if attempts > 1 {
                sleep(poll_interval).await;
            }

            let mut req_builder = self.client.post(token_endpoint.clone()).form(&request);

            // Add client secret if required
            if let Some(ref client_secret) = self.config.client_secret {
                req_builder = req_builder.form(&[("client_secret", client_secret.expose_secret())]);
            }

            // Add provider-specific headers
            for (key, value) in self.provider.headers() {
                req_builder = req_builder.header(key, value);
            }

            // Add additional headers
            for (key, value) in &self.config.additional_headers {
                req_builder = req_builder.header(key, value);
            }

            let response = req_builder.send().await?;

            if response.status().is_success() {
                let mut token_response: TokenResponse = response.json().await?;

                // Set the issued_at timestamp
                token_response.issued_at = OffsetDateTime::now_utc();

                return Ok(token_response);
            }

            // Handle error responses
            let error_response: ErrorResponse = response.json().await?;

            match error_response.error.as_str() {
                "authorization_pending" => {
                    // Continue polling
                    continue;
                }
                "slow_down" => {
                    // Increase polling interval
                    poll_interval = Duration::from_secs(
                        (poll_interval.as_secs() as f64 * self.config.backoff_multiplier) as u64,
                    )
                    .min(self.config.max_poll_interval);
                    continue;
                }
                "access_denied" => {
                    return Err(DeviceFlowError::AuthorizationDenied);
                }
                "expired_token" => {
                    return Err(DeviceFlowError::ExpiredToken);
                }
                _ => {
                    return Err(DeviceFlowError::oauth_error(
                        error_response.error,
                        error_response.error_description.unwrap_or_default(),
                    ));
                }
            }
        }
    }

    /// Run the complete device flow and return a token manager
    pub async fn run(&mut self) -> Result<TokenManager> {
        let _auth_response = self.initialize().await?;
        let token_response = self.poll_for_token().await?;

        TokenManager::new(token_response, self.provider, self.config.clone())
    }

    /// Get the current authorization response
    pub fn authorization_response(&self) -> Option<&AuthorizationResponse> {
        self.auth_response.as_ref()
    }

    /// Get the provider
    pub fn provider(&self) -> Provider {
        self.provider
    }

    /// Get the configuration
    pub fn config(&self) -> &DeviceFlowConfig {
        &self.config
    }

    /// Check if the device flow has been initialized
    pub fn is_initialized(&self) -> bool {
        self.auth_response.is_some()
    }

    /// Reset the device flow (clear authorization response)
    pub fn reset(&mut self) {
        self.auth_response = None;
    }

    /// Create a new device flow for a different provider with the same config
    pub fn with_provider(self, provider: Provider) -> Result<Self> {
        Self::new(provider, self.config)
    }

    /// Update the configuration
    pub fn with_config(mut self, config: DeviceFlowConfig) -> Result<Self> {
        config.validate(self.provider)?;
        self.client = Self::build_client(&config)?;
        self.config = config;
        Ok(self)
    }

    /// Build HTTP client with configuration
    fn build_client(config: &DeviceFlowConfig) -> Result<Client> {
        let mut client_builder = Client::builder().timeout(config.request_timeout);

        if let Some(ref user_agent) = config.user_agent {
            client_builder = client_builder.user_agent(user_agent);
        }

        client_builder.build().map_err(DeviceFlowError::from)
    }
}

/// Convenience function to run a complete device flow
pub async fn run_device_flow(provider: Provider, config: DeviceFlowConfig) -> Result<TokenManager> {
    let mut device_flow = DeviceFlow::new(provider, config)?;
    device_flow.run().await
}

/// Convenience function to run a device flow with a callback for user interaction
pub async fn run_device_flow_with_callback<F>(
    provider: Provider,
    config: DeviceFlowConfig,
    callback: F,
) -> Result<TokenManager>
where
    F: FnOnce(&AuthorizationResponse) -> Result<()>,
{
    let mut device_flow = DeviceFlow::new(provider, config)?;
    let auth_response = device_flow.initialize().await?;

    // Call the user-provided callback with the authorization response
    callback(auth_response)?;

    let token_response = device_flow.poll_for_token().await?;
    TokenManager::new(token_response, provider, device_flow.config)
}