openauth-core 0.0.6

Core types and primitives for OpenAuth.
Documentation
#[cfg(feature = "oauth")]
use http::StatusCode;
use serde::{Deserialize, Serialize};
#[cfg(feature = "oauth")]
use serde_json::{json, Value};
use time::OffsetDateTime;

#[cfg(feature = "oauth")]
use super::super::shared::{error_response, json_openapi_response};
#[cfg(feature = "oauth")]
use crate::api::ApiResponse;
use crate::api::{BodyField, BodySchema, JsonSchemaType};
#[cfg(feature = "oauth")]
use crate::auth::oauth::{decrypt_oauth_token, set_token_util};
use crate::db::Account;
#[cfg(feature = "oauth")]
use crate::error::OpenAuthError;
#[cfg(feature = "oauth")]
use crate::user::{DbUserStore, UpdateAccountInput};
#[cfg(feature = "oauth")]
use openauth_oauth::oauth2::{OAuth2Tokens, OAuth2UserInfo};

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(super) struct UnlinkAccountBody {
    pub(super) provider_id: String,
    pub(super) account_id: Option<String>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
#[cfg(feature = "oauth")]
pub(super) struct TokenBody {
    pub(super) provider_id: String,
    pub(super) account_id: Option<String>,
    pub(super) user_id: Option<String>,
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub(super) struct AccountResponse {
    id: String,
    provider_id: String,
    account_id: String,
    user_id: String,
    scopes: Vec<String>,
    created_at: OffsetDateTime,
    updated_at: OffsetDateTime,
}

#[derive(Debug, Serialize)]
pub(super) struct StatusBody {
    pub(super) status: bool,
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
#[cfg(feature = "oauth")]
pub(super) struct AccessTokenResponse {
    pub(super) access_token: Option<String>,
    pub(super) access_token_expires_at: Option<OffsetDateTime>,
    pub(super) scopes: Vec<String>,
    pub(super) id_token: Option<String>,
    pub(super) token_type: Option<String>,
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
#[cfg(feature = "oauth")]
pub(super) struct RefreshTokenResponse {
    pub(super) access_token: String,
    pub(super) refresh_token: String,
    pub(super) access_token_expires_at: Option<OffsetDateTime>,
    pub(super) refresh_token_expires_at: Option<OffsetDateTime>,
    pub(super) scope: Option<String>,
    pub(super) id_token: Option<String>,
    pub(super) provider_id: String,
    pub(super) account_id: String,
    pub(super) token_type: Option<String>,
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
#[cfg(feature = "oauth")]
pub(super) struct AccountInfoResponse {
    pub(super) user: AccountInfoUser,
    pub(super) data: Value,
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
#[cfg(feature = "oauth")]
pub(super) struct AccountInfoUser {
    id: String,
    name: Option<String>,
    email: Option<String>,
    image: Option<String>,
    email_verified: bool,
}

impl From<Account> for AccountResponse {
    fn from(account: Account) -> Self {
        let scopes = account_scopes(&account);
        Self {
            id: account.id,
            provider_id: account.provider_id,
            account_id: account.account_id,
            user_id: account.user_id,
            scopes,
            created_at: account.created_at,
            updated_at: account.updated_at,
        }
    }
}

#[cfg(feature = "oauth")]
impl From<OAuth2UserInfo> for AccountInfoUser {
    fn from(user: OAuth2UserInfo) -> Self {
        Self {
            id: user.id,
            name: user.name,
            email: user.email,
            image: user.image,
            email_verified: user.email_verified,
        }
    }
}

pub(super) fn unlink_account_body_schema() -> BodySchema {
    BodySchema::object([
        BodyField::new("providerId", JsonSchemaType::String)
            .description("The provider ID of the account to unlink"),
        BodyField::optional("accountId", JsonSchemaType::String)
            .description("The account ID to unlink"),
    ])
}

#[cfg(feature = "oauth")]
pub(super) fn token_body_schema() -> BodySchema {
    BodySchema::object([
        BodyField::new("providerId", JsonSchemaType::String)
            .description("The provider ID for the OAuth provider"),
        BodyField::optional("accountId", JsonSchemaType::String)
            .description("The account ID associated with the refresh token"),
        BodyField::optional("userId", JsonSchemaType::String)
            .description("The user ID associated with the account"),
    ])
}

pub(super) fn account_openapi_schema() -> serde_json::Value {
    serde_json::json!({
        "type": "object",
        "properties": {
            "id": { "type": "string" },
            "providerId": { "type": "string" },
            "accountId": { "type": "string" },
            "userId": { "type": "string" },
            "scopes": {
                "type": "array",
                "items": { "type": "string" },
            },
            "createdAt": { "type": "string", "format": "date-time" },
            "updatedAt": { "type": "string", "format": "date-time" },
        },
        "required": [
            "id",
            "providerId",
            "accountId",
            "userId",
            "scopes",
            "createdAt",
            "updatedAt"
        ],
    })
}

#[cfg(feature = "oauth")]
pub(super) fn token_openapi_response(include_refresh: bool) -> serde_json::Value {
    let mut properties = serde_json::Map::new();
    properties.insert("tokenType".to_owned(), json!({ "type": "string" }));
    properties.insert("idToken".to_owned(), json!({ "type": "string" }));
    properties.insert("accessToken".to_owned(), json!({ "type": "string" }));
    properties.insert(
        "accessTokenExpiresAt".to_owned(),
        json!({ "type": "string", "format": "date-time" }),
    );
    if include_refresh {
        properties.insert("refreshToken".to_owned(), json!({ "type": "string" }));
        properties.insert(
            "refreshTokenExpiresAt".to_owned(),
            json!({ "type": "string", "format": "date-time" }),
        );
    }
    json_openapi_response(
        "Success",
        Value::Object(
            [
                ("type".to_owned(), json!("object")),
                ("properties".to_owned(), Value::Object(properties)),
            ]
            .into_iter()
            .collect(),
        ),
    )
}

#[cfg(feature = "oauth")]
pub(super) async fn find_user_account(
    users: &DbUserStore<'_>,
    user_id: &str,
    provider_id: &str,
    account_id: Option<&str>,
) -> Result<Option<Account>, OpenAuthError> {
    let accounts = users.list_accounts_for_user(user_id).await?;
    Ok(accounts.into_iter().find(|account| {
        account.provider_id == provider_id
            && account_id
                .map(|account_id| account.account_id == account_id)
                .unwrap_or(true)
    }))
}

#[cfg(feature = "oauth")]
pub(super) fn should_refresh(account: &Account) -> bool {
    account
        .access_token_expires_at
        .map(|expires_at| expires_at - OffsetDateTime::now_utc() < time::Duration::seconds(5))
        .unwrap_or(false)
}

#[cfg(feature = "oauth")]
pub(super) async fn persist_refreshed_tokens(
    context: &crate::context::AuthContext,
    users: &DbUserStore<'_>,
    account: Account,
    tokens: OAuth2Tokens,
    fallback_refresh_token: Option<&str>,
) -> Result<Account, OpenAuthError> {
    let access_token = match tokens.access_token.as_deref() {
        Some(token) => set_token_util(Some(token), context)?,
        None => account.access_token.clone(),
    };
    let refresh_token = match tokens.refresh_token.as_deref().or(fallback_refresh_token) {
        Some(token) => set_token_util(Some(token), context)?,
        None => account.refresh_token.clone(),
    };
    let id_token = tokens.id_token.clone().or_else(|| account.id_token.clone());
    let access_token_expires_at = tokens
        .access_token_expires_at
        .or(account.access_token_expires_at);
    let refresh_token_expires_at = tokens
        .refresh_token_expires_at
        .or(account.refresh_token_expires_at);
    let scope = if tokens.scopes.is_empty() {
        account.scope.clone()
    } else {
        Some(tokens.scopes.join(","))
    };
    users
        .update_account(
            &account.id,
            UpdateAccountInput {
                access_token: Some(access_token),
                refresh_token: Some(refresh_token),
                id_token: Some(id_token),
                access_token_expires_at: Some(access_token_expires_at),
                refresh_token_expires_at: Some(refresh_token_expires_at),
                scope: Some(scope),
            },
        )
        .await?
        .ok_or_else(|| OpenAuthError::Adapter("failed to update account".to_owned()))
}

#[cfg(feature = "oauth")]
pub(super) fn access_token_response_from_tokens(
    tokens: OAuth2Tokens,
    account: &Account,
) -> AccessTokenResponse {
    AccessTokenResponse {
        access_token: tokens.access_token,
        access_token_expires_at: tokens
            .access_token_expires_at
            .or(account.access_token_expires_at),
        scopes: if tokens.scopes.is_empty() {
            account_scopes(account)
        } else {
            tokens.scopes
        },
        id_token: tokens.id_token.or_else(|| account.id_token.clone()),
        token_type: tokens.token_type,
    }
}

#[cfg(feature = "oauth")]
pub(super) fn tokens_from_account(
    context: &crate::context::AuthContext,
    account: &Account,
) -> Result<OAuth2Tokens, OpenAuthError> {
    Ok(OAuth2Tokens {
        access_token: account
            .access_token
            .as_deref()
            .map(|token| decrypt_oauth_token(token, context))
            .transpose()?,
        refresh_token: account
            .refresh_token
            .as_deref()
            .map(|token| decrypt_oauth_token(token, context))
            .transpose()?,
        access_token_expires_at: account.access_token_expires_at,
        refresh_token_expires_at: account.refresh_token_expires_at,
        scopes: account_scopes(account),
        id_token: account.id_token.clone(),
        ..OAuth2Tokens::default()
    })
}

#[cfg(feature = "oauth")]
pub(super) fn is_refresh_unsupported(error: &openauth_oauth::oauth2::OAuthError) -> bool {
    error
        .to_string()
        .contains("does not support refresh tokens")
}

#[cfg(feature = "oauth")]
pub(super) fn provider_not_supported(provider_id: &str) -> Result<ApiResponse, OpenAuthError> {
    error_response(
        StatusCode::BAD_REQUEST,
        "PROVIDER_NOT_SUPPORTED",
        format!("Provider {provider_id} is not supported."),
    )
}

#[cfg(feature = "oauth")]
pub(super) fn account_not_found() -> Result<ApiResponse, OpenAuthError> {
    error_response(
        StatusCode::BAD_REQUEST,
        "ACCOUNT_NOT_FOUND",
        "Account not found",
    )
}

pub(super) fn account_scopes(account: &Account) -> Vec<String> {
    account
        .scope
        .as_deref()
        .map(|scope| {
            scope
                .split(',')
                .filter(|scope| !scope.is_empty())
                .map(str::to_owned)
                .collect()
        })
        .unwrap_or_default()
}