1use std::collections::HashMap;
38use std::sync::Arc;
39use std::time::{Duration, SystemTime, UNIX_EPOCH};
40
41use crate::http::security::User;
42
43#[derive(Debug, Clone, PartialEq, Eq, Default)]
45pub enum NameIdFormat {
46 #[default]
48 Unspecified,
49 EmailAddress,
51 X509SubjectName,
53 WindowsDomainQualifiedName,
55 Kerberos,
57 Persistent,
59 Transient,
61 Custom(String),
63}
64
65impl NameIdFormat {
66 pub fn as_urn(&self) -> &str {
68 match self {
69 NameIdFormat::Unspecified => {
70 "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"
71 }
72 NameIdFormat::EmailAddress => {
73 "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
74 }
75 NameIdFormat::X509SubjectName => {
76 "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName"
77 }
78 NameIdFormat::WindowsDomainQualifiedName => {
79 "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName"
80 }
81 NameIdFormat::Kerberos => {
82 "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos"
83 }
84 NameIdFormat::Persistent => {
85 "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
86 }
87 NameIdFormat::Transient => {
88 "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"
89 }
90 NameIdFormat::Custom(urn) => urn,
91 }
92 }
93
94 pub fn from_urn(urn: &str) -> Self {
96 match urn {
97 "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified" => NameIdFormat::Unspecified,
98 "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress" => NameIdFormat::EmailAddress,
99 "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName" => {
100 NameIdFormat::X509SubjectName
101 }
102 "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName" => {
103 NameIdFormat::WindowsDomainQualifiedName
104 }
105 "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos" => NameIdFormat::Kerberos,
106 "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent" => NameIdFormat::Persistent,
107 "urn:oasis:names:tc:SAML:2.0:nameid-format:transient" => NameIdFormat::Transient,
108 other => NameIdFormat::Custom(other.to_string()),
109 }
110 }
111}
112
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
116pub enum SamlBinding {
117 #[default]
119 HttpRedirect,
120 HttpPost,
122 HttpArtifact,
124 Soap,
126}
127
128impl SamlBinding {
129 pub fn as_urn(&self) -> &str {
131 match self {
132 SamlBinding::HttpRedirect => "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
133 SamlBinding::HttpPost => "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
134 SamlBinding::HttpArtifact => "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact",
135 SamlBinding::Soap => "urn:oasis:names:tc:SAML:2.0:bindings:SOAP",
136 }
137 }
138}
139
140#[derive(Debug, Clone, PartialEq, Eq, Default)]
142pub enum AuthnContextClass {
143 #[default]
145 Unspecified,
146 Password,
148 PasswordProtectedTransport,
150 X509,
152 Kerberos,
154 MultiFactor,
156 Custom(String),
158}
159
160impl AuthnContextClass {
161 pub fn as_urn(&self) -> &str {
163 match self {
164 AuthnContextClass::Unspecified => {
165 "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
166 }
167 AuthnContextClass::Password => "urn:oasis:names:tc:SAML:2.0:ac:classes:Password",
168 AuthnContextClass::PasswordProtectedTransport => {
169 "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport"
170 }
171 AuthnContextClass::X509 => "urn:oasis:names:tc:SAML:2.0:ac:classes:X509",
172 AuthnContextClass::Kerberos => "urn:oasis:names:tc:SAML:2.0:ac:classes:Kerberos",
173 AuthnContextClass::MultiFactor => {
174 "urn:oasis:names:tc:SAML:2.0:ac:classes:MultiFactor"
175 }
176 AuthnContextClass::Custom(urn) => urn,
177 }
178 }
179}
180
181#[derive(Debug, Clone, PartialEq, Eq)]
183pub enum SamlStatusCode {
184 Success,
186 Requester,
188 Responder,
190 VersionMismatch,
192 AuthnFailed,
194 InvalidNameIdPolicy,
196 NoAuthnContext,
198 Unknown(String),
200}
201
202impl SamlStatusCode {
203 pub fn from_urn(urn: &str) -> Self {
205 match urn {
206 "urn:oasis:names:tc:SAML:2.0:status:Success" => SamlStatusCode::Success,
207 "urn:oasis:names:tc:SAML:2.0:status:Requester" => SamlStatusCode::Requester,
208 "urn:oasis:names:tc:SAML:2.0:status:Responder" => SamlStatusCode::Responder,
209 "urn:oasis:names:tc:SAML:2.0:status:VersionMismatch" => SamlStatusCode::VersionMismatch,
210 "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed" => SamlStatusCode::AuthnFailed,
211 "urn:oasis:names:tc:SAML:2.0:status:InvalidNameIDPolicy" => {
212 SamlStatusCode::InvalidNameIdPolicy
213 }
214 "urn:oasis:names:tc:SAML:2.0:status:NoAuthnContext" => SamlStatusCode::NoAuthnContext,
215 other => SamlStatusCode::Unknown(other.to_string()),
216 }
217 }
218
219 pub fn is_success(&self) -> bool {
221 matches!(self, SamlStatusCode::Success)
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct SamlConfig {
228 pub entity_id: String,
230 pub idp_sso_url: String,
232 pub idp_slo_url: Option<String>,
234 pub idp_entity_id: Option<String>,
236 pub idp_certificate: Option<String>,
238 pub sp_private_key: Option<String>,
240 pub sp_certificate: Option<String>,
242 pub acs_url: String,
244 pub sls_url: Option<String>,
246 pub sso_binding: SamlBinding,
248 pub slo_binding: SamlBinding,
250 pub name_id_format: NameIdFormat,
252 pub authn_context_class: Option<AuthnContextClass>,
254 pub sign_authn_request: bool,
256 pub want_assertions_signed: bool,
258 pub want_assertions_encrypted: bool,
260 pub max_clock_skew: Duration,
262 pub attribute_mapping: HashMap<String, String>,
264 pub role_attribute: Option<String>,
266 pub authority_attribute: Option<String>,
268 pub default_roles: Vec<String>,
270 pub allow_unsolicited_responses: bool,
272 pub session_timeout: Duration,
274}
275
276impl SamlConfig {
277 pub fn new() -> Self {
279 Self {
280 entity_id: String::new(),
281 idp_sso_url: String::new(),
282 idp_slo_url: None,
283 idp_entity_id: None,
284 idp_certificate: None,
285 sp_private_key: None,
286 sp_certificate: None,
287 acs_url: String::new(),
288 sls_url: None,
289 sso_binding: SamlBinding::HttpRedirect,
290 slo_binding: SamlBinding::HttpRedirect,
291 name_id_format: NameIdFormat::Unspecified,
292 authn_context_class: None,
293 sign_authn_request: false,
294 want_assertions_signed: true,
295 want_assertions_encrypted: false,
296 max_clock_skew: Duration::from_secs(120),
297 attribute_mapping: HashMap::new(),
298 role_attribute: None,
299 authority_attribute: None,
300 default_roles: vec!["USER".to_string()],
301 allow_unsolicited_responses: false,
302 session_timeout: Duration::from_secs(3600),
303 }
304 }
305
306 pub fn entity_id(mut self, entity_id: impl Into<String>) -> Self {
308 self.entity_id = entity_id.into();
309 self
310 }
311
312 pub fn idp_sso_url(mut self, url: impl Into<String>) -> Self {
314 self.idp_sso_url = url.into();
315 self
316 }
317
318 pub fn idp_slo_url(mut self, url: impl Into<String>) -> Self {
320 self.idp_slo_url = Some(url.into());
321 self
322 }
323
324 pub fn idp_entity_id(mut self, entity_id: impl Into<String>) -> Self {
326 self.idp_entity_id = Some(entity_id.into());
327 self
328 }
329
330 pub fn idp_certificate(mut self, cert: impl Into<String>) -> Self {
332 self.idp_certificate = Some(cert.into());
333 self
334 }
335
336 pub fn sp_private_key(mut self, key: impl Into<String>) -> Self {
338 self.sp_private_key = Some(key.into());
339 self
340 }
341
342 pub fn sp_certificate(mut self, cert: impl Into<String>) -> Self {
344 self.sp_certificate = Some(cert.into());
345 self
346 }
347
348 pub fn acs_url(mut self, url: impl Into<String>) -> Self {
350 self.acs_url = url.into();
351 self
352 }
353
354 pub fn assertion_consumer_service_url(self, url: impl Into<String>) -> Self {
356 self.acs_url(url)
357 }
358
359 pub fn sls_url(mut self, url: impl Into<String>) -> Self {
361 self.sls_url = Some(url.into());
362 self
363 }
364
365 pub fn sso_binding(mut self, binding: SamlBinding) -> Self {
367 self.sso_binding = binding;
368 self
369 }
370
371 pub fn slo_binding(mut self, binding: SamlBinding) -> Self {
373 self.slo_binding = binding;
374 self
375 }
376
377 pub fn name_id_format(mut self, format: NameIdFormat) -> Self {
379 self.name_id_format = format;
380 self
381 }
382
383 pub fn authn_context_class(mut self, class: AuthnContextClass) -> Self {
385 self.authn_context_class = Some(class);
386 self
387 }
388
389 pub fn sign_authn_request(mut self, sign: bool) -> Self {
391 self.sign_authn_request = sign;
392 self
393 }
394
395 pub fn want_assertions_signed(mut self, signed: bool) -> Self {
397 self.want_assertions_signed = signed;
398 self
399 }
400
401 pub fn want_assertions_encrypted(mut self, encrypted: bool) -> Self {
403 self.want_assertions_encrypted = encrypted;
404 self
405 }
406
407 pub fn max_clock_skew(mut self, skew: Duration) -> Self {
409 self.max_clock_skew = skew;
410 self
411 }
412
413 pub fn map_attribute(
415 mut self,
416 saml_attribute: impl Into<String>,
417 user_field: impl Into<String>,
418 ) -> Self {
419 self.attribute_mapping
420 .insert(saml_attribute.into(), user_field.into());
421 self
422 }
423
424 pub fn role_attribute(mut self, attr: impl Into<String>) -> Self {
426 self.role_attribute = Some(attr.into());
427 self
428 }
429
430 pub fn authority_attribute(mut self, attr: impl Into<String>) -> Self {
432 self.authority_attribute = Some(attr.into());
433 self
434 }
435
436 pub fn default_roles(mut self, roles: Vec<String>) -> Self {
438 self.default_roles = roles;
439 self
440 }
441
442 pub fn allow_unsolicited_responses(mut self, allow: bool) -> Self {
444 self.allow_unsolicited_responses = allow;
445 self
446 }
447
448 pub fn session_timeout(mut self, timeout: Duration) -> Self {
450 self.session_timeout = timeout;
451 self
452 }
453
454 pub fn okta(
456 okta_domain: impl Into<String>,
457 app_id: impl Into<String>,
458 sp_entity_id: impl Into<String>,
459 ) -> Self {
460 let domain = okta_domain.into();
461 let app = app_id.into();
462 Self::new()
463 .entity_id(sp_entity_id)
464 .idp_sso_url(format!("https://{}/app/{}/sso/saml", domain, app))
465 .idp_entity_id(format!("http://www.okta.com/{}", app))
466 .name_id_format(NameIdFormat::EmailAddress)
467 .sso_binding(SamlBinding::HttpPost)
468 }
469
470 pub fn azure_ad(
472 tenant_id: impl Into<String>,
473 _app_id: impl Into<String>,
474 sp_entity_id: impl Into<String>,
475 ) -> Self {
476 let tenant = tenant_id.into();
477 Self::new()
478 .entity_id(sp_entity_id)
479 .idp_sso_url(format!(
480 "https://login.microsoftonline.com/{}/saml2",
481 tenant
482 ))
483 .idp_entity_id(format!(
484 "https://sts.windows.net/{}/",
485 tenant
486 ))
487 .name_id_format(NameIdFormat::EmailAddress)
488 .sso_binding(SamlBinding::HttpRedirect)
489 .map_attribute(
490 "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress",
491 "email",
492 )
493 .map_attribute(
494 "http://schemas.microsoft.com/ws/2008/06/identity/claims/groups",
495 "groups",
496 )
497 }
498
499 pub fn google_workspace(sp_entity_id: impl Into<String>, acs_url: impl Into<String>) -> Self {
501 Self::new()
502 .entity_id(sp_entity_id)
503 .idp_sso_url("https://accounts.google.com/o/saml2/idp")
504 .acs_url(acs_url)
505 .name_id_format(NameIdFormat::EmailAddress)
506 .sso_binding(SamlBinding::HttpRedirect)
507 }
508
509 pub fn adfs(adfs_host: impl Into<String>, sp_entity_id: impl Into<String>) -> Self {
511 let host = adfs_host.into();
512 Self::new()
513 .entity_id(sp_entity_id)
514 .idp_sso_url(format!("https://{}/adfs/ls/", host))
515 .idp_entity_id(format!("http://{}/adfs/services/trust", host))
516 .name_id_format(NameIdFormat::Unspecified)
517 .sso_binding(SamlBinding::HttpPost)
518 }
519
520 pub fn validate(&self) -> Result<(), SamlError> {
522 if self.entity_id.is_empty() {
523 return Err(SamlError::Configuration("entity_id is required".into()));
524 }
525 if self.idp_sso_url.is_empty() {
526 return Err(SamlError::Configuration("idp_sso_url is required".into()));
527 }
528 if self.acs_url.is_empty() {
529 return Err(SamlError::Configuration("acs_url is required".into()));
530 }
531 if self.sign_authn_request && self.sp_private_key.is_none() {
532 return Err(SamlError::Configuration(
533 "sp_private_key is required when sign_authn_request is true".into(),
534 ));
535 }
536 if self.want_assertions_signed && self.idp_certificate.is_none() {
537 return Err(SamlError::Configuration(
538 "idp_certificate is required when want_assertions_signed is true".into(),
539 ));
540 }
541 Ok(())
542 }
543}
544
545impl Default for SamlConfig {
546 fn default() -> Self {
547 Self::new()
548 }
549}
550
551#[derive(Debug, Clone)]
553pub enum SamlError {
554 Configuration(String),
556 InvalidResponse(String),
558 SignatureVerificationFailed(String),
560 AssertionValidationFailed(String),
562 TimeConditionNotMet(String),
564 AudienceRestrictionNotMet(String),
566 MissingAttribute(String),
568 IdpError(SamlStatusCode, Option<String>),
570 DecryptionFailed(String),
572 XmlParsingError(String),
574 NetworkError(String),
576}
577
578impl std::fmt::Display for SamlError {
579 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580 match self {
581 SamlError::Configuration(msg) => write!(f, "SAML configuration error: {}", msg),
582 SamlError::InvalidResponse(msg) => write!(f, "Invalid SAML response: {}", msg),
583 SamlError::SignatureVerificationFailed(msg) => {
584 write!(f, "SAML signature verification failed: {}", msg)
585 }
586 SamlError::AssertionValidationFailed(msg) => {
587 write!(f, "SAML assertion validation failed: {}", msg)
588 }
589 SamlError::TimeConditionNotMet(msg) => {
590 write!(f, "SAML time condition not met: {}", msg)
591 }
592 SamlError::AudienceRestrictionNotMet(msg) => {
593 write!(f, "SAML audience restriction not met: {}", msg)
594 }
595 SamlError::MissingAttribute(attr) => {
596 write!(f, "Required SAML attribute missing: {}", attr)
597 }
598 SamlError::IdpError(code, msg) => {
599 write!(f, "IdP returned error {:?}: {:?}", code, msg)
600 }
601 SamlError::DecryptionFailed(msg) => write!(f, "SAML decryption failed: {}", msg),
602 SamlError::XmlParsingError(msg) => write!(f, "XML parsing error: {}", msg),
603 SamlError::NetworkError(msg) => write!(f, "Network error: {}", msg),
604 }
605 }
606}
607
608impl std::error::Error for SamlError {}
609
610#[derive(Debug, Clone)]
612pub struct AuthnRequest {
613 pub id: String,
615 pub issue_instant: String,
617 pub issuer: String,
619 pub destination: String,
621 pub acs_url: String,
623 pub protocol_binding: SamlBinding,
625 pub name_id_format: NameIdFormat,
627 pub authn_context: Option<AuthnContextClass>,
629 pub force_authn: bool,
631 pub is_passive: bool,
633}
634
635impl AuthnRequest {
636 pub fn new(config: &SamlConfig) -> Self {
638 let id = format!("_{}_{}", generate_id(), timestamp_millis());
639
640 Self {
641 id,
642 issue_instant: iso8601_now(),
643 issuer: config.entity_id.clone(),
644 destination: config.idp_sso_url.clone(),
645 acs_url: config.acs_url.clone(),
646 protocol_binding: SamlBinding::HttpPost,
647 name_id_format: config.name_id_format.clone(),
648 authn_context: config.authn_context_class.clone(),
649 force_authn: false,
650 is_passive: false,
651 }
652 }
653
654 pub fn force_authn(mut self, force: bool) -> Self {
656 self.force_authn = force;
657 self
658 }
659
660 pub fn is_passive(mut self, passive: bool) -> Self {
662 self.is_passive = passive;
663 self
664 }
665
666 pub fn to_xml(&self) -> String {
668 let mut xml = String::new();
669 xml.push_str(r#"<?xml version="1.0" encoding="UTF-8"?>"#);
670 xml.push_str(&format!(
671 r#"<samlp:AuthnRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="{}" Version="2.0" IssueInstant="{}" Destination="{}" AssertionConsumerServiceURL="{}" ProtocolBinding="{}""#,
672 self.id,
673 self.issue_instant,
674 self.destination,
675 self.acs_url,
676 self.protocol_binding.as_urn()
677 ));
678
679 if self.force_authn {
680 xml.push_str(r#" ForceAuthn="true""#);
681 }
682 if self.is_passive {
683 xml.push_str(r#" IsPassive="true""#);
684 }
685
686 xml.push('>');
687
688 xml.push_str(&format!(
690 r#"<saml:Issuer>{}</saml:Issuer>"#,
691 self.issuer
692 ));
693
694 xml.push_str(&format!(
696 r#"<samlp:NameIDPolicy Format="{}" AllowCreate="true"/>"#,
697 self.name_id_format.as_urn()
698 ));
699
700 if let Some(ref authn_context) = self.authn_context {
702 xml.push_str(r#"<samlp:RequestedAuthnContext Comparison="exact">"#);
703 xml.push_str(&format!(
704 r#"<saml:AuthnContextClassRef>{}</saml:AuthnContextClassRef>"#,
705 authn_context.as_urn()
706 ));
707 xml.push_str(r#"</samlp:RequestedAuthnContext>"#);
708 }
709
710 xml.push_str(r#"</samlp:AuthnRequest>"#);
711 xml
712 }
713
714 pub fn to_redirect_url(&self, relay_state: Option<&str>) -> String {
716 let xml = self.to_xml();
717 let deflated = deflate_and_encode(&xml);
718
719 let mut url = format!(
720 "{}?SAMLRequest={}",
721 self.destination,
722 urlencoding::encode(&deflated)
723 );
724
725 if let Some(state) = relay_state {
726 url.push_str(&format!("&RelayState={}", urlencoding::encode(state)));
727 }
728
729 url
730 }
731}
732
733#[derive(Debug, Clone)]
735pub struct SamlAssertion {
736 pub id: String,
738 pub issue_instant: String,
740 pub issuer: String,
742 pub name_id: String,
744 pub name_id_format: NameIdFormat,
746 pub session_index: Option<String>,
748 pub session_not_on_or_after: Option<String>,
750 pub not_before: Option<String>,
752 pub not_on_or_after: Option<String>,
754 pub audiences: Vec<String>,
756 pub authn_context_class: Option<String>,
758 pub attributes: HashMap<String, Vec<String>>,
760}
761
762impl SamlAssertion {
763 pub fn get_attribute(&self, name: &str) -> Option<&str> {
765 self.attributes
766 .get(name)
767 .and_then(|values| values.first())
768 .map(|s| s.as_str())
769 }
770
771 pub fn get_attribute_values(&self, name: &str) -> Option<&Vec<String>> {
773 self.attributes.get(name)
774 }
775
776 pub fn validate(&self, config: &SamlConfig) -> Result<(), SamlError> {
778 let now = SystemTime::now()
779 .duration_since(UNIX_EPOCH)
780 .unwrap()
781 .as_secs();
782
783 if let Some(ref not_before) = self.not_before {
785 if let Ok(nb_time) = parse_iso8601(not_before) {
786 let skew = config.max_clock_skew.as_secs();
787 if now + skew < nb_time {
788 return Err(SamlError::TimeConditionNotMet(format!(
789 "Assertion not valid before {}",
790 not_before
791 )));
792 }
793 }
794 }
795
796 if let Some(ref not_on_or_after) = self.not_on_or_after {
798 if let Ok(noa_time) = parse_iso8601(not_on_or_after) {
799 let skew = config.max_clock_skew.as_secs();
800 if now > noa_time + skew {
801 return Err(SamlError::TimeConditionNotMet(format!(
802 "Assertion expired at {}",
803 not_on_or_after
804 )));
805 }
806 }
807 }
808
809 if !self.audiences.is_empty() && !self.audiences.contains(&config.entity_id) {
811 return Err(SamlError::AudienceRestrictionNotMet(format!(
812 "SP entity ID {} not in audiences: {:?}",
813 config.entity_id, self.audiences
814 )));
815 }
816
817 if let Some(ref expected_issuer) = config.idp_entity_id {
819 if &self.issuer != expected_issuer {
820 return Err(SamlError::AssertionValidationFailed(format!(
821 "Issuer mismatch: expected {}, got {}",
822 expected_issuer, self.issuer
823 )));
824 }
825 }
826
827 Ok(())
828 }
829}
830
831#[derive(Debug, Clone)]
833pub struct SamlResponse {
834 pub id: String,
836 pub in_response_to: Option<String>,
838 pub issue_instant: String,
840 pub destination: Option<String>,
842 pub issuer: String,
844 pub status_code: SamlStatusCode,
846 pub status_message: Option<String>,
848 pub assertions: Vec<SamlAssertion>,
850}
851
852impl SamlResponse {
853 pub fn from_base64(encoded: &str) -> Result<Self, SamlError> {
858 use base64::{engine::general_purpose::STANDARD, Engine as _};
859
860 let decoded = STANDARD
861 .decode(encoded)
862 .map_err(|e| SamlError::InvalidResponse(format!("Base64 decode error: {}", e)))?;
863
864 let xml = String::from_utf8(decoded)
865 .map_err(|e| SamlError::InvalidResponse(format!("UTF-8 decode error: {}", e)))?;
866
867 Self::from_xml(&xml)
868 }
869
870 pub fn from_xml(xml: &str) -> Result<Self, SamlError> {
875 let id = extract_attribute(xml, "Response", "ID")
879 .ok_or_else(|| SamlError::XmlParsingError("Missing Response ID".into()))?;
880
881 let in_response_to = extract_attribute(xml, "Response", "InResponseTo");
882 let issue_instant = extract_attribute(xml, "Response", "IssueInstant")
883 .ok_or_else(|| SamlError::XmlParsingError("Missing IssueInstant".into()))?;
884 let destination = extract_attribute(xml, "Response", "Destination");
885 let issuer = extract_element_text(xml, "Issuer")
886 .ok_or_else(|| SamlError::XmlParsingError("Missing Issuer".into()))?;
887
888 let status_code = extract_status_code(xml).unwrap_or(SamlStatusCode::Unknown(String::new()));
890 let status_message = extract_element_text(xml, "StatusMessage");
891
892 let assertions = parse_assertions(xml)?;
894
895 Ok(Self {
896 id,
897 in_response_to,
898 issue_instant,
899 destination,
900 issuer,
901 status_code,
902 status_message,
903 assertions,
904 })
905 }
906
907 pub fn is_success(&self) -> bool {
909 self.status_code.is_success()
910 }
911
912 pub fn assertion(&self) -> Option<&SamlAssertion> {
914 self.assertions.first()
915 }
916}
917
918#[derive(Clone)]
920pub struct SamlAuthenticator {
921 config: Arc<SamlConfig>,
922 pending_requests: Arc<std::sync::RwLock<HashMap<String, PendingRequest>>>,
923}
924
925#[derive(Debug, Clone)]
927#[allow(dead_code)] struct PendingRequest {
929 id: String,
931 created_at: u64,
933 relay_state: Option<String>,
935}
936
937impl SamlAuthenticator {
938 pub fn new(config: SamlConfig) -> Result<Self, SamlError> {
940 config.validate()?;
941 Ok(Self {
942 config: Arc::new(config),
943 pending_requests: Arc::new(std::sync::RwLock::new(HashMap::new())),
944 })
945 }
946
947 pub fn config(&self) -> &SamlConfig {
949 &self.config
950 }
951
952 pub fn create_authn_request(&self) -> AuthnRequest {
954 AuthnRequest::new(&self.config)
955 }
956
957 pub fn store_pending_request(&self, request: &AuthnRequest, relay_state: Option<String>) {
959 let mut pending = self.pending_requests.write().unwrap();
960 pending.insert(
961 request.id.clone(),
962 PendingRequest {
963 id: request.id.clone(),
964 created_at: timestamp_millis() / 1000,
965 relay_state,
966 },
967 );
968
969 let now = timestamp_millis() / 1000;
971 pending.retain(|_, req| now - req.created_at < 600);
972 }
973
974 pub fn initiate_login(&self, relay_state: Option<&str>) -> String {
976 let request = self.create_authn_request();
977 self.store_pending_request(&request, relay_state.map(|s| s.to_string()));
978 request.to_redirect_url(relay_state)
979 }
980
981 pub fn process_response(&self, encoded_response: &str) -> Result<SamlAuthResult, SamlError> {
983 let response = SamlResponse::from_base64(encoded_response)?;
984
985 self.validate_response(&response)?;
987
988 let assertion = response
990 .assertion()
991 .ok_or_else(|| SamlError::InvalidResponse("No assertion in response".into()))?;
992
993 assertion.validate(&self.config)?;
995
996 let user = self.map_assertion_to_user(assertion)?;
998
999 if let Some(ref in_response_to) = response.in_response_to {
1001 let mut pending = self.pending_requests.write().unwrap();
1002 pending.remove(in_response_to);
1003 }
1004
1005 Ok(SamlAuthResult {
1006 user,
1007 session_index: assertion.session_index.clone(),
1008 name_id: assertion.name_id.clone(),
1009 name_id_format: assertion.name_id_format.clone(),
1010 attributes: assertion.attributes.clone(),
1011 })
1012 }
1013
1014 fn validate_response(&self, response: &SamlResponse) -> Result<(), SamlError> {
1016 if !response.is_success() {
1018 return Err(SamlError::IdpError(
1019 response.status_code.clone(),
1020 response.status_message.clone(),
1021 ));
1022 }
1023
1024 if !self.config.allow_unsolicited_responses {
1026 if let Some(ref in_response_to) = response.in_response_to {
1027 let pending = self.pending_requests.read().unwrap();
1028 if !pending.contains_key(in_response_to) {
1029 return Err(SamlError::InvalidResponse(
1030 "InResponseTo does not match any pending request".into(),
1031 ));
1032 }
1033 } else {
1034 return Err(SamlError::InvalidResponse(
1035 "Unsolicited responses are not allowed".into(),
1036 ));
1037 }
1038 }
1039
1040 if let Some(ref destination) = response.destination {
1042 if destination != &self.config.acs_url {
1043 return Err(SamlError::InvalidResponse(format!(
1044 "Destination mismatch: expected {}, got {}",
1045 self.config.acs_url, destination
1046 )));
1047 }
1048 }
1049
1050 Ok(())
1051 }
1052
1053 fn map_assertion_to_user(&self, assertion: &SamlAssertion) -> Result<User, SamlError> {
1055 let username = assertion.name_id.clone();
1056
1057 let mut user = User::with_encoded_password(&username, "{saml}external".to_string());
1059
1060 for (saml_attr, user_field) in &self.config.attribute_mapping {
1062 if let Some(values) = assertion.attributes.get(saml_attr) {
1063 if let Some(value) = values.first() {
1064 match user_field.as_str() {
1065 "email" => {
1066 user = user.authorities(&[format!("email:{}", value)]);
1068 }
1069 "display_name" | "name" => {
1070 }
1072 _ => {}
1073 }
1074 }
1075 }
1076 }
1077
1078 let mut roles: Vec<String> = self.config.default_roles.clone();
1080 if let Some(ref role_attr) = self.config.role_attribute {
1081 if let Some(values) = assertion.attributes.get(role_attr) {
1082 roles.extend(values.iter().map(|r| r.to_uppercase()));
1083 }
1084 }
1085 user = user.roles(&roles);
1086
1087 if let Some(ref auth_attr) = self.config.authority_attribute {
1089 if let Some(values) = assertion.attributes.get(auth_attr) {
1090 user = user.authorities(values);
1091 }
1092 }
1093
1094 Ok(user)
1095 }
1096
1097 pub fn generate_metadata(&self) -> String {
1099 let mut xml = String::new();
1100 xml.push_str(r#"<?xml version="1.0" encoding="UTF-8"?>"#);
1101 xml.push_str(&format!(
1102 r#"<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="{}">"#,
1103 self.config.entity_id
1104 ));
1105
1106 xml.push_str(r#"<md:SPSSODescriptor AuthnRequestsSigned=""#);
1107 xml.push_str(if self.config.sign_authn_request {
1108 "true"
1109 } else {
1110 "false"
1111 });
1112 xml.push_str(r#"" WantAssertionsSigned=""#);
1113 xml.push_str(if self.config.want_assertions_signed {
1114 "true"
1115 } else {
1116 "false"
1117 });
1118 xml.push_str(r#"" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">"#);
1119
1120 xml.push_str(&format!(
1122 r#"<md:NameIDFormat>{}</md:NameIDFormat>"#,
1123 self.config.name_id_format.as_urn()
1124 ));
1125
1126 xml.push_str(&format!(
1128 r#"<md:AssertionConsumerService Binding="{}" Location="{}" index="0"/>"#,
1129 SamlBinding::HttpPost.as_urn(),
1130 self.config.acs_url
1131 ));
1132
1133 if let Some(ref sls_url) = self.config.sls_url {
1135 xml.push_str(&format!(
1136 r#"<md:SingleLogoutService Binding="{}" Location="{}"/>"#,
1137 self.config.slo_binding.as_urn(),
1138 sls_url
1139 ));
1140 }
1141
1142 xml.push_str(r#"</md:SPSSODescriptor></md:EntityDescriptor>"#);
1143 xml
1144 }
1145
1146 pub fn create_logout_request(
1148 &self,
1149 name_id: &str,
1150 session_index: Option<&str>,
1151 ) -> Option<String> {
1152 let slo_url = self.config.idp_slo_url.as_ref()?;
1153
1154 let id = format!("_{}_{}", generate_id(), timestamp_millis());
1155 let issue_instant = iso8601_now();
1156
1157 let mut xml = String::new();
1158 xml.push_str(r#"<?xml version="1.0" encoding="UTF-8"?>"#);
1159 xml.push_str(&format!(
1160 r#"<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="{}" Version="2.0" IssueInstant="{}" Destination="{}">"#,
1161 id, issue_instant, slo_url
1162 ));
1163
1164 xml.push_str(&format!(
1165 r#"<saml:Issuer>{}</saml:Issuer>"#,
1166 self.config.entity_id
1167 ));
1168
1169 xml.push_str(&format!(
1170 r#"<saml:NameID Format="{}">{}</saml:NameID>"#,
1171 self.config.name_id_format.as_urn(),
1172 name_id
1173 ));
1174
1175 if let Some(session_idx) = session_index {
1176 xml.push_str(&format!(
1177 r#"<samlp:SessionIndex>{}</samlp:SessionIndex>"#,
1178 session_idx
1179 ));
1180 }
1181
1182 xml.push_str(r#"</samlp:LogoutRequest>"#);
1183
1184 let deflated = deflate_and_encode(&xml);
1185 Some(format!(
1186 "{}?SAMLRequest={}",
1187 slo_url,
1188 urlencoding::encode(&deflated)
1189 ))
1190 }
1191}
1192
1193#[derive(Debug, Clone)]
1195pub struct SamlAuthResult {
1196 pub user: User,
1198 pub session_index: Option<String>,
1200 pub name_id: String,
1202 pub name_id_format: NameIdFormat,
1204 pub attributes: HashMap<String, Vec<String>>,
1206}
1207
1208fn generate_id() -> String {
1214 use std::collections::hash_map::RandomState;
1215 use std::hash::{BuildHasher, Hasher};
1216
1217 let hasher = RandomState::new();
1218 let mut h = hasher.build_hasher();
1219 h.write_u64(timestamp_millis());
1220 format!("{:016x}", h.finish())
1221}
1222
1223fn timestamp_millis() -> u64 {
1225 SystemTime::now()
1226 .duration_since(UNIX_EPOCH)
1227 .unwrap()
1228 .as_millis() as u64
1229}
1230
1231fn iso8601_now() -> String {
1233 let now = SystemTime::now()
1234 .duration_since(UNIX_EPOCH)
1235 .unwrap()
1236 .as_secs();
1237
1238 let secs_per_minute = 60;
1240 let secs_per_hour = 3600;
1241 let secs_per_day = 86400;
1242
1243 let days_since_1970 = now / secs_per_day;
1244 let time_of_day = now % secs_per_day;
1245
1246 let hours = time_of_day / secs_per_hour;
1247 let minutes = (time_of_day % secs_per_hour) / secs_per_minute;
1248 let seconds = time_of_day % secs_per_minute;
1249
1250 let (year, month, day) = days_to_ymd(days_since_1970);
1252
1253 format!(
1254 "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
1255 year, month, day, hours, minutes, seconds
1256 )
1257}
1258
1259fn days_to_ymd(days: u64) -> (u64, u64, u64) {
1261 let mut remaining = days;
1263 let mut year = 1970u64;
1264
1265 loop {
1266 let days_in_year = if is_leap_year(year) { 366 } else { 365 };
1267 if remaining < days_in_year {
1268 break;
1269 }
1270 remaining -= days_in_year;
1271 year += 1;
1272 }
1273
1274 let months = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
1275 let mut month = 1u64;
1276
1277 for (i, &days_in_month) in months.iter().enumerate() {
1278 let days_in_month = if i == 1 && is_leap_year(year) {
1279 29
1280 } else {
1281 days_in_month
1282 };
1283 if remaining < days_in_month {
1284 break;
1285 }
1286 remaining -= days_in_month;
1287 month += 1;
1288 }
1289
1290 (year, month, remaining + 1)
1291}
1292
1293fn is_leap_year(year: u64) -> bool {
1294 (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
1295}
1296
1297fn parse_iso8601(s: &str) -> Result<u64, ()> {
1299 if s.len() < 19 {
1301 return Err(());
1302 }
1303
1304 let year: u64 = s[0..4].parse().map_err(|_| ())?;
1305 let month: u64 = s[5..7].parse().map_err(|_| ())?;
1306 let day: u64 = s[8..10].parse().map_err(|_| ())?;
1307 let hour: u64 = s[11..13].parse().map_err(|_| ())?;
1308 let minute: u64 = s[14..16].parse().map_err(|_| ())?;
1309 let second: u64 = s[17..19].parse().map_err(|_| ())?;
1310
1311 let mut days = 0u64;
1313 for y in 1970..year {
1314 days += if is_leap_year(y) { 366 } else { 365 };
1315 }
1316
1317 let months = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
1318 for (i, &d) in months.iter().take((month - 1) as usize).enumerate() {
1319 days += if i == 1 && is_leap_year(year) {
1320 29
1321 } else {
1322 d
1323 };
1324 }
1325 days += day - 1;
1326
1327 Ok(days * 86400 + hour * 3600 + minute * 60 + second)
1328}
1329
1330fn deflate_and_encode(xml: &str) -> String {
1332 use base64::{engine::general_purpose::STANDARD, Engine as _};
1333
1334 STANDARD.encode(xml)
1337}
1338
1339fn extract_attribute(xml: &str, element: &str, attr: &str) -> Option<String> {
1341 let element_pattern = format!("<{}", element);
1342 let start = xml.find(&element_pattern)?;
1343 let end = xml[start..].find('>')? + start;
1344 let element_str = &xml[start..end];
1345
1346 let attr_pattern = format!("{}=\"", attr);
1347 let attr_start = element_str.find(&attr_pattern)? + attr_pattern.len();
1348 let attr_end = element_str[attr_start..].find('"')? + attr_start;
1349
1350 Some(element_str[attr_start..attr_end].to_string())
1351}
1352
1353fn extract_element_text(xml: &str, element: &str) -> Option<String> {
1355 let patterns = [
1357 format!("<{}:", element),
1358 format!("<{}>", element),
1359 ];
1360
1361 for pattern in &patterns {
1362 if let Some(start) = xml.find(pattern) {
1363 let content_start = xml[start..].find('>')? + start + 1;
1364 let end_pattern = format!("</{}", element);
1365 if let Some(end) = xml[content_start..].find(&end_pattern) {
1366 let content = &xml[content_start..content_start + end];
1367 if let Some(actual_end) = content.rfind('<') {
1369 return Some(content[..actual_end].trim().to_string());
1370 }
1371 return Some(content.trim().to_string());
1372 }
1373 }
1374 }
1375
1376 let start_tag = format!("<{}>", element);
1378 let end_tag = format!("</{}>", element);
1379 if let Some(start) = xml.find(&start_tag) {
1380 let content_start = start + start_tag.len();
1381 if let Some(end) = xml[content_start..].find(&end_tag) {
1382 return Some(xml[content_start..content_start + end].trim().to_string());
1383 }
1384 }
1385
1386 None
1387}
1388
1389fn extract_status_code(xml: &str) -> Option<SamlStatusCode> {
1391 let pattern = "StatusCode";
1392 let start = xml.find(pattern)?;
1393 let value_start = xml[start..].find("Value=\"")? + start + 7;
1394 let value_end = xml[value_start..].find('"')? + value_start;
1395 let value = &xml[value_start..value_end];
1396
1397 Some(SamlStatusCode::from_urn(value))
1398}
1399
1400fn parse_assertions(xml: &str) -> Result<Vec<SamlAssertion>, SamlError> {
1402 let mut assertions = Vec::new();
1403
1404 let assertion_pattern = "<saml:Assertion";
1406 if let Some(start) = xml.find(assertion_pattern) {
1407 let assertion_xml = &xml[start..];
1408
1409 let id = extract_attribute(assertion_xml, "Assertion", "ID")
1410 .unwrap_or_else(|| format!("_generated_{}", timestamp_millis()));
1411 let issue_instant = extract_attribute(assertion_xml, "Assertion", "IssueInstant")
1412 .unwrap_or_default();
1413 let issuer = extract_element_text(assertion_xml, "Issuer").unwrap_or_default();
1414
1415 let name_id = extract_element_text(assertion_xml, "NameID").unwrap_or_default();
1417 let name_id_format = extract_attribute(assertion_xml, "NameID", "Format")
1418 .map(|f| NameIdFormat::from_urn(&f))
1419 .unwrap_or_default();
1420
1421 let not_before = extract_attribute(assertion_xml, "Conditions", "NotBefore");
1423 let not_on_or_after = extract_attribute(assertion_xml, "Conditions", "NotOnOrAfter");
1424
1425 let session_index = extract_attribute(assertion_xml, "AuthnStatement", "SessionIndex");
1427 let session_not_on_or_after =
1428 extract_attribute(assertion_xml, "AuthnStatement", "SessionNotOnOrAfter");
1429
1430 let audiences = extract_element_text(assertion_xml, "Audience")
1432 .map(|a| vec![a])
1433 .unwrap_or_default();
1434
1435 let authn_context_class =
1437 extract_element_text(assertion_xml, "AuthnContextClassRef");
1438
1439 let attributes = parse_attributes(assertion_xml);
1441
1442 assertions.push(SamlAssertion {
1443 id,
1444 issue_instant,
1445 issuer,
1446 name_id,
1447 name_id_format,
1448 session_index,
1449 session_not_on_or_after,
1450 not_before,
1451 not_on_or_after,
1452 audiences,
1453 authn_context_class,
1454 attributes,
1455 });
1456 }
1457
1458 Ok(assertions)
1459}
1460
1461fn parse_attributes(xml: &str) -> HashMap<String, Vec<String>> {
1463 let mut attributes = HashMap::new();
1464
1465 let attr_statement = "<AttributeStatement";
1467 if let Some(start) = xml.find(attr_statement) {
1468 let end = xml[start..]
1469 .find("</AttributeStatement>")
1470 .unwrap_or(xml.len() - start);
1471 let statement_xml = &xml[start..start + end];
1472
1473 let mut search_pos = 0;
1475 while let Some(attr_start) = statement_xml[search_pos..].find("<Attribute ") {
1476 let attr_start = search_pos + attr_start;
1477
1478 if let Some(name) = extract_attribute(&statement_xml[attr_start..], "Attribute", "Name")
1480 {
1481 let attr_end = statement_xml[attr_start..]
1483 .find("</Attribute>")
1484 .unwrap_or(statement_xml.len() - attr_start);
1485 let attr_xml = &statement_xml[attr_start..attr_start + attr_end];
1486
1487 let mut values = Vec::new();
1488 let mut value_pos = 0;
1489 while let Some(value_start) = attr_xml[value_pos..].find("<AttributeValue") {
1490 let value_start = value_pos + value_start;
1491 if let Some(content_start) = attr_xml[value_start..].find('>') {
1492 let content_start = value_start + content_start + 1;
1493 if let Some(content_end) = attr_xml[content_start..].find("</AttributeValue>")
1494 {
1495 let value = attr_xml[content_start..content_start + content_end].trim();
1496 values.push(value.to_string());
1497 }
1498 }
1499 value_pos = value_start + 1;
1500 }
1501
1502 if !values.is_empty() {
1503 attributes.insert(name, values);
1504 }
1505 }
1506
1507 search_pos = attr_start + 1;
1508 }
1509 }
1510
1511 attributes
1512}
1513
1514#[cfg(test)]
1519mod tests {
1520 use super::*;
1521
1522 #[test]
1523 fn test_name_id_format() {
1524 assert_eq!(
1525 NameIdFormat::EmailAddress.as_urn(),
1526 "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
1527 );
1528
1529 let parsed = NameIdFormat::from_urn("urn:oasis:names:tc:SAML:2.0:nameid-format:persistent");
1530 assert_eq!(parsed, NameIdFormat::Persistent);
1531 }
1532
1533 #[test]
1534 fn test_saml_binding() {
1535 assert_eq!(
1536 SamlBinding::HttpPost.as_urn(),
1537 "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
1538 );
1539 }
1540
1541 #[test]
1542 fn test_status_code() {
1543 let success = SamlStatusCode::from_urn("urn:oasis:names:tc:SAML:2.0:status:Success");
1544 assert!(success.is_success());
1545
1546 let failure = SamlStatusCode::from_urn("urn:oasis:names:tc:SAML:2.0:status:AuthnFailed");
1547 assert!(!failure.is_success());
1548 }
1549
1550 #[test]
1551 fn test_config_builder() {
1552 let config = SamlConfig::new()
1553 .entity_id("https://sp.example.com/saml/metadata")
1554 .idp_sso_url("https://idp.example.com/saml/sso")
1555 .acs_url("https://sp.example.com/saml/acs")
1556 .name_id_format(NameIdFormat::EmailAddress)
1557 .want_assertions_signed(true);
1558
1559 assert_eq!(config.entity_id, "https://sp.example.com/saml/metadata");
1560 assert_eq!(config.name_id_format, NameIdFormat::EmailAddress);
1561 }
1562
1563 #[test]
1564 fn test_config_validation() {
1565 let incomplete = SamlConfig::new();
1566 assert!(incomplete.validate().is_err());
1567
1568 let valid = SamlConfig::new()
1569 .entity_id("https://sp.example.com")
1570 .idp_sso_url("https://idp.example.com/sso")
1571 .acs_url("https://sp.example.com/acs")
1572 .want_assertions_signed(false);
1573
1574 assert!(valid.validate().is_ok());
1575 }
1576
1577 #[test]
1578 fn test_config_presets() {
1579 let okta = SamlConfig::okta("myorg.okta.com", "app123", "https://myapp.com");
1580 assert!(okta.idp_sso_url.contains("okta.com"));
1581
1582 let azure = SamlConfig::azure_ad("tenant-id", "app-id", "https://myapp.com");
1583 assert!(azure.idp_sso_url.contains("microsoftonline.com"));
1584
1585 let adfs = SamlConfig::adfs("adfs.company.com", "https://myapp.com");
1586 assert!(adfs.idp_sso_url.contains("adfs"));
1587 }
1588
1589 #[test]
1590 fn test_authn_request_generation() {
1591 let config = SamlConfig::new()
1592 .entity_id("https://sp.example.com")
1593 .idp_sso_url("https://idp.example.com/sso")
1594 .acs_url("https://sp.example.com/acs")
1595 .name_id_format(NameIdFormat::EmailAddress);
1596
1597 let request = AuthnRequest::new(&config);
1598 let xml = request.to_xml();
1599
1600 assert!(xml.contains("AuthnRequest"));
1601 assert!(xml.contains("https://sp.example.com"));
1602 assert!(xml.contains("emailAddress"));
1603 }
1604
1605 #[test]
1606 fn test_authn_request_url() {
1607 let config = SamlConfig::new()
1608 .entity_id("https://sp.example.com")
1609 .idp_sso_url("https://idp.example.com/sso")
1610 .acs_url("https://sp.example.com/acs");
1611
1612 let request = AuthnRequest::new(&config);
1613 let url = request.to_redirect_url(Some("/dashboard"));
1614
1615 assert!(url.starts_with("https://idp.example.com/sso?"));
1616 assert!(url.contains("SAMLRequest="));
1617 assert!(url.contains("RelayState="));
1618 }
1619
1620 #[test]
1621 fn test_assertion_validation() {
1622 let config = SamlConfig::new()
1623 .entity_id("https://sp.example.com")
1624 .idp_sso_url("https://idp.example.com/sso")
1625 .acs_url("https://sp.example.com/acs")
1626 .idp_entity_id("https://idp.example.com")
1627 .max_clock_skew(Duration::from_secs(300));
1628
1629 let assertion = SamlAssertion {
1630 id: "_test".to_string(),
1631 issue_instant: iso8601_now(),
1632 issuer: "https://idp.example.com".to_string(),
1633 name_id: "user@example.com".to_string(),
1634 name_id_format: NameIdFormat::EmailAddress,
1635 session_index: Some("_session123".to_string()),
1636 session_not_on_or_after: None,
1637 not_before: None,
1638 not_on_or_after: None,
1639 audiences: vec!["https://sp.example.com".to_string()],
1640 authn_context_class: None,
1641 attributes: HashMap::new(),
1642 };
1643
1644 assert!(assertion.validate(&config).is_ok());
1645 }
1646
1647 #[test]
1648 fn test_assertion_audience_validation() {
1649 let config = SamlConfig::new()
1650 .entity_id("https://sp.example.com")
1651 .idp_sso_url("https://idp.example.com/sso")
1652 .acs_url("https://sp.example.com/acs");
1653
1654 let assertion = SamlAssertion {
1655 id: "_test".to_string(),
1656 issue_instant: iso8601_now(),
1657 issuer: "https://idp.example.com".to_string(),
1658 name_id: "user@example.com".to_string(),
1659 name_id_format: NameIdFormat::EmailAddress,
1660 session_index: None,
1661 session_not_on_or_after: None,
1662 not_before: None,
1663 not_on_or_after: None,
1664 audiences: vec!["https://other-sp.example.com".to_string()],
1665 authn_context_class: None,
1666 attributes: HashMap::new(),
1667 };
1668
1669 let result = assertion.validate(&config);
1670 assert!(matches!(result, Err(SamlError::AudienceRestrictionNotMet(_))));
1671 }
1672
1673 #[test]
1674 fn test_authenticator_creation() {
1675 let config = SamlConfig::new()
1676 .entity_id("https://sp.example.com")
1677 .idp_sso_url("https://idp.example.com/sso")
1678 .acs_url("https://sp.example.com/acs")
1679 .want_assertions_signed(false);
1680
1681 let authenticator = SamlAuthenticator::new(config);
1682 assert!(authenticator.is_ok());
1683 }
1684
1685 #[test]
1686 fn test_metadata_generation() {
1687 let config = SamlConfig::new()
1688 .entity_id("https://sp.example.com")
1689 .idp_sso_url("https://idp.example.com/sso")
1690 .acs_url("https://sp.example.com/acs")
1691 .sls_url("https://sp.example.com/sls")
1692 .want_assertions_signed(false);
1693
1694 let authenticator = SamlAuthenticator::new(config).unwrap();
1695 let metadata = authenticator.generate_metadata();
1696
1697 assert!(metadata.contains("EntityDescriptor"));
1698 assert!(metadata.contains("https://sp.example.com"));
1699 assert!(metadata.contains("AssertionConsumerService"));
1700 assert!(metadata.contains("SingleLogoutService"));
1701 }
1702
1703 #[test]
1704 fn test_iso8601_generation() {
1705 let now = iso8601_now();
1706 assert!(now.contains("T"));
1707 assert!(now.ends_with("Z"));
1708 assert_eq!(now.len(), 20);
1709 }
1710
1711 #[test]
1712 fn test_iso8601_parsing() {
1713 let timestamp = parse_iso8601("2024-01-15T10:30:00Z");
1714 assert!(timestamp.is_ok());
1715
1716 let invalid = parse_iso8601("invalid");
1717 assert!(invalid.is_err());
1718 }
1719
1720 #[test]
1721 fn test_xml_attribute_extraction() {
1722 let xml = r#"<Response ID="resp123" Version="2.0">"#;
1723 assert_eq!(
1724 extract_attribute(xml, "Response", "ID"),
1725 Some("resp123".to_string())
1726 );
1727 }
1728
1729 #[test]
1730 fn test_attribute_parsing() {
1731 let xml = r#"
1732 <AttributeStatement>
1733 <Attribute Name="email">
1734 <AttributeValue>user@example.com</AttributeValue>
1735 </Attribute>
1736 <Attribute Name="roles">
1737 <AttributeValue>admin</AttributeValue>
1738 <AttributeValue>user</AttributeValue>
1739 </Attribute>
1740 </AttributeStatement>
1741 "#;
1742
1743 let attrs = parse_attributes(xml);
1744 assert_eq!(
1745 attrs.get("email"),
1746 Some(&vec!["user@example.com".to_string()])
1747 );
1748 assert_eq!(
1749 attrs.get("roles"),
1750 Some(&vec!["admin".to_string(), "user".to_string()])
1751 );
1752 }
1753
1754 #[test]
1755 fn test_saml_error_display() {
1756 let err = SamlError::Configuration("test error".to_string());
1757 let display = format!("{}", err);
1758 assert!(display.contains("test error"));
1759 }
1760
1761 #[test]
1762 fn test_logout_request() {
1763 let config = SamlConfig::new()
1764 .entity_id("https://sp.example.com")
1765 .idp_sso_url("https://idp.example.com/sso")
1766 .idp_slo_url("https://idp.example.com/slo")
1767 .acs_url("https://sp.example.com/acs")
1768 .want_assertions_signed(false);
1769
1770 let authenticator = SamlAuthenticator::new(config).unwrap();
1771 let logout_url =
1772 authenticator.create_logout_request("user@example.com", Some("_session123"));
1773
1774 assert!(logout_url.is_some());
1775 let url = logout_url.unwrap();
1776 assert!(url.contains("SAMLRequest="));
1777 }
1778}