userp 0.0.4

Userp is a user account system inspired by Next Auth, with OAuth, password and email support, a ready-made Axum router with Askama templates, and more on the way!
Documentation
use std::{fmt::Display, future::Future};

use super::{ExchangeResult, OAuthProvider};
use crate::{
    config::Allow,
    oauth::{OAuthProviderUser, UnmatchedOAuthToken},
};
use anyhow::Context;
use async_trait::async_trait;
use oauth2::{
    basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
    ClientSecret, CsrfToken, RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
};
use url::Url;

pub type OAuthProviderUserCallbackResult = anyhow::Result<OAuthProviderUser>;

pub struct OAuthCustomProvider<F, Fut>
where
    Fut: Future<Output = OAuthProviderUserCallbackResult> + Send + Sync + 'static,
    F: Fn(String) -> Fut + Send + Sync + 'static,
{
    client: BasicClient,
    name: String,
    display_name: String,
    scopes: Vec<Scope>,
    get_user: Box<F>,
    allow_signup: Option<Allow>,
    allow_login: Option<Allow>,
    allow_linking: Option<bool>,
}

impl<
        F: Fn(String) -> Fut + Send + Sync,
        Fut: Send + Sync + Future<Output = OAuthProviderUserCallbackResult>,
    > std::fmt::Debug for OAuthCustomProvider<F, Fut>
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("OAuthProviderBaseWithUserCallback")
            .field("client", &self.client)
            .field("name", &self.name)
            .field("display_name", &self.display_name)
            .field("scopes", &self.scopes)
            .field("get_user", &"You dont want no part of this Dewey")
            .field("allow_signup", &self.allow_signup)
            .field("allow_login", &self.allow_login)
            .field("allow_linking", &self.allow_linking)
            .finish()
    }
}

impl<F, Fut> OAuthCustomProvider<F, Fut>
where
    Fut: Future<Output = OAuthProviderUserCallbackResult> + Send + Sync + 'static,
    F: Fn(String) -> Fut + Send + Sync + 'static,
{
    pub fn new_with_client_and_callback(
        name: impl Into<String>,
        display_name: impl Into<String>,
        scopes: &[impl Display],
        client: BasicClient,
        get_user: F,
    ) -> Result<OAuthCustomProvider<F, Fut>, anyhow::Error> {
        Ok(Self {
            allow_login: None,
            allow_signup: None,
            allow_linking: None,
            client,
            display_name: display_name.into(),
            scopes: scopes.iter().map(|s| Scope::new(s.to_string())).collect(),
            name: name.into(),
            get_user: Box::new(get_user),
        })
    }

    #[allow(clippy::too_many_arguments)]
    pub fn new_with_callback(
        name: impl Into<String>,
        display_name: impl Into<String>,
        client_id: impl Into<String>,
        client_secret: impl Into<String>,
        auth_url: impl Into<String>,
        token_url: impl Into<String>,
        scopes: &[impl Display],
        get_user: F,
    ) -> Result<OAuthCustomProvider<F, Fut>, anyhow::Error> {
        let client = BasicClient::new(
            ClientId::new(client_id.into()),
            Some(ClientSecret::new(client_secret.into())),
            AuthUrl::from_url(Url::parse(&auth_url.into())?),
            Some(TokenUrl::from_url(Url::parse(&token_url.into())?)),
        );

        Self::new_with_client_and_callback(name, display_name, scopes, client, get_user)
    }

    pub fn with_allow_signup(mut self, allow_signup: Option<Allow>) -> Self {
        self.allow_signup = allow_signup;
        self
    }

    pub fn with_allow_login(mut self, allow_login: Option<Allow>) -> Self {
        self.allow_login = allow_login;
        self
    }

    pub fn with_allow_linking(mut self, allow_linking: Option<bool>) -> Self {
        self.allow_linking = allow_linking;
        self
    }
}

#[async_trait]
impl<F, Fut> OAuthProvider for OAuthCustomProvider<F, Fut>
where
    Fut: Future<Output = OAuthProviderUserCallbackResult> + Send + Sync + 'static,
    F: Fn(String) -> Fut + Send + Sync + 'static,
{
    fn name(&self) -> &str {
        self.name.as_str()
    }

    fn display_name(&self) -> &str {
        self.display_name.as_str()
    }

    fn allow_signup(&self) -> Option<Allow> {
        self.allow_signup
    }

    fn allow_login(&self) -> Option<Allow> {
        self.allow_login
    }

    fn allow_linking(&self) -> Option<bool> {
        self.allow_linking
    }

    fn scopes(&self) -> &[Scope] {
        &self.scopes
    }

    fn get_authorization_url_and_state(
        &self,
        base_redirect_url: &RedirectUrl,
        scopes: &[Scope],
    ) -> (Url, CsrfToken) {
        self.client
            .clone()
            .set_redirect_uri(base_redirect_url.clone())
            .authorize_url(CsrfToken::new_random)
            .add_scopes(scopes.to_vec())
            .url()
    }

    async fn exchange_authorization_code(
        &self,
        provider_name: &str,
        redirect_url: &RedirectUrl,
        code: &AuthorizationCode,
    ) -> ExchangeResult {
        let res = self
            .client
            .clone()
            .set_redirect_uri(redirect_url.clone())
            .exchange_code(code.clone())
            .request_async(async_http_client)
            .await
            .context("Requesting authorization code exchange")?;

        let provider_user = (self.get_user)(res.access_token().secret().to_string()).await?;

        Ok(UnmatchedOAuthToken::from_standard_token_response(
            &res,
            provider_name,
            provider_user,
        ))
    }

    async fn exchange_refresh_token(
        &self,
        provider_name: &str,
        redirect_url: &RedirectUrl,
        refresh_token: &RefreshToken,
    ) -> ExchangeResult {
        let res = self
            .client
            .clone()
            .set_redirect_uri(redirect_url.clone())
            .exchange_refresh_token(refresh_token)
            .request_async(async_http_client)
            .await
            .context("Requesting refresh token exchange")?;

        let provider_user = (self.get_user)(res.access_token().secret().to_string()).await?;

        Ok(UnmatchedOAuthToken::from_standard_token_response(
            &res,
            provider_name,
            provider_user,
        ))
    }
}