use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use url::Url;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Provider {
Microsoft,
Google,
GitHub,
GitLab,
Generic,
}
impl Provider {
pub fn device_authorization_endpoint(&self) -> &'static str {
match self {
Provider::Microsoft => {
"https://login.microsoftonline.com/common/oauth2/v2.0/devicecode"
}
Provider::Google => "https://oauth2.googleapis.com/device/code",
Provider::GitHub => "https://github.com/login/device/code",
Provider::GitLab => "https://gitlab.com/oauth/device/code",
Provider::Generic => "", }
}
pub fn token_endpoint(&self) -> &'static str {
match self {
Provider::Microsoft => "https://login.microsoftonline.com/common/oauth2/v2.0/token",
Provider::Google => "https://oauth2.googleapis.com/token",
Provider::GitHub => "https://github.com/login/oauth/access_token",
Provider::GitLab => "https://gitlab.com/oauth/token",
Provider::Generic => "", }
}
pub fn default_scopes(&self) -> Vec<&'static str> {
match self {
Provider::Microsoft => vec!["https://graph.microsoft.com/.default"],
Provider::Google => vec!["openid", "email", "profile"],
Provider::GitHub => vec!["user:email"],
Provider::GitLab => vec!["read_user"],
Provider::Generic => vec![], }
}
pub fn default_poll_interval(&self) -> std::time::Duration {
match self {
Provider::Microsoft => std::time::Duration::from_secs(5),
Provider::Google => std::time::Duration::from_secs(5),
Provider::GitHub => std::time::Duration::from_secs(5),
Provider::GitLab => std::time::Duration::from_secs(5),
Provider::Generic => std::time::Duration::from_secs(5),
}
}
pub fn headers(&self) -> HashMap<&'static str, &'static str> {
let mut headers = HashMap::new();
match self {
Provider::GitHub => {
headers.insert("Accept", "application/json");
headers.insert("User-Agent", "oauth-device-flows");
}
Provider::GitLab => {
headers.insert("Accept", "application/json");
}
_ => {
headers.insert("Accept", "application/json");
}
}
headers
}
pub fn display_name(&self) -> &'static str {
match self {
Provider::Microsoft => "Microsoft",
Provider::Google => "Google",
Provider::GitHub => "GitHub",
Provider::GitLab => "GitLab",
Provider::Generic => "Generic",
}
}
pub fn verification_uri_format(&self) -> &'static str {
match self {
Provider::Microsoft => "https://microsoft.com/devicelogin",
Provider::Google => "https://www.google.com/device",
Provider::GitHub => "https://github.com/login/device",
Provider::GitLab => "https://gitlab.com/-/user_settings/applications",
Provider::Generic => "", }
}
pub fn supports_pkce(&self) -> bool {
matches!(self, Provider::Microsoft | Provider::Google)
}
pub fn requires_client_secret(&self) -> bool {
matches!(self, Provider::Google | Provider::GitLab)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenericProviderConfig {
pub device_authorization_endpoint: Url,
pub token_endpoint: Url,
pub default_scopes: Vec<String>,
pub default_poll_interval: std::time::Duration,
pub headers: HashMap<String, String>,
pub display_name: String,
pub supports_pkce: bool,
pub requires_client_secret: bool,
}
impl GenericProviderConfig {
pub fn new(
device_authorization_endpoint: Url,
token_endpoint: Url,
display_name: String,
) -> Self {
Self {
device_authorization_endpoint,
token_endpoint,
default_scopes: Vec::new(),
default_poll_interval: std::time::Duration::from_secs(5),
headers: HashMap::new(),
display_name,
supports_pkce: false,
requires_client_secret: false,
}
}
pub fn with_default_scopes(mut self, scopes: Vec<String>) -> Self {
self.default_scopes = scopes;
self
}
pub fn with_poll_interval(mut self, interval: std::time::Duration) -> Self {
self.default_poll_interval = interval;
self
}
pub fn with_header(mut self, key: String, value: String) -> Self {
self.headers.insert(key, value);
self
}
pub fn with_pkce(mut self, supports_pkce: bool) -> Self {
self.supports_pkce = supports_pkce;
self
}
pub fn with_client_secret_required(mut self, required: bool) -> Self {
self.requires_client_secret = required;
self
}
}
impl std::fmt::Display for Provider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.display_name())
}
}
impl std::str::FromStr for Provider {
type Err = crate::error::DeviceFlowError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"microsoft" | "azure" | "azuread" => Ok(Provider::Microsoft),
"google" => Ok(Provider::Google),
"github" => Ok(Provider::GitHub),
"gitlab" => Ok(Provider::GitLab),
"generic" => Ok(Provider::Generic),
_ => Err(crate::error::DeviceFlowError::UnsupportedProvider(
s.to_string(),
)),
}
}
}