samael 0.0.21

A SAML2 library for Rust
use crate::crypto::CertificateDer;
#[cfg(feature = "xmlsec")]
use crate::crypto::{Crypto, CryptoProvider};
use crate::schema::{Conditions, Issuer, NameIdPolicy, RequestedAuthnContext, Subject};
use crate::signature::Signature;
use chrono::prelude::*;
use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
use quick_xml::Writer;
use serde::Deserialize;
use std::io::Cursor;
use std::str::FromStr;
use thiserror::Error;

const NAME: &str = "saml2p:AuthnRequest";
const SCHEMA: (&str, &str) = ("xmlns:saml2p", "urn:oasis:names:tc:SAML:2.0:protocol");

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct AuthnRequest {
    #[serde(rename = "@ID")]
    pub id: String,
    #[serde(rename = "@Version")]
    pub version: String,
    #[serde(rename = "@IssueInstant")]
    pub issue_instant: DateTime<Utc>,
    #[serde(rename = "@Destination")]
    pub destination: Option<String>,
    #[serde(rename = "@Consent")]
    pub consent: Option<String>,
    #[serde(rename = "Issuer")]
    pub issuer: Option<Issuer>,
    #[serde(rename = "Signature")]
    pub signature: Option<Signature>,
    #[serde(rename = "Subject")]
    pub subject: Option<Subject>,
    #[serde(rename = "NameIDPolicy")]
    pub name_id_policy: Option<NameIdPolicy>,
    #[serde(rename = "Conditions")]
    pub conditions: Option<Conditions>,
    #[serde(rename = "RequestedAuthnContext")]
    pub requested_authn_context: Option<RequestedAuthnContext>,
    #[serde(rename = "@ForceAuthn")]
    pub force_authn: Option<bool>,
    #[serde(rename = "@IsPassive")]
    pub is_passive: Option<bool>,
    #[serde(rename = "@AssertionConsumerServiceIndex")]
    pub assertion_consumer_service_index: Option<usize>,
    #[serde(rename = "@AssertionConsumerServiceURL")]
    pub assertion_consumer_service_url: Option<String>,
    #[serde(rename = "@ProtocolBinding")]
    pub protocol_binding: Option<String>,
    #[serde(rename = "@AttributeConsumingServiceIndex")]
    pub attribute_consuming_service_index: Option<usize>,
    #[serde(rename = "@ProviderName")]
    pub provider_name: Option<String>,
}

impl Default for AuthnRequest {
    fn default() -> Self {
        AuthnRequest {
            id: "".to_string(),
            version: "".to_string(),
            issue_instant: Utc::now(),
            destination: None,
            consent: None,
            issuer: None,
            signature: None,
            subject: None,
            name_id_policy: None,
            conditions: None,
            requested_authn_context: None,
            force_authn: None,
            is_passive: None,
            assertion_consumer_service_index: None,
            assertion_consumer_service_url: None,
            protocol_binding: None,
            attribute_consuming_service_index: None,
            provider_name: None,
        }
    }
}

#[derive(Debug, Error)]
pub enum Error {
    #[error("Failed to deserialize AuthnRequest: {:?}", source)]
    ParseError {
        #[from]
        source: quick_xml::DeError,
    },

    #[error("No subject name ID")]
    NoSubjectNameID,
}

impl FromStr for AuthnRequest {
    type Err = Error;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Ok(quick_xml::de::from_str(s)?)
    }
}

impl AuthnRequest {
    pub fn subject_name_id(&self) -> Result<String, Error> {
        Ok(self
            .subject
            .clone()
            .and_then(|s| s.name_id)
            .ok_or(Error::NoSubjectNameID)?
            .value)
    }

    pub fn issuer_at(&self) -> &DateTime<Utc> {
        &self.issue_instant
    }

    pub fn issuer_value(&self) -> Option<String> {
        self.issuer.clone().and_then(|iss| iss.value)
    }

    pub fn add_key_info(&mut self, public_cert_der: &CertificateDer) -> &mut Self {
        if let Some(ref mut signature) = self.signature {
            signature.add_key_info(public_cert_der);
        }
        self
    }

    #[cfg(feature = "xmlsec")]
    pub fn to_signed_xml(
        &self,
        private_key_der: &[u8],
    ) -> Result<String, Box<dyn std::error::Error>> {
        use crate::traits::ToXml;

        Crypto::sign_xml(self.to_string()?, private_key_der)
            .map_err(|crypto_error| Box::new(crypto_error) as Box<dyn std::error::Error>)
    }
}

