shield-oidc 0.2.2

OpenID Connect method for Shield.
Documentation
use std::sync::Arc;

use async_trait::async_trait;
use chrono::{DateTime, Duration, FixedOffset, Utc};
use openidconnect::{
    AuthorizationCode, EmptyAdditionalClaims, Nonce, OAuth2TokenResponse, PkceCodeVerifier,
    TokenResponse, UserInfoClaims,
    core::{CoreGenderClaim, CoreTokenResponse},
    url::form_urlencoded::parse,
};
use secrecy::SecretString;
use shield::{
    Action, ActionMethod, ConfigurationError, CreateEmailAddress, CreateUser, Form, MethodSession,
    Request, Response, ResponseType, SessionAction, ShieldError, SignInCallbackAction, UpdateUser,
    User, erased_action,
};
use tracing::debug;

use crate::{
    claims::Claims,
    client::async_http_client,
    connection::{CreateOidcConnection, OidcConnection, UpdateOidcConnection},
    options::OidcOptions,
    provider::{OidcProvider, OidcProviderPkceCodeChallenge},
    session::OidcSession,
    storage::OidcStorage,
};

pub struct OidcSignInCallbackAction<U: User> {
    options: OidcOptions,
    storage: Arc<dyn OidcStorage<U>>,
}

impl<U: User> OidcSignInCallbackAction<U> {
    pub fn new(options: OidcOptions, storage: Arc<dyn OidcStorage<U>>) -> Self {
        Self { options, storage }
    }

    // TODO: Consider if there is a better location for the functions below.

    async fn create_user(&self, claims: &Claims) -> Result<U, ShieldError> {
        if let Some(email) = claims.email() {
            match self.storage.user_by_email(email).await? {
                Some(_) => Err(ShieldError::Validation(
                    "\
                Email address `{email}` is already used by another account. \
                To link a new provider, sign in to with your exising account first. \
                If this is not your account, please contact support for assistence.\
                "
                    .to_owned(),
                )),
                None => Ok(self
                    .storage
                    .create_user(
                        CreateUser {
                            name: claims
                                .name()
                                .and_then(|name| name.get(None).map(|name| name.to_string())),
                        },
                        CreateEmailAddress {
                            email: email.to_string(),
                            is_primary: true,
                            // TODO: from claim?
                            is_verified: false,
                            // TODO: generate if not verified
                            verification_token: None,
                            verification_token_expired_at: None,
                            verified_at: None,
                        },
                    )
                    .await?),
            }
        } else {
            Err(ShieldError::Validation(
                "Missing email address in OpenID Connect claims.".to_owned(),
            ))
        }
    }

    async fn update_user(&self, user_id: &str, claims: &Claims) -> Result<U, ShieldError> {
        self.storage
            .update_user(UpdateUser {
                id: user_id.to_owned(),
                name: claims
                    .name()
                    .and_then(|name| name.get(None).map(|name| name.to_string()))
                    .map(Some),
            })
            .await
            .map_err(ShieldError::Storage)
    }

    async fn create_oidc_connection(
        &self,
        provider_id: String,
        user_id: String,
        identifier: String,
        token_response: CoreTokenResponse,
    ) -> Result<OidcConnection, ShieldError> {
        let (token_type, access_token, refresh_token, id_token, expired_at, scopes) =
            parse_token_response(token_response)?;

        self.storage
            .create_oidc_connection(CreateOidcConnection {
                identifier,
                token_type,
                access_token,
                refresh_token,
                id_token,
                expired_at,
                scopes,
                provider_id,
                user_id,
            })
            .await
            .map_err(ShieldError::Storage)
    }

    async fn update_oidc_connection(
        &self,
        connection_id: String,
        token_response: CoreTokenResponse,
    ) -> Result<OidcConnection, ShieldError> {
        let (token_type, access_token, refresh_token, id_token, expired_at, scopes) =
            parse_token_response(token_response)?;

        self.storage
            .update_oidc_connection(UpdateOidcConnection {
                id: connection_id,
                token_type: Some(token_type),
                access_token: Some(access_token),
                refresh_token: refresh_token.map(Some),
                id_token: id_token.map(Some),
                expired_at: expired_at.map(Some),
                scopes: scopes.map(Some),
            })
            .await
            .map_err(ShieldError::Storage)
    }
}

