samael 0.0.21

A SAML2 library for Rust
use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::Writer;
use serde::Deserialize;
use std::io::Cursor;

use crate::crypto::{Crypto, CryptoProvider};
use crate::key_info::{EncryptedKeyInfo, KeyInfo};
use crate::schema::Assertion;
use crate::service_provider::Error;
use crate::signature::DigestMethod;

const NAME: &str = "saml2:EncryptedAssertion";
const SCHEMA: (&str, &str) = ("xmlns:saml2", "urn:oasis:names:tc:SAML:2.0:assertion");

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct EncryptedAssertion {
    #[serde(rename = "EncryptedData")]
    pub data: Option<EncryptedData>,
    #[serde(rename = "EncryptedKey")]
    pub encrypted_key: Option<EncryptedKey>,
}

impl EncryptedAssertion {
    pub fn encrypted_key_info(&self) -> Option<(&CipherValue, &String)> {
        self.data.as_ref().and_then(|ed| ed.key_info()).or_else(|| {
            self.encrypted_key
                .as_ref()
                .and_then(|e| e.cipher_data.as_ref().zip(e.encryption_method.as_ref()))
                .and_then(|(cd, em)| cd.cipher_value.as_ref().zip(em.algorithm.as_ref()))
        })
    }

    pub fn encrypted_value_info(&self) -> Option<(&CipherValue, &String)> {
        self.data.as_ref().and_then(|ed| ed.value_info())
    }

    pub fn decrypt(
        &self,
        decryption_key: &<Crypto as CryptoProvider>::PrivateKey,
    ) -> Result<Assertion, Error> {
        let (ekey, method) = self
            .encrypted_key_info()
            .ok_or(Error::MissingEncryptedKeyInfo)?;
        let decrypted_key = Crypto::decrypt_assertion_key_info(ekey, method, decryption_key)?;

        let (evalue, method) = self
            .encrypted_value_info()
            .ok_or(Error::MissingEncryptedValueInfo)?;
        let plaintext = Crypto::decrypt_assertion_value_info(evalue, method, &decrypted_key)?;

        let assertion_string = match String::from_utf8(plaintext) {
            Ok(s) => s,
            Err(e) => {
                let i = e.utf8_error().valid_up_to();
                let mut plaintext = e.into_bytes();
                plaintext.truncate(i);
                let s =
                    String::from_utf8(plaintext).map_err(|_| Error::EncryptedAssertionInvalid)?;
                let fi = s.find("<").unwrap();
                let li = s.rfind(">").unwrap();
                s[fi..li + 1].to_owned()
            }
        };
        quick_xml::de::from_str(&assertion_string).map_err(|_e| Error::FailedToDecryptAssertion)
    }
}

