oauth-device-flows 0.1.0

A specialized Rust library implementing OAuth 2.0 Device Authorization Grant (RFC 8628)
Documentation
//! OAuth provider configurations and implementations

use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use url::Url;

/// Supported OAuth providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Provider {
    /// Microsoft (Azure AD / Microsoft Entra)
    Microsoft,
    /// Google
    Google,
    /// GitHub
    GitHub,
    /// GitLab
    GitLab,
    /// Generic OAuth provider
    Generic,
}

impl Provider {
    /// Get the device authorization endpoint for this provider
    pub fn device_authorization_endpoint(&self) -> &'static str {
        match self {
            Provider::Microsoft => {
                "https://login.microsoftonline.com/common/oauth2/v2.0/devicecode"
            }
            Provider::Google => "https://oauth2.googleapis.com/device/code",
            Provider::GitHub => "https://github.com/login/device/code",
            Provider::GitLab => "https://gitlab.com/oauth/device/code",
            Provider::Generic => "", // Must be configured manually
        }
    }

    /// Get the token endpoint for this provider
    pub fn token_endpoint(&self) -> &'static str {
        match self {
            Provider::Microsoft => "https://login.microsoftonline.com/common/oauth2/v2.0/token",
            Provider::Google => "https://oauth2.googleapis.com/token",
            Provider::GitHub => "https://github.com/login/oauth/access_token",
            Provider::GitLab => "https://gitlab.com/oauth/token",
            Provider::Generic => "", // Must be configured manually
        }
    }

    /// Get the default scopes for this provider
    pub fn default_scopes(&self) -> Vec<&'static str> {
        match self {
            Provider::Microsoft => vec!["https://graph.microsoft.com/.default"],
            Provider::Google => vec!["openid", "email", "profile"],
            Provider::GitHub => vec!["user:email"],
            Provider::GitLab => vec!["read_user"],
            Provider::Generic => vec![], // Must be configured manually
        }
    }

    /// Get the default polling interval for this provider
    pub fn default_poll_interval(&self) -> std::time::Duration {
        match self {
            Provider::Microsoft => std::time::Duration::from_secs(5),
            Provider::Google => std::time::Duration::from_secs(5),
            Provider::GitHub => std::time::Duration::from_secs(5),
            Provider::GitLab => std::time::Duration::from_secs(5),
            Provider::Generic => std::time::Duration::from_secs(5),
        }
    }

    /// Get provider-specific headers
    pub fn headers(&self) -> HashMap<&'static str, &'static str> {
        let mut headers = HashMap::new();

        match self {
            Provider::GitHub => {
                headers.insert("Accept", "application/json");
                headers.insert("User-Agent", "oauth-device-flows");
            }
            Provider::GitLab => {
                headers.insert("Accept", "application/json");
            }
            _ => {
                headers.insert("Accept", "application/json");
            }
        }

        headers
    }

    /// Get the display name for this provider
    pub fn display_name(&self) -> &'static str {
        match self {
            Provider::Microsoft => "Microsoft",
            Provider::Google => "Google",
            Provider::GitHub => "GitHub",
            Provider::GitLab => "GitLab",
            Provider::Generic => "Generic",
        }
    }

    /// Get the verification URI format for this provider
    pub fn verification_uri_format(&self) -> &'static str {
        match self {
            Provider::Microsoft => "https://microsoft.com/devicelogin",
            Provider::Google => "https://www.google.com/device",
            Provider::GitHub => "https://github.com/login/device",
            Provider::GitLab => "https://gitlab.com/-/user_settings/applications",
            Provider::Generic => "", // Must be configured manually
        }
    }

    /// Check if this provider supports PKCE
    pub fn supports_pkce(&self) -> bool {
        matches!(self, Provider::Microsoft | Provider::Google)
    }

    /// Check if this provider requires client_secret
    pub fn requires_client_secret(&self) -> bool {
        matches!(self, Provider::Google | Provider::GitLab)
    }
}

/// Configuration for a generic OAuth provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenericProviderConfig {
    /// Device authorization endpoint
    pub device_authorization_endpoint: Url,

    /// Token endpoint
    pub token_endpoint: Url,

    /// Default scopes
    pub default_scopes: Vec<String>,

    /// Default polling interval
    pub default_poll_interval: std::time::Duration,

    /// Additional headers to send with requests
    pub headers: HashMap<String, String>,

    /// Display name for the provider
    pub display_name: String,

    /// Whether the provider supports PKCE
    pub supports_pkce: bool,

    /// Whether the provider requires client_secret
    pub requires_client_secret: bool,
}

impl GenericProviderConfig {
    /// Create a new generic provider configuration
    pub fn new(
        device_authorization_endpoint: Url,
        token_endpoint: Url,
        display_name: String,
    ) -> Self {
        Self {
            device_authorization_endpoint,
            token_endpoint,
            default_scopes: Vec::new(),
            default_poll_interval: std::time::Duration::from_secs(5),
            headers: HashMap::new(),
            display_name,
            supports_pkce: false,
            requires_client_secret: false,
        }
    }

    /// Set default scopes
    pub fn with_default_scopes(mut self, scopes: Vec<String>) -> Self {
        self.default_scopes = scopes;
        self
    }

    /// Set default polling interval
    pub fn with_poll_interval(mut self, interval: std::time::Duration) -> Self {
        self.default_poll_interval = interval;
        self
    }

    /// Add a header
    pub fn with_header(mut self, key: String, value: String) -> Self {
        self.headers.insert(key, value);
        self
    }

    /// Set PKCE support
    pub fn with_pkce(mut self, supports_pkce: bool) -> Self {
        self.supports_pkce = supports_pkce;
        self
    }

    /// Set client_secret requirement
    pub fn with_client_secret_required(mut self, required: bool) -> Self {
        self.requires_client_secret = required;
        self
    }
}

impl std::fmt::Display for Provider {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.display_name())
    }
}

impl std::str::FromStr for Provider {
    type Err = crate::error::DeviceFlowError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s.to_lowercase().as_str() {
            "microsoft" | "azure" | "azuread" => Ok(Provider::Microsoft),
            "google" => Ok(Provider::Google),
            "github" => Ok(Provider::GitHub),
            "gitlab" => Ok(Provider::GitLab),
            "generic" => Ok(Provider::Generic),
            _ => Err(crate::error::DeviceFlowError::UnsupportedProvider(
                s.to_string(),
            )),
        }
    }
}