use crate::errors::{AuthError, Result};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAssertion {
pub id: String,
pub issuer: String,
pub issue_instant: DateTime<Utc>,
pub version: String,
pub subject: Option<SamlSubject>,
pub conditions: Option<SamlConditions>,
pub attribute_statements: Vec<SamlAttributeStatement>,
pub authn_statements: Vec<SamlAuthnStatement>,
pub authz_decision_statements: Vec<SamlAuthzDecisionStatement>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlSubject {
pub name_id: Option<SamlNameId>,
pub subject_confirmations: Vec<SamlSubjectConfirmation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlNameId {
pub value: String,
pub format: Option<String>,
pub name_qualifier: Option<String>,
pub sp_name_qualifier: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlSubjectConfirmation {
pub method: String,
pub subject_confirmation_data: Option<SamlSubjectConfirmationData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlSubjectConfirmationData {
pub not_before: Option<DateTime<Utc>>,
pub not_on_or_after: Option<DateTime<Utc>>,
pub recipient: Option<String>,
pub in_response_to: Option<String>,
pub address: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlConditions {
pub not_before: Option<DateTime<Utc>>,
pub not_on_or_after: Option<DateTime<Utc>>,
pub audience_restrictions: Vec<SamlAudienceRestriction>,
pub one_time_use: bool,
pub proxy_restriction: Option<SamlProxyRestriction>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAudienceRestriction {
pub audiences: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlProxyRestriction {
pub count: Option<u32>,
pub audiences: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAttributeStatement {
pub attributes: Vec<SamlAttribute>,
pub encrypted_attributes: Vec<String>, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAttribute {
pub name: String,
pub name_format: Option<String>,
pub friendly_name: Option<String>,
pub values: Vec<SamlAttributeValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAttributeValue {
pub value: String,
pub type_info: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAuthnStatement {
pub authn_instant: DateTime<Utc>,
pub session_index: Option<String>,
pub session_not_on_or_after: Option<DateTime<Utc>>,
pub authn_context: SamlAuthnContext,
pub subject_locality: Option<SamlSubjectLocality>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAuthnContext {
pub authn_context_class_ref: Option<String>,
pub authn_context_decl: Option<String>,
pub authn_context_decl_ref: Option<String>,
pub authenticating_authorities: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlSubjectLocality {
pub address: Option<String>,
pub dns_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAuthzDecisionStatement {
pub resource: String,
pub decision: SamlDecision,
pub actions: Vec<SamlAction>,
pub evidence: Option<SamlEvidence>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SamlDecision {
Permit,
Deny,
Indeterminate,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlAction {
pub value: String,
pub namespace: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamlEvidence {
pub assertions: Vec<String>,
pub assertion_id_refs: Vec<String>,
pub assertion_uri_refs: Vec<String>,
}
pub struct SamlAssertionBuilder {
assertion: SamlAssertion,
}
pub struct SamlAssertionValidator {
clock_skew: Duration,
trusted_issuers: Vec<String>,
expected_audiences: Vec<String>,
}
impl SamlAssertionBuilder {
pub fn new(issuer: &str) -> Self {
let assertion = SamlAssertion {
id: format!("_{}", uuid::Uuid::new_v4()),
issuer: issuer.to_string(),
issue_instant: Utc::now(),
version: "2.0".to_string(),
subject: None,
conditions: None,
attribute_statements: Vec::new(),
authn_statements: Vec::new(),
authz_decision_statements: Vec::new(),
};
Self { assertion }
}
pub fn with_subject(mut self, subject: SamlSubject) -> Self {
self.assertion.subject = Some(subject);
self
}
pub fn with_conditions(mut self, conditions: SamlConditions) -> Self {
self.assertion.conditions = Some(conditions);
self
}
pub fn with_attribute_statement(mut self, statement: SamlAttributeStatement) -> Self {
self.assertion.attribute_statements.push(statement);
self
}
pub fn with_authn_statement(mut self, statement: SamlAuthnStatement) -> Self {
self.assertion.authn_statements.push(statement);
self
}
pub fn with_authz_decision_statement(mut self, statement: SamlAuthzDecisionStatement) -> Self {
self.assertion.authz_decision_statements.push(statement);
self
}
pub fn with_attribute(mut self, name: &str, value: &str) -> Self {
let attribute = SamlAttribute {
name: name.to_string(),
name_format: Some("urn:oasis:names:tc:SAML:2.0:attrname-format:basic".to_string()),
friendly_name: None,
values: vec![SamlAttributeValue {
value: value.to_string(),
type_info: None,
}],
};
if self.assertion.attribute_statements.is_empty() {
self.assertion
.attribute_statements
.push(SamlAttributeStatement {
attributes: vec![attribute],
encrypted_attributes: Vec::new(),
});
} else {
self.assertion.attribute_statements[0]
.attributes
.push(attribute);
}
self
}
pub fn with_validity_period(
mut self,
not_before: DateTime<Utc>,
not_on_or_after: DateTime<Utc>,
) -> Self {
if let Some(ref mut conditions) = self.assertion.conditions {
conditions.not_before = Some(not_before);
conditions.not_on_or_after = Some(not_on_or_after);
} else {
let conditions = SamlConditions {
not_before: Some(not_before),
not_on_or_after: Some(not_on_or_after),
audience_restrictions: Vec::new(),
one_time_use: false,
proxy_restriction: None,
};
self.assertion.conditions = Some(conditions);
}
self
}
pub fn with_audience(mut self, audience: &str) -> Self {
if let Some(ref mut conditions) = self.assertion.conditions {
if conditions.audience_restrictions.is_empty() {
conditions
.audience_restrictions
.push(SamlAudienceRestriction {
audiences: vec![audience.to_string()],
});
} else {
conditions.audience_restrictions[0]
.audiences
.push(audience.to_string());
}
} else {
let conditions = SamlConditions {
not_before: None,
not_on_or_after: None,
audience_restrictions: vec![SamlAudienceRestriction {
audiences: vec![audience.to_string()],
}],
one_time_use: false,
proxy_restriction: None,
};
self.assertion.conditions = Some(conditions);
}
self
}
pub fn build(self) -> SamlAssertion {
self.assertion
}
pub fn build_xml(self) -> Result<String> {
let assertion = self.assertion;
assertion.to_xml()
}
}
impl SamlAssertionValidator {
pub fn new() -> Self {
Self {
clock_skew: Duration::minutes(5),
trusted_issuers: Vec::new(),
expected_audiences: Vec::new(),
}
}
pub fn with_clock_skew(mut self, skew: Duration) -> Self {
self.clock_skew = skew;
self
}
pub fn with_trusted_issuer(mut self, issuer: &str) -> Self {
self.trusted_issuers.push(issuer.to_string());
self
}
pub fn with_expected_audience(mut self, audience: &str) -> Self {
self.expected_audiences.push(audience.to_string());
self
}
pub fn validate(&self, assertion: &SamlAssertion) -> Result<()> {
if !self.trusted_issuers.is_empty() && !self.trusted_issuers.contains(&assertion.issuer) {
return Err(AuthError::auth_method("saml", "Untrusted issuer"));
}
self.validate_timing(assertion)?;
self.validate_audience(assertion)?;
if let Some(ref subject) = assertion.subject {
self.validate_subject_confirmation(subject)?;
}
Ok(())
}
fn validate_timing(&self, assertion: &SamlAssertion) -> Result<()> {
let now = Utc::now();
if assertion.issue_instant > now + self.clock_skew {
return Err(AuthError::auth_method(
"saml",
"Assertion issued in the future",
));
}
if let Some(ref conditions) = assertion.conditions {
if let Some(not_before) = conditions.not_before
&& now < not_before - self.clock_skew
{
return Err(AuthError::auth_method("saml", "Assertion not yet valid"));
}
if let Some(not_on_or_after) = conditions.not_on_or_after
&& now >= not_on_or_after + self.clock_skew
{
return Err(AuthError::auth_method("saml", "Assertion has expired"));
}
}
Ok(())
}
fn validate_audience(&self, assertion: &SamlAssertion) -> Result<()> {
if self.expected_audiences.is_empty() {
return Ok(());
}
if let Some(ref conditions) = assertion.conditions {
for restriction in &conditions.audience_restrictions {
for audience in &restriction.audiences {
if self.expected_audiences.contains(audience) {
return Ok(());
}
}
}
if !conditions.audience_restrictions.is_empty() {
return Err(AuthError::auth_method("saml", "No matching audience found"));
}
}
Ok(())
}
fn validate_subject_confirmation(&self, _subject: &SamlSubject) -> Result<()> {
Ok(())
}
}
impl SamlAssertion {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str(&format!(
r#"<saml:Assertion xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="{}" IssueInstant="{}" Version="{}">"#,
self.id,
self.issue_instant.format("%Y-%m-%dT%H:%M:%S%.3fZ"),
self.version
));
xml.push_str(&format!("<saml:Issuer>{}</saml:Issuer>", self.issuer));
if let Some(ref subject) = self.subject {
xml.push_str(&subject.to_xml()?);
}
if let Some(ref conditions) = self.conditions {
xml.push_str(&conditions.to_xml()?);
}
for statement in &self.attribute_statements {
xml.push_str(&statement.to_xml()?);
}
for statement in &self.authn_statements {
xml.push_str(&statement.to_xml()?);
}
for statement in &self.authz_decision_statements {
xml.push_str(&statement.to_xml()?);
}
xml.push_str("</saml:Assertion>");
Ok(xml)
}
}
impl SamlSubject {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:Subject>");
if let Some(ref name_id) = self.name_id {
xml.push_str(&name_id.to_xml()?);
}
for confirmation in &self.subject_confirmations {
xml.push_str(&confirmation.to_xml()?);
}
xml.push_str("</saml:Subject>");
Ok(xml)
}
}
impl SamlNameId {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:NameID");
if let Some(ref format) = self.format {
xml.push_str(&format!(" Format=\"{}\"", format));
}
if let Some(ref name_qualifier) = self.name_qualifier {
xml.push_str(&format!(" NameQualifier=\"{}\"", name_qualifier));
}
if let Some(ref sp_name_qualifier) = self.sp_name_qualifier {
xml.push_str(&format!(" SPNameQualifier=\"{}\"", sp_name_qualifier));
}
xml.push_str(&format!(">{}</saml:NameID>", self.value));
Ok(xml)
}
}
impl SamlSubjectConfirmation {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str(&format!(
"<saml:SubjectConfirmation Method=\"{}\">",
self.method
));
if let Some(ref data) = self.subject_confirmation_data {
xml.push_str(&data.to_xml()?);
}
xml.push_str("</saml:SubjectConfirmation>");
Ok(xml)
}
}
impl SamlSubjectConfirmationData {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:SubjectConfirmationData");
if let Some(not_before) = self.not_before {
xml.push_str(&format!(
" NotBefore=\"{}\"",
not_before.format("%Y-%m-%dT%H:%M:%S%.3fZ")
));
}
if let Some(not_on_or_after) = self.not_on_or_after {
xml.push_str(&format!(
" NotOnOrAfter=\"{}\"",
not_on_or_after.format("%Y-%m-%dT%H:%M:%S%.3fZ")
));
}
if let Some(ref recipient) = self.recipient {
xml.push_str(&format!(" Recipient=\"{}\"", recipient));
}
if let Some(ref in_response_to) = self.in_response_to {
xml.push_str(&format!(" InResponseTo=\"{}\"", in_response_to));
}
if let Some(ref address) = self.address {
xml.push_str(&format!(" Address=\"{}\"", address));
}
xml.push_str("/>");
Ok(xml)
}
}
impl SamlConditions {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:Conditions");
if let Some(not_before) = self.not_before {
xml.push_str(&format!(
" NotBefore=\"{}\"",
not_before.format("%Y-%m-%dT%H:%M:%S%.3fZ")
));
}
if let Some(not_on_or_after) = self.not_on_or_after {
xml.push_str(&format!(
" NotOnOrAfter=\"{}\"",
not_on_or_after.format("%Y-%m-%dT%H:%M:%S%.3fZ")
));
}
xml.push('>');
for restriction in &self.audience_restrictions {
xml.push_str(&restriction.to_xml()?);
}
if self.one_time_use {
xml.push_str("<saml:OneTimeUse/>");
}
if let Some(ref proxy_restriction) = self.proxy_restriction {
xml.push_str(&proxy_restriction.to_xml()?);
}
xml.push_str("</saml:Conditions>");
Ok(xml)
}
}
impl SamlAudienceRestriction {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:AudienceRestriction>");
for audience in &self.audiences {
xml.push_str(&format!("<saml:Audience>{}</saml:Audience>", audience));
}
xml.push_str("</saml:AudienceRestriction>");
Ok(xml)
}
}
impl SamlProxyRestriction {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:ProxyRestriction");
if let Some(count) = self.count {
xml.push_str(&format!(" Count=\"{}\"", count));
}
xml.push('>');
for audience in &self.audiences {
xml.push_str(&format!("<saml:Audience>{}</saml:Audience>", audience));
}
xml.push_str("</saml:ProxyRestriction>");
Ok(xml)
}
}
impl SamlAttributeStatement {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:AttributeStatement>");
for attribute in &self.attributes {
xml.push_str(&attribute.to_xml()?);
}
for encrypted_attr in &self.encrypted_attributes {
xml.push_str(encrypted_attr);
}
xml.push_str("</saml:AttributeStatement>");
Ok(xml)
}
}
impl SamlAttribute {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str(&format!("<saml:Attribute Name=\"{}\">", self.name));
if let Some(ref name_format) = self.name_format {
xml = xml.replace(">", &format!(" NameFormat=\"{}\">", name_format));
}
if let Some(ref friendly_name) = self.friendly_name {
xml = xml.replace(">", &format!(" FriendlyName=\"{}\">", friendly_name));
}
for value in &self.values {
xml.push_str(&value.to_xml()?);
}
xml.push_str("</saml:Attribute>");
Ok(xml)
}
}
impl SamlAttributeValue {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:AttributeValue");
if let Some(ref type_info) = self.type_info {
xml.push_str(&format!(" xsi:type=\"{}\"", type_info));
}
xml.push_str(&format!(">{}</saml:AttributeValue>", self.value));
Ok(xml)
}
}
impl SamlAuthnStatement {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str(&format!(
"<saml:AuthnStatement AuthnInstant=\"{}\"",
self.authn_instant.format("%Y-%m-%dT%H:%M:%S%.3fZ")
));
if let Some(ref session_index) = self.session_index {
xml.push_str(&format!(" SessionIndex=\"{}\"", session_index));
}
if let Some(session_not_on_or_after) = self.session_not_on_or_after {
xml.push_str(&format!(
" SessionNotOnOrAfter=\"{}\"",
session_not_on_or_after.format("%Y-%m-%dT%H:%M:%S%.3fZ")
));
}
xml.push('>');
if let Some(ref locality) = self.subject_locality {
xml.push_str(&locality.to_xml()?);
}
xml.push_str(&self.authn_context.to_xml()?);
xml.push_str("</saml:AuthnStatement>");
Ok(xml)
}
}
impl SamlAuthnContext {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:AuthnContext>");
if let Some(ref class_ref) = self.authn_context_class_ref {
xml.push_str(&format!(
"<saml:AuthnContextClassRef>{}</saml:AuthnContextClassRef>",
class_ref
));
}
if let Some(ref decl) = self.authn_context_decl {
xml.push_str(&format!(
"<saml:AuthnContextDecl>{}</saml:AuthnContextDecl>",
decl
));
}
if let Some(ref decl_ref) = self.authn_context_decl_ref {
xml.push_str(&format!(
"<saml:AuthnContextDeclRef>{}</saml:AuthnContextDeclRef>",
decl_ref
));
}
for authority in &self.authenticating_authorities {
xml.push_str(&format!(
"<saml:AuthenticatingAuthority>{}</saml:AuthenticatingAuthority>",
authority
));
}
xml.push_str("</saml:AuthnContext>");
Ok(xml)
}
}
impl SamlSubjectLocality {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:SubjectLocality");
if let Some(ref address) = self.address {
xml.push_str(&format!(" Address=\"{}\"", address));
}
if let Some(ref dns_name) = self.dns_name {
xml.push_str(&format!(" DNSName=\"{}\"", dns_name));
}
xml.push_str("/>");
Ok(xml)
}
}
impl SamlAuthzDecisionStatement {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
let decision_str = match self.decision {
SamlDecision::Permit => "Permit",
SamlDecision::Deny => "Deny",
SamlDecision::Indeterminate => "Indeterminate",
};
xml.push_str(&format!(
"<saml:AuthzDecisionStatement Decision=\"{}\" Resource=\"{}\">",
decision_str, self.resource
));
for action in &self.actions {
xml.push_str(&action.to_xml()?);
}
if let Some(ref evidence) = self.evidence {
xml.push_str(&evidence.to_xml()?);
}
xml.push_str("</saml:AuthzDecisionStatement>");
Ok(xml)
}
}
impl SamlAction {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:Action");
if let Some(ref namespace) = self.namespace {
xml.push_str(&format!(" Namespace=\"{}\"", namespace));
}
xml.push_str(&format!(">{}</saml:Action>", self.value));
Ok(xml)
}
}
impl SamlEvidence {
pub fn to_xml(&self) -> Result<String> {
let mut xml = String::new();
xml.push_str("<saml:Evidence>");
for assertion in &self.assertions {
xml.push_str(assertion);
}
for id_ref in &self.assertion_id_refs {
xml.push_str(&format!(
"<saml:AssertionIDRef>{}</saml:AssertionIDRef>",
id_ref
));
}
for uri_ref in &self.assertion_uri_refs {
xml.push_str(&format!(
"<saml:AssertionURIRef>{}</saml:AssertionURIRef>",
uri_ref
));
}
xml.push_str("</saml:Evidence>");
Ok(xml)
}
}
impl Default for SamlAssertionValidator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_saml_assertion_builder() {
let assertion = SamlAssertionBuilder::new("https://idp.example.com")
.with_attribute("username", "testuser")
.with_attribute("email", "test@example.com")
.with_audience("https://sp.example.com")
.build();
assert_eq!(assertion.issuer, "https://idp.example.com");
assert_eq!(assertion.version, "2.0");
assert!(!assertion.attribute_statements.is_empty());
assert!(assertion.conditions.is_some());
}
#[test]
fn test_saml_assertion_xml() {
let assertion = SamlAssertionBuilder::new("https://idp.example.com")
.with_attribute("username", "testuser")
.build();
let xml = assertion.to_xml().unwrap();
assert!(xml.contains("<saml:Assertion"));
assert!(xml.contains("https://idp.example.com"));
assert!(xml.contains("testuser"));
assert!(xml.contains("</saml:Assertion>"));
}
#[test]
fn test_saml_assertion_validation() {
let validator = SamlAssertionValidator::new()
.with_trusted_issuer("https://idp.example.com")
.with_expected_audience("https://sp.example.com");
let assertion = SamlAssertionBuilder::new("https://idp.example.com")
.with_audience("https://sp.example.com")
.with_validity_period(
Utc::now() - Duration::minutes(1),
Utc::now() + Duration::hours(1),
)
.build();
assert!(validator.validate(&assertion).is_ok());
}
#[test]
fn test_expired_assertion_validation() {
let validator = SamlAssertionValidator::new();
let assertion = SamlAssertionBuilder::new("https://idp.example.com")
.with_validity_period(
Utc::now() - Duration::hours(2),
Utc::now() - Duration::hours(1),
)
.build();
assert!(validator.validate(&assertion).is_err());
}
}