use std::sync::Arc;
use ppoppo_clock::ArcClock;
use ppoppo_clock::native::WallClock;
use ppoppo_token::id_token::scopes::{
Email, EmailProfile, EmailProfilePhone, EmailProfilePhoneAddress, Openid, Profile,
};
use ppoppo_token::id_token::Nonce;
use url::Url;
use super::{
discovery::{fetch_discovery, Discovery, DiscoveryError},
port::{IdTokenVerifier, IdVerifyError, ScopePiiReader},
refresh_outcome::RefreshOutcome,
state_store::{
AuthorizationRedirect, CallbackParams, Completion, Config, PendingAuthRequest,
RelativePath, State, StateStore, StateStoreError,
},
verifier::PasIdTokenVerifier,
};
use crate::oauth::{AuthClient, OAuthConfig};
use crate::pkce;
use crate::VerifyConfig;
pub trait RequestedScope: ScopePiiReader {
const SCOPE: &'static str;
}
impl RequestedScope for Openid {
const SCOPE: &'static str = "openid";
}
impl RequestedScope for Email {
const SCOPE: &'static str = "openid email";
}
impl RequestedScope for Profile {
const SCOPE: &'static str = "openid profile";
}
impl RequestedScope for EmailProfile {
const SCOPE: &'static str = "openid email profile";
}
impl RequestedScope for EmailProfilePhone {
const SCOPE: &'static str = "openid email profile phone";
}
impl RequestedScope for EmailProfilePhoneAddress {
const SCOPE: &'static str = "openid email profile phone address";
}
pub struct RelyingParty<S: RequestedScope> {
config: Config,
state_store: Arc<dyn StateStore>,
auth_client: AuthClient,
verifier: Arc<dyn IdTokenVerifier<S>>,
discovery: Discovery,
clock: ArcClock,
}
impl<S: RequestedScope> std::fmt::Debug for RelyingParty<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RelyingParty")
.field("config", &self.config)
.field("discovery", &self.discovery)
.finish_non_exhaustive()
}
}
#[derive(Debug, thiserror::Error)]
pub enum RelyingPartyInitError {
#[error("OIDC discovery fetch failed: {0}")]
Discovery(#[from] DiscoveryError),
#[error("JWKS fetch failed: {0}")]
Jwks(IdVerifyError),
#[error("OAuth client construction failed: {0}")]
OAuthClient(String),
}
#[derive(Debug, thiserror::Error)]
pub enum StartError {
#[error("state store failure: {0}")]
StateStore(#[from] StateStoreError),
#[error("authorize URL construction failed: {0}")]
UrlBuild(String),
}
#[derive(Debug, thiserror::Error)]
pub enum RefreshError {
#[error("refresh_token rejected by PAS: {0}")]
Rejected(String),
#[error("refresh transient failure: {0}")]
Transient(String),
}
#[derive(Debug, thiserror::Error)]
pub enum CallbackError {
#[error("state not found or already consumed (CSRF defense triggered)")]
StateNotFoundOrConsumed,
#[error("state store failure: {0}")]
StateStore(#[from] StateStoreError),
#[error("token exchange failed: {0}")]
TokenExchange(String),
#[error("id_token verification failed: {0}")]
IdToken(#[from] IdVerifyError),
}
impl<S: RequestedScope> RelyingParty<S> {
pub async fn new(
config: Config,
state_store: Arc<dyn StateStore>,
) -> Result<Self, RelyingPartyInitError> {
let discovery = fetch_discovery(&config.issuer).await?;
let expectations = VerifyConfig::new(
discovery.issuer.as_str(),
config.client_id.clone(),
);
let verifier_concrete: PasIdTokenVerifier<S> =
PasIdTokenVerifier::from_jwks_url(discovery.jwks_uri.to_string(), expectations)
.await
.map_err(RelyingPartyInitError::Jwks)?;
let verifier: Arc<dyn IdTokenVerifier<S>> = Arc::new(verifier_concrete);
let oauth_config = OAuthConfig::new(
config.client_id.clone(),
config.redirect_uri.clone(),
)
.with_auth_url(discovery.authorization_endpoint.clone())
.with_token_url(discovery.token_endpoint.clone());
let auth_client = AuthClient::try_new(oauth_config)
.map_err(|e| RelyingPartyInitError::OAuthClient(e.to_string()))?;
Ok(Self {
config,
state_store,
auth_client,
verifier,
discovery,
clock: Arc::new(WallClock),
})
}
#[must_use]
pub fn with_clock(mut self, clock: ArcClock) -> Self {
self.clock = clock;
self
}
pub async fn start(
&self,
after_login: RelativePath,
) -> Result<AuthorizationRedirect, StartError> {
let state_str = pkce::generate_state();
let code_verifier = pkce::generate_code_verifier();
let code_challenge = pkce::generate_code_challenge(&code_verifier);
let nonce = pkce::generate_state();
let state = State::from_string(state_str.clone());
let pending = PendingAuthRequest {
code_verifier: code_verifier.clone(),
nonce: nonce.clone(),
after_login,
created_at: self.clock.now_utc(),
};
self.state_store
.put(&state, pending, self.config.state_ttl)
.await?;
let url = build_authorize_url(
&self.discovery.authorization_endpoint,
&self.config.client_id,
&self.config.redirect_uri,
&state_str,
&code_challenge,
S::SCOPE,
&nonce,
);
Ok(AuthorizationRedirect { url, state })
}
pub async fn complete(
&self,
params: CallbackParams,
) -> Result<Completion<S>, CallbackError> {
let pending = self
.state_store
.take(¶ms.state)
.await?
.ok_or(CallbackError::StateNotFoundOrConsumed)?;
let tokens = self
.auth_client
.exchange_code(¶ms.code, &pending.code_verifier)
.await
.map_err(|e| CallbackError::TokenExchange(e.to_string()))?;
let id_token = tokens.id_token.as_deref().ok_or_else(|| {
CallbackError::TokenExchange("token response missing id_token".to_owned())
})?;
let nonce = Nonce::new(pending.nonce.as_str())
.map_err(|_| CallbackError::IdToken(IdVerifyError::NonceMismatch))?;
let id_assertion = self.verifier.verify(id_token, &nonce).await?;
Ok(Completion {
id_assertion,
tokens,
redirect_to: pending.after_login,
})
}
pub async fn refresh(
&self,
refresh_token: &str,
) -> Result<RefreshOutcome, RefreshError> {
use crate::pas_port::{PasAuthPort, PasFailure};
match self.auth_client.refresh(refresh_token).await {
Ok(t) => Ok(RefreshOutcome::from(t)),
Err(PasFailure::Rejected { detail, .. }) => Err(RefreshError::Rejected(detail)),
Err(PasFailure::ServerError { detail, .. })
| Err(PasFailure::Transport { detail }) => Err(RefreshError::Transient(detail)),
}
}
}
fn build_authorize_url(
authorization_endpoint: &Url,
client_id: &str,
redirect_uri: &Url,
state: &str,
code_challenge: &str,
scope: &str,
nonce: &str,
) -> Url {
let mut url = authorization_endpoint.clone();
url.query_pairs_mut()
.append_pair("response_type", "code")
.append_pair("client_id", client_id)
.append_pair("redirect_uri", redirect_uri.as_str())
.append_pair("state", state)
.append_pair("code_challenge", code_challenge)
.append_pair("code_challenge_method", "S256")
.append_pair("scope", scope)
.append_pair("nonce", nonce);
url
}