use rustauth_core::crypto::random::generate_random_string;
use rustauth_core::db::{
Create, DbAdapter, DbRecord, DbValue, Delete, FindMany, FindOne, Update, Where,
};
use rustauth_core::error::RustAuthError;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
#[cfg(feature = "oidc")]
use crate::oidc_impl::flow::oidc_redirect_uri;
use crate::options::OidcConfig;
#[cfg(feature = "saml")]
use crate::options::SamlConfig;
use crate::schema::SSO_PROVIDER_MODEL;
#[cfg(feature = "saml")]
use crate::utils::certificate_metadata;
use crate::utils::client_id_last_four;
const SSO_PROVIDER_FIELDS: [&str; 9] = [
"id",
"issuer",
"oidc_config",
"saml_config",
"user_id",
"provider_id",
"organization_id",
"domain",
"created_at",
];
const SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED: [&str; 10] = [
"id",
"issuer",
"oidc_config",
"saml_config",
"user_id",
"provider_id",
"organization_id",
"domain",
"domain_verified",
"created_at",
];
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SsoProviderRecord {
pub id: String,
pub issuer: String,
pub oidc_config: Option<String>,
pub saml_config: Option<String>,
pub user_id: String,
pub provider_id: String,
pub organization_id: Option<String>,
pub domain: String,
pub domain_verified: Option<bool>,
pub created_at: Option<OffsetDateTime>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SanitizedSsoProvider {
pub provider_id: String,
pub provider_type: String,
#[serde(rename = "type")]
pub upstream_type: String,
pub issuer: String,
pub domain: String,
pub organization_id: Option<String>,
pub domain_verified: bool,
pub oidc_config: Option<SanitizedOidcConfig>,
pub saml_config: Option<SanitizedSamlConfig>,
#[serde(skip_serializing_if = "Option::is_none", rename = "redirectURI")]
pub redirect_uri: Option<String>,
pub sp_metadata_url: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SanitizedOidcConfig {
pub discovery_endpoint: String,
pub client_id_last_four: String,
pub pkce: bool,
pub authorization_endpoint: Option<String>,
pub token_endpoint: Option<String>,
pub user_info_endpoint: Option<String>,
pub jwks_endpoint: Option<String>,
pub revocation_endpoint: Option<String>,
pub end_session_endpoint: Option<String>,
pub introspection_endpoint: Option<String>,
pub token_endpoint_authentication: Option<crate::options::TokenEndpointAuthentication>,
pub scopes: Option<Vec<String>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SanitizedSamlConfig {
pub entry_point: String,
pub callback_url: String,
pub acs_url: Option<String>,
pub audience: Option<String>,
pub want_assertions_signed: bool,
pub authn_requests_signed: bool,
pub identifier_format: Option<String>,
pub signature_algorithm: Option<String>,
pub digest_algorithm: Option<String>,
pub certificate_sha256_fingerprint: String,
pub certificate_not_before: Option<String>,
pub certificate_not_after: Option<String>,
pub certificate_public_key_algorithm: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub certificate_error: Option<String>,
}
#[derive(Clone, Copy)]
pub struct SsoProviderStore<'a> {
adapter: &'a dyn DbAdapter,
model_name: &'a str,
include_domain_verified: bool,
}
impl<'a> SsoProviderStore<'a> {
pub fn new(adapter: &'a dyn DbAdapter) -> Self {
Self::new_with_model(adapter, SSO_PROVIDER_MODEL)
}
pub fn new_with_model(adapter: &'a dyn DbAdapter, model_name: &'a str) -> Self {
Self {
adapter,
model_name,
include_domain_verified: false,
}
}
pub fn new_with_options(
adapter: &'a dyn DbAdapter,
options: &'a crate::options::SsoOptions,
) -> Self {
Self::new_with_model_and_domain_verification(
adapter,
&options.model_name,
options.domain_verification.enabled,
)
}
pub fn new_with_model_and_domain_verification(
adapter: &'a dyn DbAdapter,
model_name: &'a str,
include_domain_verified: bool,
) -> Self {
Self {
adapter,
model_name,
include_domain_verified,
}
}
pub async fn list(&self) -> Result<Vec<SsoProviderRecord>, RustAuthError> {
let query = self.select_find_many(FindMany::new(self.model_name));
self.adapter
.find_many(query)
.await?
.into_iter()
.map(record_from_db)
.collect()
}
pub async fn list_by_user(
&self,
user_id: &str,
) -> Result<Vec<SsoProviderRecord>, RustAuthError> {
let query = FindMany::new(self.model_name)
.where_clause(Where::new("user_id", DbValue::String(user_id.to_owned())));
self.adapter
.find_many(self.select_find_many(query))
.await?
.into_iter()
.map(record_from_db)
.collect()
}
pub async fn find_by_provider_id(
&self,
provider_id: &str,
) -> Result<Option<SsoProviderRecord>, RustAuthError> {
let query = FindOne::new(self.model_name).where_clause(provider_id_where(provider_id));
self.adapter
.find_one(self.select_find_one(query))
.await?
.map(record_from_db)
.transpose()
}
pub async fn find_by_organization_id(
&self,
organization_id: &str,
) -> Result<Option<SsoProviderRecord>, RustAuthError> {
let query = FindOne::new(self.model_name).where_clause(Where::new(
"organization_id",
DbValue::String(organization_id.to_owned()),
));
self.adapter
.find_one(self.select_find_one(query))
.await?
.map(record_from_db)
.transpose()
}
pub async fn create(
&self,
input: CreateSsoProviderInput,
) -> Result<SsoProviderRecord, RustAuthError> {
let now = OffsetDateTime::now_utc();
let mut query = Create::new(self.model_name)
.data("id", DbValue::String(generate_random_string(32)))
.data("issuer", DbValue::String(input.issuer))
.data("oidc_config", optional_string(input.oidc_config))
.data("saml_config", optional_string(input.saml_config))
.data("user_id", DbValue::String(input.user_id))
.data("provider_id", DbValue::String(input.provider_id))
.data("organization_id", optional_string(input.organization_id))
.data("domain", DbValue::String(input.domain))
.data("created_at", DbValue::Timestamp(now))
.data("updated_at", DbValue::Timestamp(now))
.force_allow_id();
query = self.select_create(query);
if let Some(domain_verified) = input.domain_verified {
query = query.data("domain_verified", DbValue::Boolean(domain_verified));
}
record_from_db(self.adapter.create(query).await?)
}
pub async fn update_domain_verified(
&self,
provider_id: &str,
verified: bool,
) -> Result<Option<SsoProviderRecord>, RustAuthError> {
self.adapter
.update(
Update::new(self.model_name)
.where_clause(provider_id_where(provider_id))
.data("domain_verified", DbValue::Boolean(verified))
.data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc())),
)
.await?
.map(record_from_db)
.transpose()
}
pub async fn update(
&self,
provider_id: &str,
input: UpdateSsoProviderInput,
) -> Result<Option<SsoProviderRecord>, RustAuthError> {
let mut query = Update::new(self.model_name)
.where_clause(provider_id_where(provider_id))
.data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
if let Some(issuer) = input.issuer {
query = query.data("issuer", DbValue::String(issuer));
}
if let Some(domain) = input.domain {
query = query.data("domain", DbValue::String(domain));
}
if let Some(organization_id) = input.organization_id {
query = query.data("organization_id", DbValue::String(organization_id));
}
if let Some(oidc_config) = input.oidc_config {
query = query.data("oidc_config", optional_string(oidc_config));
}
if let Some(saml_config) = input.saml_config {
query = query.data("saml_config", optional_string(saml_config));
}
if let Some(domain_verified) = input.domain_verified {
query = query.data("domain_verified", DbValue::Boolean(domain_verified));
}
self.adapter
.update(query)
.await?
.map(record_from_db)
.transpose()
}
pub async fn delete(&self, provider_id: &str) -> Result<(), RustAuthError> {
self.adapter
.delete(Delete::new(self.model_name).where_clause(provider_id_where(provider_id)))
.await
}
fn select_create(&self, query: Create) -> Create {
if self.include_domain_verified {
query.select(SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED)
} else {
query.select(SSO_PROVIDER_FIELDS)
}
}
fn select_find_one(&self, query: FindOne) -> FindOne {
if self.include_domain_verified {
query.select(SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED)
} else {
query.select(SSO_PROVIDER_FIELDS)
}
}
fn select_find_many(&self, query: FindMany) -> FindMany {
if self.include_domain_verified {
query.select(SSO_PROVIDER_FIELDS_WITH_DOMAIN_VERIFIED)
} else {
query.select(SSO_PROVIDER_FIELDS)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CreateSsoProviderInput {
pub provider_id: String,
pub issuer: String,
pub domain: String,
pub user_id: String,
pub organization_id: Option<String>,
pub oidc_config: Option<String>,
pub saml_config: Option<String>,
pub domain_verified: Option<bool>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct UpdateSsoProviderInput {
pub issuer: Option<String>,
pub domain: Option<String>,
pub organization_id: Option<String>,
pub oidc_config: Option<Option<String>>,
pub saml_config: Option<Option<String>>,
pub domain_verified: Option<bool>,
}
impl SsoProviderRecord {
pub fn sanitized_with_options(
&self,
base_url: &str,
options: Option<&crate::options::SsoOptions>,
) -> SanitizedSsoProvider {
let oidc_config = self
.oidc_config
.as_deref()
.and_then(|value| serde_json::from_str::<OidcConfig>(value).ok())
.map(|config| SanitizedOidcConfig {
discovery_endpoint: config.discovery_endpoint,
client_id_last_four: client_id_last_four(&config.client_id),
pkce: config.pkce,
authorization_endpoint: config.authorization_endpoint,
token_endpoint: config.token_endpoint,
user_info_endpoint: config.user_info_endpoint,
jwks_endpoint: config.jwks_endpoint,
revocation_endpoint: config.revocation_endpoint,
end_session_endpoint: config.end_session_endpoint,
introspection_endpoint: config.introspection_endpoint,
token_endpoint_authentication: config.token_endpoint_authentication,
scopes: config.scopes,
});
#[cfg(feature = "saml")]
let saml_config = self
.saml_config
.as_deref()
.and_then(|value| serde_json::from_str::<SamlConfig>(value).ok())
.map(|config| {
let certificate = certificate_metadata(&config.cert);
SanitizedSamlConfig {
entry_point: config.entry_point,
callback_url: config.callback_url,
acs_url: config.acs_url,
audience: config.audience,
want_assertions_signed: config.want_assertions_signed,
authn_requests_signed: config.authn_requests_signed,
identifier_format: config.identifier_format,
signature_algorithm: config.signature_algorithm,
digest_algorithm: config.digest_algorithm,
certificate_sha256_fingerprint: certificate.sha256_fingerprint,
certificate_not_before: certificate.not_before,
certificate_not_after: certificate.not_after,
certificate_public_key_algorithm: certificate.public_key_algorithm,
certificate_error: certificate.parse_error,
}
});
#[cfg(not(feature = "saml"))]
let saml_config = None;
let provider_type = if saml_config.is_some() {
"saml"
} else {
"oidc"
}
.to_owned();
#[cfg(feature = "oidc")]
let redirect_uri = oidc_config.as_ref().and_then(|_| {
options.map(|options| oidc_redirect_uri(base_url, &self.provider_id, options))
});
#[cfg(not(feature = "oidc"))]
let redirect_uri = None;
SanitizedSsoProvider {
provider_id: self.provider_id.clone(),
provider_type: provider_type.clone(),
upstream_type: provider_type,
issuer: self.issuer.clone(),
domain: self.domain.clone(),
organization_id: self.organization_id.clone(),
domain_verified: self.domain_verified.unwrap_or(false),
oidc_config,
saml_config,
redirect_uri,
sp_metadata_url: format!(
"{}/sso/saml2/sp/metadata?providerId={}",
base_url.trim_end_matches('/'),
url::form_urlencoded::byte_serialize(self.provider_id.as_bytes())
.collect::<String>()
),
}
}
pub fn sanitized(&self, base_url: &str) -> SanitizedSsoProvider {
self.sanitized_with_options(base_url, None)
}
}
fn provider_id_where(provider_id: &str) -> Where {
Where::new("provider_id", DbValue::String(provider_id.to_owned()))
}
fn optional_string(value: Option<String>) -> DbValue {
value.map(DbValue::String).unwrap_or(DbValue::Null)
}
fn record_from_db(record: DbRecord) -> Result<SsoProviderRecord, RustAuthError> {
Ok(SsoProviderRecord {
id: required_string(&record, "id")?.to_owned(),
issuer: required_string(&record, "issuer")?.to_owned(),
oidc_config: optional_string_field(&record, "oidc_config")?,
saml_config: optional_string_field(&record, "saml_config")?,
user_id: required_string(&record, "user_id")?.to_owned(),
provider_id: required_string(&record, "provider_id")?.to_owned(),
organization_id: optional_string_field(&record, "organization_id")?,
domain: required_string(&record, "domain")?.to_owned(),
domain_verified: optional_bool_field(&record, "domain_verified")?,
created_at: optional_timestamp_field(&record, "created_at")?,
})
}
fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
match record.get(field) {
Some(DbValue::String(value)) => Ok(value),
Some(_) => Err(invalid_field(field, "string")),
None => Err(missing_field(field)),
}
}
fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, RustAuthError> {
match record.get(field) {
Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
Some(DbValue::Json(value)) => serde_json::to_string(value)
.map(Some)
.map_err(|error| RustAuthError::Adapter(format!("invalid JSON in `{field}`: {error}"))),
Some(DbValue::Null) | None => Ok(None),
Some(_) => Err(invalid_field(field, "string, JSON, or null")),
}
}
fn optional_bool_field(record: &DbRecord, field: &str) -> Result<Option<bool>, RustAuthError> {
match record.get(field) {
Some(DbValue::Boolean(value)) => Ok(Some(*value)),
Some(DbValue::Null) | None => Ok(None),
Some(_) => Err(invalid_field(field, "boolean or null")),
}
}
fn optional_timestamp_field(
record: &DbRecord,
field: &str,
) -> Result<Option<OffsetDateTime>, RustAuthError> {
match record.get(field) {
Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
Some(DbValue::Null) | None => Ok(None),
Some(_) => Err(invalid_field(field, "timestamp or null")),
}
}
fn missing_field(field: &str) -> RustAuthError {
RustAuthError::Adapter(format!("sso provider record is missing `{field}`"))
}
fn invalid_field(field: &str, expected: &str) -> RustAuthError {
RustAuthError::Adapter(format!(
"sso provider record field `{field}` must be {expected}"
))
}