impl TryFrom<AuthnRequest> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: AuthnRequest) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&AuthnRequest> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &AuthnRequest) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        writer.write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))?;

        let mut root = BytesStart::new(NAME);
        root.push_attribute(SCHEMA);
        root.push_attribute(("ID", value.id.as_ref()));
        root.push_attribute(("Version", value.version.as_ref()));
        root.push_attribute((
            "IssueInstant",
            value
                .issue_instant
                .to_rfc3339_opts(SecondsFormat::Millis, true)
                .as_ref(),
        ));

        if let Some(destination) = &value.destination {
            root.push_attribute(("Destination", destination.as_ref()));
        }
        if let Some(consent) = &value.consent {
            root.push_attribute(("Consent", consent.as_ref()));
        }
        if let Some(force_authn) = &value.force_authn {
            root.push_attribute(("ForceAuthn", force_authn.to_string().as_ref()));
        }
        if let Some(is_passive) = &value.is_passive {
            root.push_attribute(("IsPassive", is_passive.to_string().as_ref()));
        }
        if let Some(protocol_binding) = &value.protocol_binding {
            root.push_attribute(("ProtocolBinding", protocol_binding.as_ref()));
        }
        if let Some(assertion_consumer_service_index) = &value.assertion_consumer_service_index {
            root.push_attribute((
                "AssertionConsumerServiceIndex",
                assertion_consumer_service_index.to_string().as_ref(),
            ));
        }
        if let Some(assertion_consumer_service_url) = &value.assertion_consumer_service_url {
            root.push_attribute((
                "AssertionConsumerServiceURL",
                assertion_consumer_service_url.as_ref(),
            ));
        }
        if let Some(attribute_consuming_service_index) = &value.attribute_consuming_service_index {
            root.push_attribute((
                "AttributeConsumingServiceIndex",
                attribute_consuming_service_index.to_string().as_ref(),
            ));
        }
        if let Some(provider_name) = &value.provider_name {
            root.push_attribute(("ProviderName", provider_name.as_ref()));
        }
        writer.write_event(Event::Start(root))?;

        if let Some(issuer) = &value.issuer {
            let event: Event<'_> = issuer.try_into()?;
            writer.write_event(event)?;
        }
        if let Some(signature) = &value.signature {
            let event: Event<'_> = signature.try_into()?;
            writer.write_event(event)?;
        }
        if let Some(subject) = &value.subject {
            let event: Event<'_> = subject.try_into()?;
            writer.write_event(event)?;
        }
        if let Some(name_id_policy) = &value.name_id_policy {
            let event: Event<'_> = name_id_policy.try_into()?;
            writer.write_event(event)?;
        }
        if let Some(conditions) = &value.conditions {
            let event: Event<'_> = conditions.try_into()?;
            writer.write_event(event)?;
        }
        if let Some(requested_authn_context) = &value.requested_authn_context {
            let event: Event<'_> = requested_authn_context.try_into()?;
            writer.write_event(event)?;
        }

        writer.write_event(Event::End(BytesEnd::new(NAME)))?;

        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

#[cfg(all(test, feature = "xmlsec"))]
mod test {
    use super::*;
    use crate::crypto::UrlVerifier;

    #[test]
    #[cfg(feature = "xmlsec")]
    pub fn test_signed_authn() -> Result<(), Box<dyn std::error::Error>> {
        let private_key = include_bytes!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/private.der"
        ));

        let public_cert = include_bytes!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/public.der"
        ))
        .to_vec()
        .into();

        let authn_request_sign_template = include_str!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/authn_request_sign_template.xml"
        ));

        let signed_authn_request = authn_request_sign_template
            .parse::<AuthnRequest>()?
            .add_key_info(&public_cert)
            .to_signed_xml(private_key)?;

        assert!(Crypto::verify_signed_xml(signed_authn_request, &public_cert, Some("ID"),).is_ok());

        Ok(())
    }

    #[test]
    pub fn test_redirect_signature() -> Result<(), Box<dyn std::error::Error>> {
        let private_key = include_bytes!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/private.der"
        ));

        let public_key = include_bytes!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/public_key.pem"
        ));

        let authn_request_sign_template = include_str!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/authn_request_sign_template.xml"
        ));

        let private_key = openssl::rsa::Rsa::private_key_from_der(private_key).unwrap();
        let private_key = openssl::pkey::PKey::from_rsa(private_key).unwrap();

        let signed_authn_redirect_url = authn_request_sign_template
            .parse::<AuthnRequest>()?
            .signed_redirect("", &private_key)?
            .unwrap();

        let url_verifier = UrlVerifier::from_rsa_pem(public_key)?;
        assert!(url_verifier.verify_signed_request_url(&signed_authn_redirect_url)?);

        Ok(())
    }

    #[test]
    pub fn test_redirect_signature_with_relaystate() -> Result<(), Box<dyn std::error::Error>> {
        let private_key = include_bytes!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/private.der"
        ));

        let public_key = include_bytes!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/public_key.der"
        ));

        let authn_request_sign_template = include_str!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/authn_request_sign_template.xml"
        ));

        let private_key = openssl::rsa::Rsa::private_key_from_der(private_key).unwrap();
        let private_key = openssl::pkey::PKey::from_rsa(private_key).unwrap();

        let signed_authn_redirect_url = authn_request_sign_template
            .parse::<AuthnRequest>()?
            .signed_redirect("some_relay_state_here", &private_key)?
            .unwrap();

        let url_verifier = UrlVerifier::from_rsa_der(public_key)?;
        assert!(url_verifier.verify_signed_request_url(&signed_authn_redirect_url)?);

        Ok(())
    }

    #[test]
    pub fn test_redirect_signature_with_relaystate_using_x509_cert(
    ) -> Result<(), Box<dyn std::error::Error>> {
        let private_key = include_bytes!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/private.der"
        ));

        let public_cert = include_str!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/idp_2_metadata_public.pem"
        ));

        let authn_request_sign_template = include_str!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/authn_request_sign_template.xml"
        ));

        let private_key = openssl::rsa::Rsa::private_key_from_der(private_key).unwrap();
        let private_key = openssl::pkey::PKey::from_rsa(private_key).unwrap();

        let signed_authn_redirect_url = authn_request_sign_template
            .parse::<AuthnRequest>()?
            .signed_redirect("some_relay_state_here", &private_key)?
            .unwrap();

        let url_verifier = UrlVerifier::from_x509_cert_pem(public_cert)?;
        assert!(url_verifier.verify_signed_request_url(&signed_authn_redirect_url)?);

        Ok(())
    }
}