use crate::{AuthError, Result};
use async_trait::async_trait;
use oauth2::basic::{BasicClient, BasicTokenType};
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
EndpointSet, RedirectUrl, Scope, StandardErrorResponse, StandardRevocableToken,
StandardTokenIntrospectionResponse, StandardTokenResponse, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use url::Url;
type ConfiguredClient = oauth2::Client<
StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>,
StandardTokenIntrospectionResponse<EmptyExtraTokenFields, BasicTokenType>,
StandardRevocableToken,
StandardErrorResponse<oauth2::RevocationErrorResponseType>,
EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
EndpointSet,
>;
#[async_trait]
pub trait OAuth2Provider: Send + Sync {
fn name(&self) -> &str;
fn authorization_url(&self) -> Result<(Url, CsrfToken)>;
async fn exchange_code(&self, code: String) -> Result<OAuth2Token>;
async fn get_user_info(&self, token: &OAuth2Token) -> Result<OAuth2UserInfo>;
async fn refresh_token(&self, refresh_token: String) -> Result<OAuth2Token>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2Token {
pub access_token: String,
pub token_type: String,
pub expires_in: Option<u64>,
pub refresh_token: Option<String>,
pub scope: Option<String>,
pub id_token: Option<String>, }
impl From<StandardTokenResponse<EmptyExtraTokenFields, oauth2::basic::BasicTokenType>>
for OAuth2Token
{
fn from(
token: StandardTokenResponse<EmptyExtraTokenFields, oauth2::basic::BasicTokenType>,
) -> Self {
Self {
access_token: token.access_token().secret().clone(),
token_type: token.token_type().as_ref().to_string(),
expires_in: token.expires_in().map(|d| d.as_secs()),
refresh_token: token.refresh_token().map(|t| t.secret().clone()),
scope: token.scopes().map(|s| {
s.iter()
.map(|scope| scope.as_str())
.collect::<Vec<_>>()
.join(" ")
}),
id_token: None, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2UserInfo {
pub sub: String, pub email: Option<String>,
pub name: Option<String>,
pub given_name: Option<String>,
pub family_name: Option<String>,
pub picture: Option<String>,
pub email_verified: Option<bool>,
#[serde(flatten)]
pub additional: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct OAuth2Config {
pub client_id: String,
pub client_secret: String,
pub auth_url: String,
pub token_url: String,
pub redirect_url: String,
pub scopes: Vec<String>,
pub user_info_url: Option<String>,
}
impl OAuth2Config {
pub fn new(
client_id: String,
client_secret: String,
auth_url: String,
token_url: String,
redirect_url: String,
) -> Self {
Self {
client_id,
client_secret,
auth_url,
token_url,
redirect_url,
scopes: Vec::new(),
user_info_url: None,
}
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
pub fn with_user_info_url(mut self, url: String) -> Self {
self.user_info_url = Some(url);
self
}
}
pub struct GenericOAuth2Provider {
name: String,
client: ConfiguredClient,
config: OAuth2Config,
}
impl GenericOAuth2Provider {
pub fn new(name: String, config: OAuth2Config) -> Result<Self> {
let client = BasicClient::new(ClientId::new(config.client_id.clone()))
.set_client_secret(ClientSecret::new(config.client_secret.clone()))
.set_auth_uri(
AuthUrl::new(config.auth_url.clone())
.map_err(|e| AuthError::AuthenticationFailed(e.to_string()))?,
)
.set_token_uri(
TokenUrl::new(config.token_url.clone())
.map_err(|e| AuthError::AuthenticationFailed(e.to_string()))?,
)
.set_redirect_uri(
RedirectUrl::new(config.redirect_url.clone())
.map_err(|e| AuthError::AuthenticationFailed(e.to_string()))?,
);
Ok(Self {
name,
client,
config,
})
}
}
#[async_trait]
impl OAuth2Provider for GenericOAuth2Provider {
fn name(&self) -> &str {
&self.name
}
fn authorization_url(&self) -> Result<(Url, CsrfToken)> {
let mut auth_request = self.client.authorize_url(CsrfToken::new_random);
for scope in &self.config.scopes {
auth_request = auth_request.add_scope(Scope::new(scope.clone()));
}
let (url, csrf_token) = auth_request.url();
Ok((url, csrf_token))
}
async fn exchange_code(&self, code: String) -> Result<OAuth2Token> {
let http_client = oauth2::reqwest::Client::new();
let token = self
.client
.exchange_code(AuthorizationCode::new(code))
.request_async(&http_client)
.await
.map_err(|e| {
AuthError::AuthenticationFailed(format!("Token exchange failed: {}", e))
})?;
Ok(token.into())
}
async fn get_user_info(&self, token: &OAuth2Token) -> Result<OAuth2UserInfo> {
let user_info_url =
self.config.user_info_url.as_ref().ok_or_else(|| {
AuthError::AuthenticationFailed("No user info URL configured".into())
})?;
let client = reqwest::Client::new();
let response = client
.get(user_info_url)
.bearer_auth(&token.access_token)
.send()
.await
.map_err(|e| {
AuthError::AuthenticationFailed(format!("User info request failed: {}", e))
})?;
if !response.status().is_success() {
return Err(AuthError::AuthenticationFailed(format!(
"User info request failed with status: {}",
response.status()
)));
}
let user_info: OAuth2UserInfo = response.json().await.map_err(|e| {
AuthError::AuthenticationFailed(format!("Failed to parse user info: {}", e))
})?;
Ok(user_info)
}
async fn refresh_token(&self, refresh_token: String) -> Result<OAuth2Token> {
let http_client = oauth2::reqwest::Client::new();
let token = self
.client
.exchange_refresh_token(&oauth2::RefreshToken::new(refresh_token))
.request_async(&http_client)
.await
.map_err(|e| AuthError::AuthenticationFailed(format!("Token refresh failed: {}", e)))?;
Ok(token.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oauth2_config() {
let config = OAuth2Config::new(
"client_id".to_string(),
"client_secret".to_string(),
"https://example.com/auth".to_string(),
"https://example.com/token".to_string(),
"https://example.com/callback".to_string(),
)
.with_scopes(vec!["openid".to_string(), "profile".to_string()])
.with_user_info_url("https://example.com/userinfo".to_string());
assert_eq!(config.client_id, "client_id");
assert_eq!(config.scopes.len(), 2);
assert!(config.user_info_url.is_some());
}
}