rustauth-saml 0.2.0

SAML 2.0 service-provider support for RustAuth.
Documentation
use opensaml::constants::Binding;
use opensaml::entity::{now_iso8601, CustomTagReplacement};
use opensaml::template::replace_tags_by_value;
use url::Url;

use crate::bridge::{
    create_identity_provider, create_service_provider, opensaml_error_code, SpBuildOptions,
};
use crate::options::SamlConfig;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SamlAuthnRequest {
    pub id: String,
    pub redirect_url: String,
}

pub fn build_authn_request_redirect(
    provider_id: &str,
    base_url: &str,
    config: &SamlConfig,
    request_id: String,
    relay_state: String,
) -> Result<SamlAuthnRequest, SamlAuthnRequestError> {
    if config.authn_requests_signed
        && config.private_key.is_none()
        && config.sp_metadata.private_key.is_none()
    {
        return Err(SamlAuthnRequestError::PrivateKeyRequired);
    }

    let sp = create_service_provider(
        config,
        base_url,
        provider_id,
        &SpBuildOptions {
            relay_state: Some(relay_state),
            ..Default::default()
        },
    )
    .map_err(map_authn_request_error)?;
    let idp = create_identity_provider(config).map_err(map_authn_request_error)?;
    let destination = idp
        .metadata
        .get_single_sign_on_service(Binding::Redirect)
        .unwrap_or_else(|| config.entry_point.clone());

    let provider_id = provider_id.to_owned();
    let base_url = base_url.to_owned();
    let config = config.clone();
    let request_id_for_custom = request_id.clone();
    let custom: CustomTagReplacement = &|template| {
        let acs = assertion_consumer_service_url(&provider_id, &base_url, &config);
        let issuer = config
            .sp_metadata
            .entity_id
            .as_deref()
            .unwrap_or(config.issuer.as_str())
            .to_owned();
        let name_id_format = config
            .identifier_format
            .clone()
            .unwrap_or_else(|| "urn:oasis:names:tc:SAML:2.0:nameid-format:transient".to_owned());
        let xml = replace_tags_by_value(
            template,
            &[
                ("ID", request_id_for_custom.clone()),
                ("IssueInstant", now_iso8601()),
                ("Destination", destination.clone()),
                ("AssertionConsumerServiceURL", acs),
                ("Issuer", issuer),
                ("NameIDFormat", name_id_format),
                ("AllowCreate", "true".to_string()),
            ],
        );
        (request_id_for_custom.clone(), xml)
    };

    let context = sp
        .create_login_request(&idp, Binding::Redirect, Some(custom))
        .map_err(map_authn_request_error)?;

    Ok(SamlAuthnRequest {
        id: request_id,
        redirect_url: context.context,
    })
}

pub fn authn_request_xml(
    provider_id: &str,
    base_url: &str,
    config: &SamlConfig,
    request_id: &str,
) -> Result<String, SamlAuthnRequestError> {
    let redirect = build_authn_request_redirect(
        provider_id,
        base_url,
        config,
        request_id.to_owned(),
        String::new(),
    )?;
    let url = Url::parse(&redirect.redirect_url)
        .map_err(|source| SamlAuthnRequestError::InvalidEntryPoint(source.to_string()))?;
    let encoded = url
        .query_pairs()
        .find(|(key, _)| key == "SAMLRequest")
        .map(|(_, value)| value.into_owned())
        .ok_or_else(|| SamlAuthnRequestError::Encode("missing SAMLRequest".to_owned()))?;
    decode_redirect_authn_request(&encoded)
}

pub fn assertion_consumer_service_url(
    provider_id: &str,
    base_url: &str,
    config: &SamlConfig,
) -> String {
    if let Some(acs_url) = config
        .acs_url
        .as_deref()
        .filter(|value| !value.trim().is_empty())
    {
        acs_url.to_owned()
    } else if config.callback_url.is_empty() {
        format!(
            "{}/sso/saml2/sp/acs/{}",
            base_url.trim_end_matches('/'),
            provider_id
        )
    } else {
        config.callback_url.clone()
    }
}

fn decode_redirect_authn_request(encoded: &str) -> Result<String, SamlAuthnRequestError> {
    let bytes = base64::Engine::decode(
        &base64::engine::general_purpose::STANDARD,
        encoded.as_bytes(),
    )
    .map_err(|source| SamlAuthnRequestError::Encode(source.to_string()))?;
    let mut decoder = flate2::read::DeflateDecoder::new(bytes.as_slice());
    let mut xml = String::new();
    std::io::Read::read_to_string(&mut decoder, &mut xml)
        .map_err(|source| SamlAuthnRequestError::Encode(source.to_string()))?;
    Ok(xml)
}

fn map_authn_request_error(error: opensaml::error::OpenSamlError) -> SamlAuthnRequestError {
    match &error {
        opensaml::error::OpenSamlError::MissingKey(_) => SamlAuthnRequestError::PrivateKeyRequired,
        opensaml::error::OpenSamlError::Unsupported(_) => {
            SamlAuthnRequestError::SigningNotSupported
        }
        opensaml::error::OpenSamlError::Invalid(message) if message.contains("ENTRY_POINT") => {
            SamlAuthnRequestError::InvalidEntryPoint(message.clone())
        }
        other => SamlAuthnRequestError::Sign(format!("{other} ({})", opensaml_error_code(other))),
    }
}

#[derive(Debug, thiserror::Error)]
pub enum SamlAuthnRequestError {
    #[error("invalid SAML entry point: {0}")]
    InvalidEntryPoint(String),
    #[error("failed to encode SAML AuthnRequest: {0}")]
    Encode(String),
    #[error("signed SAML AuthnRequests require SP private key support")]
    SigningNotSupported,
    #[error("signed SAML AuthnRequests require SP private key material")]
    PrivateKeyRequired,
    #[error("invalid SAML AuthnRequest private key: {0}")]
    InvalidPrivateKey(String),
    #[error("{0}")]
    Sign(String),
}