use oauth2::basic::BasicClient;
use oauth2::{EndpointNotSet, EndpointSet};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::time::{Duration, SystemTime};
pub type ConfiguredClient = BasicClient<
EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet, >;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OAuthProvider {
Google,
GitHub,
Oidc,
}
impl OAuthProvider {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Google => "google",
Self::GitHub => "github",
Self::Oidc => "oidc",
}
}
}
impl FromStr for OAuthProvider {
type Err = OAuthError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"google" => Ok(Self::Google),
"github" => Ok(Self::GitHub),
"oidc" => Ok(Self::Oidc),
_ => Err(OAuthError::UnknownProvider(s.to_string())),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auth_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub google: Option<ProviderConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub github: Option<ProviderConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oidc: Option<ProviderConfig>,
}
impl OAuthConfig {
#[must_use]
pub const fn new() -> Self {
Self {
google: None,
github: None,
oidc: None,
}
}
const fn provider_config(&self, provider: OAuthProvider) -> Option<&ProviderConfig> {
match provider {
OAuthProvider::Google => self.google.as_ref(),
OAuthProvider::GitHub => self.github.as_ref(),
OAuthProvider::Oidc => self.oidc.as_ref(),
}
}
pub fn get_provider(&self, provider: OAuthProvider) -> Result<&ProviderConfig, OAuthError> {
self.provider_config(provider)
.ok_or(OAuthError::ProviderNotConfigured(provider))
}
#[must_use]
pub const fn is_provider_configured(&self, provider: OAuthProvider) -> bool {
self.provider_config(provider).is_some()
}
}
impl Default for OAuthConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthState {
pub token: String,
pub provider: OAuthProvider,
pub expires_at: SystemTime,
}
impl OAuthState {
#[must_use]
pub fn generate(provider: OAuthProvider) -> Self {
use rand::Rng;
let random_bytes: [u8; 32] = rand::rng().random();
let token = hex::encode(random_bytes);
Self {
token,
provider,
expires_at: SystemTime::now() + Duration::from_secs(600), }
}
#[must_use]
pub fn is_expired(&self) -> bool {
SystemTime::now() > self.expires_at
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthToken {
pub access_token: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
pub token_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<SystemTime>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scopes: Option<Vec<String>>,
}
impl OAuthToken {
#[must_use]
pub fn is_expired(&self) -> bool {
self.expires_at
.is_some_and(|expires| SystemTime::now() > expires)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthUserInfo {
pub provider_user_id: String,
pub email: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub avatar_url: Option<String>,
pub email_verified: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum OAuthError {
#[error("Unknown OAuth2 provider: {0}")]
UnknownProvider(String),
#[error("OAuth2 provider not configured: {0:?}")]
ProviderNotConfigured(OAuthProvider),
#[error("Invalid or expired OAuth2 state token")]
InvalidState,
#[error("OAuth2 state token mismatch (potential CSRF attack)")]
StateMismatch,
#[error("Failed to exchange authorization code for token: {0}")]
TokenExchangeFailed(String),
#[error("Failed to fetch user information: {0}")]
UserInfoFailed(String),
#[error("OAuth2 token has expired")]
TokenExpired,
#[error("OAuth2 error: {0}")]
Generic(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_as_str() {
assert_eq!(OAuthProvider::Google.as_str(), "google");
assert_eq!(OAuthProvider::GitHub.as_str(), "github");
assert_eq!(OAuthProvider::Oidc.as_str(), "oidc");
}
#[test]
fn test_provider_from_str() {
assert_eq!(
"google".parse::<OAuthProvider>().unwrap(),
OAuthProvider::Google
);
assert_eq!(
"GOOGLE".parse::<OAuthProvider>().unwrap(),
OAuthProvider::Google
);
assert_eq!(
"github".parse::<OAuthProvider>().unwrap(),
OAuthProvider::GitHub
);
assert_eq!(
"oidc".parse::<OAuthProvider>().unwrap(),
OAuthProvider::Oidc
);
assert!("invalid".parse::<OAuthProvider>().is_err());
}
#[test]
fn test_oauth_config_default() {
let config = OAuthConfig::default();
assert!(config.google.is_none());
assert!(config.github.is_none());
assert!(config.oidc.is_none());
}
#[test]
fn test_oauth_config_is_provider_configured() {
let mut config = OAuthConfig::default();
assert!(!config.is_provider_configured(OAuthProvider::Google));
config.google = Some(ProviderConfig {
client_id: "test".to_string(),
client_secret: "test".to_string(),
redirect_uri: "http://localhost/callback".to_string(),
scopes: vec!["email".to_string()],
auth_url: None,
token_url: None,
userinfo_url: None,
});
assert!(config.is_provider_configured(OAuthProvider::Google));
assert!(!config.is_provider_configured(OAuthProvider::GitHub));
}
#[test]
fn test_oauth_state_generation() {
let state = OAuthState::generate(OAuthProvider::Google);
assert_eq!(state.provider, OAuthProvider::Google);
assert!(!state.is_expired());
assert_eq!(state.token.len(), 64); }
#[test]
fn test_oauth_token_is_expired() {
let token = OAuthToken {
access_token: "test".to_string(),
refresh_token: None,
token_type: "Bearer".to_string(),
expires_at: None,
scopes: None,
};
assert!(!token.is_expired());
let expired_token = OAuthToken {
access_token: "test".to_string(),
refresh_token: None,
token_type: "Bearer".to_string(),
expires_at: Some(SystemTime::now() - Duration::from_secs(3600)),
scopes: None,
};
assert!(expired_token.is_expired());
}
}