rustauth-scim 0.3.0

SCIM support for RustAuth.
Documentation
//! Adapter-backed SCIM storage helpers.

use rustauth_core::crypto::random::generate_random_string;
use rustauth_core::db::{
    Create, DbAdapter, DbRecord, DbValue, Delete, FindMany, FindOne, Update, Where,
};
use rustauth_core::error::RustAuthError;
use serde::{Deserialize, Serialize};

use crate::schema::SCIM_PROVIDER_MODEL;

const SCIM_PROVIDER_FIELDS: [&str; 5] = [
    "id",
    "provider_id",
    "scim_token",
    "organization_id",
    "user_id",
];

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ScimProviderRecord {
    pub id: String,
    pub provider_id: String,
    pub scim_token: String,
    pub organization_id: Option<String>,
    pub user_id: Option<String>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CreateScimProviderInput {
    pub provider_id: String,
    pub scim_token: String,
    pub organization_id: Option<String>,
    pub user_id: Option<String>,
}

#[derive(Clone, Copy)]
pub struct ScimProviderStore<'a> {
    adapter: &'a dyn DbAdapter,
}

impl<'a> ScimProviderStore<'a> {
    pub fn new(adapter: &'a dyn DbAdapter) -> Self {
        Self { adapter }
    }

    pub async fn create(
        &self,
        input: CreateScimProviderInput,
    ) -> Result<ScimProviderRecord, RustAuthError> {
        let record = self
            .adapter
            .create(
                Create::new(SCIM_PROVIDER_MODEL)
                    .data("id", DbValue::String(generate_random_string(32)))
                    .data("provider_id", DbValue::String(input.provider_id))
                    .data("scim_token", DbValue::String(input.scim_token))
                    .data("organization_id", optional_string(input.organization_id))
                    .data("user_id", optional_string(input.user_id))
                    .select(SCIM_PROVIDER_FIELDS)
                    .force_allow_id(),
            )
            .await?;

        record_from_db(record)
    }

    pub async fn list(&self) -> Result<Vec<ScimProviderRecord>, RustAuthError> {
        self.adapter
            .find_many(FindMany::new(SCIM_PROVIDER_MODEL).select(SCIM_PROVIDER_FIELDS))
            .await?
            .into_iter()
            .map(record_from_db)
            .collect()
    }

    #[cfg_attr(not(feature = "test-util"), allow(dead_code))]
    pub async fn list_by_user(
        &self,
        user_id: &str,
    ) -> Result<Vec<ScimProviderRecord>, RustAuthError> {
        self.adapter
            .find_many(
                FindMany::new(SCIM_PROVIDER_MODEL)
                    .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())))
                    .select(SCIM_PROVIDER_FIELDS),
            )
            .await?
            .into_iter()
            .map(record_from_db)
            .collect()
    }

    pub async fn find_by_provider_id(
        &self,
        provider_id: &str,
    ) -> Result<Option<ScimProviderRecord>, RustAuthError> {
        self.adapter
            .find_one(
                FindOne::new(SCIM_PROVIDER_MODEL)
                    .where_clause(provider_id_where(provider_id))
                    .select(SCIM_PROVIDER_FIELDS),
            )
            .await?
            .map(record_from_db)
            .transpose()
    }

    #[cfg_attr(not(feature = "test-util"), allow(dead_code))]
    pub async fn find_by_organization_id(
        &self,
        organization_id: &str,
    ) -> Result<Option<ScimProviderRecord>, RustAuthError> {
        self.adapter
            .find_one(
                FindOne::new(SCIM_PROVIDER_MODEL)
                    .where_clause(Where::new(
                        "organization_id",
                        DbValue::String(organization_id.to_owned()),
                    ))
                    .select(SCIM_PROVIDER_FIELDS),
            )
            .await?
            .map(record_from_db)
            .transpose()
    }

    pub async fn delete(&self, provider_id: &str) -> Result<(), RustAuthError> {
        self.adapter
            .delete(Delete::new(SCIM_PROVIDER_MODEL).where_clause(provider_id_where(provider_id)))
            .await
    }

    pub async fn upsert(
        &self,
        input: CreateScimProviderInput,
    ) -> Result<ScimProviderRecord, RustAuthError> {
        if self
            .find_by_provider_id(&input.provider_id)
            .await?
            .is_some()
        {
            self.update(input).await
        } else {
            self.create(input).await
        }
    }

    async fn update(
        &self,
        input: CreateScimProviderInput,
    ) -> Result<ScimProviderRecord, RustAuthError> {
        self.adapter
            .update(
                Update::new(SCIM_PROVIDER_MODEL)
                    .where_clause(provider_id_where(&input.provider_id))
                    .data("scim_token", DbValue::String(input.scim_token))
                    .data("organization_id", optional_string(input.organization_id))
                    .data("user_id", optional_string(input.user_id)),
            )
            .await?;
        self.find_by_provider_id(&input.provider_id)
            .await?
            .ok_or_else(|| {
                RustAuthError::Adapter(format!(
                    "SCIM provider `{}` was not found after update",
                    input.provider_id
                ))
            })
    }
}

fn provider_id_where(provider_id: &str) -> Where {
    Where::new("provider_id", DbValue::String(provider_id.to_owned()))
}

fn optional_string(value: Option<String>) -> DbValue {
    value.map(DbValue::String).unwrap_or(DbValue::Null)
}

fn record_from_db(record: DbRecord) -> Result<ScimProviderRecord, RustAuthError> {
    Ok(ScimProviderRecord {
        id: required_string(&record, "id")?.to_owned(),
        provider_id: required_string(&record, "provider_id")?.to_owned(),
        scim_token: required_string(&record, "scim_token")?.to_owned(),
        organization_id: optional_string_field(&record, "organization_id")?,
        user_id: optional_string_field(&record, "user_id")?,
    })
}

fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
    match record.get(field) {
        Some(DbValue::String(value)) => Ok(value),
        Some(_) => Err(invalid_field(field, "string")),
        None => Err(missing_field(field)),
    }
}

fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, RustAuthError> {
    match record.get(field) {
        Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
        Some(DbValue::Null) | None => Ok(None),
        Some(_) => Err(invalid_field(field, "string or null")),
    }
}

fn missing_field(field: &str) -> RustAuthError {
    RustAuthError::Adapter(format!("SCIM provider record is missing `{field}`"))
}

fn invalid_field(field: &str, expected: &str) -> RustAuthError {
    RustAuthError::Adapter(format!(
        "SCIM provider record field `{field}` must be {expected}"
    ))
}