1use std::collections::HashMap;
38use std::sync::Arc;
39use std::time::{Duration, SystemTime, UNIX_EPOCH};
40
41use crate::http::security::User;
42
43#[derive(Debug, Clone, PartialEq, Eq, Default)]
45pub enum NameIdFormat {
46 #[default]
48 Unspecified,
49 EmailAddress,
51 X509SubjectName,
53 WindowsDomainQualifiedName,
55 Kerberos,
57 Persistent,
59 Transient,
61 Custom(String),
63}
64
65impl NameIdFormat {
66 pub fn as_urn(&self) -> &str {
68 match self {
69 NameIdFormat::Unspecified => "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified",
70 NameIdFormat::EmailAddress => "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
71 NameIdFormat::X509SubjectName => {
72 "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName"
73 }
74 NameIdFormat::WindowsDomainQualifiedName => {
75 "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName"
76 }
77 NameIdFormat::Kerberos => "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos",
78 NameIdFormat::Persistent => "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent",
79 NameIdFormat::Transient => "urn:oasis:names:tc:SAML:2.0:nameid-format:transient",
80 NameIdFormat::Custom(urn) => urn,
81 }
82 }
83
84 pub fn from_urn(urn: &str) -> Self {
86 match urn {
87 "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified" => NameIdFormat::Unspecified,
88 "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress" => NameIdFormat::EmailAddress,
89 "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName" => {
90 NameIdFormat::X509SubjectName
91 }
92 "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName" => {
93 NameIdFormat::WindowsDomainQualifiedName
94 }
95 "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos" => NameIdFormat::Kerberos,
96 "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent" => NameIdFormat::Persistent,
97 "urn:oasis:names:tc:SAML:2.0:nameid-format:transient" => NameIdFormat::Transient,
98 other => NameIdFormat::Custom(other.to_string()),
99 }
100 }
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
105pub enum SamlBinding {
106 #[default]
108 HttpRedirect,
109 HttpPost,
111 HttpArtifact,
113 Soap,
115}
116
117impl SamlBinding {
118 pub fn as_urn(&self) -> &str {
120 match self {
121 SamlBinding::HttpRedirect => "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
122 SamlBinding::HttpPost => "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
123 SamlBinding::HttpArtifact => "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact",
124 SamlBinding::Soap => "urn:oasis:names:tc:SAML:2.0:bindings:SOAP",
125 }
126 }
127}
128
129#[derive(Debug, Clone, PartialEq, Eq, Default)]
131pub enum AuthnContextClass {
132 #[default]
134 Unspecified,
135 Password,
137 PasswordProtectedTransport,
139 X509,
141 Kerberos,
143 MultiFactor,
145 Custom(String),
147}
148
149impl AuthnContextClass {
150 pub fn as_urn(&self) -> &str {
152 match self {
153 AuthnContextClass::Unspecified => "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified",
154 AuthnContextClass::Password => "urn:oasis:names:tc:SAML:2.0:ac:classes:Password",
155 AuthnContextClass::PasswordProtectedTransport => {
156 "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport"
157 }
158 AuthnContextClass::X509 => "urn:oasis:names:tc:SAML:2.0:ac:classes:X509",
159 AuthnContextClass::Kerberos => "urn:oasis:names:tc:SAML:2.0:ac:classes:Kerberos",
160 AuthnContextClass::MultiFactor => "urn:oasis:names:tc:SAML:2.0:ac:classes:MultiFactor",
161 AuthnContextClass::Custom(urn) => urn,
162 }
163 }
164}
165
166#[derive(Debug, Clone, PartialEq, Eq)]
168pub enum SamlStatusCode {
169 Success,
171 Requester,
173 Responder,
175 VersionMismatch,
177 AuthnFailed,
179 InvalidNameIdPolicy,
181 NoAuthnContext,
183 Unknown(String),
185}
186
187impl SamlStatusCode {
188 pub fn from_urn(urn: &str) -> Self {
190 match urn {
191 "urn:oasis:names:tc:SAML:2.0:status:Success" => SamlStatusCode::Success,
192 "urn:oasis:names:tc:SAML:2.0:status:Requester" => SamlStatusCode::Requester,
193 "urn:oasis:names:tc:SAML:2.0:status:Responder" => SamlStatusCode::Responder,
194 "urn:oasis:names:tc:SAML:2.0:status:VersionMismatch" => SamlStatusCode::VersionMismatch,
195 "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed" => SamlStatusCode::AuthnFailed,
196 "urn:oasis:names:tc:SAML:2.0:status:InvalidNameIDPolicy" => {
197 SamlStatusCode::InvalidNameIdPolicy
198 }
199 "urn:oasis:names:tc:SAML:2.0:status:NoAuthnContext" => SamlStatusCode::NoAuthnContext,
200 other => SamlStatusCode::Unknown(other.to_string()),
201 }
202 }
203
204 pub fn is_success(&self) -> bool {
206 matches!(self, SamlStatusCode::Success)
207 }
208}
209
210#[derive(Debug, Clone)]
212pub struct SamlConfig {
213 pub entity_id: String,
215 pub idp_sso_url: String,
217 pub idp_slo_url: Option<String>,
219 pub idp_entity_id: Option<String>,
221 pub idp_certificate: Option<String>,
223 pub sp_private_key: Option<String>,
225 pub sp_certificate: Option<String>,
227 pub acs_url: String,
229 pub sls_url: Option<String>,
231 pub sso_binding: SamlBinding,
233 pub slo_binding: SamlBinding,
235 pub name_id_format: NameIdFormat,
237 pub authn_context_class: Option<AuthnContextClass>,
239 pub sign_authn_request: bool,
241 pub want_assertions_signed: bool,
243 pub want_assertions_encrypted: bool,
245 pub max_clock_skew: Duration,
247 pub attribute_mapping: HashMap<String, String>,
249 pub role_attribute: Option<String>,
251 pub authority_attribute: Option<String>,
253 pub default_roles: Vec<String>,
255 pub allow_unsolicited_responses: bool,
257 pub session_timeout: Duration,
259}
260
261impl SamlConfig {
262 pub fn new() -> Self {
264 Self {
265 entity_id: String::new(),
266 idp_sso_url: String::new(),
267 idp_slo_url: None,
268 idp_entity_id: None,
269 idp_certificate: None,
270 sp_private_key: None,
271 sp_certificate: None,
272 acs_url: String::new(),
273 sls_url: None,
274 sso_binding: SamlBinding::HttpRedirect,
275 slo_binding: SamlBinding::HttpRedirect,
276 name_id_format: NameIdFormat::Unspecified,
277 authn_context_class: None,
278 sign_authn_request: false,
279 want_assertions_signed: true,
280 want_assertions_encrypted: false,
281 max_clock_skew: Duration::from_secs(120),
282 attribute_mapping: HashMap::new(),
283 role_attribute: None,
284 authority_attribute: None,
285 default_roles: vec!["USER".to_string()],
286 allow_unsolicited_responses: false,
287 session_timeout: Duration::from_secs(3600),
288 }
289 }
290
291 pub fn entity_id(mut self, entity_id: impl Into<String>) -> Self {
293 self.entity_id = entity_id.into();
294 self
295 }
296
297 pub fn idp_sso_url(mut self, url: impl Into<String>) -> Self {
299 self.idp_sso_url = url.into();
300 self
301 }
302
303 pub fn idp_slo_url(mut self, url: impl Into<String>) -> Self {
305 self.idp_slo_url = Some(url.into());
306 self
307 }
308
309 pub fn idp_entity_id(mut self, entity_id: impl Into<String>) -> Self {
311 self.idp_entity_id = Some(entity_id.into());
312 self
313 }
314
315 pub fn idp_certificate(mut self, cert: impl Into<String>) -> Self {
317 self.idp_certificate = Some(cert.into());
318 self
319 }
320
321 pub fn sp_private_key(mut self, key: impl Into<String>) -> Self {
323 self.sp_private_key = Some(key.into());
324 self
325 }
326
327 pub fn sp_certificate(mut self, cert: impl Into<String>) -> Self {
329 self.sp_certificate = Some(cert.into());
330 self
331 }
332
333 pub fn acs_url(mut self, url: impl Into<String>) -> Self {
335 self.acs_url = url.into();
336 self
337 }
338
339 pub fn assertion_consumer_service_url(self, url: impl Into<String>) -> Self {
341 self.acs_url(url)
342 }
343
344 pub fn sls_url(mut self, url: impl Into<String>) -> Self {
346 self.sls_url = Some(url.into());
347 self
348 }
349
350 pub fn sso_binding(mut self, binding: SamlBinding) -> Self {
352 self.sso_binding = binding;
353 self
354 }
355
356 pub fn slo_binding(mut self, binding: SamlBinding) -> Self {
358 self.slo_binding = binding;
359 self
360 }
361
362 pub fn name_id_format(mut self, format: NameIdFormat) -> Self {
364 self.name_id_format = format;
365 self
366 }
367
368 pub fn authn_context_class(mut self, class: AuthnContextClass) -> Self {
370 self.authn_context_class = Some(class);
371 self
372 }
373
374 pub fn sign_authn_request(mut self, sign: bool) -> Self {
376 self.sign_authn_request = sign;
377 self
378 }
379
380 pub fn want_assertions_signed(mut self, signed: bool) -> Self {
382 self.want_assertions_signed = signed;
383 self
384 }
385
386 pub fn want_assertions_encrypted(mut self, encrypted: bool) -> Self {
388 self.want_assertions_encrypted = encrypted;
389 self
390 }
391
392 pub fn max_clock_skew(mut self, skew: Duration) -> Self {
394 self.max_clock_skew = skew;
395 self
396 }
397
398 pub fn map_attribute(
400 mut self,
401 saml_attribute: impl Into<String>,
402 user_field: impl Into<String>,
403 ) -> Self {
404 self.attribute_mapping
405 .insert(saml_attribute.into(), user_field.into());
406 self
407 }
408
409 pub fn role_attribute(mut self, attr: impl Into<String>) -> Self {
411 self.role_attribute = Some(attr.into());
412 self
413 }
414
415 pub fn authority_attribute(mut self, attr: impl Into<String>) -> Self {
417 self.authority_attribute = Some(attr.into());
418 self
419 }
420
421 pub fn default_roles(mut self, roles: Vec<String>) -> Self {
423 self.default_roles = roles;
424 self
425 }
426
427 pub fn allow_unsolicited_responses(mut self, allow: bool) -> Self {
429 self.allow_unsolicited_responses = allow;
430 self
431 }
432
433 pub fn session_timeout(mut self, timeout: Duration) -> Self {
435 self.session_timeout = timeout;
436 self
437 }
438
439 pub fn okta(
441 okta_domain: impl Into<String>,
442 app_id: impl Into<String>,
443 sp_entity_id: impl Into<String>,
444 ) -> Self {
445 let domain = okta_domain.into();
446 let app = app_id.into();
447 Self::new()
448 .entity_id(sp_entity_id)
449 .idp_sso_url(format!("https://{}/app/{}/sso/saml", domain, app))
450 .idp_entity_id(format!("http://www.okta.com/{}", app))
451 .name_id_format(NameIdFormat::EmailAddress)
452 .sso_binding(SamlBinding::HttpPost)
453 }
454
455 pub fn azure_ad(
457 tenant_id: impl Into<String>,
458 _app_id: impl Into<String>,
459 sp_entity_id: impl Into<String>,
460 ) -> Self {
461 let tenant = tenant_id.into();
462 Self::new()
463 .entity_id(sp_entity_id)
464 .idp_sso_url(format!(
465 "https://login.microsoftonline.com/{}/saml2",
466 tenant
467 ))
468 .idp_entity_id(format!("https://sts.windows.net/{}/", tenant))
469 .name_id_format(NameIdFormat::EmailAddress)
470 .sso_binding(SamlBinding::HttpRedirect)
471 .map_attribute(
472 "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress",
473 "email",
474 )
475 .map_attribute(
476 "http://schemas.microsoft.com/ws/2008/06/identity/claims/groups",
477 "groups",
478 )
479 }
480
481 pub fn google_workspace(sp_entity_id: impl Into<String>, acs_url: impl Into<String>) -> Self {
483 Self::new()
484 .entity_id(sp_entity_id)
485 .idp_sso_url("https://accounts.google.com/o/saml2/idp")
486 .acs_url(acs_url)
487 .name_id_format(NameIdFormat::EmailAddress)
488 .sso_binding(SamlBinding::HttpRedirect)
489 }
490
491 pub fn adfs(adfs_host: impl Into<String>, sp_entity_id: impl Into<String>) -> Self {
493 let host = adfs_host.into();
494 Self::new()
495 .entity_id(sp_entity_id)
496 .idp_sso_url(format!("https://{}/adfs/ls/", host))
497 .idp_entity_id(format!("http://{}/adfs/services/trust", host))
498 .name_id_format(NameIdFormat::Unspecified)
499 .sso_binding(SamlBinding::HttpPost)
500 }
501
502 pub fn validate(&self) -> Result<(), SamlError> {
504 if self.entity_id.is_empty() {
505 return Err(SamlError::Configuration("entity_id is required".into()));
506 }
507 if self.idp_sso_url.is_empty() {
508 return Err(SamlError::Configuration("idp_sso_url is required".into()));
509 }
510 if self.acs_url.is_empty() {
511 return Err(SamlError::Configuration("acs_url is required".into()));
512 }
513 if self.sign_authn_request && self.sp_private_key.is_none() {
514 return Err(SamlError::Configuration(
515 "sp_private_key is required when sign_authn_request is true".into(),
516 ));
517 }
518 if self.want_assertions_signed && self.idp_certificate.is_none() {
519 return Err(SamlError::Configuration(
520 "idp_certificate is required when want_assertions_signed is true".into(),
521 ));
522 }
523 Ok(())
524 }
525}
526
527impl Default for SamlConfig {
528 fn default() -> Self {
529 Self::new()
530 }
531}
532
533#[derive(Debug, Clone)]
535pub enum SamlError {
536 Configuration(String),
538 InvalidResponse(String),
540 SignatureVerificationFailed(String),
542 AssertionValidationFailed(String),
544 TimeConditionNotMet(String),
546 AudienceRestrictionNotMet(String),
548 MissingAttribute(String),
550 IdpError(SamlStatusCode, Option<String>),
552 DecryptionFailed(String),
554 XmlParsingError(String),
556 NetworkError(String),
558}
559
560impl std::fmt::Display for SamlError {
561 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562 match self {
563 SamlError::Configuration(msg) => write!(f, "SAML configuration error: {}", msg),
564 SamlError::InvalidResponse(msg) => write!(f, "Invalid SAML response: {}", msg),
565 SamlError::SignatureVerificationFailed(msg) => {
566 write!(f, "SAML signature verification failed: {}", msg)
567 }
568 SamlError::AssertionValidationFailed(msg) => {
569 write!(f, "SAML assertion validation failed: {}", msg)
570 }
571 SamlError::TimeConditionNotMet(msg) => {
572 write!(f, "SAML time condition not met: {}", msg)
573 }
574 SamlError::AudienceRestrictionNotMet(msg) => {
575 write!(f, "SAML audience restriction not met: {}", msg)
576 }
577 SamlError::MissingAttribute(attr) => {
578 write!(f, "Required SAML attribute missing: {}", attr)
579 }
580 SamlError::IdpError(code, msg) => {
581 write!(f, "IdP returned error {:?}: {:?}", code, msg)
582 }
583 SamlError::DecryptionFailed(msg) => write!(f, "SAML decryption failed: {}", msg),
584 SamlError::XmlParsingError(msg) => write!(f, "XML parsing error: {}", msg),
585 SamlError::NetworkError(msg) => write!(f, "Network error: {}", msg),
586 }
587 }
588}
589
590impl std::error::Error for SamlError {}
591
592#[derive(Debug, Clone)]
594pub struct AuthnRequest {
595 pub id: String,
597 pub issue_instant: String,
599 pub issuer: String,
601 pub destination: String,
603 pub acs_url: String,
605 pub protocol_binding: SamlBinding,
607 pub name_id_format: NameIdFormat,
609 pub authn_context: Option<AuthnContextClass>,
611 pub force_authn: bool,
613 pub is_passive: bool,
615}
616
617impl AuthnRequest {
618 pub fn new(config: &SamlConfig) -> Self {
620 let id = format!("_{}_{}", generate_id(), timestamp_millis());
621
622 Self {
623 id,
624 issue_instant: iso8601_now(),
625 issuer: config.entity_id.clone(),
626 destination: config.idp_sso_url.clone(),
627 acs_url: config.acs_url.clone(),
628 protocol_binding: SamlBinding::HttpPost,
629 name_id_format: config.name_id_format.clone(),
630 authn_context: config.authn_context_class.clone(),
631 force_authn: false,
632 is_passive: false,
633 }
634 }
635
636 pub fn force_authn(mut self, force: bool) -> Self {
638 self.force_authn = force;
639 self
640 }
641
642 pub fn is_passive(mut self, passive: bool) -> Self {
644 self.is_passive = passive;
645 self
646 }
647
648 pub fn to_xml(&self) -> String {
650 let mut xml = String::new();
651 xml.push_str(r#"<?xml version="1.0" encoding="UTF-8"?>"#);
652 xml.push_str(&format!(
653 r#"<samlp:AuthnRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="{}" Version="2.0" IssueInstant="{}" Destination="{}" AssertionConsumerServiceURL="{}" ProtocolBinding="{}""#,
654 self.id,
655 self.issue_instant,
656 self.destination,
657 self.acs_url,
658 self.protocol_binding.as_urn()
659 ));
660
661 if self.force_authn {
662 xml.push_str(r#" ForceAuthn="true""#);
663 }
664 if self.is_passive {
665 xml.push_str(r#" IsPassive="true""#);
666 }
667
668 xml.push('>');
669
670 xml.push_str(&format!(r#"<saml:Issuer>{}</saml:Issuer>"#, self.issuer));
672
673 xml.push_str(&format!(
675 r#"<samlp:NameIDPolicy Format="{}" AllowCreate="true"/>"#,
676 self.name_id_format.as_urn()
677 ));
678
679 if let Some(ref authn_context) = self.authn_context {
681 xml.push_str(r#"<samlp:RequestedAuthnContext Comparison="exact">"#);
682 xml.push_str(&format!(
683 r#"<saml:AuthnContextClassRef>{}</saml:AuthnContextClassRef>"#,
684 authn_context.as_urn()
685 ));
686 xml.push_str(r#"</samlp:RequestedAuthnContext>"#);
687 }
688
689 xml.push_str(r#"</samlp:AuthnRequest>"#);
690 xml
691 }
692
693 pub fn to_redirect_url(&self, relay_state: Option<&str>) -> String {
695 let xml = self.to_xml();
696 let deflated = deflate_and_encode(&xml);
697
698 let mut url = format!(
699 "{}?SAMLRequest={}",
700 self.destination,
701 urlencoding::encode(&deflated)
702 );
703
704 if let Some(state) = relay_state {
705 url.push_str(&format!("&RelayState={}", urlencoding::encode(state)));
706 }
707
708 url
709 }
710}
711
712#[derive(Debug, Clone)]
714pub struct SamlAssertion {
715 pub id: String,
717 pub issue_instant: String,
719 pub issuer: String,
721 pub name_id: String,
723 pub name_id_format: NameIdFormat,
725 pub session_index: Option<String>,
727 pub session_not_on_or_after: Option<String>,
729 pub not_before: Option<String>,
731 pub not_on_or_after: Option<String>,
733 pub audiences: Vec<String>,
735 pub authn_context_class: Option<String>,
737 pub attributes: HashMap<String, Vec<String>>,
739}
740
741impl SamlAssertion {
742 pub fn get_attribute(&self, name: &str) -> Option<&str> {
744 self.attributes
745 .get(name)
746 .and_then(|values| values.first())
747 .map(|s| s.as_str())
748 }
749
750 pub fn get_attribute_values(&self, name: &str) -> Option<&Vec<String>> {
752 self.attributes.get(name)
753 }
754
755 pub fn validate(&self, config: &SamlConfig) -> Result<(), SamlError> {
757 let now = SystemTime::now()
758 .duration_since(UNIX_EPOCH)
759 .unwrap()
760 .as_secs();
761
762 if let Some(ref not_before) = self.not_before {
764 if let Ok(nb_time) = parse_iso8601(not_before) {
765 let skew = config.max_clock_skew.as_secs();
766 if now + skew < nb_time {
767 return Err(SamlError::TimeConditionNotMet(format!(
768 "Assertion not valid before {}",
769 not_before
770 )));
771 }
772 }
773 }
774
775 if let Some(ref not_on_or_after) = self.not_on_or_after {
777 if let Ok(noa_time) = parse_iso8601(not_on_or_after) {
778 let skew = config.max_clock_skew.as_secs();
779 if now > noa_time + skew {
780 return Err(SamlError::TimeConditionNotMet(format!(
781 "Assertion expired at {}",
782 not_on_or_after
783 )));
784 }
785 }
786 }
787
788 if !self.audiences.is_empty() && !self.audiences.contains(&config.entity_id) {
790 return Err(SamlError::AudienceRestrictionNotMet(format!(
791 "SP entity ID {} not in audiences: {:?}",
792 config.entity_id, self.audiences
793 )));
794 }
795
796 if let Some(ref expected_issuer) = config.idp_entity_id {
798 if &self.issuer != expected_issuer {
799 return Err(SamlError::AssertionValidationFailed(format!(
800 "Issuer mismatch: expected {}, got {}",
801 expected_issuer, self.issuer
802 )));
803 }
804 }
805
806 Ok(())
807 }
808}
809
810#[derive(Debug, Clone)]
812pub struct SamlResponse {
813 pub id: String,
815 pub in_response_to: Option<String>,
817 pub issue_instant: String,
819 pub destination: Option<String>,
821 pub issuer: String,
823 pub status_code: SamlStatusCode,
825 pub status_message: Option<String>,
827 pub assertions: Vec<SamlAssertion>,
829}
830
831impl SamlResponse {
832 pub fn from_base64(encoded: &str) -> Result<Self, SamlError> {
837 use base64::{engine::general_purpose::STANDARD, Engine as _};
838
839 let decoded = STANDARD
840 .decode(encoded)
841 .map_err(|e| SamlError::InvalidResponse(format!("Base64 decode error: {}", e)))?;
842
843 let xml = String::from_utf8(decoded)
844 .map_err(|e| SamlError::InvalidResponse(format!("UTF-8 decode error: {}", e)))?;
845
846 Self::from_xml(&xml)
847 }
848
849 pub fn from_xml(xml: &str) -> Result<Self, SamlError> {
854 let id = extract_attribute(xml, "Response", "ID")
858 .ok_or_else(|| SamlError::XmlParsingError("Missing Response ID".into()))?;
859
860 let in_response_to = extract_attribute(xml, "Response", "InResponseTo");
861 let issue_instant = extract_attribute(xml, "Response", "IssueInstant")
862 .ok_or_else(|| SamlError::XmlParsingError("Missing IssueInstant".into()))?;
863 let destination = extract_attribute(xml, "Response", "Destination");
864 let issuer = extract_element_text(xml, "Issuer")
865 .ok_or_else(|| SamlError::XmlParsingError("Missing Issuer".into()))?;
866
867 let status_code =
869 extract_status_code(xml).unwrap_or(SamlStatusCode::Unknown(String::new()));
870 let status_message = extract_element_text(xml, "StatusMessage");
871
872 let assertions = parse_assertions(xml)?;
874
875 Ok(Self {
876 id,
877 in_response_to,
878 issue_instant,
879 destination,
880 issuer,
881 status_code,
882 status_message,
883 assertions,
884 })
885 }
886
887 pub fn is_success(&self) -> bool {
889 self.status_code.is_success()
890 }
891
892 pub fn assertion(&self) -> Option<&SamlAssertion> {
894 self.assertions.first()
895 }
896}
897
898#[derive(Clone)]
900pub struct SamlAuthenticator {
901 config: Arc<SamlConfig>,
902 pending_requests: Arc<std::sync::RwLock<HashMap<String, PendingRequest>>>,
903}
904
905#[derive(Debug, Clone)]
907#[allow(dead_code)] struct PendingRequest {
909 id: String,
911 created_at: u64,
913 relay_state: Option<String>,
915}
916
917impl SamlAuthenticator {
918 pub fn new(config: SamlConfig) -> Result<Self, SamlError> {
920 config.validate()?;
921 Ok(Self {
922 config: Arc::new(config),
923 pending_requests: Arc::new(std::sync::RwLock::new(HashMap::new())),
924 })
925 }
926
927 pub fn config(&self) -> &SamlConfig {
929 &self.config
930 }
931
932 pub fn create_authn_request(&self) -> AuthnRequest {
934 AuthnRequest::new(&self.config)
935 }
936
937 pub fn store_pending_request(&self, request: &AuthnRequest, relay_state: Option<String>) {
939 let mut pending = self.pending_requests.write().unwrap();
940 pending.insert(
941 request.id.clone(),
942 PendingRequest {
943 id: request.id.clone(),
944 created_at: timestamp_millis() / 1000,
945 relay_state,
946 },
947 );
948
949 let now = timestamp_millis() / 1000;
951 pending.retain(|_, req| now - req.created_at < 600);
952 }
953
954 pub fn initiate_login(&self, relay_state: Option<&str>) -> String {
956 let request = self.create_authn_request();
957 self.store_pending_request(&request, relay_state.map(|s| s.to_string()));
958 request.to_redirect_url(relay_state)
959 }
960
961 pub fn process_response(&self, encoded_response: &str) -> Result<SamlAuthResult, SamlError> {
963 let response = SamlResponse::from_base64(encoded_response)?;
964
965 self.validate_response(&response)?;
967
968 let assertion = response
970 .assertion()
971 .ok_or_else(|| SamlError::InvalidResponse("No assertion in response".into()))?;
972
973 assertion.validate(&self.config)?;
975
976 let user = self.map_assertion_to_user(assertion)?;
978
979 if let Some(ref in_response_to) = response.in_response_to {
981 let mut pending = self.pending_requests.write().unwrap();
982 pending.remove(in_response_to);
983 }
984
985 Ok(SamlAuthResult {
986 user,
987 session_index: assertion.session_index.clone(),
988 name_id: assertion.name_id.clone(),
989 name_id_format: assertion.name_id_format.clone(),
990 attributes: assertion.attributes.clone(),
991 })
992 }
993
994 fn validate_response(&self, response: &SamlResponse) -> Result<(), SamlError> {
996 if !response.is_success() {
998 return Err(SamlError::IdpError(
999 response.status_code.clone(),
1000 response.status_message.clone(),
1001 ));
1002 }
1003
1004 if !self.config.allow_unsolicited_responses {
1006 if let Some(ref in_response_to) = response.in_response_to {
1007 let pending = self.pending_requests.read().unwrap();
1008 if !pending.contains_key(in_response_to) {
1009 return Err(SamlError::InvalidResponse(
1010 "InResponseTo does not match any pending request".into(),
1011 ));
1012 }
1013 } else {
1014 return Err(SamlError::InvalidResponse(
1015 "Unsolicited responses are not allowed".into(),
1016 ));
1017 }
1018 }
1019
1020 if let Some(ref destination) = response.destination {
1022 if destination != &self.config.acs_url {
1023 return Err(SamlError::InvalidResponse(format!(
1024 "Destination mismatch: expected {}, got {}",
1025 self.config.acs_url, destination
1026 )));
1027 }
1028 }
1029
1030 Ok(())
1031 }
1032
1033 fn map_assertion_to_user(&self, assertion: &SamlAssertion) -> Result<User, SamlError> {
1035 let username = assertion.name_id.clone();
1036
1037 let mut user = User::with_encoded_password(&username, "{saml}external".to_string());
1039
1040 for (saml_attr, user_field) in &self.config.attribute_mapping {
1042 if let Some(values) = assertion.attributes.get(saml_attr) {
1043 if let Some(value) = values.first() {
1044 match user_field.as_str() {
1045 "email" => {
1046 user = user.authorities(&[format!("email:{}", value)]);
1048 }
1049 "display_name" | "name" => {
1050 }
1052 _ => {}
1053 }
1054 }
1055 }
1056 }
1057
1058 let mut roles: Vec<String> = self.config.default_roles.clone();
1060 if let Some(ref role_attr) = self.config.role_attribute {
1061 if let Some(values) = assertion.attributes.get(role_attr) {
1062 roles.extend(values.iter().map(|r| r.to_uppercase()));
1063 }
1064 }
1065 user = user.roles(&roles);
1066
1067 if let Some(ref auth_attr) = self.config.authority_attribute {
1069 if let Some(values) = assertion.attributes.get(auth_attr) {
1070 user = user.authorities(values);
1071 }
1072 }
1073
1074 Ok(user)
1075 }
1076
1077 pub fn generate_metadata(&self) -> String {
1079 let mut xml = String::new();
1080 xml.push_str(r#"<?xml version="1.0" encoding="UTF-8"?>"#);
1081 xml.push_str(&format!(
1082 r#"<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="{}">"#,
1083 self.config.entity_id
1084 ));
1085
1086 xml.push_str(r#"<md:SPSSODescriptor AuthnRequestsSigned=""#);
1087 xml.push_str(if self.config.sign_authn_request {
1088 "true"
1089 } else {
1090 "false"
1091 });
1092 xml.push_str(r#"" WantAssertionsSigned=""#);
1093 xml.push_str(if self.config.want_assertions_signed {
1094 "true"
1095 } else {
1096 "false"
1097 });
1098 xml.push_str(r#"" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">"#);
1099
1100 xml.push_str(&format!(
1102 r#"<md:NameIDFormat>{}</md:NameIDFormat>"#,
1103 self.config.name_id_format.as_urn()
1104 ));
1105
1106 xml.push_str(&format!(
1108 r#"<md:AssertionConsumerService Binding="{}" Location="{}" index="0"/>"#,
1109 SamlBinding::HttpPost.as_urn(),
1110 self.config.acs_url
1111 ));
1112
1113 if let Some(ref sls_url) = self.config.sls_url {
1115 xml.push_str(&format!(
1116 r#"<md:SingleLogoutService Binding="{}" Location="{}"/>"#,
1117 self.config.slo_binding.as_urn(),
1118 sls_url
1119 ));
1120 }
1121
1122 xml.push_str(r#"</md:SPSSODescriptor></md:EntityDescriptor>"#);
1123 xml
1124 }
1125
1126 pub fn create_logout_request(
1128 &self,
1129 name_id: &str,
1130 session_index: Option<&str>,
1131 ) -> Option<String> {
1132 let slo_url = self.config.idp_slo_url.as_ref()?;
1133
1134 let id = format!("_{}_{}", generate_id(), timestamp_millis());
1135 let issue_instant = iso8601_now();
1136
1137 let mut xml = String::new();
1138 xml.push_str(r#"<?xml version="1.0" encoding="UTF-8"?>"#);
1139 xml.push_str(&format!(
1140 r#"<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="{}" Version="2.0" IssueInstant="{}" Destination="{}">"#,
1141 id, issue_instant, slo_url
1142 ));
1143
1144 xml.push_str(&format!(
1145 r#"<saml:Issuer>{}</saml:Issuer>"#,
1146 self.config.entity_id
1147 ));
1148
1149 xml.push_str(&format!(
1150 r#"<saml:NameID Format="{}">{}</saml:NameID>"#,
1151 self.config.name_id_format.as_urn(),
1152 name_id
1153 ));
1154
1155 if let Some(session_idx) = session_index {
1156 xml.push_str(&format!(
1157 r#"<samlp:SessionIndex>{}</samlp:SessionIndex>"#,
1158 session_idx
1159 ));
1160 }
1161
1162 xml.push_str(r#"</samlp:LogoutRequest>"#);
1163
1164 let deflated = deflate_and_encode(&xml);
1165 Some(format!(
1166 "{}?SAMLRequest={}",
1167 slo_url,
1168 urlencoding::encode(&deflated)
1169 ))
1170 }
1171}
1172
1173#[derive(Debug, Clone)]
1175pub struct SamlAuthResult {
1176 pub user: User,
1178 pub session_index: Option<String>,
1180 pub name_id: String,
1182 pub name_id_format: NameIdFormat,
1184 pub attributes: HashMap<String, Vec<String>>,
1186}
1187
1188fn generate_id() -> String {
1194 use std::collections::hash_map::RandomState;
1195 use std::hash::{BuildHasher, Hasher};
1196
1197 let hasher = RandomState::new();
1198 let mut h = hasher.build_hasher();
1199 h.write_u64(timestamp_millis());
1200 format!("{:016x}", h.finish())
1201}
1202
1203fn timestamp_millis() -> u64 {
1205 SystemTime::now()
1206 .duration_since(UNIX_EPOCH)
1207 .unwrap()
1208 .as_millis() as u64
1209}
1210
1211fn iso8601_now() -> String {
1213 let now = SystemTime::now()
1214 .duration_since(UNIX_EPOCH)
1215 .unwrap()
1216 .as_secs();
1217
1218 let secs_per_minute = 60;
1220 let secs_per_hour = 3600;
1221 let secs_per_day = 86400;
1222
1223 let days_since_1970 = now / secs_per_day;
1224 let time_of_day = now % secs_per_day;
1225
1226 let hours = time_of_day / secs_per_hour;
1227 let minutes = (time_of_day % secs_per_hour) / secs_per_minute;
1228 let seconds = time_of_day % secs_per_minute;
1229
1230 let (year, month, day) = days_to_ymd(days_since_1970);
1232
1233 format!(
1234 "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
1235 year, month, day, hours, minutes, seconds
1236 )
1237}
1238
1239fn days_to_ymd(days: u64) -> (u64, u64, u64) {
1241 let mut remaining = days;
1243 let mut year = 1970u64;
1244
1245 loop {
1246 let days_in_year = if is_leap_year(year) { 366 } else { 365 };
1247 if remaining < days_in_year {
1248 break;
1249 }
1250 remaining -= days_in_year;
1251 year += 1;
1252 }
1253
1254 let months = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
1255 let mut month = 1u64;
1256
1257 for (i, &days_in_month) in months.iter().enumerate() {
1258 let days_in_month = if i == 1 && is_leap_year(year) {
1259 29
1260 } else {
1261 days_in_month
1262 };
1263 if remaining < days_in_month {
1264 break;
1265 }
1266 remaining -= days_in_month;
1267 month += 1;
1268 }
1269
1270 (year, month, remaining + 1)
1271}
1272
1273fn is_leap_year(year: u64) -> bool {
1274 (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
1275}
1276
1277fn parse_iso8601(s: &str) -> Result<u64, ()> {
1279 if s.len() < 19 {
1281 return Err(());
1282 }
1283
1284 let year: u64 = s[0..4].parse().map_err(|_| ())?;
1285 let month: u64 = s[5..7].parse().map_err(|_| ())?;
1286 let day: u64 = s[8..10].parse().map_err(|_| ())?;
1287 let hour: u64 = s[11..13].parse().map_err(|_| ())?;
1288 let minute: u64 = s[14..16].parse().map_err(|_| ())?;
1289 let second: u64 = s[17..19].parse().map_err(|_| ())?;
1290
1291 let mut days = 0u64;
1293 for y in 1970..year {
1294 days += if is_leap_year(y) { 366 } else { 365 };
1295 }
1296
1297 let months = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
1298 for (i, &d) in months.iter().take((month - 1) as usize).enumerate() {
1299 days += if i == 1 && is_leap_year(year) { 29 } else { d };
1300 }
1301 days += day - 1;
1302
1303 Ok(days * 86400 + hour * 3600 + minute * 60 + second)
1304}
1305
1306fn deflate_and_encode(xml: &str) -> String {
1308 use base64::{engine::general_purpose::STANDARD, Engine as _};
1309
1310 STANDARD.encode(xml)
1313}
1314
1315fn extract_attribute(xml: &str, element: &str, attr: &str) -> Option<String> {
1317 let element_pattern = format!("<{}", element);
1318 let start = xml.find(&element_pattern)?;
1319 let end = xml[start..].find('>')? + start;
1320 let element_str = &xml[start..end];
1321
1322 let attr_pattern = format!("{}=\"", attr);
1323 let attr_start = element_str.find(&attr_pattern)? + attr_pattern.len();
1324 let attr_end = element_str[attr_start..].find('"')? + attr_start;
1325
1326 Some(element_str[attr_start..attr_end].to_string())
1327}
1328
1329fn extract_element_text(xml: &str, element: &str) -> Option<String> {
1331 let patterns = [format!("<{}:", element), format!("<{}>", element)];
1333
1334 for pattern in &patterns {
1335 if let Some(start) = xml.find(pattern) {
1336 let content_start = xml[start..].find('>')? + start + 1;
1337 let end_pattern = format!("</{}", element);
1338 if let Some(end) = xml[content_start..].find(&end_pattern) {
1339 let content = &xml[content_start..content_start + end];
1340 if let Some(actual_end) = content.rfind('<') {
1342 return Some(content[..actual_end].trim().to_string());
1343 }
1344 return Some(content.trim().to_string());
1345 }
1346 }
1347 }
1348
1349 let start_tag = format!("<{}>", element);
1351 let end_tag = format!("</{}>", element);
1352 if let Some(start) = xml.find(&start_tag) {
1353 let content_start = start + start_tag.len();
1354 if let Some(end) = xml[content_start..].find(&end_tag) {
1355 return Some(xml[content_start..content_start + end].trim().to_string());
1356 }
1357 }
1358
1359 None
1360}
1361
1362fn extract_status_code(xml: &str) -> Option<SamlStatusCode> {
1364 let pattern = "StatusCode";
1365 let start = xml.find(pattern)?;
1366 let value_start = xml[start..].find("Value=\"")? + start + 7;
1367 let value_end = xml[value_start..].find('"')? + value_start;
1368 let value = &xml[value_start..value_end];
1369
1370 Some(SamlStatusCode::from_urn(value))
1371}
1372
1373fn parse_assertions(xml: &str) -> Result<Vec<SamlAssertion>, SamlError> {
1375 let mut assertions = Vec::new();
1376
1377 let assertion_pattern = "<saml:Assertion";
1379 if let Some(start) = xml.find(assertion_pattern) {
1380 let assertion_xml = &xml[start..];
1381
1382 let id = extract_attribute(assertion_xml, "Assertion", "ID")
1383 .unwrap_or_else(|| format!("_generated_{}", timestamp_millis()));
1384 let issue_instant =
1385 extract_attribute(assertion_xml, "Assertion", "IssueInstant").unwrap_or_default();
1386 let issuer = extract_element_text(assertion_xml, "Issuer").unwrap_or_default();
1387
1388 let name_id = extract_element_text(assertion_xml, "NameID").unwrap_or_default();
1390 let name_id_format = extract_attribute(assertion_xml, "NameID", "Format")
1391 .map(|f| NameIdFormat::from_urn(&f))
1392 .unwrap_or_default();
1393
1394 let not_before = extract_attribute(assertion_xml, "Conditions", "NotBefore");
1396 let not_on_or_after = extract_attribute(assertion_xml, "Conditions", "NotOnOrAfter");
1397
1398 let session_index = extract_attribute(assertion_xml, "AuthnStatement", "SessionIndex");
1400 let session_not_on_or_after =
1401 extract_attribute(assertion_xml, "AuthnStatement", "SessionNotOnOrAfter");
1402
1403 let audiences = extract_element_text(assertion_xml, "Audience")
1405 .map(|a| vec![a])
1406 .unwrap_or_default();
1407
1408 let authn_context_class = extract_element_text(assertion_xml, "AuthnContextClassRef");
1410
1411 let attributes = parse_attributes(assertion_xml);
1413
1414 assertions.push(SamlAssertion {
1415 id,
1416 issue_instant,
1417 issuer,
1418 name_id,
1419 name_id_format,
1420 session_index,
1421 session_not_on_or_after,
1422 not_before,
1423 not_on_or_after,
1424 audiences,
1425 authn_context_class,
1426 attributes,
1427 });
1428 }
1429
1430 Ok(assertions)
1431}
1432
1433fn parse_attributes(xml: &str) -> HashMap<String, Vec<String>> {
1435 let mut attributes = HashMap::new();
1436
1437 let attr_statement = "<AttributeStatement";
1439 if let Some(start) = xml.find(attr_statement) {
1440 let end = xml[start..]
1441 .find("</AttributeStatement>")
1442 .unwrap_or(xml.len() - start);
1443 let statement_xml = &xml[start..start + end];
1444
1445 let mut search_pos = 0;
1447 while let Some(attr_start) = statement_xml[search_pos..].find("<Attribute ") {
1448 let attr_start = search_pos + attr_start;
1449
1450 if let Some(name) = extract_attribute(&statement_xml[attr_start..], "Attribute", "Name")
1452 {
1453 let attr_end = statement_xml[attr_start..]
1455 .find("</Attribute>")
1456 .unwrap_or(statement_xml.len() - attr_start);
1457 let attr_xml = &statement_xml[attr_start..attr_start + attr_end];
1458
1459 let mut values = Vec::new();
1460 let mut value_pos = 0;
1461 while let Some(value_start) = attr_xml[value_pos..].find("<AttributeValue") {
1462 let value_start = value_pos + value_start;
1463 if let Some(content_start) = attr_xml[value_start..].find('>') {
1464 let content_start = value_start + content_start + 1;
1465 if let Some(content_end) =
1466 attr_xml[content_start..].find("</AttributeValue>")
1467 {
1468 let value = attr_xml[content_start..content_start + content_end].trim();
1469 values.push(value.to_string());
1470 }
1471 }
1472 value_pos = value_start + 1;
1473 }
1474
1475 if !values.is_empty() {
1476 attributes.insert(name, values);
1477 }
1478 }
1479
1480 search_pos = attr_start + 1;
1481 }
1482 }
1483
1484 attributes
1485}
1486
1487#[cfg(test)]
1492mod tests {
1493 use super::*;
1494
1495 #[test]
1496 fn test_name_id_format() {
1497 assert_eq!(
1498 NameIdFormat::EmailAddress.as_urn(),
1499 "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
1500 );
1501
1502 let parsed = NameIdFormat::from_urn("urn:oasis:names:tc:SAML:2.0:nameid-format:persistent");
1503 assert_eq!(parsed, NameIdFormat::Persistent);
1504 }
1505
1506 #[test]
1507 fn test_saml_binding() {
1508 assert_eq!(
1509 SamlBinding::HttpPost.as_urn(),
1510 "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
1511 );
1512 }
1513
1514 #[test]
1515 fn test_status_code() {
1516 let success = SamlStatusCode::from_urn("urn:oasis:names:tc:SAML:2.0:status:Success");
1517 assert!(success.is_success());
1518
1519 let failure = SamlStatusCode::from_urn("urn:oasis:names:tc:SAML:2.0:status:AuthnFailed");
1520 assert!(!failure.is_success());
1521 }
1522
1523 #[test]
1524 fn test_config_builder() {
1525 let config = SamlConfig::new()
1526 .entity_id("https://sp.example.com/saml/metadata")
1527 .idp_sso_url("https://idp.example.com/saml/sso")
1528 .acs_url("https://sp.example.com/saml/acs")
1529 .name_id_format(NameIdFormat::EmailAddress)
1530 .want_assertions_signed(true);
1531
1532 assert_eq!(config.entity_id, "https://sp.example.com/saml/metadata");
1533 assert_eq!(config.name_id_format, NameIdFormat::EmailAddress);
1534 }
1535
1536 #[test]
1537 fn test_config_validation() {
1538 let incomplete = SamlConfig::new();
1539 assert!(incomplete.validate().is_err());
1540
1541 let valid = SamlConfig::new()
1542 .entity_id("https://sp.example.com")
1543 .idp_sso_url("https://idp.example.com/sso")
1544 .acs_url("https://sp.example.com/acs")
1545 .want_assertions_signed(false);
1546
1547 assert!(valid.validate().is_ok());
1548 }
1549
1550 #[test]
1551 fn test_config_presets() {
1552 let okta = SamlConfig::okta("myorg.okta.com", "app123", "https://myapp.com");
1553 assert!(okta.idp_sso_url.contains("okta.com"));
1554
1555 let azure = SamlConfig::azure_ad("tenant-id", "app-id", "https://myapp.com");
1556 assert!(azure.idp_sso_url.contains("microsoftonline.com"));
1557
1558 let adfs = SamlConfig::adfs("adfs.company.com", "https://myapp.com");
1559 assert!(adfs.idp_sso_url.contains("adfs"));
1560 }
1561
1562 #[test]
1563 fn test_authn_request_generation() {
1564 let config = SamlConfig::new()
1565 .entity_id("https://sp.example.com")
1566 .idp_sso_url("https://idp.example.com/sso")
1567 .acs_url("https://sp.example.com/acs")
1568 .name_id_format(NameIdFormat::EmailAddress);
1569
1570 let request = AuthnRequest::new(&config);
1571 let xml = request.to_xml();
1572
1573 assert!(xml.contains("AuthnRequest"));
1574 assert!(xml.contains("https://sp.example.com"));
1575 assert!(xml.contains("emailAddress"));
1576 }
1577
1578 #[test]
1579 fn test_authn_request_url() {
1580 let config = SamlConfig::new()
1581 .entity_id("https://sp.example.com")
1582 .idp_sso_url("https://idp.example.com/sso")
1583 .acs_url("https://sp.example.com/acs");
1584
1585 let request = AuthnRequest::new(&config);
1586 let url = request.to_redirect_url(Some("/dashboard"));
1587
1588 assert!(url.starts_with("https://idp.example.com/sso?"));
1589 assert!(url.contains("SAMLRequest="));
1590 assert!(url.contains("RelayState="));
1591 }
1592
1593 #[test]
1594 fn test_assertion_validation() {
1595 let config = SamlConfig::new()
1596 .entity_id("https://sp.example.com")
1597 .idp_sso_url("https://idp.example.com/sso")
1598 .acs_url("https://sp.example.com/acs")
1599 .idp_entity_id("https://idp.example.com")
1600 .max_clock_skew(Duration::from_secs(300));
1601
1602 let assertion = SamlAssertion {
1603 id: "_test".to_string(),
1604 issue_instant: iso8601_now(),
1605 issuer: "https://idp.example.com".to_string(),
1606 name_id: "user@example.com".to_string(),
1607 name_id_format: NameIdFormat::EmailAddress,
1608 session_index: Some("_session123".to_string()),
1609 session_not_on_or_after: None,
1610 not_before: None,
1611 not_on_or_after: None,
1612 audiences: vec!["https://sp.example.com".to_string()],
1613 authn_context_class: None,
1614 attributes: HashMap::new(),
1615 };
1616
1617 assert!(assertion.validate(&config).is_ok());
1618 }
1619
1620 #[test]
1621 fn test_assertion_audience_validation() {
1622 let config = SamlConfig::new()
1623 .entity_id("https://sp.example.com")
1624 .idp_sso_url("https://idp.example.com/sso")
1625 .acs_url("https://sp.example.com/acs");
1626
1627 let assertion = SamlAssertion {
1628 id: "_test".to_string(),
1629 issue_instant: iso8601_now(),
1630 issuer: "https://idp.example.com".to_string(),
1631 name_id: "user@example.com".to_string(),
1632 name_id_format: NameIdFormat::EmailAddress,
1633 session_index: None,
1634 session_not_on_or_after: None,
1635 not_before: None,
1636 not_on_or_after: None,
1637 audiences: vec!["https://other-sp.example.com".to_string()],
1638 authn_context_class: None,
1639 attributes: HashMap::new(),
1640 };
1641
1642 let result = assertion.validate(&config);
1643 assert!(matches!(
1644 result,
1645 Err(SamlError::AudienceRestrictionNotMet(_))
1646 ));
1647 }
1648
1649 #[test]
1650 fn test_authenticator_creation() {
1651 let config = SamlConfig::new()
1652 .entity_id("https://sp.example.com")
1653 .idp_sso_url("https://idp.example.com/sso")
1654 .acs_url("https://sp.example.com/acs")
1655 .want_assertions_signed(false);
1656
1657 let authenticator = SamlAuthenticator::new(config);
1658 assert!(authenticator.is_ok());
1659 }
1660
1661 #[test]
1662 fn test_metadata_generation() {
1663 let config = SamlConfig::new()
1664 .entity_id("https://sp.example.com")
1665 .idp_sso_url("https://idp.example.com/sso")
1666 .acs_url("https://sp.example.com/acs")
1667 .sls_url("https://sp.example.com/sls")
1668 .want_assertions_signed(false);
1669
1670 let authenticator = SamlAuthenticator::new(config).unwrap();
1671 let metadata = authenticator.generate_metadata();
1672
1673 assert!(metadata.contains("EntityDescriptor"));
1674 assert!(metadata.contains("https://sp.example.com"));
1675 assert!(metadata.contains("AssertionConsumerService"));
1676 assert!(metadata.contains("SingleLogoutService"));
1677 }
1678
1679 #[test]
1680 fn test_iso8601_generation() {
1681 let now = iso8601_now();
1682 assert!(now.contains("T"));
1683 assert!(now.ends_with("Z"));
1684 assert_eq!(now.len(), 20);
1685 }
1686
1687 #[test]
1688 fn test_iso8601_parsing() {
1689 let timestamp = parse_iso8601("2024-01-15T10:30:00Z");
1690 assert!(timestamp.is_ok());
1691
1692 let invalid = parse_iso8601("invalid");
1693 assert!(invalid.is_err());
1694 }
1695
1696 #[test]
1697 fn test_xml_attribute_extraction() {
1698 let xml = r#"<Response ID="resp123" Version="2.0">"#;
1699 assert_eq!(
1700 extract_attribute(xml, "Response", "ID"),
1701 Some("resp123".to_string())
1702 );
1703 }
1704
1705 #[test]
1706 fn test_attribute_parsing() {
1707 let xml = r#"
1708 <AttributeStatement>
1709 <Attribute Name="email">
1710 <AttributeValue>user@example.com</AttributeValue>
1711 </Attribute>
1712 <Attribute Name="roles">
1713 <AttributeValue>admin</AttributeValue>
1714 <AttributeValue>user</AttributeValue>
1715 </Attribute>
1716 </AttributeStatement>
1717 "#;
1718
1719 let attrs = parse_attributes(xml);
1720 assert_eq!(
1721 attrs.get("email"),
1722 Some(&vec!["user@example.com".to_string()])
1723 );
1724 assert_eq!(
1725 attrs.get("roles"),
1726 Some(&vec!["admin".to_string(), "user".to_string()])
1727 );
1728 }
1729
1730 #[test]
1731 fn test_saml_error_display() {
1732 let err = SamlError::Configuration("test error".to_string());
1733 let display = format!("{}", err);
1734 assert!(display.contains("test error"));
1735 }
1736
1737 #[test]
1738 fn test_logout_request() {
1739 let config = SamlConfig::new()
1740 .entity_id("https://sp.example.com")
1741 .idp_sso_url("https://idp.example.com/sso")
1742 .idp_slo_url("https://idp.example.com/slo")
1743 .acs_url("https://sp.example.com/acs")
1744 .want_assertions_signed(false);
1745
1746 let authenticator = SamlAuthenticator::new(config).unwrap();
1747 let logout_url =
1748 authenticator.create_logout_request("user@example.com", Some("_session123"));
1749
1750 assert!(logout_url.is_some());
1751 let url = logout_url.unwrap();
1752 assert!(url.contains("SAMLRequest="));
1753 }
1754}