use std::{borrow::Cow, str::FromStr};
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct Header<'a> {
#[serde(rename = "a:Action")]
#[serde(alias = "Action")]
pub action: Action,
#[serde(rename = "ActivityId", skip_serializing_if = "Option::is_none")]
pub activity_id: Option<ActivityId<'a>>,
#[serde(
rename = "a:RelatesTo",
alias = "RelatesTo",
skip_serializing_if = "Option::is_none"
)]
pub relates_to: Option<Cow<'a, str>>,
#[serde(
rename = "a:MessageID",
alias = "MessageID",
skip_serializing_if = "Option::is_none"
)]
pub message_id: Option<Cow<'a, str>>,
#[serde(
rename = "a:ReplyTo",
alias = "ReplyTo",
skip_serializing_if = "Option::is_none"
)]
pub reply_to: Option<ReplyTo<'a>>,
#[serde(rename = "a:To", alias = "To", skip_serializing_if = "Option::is_none")]
pub to: Option<To<'a>>,
}
impl<'a> Header<'a> {
#[must_use]
pub fn new_request_header(
action_type: ActionType,
message_id: &'a str,
to: Option<&'a str>,
reply_to: Option<&'a str>,
) -> Self {
Self {
action: Action::new(action_type),
activity_id: None,
relates_to: None,
message_id: Some(message_id.into()),
reply_to: reply_to.map(ReplyTo::new),
to: to.map(To::new),
}
}
#[must_use]
pub fn new_response_header(
action_type: ActionType,
activity_id: ActivityId<'a>,
relates_to: &'a str,
) -> Self {
Self {
action: Action::new(action_type),
activity_id: Some(activity_id),
relates_to: Some(relates_to.into()),
message_id: None,
reply_to: None,
to: None,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct To<'a> {
#[serde(rename = "@s:mustUnderstand")]
#[serde(alias = "@mustUnderstand")]
pub must_understand: MustUnderstand,
#[serde(rename = "$text")]
pub value: Cow<'a, str>,
}
impl<'a> To<'a> {
#[must_use]
pub fn new(value: &'a str) -> Self {
Self {
must_understand: MustUnderstand(true),
value: value.into(),
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct ReplyTo<'a> {
#[serde(rename = "a:Address")]
#[serde(alias = "Address")]
pub address: Address<'a>,
}
impl<'a> ReplyTo<'a> {
fn new(address: &'a str) -> Self {
Self {
address: Address {
value: address.into(),
},
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct Address<'a> {
#[serde(rename = "$text")]
pub value: Cow<'a, str>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct ActivityId<'a> {
#[serde(rename = "@CorrelationId")]
pub correlation_id: Cow<'a, str>,
#[serde(rename = "@xmlns")]
pub xmlns: Cow<'a, str>,
#[serde(rename = "$text")]
pub value: Cow<'a, str>,
}
impl<'a> ActivityId<'a> {
const XMLNS: &'static str = "http://schemas.microsoft.com/2004/09/ServiceModel/Diagnostics";
#[must_use]
pub fn new(id: &'a str, correlation_id: &'a str) -> Self {
Self {
correlation_id: correlation_id.into(),
xmlns: Self::XMLNS.into(),
value: id.into(),
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct Action {
#[serde(rename = "@s:mustUnderstand")]
#[serde(alias = "@mustUnderstand")]
pub must_understand: MustUnderstand,
#[serde(rename = "$text")]
pub action_type: ActionType,
}
impl Action {
#[must_use]
pub const fn new(action_type: ActionType) -> Self {
Self {
must_understand: MustUnderstand(true),
action_type,
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct MustUnderstand(pub bool);
impl Serialize for MustUnderstand {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(if self.0 { "1" } else { "0" })
}
}
impl<'a> Deserialize<'a> for MustUnderstand {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'a>,
{
let s = String::deserialize(deserializer)?;
Ok(Self(s == "1"))
}
}
#[non_exhaustive]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ActionType {
RequestSecurityToken,
RequestSecurityTokenResponseCollection,
KeyExchangeToken,
KeyExchangeTokenFinal,
SoapFault,
Fault,
FaultDetail,
}
impl ActionType {
const RST_WSTEP: &str = "http://schemas.microsoft.com/windows/pki/2009/01/enrollment/RST/wstep";
const RSTRC_WSTEP: &str =
"http://schemas.microsoft.com/windows/pki/2009/01/enrollment/RSTRC/wstep";
const KET: &str = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/KET";
const KET_FINAL: &str = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RSTR/KETFinal";
const FAULT: &str =
"http://schemas.microsoft.com/net/2005/12/windowscommunicationfoundation/dispatcher/fault";
const FAULT_DETAIL: &str = "http://schemas.microsoft.com/windows/pki/2009/01/enrollment/RequestSecurityTokenCertificateEnrollmentWSDetailFault";
const SOAP_FAULT: &str = "http://www.w3.org/2005/08/addressing/soap/fault";
}
impl Serialize for ActionType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str((*self).into())
}
}
impl<'de> Deserialize<'de> for ActionType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
impl From<ActionType> for &'static str {
fn from(value: ActionType) -> Self {
match value {
ActionType::RequestSecurityToken => ActionType::RST_WSTEP,
ActionType::RequestSecurityTokenResponseCollection => ActionType::RSTRC_WSTEP,
ActionType::KeyExchangeToken => ActionType::KET,
ActionType::KeyExchangeTokenFinal => ActionType::KET_FINAL,
ActionType::Fault => ActionType::FAULT,
ActionType::FaultDetail => ActionType::FAULT_DETAIL,
ActionType::SoapFault => ActionType::SOAP_FAULT,
}
}
}
impl FromStr for ActionType {
type Err = ActionTypeParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let action = match s {
Self::RST_WSTEP => Self::RequestSecurityToken,
Self::RSTRC_WSTEP => Self::RequestSecurityTokenResponseCollection,
Self::KET => Self::KeyExchangeToken,
Self::KET_FINAL => Self::KeyExchangeTokenFinal,
Self::FAULT => Self::Fault,
Self::FAULT_DETAIL => Self::FaultDetail,
Self::SOAP_FAULT => Self::SoapFault,
other => return Err(ActionTypeParseError(other.to_string())),
};
Ok(action)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ActionTypeParseError(String);
impl std::fmt::Display for ActionTypeParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} is not a valid MS-WSTEP action type", self.0)
}
}
impl std::error::Error for ActionTypeParseError {}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct BinarySecurityToken<'a> {
#[serde(rename = "@EncodingType")]
pub encoding_type: EncodingType,
#[serde(rename = "@ValueType")]
pub value_type: ValueType,
#[serde(rename = "@xmlns")]
pub xmlns: Cow<'a, str>,
#[serde(rename = "$text")]
pub value: Cow<'a, str>,
}
impl<'a> BinarySecurityToken<'a> {
const XMLNS: &'static str =
"http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd";
#[must_use]
pub fn new(value: &'a str, value_type: ValueType, encoding_type: EncodingType) -> Self {
Self {
value: value.into(),
encoding_type,
value_type,
xmlns: Self::XMLNS.into(),
}
}
#[must_use]
pub fn new_pkcs7_base64(value: &'a str) -> Self {
Self::new(value, ValueType::Pkcs7, EncodingType::Base64Binary)
}
#[must_use]
pub fn new_x509v3_base64(value: &'a str) -> Self {
Self::new(value, ValueType::X509v3, EncodingType::Base64Binary)
}
}
#[non_exhaustive]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ValueType {
Pkcs7,
X509v3,
}
impl ValueType {
const NS_PKCS7: &'static str =
"http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd#PKCS7";
const NS_X509V3: &'static str =
"http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-x509-token-profile-1.0#X509v3";
}
impl From<ValueType> for &'static str {
fn from(value: ValueType) -> Self {
match value {
ValueType::Pkcs7 => ValueType::NS_PKCS7,
ValueType::X509v3 => ValueType::NS_X509V3,
}
}
}
impl FromStr for ValueType {
type Err = ValueTypeParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
Self::NS_PKCS7 => Self::Pkcs7,
Self::NS_X509V3 => Self::X509v3,
other => return Err(ValueTypeParseError(other.to_string())),
})
}
}
impl Serialize for ValueType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str((*self).into())
}
}
impl<'de> Deserialize<'de> for ValueType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
#[derive(Debug)]
pub struct ValueTypeParseError(String);
impl std::fmt::Display for ValueTypeParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} is not a valid MS-WSTEP BinarySecurityToken ValueType attribute",
self.0
)
}
}
impl std::error::Error for ValueTypeParseError {}
#[non_exhaustive]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum EncodingType {
Base64Binary,
}
impl EncodingType {
const NS_BASE64_BINARY: &'static str =
"http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd#base64binary";
}
impl From<EncodingType> for &'static str {
fn from(value: EncodingType) -> Self {
match value {
EncodingType::Base64Binary => EncodingType::NS_BASE64_BINARY,
}
}
}
impl FromStr for EncodingType {
type Err = EncodingTypeParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
Self::NS_BASE64_BINARY => Self::Base64Binary,
other => return Err(EncodingTypeParseError(other.to_string())),
})
}
}
impl Serialize for EncodingType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str((*self).into())
}
}
impl<'de> Deserialize<'de> for EncodingType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EncodingTypeParseError(String);
impl std::fmt::Display for EncodingTypeParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} is not a valid MS-WSTEP BinarySecurityToken EncodingType attribute",
self.0
)
}
}
impl std::error::Error for EncodingTypeParseError {}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct RequestId<'a> {
#[serde(rename = "$text", skip_serializing_if = "Option::is_none")]
pub value: Option<Cow<'a, str>>,
#[serde(
rename = "@xsi:nil",
alias = "@nil",
skip_serializing_if = "Option::is_none"
)]
pub xsi_nil: Option<bool>,
#[serde(rename = "@xmlns", skip_serializing_if = "Option::is_none")]
pub xmlns: Option<Cow<'a, str>>,
}
impl<'a> RequestId<'a> {
const XMLNS: &'static str = "http://schemas.microsoft.com/windows/pki/2009/01/enrollment";
#[must_use]
pub fn new(id: Option<&'a str>) -> Self {
Self {
xsi_nil: if id.is_some() { None } else { Some(true) },
value: id.map(Into::into),
xmlns: Some(Self::XMLNS.into()),
}
}
#[must_use]
pub fn new_with_id(id: &'a str) -> Self {
Self::new(Some(id))
}
#[must_use]
pub fn nil() -> Self {
Self::new(None)
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct TokenType<'a> {
#[serde(rename = "$text")]
pub value: Cow<'a, str>,
}
impl<'a> TokenType<'a> {
pub const X509V3_TOKEN_TYPE: &'static str =
"http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-x509-token-profile-1.0#X509v3";
#[must_use]
pub fn x509v3() -> Self {
Self {
value: Self::X509V3_TOKEN_TYPE.into(),
}
}
#[must_use]
pub fn other(value: impl Into<Cow<'a, str>>) -> Self {
Self {
value: value.into(),
}
}
}
#[cfg(test)]
pub(crate) mod common_serde_tests {
use crate::common::BinarySecurityToken;
use super::{
serde_test_utils::{serde_test, serde_test_with_root},
RequestId, TokenType,
};
#[test]
fn test_serde_request_id() {
let serialized_id = r#"<RequestID xmlns="http://schemas.microsoft.com/windows/pki/2009/01/enrollment">61</RequestID>"#;
let value_id = RequestId::new(Some("61"));
serde_test_with_root(serialized_id, value_id, "RequestID");
let serialized_nil = r#"<RequestID xsi:nil="true" xmlns="http://schemas.microsoft.com/windows/pki/2009/01/enrollment"/>"#;
let value_nil = RequestId::new(None);
serde_test_with_root(serialized_nil, value_nil, "RequestID");
}
#[test]
fn test_serde_token_type() {
let serialized = "<TokenType>http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-x509-token-profile-1.0#X509v3</TokenType>";
let value = TokenType::x509v3();
serde_test(serialized, value);
}
#[test]
fn test_serde_binary_security_token() {
let cms = include_str!("../tests/data/standard_certificate_client_request.cms");
let serialized = format!(
r#"<BinarySecurityToken EncodingType="http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd#base64binary" ValueType="http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd#PKCS7" xmlns="http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd">{cms}</BinarySecurityToken>"#
);
let value = BinarySecurityToken::new_pkcs7_base64(cms);
serde_test(&serialized, value);
}
}
#[cfg(test)]
pub(crate) mod serde_test_utils {
use std::fmt::Debug;
use pretty_assertions::assert_eq;
use serde::{Deserialize, Serialize};
pub fn se_test<T>(expected_serialized: &str, value: &T)
where
T: Serialize + PartialEq + Eq + Debug,
{
se_test_inner(value, expected_serialized, None);
}
pub fn serde_test<'de, T>(serialized: &'de str, value: T)
where
T: Serialize + Deserialize<'de> + PartialEq + Eq + Debug,
{
serde_test_inner(serialized, value, None);
}
pub fn serde_test_with_root<'de, T>(serialized: &'de str, value: T, root_tag: &'de str)
where
T: Serialize + Deserialize<'de> + PartialEq + Eq + Debug,
{
serde_test_inner(serialized, value, Some(root_tag));
}
fn serde_test_inner<'de, T>(serialized: &'de str, value: T, se_with_root: Option<&'de str>)
where
T: Serialize + Deserialize<'de> + PartialEq + Eq + Debug,
{
se_test_inner(&value, serialized, se_with_root);
de_test_inner(serialized, value);
}
fn se_test_inner<T>(value: &T, expected_serialized: &str, se_with_root: Option<&str>)
where
T: Serialize + PartialEq + Eq + Debug,
{
let actual_serialized = se_with_root.map_or_else(
|| quick_xml::se::to_string(&value).unwrap(),
|root_tag| quick_xml::se::to_string_with_root(root_tag, &value).unwrap(),
);
assert_eq!(expected_serialized, actual_serialized);
}
fn de_test_inner<'de, T>(serialized: &'de str, value: T)
where
T: Serialize + Deserialize<'de> + PartialEq + Eq + Debug,
{
let expected_deserialized = value;
let actual_deserialized: T = quick_xml::de::from_str(serialized).unwrap();
assert_eq!(expected_deserialized, actual_deserialized);
}
}