#[async_trait]
impl<U: User + 'static> Action<OidcProvider, OidcSession> for OidcSignInCallbackAction<U> {
    fn id(&self) -> String {
        SignInCallbackAction::id()
    }

    fn name(&self) -> String {
        SignInCallbackAction::name()
    }

    fn openapi_summary(&self) -> &'static str {
        "Sign in callback for OpenID Connect"
    }

    fn openapi_description(&self) -> &'static str {
        "Sign in callback for OpenID Connect."
    }

    fn method(&self) -> ActionMethod {
        ActionMethod::Get
    }

    fn condition(
        &self,
        provider: &OidcProvider,
        session: &MethodSession<OidcSession>,
    ) -> Result<bool, ShieldError> {
        SignInCallbackAction::condition(provider, session)
    }

    async fn forms(&self, _provider: OidcProvider) -> Result<Vec<Form>, ShieldError> {
        Ok(vec![])
    }

    async fn call(
        &self,
        provider: OidcProvider,
        session: &MethodSession<OidcSession>,
        request: Request,
    ) -> Result<Response, ShieldError> {
        let OidcSession {
            csrf,
            nonce,
            pkce_verifier,
            ..
        } = &session.method;

        let state = request
            .query
            .get("state")
            .and_then(|code| code.as_str())
            .ok_or_else(|| ShieldError::Validation("Missing state.".to_owned()))?;

        if csrf.as_ref().is_none_or(|csrf| csrf != state) {
            return Err(ShieldError::Validation("Invalid state.".to_owned()));
        }

        let authorization_code = request
            .query
            .get("code")
            .and_then(|code| code.as_str())
            .ok_or_else(|| ShieldError::Validation("Missing authorization code.".to_owned()))?;

        let client = provider.oidc_client().await?;

        let mut token_request = client
            .exchange_code(AuthorizationCode::new(authorization_code.to_owned()))
            .map_err(|err| {
                ShieldError::Configuration(ConfigurationError::Missing(err.to_string()))
            })?;

        if let Some(pkce_verifier) = pkce_verifier {
            token_request =
                token_request.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_owned()));
        } else if provider.pkce_code_challenge != OidcProviderPkceCodeChallenge::None {
            return Err(ShieldError::Validation("Missing PKCE verifier.".to_owned()));
        }

        if let Some(token_url_params) = provider.token_url_params {
            let params = parse(token_url_params.trim_start_matches('?').as_bytes());

            for (name, value) in params {
                token_request =
                    token_request.add_extra_param(name.into_owned(), value.into_owned());
            }
        }

        let async_http_client = async_http_client()?;

        let token_response = token_request
            .request_async(&async_http_client)
            .await
            .map_err(|err| ShieldError::Request(err.to_string()))?;

        let claims = if let Some(id_token) = token_response.id_token() {
            let claims = id_token
                .claims(
                    &client.id_token_verifier(),
                    &Nonce::new(
                        nonce
                            .as_ref()
                            .ok_or_else(|| ShieldError::Validation("Missing nonce.".to_owned()))?
                            .to_owned(),
                    ),
                )
                .map_err(|err| ShieldError::Validation(err.to_string()))?;

            Claims::from(claims.clone())
        } else {
            let claims: UserInfoClaims<EmptyAdditionalClaims, CoreGenderClaim> = client
                .user_info(token_response.access_token().to_owned(), None)
                .map_err(|err| ConfigurationError::Missing(err.to_string()))?
                .request_async(&async_http_client)
                .await
                .map_err(|err| ShieldError::Request(err.to_string()))?;

            Claims::from(claims)
        };

        debug!("{:?}\n{:?}", claims.subject(), claims);

        let (connection, user) = match self
            .storage
            .oidc_connection_by_identifier(&provider.id, claims.subject())
            .await?
        {
            Some(connection) => {
                let connection = self
                    .update_oidc_connection(connection.id, token_response)
                    .await?;

                let user = self.update_user(&connection.user_id, &claims).await?;

                (connection, user)
            }
            None => {
                let user = self.create_user(&claims).await?;

                let connection = self
                    .create_oidc_connection(
                        provider.id.clone(),
                        user.id(),
                        claims.subject().to_string(),
                        token_response,
                    )
                    .await?;

                (connection, user)
            }
        };

        Ok(Response::new(ResponseType::Redirect(
            session
                .method
                .redirect_url
                .as_ref()
                .map(ToString::to_string)
                .unwrap_or_else(|| self.options.sign_in_redirect.clone()),
        ))
        .session_action(SessionAction::authenticate(user))
        .session_action(SessionAction::data(OidcSession {
            redirect_url: None,
            csrf: None,
            nonce: None,
            pkce_verifier: None,
            oidc_connection_id: Some(connection.id),
        })?))
    }
}

erased_action!(OidcSignInCallbackAction, <U: User>);

type ParsedTokenResponse = (
    String,
    SecretString,
    Option<SecretString>,
    Option<SecretString>,
    Option<DateTime<FixedOffset>>,
    Option<Vec<String>>,
);

fn parse_token_response(
    token_response: CoreTokenResponse,
) -> Result<ParsedTokenResponse, ShieldError> {
    Ok((
        token_response.token_type().as_ref().to_string(),
        token_response.access_token().secret().as_str().into(),
        token_response
            .refresh_token()
            .map(|refresh_token| refresh_token.secret().as_str().into()),
        token_response
            .id_token()
            .map(|id_token| id_token.to_string().into()),
        match token_response.expires_in() {
            Some(expires_in) => Some(
                (Utc::now()
                    + Duration::from_std(expires_in)
                        .map_err(|err| ShieldError::Validation(err.to_string()))?)
                .into(),
            ),
            None => None,
        },
        token_response
            .scopes()
            .map(|scopes| scopes.iter().map(|scope| scope.to_string()).collect()),
    ))
}