#![allow(clippy::needless_borrows_for_generic_args)]
#![allow(clippy::needless_borrow)]
use crate::errors::{AuthError, Result};
use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use quick_xml::{Reader, Writer, events::Event};
use ring::signature;
use std::collections::BTreeMap;
use std::io::Cursor;
use x509_parser::{parse_x509_certificate, public_key::PublicKey};
pub struct XmlCanonicalizer;
impl Default for XmlCanonicalizer {
fn default() -> Self {
Self::new()
}
}
impl XmlCanonicalizer {
pub fn new() -> Self {
Self
}
pub fn canonicalize_xml(&self, xml: &str) -> Result<String> {
let mut reader = Reader::from_str(xml);
reader.config_mut().trim_text(true);
let mut canonical = Vec::new();
let mut writer = Writer::new(Cursor::new(&mut canonical));
let mut namespace_stack: Vec<BTreeMap<String, String>> = vec![BTreeMap::new()];
loop {
match reader.read_event() {
Ok(Event::Start(ref e)) => {
let mut ns_ctx = namespace_stack
.last()
.expect("namespace stack initialized before loop")
.clone();
for attr in e.attributes() {
let attr = attr.map_err(|e| {
AuthError::validation(&format!("XML attribute error: {}", e))
})?;
let key = std::str::from_utf8(attr.key.as_ref()).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in attribute key: {}", e))
})?;
let value = std::str::from_utf8(&attr.value).map_err(|e| {
AuthError::validation(&format!(
"Invalid UTF-8 in attribute value: {}",
e
))
})?;
if key.starts_with("xmlns:") || key == "xmlns" {
let prefix = if key == "xmlns" {
String::new()
} else {
key[6..].to_string()
};
ns_ctx.insert(prefix, value.to_string());
}
}
namespace_stack.push(ns_ctx);
let canonicalized_element = self.canonicalize_element(e, &namespace_stack)?;
writer
.write_event(Event::Start(canonicalized_element))
.map_err(|e| AuthError::validation(&format!("XML write error: {}", e)))?;
}
Ok(Event::End(ref e)) => {
namespace_stack.pop();
writer
.write_event(Event::End(e.clone()))
.map_err(|e| AuthError::validation(&format!("XML write error: {}", e)))?;
}
Ok(Event::Text(ref e)) => {
let text = e.xml_content().map_err(|e| {
AuthError::validation(&format!("XML text decode error: {}", e))
})?;
if !text.trim().is_empty() {
writer
.write_event(Event::Text(quick_xml::events::BytesText::new(&text)))
.map_err(|e| {
AuthError::validation(&format!("XML write error: {}", e))
})?;
}
}
Ok(Event::Empty(ref e)) => {
let canonicalized_element = self.canonicalize_element(e, &namespace_stack)?;
writer
.write_event(Event::Empty(canonicalized_element))
.map_err(|e| AuthError::validation(&format!("XML write error: {}", e)))?;
}
Ok(Event::Eof) => break,
Ok(Event::Comment(_)) | Ok(Event::PI(_)) | Ok(Event::CData(_)) => continue,
Ok(Event::Decl(_)) => continue, Ok(Event::DocType(_)) => continue, Ok(Event::GeneralRef(_)) => continue, Err(e) => return Err(AuthError::validation(&format!("XML parsing error: {}", e))),
}
}
String::from_utf8(canonical).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in canonicalized XML: {}", e))
})
}
fn canonicalize_element(
&self,
element: &quick_xml::events::BytesStart,
_namespace_stack: &[BTreeMap<String, String>],
) -> Result<quick_xml::events::BytesStart<'static>> {
let mut attrs: BTreeMap<String, String> = BTreeMap::new();
for attr in element.attributes() {
let attr =
attr.map_err(|e| AuthError::validation(&format!("XML attribute error: {}", e)))?;
let key = std::str::from_utf8(attr.key.as_ref()).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in attribute key: {}", e))
})?;
let value = std::str::from_utf8(&attr.value).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in attribute value: {}", e))
})?;
attrs.insert(key.to_string(), value.to_string());
}
let element_name = std::str::from_utf8(element.name().as_ref())
.map_err(|e| AuthError::validation(&format!("Invalid UTF-8 in element name: {}", e)))?
.to_string();
let element_name_len = element_name.len();
let mut new_element =
quick_xml::events::BytesStart::from_content(element_name, element_name_len);
for (key, value) in attrs {
new_element.push_attribute((key.as_str(), value.as_str()));
}
Ok(new_element)
}
}
pub struct SamlSignatureValidator;
impl SamlSignatureValidator {
fn local_name<'a>(&self, name: &'a [u8]) -> &'a [u8] {
name.rsplit(|byte| *byte == b':').next().unwrap_or(name)
}
pub fn validate_xml_signature(&self, xml: &str, cert_der: &[u8]) -> Result<bool> {
let (_, cert) = parse_x509_certificate(cert_der)
.map_err(|e| AuthError::validation(&format!("Certificate parsing error: {}", e)))?;
let public_key_info = cert.public_key();
let signed_info = self.extract_signed_info(xml)?;
let canonicalizer = XmlCanonicalizer::new();
let canonical_signed_info = canonicalizer.canonicalize_xml(&signed_info)?;
let signature_value = self.extract_signature_value(xml)?;
let signature_bytes = BASE64
.decode(&signature_value)
.map_err(|e| AuthError::validation(&format!("Invalid base64 signature: {}", e)))?;
match &public_key_info.algorithm.algorithm {
oid if oid.to_string() == "1.2.840.113549.1.1.1" => {
let public_key_bytes = match &public_key_info.parsed() {
Ok(PublicKey::RSA(rsa_key)) => self.construct_rsa_public_key(&rsa_key)?,
_ => {
return Err(AuthError::validation("Invalid RSA public key"));
}
};
let public_key = signature::UnparsedPublicKey::new(
&signature::RSA_PKCS1_2048_8192_SHA256,
&public_key_bytes,
);
match public_key.verify(canonical_signed_info.as_bytes(), &signature_bytes) {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
oid if oid.to_string() == "1.2.840.10045.2.1" => {
let public_key_bytes = match &public_key_info.parsed() {
Ok(PublicKey::EC(ec_key)) => self.construct_ecdsa_public_key(&ec_key)?,
_ => {
return Err(AuthError::validation("Invalid ECDSA public key"));
}
};
let public_key = signature::UnparsedPublicKey::new(
&signature::ECDSA_P256_SHA256_ASN1,
&public_key_bytes,
);
match public_key.verify(canonical_signed_info.as_bytes(), &signature_bytes) {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
oid => Err(AuthError::validation(&format!(
"Unsupported signature algorithm: {}",
oid
))),
}
}
fn extract_signed_info(&self, xml: &str) -> Result<String> {
let mut reader = Reader::from_str(xml);
let mut signed_info = String::new();
let mut inside_signed_info = false;
let mut depth = 0;
loop {
match reader.read_event() {
Ok(Event::Start(ref e)) if self.local_name(e.name().as_ref()) == b"SignedInfo" => {
inside_signed_info = true;
depth = 1;
signed_info.push_str(&format!(
"<{}>",
std::str::from_utf8(e.name().as_ref()).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in element name: {}", e))
})?
));
for attr in e.attributes() {
let attr = attr.map_err(|e| {
AuthError::validation(&format!("XML attribute error: {}", e))
})?;
let key = std::str::from_utf8(attr.key.as_ref()).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in attribute key: {}", e))
})?;
let value = std::str::from_utf8(&attr.value).map_err(|e| {
AuthError::validation(&format!(
"Invalid UTF-8 in attribute value: {}",
e
))
})?;
signed_info.push_str(&format!(" {}=\"{}\"", key, value));
}
signed_info.push('>');
}
Ok(Event::Start(ref e)) if inside_signed_info => {
depth += 1;
signed_info.push_str(&format!(
"<{}>",
std::str::from_utf8(e.name().as_ref()).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in element name: {}", e))
})?
));
for attr in e.attributes() {
let attr = attr.map_err(|e| {
AuthError::validation(&format!("XML attribute error: {}", e))
})?;
let key = std::str::from_utf8(attr.key.as_ref()).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in attribute key: {}", e))
})?;
let value = std::str::from_utf8(&attr.value).map_err(|e| {
AuthError::validation(&format!(
"Invalid UTF-8 in attribute value: {}",
e
))
})?;
signed_info.push_str(&format!(" {}=\"{}\"", key, value));
}
signed_info.push('>');
}
Ok(Event::End(ref e)) if inside_signed_info => {
depth -= 1;
signed_info.push_str(&format!(
"</{}>",
std::str::from_utf8(e.name().as_ref()).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in element name: {}", e))
})?
));
if depth == 0 {
break;
}
}
Ok(Event::Text(ref e)) if inside_signed_info => {
let text = e.xml_content().map_err(|e| {
AuthError::validation(&format!("XML text decode error: {}", e))
})?;
signed_info.push_str(&text);
}
Ok(Event::Empty(ref e)) if inside_signed_info => {
signed_info.push_str(&format!(
"<{}",
std::str::from_utf8(e.name().as_ref()).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in element name: {}", e))
})?
));
for attr in e.attributes() {
let attr = attr.map_err(|e| {
AuthError::validation(&format!("XML attribute error: {}", e))
})?;
let key = std::str::from_utf8(attr.key.as_ref()).map_err(|e| {
AuthError::validation(&format!("Invalid UTF-8 in attribute key: {}", e))
})?;
let value = std::str::from_utf8(&attr.value).map_err(|e| {
AuthError::validation(&format!(
"Invalid UTF-8 in attribute value: {}",
e
))
})?;
signed_info.push_str(&format!(" {}=\"{}\"", key, value));
}
signed_info.push_str(" />");
}
Ok(Event::Eof) => break,
Err(e) => return Err(AuthError::validation(&format!("XML parsing error: {}", e))),
_ => continue,
}
}
if signed_info.is_empty() {
return Err(AuthError::validation("SignedInfo element not found"));
}
Ok(signed_info)
}
fn extract_signature_value(&self, xml: &str) -> Result<String> {
let mut reader = Reader::from_str(xml);
let mut inside_signature_value = false;
let mut signature_value = String::new();
loop {
match reader.read_event() {
Ok(Event::Start(ref e))
if self.local_name(e.name().as_ref()) == b"SignatureValue" =>
{
inside_signature_value = true;
}
Ok(Event::Text(ref e)) if inside_signature_value => {
let text = e.xml_content().map_err(|e| {
AuthError::validation(&format!("XML text decode error: {}", e))
})?;
signature_value.push_str(&text);
}
Ok(Event::End(ref e))
if self.local_name(e.name().as_ref()) == b"SignatureValue" =>
{
break;
}
Ok(Event::Eof) => break,
Err(e) => return Err(AuthError::validation(&format!("XML parsing error: {}", e))),
_ => continue,
}
}
if signature_value.is_empty() {
return Err(AuthError::validation("SignatureValue element not found"));
}
Ok(signature_value
.chars()
.filter(|c| !c.is_whitespace())
.collect())
}
pub fn extract_embedded_certificate(&self, xml: &str) -> Result<Vec<u8>> {
let mut reader = Reader::from_str(xml);
let mut inside_certificate = false;
let mut certificate = String::new();
loop {
match reader.read_event() {
Ok(Event::Start(ref e))
if self.local_name(e.name().as_ref()) == b"X509Certificate" =>
{
inside_certificate = true;
}
Ok(Event::Text(ref e)) if inside_certificate => {
let text = e.xml_content().map_err(|e| {
AuthError::validation(&format!("XML text decode error: {}", e))
})?;
certificate.push_str(&text);
}
Ok(Event::End(ref e))
if self.local_name(e.name().as_ref()) == b"X509Certificate" =>
{
break;
}
Ok(Event::Eof) => break,
Err(e) => return Err(AuthError::validation(&format!("XML parsing error: {}", e))),
_ => continue,
}
}
if certificate.trim().is_empty() {
return Err(AuthError::validation(
"No embedded X509Certificate element found in SAML assertion",
));
}
BASE64
.decode(
certificate
.chars()
.filter(|c| !c.is_whitespace())
.collect::<String>(),
)
.map_err(|e| AuthError::validation(&format!("Invalid embedded certificate: {}", e)))
}
fn construct_rsa_public_key(
&self,
rsa_key: &x509_parser::public_key::RSAPublicKey,
) -> Result<Vec<u8>> {
fn der_encode_integer(value: &[u8]) -> Vec<u8> {
let stripped = match value.iter().position(|&b| b != 0) {
Some(pos) => &value[pos..],
None => &[0u8],
};
let needs_pad = stripped[0] & 0x80 != 0;
let len = stripped.len() + if needs_pad { 1 } else { 0 };
let mut out = Vec::with_capacity(2 + len);
out.push(0x02); der_push_length(&mut out, len);
if needs_pad {
out.push(0x00);
}
out.extend_from_slice(stripped);
out
}
fn der_push_length(buf: &mut Vec<u8>, len: usize) {
if len < 0x80 {
buf.push(len as u8);
} else if len <= 0xFF {
buf.push(0x81);
buf.push(len as u8);
} else {
buf.push(0x82);
buf.push((len >> 8) as u8);
buf.push(len as u8);
}
}
let modulus_der = der_encode_integer(rsa_key.modulus);
let exponent_der = der_encode_integer(rsa_key.exponent);
let seq_content_len = modulus_der.len() + exponent_der.len();
let mut key_data = Vec::with_capacity(4 + seq_content_len);
key_data.push(0x30); der_push_length(&mut key_data, seq_content_len);
key_data.extend_from_slice(&modulus_der);
key_data.extend_from_slice(&exponent_der);
Ok(key_data)
}
fn construct_ecdsa_public_key(
&self,
ec_key: &x509_parser::public_key::ECPoint,
) -> Result<Vec<u8>> {
Ok(ec_key.data().to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xml_canonicalization() {
let xml = r#"<test xmlns:ns="http://example.com" attr2="value2" attr1="value1">
<child>content</child>
</test>"#;
let canonicalizer = XmlCanonicalizer::new();
let result = canonicalizer.canonicalize_xml(xml);
assert!(result.is_ok());
let canonical = result.unwrap();
assert!(canonical.contains("attr1"));
assert!(canonical.contains("attr2"));
}
#[test]
fn test_signed_info_extraction() {
let xml = r#"
<Assertion>
<Signature>
<SignedInfo>
<CanonicalizationMethod Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#" />
<SignatureMethod Algorithm="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" />
<Reference URI="">
<DigestMethod Algorithm="http://www.w3.org/2001/04/xmlenc#sha256" />
<DigestValue>base64digest</DigestValue>
</Reference>
</SignedInfo>
<SignatureValue>base64signature</SignatureValue>
</Signature>
</Assertion>"#;
let validator = SamlSignatureValidator;
let result = validator.extract_signed_info(xml);
assert!(result.is_ok());
let signed_info = result.unwrap();
assert!(signed_info.contains("SignedInfo"));
assert!(signed_info.contains("CanonicalizationMethod"));
assert!(signed_info.contains("SignatureMethod"));
assert!(signed_info.contains("Reference"));
}
#[test]
fn test_signature_value_extraction() {
let xml = r#"
<Signature>
<SignatureValue>
YmFzZTY0c2lnbmF0dXJl
</SignatureValue>
</Signature>"#;
let validator = SamlSignatureValidator;
let result = validator.extract_signature_value(xml);
assert!(result.is_ok());
let signature_value = result.unwrap();
assert_eq!(signature_value, "YmFzZTY0c2lnbmF0dXJl");
}
}