use std::fmt::Display;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use super::GateExt;
use crate::accounts::Account;
use crate::authz::access_hierarchy::AccessHierarchy;
use crate::codecs::Codec;
use crate::codecs::jwt::{JwtClaims, RegisteredClaims};
use crate::cookie_template::CookieTemplate;
use chrono::Utc;
use cookie::Cookie;
use oauth2::{
AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, RedirectUrl, Scope,
StandardTokenResponse, TokenUrl, basic::BasicTokenType,
};
use serde::{Deserialize, Serialize};
use webgates_repositories::account_repository::AccountRepository;
pub mod errors;
use errors::{OAuth2CookieKind, OAuth2Error, Result as OAuth2Result};
pub trait OAuth2GateAdapter<R, G>
where
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
type Output;
fn adapt(&self, gate: OAuth2Gate<R, G>) -> Self::Output;
}
impl<R, G, A> crate::gate::adapter::GateAdapter<OAuth2Gate<R, G>> for A
where
A: OAuth2GateAdapter<R, G>,
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
type Output = A::Output;
fn adapt(&self, gate: OAuth2Gate<R, G>) -> Self::Output {
A::adapt(self, gate)
}
}
type OAuth2TokenExchangeFuture = Pin<
Box<
dyn Future<
Output = OAuth2Result<StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>>,
> + Send,
>,
>;
type AccountMapperFn<R, G> = Arc<
dyn for<'a> Fn(
&'a StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>,
)
-> Pin<Box<dyn Future<Output = OAuth2Result<Account<R, G>>> + Send + 'a>>
+ Send
+ Sync,
>;
type AccountPersistFn<R, G> = Arc<
dyn Fn(Account<R, G>) -> Pin<Box<dyn Future<Output = OAuth2Result<Account<R, G>>> + Send>>
+ Send
+ Sync,
>;
type AccountEncoderFn<R, G> = Arc<dyn Fn(Account<R, G>) -> OAuth2Result<String> + Send + Sync>;
#[derive(Clone, Debug)]
pub struct TokenRequest {
pub code: String,
pub pkce_verifier: String,
pub token_url: String,
pub client_id: String,
pub client_secret: Option<String>,
pub redirect_url: String,
}
pub trait TokenExchanger: Send + Sync {
fn exchange_code(&self, request: TokenRequest) -> OAuth2TokenExchangeFuture;
}
#[derive(Debug, Clone)]
pub struct LoginPreparation {
pub redirect_url: String,
pub state_cookie: Cookie<'static>,
pub pkce_cookie: Cookie<'static>,
}
#[derive(Debug, Clone)]
pub struct CallbackInput {
pub code: Option<String>,
pub state: Option<String>,
pub error: Option<String>,
pub error_description: Option<String>,
pub state_cookie: Option<Cookie<'static>>,
pub pkce_cookie: Option<Cookie<'static>>,
}
#[derive(Debug, Clone)]
pub enum CallbackOutcome {
Success {
cookies: Vec<Cookie<'static>>,
redirect_to: Option<String>,
message: Option<String>,
},
Failure {
cookies: Vec<Cookie<'static>>,
error: OAuth2Error,
},
}
#[derive(Clone)]
#[must_use]
pub struct OAuth2Gate<R, G>
where
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
auth_url: Option<String>,
token_url: Option<String>,
client_id: Option<String>,
client_secret: Option<String>,
redirect_url: Option<String>,
scopes: Vec<String>,
state_cookie_template: CookieTemplate,
pkce_cookie_template: CookieTemplate,
auth_cookie_template: CookieTemplate,
post_login_redirect: Option<String>,
mapper: Option<AccountMapperFn<R, G>>,
account_inserter: Option<AccountPersistFn<R, G>>,
jwt_encoder: Option<AccountEncoderFn<R, G>>,
_phantom: PhantomData<(R, G)>,
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::codecs::jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER as JWT_CRYPTO_PROVIDER;
use webgates_codecs::jwt::{JsonWebToken, JwtClaims};
use webgates_core::groups::Group;
use webgates_core::roles::Role;
fn install_jwt_crypto_provider() {
let _ = JWT_CRYPTO_PROVIDER.install_default();
}
struct StubTokenExchanger;
impl TokenExchanger for StubTokenExchanger {
fn exchange_code(&self, _request: TokenRequest) -> OAuth2TokenExchangeFuture {
Box::pin(async move {
Err(OAuth2Error::token_exchange(
"stub exchanger should not be reached in this test",
))
})
}
}
#[tokio::test]
async fn oauth2_with_jwt_codec_mints_jti_and_leaves_sid_absent() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let runtime = OAuth2Gate::<Role, Group>::new()
.auth_url("https://provider.example/authorize")
.token_url("https://provider.example/token")
.client_id("client-id")
.redirect_url("https://app.example/callback")
.with_account_mapper(|_token| Box::pin(async { Ok(Account::<Role, Group>::new("u")) }))
.with_jwt_codec("issuer", Arc::clone(&codec), 900)
.build();
let runtime = match runtime {
Ok(runtime) => runtime,
Err(error) => panic!("oauth2 runtime should build: {}", error),
};
let input = CallbackInput {
code: Some("auth-code".to_string()),
state: Some("state-1".to_string()),
error: None,
error_description: None,
state_cookie: Some(Cookie::new("oauth2-state", "state-1")),
pkce_cookie: Some(Cookie::new("oauth2-pkce", "pkce-1")),
};
let outcome = runtime.evaluate_callback(input, &StubTokenExchanger).await;
let _ = outcome;
let encoder = match runtime.jwt_encoder.as_ref() {
Some(encoder) => encoder.clone(),
None => panic!("jwt encoder should be configured"),
};
let token = match encoder(Account::<Role, Group>::new("user@example.com")) {
Ok(token) => token,
Err(error) => panic!("token should be minted: {}", error),
};
let decoded = match codec.decode(token.as_bytes()) {
Ok(claims) => claims,
Err(error) => panic!("minted token should decode: {}", error),
};
assert!(decoded.registered_claims.jwt_id.is_some());
assert!(decoded.registered_claims.session_id.is_none());
}
}
#[derive(Clone)]
pub struct OAuth2Config<R, G>
where
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
pub auth_url: String,
pub token_url: String,
pub client_id: String,
pub client_secret: Option<String>,
pub redirect_url: String,
pub scopes: Vec<String>,
pub state_cookie_template: CookieTemplate,
pub pkce_cookie_template: CookieTemplate,
pub auth_cookie_template: CookieTemplate,
pub post_login_redirect: Option<String>,
pub mapper: Option<AccountMapperFn<R, G>>,
pub account_inserter: Option<AccountPersistFn<R, G>>,
pub jwt_encoder: Option<AccountEncoderFn<R, G>>,
}
impl<R, G> OAuth2Gate<R, G>
where
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
pub fn into_config(self) -> Result<OAuth2Config<R, G>, OAuth2Error> {
let auth_url = self
.auth_url
.ok_or_else(|| OAuth2Error::missing("auth_url"))?;
let token_url = self
.token_url
.ok_or_else(|| OAuth2Error::missing("token_url"))?;
let client_id = self
.client_id
.ok_or_else(|| OAuth2Error::missing("client_id"))?;
let redirect_url = self
.redirect_url
.ok_or_else(|| OAuth2Error::missing("redirect_url"))?;
self.state_cookie_template
.validate()
.map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::State, e.to_string()))?;
self.pkce_cookie_template
.validate()
.map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Pkce, e.to_string()))?;
self.auth_cookie_template
.validate()
.map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Auth, e.to_string()))?;
Ok(OAuth2Config {
auth_url,
token_url,
client_id,
client_secret: self.client_secret,
redirect_url,
scopes: self.scopes,
state_cookie_template: self.state_cookie_template,
pkce_cookie_template: self.pkce_cookie_template,
auth_cookie_template: self.auth_cookie_template,
post_login_redirect: self.post_login_redirect,
mapper: self.mapper,
account_inserter: self.account_inserter,
jwt_encoder: self.jwt_encoder,
})
}
}
impl<R, G> Default for OAuth2Gate<R, G>
where
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self {
auth_url: None,
token_url: None,
client_id: None,
client_secret: None,
redirect_url: None,
scopes: Vec::new(),
state_cookie_template: CookieTemplate::recommended(),
pkce_cookie_template: CookieTemplate::recommended(),
auth_cookie_template: CookieTemplate::recommended(),
post_login_redirect: None,
mapper: None,
account_inserter: None,
jwt_encoder: None,
_phantom: PhantomData,
}
}
}
impl<R, G> OAuth2Gate<R, G>
where
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self::default()
}
pub fn auth_url(mut self, url: impl Into<String>) -> Self {
self.auth_url = Some(url.into());
self
}
pub fn token_url(mut self, url: impl Into<String>) -> Self {
self.token_url = Some(url.into());
self
}
pub fn client_id(mut self, id: impl Into<String>) -> Self {
self.client_id = Some(id.into());
self
}
pub fn client_secret(mut self, secret: impl Into<String>) -> Self {
self.client_secret = Some(secret.into());
self
}
pub fn redirect_url(mut self, url: impl Into<String>) -> Self {
self.redirect_url = Some(url.into());
self
}
pub fn add_scope(mut self, scope: impl Into<String>) -> Self {
self.scopes.push(scope.into());
self
}
pub fn with_cookie_names(
mut self,
state_cookie: impl Into<String>,
pkce_cookie: impl Into<String>,
) -> Self {
let state_name: String = state_cookie.into();
let pkce_name: String = pkce_cookie.into();
self.state_cookie_template = self.state_cookie_template.name(state_name);
self.pkce_cookie_template = self.pkce_cookie_template.name(pkce_name);
self
}
pub fn with_state_cookie_template(mut self, template: CookieTemplate) -> Self {
self.state_cookie_template = template;
self
}
pub fn configure_state_cookie_template<F>(mut self, f: F) -> OAuth2Result<Self>
where
F: FnOnce(CookieTemplate) -> CookieTemplate,
{
let template = f(CookieTemplate::recommended());
template
.validate()
.map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::State, e.to_string()))?;
self.state_cookie_template = template;
Ok(self)
}
pub fn with_pkce_cookie_template(mut self, template: CookieTemplate) -> Self {
self.pkce_cookie_template = template;
self
}
pub fn configure_pkce_cookie_template<F>(mut self, f: F) -> OAuth2Result<Self>
where
F: FnOnce(CookieTemplate) -> CookieTemplate,
{
let template = f(CookieTemplate::recommended());
template
.validate()
.map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Pkce, e.to_string()))?;
self.pkce_cookie_template = template;
Ok(self)
}
pub fn with_cookie_template(mut self, template: CookieTemplate) -> Self {
self.auth_cookie_template = template;
self
}
pub fn configure_cookie_template<F>(mut self, f: F) -> OAuth2Result<Self>
where
F: FnOnce(CookieTemplate) -> CookieTemplate,
{
let template = f(CookieTemplate::recommended());
template
.validate()
.map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Auth, e.to_string()))?;
self.auth_cookie_template = template;
Ok(self)
}
pub fn with_post_login_redirect(mut self, url: impl Into<String>) -> Self {
self.post_login_redirect = Some(url.into());
self
}
pub fn with_account_mapper<F>(mut self, f: F) -> Self
where
F: Send + Sync + 'static,
for<'a> F: Fn(
&'a StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>,
)
-> Pin<Box<dyn Future<Output = OAuth2Result<Account<R, G>>> + Send + 'a>>,
{
let f = Arc::new(f);
self.mapper = Some(Arc::new(move |token_resp| (f)(token_resp)));
self
}
pub fn with_account_inserter<F, Fut>(mut self, f: F) -> Self
where
F: Fn(Account<R, G>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = OAuth2Result<Account<R, G>>> + Send + 'static,
{
self.account_inserter = Some(Arc::new(move |account: Account<R, G>| Box::pin(f(account))));
self
}
pub fn with_account_repository<AccRepo>(mut self, account_repository: Arc<AccRepo>) -> Self
where
AccRepo: AccountRepository<R, G> + Send + Sync + 'static,
{
self.account_inserter = Some(Arc::new(move |account: Account<R, G>| {
let repo = Arc::clone(&account_repository);
Box::pin(async move {
match repo.query_account_by_user_id(&account.user_id).await {
Ok(Some(existing)) => Ok(existing),
Ok(None) => match repo.store_account(account).await {
Ok(Some(stored)) => Ok(stored),
Ok(None) => Err(OAuth2Error::account_persistence(
"account repo returned None on store",
)),
Err(e) => Err(OAuth2Error::account_persistence(e.to_string())),
},
Err(e) => Err(OAuth2Error::account_persistence(e.to_string())),
}
})
}));
self
}
pub fn with_jwt_codec<C>(mut self, issuer: &str, codec: Arc<C>, ttl_secs: u64) -> Self
where
C: Codec<Payload = JwtClaims<Account<R, G>>> + Send + Sync + 'static,
{
let issuer = issuer.to_string();
self.jwt_encoder = Some(Arc::new(move |account: Account<R, G>| {
let exp = Utc::now().timestamp() as u64 + ttl_secs;
let mut registered = RegisteredClaims::new(&issuer, exp);
registered.session_id = None;
let claims = JwtClaims::new(account, registered);
let bytes = codec
.encode(&claims)
.map_err(|e| OAuth2Error::jwt_encoding(e.to_string()))?;
let token = String::from_utf8(bytes).map_err(|_| OAuth2Error::JwtNotUtf8)?;
Ok(token)
}));
self
}
pub fn build(self) -> OAuth2Result<OAuth2Runtime<R, G>> {
let auth_url = self
.auth_url
.ok_or_else(|| OAuth2Error::missing("auth_url"))?;
let token_url = self
.token_url
.ok_or_else(|| OAuth2Error::missing("token_url"))?;
let client_id = self
.client_id
.ok_or_else(|| OAuth2Error::missing("client_id"))?;
let redirect_url = self
.redirect_url
.ok_or_else(|| OAuth2Error::missing("redirect_url"))?;
self.state_cookie_template
.validate()
.map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::State, e.to_string()))?;
self.pkce_cookie_template
.validate()
.map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Pkce, e.to_string()))?;
self.auth_cookie_template
.validate()
.map_err(|e| OAuth2Error::cookie_invalid(OAuth2CookieKind::Auth, e.to_string()))?;
Ok(OAuth2Runtime {
auth_url,
token_url,
client_id,
client_secret: self.client_secret,
redirect_url,
scopes: self.scopes,
state_cookie_template: self.state_cookie_template,
pkce_cookie_template: self.pkce_cookie_template,
auth_cookie_template: self.auth_cookie_template,
post_login_redirect: self.post_login_redirect,
mapper: self.mapper,
account_inserter: self.account_inserter,
jwt_encoder: self.jwt_encoder,
_phantom: PhantomData,
})
}
}
#[derive(Clone)]
pub struct OAuth2Runtime<R, G>
where
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
auth_url: String,
token_url: String,
client_id: String,
client_secret: Option<String>,
redirect_url: String,
scopes: Vec<String>,
state_cookie_template: CookieTemplate,
pkce_cookie_template: CookieTemplate,
auth_cookie_template: CookieTemplate,
post_login_redirect: Option<String>,
mapper: Option<AccountMapperFn<R, G>>,
account_inserter: Option<AccountPersistFn<R, G>>,
jwt_encoder: Option<AccountEncoderFn<R, G>>,
_phantom: PhantomData<(R, G)>,
}
impl<R, G> OAuth2Runtime<R, G>
where
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
pub fn prepare_login(&self) -> OAuth2Result<LoginPreparation> {
let auth_url = AuthUrl::new(self.auth_url.clone())
.map_err(|e| OAuth2Error::invalid_url("auth_url", e.to_string()))?;
let token_url = TokenUrl::new(self.token_url.clone())
.map_err(|e| OAuth2Error::invalid_url("token_url", e.to_string()))?;
let redirect_url = RedirectUrl::new(self.redirect_url.clone())
.map_err(|e| OAuth2Error::invalid_url("redirect_url", e.to_string()))?;
let mut client = oauth2::basic::BasicClient::new(ClientId::new(self.client_id.clone()))
.set_auth_uri(auth_url)
.set_token_uri(token_url)
.set_redirect_uri(redirect_url);
if let Some(secret) = &self.client_secret {
client = client.set_client_secret(ClientSecret::new(secret.clone()));
}
let csrf = oauth2::CsrfToken::new_random();
let (pkce_challenge, pkce_verifier) = oauth2::PkceCodeChallenge::new_random_sha256();
let mut req = client
.authorize_url(|| csrf.clone())
.set_pkce_challenge(pkce_challenge);
for s in &self.scopes {
req = req.add_scope(Scope::new(s.clone()));
}
let (auth_url, csrf_token) = req.url();
let state_cookie = self
.state_cookie_template
.build_with_value(csrf_token.secret());
let pkce_cookie = self
.pkce_cookie_template
.build_with_value(pkce_verifier.secret());
let prep = LoginPreparation {
redirect_url: auth_url.to_string(),
state_cookie,
pkce_cookie,
};
Ok(prep)
}
pub async fn evaluate_callback(
&self,
input: CallbackInput,
exchanger: &dyn TokenExchanger,
) -> CallbackOutcome {
let mut cookies: Vec<Cookie<'static>> = Vec::new();
cookies.push(self.state_cookie_template.build_removal());
cookies.push(self.pkce_cookie_template.build_removal());
if let Some(err) = input.error.as_ref() {
let oe = OAuth2Error::provider_error(err.clone(), input.error_description.clone());
return CallbackOutcome::Failure { cookies, error: oe };
}
let state_cookie = match input.state_cookie {
Some(c) => c,
None => {
return CallbackOutcome::Failure {
cookies,
error: OAuth2Error::MissingStateCookie,
};
}
};
let pkce_cookie = match input.pkce_cookie {
Some(c) => c,
None => {
return CallbackOutcome::Failure {
cookies,
error: OAuth2Error::MissingPkceCookie,
};
}
};
let state_valid = input.state.as_deref().is_some_and(|state| {
use subtle::ConstantTimeEq;
bool::from(state_cookie.value().as_bytes().ct_eq(state.as_bytes()))
});
if !state_valid {
return CallbackOutcome::Failure {
cookies,
error: OAuth2Error::StateMismatch,
};
}
let code_str = match input.code {
Some(c) => c,
None => {
return CallbackOutcome::Failure {
cookies,
error: OAuth2Error::MissingAuthorizationCode,
};
}
};
let token_req = TokenRequest {
code: code_str,
pkce_verifier: pkce_cookie.value().to_string(),
token_url: self.token_url.clone(),
client_id: self.client_id.clone(),
client_secret: self.client_secret.clone(),
redirect_url: self.redirect_url.clone(),
};
let token_resp = match exchanger.exchange_code(token_req).await {
Ok(resp) => resp,
Err(e) => {
return CallbackOutcome::Failure {
cookies,
error: OAuth2Error::token_exchange(e.to_string()),
};
}
};
let mapped_account = if let Some(mapper) = &self.mapper {
match (mapper)(&token_resp).await {
Ok(acc) => Some(acc),
Err(e) => {
return CallbackOutcome::Failure { cookies, error: e };
}
}
} else {
None
};
let persisted_account = if let Some(account) = mapped_account {
if let Some(inserter) = &self.account_inserter {
match (inserter)(account).await {
Ok(acc) => Some(acc),
Err(e) => {
return CallbackOutcome::Failure { cookies, error: e };
}
}
} else {
Some(account)
}
} else {
None
};
if let (Some(account), Some(encoder)) = (persisted_account, &self.jwt_encoder) {
match encoder(account) {
Ok(token) => {
let auth_cookie = self.auth_cookie_template.build_with_value(&token);
cookies.push(auth_cookie);
}
Err(e) => {
return CallbackOutcome::Failure { cookies, error: e };
}
}
}
CallbackOutcome::Success {
cookies,
redirect_to: self.post_login_redirect.clone(),
message: Some("OAuth2 callback OK".into()),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ProviderTokenResponse {
pub access_token: String,
pub refresh_token: Option<String>,
pub id_token: Option<String>,
pub scopes: Option<Vec<String>>,
}
impl<R, G> GateExt for OAuth2Gate<R, G>
where
R: AccessHierarchy + Eq + Display + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
}