oauth-device-flows 0.1.0

A specialized Rust library implementing OAuth 2.0 Device Authorization Grant (RFC 8628)
Documentation
//! Configuration for OAuth device flows

use crate::provider::{GenericProviderConfig, Provider};
use secrecy::Secret;
use std::time::Duration;
use url::Url;

/// Configuration for the device flow
#[derive(Debug, Clone)]
pub struct DeviceFlowConfig {
    /// OAuth client ID
    pub client_id: String,

    /// OAuth client secret (optional, required for some providers)
    pub client_secret: Option<Secret<String>>,

    /// Requested scopes
    pub scopes: Vec<String>,

    /// Custom redirect URI (optional)
    pub redirect_uri: Option<Url>,

    /// Polling interval
    pub poll_interval: Duration,

    /// Maximum number of polling attempts
    pub max_attempts: u32,

    /// Exponential backoff multiplier for polling
    pub backoff_multiplier: f64,

    /// Maximum polling interval (for backoff)
    pub max_poll_interval: Duration,

    /// Custom user agent for HTTP requests
    pub user_agent: Option<String>,

    /// Additional headers to send with requests
    pub additional_headers: std::collections::HashMap<String, String>,

    /// Configuration for generic providers
    pub generic_provider_config: Option<GenericProviderConfig>,

    /// Timeout for HTTP requests
    pub request_timeout: Duration,

    /// Whether to use PKCE (Proof Key for Code Exchange)
    pub use_pkce: Option<bool>,
}

impl DeviceFlowConfig {
    /// Create a new configuration with defaults
    pub fn new() -> Self {
        Self {
            client_id: String::new(),
            client_secret: None,
            scopes: Vec::new(),
            redirect_uri: None,
            poll_interval: Duration::from_secs(5),
            max_attempts: 60, // 5 minutes with 5-second intervals
            backoff_multiplier: 1.1,
            max_poll_interval: Duration::from_secs(30),
            user_agent: Some(format!("oauth-device-flows/{}", env!("CARGO_PKG_VERSION"))),
            additional_headers: std::collections::HashMap::new(),
            generic_provider_config: None,
            request_timeout: Duration::from_secs(30),
            use_pkce: None, // Auto-detect based on provider
        }
    }

    /// Set the client ID
    pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
        self.client_id = client_id.into();
        self
    }

    /// Set the client secret
    pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
        self.client_secret = Some(Secret::new(client_secret.into()));
        self
    }

    /// Set the scopes
    pub fn scopes<I, S>(mut self, scopes: I) -> Self
    where
        I: IntoIterator<Item = S>,
        S: Into<String>,
    {
        self.scopes = scopes.into_iter().map(|s| s.into()).collect();
        self
    }

    /// Add a scope
    pub fn scope(mut self, scope: impl Into<String>) -> Self {
        self.scopes.push(scope.into());
        self
    }

    /// Set the redirect URI
    pub fn redirect_uri(mut self, uri: impl Into<String>) -> Result<Self, url::ParseError> {
        self.redirect_uri = Some(Url::parse(&uri.into())?);
        Ok(self)
    }

    /// Set the polling interval
    pub fn poll_interval(mut self, interval: Duration) -> Self {
        self.poll_interval = interval;
        self
    }

    /// Set the maximum number of polling attempts
    pub fn max_attempts(mut self, attempts: u32) -> Self {
        self.max_attempts = attempts;
        self
    }

    /// Set the exponential backoff multiplier
    pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
        self.backoff_multiplier = multiplier;
        self
    }

    /// Set the maximum polling interval
    pub fn max_poll_interval(mut self, interval: Duration) -> Self {
        self.max_poll_interval = interval;
        self
    }

    /// Set a custom user agent
    pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
        self.user_agent = Some(user_agent.into());
        self
    }

    /// Add an additional header
    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
        self.additional_headers.insert(key.into(), value.into());
        self
    }

    /// Set the generic provider configuration
    pub fn generic_provider(mut self, config: GenericProviderConfig) -> Self {
        self.generic_provider_config = Some(config);
        self
    }

    /// Set the HTTP request timeout
    pub fn request_timeout(mut self, timeout: Duration) -> Self {
        self.request_timeout = timeout;
        self
    }

    /// Enable or disable PKCE
    pub fn use_pkce(mut self, use_pkce: bool) -> Self {
        self.use_pkce = Some(use_pkce);
        self
    }

    /// Get the effective scopes for a provider
    pub fn effective_scopes(&self, provider: Provider) -> Vec<String> {
        if !self.scopes.is_empty() {
            self.scopes.clone()
        } else if let Some(ref config) = self.generic_provider_config {
            config.default_scopes.clone()
        } else {
            provider
                .default_scopes()
                .into_iter()
                .map(|s| s.to_string())
                .collect()
        }
    }

    /// Get the effective polling interval for a provider
    pub fn effective_poll_interval(&self, provider: Provider) -> Duration {
        if self.poll_interval != Duration::from_secs(5) {
            self.poll_interval
        } else if let Some(ref config) = self.generic_provider_config {
            config.default_poll_interval
        } else {
            provider.default_poll_interval()
        }
    }

    /// Check if PKCE should be used for a provider
    pub fn should_use_pkce(&self, provider: Provider) -> bool {
        self.use_pkce.unwrap_or_else(|| {
            if let Some(config) = &self.generic_provider_config {
                config.supports_pkce
            } else {
                provider.supports_pkce()
            }
        })
    }

    /// Check if client secret is required for a provider
    pub fn requires_client_secret(&self, provider: Provider) -> bool {
        if let Some(config) = &self.generic_provider_config {
            config.requires_client_secret
        } else {
            provider.requires_client_secret()
        }
    }

    /// Validate the configuration for a specific provider
    pub fn validate(&self, provider: Provider) -> Result<(), crate::error::DeviceFlowError> {
        use crate::error::DeviceFlowError;

        // Check if client ID is set
        if self.client_id.is_empty() {
            return Err(DeviceFlowError::invalid_client("Client ID is required"));
        }

        // Check if client secret is required but not provided
        if self.requires_client_secret(provider) && self.client_secret.is_none() {
            return Err(DeviceFlowError::invalid_client(format!(
                "Client secret is required for {provider}"
            )));
        }

        // Check if generic provider config is provided when needed
        if provider == Provider::Generic && self.generic_provider_config.is_none() {
            return Err(DeviceFlowError::invalid_client(
                "Generic provider configuration is required for generic provider",
            ));
        }

        // Validate polling configuration
        if self.max_attempts == 0 {
            return Err(DeviceFlowError::invalid_client(
                "Max attempts must be greater than 0",
            ));
        }

        if self.backoff_multiplier < 1.0 {
            return Err(DeviceFlowError::invalid_client(
                "Backoff multiplier must be >= 1.0",
            ));
        }

        Ok(())
    }
}

impl Default for DeviceFlowConfig {
    fn default() -> Self {
        Self::new()
    }
}