impl TryFrom<EncryptedAssertion> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: EncryptedAssertion) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&EncryptedAssertion> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &EncryptedAssertion) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let mut root = BytesStart::from_content(NAME, NAME.len());
        root.push_attribute(SCHEMA);

        writer.write_event(Event::Start(root))?;

        if let Some(encrypted_data) = &value.data {
            let event: Event<'_> = encrypted_data.try_into()?;
            writer.write_event(event)?;
        }

        writer.write_event(Event::End(BytesEnd::new(NAME)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

const ED_NAME: &str = "xenc:EncryptedData";
const ED_SCHEMA: (&str, &str) = ("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#");

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct EncryptedData {
    #[serde(rename = "@Id")]
    pub id: Option<String>,
    #[serde(rename = "@Type")]
    pub ty: Option<String>,
    #[serde(rename = "EncryptionMethod")]
    pub encryption_method: Option<EncryptionMethod>,
    #[serde(alias = "KeyInfo", alias = "ds:KeyInfo")]
    pub key_info: Option<EncryptedKeyInfo>,
    #[serde(rename = "CipherData")]
    pub cipher_data: Option<CipherData>,
}

impl EncryptedData {
    pub fn key_info(&self) -> Option<(&CipherValue, &String)> {
        self.key_info
            .as_ref()
            .and_then(|k| k.encrypted_key.as_ref())
            .and_then(|e| e.cipher_data.as_ref().zip(e.encryption_method.as_ref()))
            .and_then(|(cd, em)| cd.cipher_value.as_ref().zip(em.algorithm.as_ref()))
    }

    pub fn value_info(&self) -> Option<(&CipherValue, &String)> {
        self.cipher_data
            .as_ref()
            .zip(self.encryption_method.as_ref())
            .and_then(|(cd, em)| cd.cipher_value.as_ref().zip(em.algorithm.as_ref()))
    }
}

impl TryFrom<EncryptedData> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: EncryptedData) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&EncryptedData> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &EncryptedData) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let mut root = BytesStart::from_content(ED_NAME, ED_NAME.len());
        root.push_attribute(ED_SCHEMA);
        if let Some(id) = &value.id {
            root.push_attribute(("Id", id.as_ref()));
        }
        if let Some(ty) = &value.ty {
            root.push_attribute(("Type", ty.as_ref()));
        }

        writer.write_event(Event::Start(root))?;

        if let Some(encryption_method) = &value.encryption_method {
            let event: Event<'_> = encryption_method.try_into()?;
            writer.write_event(event)?;
        }
        if let Some(key_info) = &value.key_info {
            let event: Event<'_> = key_info.try_into()?;
            writer.write_event(event)?;
        }
        if let Some(cipher_data) = &value.cipher_data {
            let event: Event<'_> = cipher_data.try_into()?;
            writer.write_event(event)?;
        }

        writer.write_event(Event::End(BytesEnd::new(ED_NAME)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

const EM_NAME: &str = "xenc:EncryptionMethod";
const EM_SCHEMA: (&str, &str) = ("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#");

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct EncryptionMethod {
    #[serde(rename = "@Algorithm")]
    pub algorithm: Option<String>,
    #[serde(rename = "DigestMethod")]
    pub digest_method: Option<DigestMethod>,
}

impl TryFrom<EncryptionMethod> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: EncryptionMethod) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&EncryptionMethod> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &EncryptionMethod) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let mut root = BytesStart::from_content(EM_NAME, EM_NAME.len());
        root.push_attribute(EM_SCHEMA);
        if let Some(algorithm) = &value.algorithm {
            root.push_attribute(("Algorithm", algorithm.as_ref()));
        }

        writer.write_event(Event::Start(root))?;

        if let Some(digest_method) = &value.digest_method {
            let event: Event<'_> = digest_method.try_into()?;
            writer.write_event(event)?;
        }

        writer.write_event(Event::End(BytesEnd::new(EM_NAME)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

const CD_NAME: &str = "xenc:CipherData";
const CD_SCHEMA: (&str, &str) = ("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#");

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct CipherData {
    #[serde(rename = "CipherValue")]
    pub cipher_value: Option<CipherValue>,
}

impl TryFrom<CipherData> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: CipherData) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&CipherData> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &CipherData) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let mut root = BytesStart::from_content(CD_NAME, CD_NAME.len());
        root.push_attribute(CD_SCHEMA);

        writer.write_event(Event::Start(root))?;

        if let Some(cipher_value) = &value.cipher_value {
            let event: Event<'_> = cipher_value.try_into()?;
            writer.write_event(event)?;
        }

        writer.write_event(Event::End(BytesEnd::new(CD_NAME)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

const CV_NAME: &str = "xenc:CipherValue";

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct CipherValue {
    #[serde(rename = "$value")]
    pub value: String,
}

impl TryFrom<CipherValue> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: CipherValue) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&CipherValue> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &CipherValue) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let root = BytesStart::from_content(CV_NAME, CV_NAME.len());

        writer.write_event(Event::Start(root))?;

        writer.write_event(Event::Text(BytesText::new(&value.value)))?;

        writer.write_event(Event::End(BytesEnd::new(CV_NAME)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

const EK_NAME: &str = "xenc:EncryptedKey";
const EK_SCHEMA: (&str, &str) = ("xmlns:xenc", "http://www.w3.org/2001/04/xmlenc#");

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct EncryptedKey {
    #[serde(rename = "@Id")]
    pub id: Option<String>,
    #[serde(rename = "@Recipient")]
    pub recipient: Option<String>,
    #[serde(rename = "EncryptionMethod")]
    pub encryption_method: Option<EncryptionMethod>,
    #[serde(rename = "KeyInfo")]
    pub key_info: Option<KeyInfo>,
    #[serde(rename = "CipherData")]
    pub cipher_data: Option<CipherData>,
}

impl TryFrom<EncryptedKey> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: EncryptedKey) -> Result<Self, Self::Error> {
        (&value).try_into()
    }
}

impl TryFrom<&EncryptedKey> for Event<'_> {
    type Error = Box<dyn std::error::Error>;

    fn try_from(value: &EncryptedKey) -> Result<Self, Self::Error> {
        let mut write_buf = Vec::new();
        let mut writer = Writer::new(Cursor::new(&mut write_buf));
        let mut root = BytesStart::from_content(EK_NAME, EK_NAME.len());
        root.push_attribute(EK_SCHEMA);
        if let Some(id) = &value.id {
            root.push_attribute(("Id", id.as_ref()));
        }
        if let Some(recipient) = &value.recipient {
            root.push_attribute(("Recipient", recipient.as_ref()));
        }

        writer.write_event(Event::Start(root))?;

        if let Some(encryption_method) = &value.encryption_method {
            let event: Event<'_> = encryption_method.try_into()?;
            writer.write_event(event)?;
        }

        if let Some(key_info) = &value.key_info {
            let event: Event<'_> = key_info.try_into()?;
            writer.write_event(event)?;
        }

        if let Some(cipher_data) = &value.cipher_data {
            let event: Event<'_> = cipher_data.try_into()?;
            writer.write_event(event)?;
        }

        writer.write_event(Event::End(BytesEnd::new(EK_NAME)))?;
        Ok(Event::Text(BytesText::from_escaped(String::from_utf8(
            write_buf,
        )?)))
    }
}

#[cfg(test)]
mod test {
    use crate::schema::Response;

    #[test]
    fn test_encrypted_assertion_key_info() {
        let response_xml = include_str!(concat!(
            env!("CARGO_MANIFEST_DIR"),
            "/test_vectors/response_encrypted.xml",
        ));
        let response: Response = response_xml
            .parse()
            .expect("failed to parse response_encrypted.xml");

        let encrypted_assertion = response
            .encrypted_assertion
            .expect("EncryptedAssertion missing");
        let key_info = encrypted_assertion.encrypted_key_info();

        let key_info_exists = key_info.is_some();

        assert!(key_info_exists, "KeyInfo missing on EncryptedAssertion");
    }
}