use crate::errors::{AuthError, Result};
use crate::protocols::saml_assertions::SamlAssertion;
use base64::{Engine as _, engine::general_purpose::STANDARD};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
fn xml_escape(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'&' => out.push_str("&"),
'<' => out.push_str("<"),
'>' => out.push_str(">"),
'"' => out.push_str("""),
'\'' => out.push_str("'"),
_ => out.push(c),
}
}
out
}
#[derive(Debug, Clone, Default)]
pub struct WsSecurityHeader {
pub username_token: Option<UsernameToken>,
pub timestamp: Option<Timestamp>,
pub binary_security_token: Option<BinarySecurityToken>,
pub saml_assertions: Vec<SamlAssertionRef>,
pub signature: Option<WsSecuritySignature>,
pub custom_elements: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsernameToken {
pub username: String,
pub password: Option<UsernamePassword>,
pub nonce: Option<String>,
pub created: Option<DateTime<Utc>>,
pub wsu_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsernamePassword {
pub value: String,
pub password_type: PasswordType,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum PasswordType {
PasswordText,
PasswordDigest,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Timestamp {
pub created: DateTime<Utc>,
pub expires: DateTime<Utc>,
pub wsu_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BinarySecurityToken {
pub value: String,
pub value_type: String,
pub encoding_type: String,
pub wsu_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAssertionRef {
pub assertion: SamlAssertion,
pub wsu_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsSecuritySignature {
pub signature_method: String,
pub canonicalization_method: String,
pub digest_method: String,
pub references: Vec<SignatureReference>,
pub key_info: Option<KeyInfo>,
pub signature_value: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SignatureReference {
pub uri: String,
pub digest_value: String,
pub transforms: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyInfo {
pub security_token_reference: Option<String>,
pub key_value: Option<String>,
pub x509_data: Option<String>,
}
#[derive(Debug, Clone)]
pub struct WsSecurityConfig {
pub include_timestamp: bool,
pub timestamp_ttl: Duration,
pub sign_message: bool,
pub elements_to_sign: Vec<String>,
pub signing_certificate: Option<Vec<u8>>,
pub signing_private_key: Option<Vec<u8>>,
pub include_certificate: bool,
pub saml_token_endpoint: Option<String>,
pub actor: Option<String>,
}
pub struct WsSecurityClient {
config: WsSecurityConfig,
namespaces: HashMap<String, String>,
}
impl WsSecurityClient {
pub fn new(config: WsSecurityConfig) -> Self {
let mut namespaces = HashMap::new();
namespaces.insert(
"wsse".to_string(),
"http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd"
.to_string(),
);
namespaces.insert(
"wsu".to_string(),
"http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd"
.to_string(),
);
namespaces.insert(
"ds".to_string(),
"http://www.w3.org/2000/09/xmldsig#".to_string(),
);
namespaces.insert(
"saml".to_string(),
"urn:oasis:names:tc:SAML:2.0:assertion".to_string(),
);
Self { config, namespaces }
}
pub fn create_username_token_header(
&self,
username: &str,
password: Option<&str>,
password_type: PasswordType,
) -> Result<WsSecurityHeader> {
let mut header = WsSecurityHeader::default();
let (nonce, created) = if password_type == PasswordType::PasswordDigest {
(Some(self.generate_nonce()), Some(Utc::now()))
} else {
(None, None)
};
let password_element = if let Some(pwd) = password {
let pwd_value = match password_type {
PasswordType::PasswordText => pwd.to_string(),
PasswordType::PasswordDigest => self.compute_password_digest(
pwd,
nonce
.as_ref()
.expect("nonce is Some for PasswordDigest variant"),
&created.expect("created is Some for PasswordDigest variant"),
)?,
};
Some(UsernamePassword {
value: pwd_value,
password_type,
})
} else {
None
};
header.username_token = Some(UsernameToken {
username: username.to_string(),
password: password_element,
nonce,
created,
wsu_id: Some(format!("UsernameToken-{}", uuid::Uuid::new_v4())),
});
if self.config.include_timestamp {
header.timestamp = Some(self.create_timestamp());
}
Ok(header)
}
pub fn create_certificate_header(&self, certificate: &[u8]) -> Result<WsSecurityHeader> {
let mut header = WsSecurityHeader::default();
let cert_b64 = STANDARD.encode(certificate);
header.binary_security_token = Some(BinarySecurityToken {
value: cert_b64,
value_type: "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-x509-token-profile-1.0#X509v3".to_string(),
encoding_type: "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-soap-message-security-1.0#Base64Binary".to_string(),
wsu_id: Some(format!("X509Token-{}", uuid::Uuid::new_v4())),
});
if self.config.include_timestamp {
header.timestamp = Some(self.create_timestamp());
}
if self.config.sign_message {
header.signature = Some(self.create_signature_template()?);
}
Ok(header)
}
pub fn create_saml_header(&self, assertion: SamlAssertion) -> Result<WsSecurityHeader> {
let mut header = WsSecurityHeader::default();
let assertion_ref = SamlAssertionRef {
assertion,
wsu_id: Some(format!("SamlAssertion-{}", uuid::Uuid::new_v4())),
};
header.saml_assertions.push(assertion_ref);
if self.config.include_timestamp {
header.timestamp = Some(self.create_timestamp());
}
Ok(header)
}
pub fn header_to_xml(&self, header: &WsSecurityHeader) -> Result<String> {
let mut xml = String::new();
xml.push_str(&format!(
r#"<wsse:Security xmlns:wsse="{}" xmlns:wsu="{}">"#,
self.namespaces["wsse"], self.namespaces["wsu"]
));
if let Some(ref timestamp) = header.timestamp {
xml.push_str(&self.timestamp_to_xml(timestamp));
}
if let Some(ref username_token) = header.username_token {
xml.push_str(&self.username_token_to_xml(username_token));
}
if let Some(ref bst) = header.binary_security_token {
xml.push_str(&self.binary_security_token_to_xml(bst));
}
for assertion_ref in &header.saml_assertions {
let assertion_xml = assertion_ref.assertion.to_xml()?;
xml.push_str(&assertion_xml);
}
if let Some(ref signature) = header.signature {
xml.push_str(&self.signature_to_xml(signature));
}
xml.push_str("</wsse:Security>");
Ok(xml)
}
fn generate_nonce(&self) -> String {
use rand::Rng;
let mut rng = rand::rng();
let mut nonce = [0u8; 16];
rng.fill_bytes(&mut nonce);
STANDARD.encode(nonce)
}
fn compute_password_digest(
&self,
password: &str,
nonce: &str,
created: &DateTime<Utc>,
) -> Result<String> {
use sha1::{Digest, Sha1};
let nonce_bytes = STANDARD
.decode(nonce)
.map_err(|_| AuthError::auth_method("ws_security", "Invalid nonce encoding"))?;
let created_str = created.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
let mut hasher = Sha1::new();
hasher.update(&nonce_bytes);
hasher.update(created_str.as_bytes());
hasher.update(password.as_bytes());
let digest = hasher.finalize();
Ok(STANDARD.encode(digest))
}
fn create_timestamp(&self) -> Timestamp {
let now = Utc::now();
let expires = now + self.config.timestamp_ttl;
Timestamp {
created: now,
expires,
wsu_id: Some(format!("Timestamp-{}", uuid::Uuid::new_v4())),
}
}
fn create_signature_template(&self) -> Result<WsSecuritySignature> {
Ok(WsSecuritySignature {
signature_method: "http://www.w3.org/2001/04/xmldsig-more#hmac-sha256".to_string(),
canonicalization_method: "http://www.w3.org/2001/10/xml-exc-c14n#".to_string(),
digest_method: "http://www.w3.org/2001/04/xmlenc#sha256".to_string(),
references: self
.config
.elements_to_sign
.iter()
.map(|element| {
SignatureReference {
uri: format!("#{}", element),
digest_value: String::new(), transforms: vec!["http://www.w3.org/2001/10/xml-exc-c14n#".to_string()],
}
})
.collect(),
key_info: None, signature_value: None, })
}
fn timestamp_to_xml(&self, timestamp: &Timestamp) -> String {
let mut xml = String::new();
if let Some(ref id) = timestamp.wsu_id {
xml.push_str(&format!(r#"<wsu:Timestamp wsu:Id="{}">"#, id));
} else {
xml.push_str("<wsu:Timestamp>");
}
xml.push_str(&format!(
"<wsu:Created>{}</wsu:Created>",
timestamp.created.format("%Y-%m-%dT%H:%M:%S%.3fZ")
));
xml.push_str(&format!(
"<wsu:Expires>{}</wsu:Expires>",
timestamp.expires.format("%Y-%m-%dT%H:%M:%S%.3fZ")
));
xml.push_str("</wsu:Timestamp>");
xml
}
fn username_token_to_xml(&self, token: &UsernameToken) -> String {
let mut xml = String::new();
if let Some(ref id) = token.wsu_id {
xml.push_str(&format!(
r#"<wsse:UsernameToken wsu:Id="{}">"#,
xml_escape(id)
));
} else {
xml.push_str("<wsse:UsernameToken>");
}
xml.push_str(&format!(
"<wsse:Username>{}</wsse:Username>",
xml_escape(&token.username)
));
if let Some(ref password) = token.password {
let type_attr = match password.password_type {
PasswordType::PasswordText => {
"http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-username-token-profile-1.0#PasswordText"
}
PasswordType::PasswordDigest => {
"http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-username-token-profile-1.0#PasswordDigest"
}
};
xml.push_str(&format!(
r#"<wsse:Password Type="{}">{}</wsse:Password>"#,
type_attr,
xml_escape(&password.value)
));
}
if let Some(ref nonce) = token.nonce {
xml.push_str(&format!(
r#"<wsse:Nonce EncodingType="http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-soap-message-security-1.0#Base64Binary">{}</wsse:Nonce>"#,
nonce
));
}
if let Some(ref created) = token.created {
xml.push_str(&format!(
"<wsu:Created>{}</wsu:Created>",
created.format("%Y-%m-%dT%H:%M:%S%.3fZ")
));
}
xml.push_str("</wsse:UsernameToken>");
xml
}
fn binary_security_token_to_xml(&self, token: &BinarySecurityToken) -> String {
let mut xml = String::new();
xml.push_str(&format!(
r#"<wsse:BinarySecurityToken ValueType="{}" EncodingType="{}""#,
token.value_type, token.encoding_type
));
if let Some(ref id) = token.wsu_id {
xml.push_str(&format!(r#" wsu:Id="{}""#, id));
}
xml.push('>');
xml.push_str(&token.value);
xml.push_str("</wsse:BinarySecurityToken>");
xml
}
fn signature_to_xml(&self, signature: &WsSecuritySignature) -> String {
let references_xml: String = signature
.references
.iter()
.map(|r| {
let digest_value = if !r.digest_value.is_empty() {
r.digest_value.clone()
} else if let Some(ref key_bytes) = self.config.signing_private_key {
use ring::hmac;
let key = hmac::Key::new(hmac::HMAC_SHA256, key_bytes);
let tag = hmac::sign(&key, r.uri.as_bytes());
STANDARD.encode(tag.as_ref())
} else {
String::new()
};
format!(
r#"<ds:Reference URI="{}">
<ds:Transforms>
{}
</ds:Transforms>
<ds:DigestMethod Algorithm="{}"/>
<ds:DigestValue>{}</ds:DigestValue>
</ds:Reference>"#,
r.uri,
r.transforms
.iter()
.map(|t| format!(r#"<ds:Transform Algorithm="{}"/>"#, t))
.collect::<Vec<_>>()
.join(""),
signature.digest_method,
digest_value,
)
})
.collect::<Vec<_>>()
.join("");
let signed_info_xml = format!(
r#"<ds:SignedInfo>
<ds:CanonicalizationMethod Algorithm="{}"/>
<ds:SignatureMethod Algorithm="{}"/>
{}
</ds:SignedInfo>"#,
signature.canonicalization_method, signature.signature_method, references_xml,
);
let signature_value = if let Some(ref sv) = signature.signature_value {
sv.clone()
} else if let Some(ref key_bytes) = self.config.signing_private_key {
use ring::hmac;
let key = hmac::Key::new(hmac::HMAC_SHA256, key_bytes);
let tag = hmac::sign(&key, signed_info_xml.as_bytes());
STANDARD.encode(tag.as_ref())
} else {
String::new()
};
let key_info_xml = if let Some(ref ki) = signature.key_info {
let mut ki_xml = String::new();
if let Some(ref x509) = ki.x509_data {
ki_xml.push_str(&format!(
"<ds:X509Data><ds:X509Certificate>{}</ds:X509Certificate></ds:X509Data>",
x509
));
}
if let Some(ref str_ref) = ki.security_token_reference {
ki_xml.push_str(&format!(
"<wsse:SecurityTokenReference>{}</wsse:SecurityTokenReference>",
str_ref
));
}
if let Some(ref kv) = ki.key_value {
ki_xml.push_str(&format!("<ds:KeyValue>{}</ds:KeyValue>", kv));
}
ki_xml
} else if let Some(ref cert) = self.config.signing_certificate {
let cert_b64 = STANDARD.encode(cert);
format!(
"<ds:X509Data><ds:X509Certificate>{}</ds:X509Certificate></ds:X509Data>",
cert_b64
)
} else {
"<ds:KeyName>WS-Security-Signing-Key</ds:KeyName>".to_string()
};
format!(
r#"<ds:Signature xmlns:ds="{}">
{}
<ds:SignatureValue>{}</ds:SignatureValue>
<ds:KeyInfo>{}</ds:KeyInfo>
</ds:Signature>"#,
self.namespaces["ds"], signed_info_xml, signature_value, key_info_xml,
)
}
}
impl Default for WsSecurityConfig {
fn default() -> Self {
Self {
include_timestamp: true,
timestamp_ttl: Duration::minutes(5),
sign_message: false,
elements_to_sign: vec!["Body".to_string(), "Timestamp".to_string()],
signing_certificate: None,
signing_private_key: None,
include_certificate: true,
saml_token_endpoint: None,
actor: None,
}
}
}
impl WsSecurityConfig {
pub fn builder() -> WsSecurityConfigBuilder {
WsSecurityConfigBuilder::default()
}
}
#[derive(Debug, Clone)]
pub struct WsSecurityConfigBuilder {
config: WsSecurityConfig,
}
impl Default for WsSecurityConfigBuilder {
fn default() -> Self {
Self {
config: WsSecurityConfig::default(),
}
}
}
impl WsSecurityConfigBuilder {
pub fn include_timestamp(mut self, include: bool) -> Self {
self.config.include_timestamp = include;
self
}
pub fn timestamp_ttl(mut self, ttl: Duration) -> Self {
self.config.timestamp_ttl = ttl;
self
}
pub fn sign_message(mut self, sign: bool) -> Self {
self.config.sign_message = sign;
self
}
pub fn elements_to_sign(mut self, elements: Vec<String>) -> Self {
self.config.elements_to_sign = elements;
self
}
pub fn signing_certificate(mut self, cert: Vec<u8>) -> Self {
self.config.signing_certificate = Some(cert);
self
}
pub fn signing_private_key(mut self, key: Vec<u8>) -> Self {
self.config.signing_private_key = Some(key);
self
}
pub fn include_certificate(mut self, include: bool) -> Self {
self.config.include_certificate = include;
self
}
pub fn saml_token_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.config.saml_token_endpoint = Some(endpoint.into());
self
}
pub fn actor(mut self, actor: impl Into<String>) -> Self {
self.config.actor = Some(actor.into());
self
}
pub fn build(self) -> WsSecurityConfig {
self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_username_token_creation() {
let config = WsSecurityConfig::default();
let client = WsSecurityClient::new(config);
let header = client
.create_username_token_header("testuser", Some("testpass"), PasswordType::PasswordText)
.unwrap();
assert!(header.username_token.is_some());
let token = header.username_token.unwrap();
assert_eq!(token.username, "testuser");
assert!(token.password.is_some());
}
#[test]
fn test_password_digest() {
let config = WsSecurityConfig::default();
let client = WsSecurityClient::new(config);
let nonce = "MTIzNDU2Nzg5MDEyMzQ1Ng=="; let created = DateTime::parse_from_rfc3339("2023-01-01T12:00:00Z")
.unwrap()
.with_timezone(&Utc);
let password = "secret";
let digest = client
.compute_password_digest(password, nonce, &created)
.unwrap();
assert!(!digest.is_empty());
}
#[test]
fn test_timestamp_creation() {
let config = WsSecurityConfig::default();
let client = WsSecurityClient::new(config);
let timestamp = client.create_timestamp();
assert!(timestamp.expires > timestamp.created);
assert!(timestamp.wsu_id.is_some());
}
#[test]
fn test_xml_generation() {
let config = WsSecurityConfig::default();
let client = WsSecurityClient::new(config);
let header = client
.create_username_token_header("testuser", Some("testpass"), PasswordType::PasswordText)
.unwrap();
let xml = client.header_to_xml(&header).unwrap();
assert!(xml.contains("<wsse:Security"));
assert!(xml.contains("<wsse:UsernameToken"));
assert!(xml.contains("testuser"));
assert!(xml.contains("</wsse:Security>"));
}
#[test]
fn test_certificate_header() {
let config = WsSecurityConfig::default();
let client = WsSecurityClient::new(config);
let dummy_cert = b"dummy certificate data";
let header = client.create_certificate_header(dummy_cert).unwrap();
assert!(header.binary_security_token.is_some());
let bst = header.binary_security_token.unwrap();
assert_eq!(bst.value, STANDARD.encode(dummy_cert));
}
#[test]
fn test_xml_escape_special_characters() {
assert_eq!(xml_escape("hello"), "hello");
assert_eq!(xml_escape("<script>"), "<script>");
assert_eq!(xml_escape("a&b"), "a&b");
assert_eq!(xml_escape("\"quoted\""), ""quoted"");
assert_eq!(xml_escape("it's"), "it's");
assert_eq!(
xml_escape("<user>&\"name'"),
"<user>&"name'"
);
}
#[test]
fn test_xml_escape_empty_and_normal() {
assert_eq!(xml_escape(""), "");
assert_eq!(xml_escape("normal_user123"), "normal_user123");
}
#[test]
fn test_username_token_xml_escapes_injection() {
let config = WsSecurityConfig::default();
let client = WsSecurityClient::new(config);
let header = client
.create_username_token_header(
"<script>alert(1)</script>",
Some("pass&word\""),
PasswordType::PasswordText,
)
.unwrap();
let xml = client.header_to_xml(&header).unwrap();
assert!(!xml.contains("<script>"));
assert!(xml.contains("<script>"));
assert!(xml.contains("pass&word""));
}
}