samael 0.0.21

A SAML2 library for Rust
use chrono::prelude::*;
use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::Writer;
use serde::Deserialize;
use std::io::Cursor;

pub enum SubjectType<'a> {
    BaseId,
    NameId(&'a str),
    EncryptedId,
}

impl<'a> SubjectType<'a> {
    fn saml_element_name(&self) -> &'static str {
        match self {
            SubjectType::BaseId => "saml2:BaseID",
            SubjectType::NameId(_) => "saml2:NameID",
            SubjectType::EncryptedId => "saml2:EncryptedID",
        }
    }
}

impl<'a> TryFrom<SubjectType<'a>> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: SubjectType) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl<'a> TryFrom<&SubjectType<'a>> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &SubjectType) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let elem_name = value.saml_element_name();
        let root = BytesStart::new(elem_name);
        writer.write_event(Event::Start(root))?;
        if let SubjectType::NameId(content) = value {
            writer.write_event(Event::Text(BytesText::from_escaped(*content)))?;
        }
        writer.write_event(Event::End(BytesEnd::new(elem_name)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

const NAME: &str = "saml2:Subject";
const SCHEMA: (&str, &str) = ("xmlns:saml2", "urn:oasis:names:tc:SAML:2.0:assertion");

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct Subject {
    #[serde(rename = "NameID")]
    pub name_id: Option<SubjectNameID>,
    #[serde(rename = "SubjectConfirmation")]
    pub subject_confirmations: Option<Vec<SubjectConfirmation>>,
}

impl TryFrom<Subject> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: Subject) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&Subject> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &Subject) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let mut root = BytesStart::new(NAME);
        root.push_attribute(SCHEMA);

        writer.write_event(Event::Start(root))?;
        if let Some(name_id) = &value.name_id {
            let event: Event<'_> = name_id.try_into()?;
            writer.write_event(event)?;
        }
        if let Some(subject_confirmations) = &value.subject_confirmations {
            for confirmation in subject_confirmations {
                let event: Event<'_> = confirmation.try_into()?;
                writer.write_event(event)?;
            }
        }
        writer.write_event(Event::End(BytesEnd::new(NAME)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct SubjectNameID {
    #[serde(rename = "@Format")]
    pub format: Option<String>,

    #[serde(rename = "$value")]
    pub value: String,
}

impl SubjectNameID {
    fn name() -> &'static str {
        "saml2:NameID"
    }
}

impl TryFrom<SubjectNameID> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: SubjectNameID) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&SubjectNameID> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &SubjectNameID) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let mut root = BytesStart::new(SubjectNameID::name());

        if let Some(format) = &value.format {
            root.push_attribute(("Format", format.as_ref()));
        }

        writer.write_event(Event::Start(root))?;
        writer.write_event(Event::Text(BytesText::from_escaped(value.value.as_str())))?;
        writer.write_event(Event::End(BytesEnd::new(SubjectNameID::name())))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

const SUBJECT_CONFIRMATION_NAME: &str = "saml2:SubjectConfirmation";

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct SubjectConfirmation {
    #[serde(rename = "@Method")]
    pub method: Option<String>,
    #[serde(rename = "NameID")]
    pub name_id: Option<SubjectNameID>,
    #[serde(rename = "SubjectConfirmationData")]
    pub subject_confirmation_data: Option<SubjectConfirmationData>,
}

impl TryFrom<SubjectConfirmation> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: SubjectConfirmation) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&SubjectConfirmation> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &SubjectConfirmation) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let mut root = BytesStart::new(SUBJECT_CONFIRMATION_NAME);
        if let Some(method) = &value.method {
            root.push_attribute(("Method", method.as_ref()));
        }
        writer.write_event(Event::Start(root))?;
        if let Some(name_id) = &value.name_id {
            let event: Event<'_> = name_id.try_into()?;
            writer.write_event(event)?;
        }
        if let Some(subject_confirmation_data) = &value.subject_confirmation_data {
            let event: Event<'_> = subject_confirmation_data.try_into()?;
            writer.write_event(event)?;
        }
        writer.write_event(Event::End(BytesEnd::new(SUBJECT_CONFIRMATION_NAME)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

const SUBJECT_CONFIRMATION_DATA_NAME: &str = "saml2:SubjectConfirmationData";

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct SubjectConfirmationData {
    #[serde(rename = "@NotBefore")]
    pub not_before: Option<chrono::DateTime<Utc>>,
    #[serde(rename = "@NotOnOrAfter")]
    pub not_on_or_after: Option<chrono::DateTime<Utc>>,
    #[serde(rename = "@Recipient")]
    pub recipient: Option<String>,
    #[serde(rename = "@InResponseTo")]
    pub in_response_to: Option<String>,
    #[serde(rename = "@Address")]
    pub address: Option<String>,
    #[serde(rename = "$value")]
    pub content: Option<String>,
}

impl TryFrom<SubjectConfirmationData> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: SubjectConfirmationData) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&SubjectConfirmationData> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &SubjectConfirmationData) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let mut root = BytesStart::new(SUBJECT_CONFIRMATION_DATA_NAME);
        if let Some(not_before) = &value.not_before {
            root.push_attribute((
                "NotBefore",
                not_before
                    .to_rfc3339_opts(SecondsFormat::Millis, true)
                    .as_ref(),
            ));
        }
        if let Some(not_on_or_after) = &value.not_on_or_after {
            root.push_attribute((
                "NotOnOrAfter",
                not_on_or_after
                    .to_rfc3339_opts(SecondsFormat::Millis, true)
                    .as_ref(),
            ));
        }
        if let Some(recipient) = &value.recipient {
            root.push_attribute(("Recipient", recipient.as_ref()));
        }
        if let Some(in_response_to) = &value.in_response_to {
            root.push_attribute(("InResponseTo", in_response_to.as_ref()));
        }
        if let Some(address) = &value.address {
            root.push_attribute(("Address", address.as_ref()));
        }
        writer.write_event(Event::Start(root))?;
        if let Some(content) = &value.content {
            writer.write_event(Event::Text(BytesText::from_escaped(content)))?;
        }
        writer.write_event(Event::End(BytesEnd::new(SUBJECT_CONFIRMATION_DATA_NAME)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}