Skip to main content

actix_security_core/http/security/
saml.rs

1//! SAML 2.0 Authentication Module
2//!
3//! Provides SAML 2.0 Single Sign-On (SSO) authentication support.
4//! This module implements the Service Provider (SP) side of SAML authentication.
5//!
6//! # Features
7//!
8//! - **SAML AuthnRequest Generation**: Create authentication requests to IdP
9//! - **SAML Response Validation**: Parse and validate IdP responses
10//! - **Assertion Processing**: Extract user information from SAML assertions
11//! - **Signature Verification**: Verify XML signatures (with appropriate crypto)
12//! - **Metadata Support**: Configure from IdP/SP metadata
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use actix_security::http::security::saml::{SamlConfig, SamlAuthenticator};
18//!
19//! let config = SamlConfig::new()
20//!     .entity_id("https://myapp.example.com/saml/metadata")
21//!     .idp_sso_url("https://idp.example.com/saml/sso")
22//!     .idp_certificate(include_str!("../idp-cert.pem"))
23//!     .sp_private_key(include_str!("../sp-key.pem"))
24//!     .assertion_consumer_service_url("https://myapp.example.com/saml/acs");
25//!
26//! let authenticator = SamlAuthenticator::new(config);
27//! ```
28//!
29//! # SAML Flow
30//!
31//! 1. User accesses protected resource
32//! 2. SP generates AuthnRequest and redirects to IdP
33//! 3. User authenticates at IdP
34//! 4. IdP sends SAML Response back to SP's ACS URL
35//! 5. SP validates response and creates session
36
37use std::collections::HashMap;
38use std::sync::Arc;
39use std::time::{Duration, SystemTime, UNIX_EPOCH};
40
41use crate::http::security::User;
42
43/// SAML 2.0 name ID formats
44#[derive(Debug, Clone, PartialEq, Eq, Default)]
45pub enum NameIdFormat {
46    /// Unspecified format
47    #[default]
48    Unspecified,
49    /// Email address format
50    EmailAddress,
51    /// X.509 subject name format
52    X509SubjectName,
53    /// Windows domain qualified name
54    WindowsDomainQualifiedName,
55    /// Kerberos principal name
56    Kerberos,
57    /// Persistent identifier
58    Persistent,
59    /// Transient identifier
60    Transient,
61    /// Custom format
62    Custom(String),
63}
64
65impl NameIdFormat {
66    /// Get the URN for this name ID format
67    pub fn as_urn(&self) -> &str {
68        match self {
69            NameIdFormat::Unspecified => {
70                "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"
71            }
72            NameIdFormat::EmailAddress => {
73                "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
74            }
75            NameIdFormat::X509SubjectName => {
76                "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName"
77            }
78            NameIdFormat::WindowsDomainQualifiedName => {
79                "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName"
80            }
81            NameIdFormat::Kerberos => {
82                "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos"
83            }
84            NameIdFormat::Persistent => {
85                "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
86            }
87            NameIdFormat::Transient => {
88                "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"
89            }
90            NameIdFormat::Custom(urn) => urn,
91        }
92    }
93
94    /// Parse a name ID format from URN
95    pub fn from_urn(urn: &str) -> Self {
96        match urn {
97            "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified" => NameIdFormat::Unspecified,
98            "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress" => NameIdFormat::EmailAddress,
99            "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName" => {
100                NameIdFormat::X509SubjectName
101            }
102            "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName" => {
103                NameIdFormat::WindowsDomainQualifiedName
104            }
105            "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos" => NameIdFormat::Kerberos,
106            "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent" => NameIdFormat::Persistent,
107            "urn:oasis:names:tc:SAML:2.0:nameid-format:transient" => NameIdFormat::Transient,
108            other => NameIdFormat::Custom(other.to_string()),
109        }
110    }
111}
112
113
114/// SAML 2.0 binding types
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
116pub enum SamlBinding {
117    /// HTTP Redirect binding (GET with query parameters)
118    #[default]
119    HttpRedirect,
120    /// HTTP POST binding (form submission)
121    HttpPost,
122    /// HTTP Artifact binding
123    HttpArtifact,
124    /// SOAP binding
125    Soap,
126}
127
128impl SamlBinding {
129    /// Get the URN for this binding
130    pub fn as_urn(&self) -> &str {
131        match self {
132            SamlBinding::HttpRedirect => "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
133            SamlBinding::HttpPost => "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
134            SamlBinding::HttpArtifact => "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact",
135            SamlBinding::Soap => "urn:oasis:names:tc:SAML:2.0:bindings:SOAP",
136        }
137    }
138}
139
140/// SAML 2.0 authentication context classes
141#[derive(Debug, Clone, PartialEq, Eq, Default)]
142pub enum AuthnContextClass {
143    /// Unspecified authentication
144    #[default]
145    Unspecified,
146    /// Password authentication
147    Password,
148    /// Password protected transport
149    PasswordProtectedTransport,
150    /// X.509 certificate
151    X509,
152    /// Kerberos authentication
153    Kerberos,
154    /// Multi-factor authentication
155    MultiFactor,
156    /// Custom authentication context
157    Custom(String),
158}
159
160impl AuthnContextClass {
161    /// Get the URN for this authentication context
162    pub fn as_urn(&self) -> &str {
163        match self {
164            AuthnContextClass::Unspecified => {
165                "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
166            }
167            AuthnContextClass::Password => "urn:oasis:names:tc:SAML:2.0:ac:classes:Password",
168            AuthnContextClass::PasswordProtectedTransport => {
169                "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport"
170            }
171            AuthnContextClass::X509 => "urn:oasis:names:tc:SAML:2.0:ac:classes:X509",
172            AuthnContextClass::Kerberos => "urn:oasis:names:tc:SAML:2.0:ac:classes:Kerberos",
173            AuthnContextClass::MultiFactor => {
174                "urn:oasis:names:tc:SAML:2.0:ac:classes:MultiFactor"
175            }
176            AuthnContextClass::Custom(urn) => urn,
177        }
178    }
179}
180
181/// SAML Status codes
182#[derive(Debug, Clone, PartialEq, Eq)]
183pub enum SamlStatusCode {
184    /// Authentication succeeded
185    Success,
186    /// Request could not be performed due to an error
187    Requester,
188    /// Request could not be performed due to an error on the responding provider
189    Responder,
190    /// SAML responder or SAML authority is able to process the request but has chosen not to respond
191    VersionMismatch,
192    /// The SAML responder cannot properly authenticate the requesting party
193    AuthnFailed,
194    /// The responding provider cannot or will not support the requested name identifier policy
195    InvalidNameIdPolicy,
196    /// The specified authentication context requirements cannot be met
197    NoAuthnContext,
198    /// Unknown status code
199    Unknown(String),
200}
201
202impl SamlStatusCode {
203    /// Parse status code from URN
204    pub fn from_urn(urn: &str) -> Self {
205        match urn {
206            "urn:oasis:names:tc:SAML:2.0:status:Success" => SamlStatusCode::Success,
207            "urn:oasis:names:tc:SAML:2.0:status:Requester" => SamlStatusCode::Requester,
208            "urn:oasis:names:tc:SAML:2.0:status:Responder" => SamlStatusCode::Responder,
209            "urn:oasis:names:tc:SAML:2.0:status:VersionMismatch" => SamlStatusCode::VersionMismatch,
210            "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed" => SamlStatusCode::AuthnFailed,
211            "urn:oasis:names:tc:SAML:2.0:status:InvalidNameIDPolicy" => {
212                SamlStatusCode::InvalidNameIdPolicy
213            }
214            "urn:oasis:names:tc:SAML:2.0:status:NoAuthnContext" => SamlStatusCode::NoAuthnContext,
215            other => SamlStatusCode::Unknown(other.to_string()),
216        }
217    }
218
219    /// Check if this status indicates success
220    pub fn is_success(&self) -> bool {
221        matches!(self, SamlStatusCode::Success)
222    }
223}
224
225/// SAML Service Provider configuration
226#[derive(Debug, Clone)]
227pub struct SamlConfig {
228    /// SP Entity ID (unique identifier)
229    pub entity_id: String,
230    /// IdP SSO URL (where to send AuthnRequest)
231    pub idp_sso_url: String,
232    /// IdP Single Logout URL (optional)
233    pub idp_slo_url: Option<String>,
234    /// IdP Entity ID
235    pub idp_entity_id: Option<String>,
236    /// IdP certificate for signature verification (PEM format)
237    pub idp_certificate: Option<String>,
238    /// SP private key for signing requests (PEM format)
239    pub sp_private_key: Option<String>,
240    /// SP certificate (PEM format)
241    pub sp_certificate: Option<String>,
242    /// Assertion Consumer Service URL
243    pub acs_url: String,
244    /// Single Logout Service URL
245    pub sls_url: Option<String>,
246    /// Preferred binding for SSO
247    pub sso_binding: SamlBinding,
248    /// Preferred binding for SLO
249    pub slo_binding: SamlBinding,
250    /// Name ID format to request
251    pub name_id_format: NameIdFormat,
252    /// Authentication context class to request
253    pub authn_context_class: Option<AuthnContextClass>,
254    /// Whether to sign AuthnRequest
255    pub sign_authn_request: bool,
256    /// Whether to require signed assertions
257    pub want_assertions_signed: bool,
258    /// Whether to require encrypted assertions
259    pub want_assertions_encrypted: bool,
260    /// Maximum allowed clock skew
261    pub max_clock_skew: Duration,
262    /// Attribute mapping (SAML attribute name -> User field)
263    pub attribute_mapping: HashMap<String, String>,
264    /// Role attribute name
265    pub role_attribute: Option<String>,
266    /// Authority attribute name
267    pub authority_attribute: Option<String>,
268    /// Default roles for authenticated users
269    pub default_roles: Vec<String>,
270    /// Allow unsolicited responses (IdP-initiated SSO)
271    pub allow_unsolicited_responses: bool,
272    /// Session timeout
273    pub session_timeout: Duration,
274}
275
276impl SamlConfig {
277    /// Create a new SAML configuration with minimal required fields
278    pub fn new() -> Self {
279        Self {
280            entity_id: String::new(),
281            idp_sso_url: String::new(),
282            idp_slo_url: None,
283            idp_entity_id: None,
284            idp_certificate: None,
285            sp_private_key: None,
286            sp_certificate: None,
287            acs_url: String::new(),
288            sls_url: None,
289            sso_binding: SamlBinding::HttpRedirect,
290            slo_binding: SamlBinding::HttpRedirect,
291            name_id_format: NameIdFormat::Unspecified,
292            authn_context_class: None,
293            sign_authn_request: false,
294            want_assertions_signed: true,
295            want_assertions_encrypted: false,
296            max_clock_skew: Duration::from_secs(120),
297            attribute_mapping: HashMap::new(),
298            role_attribute: None,
299            authority_attribute: None,
300            default_roles: vec!["USER".to_string()],
301            allow_unsolicited_responses: false,
302            session_timeout: Duration::from_secs(3600),
303        }
304    }
305
306    /// Set the SP entity ID
307    pub fn entity_id(mut self, entity_id: impl Into<String>) -> Self {
308        self.entity_id = entity_id.into();
309        self
310    }
311
312    /// Set the IdP SSO URL
313    pub fn idp_sso_url(mut self, url: impl Into<String>) -> Self {
314        self.idp_sso_url = url.into();
315        self
316    }
317
318    /// Set the IdP SLO URL
319    pub fn idp_slo_url(mut self, url: impl Into<String>) -> Self {
320        self.idp_slo_url = Some(url.into());
321        self
322    }
323
324    /// Set the IdP entity ID
325    pub fn idp_entity_id(mut self, entity_id: impl Into<String>) -> Self {
326        self.idp_entity_id = Some(entity_id.into());
327        self
328    }
329
330    /// Set the IdP certificate (PEM format)
331    pub fn idp_certificate(mut self, cert: impl Into<String>) -> Self {
332        self.idp_certificate = Some(cert.into());
333        self
334    }
335
336    /// Set the SP private key (PEM format)
337    pub fn sp_private_key(mut self, key: impl Into<String>) -> Self {
338        self.sp_private_key = Some(key.into());
339        self
340    }
341
342    /// Set the SP certificate (PEM format)
343    pub fn sp_certificate(mut self, cert: impl Into<String>) -> Self {
344        self.sp_certificate = Some(cert.into());
345        self
346    }
347
348    /// Set the Assertion Consumer Service URL
349    pub fn acs_url(mut self, url: impl Into<String>) -> Self {
350        self.acs_url = url.into();
351        self
352    }
353
354    /// Alias for acs_url
355    pub fn assertion_consumer_service_url(self, url: impl Into<String>) -> Self {
356        self.acs_url(url)
357    }
358
359    /// Set the Single Logout Service URL
360    pub fn sls_url(mut self, url: impl Into<String>) -> Self {
361        self.sls_url = Some(url.into());
362        self
363    }
364
365    /// Set the SSO binding
366    pub fn sso_binding(mut self, binding: SamlBinding) -> Self {
367        self.sso_binding = binding;
368        self
369    }
370
371    /// Set the SLO binding
372    pub fn slo_binding(mut self, binding: SamlBinding) -> Self {
373        self.slo_binding = binding;
374        self
375    }
376
377    /// Set the Name ID format
378    pub fn name_id_format(mut self, format: NameIdFormat) -> Self {
379        self.name_id_format = format;
380        self
381    }
382
383    /// Set the authentication context class
384    pub fn authn_context_class(mut self, class: AuthnContextClass) -> Self {
385        self.authn_context_class = Some(class);
386        self
387    }
388
389    /// Set whether to sign AuthnRequest
390    pub fn sign_authn_request(mut self, sign: bool) -> Self {
391        self.sign_authn_request = sign;
392        self
393    }
394
395    /// Set whether assertions must be signed
396    pub fn want_assertions_signed(mut self, signed: bool) -> Self {
397        self.want_assertions_signed = signed;
398        self
399    }
400
401    /// Set whether assertions must be encrypted
402    pub fn want_assertions_encrypted(mut self, encrypted: bool) -> Self {
403        self.want_assertions_encrypted = encrypted;
404        self
405    }
406
407    /// Set maximum clock skew tolerance
408    pub fn max_clock_skew(mut self, skew: Duration) -> Self {
409        self.max_clock_skew = skew;
410        self
411    }
412
413    /// Add an attribute mapping
414    pub fn map_attribute(
415        mut self,
416        saml_attribute: impl Into<String>,
417        user_field: impl Into<String>,
418    ) -> Self {
419        self.attribute_mapping
420            .insert(saml_attribute.into(), user_field.into());
421        self
422    }
423
424    /// Set the role attribute name
425    pub fn role_attribute(mut self, attr: impl Into<String>) -> Self {
426        self.role_attribute = Some(attr.into());
427        self
428    }
429
430    /// Set the authority attribute name
431    pub fn authority_attribute(mut self, attr: impl Into<String>) -> Self {
432        self.authority_attribute = Some(attr.into());
433        self
434    }
435
436    /// Set default roles for authenticated users
437    pub fn default_roles(mut self, roles: Vec<String>) -> Self {
438        self.default_roles = roles;
439        self
440    }
441
442    /// Set whether to allow unsolicited responses
443    pub fn allow_unsolicited_responses(mut self, allow: bool) -> Self {
444        self.allow_unsolicited_responses = allow;
445        self
446    }
447
448    /// Set session timeout
449    pub fn session_timeout(mut self, timeout: Duration) -> Self {
450        self.session_timeout = timeout;
451        self
452    }
453
454    /// Create configuration preset for Okta
455    pub fn okta(
456        okta_domain: impl Into<String>,
457        app_id: impl Into<String>,
458        sp_entity_id: impl Into<String>,
459    ) -> Self {
460        let domain = okta_domain.into();
461        let app = app_id.into();
462        Self::new()
463            .entity_id(sp_entity_id)
464            .idp_sso_url(format!("https://{}/app/{}/sso/saml", domain, app))
465            .idp_entity_id(format!("http://www.okta.com/{}", app))
466            .name_id_format(NameIdFormat::EmailAddress)
467            .sso_binding(SamlBinding::HttpPost)
468    }
469
470    /// Create configuration preset for Azure AD
471    pub fn azure_ad(
472        tenant_id: impl Into<String>,
473        _app_id: impl Into<String>,
474        sp_entity_id: impl Into<String>,
475    ) -> Self {
476        let tenant = tenant_id.into();
477        Self::new()
478            .entity_id(sp_entity_id)
479            .idp_sso_url(format!(
480                "https://login.microsoftonline.com/{}/saml2",
481                tenant
482            ))
483            .idp_entity_id(format!(
484                "https://sts.windows.net/{}/",
485                tenant
486            ))
487            .name_id_format(NameIdFormat::EmailAddress)
488            .sso_binding(SamlBinding::HttpRedirect)
489            .map_attribute(
490                "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress",
491                "email",
492            )
493            .map_attribute(
494                "http://schemas.microsoft.com/ws/2008/06/identity/claims/groups",
495                "groups",
496            )
497    }
498
499    /// Create configuration preset for Google Workspace
500    pub fn google_workspace(sp_entity_id: impl Into<String>, acs_url: impl Into<String>) -> Self {
501        Self::new()
502            .entity_id(sp_entity_id)
503            .idp_sso_url("https://accounts.google.com/o/saml2/idp")
504            .acs_url(acs_url)
505            .name_id_format(NameIdFormat::EmailAddress)
506            .sso_binding(SamlBinding::HttpRedirect)
507    }
508
509    /// Create configuration preset for ADFS
510    pub fn adfs(adfs_host: impl Into<String>, sp_entity_id: impl Into<String>) -> Self {
511        let host = adfs_host.into();
512        Self::new()
513            .entity_id(sp_entity_id)
514            .idp_sso_url(format!("https://{}/adfs/ls/", host))
515            .idp_entity_id(format!("http://{}/adfs/services/trust", host))
516            .name_id_format(NameIdFormat::Unspecified)
517            .sso_binding(SamlBinding::HttpPost)
518    }
519
520    /// Validate the configuration
521    pub fn validate(&self) -> Result<(), SamlError> {
522        if self.entity_id.is_empty() {
523            return Err(SamlError::Configuration("entity_id is required".into()));
524        }
525        if self.idp_sso_url.is_empty() {
526            return Err(SamlError::Configuration("idp_sso_url is required".into()));
527        }
528        if self.acs_url.is_empty() {
529            return Err(SamlError::Configuration("acs_url is required".into()));
530        }
531        if self.sign_authn_request && self.sp_private_key.is_none() {
532            return Err(SamlError::Configuration(
533                "sp_private_key is required when sign_authn_request is true".into(),
534            ));
535        }
536        if self.want_assertions_signed && self.idp_certificate.is_none() {
537            return Err(SamlError::Configuration(
538                "idp_certificate is required when want_assertions_signed is true".into(),
539            ));
540        }
541        Ok(())
542    }
543}
544
545impl Default for SamlConfig {
546    fn default() -> Self {
547        Self::new()
548    }
549}
550
551/// SAML authentication error
552#[derive(Debug, Clone)]
553pub enum SamlError {
554    /// Configuration error
555    Configuration(String),
556    /// Invalid SAML response
557    InvalidResponse(String),
558    /// Signature verification failed
559    SignatureVerificationFailed(String),
560    /// Assertion validation failed
561    AssertionValidationFailed(String),
562    /// Time condition not met
563    TimeConditionNotMet(String),
564    /// Audience restriction not met
565    AudienceRestrictionNotMet(String),
566    /// Required attribute missing
567    MissingAttribute(String),
568    /// IdP returned an error status
569    IdpError(SamlStatusCode, Option<String>),
570    /// Decryption failed
571    DecryptionFailed(String),
572    /// XML parsing error
573    XmlParsingError(String),
574    /// Network error
575    NetworkError(String),
576}
577
578impl std::fmt::Display for SamlError {
579    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580        match self {
581            SamlError::Configuration(msg) => write!(f, "SAML configuration error: {}", msg),
582            SamlError::InvalidResponse(msg) => write!(f, "Invalid SAML response: {}", msg),
583            SamlError::SignatureVerificationFailed(msg) => {
584                write!(f, "SAML signature verification failed: {}", msg)
585            }
586            SamlError::AssertionValidationFailed(msg) => {
587                write!(f, "SAML assertion validation failed: {}", msg)
588            }
589            SamlError::TimeConditionNotMet(msg) => {
590                write!(f, "SAML time condition not met: {}", msg)
591            }
592            SamlError::AudienceRestrictionNotMet(msg) => {
593                write!(f, "SAML audience restriction not met: {}", msg)
594            }
595            SamlError::MissingAttribute(attr) => {
596                write!(f, "Required SAML attribute missing: {}", attr)
597            }
598            SamlError::IdpError(code, msg) => {
599                write!(f, "IdP returned error {:?}: {:?}", code, msg)
600            }
601            SamlError::DecryptionFailed(msg) => write!(f, "SAML decryption failed: {}", msg),
602            SamlError::XmlParsingError(msg) => write!(f, "XML parsing error: {}", msg),
603            SamlError::NetworkError(msg) => write!(f, "Network error: {}", msg),
604        }
605    }
606}
607
608impl std::error::Error for SamlError {}
609
610/// SAML AuthnRequest
611#[derive(Debug, Clone)]
612pub struct AuthnRequest {
613    /// Request ID
614    pub id: String,
615    /// Issue instant (ISO 8601)
616    pub issue_instant: String,
617    /// SP Entity ID
618    pub issuer: String,
619    /// Destination (IdP SSO URL)
620    pub destination: String,
621    /// Assertion Consumer Service URL
622    pub acs_url: String,
623    /// Protocol binding for response
624    pub protocol_binding: SamlBinding,
625    /// Name ID policy format
626    pub name_id_format: NameIdFormat,
627    /// Requested authentication context
628    pub authn_context: Option<AuthnContextClass>,
629    /// Force re-authentication
630    pub force_authn: bool,
631    /// Passive authentication (no user interaction)
632    pub is_passive: bool,
633}
634
635impl AuthnRequest {
636    /// Create a new AuthnRequest with generated ID
637    pub fn new(config: &SamlConfig) -> Self {
638        let id = format!("_{}_{}", generate_id(), timestamp_millis());
639
640        Self {
641            id,
642            issue_instant: iso8601_now(),
643            issuer: config.entity_id.clone(),
644            destination: config.idp_sso_url.clone(),
645            acs_url: config.acs_url.clone(),
646            protocol_binding: SamlBinding::HttpPost,
647            name_id_format: config.name_id_format.clone(),
648            authn_context: config.authn_context_class.clone(),
649            force_authn: false,
650            is_passive: false,
651        }
652    }
653
654    /// Set force re-authentication
655    pub fn force_authn(mut self, force: bool) -> Self {
656        self.force_authn = force;
657        self
658    }
659
660    /// Set passive authentication
661    pub fn is_passive(mut self, passive: bool) -> Self {
662        self.is_passive = passive;
663        self
664    }
665
666    /// Generate XML for this AuthnRequest
667    pub fn to_xml(&self) -> String {
668        let mut xml = String::new();
669        xml.push_str(r#"<?xml version="1.0" encoding="UTF-8"?>"#);
670        xml.push_str(&format!(
671            r#"<samlp:AuthnRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="{}" Version="2.0" IssueInstant="{}" Destination="{}" AssertionConsumerServiceURL="{}" ProtocolBinding="{}""#,
672            self.id,
673            self.issue_instant,
674            self.destination,
675            self.acs_url,
676            self.protocol_binding.as_urn()
677        ));
678
679        if self.force_authn {
680            xml.push_str(r#" ForceAuthn="true""#);
681        }
682        if self.is_passive {
683            xml.push_str(r#" IsPassive="true""#);
684        }
685
686        xml.push('>');
687
688        // Issuer
689        xml.push_str(&format!(
690            r#"<saml:Issuer>{}</saml:Issuer>"#,
691            self.issuer
692        ));
693
694        // NameIDPolicy
695        xml.push_str(&format!(
696            r#"<samlp:NameIDPolicy Format="{}" AllowCreate="true"/>"#,
697            self.name_id_format.as_urn()
698        ));
699
700        // RequestedAuthnContext
701        if let Some(ref authn_context) = self.authn_context {
702            xml.push_str(r#"<samlp:RequestedAuthnContext Comparison="exact">"#);
703            xml.push_str(&format!(
704                r#"<saml:AuthnContextClassRef>{}</saml:AuthnContextClassRef>"#,
705                authn_context.as_urn()
706            ));
707            xml.push_str(r#"</samlp:RequestedAuthnContext>"#);
708        }
709
710        xml.push_str(r#"</samlp:AuthnRequest>"#);
711        xml
712    }
713
714    /// Get the URL for HTTP Redirect binding (deflated and base64-encoded)
715    pub fn to_redirect_url(&self, relay_state: Option<&str>) -> String {
716        let xml = self.to_xml();
717        let deflated = deflate_and_encode(&xml);
718
719        let mut url = format!(
720            "{}?SAMLRequest={}",
721            self.destination,
722            urlencoding::encode(&deflated)
723        );
724
725        if let Some(state) = relay_state {
726            url.push_str(&format!("&RelayState={}", urlencoding::encode(state)));
727        }
728
729        url
730    }
731}
732
733/// SAML Assertion
734#[derive(Debug, Clone)]
735pub struct SamlAssertion {
736    /// Assertion ID
737    pub id: String,
738    /// Issue instant
739    pub issue_instant: String,
740    /// Issuer (IdP Entity ID)
741    pub issuer: String,
742    /// Subject NameID
743    pub name_id: String,
744    /// Subject NameID format
745    pub name_id_format: NameIdFormat,
746    /// Session index
747    pub session_index: Option<String>,
748    /// Session not on or after
749    pub session_not_on_or_after: Option<String>,
750    /// Not before condition
751    pub not_before: Option<String>,
752    /// Not on or after condition
753    pub not_on_or_after: Option<String>,
754    /// Audience restrictions
755    pub audiences: Vec<String>,
756    /// Authentication context class
757    pub authn_context_class: Option<String>,
758    /// Attributes
759    pub attributes: HashMap<String, Vec<String>>,
760}
761
762impl SamlAssertion {
763    /// Get a single-valued attribute
764    pub fn get_attribute(&self, name: &str) -> Option<&str> {
765        self.attributes
766            .get(name)
767            .and_then(|values| values.first())
768            .map(|s| s.as_str())
769    }
770
771    /// Get a multi-valued attribute
772    pub fn get_attribute_values(&self, name: &str) -> Option<&Vec<String>> {
773        self.attributes.get(name)
774    }
775
776    /// Validate the assertion against configuration
777    pub fn validate(&self, config: &SamlConfig) -> Result<(), SamlError> {
778        let now = SystemTime::now()
779            .duration_since(UNIX_EPOCH)
780            .unwrap()
781            .as_secs();
782
783        // Check NotBefore
784        if let Some(ref not_before) = self.not_before {
785            if let Ok(nb_time) = parse_iso8601(not_before) {
786                let skew = config.max_clock_skew.as_secs();
787                if now + skew < nb_time {
788                    return Err(SamlError::TimeConditionNotMet(format!(
789                        "Assertion not valid before {}",
790                        not_before
791                    )));
792                }
793            }
794        }
795
796        // Check NotOnOrAfter
797        if let Some(ref not_on_or_after) = self.not_on_or_after {
798            if let Ok(noa_time) = parse_iso8601(not_on_or_after) {
799                let skew = config.max_clock_skew.as_secs();
800                if now > noa_time + skew {
801                    return Err(SamlError::TimeConditionNotMet(format!(
802                        "Assertion expired at {}",
803                        not_on_or_after
804                    )));
805                }
806            }
807        }
808
809        // Check Audience
810        if !self.audiences.is_empty() && !self.audiences.contains(&config.entity_id) {
811            return Err(SamlError::AudienceRestrictionNotMet(format!(
812                "SP entity ID {} not in audiences: {:?}",
813                config.entity_id, self.audiences
814            )));
815        }
816
817        // Check Issuer
818        if let Some(ref expected_issuer) = config.idp_entity_id {
819            if &self.issuer != expected_issuer {
820                return Err(SamlError::AssertionValidationFailed(format!(
821                    "Issuer mismatch: expected {}, got {}",
822                    expected_issuer, self.issuer
823                )));
824            }
825        }
826
827        Ok(())
828    }
829}
830
831/// SAML Response
832#[derive(Debug, Clone)]
833pub struct SamlResponse {
834    /// Response ID
835    pub id: String,
836    /// In response to (AuthnRequest ID)
837    pub in_response_to: Option<String>,
838    /// Issue instant
839    pub issue_instant: String,
840    /// Destination
841    pub destination: Option<String>,
842    /// Issuer (IdP Entity ID)
843    pub issuer: String,
844    /// Status code
845    pub status_code: SamlStatusCode,
846    /// Status message
847    pub status_message: Option<String>,
848    /// Assertion(s)
849    pub assertions: Vec<SamlAssertion>,
850}
851
852impl SamlResponse {
853    /// Parse a SAML Response from base64-encoded XML
854    ///
855    /// Note: In production, you should use a proper XML/SAML library
856    /// like `samael` or `saml2` for full parsing and signature verification.
857    pub fn from_base64(encoded: &str) -> Result<Self, SamlError> {
858        use base64::{engine::general_purpose::STANDARD, Engine as _};
859
860        let decoded = STANDARD
861            .decode(encoded)
862            .map_err(|e| SamlError::InvalidResponse(format!("Base64 decode error: {}", e)))?;
863
864        let xml = String::from_utf8(decoded)
865            .map_err(|e| SamlError::InvalidResponse(format!("UTF-8 decode error: {}", e)))?;
866
867        Self::from_xml(&xml)
868    }
869
870    /// Parse a SAML Response from XML string
871    ///
872    /// Note: This is a simplified parser. In production, use a proper
873    /// XML library with namespace support and signature verification.
874    pub fn from_xml(xml: &str) -> Result<Self, SamlError> {
875        // This is a simplified parser for demonstration.
876        // Production code should use a proper SAML library.
877
878        let id = extract_attribute(xml, "Response", "ID")
879            .ok_or_else(|| SamlError::XmlParsingError("Missing Response ID".into()))?;
880
881        let in_response_to = extract_attribute(xml, "Response", "InResponseTo");
882        let issue_instant = extract_attribute(xml, "Response", "IssueInstant")
883            .ok_or_else(|| SamlError::XmlParsingError("Missing IssueInstant".into()))?;
884        let destination = extract_attribute(xml, "Response", "Destination");
885        let issuer = extract_element_text(xml, "Issuer")
886            .ok_or_else(|| SamlError::XmlParsingError("Missing Issuer".into()))?;
887
888        // Parse status
889        let status_code = extract_status_code(xml).unwrap_or(SamlStatusCode::Unknown(String::new()));
890        let status_message = extract_element_text(xml, "StatusMessage");
891
892        // Parse assertions (simplified)
893        let assertions = parse_assertions(xml)?;
894
895        Ok(Self {
896            id,
897            in_response_to,
898            issue_instant,
899            destination,
900            issuer,
901            status_code,
902            status_message,
903            assertions,
904        })
905    }
906
907    /// Check if the response indicates success
908    pub fn is_success(&self) -> bool {
909        self.status_code.is_success()
910    }
911
912    /// Get the first assertion (most common case)
913    pub fn assertion(&self) -> Option<&SamlAssertion> {
914        self.assertions.first()
915    }
916}
917
918/// SAML Authenticator for actix-web
919#[derive(Clone)]
920pub struct SamlAuthenticator {
921    config: Arc<SamlConfig>,
922    pending_requests: Arc<std::sync::RwLock<HashMap<String, PendingRequest>>>,
923}
924
925/// A pending authentication request
926#[derive(Debug, Clone)]
927#[allow(dead_code)] // Fields are stored for request validation and relay state retrieval
928struct PendingRequest {
929    /// Request ID
930    id: String,
931    /// Created at timestamp
932    created_at: u64,
933    /// Relay state (redirect URL after auth)
934    relay_state: Option<String>,
935}
936
937impl SamlAuthenticator {
938    /// Create a new SAML authenticator
939    pub fn new(config: SamlConfig) -> Result<Self, SamlError> {
940        config.validate()?;
941        Ok(Self {
942            config: Arc::new(config),
943            pending_requests: Arc::new(std::sync::RwLock::new(HashMap::new())),
944        })
945    }
946
947    /// Get the configuration
948    pub fn config(&self) -> &SamlConfig {
949        &self.config
950    }
951
952    /// Create a new AuthnRequest
953    pub fn create_authn_request(&self) -> AuthnRequest {
954        AuthnRequest::new(&self.config)
955    }
956
957    /// Store a pending request
958    pub fn store_pending_request(&self, request: &AuthnRequest, relay_state: Option<String>) {
959        let mut pending = self.pending_requests.write().unwrap();
960        pending.insert(
961            request.id.clone(),
962            PendingRequest {
963                id: request.id.clone(),
964                created_at: timestamp_millis() / 1000,
965                relay_state,
966            },
967        );
968
969        // Clean up old requests (older than 10 minutes)
970        let now = timestamp_millis() / 1000;
971        pending.retain(|_, req| now - req.created_at < 600);
972    }
973
974    /// Initiate SAML login (returns redirect URL)
975    pub fn initiate_login(&self, relay_state: Option<&str>) -> String {
976        let request = self.create_authn_request();
977        self.store_pending_request(&request, relay_state.map(|s| s.to_string()));
978        request.to_redirect_url(relay_state)
979    }
980
981    /// Process SAML Response and extract user
982    pub fn process_response(&self, encoded_response: &str) -> Result<SamlAuthResult, SamlError> {
983        let response = SamlResponse::from_base64(encoded_response)?;
984
985        // Validate response
986        self.validate_response(&response)?;
987
988        // Extract user from assertion
989        let assertion = response
990            .assertion()
991            .ok_or_else(|| SamlError::InvalidResponse("No assertion in response".into()))?;
992
993        // Validate assertion
994        assertion.validate(&self.config)?;
995
996        // Map to User
997        let user = self.map_assertion_to_user(assertion)?;
998
999        // Clean up pending request
1000        if let Some(ref in_response_to) = response.in_response_to {
1001            let mut pending = self.pending_requests.write().unwrap();
1002            pending.remove(in_response_to);
1003        }
1004
1005        Ok(SamlAuthResult {
1006            user,
1007            session_index: assertion.session_index.clone(),
1008            name_id: assertion.name_id.clone(),
1009            name_id_format: assertion.name_id_format.clone(),
1010            attributes: assertion.attributes.clone(),
1011        })
1012    }
1013
1014    /// Validate a SAML Response
1015    fn validate_response(&self, response: &SamlResponse) -> Result<(), SamlError> {
1016        // Check status
1017        if !response.is_success() {
1018            return Err(SamlError::IdpError(
1019                response.status_code.clone(),
1020                response.status_message.clone(),
1021            ));
1022        }
1023
1024        // Check InResponseTo if not allowing unsolicited responses
1025        if !self.config.allow_unsolicited_responses {
1026            if let Some(ref in_response_to) = response.in_response_to {
1027                let pending = self.pending_requests.read().unwrap();
1028                if !pending.contains_key(in_response_to) {
1029                    return Err(SamlError::InvalidResponse(
1030                        "InResponseTo does not match any pending request".into(),
1031                    ));
1032                }
1033            } else {
1034                return Err(SamlError::InvalidResponse(
1035                    "Unsolicited responses are not allowed".into(),
1036                ));
1037            }
1038        }
1039
1040        // Check destination
1041        if let Some(ref destination) = response.destination {
1042            if destination != &self.config.acs_url {
1043                return Err(SamlError::InvalidResponse(format!(
1044                    "Destination mismatch: expected {}, got {}",
1045                    self.config.acs_url, destination
1046                )));
1047            }
1048        }
1049
1050        Ok(())
1051    }
1052
1053    /// Map SAML assertion to User
1054    fn map_assertion_to_user(&self, assertion: &SamlAssertion) -> Result<User, SamlError> {
1055        let username = assertion.name_id.clone();
1056
1057        // Use with_encoded_password since SAML users don't have local passwords
1058        let mut user = User::with_encoded_password(&username, "{saml}external".to_string());
1059
1060        // Map attributes
1061        for (saml_attr, user_field) in &self.config.attribute_mapping {
1062            if let Some(values) = assertion.attributes.get(saml_attr) {
1063                if let Some(value) = values.first() {
1064                    match user_field.as_str() {
1065                        "email" => {
1066                            // Store email in attributes (User doesn't have email field directly)
1067                            user = user.authorities(&[format!("email:{}", value)]);
1068                        }
1069                        "display_name" | "name" => {
1070                            // Could extend User model for this
1071                        }
1072                        _ => {}
1073                    }
1074                }
1075            }
1076        }
1077
1078        // Extract roles
1079        let mut roles: Vec<String> = self.config.default_roles.clone();
1080        if let Some(ref role_attr) = self.config.role_attribute {
1081            if let Some(values) = assertion.attributes.get(role_attr) {
1082                roles.extend(values.iter().map(|r| r.to_uppercase()));
1083            }
1084        }
1085        user = user.roles(&roles);
1086
1087        // Extract authorities
1088        if let Some(ref auth_attr) = self.config.authority_attribute {
1089            if let Some(values) = assertion.attributes.get(auth_attr) {
1090                user = user.authorities(values);
1091            }
1092        }
1093
1094        Ok(user)
1095    }
1096
1097    /// Generate SP metadata XML
1098    pub fn generate_metadata(&self) -> String {
1099        let mut xml = String::new();
1100        xml.push_str(r#"<?xml version="1.0" encoding="UTF-8"?>"#);
1101        xml.push_str(&format!(
1102            r#"<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" entityID="{}">"#,
1103            self.config.entity_id
1104        ));
1105
1106        xml.push_str(r#"<md:SPSSODescriptor AuthnRequestsSigned=""#);
1107        xml.push_str(if self.config.sign_authn_request {
1108            "true"
1109        } else {
1110            "false"
1111        });
1112        xml.push_str(r#"" WantAssertionsSigned=""#);
1113        xml.push_str(if self.config.want_assertions_signed {
1114            "true"
1115        } else {
1116            "false"
1117        });
1118        xml.push_str(r#"" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">"#);
1119
1120        // NameIDFormat
1121        xml.push_str(&format!(
1122            r#"<md:NameIDFormat>{}</md:NameIDFormat>"#,
1123            self.config.name_id_format.as_urn()
1124        ));
1125
1126        // ACS
1127        xml.push_str(&format!(
1128            r#"<md:AssertionConsumerService Binding="{}" Location="{}" index="0"/>"#,
1129            SamlBinding::HttpPost.as_urn(),
1130            self.config.acs_url
1131        ));
1132
1133        // SLS (if configured)
1134        if let Some(ref sls_url) = self.config.sls_url {
1135            xml.push_str(&format!(
1136                r#"<md:SingleLogoutService Binding="{}" Location="{}"/>"#,
1137                self.config.slo_binding.as_urn(),
1138                sls_url
1139            ));
1140        }
1141
1142        xml.push_str(r#"</md:SPSSODescriptor></md:EntityDescriptor>"#);
1143        xml
1144    }
1145
1146    /// Create logout request URL
1147    pub fn create_logout_request(
1148        &self,
1149        name_id: &str,
1150        session_index: Option<&str>,
1151    ) -> Option<String> {
1152        let slo_url = self.config.idp_slo_url.as_ref()?;
1153
1154        let id = format!("_{}_{}", generate_id(), timestamp_millis());
1155        let issue_instant = iso8601_now();
1156
1157        let mut xml = String::new();
1158        xml.push_str(r#"<?xml version="1.0" encoding="UTF-8"?>"#);
1159        xml.push_str(&format!(
1160            r#"<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="{}" Version="2.0" IssueInstant="{}" Destination="{}">"#,
1161            id, issue_instant, slo_url
1162        ));
1163
1164        xml.push_str(&format!(
1165            r#"<saml:Issuer>{}</saml:Issuer>"#,
1166            self.config.entity_id
1167        ));
1168
1169        xml.push_str(&format!(
1170            r#"<saml:NameID Format="{}">{}</saml:NameID>"#,
1171            self.config.name_id_format.as_urn(),
1172            name_id
1173        ));
1174
1175        if let Some(session_idx) = session_index {
1176            xml.push_str(&format!(
1177                r#"<samlp:SessionIndex>{}</samlp:SessionIndex>"#,
1178                session_idx
1179            ));
1180        }
1181
1182        xml.push_str(r#"</samlp:LogoutRequest>"#);
1183
1184        let deflated = deflate_and_encode(&xml);
1185        Some(format!(
1186            "{}?SAMLRequest={}",
1187            slo_url,
1188            urlencoding::encode(&deflated)
1189        ))
1190    }
1191}
1192
1193/// Result of successful SAML authentication
1194#[derive(Debug, Clone)]
1195pub struct SamlAuthResult {
1196    /// Authenticated user
1197    pub user: User,
1198    /// Session index from IdP
1199    pub session_index: Option<String>,
1200    /// Name ID from assertion
1201    pub name_id: String,
1202    /// Name ID format
1203    pub name_id_format: NameIdFormat,
1204    /// All attributes from assertion
1205    pub attributes: HashMap<String, Vec<String>>,
1206}
1207
1208// ============================================================================
1209// Helper functions
1210// ============================================================================
1211
1212/// Generate a random ID
1213fn generate_id() -> String {
1214    use std::collections::hash_map::RandomState;
1215    use std::hash::{BuildHasher, Hasher};
1216
1217    let hasher = RandomState::new();
1218    let mut h = hasher.build_hasher();
1219    h.write_u64(timestamp_millis());
1220    format!("{:016x}", h.finish())
1221}
1222
1223/// Get current timestamp in milliseconds
1224fn timestamp_millis() -> u64 {
1225    SystemTime::now()
1226        .duration_since(UNIX_EPOCH)
1227        .unwrap()
1228        .as_millis() as u64
1229}
1230
1231/// Get current time in ISO 8601 format
1232fn iso8601_now() -> String {
1233    let now = SystemTime::now()
1234        .duration_since(UNIX_EPOCH)
1235        .unwrap()
1236        .as_secs();
1237
1238    // Simple ISO 8601 formatting (in production, use chrono or time crate)
1239    let secs_per_minute = 60;
1240    let secs_per_hour = 3600;
1241    let secs_per_day = 86400;
1242
1243    let days_since_1970 = now / secs_per_day;
1244    let time_of_day = now % secs_per_day;
1245
1246    let hours = time_of_day / secs_per_hour;
1247    let minutes = (time_of_day % secs_per_hour) / secs_per_minute;
1248    let seconds = time_of_day % secs_per_minute;
1249
1250    // Simple year/month/day calculation (not accounting for leap seconds perfectly)
1251    let (year, month, day) = days_to_ymd(days_since_1970);
1252
1253    format!(
1254        "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
1255        year, month, day, hours, minutes, seconds
1256    )
1257}
1258
1259/// Convert days since 1970 to year/month/day
1260fn days_to_ymd(days: u64) -> (u64, u64, u64) {
1261    // Simplified calculation
1262    let mut remaining = days;
1263    let mut year = 1970u64;
1264
1265    loop {
1266        let days_in_year = if is_leap_year(year) { 366 } else { 365 };
1267        if remaining < days_in_year {
1268            break;
1269        }
1270        remaining -= days_in_year;
1271        year += 1;
1272    }
1273
1274    let months = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
1275    let mut month = 1u64;
1276
1277    for (i, &days_in_month) in months.iter().enumerate() {
1278        let days_in_month = if i == 1 && is_leap_year(year) {
1279            29
1280        } else {
1281            days_in_month
1282        };
1283        if remaining < days_in_month {
1284            break;
1285        }
1286        remaining -= days_in_month;
1287        month += 1;
1288    }
1289
1290    (year, month, remaining + 1)
1291}
1292
1293fn is_leap_year(year: u64) -> bool {
1294    (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
1295}
1296
1297/// Parse ISO 8601 date to Unix timestamp (simplified)
1298fn parse_iso8601(s: &str) -> Result<u64, ()> {
1299    // Expected format: 2024-01-15T10:30:00Z
1300    if s.len() < 19 {
1301        return Err(());
1302    }
1303
1304    let year: u64 = s[0..4].parse().map_err(|_| ())?;
1305    let month: u64 = s[5..7].parse().map_err(|_| ())?;
1306    let day: u64 = s[8..10].parse().map_err(|_| ())?;
1307    let hour: u64 = s[11..13].parse().map_err(|_| ())?;
1308    let minute: u64 = s[14..16].parse().map_err(|_| ())?;
1309    let second: u64 = s[17..19].parse().map_err(|_| ())?;
1310
1311    // Convert to Unix timestamp
1312    let mut days = 0u64;
1313    for y in 1970..year {
1314        days += if is_leap_year(y) { 366 } else { 365 };
1315    }
1316
1317    let months = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
1318    for (i, &d) in months.iter().take((month - 1) as usize).enumerate() {
1319        days += if i == 1 && is_leap_year(year) {
1320            29
1321        } else {
1322            d
1323        };
1324    }
1325    days += day - 1;
1326
1327    Ok(days * 86400 + hour * 3600 + minute * 60 + second)
1328}
1329
1330/// Deflate and base64-encode XML for HTTP Redirect binding
1331fn deflate_and_encode(xml: &str) -> String {
1332    use base64::{engine::general_purpose::STANDARD, Engine as _};
1333
1334    // In production, use flate2 crate for proper DEFLATE compression
1335    // For now, just base64 encode (many IdPs accept this)
1336    STANDARD.encode(xml)
1337}
1338
1339/// Extract an attribute from XML element (simplified parser)
1340fn extract_attribute(xml: &str, element: &str, attr: &str) -> Option<String> {
1341    let element_pattern = format!("<{}", element);
1342    let start = xml.find(&element_pattern)?;
1343    let end = xml[start..].find('>')? + start;
1344    let element_str = &xml[start..end];
1345
1346    let attr_pattern = format!("{}=\"", attr);
1347    let attr_start = element_str.find(&attr_pattern)? + attr_pattern.len();
1348    let attr_end = element_str[attr_start..].find('"')? + attr_start;
1349
1350    Some(element_str[attr_start..attr_end].to_string())
1351}
1352
1353/// Extract element text content (simplified parser)
1354fn extract_element_text(xml: &str, element: &str) -> Option<String> {
1355    // Handle namespaced elements
1356    let patterns = [
1357        format!("<{}:", element),
1358        format!("<{}>", element),
1359    ];
1360
1361    for pattern in &patterns {
1362        if let Some(start) = xml.find(pattern) {
1363            let content_start = xml[start..].find('>')? + start + 1;
1364            let end_pattern = format!("</{}", element);
1365            if let Some(end) = xml[content_start..].find(&end_pattern) {
1366                let content = &xml[content_start..content_start + end];
1367                // Handle nested elements by finding the actual close
1368                if let Some(actual_end) = content.rfind('<') {
1369                    return Some(content[..actual_end].trim().to_string());
1370                }
1371                return Some(content.trim().to_string());
1372            }
1373        }
1374    }
1375
1376    // Try without namespace prefix
1377    let start_tag = format!("<{}>", element);
1378    let end_tag = format!("</{}>", element);
1379    if let Some(start) = xml.find(&start_tag) {
1380        let content_start = start + start_tag.len();
1381        if let Some(end) = xml[content_start..].find(&end_tag) {
1382            return Some(xml[content_start..content_start + end].trim().to_string());
1383        }
1384    }
1385
1386    None
1387}
1388
1389/// Extract SAML status code (simplified parser)
1390fn extract_status_code(xml: &str) -> Option<SamlStatusCode> {
1391    let pattern = "StatusCode";
1392    let start = xml.find(pattern)?;
1393    let value_start = xml[start..].find("Value=\"")? + start + 7;
1394    let value_end = xml[value_start..].find('"')? + value_start;
1395    let value = &xml[value_start..value_end];
1396
1397    Some(SamlStatusCode::from_urn(value))
1398}
1399
1400/// Parse assertions from SAML response (simplified parser)
1401fn parse_assertions(xml: &str) -> Result<Vec<SamlAssertion>, SamlError> {
1402    let mut assertions = Vec::new();
1403
1404    // Find Assertion element
1405    let assertion_pattern = "<saml:Assertion";
1406    if let Some(start) = xml.find(assertion_pattern) {
1407        let assertion_xml = &xml[start..];
1408
1409        let id = extract_attribute(assertion_xml, "Assertion", "ID")
1410            .unwrap_or_else(|| format!("_generated_{}", timestamp_millis()));
1411        let issue_instant = extract_attribute(assertion_xml, "Assertion", "IssueInstant")
1412            .unwrap_or_default();
1413        let issuer = extract_element_text(assertion_xml, "Issuer").unwrap_or_default();
1414
1415        // Parse NameID
1416        let name_id = extract_element_text(assertion_xml, "NameID").unwrap_or_default();
1417        let name_id_format = extract_attribute(assertion_xml, "NameID", "Format")
1418            .map(|f| NameIdFormat::from_urn(&f))
1419            .unwrap_or_default();
1420
1421        // Parse conditions
1422        let not_before = extract_attribute(assertion_xml, "Conditions", "NotBefore");
1423        let not_on_or_after = extract_attribute(assertion_xml, "Conditions", "NotOnOrAfter");
1424
1425        // Parse session index
1426        let session_index = extract_attribute(assertion_xml, "AuthnStatement", "SessionIndex");
1427        let session_not_on_or_after =
1428            extract_attribute(assertion_xml, "AuthnStatement", "SessionNotOnOrAfter");
1429
1430        // Parse audience
1431        let audiences = extract_element_text(assertion_xml, "Audience")
1432            .map(|a| vec![a])
1433            .unwrap_or_default();
1434
1435        // Parse authn context
1436        let authn_context_class =
1437            extract_element_text(assertion_xml, "AuthnContextClassRef");
1438
1439        // Parse attributes
1440        let attributes = parse_attributes(assertion_xml);
1441
1442        assertions.push(SamlAssertion {
1443            id,
1444            issue_instant,
1445            issuer,
1446            name_id,
1447            name_id_format,
1448            session_index,
1449            session_not_on_or_after,
1450            not_before,
1451            not_on_or_after,
1452            audiences,
1453            authn_context_class,
1454            attributes,
1455        });
1456    }
1457
1458    Ok(assertions)
1459}
1460
1461/// Parse SAML attributes (simplified parser)
1462fn parse_attributes(xml: &str) -> HashMap<String, Vec<String>> {
1463    let mut attributes = HashMap::new();
1464
1465    // Find AttributeStatement
1466    let attr_statement = "<AttributeStatement";
1467    if let Some(start) = xml.find(attr_statement) {
1468        let end = xml[start..]
1469            .find("</AttributeStatement>")
1470            .unwrap_or(xml.len() - start);
1471        let statement_xml = &xml[start..start + end];
1472
1473        // Find each Attribute
1474        let mut search_pos = 0;
1475        while let Some(attr_start) = statement_xml[search_pos..].find("<Attribute ") {
1476            let attr_start = search_pos + attr_start;
1477
1478            // Get attribute name
1479            if let Some(name) = extract_attribute(&statement_xml[attr_start..], "Attribute", "Name")
1480            {
1481                // Find attribute values
1482                let attr_end = statement_xml[attr_start..]
1483                    .find("</Attribute>")
1484                    .unwrap_or(statement_xml.len() - attr_start);
1485                let attr_xml = &statement_xml[attr_start..attr_start + attr_end];
1486
1487                let mut values = Vec::new();
1488                let mut value_pos = 0;
1489                while let Some(value_start) = attr_xml[value_pos..].find("<AttributeValue") {
1490                    let value_start = value_pos + value_start;
1491                    if let Some(content_start) = attr_xml[value_start..].find('>') {
1492                        let content_start = value_start + content_start + 1;
1493                        if let Some(content_end) = attr_xml[content_start..].find("</AttributeValue>")
1494                        {
1495                            let value = attr_xml[content_start..content_start + content_end].trim();
1496                            values.push(value.to_string());
1497                        }
1498                    }
1499                    value_pos = value_start + 1;
1500                }
1501
1502                if !values.is_empty() {
1503                    attributes.insert(name, values);
1504                }
1505            }
1506
1507            search_pos = attr_start + 1;
1508        }
1509    }
1510
1511    attributes
1512}
1513
1514// ============================================================================
1515// Tests
1516// ============================================================================
1517
1518#[cfg(test)]
1519mod tests {
1520    use super::*;
1521
1522    #[test]
1523    fn test_name_id_format() {
1524        assert_eq!(
1525            NameIdFormat::EmailAddress.as_urn(),
1526            "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
1527        );
1528
1529        let parsed = NameIdFormat::from_urn("urn:oasis:names:tc:SAML:2.0:nameid-format:persistent");
1530        assert_eq!(parsed, NameIdFormat::Persistent);
1531    }
1532
1533    #[test]
1534    fn test_saml_binding() {
1535        assert_eq!(
1536            SamlBinding::HttpPost.as_urn(),
1537            "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
1538        );
1539    }
1540
1541    #[test]
1542    fn test_status_code() {
1543        let success = SamlStatusCode::from_urn("urn:oasis:names:tc:SAML:2.0:status:Success");
1544        assert!(success.is_success());
1545
1546        let failure = SamlStatusCode::from_urn("urn:oasis:names:tc:SAML:2.0:status:AuthnFailed");
1547        assert!(!failure.is_success());
1548    }
1549
1550    #[test]
1551    fn test_config_builder() {
1552        let config = SamlConfig::new()
1553            .entity_id("https://sp.example.com/saml/metadata")
1554            .idp_sso_url("https://idp.example.com/saml/sso")
1555            .acs_url("https://sp.example.com/saml/acs")
1556            .name_id_format(NameIdFormat::EmailAddress)
1557            .want_assertions_signed(true);
1558
1559        assert_eq!(config.entity_id, "https://sp.example.com/saml/metadata");
1560        assert_eq!(config.name_id_format, NameIdFormat::EmailAddress);
1561    }
1562
1563    #[test]
1564    fn test_config_validation() {
1565        let incomplete = SamlConfig::new();
1566        assert!(incomplete.validate().is_err());
1567
1568        let valid = SamlConfig::new()
1569            .entity_id("https://sp.example.com")
1570            .idp_sso_url("https://idp.example.com/sso")
1571            .acs_url("https://sp.example.com/acs")
1572            .want_assertions_signed(false);
1573
1574        assert!(valid.validate().is_ok());
1575    }
1576
1577    #[test]
1578    fn test_config_presets() {
1579        let okta = SamlConfig::okta("myorg.okta.com", "app123", "https://myapp.com");
1580        assert!(okta.idp_sso_url.contains("okta.com"));
1581
1582        let azure = SamlConfig::azure_ad("tenant-id", "app-id", "https://myapp.com");
1583        assert!(azure.idp_sso_url.contains("microsoftonline.com"));
1584
1585        let adfs = SamlConfig::adfs("adfs.company.com", "https://myapp.com");
1586        assert!(adfs.idp_sso_url.contains("adfs"));
1587    }
1588
1589    #[test]
1590    fn test_authn_request_generation() {
1591        let config = SamlConfig::new()
1592            .entity_id("https://sp.example.com")
1593            .idp_sso_url("https://idp.example.com/sso")
1594            .acs_url("https://sp.example.com/acs")
1595            .name_id_format(NameIdFormat::EmailAddress);
1596
1597        let request = AuthnRequest::new(&config);
1598        let xml = request.to_xml();
1599
1600        assert!(xml.contains("AuthnRequest"));
1601        assert!(xml.contains("https://sp.example.com"));
1602        assert!(xml.contains("emailAddress"));
1603    }
1604
1605    #[test]
1606    fn test_authn_request_url() {
1607        let config = SamlConfig::new()
1608            .entity_id("https://sp.example.com")
1609            .idp_sso_url("https://idp.example.com/sso")
1610            .acs_url("https://sp.example.com/acs");
1611
1612        let request = AuthnRequest::new(&config);
1613        let url = request.to_redirect_url(Some("/dashboard"));
1614
1615        assert!(url.starts_with("https://idp.example.com/sso?"));
1616        assert!(url.contains("SAMLRequest="));
1617        assert!(url.contains("RelayState="));
1618    }
1619
1620    #[test]
1621    fn test_assertion_validation() {
1622        let config = SamlConfig::new()
1623            .entity_id("https://sp.example.com")
1624            .idp_sso_url("https://idp.example.com/sso")
1625            .acs_url("https://sp.example.com/acs")
1626            .idp_entity_id("https://idp.example.com")
1627            .max_clock_skew(Duration::from_secs(300));
1628
1629        let assertion = SamlAssertion {
1630            id: "_test".to_string(),
1631            issue_instant: iso8601_now(),
1632            issuer: "https://idp.example.com".to_string(),
1633            name_id: "user@example.com".to_string(),
1634            name_id_format: NameIdFormat::EmailAddress,
1635            session_index: Some("_session123".to_string()),
1636            session_not_on_or_after: None,
1637            not_before: None,
1638            not_on_or_after: None,
1639            audiences: vec!["https://sp.example.com".to_string()],
1640            authn_context_class: None,
1641            attributes: HashMap::new(),
1642        };
1643
1644        assert!(assertion.validate(&config).is_ok());
1645    }
1646
1647    #[test]
1648    fn test_assertion_audience_validation() {
1649        let config = SamlConfig::new()
1650            .entity_id("https://sp.example.com")
1651            .idp_sso_url("https://idp.example.com/sso")
1652            .acs_url("https://sp.example.com/acs");
1653
1654        let assertion = SamlAssertion {
1655            id: "_test".to_string(),
1656            issue_instant: iso8601_now(),
1657            issuer: "https://idp.example.com".to_string(),
1658            name_id: "user@example.com".to_string(),
1659            name_id_format: NameIdFormat::EmailAddress,
1660            session_index: None,
1661            session_not_on_or_after: None,
1662            not_before: None,
1663            not_on_or_after: None,
1664            audiences: vec!["https://other-sp.example.com".to_string()],
1665            authn_context_class: None,
1666            attributes: HashMap::new(),
1667        };
1668
1669        let result = assertion.validate(&config);
1670        assert!(matches!(result, Err(SamlError::AudienceRestrictionNotMet(_))));
1671    }
1672
1673    #[test]
1674    fn test_authenticator_creation() {
1675        let config = SamlConfig::new()
1676            .entity_id("https://sp.example.com")
1677            .idp_sso_url("https://idp.example.com/sso")
1678            .acs_url("https://sp.example.com/acs")
1679            .want_assertions_signed(false);
1680
1681        let authenticator = SamlAuthenticator::new(config);
1682        assert!(authenticator.is_ok());
1683    }
1684
1685    #[test]
1686    fn test_metadata_generation() {
1687        let config = SamlConfig::new()
1688            .entity_id("https://sp.example.com")
1689            .idp_sso_url("https://idp.example.com/sso")
1690            .acs_url("https://sp.example.com/acs")
1691            .sls_url("https://sp.example.com/sls")
1692            .want_assertions_signed(false);
1693
1694        let authenticator = SamlAuthenticator::new(config).unwrap();
1695        let metadata = authenticator.generate_metadata();
1696
1697        assert!(metadata.contains("EntityDescriptor"));
1698        assert!(metadata.contains("https://sp.example.com"));
1699        assert!(metadata.contains("AssertionConsumerService"));
1700        assert!(metadata.contains("SingleLogoutService"));
1701    }
1702
1703    #[test]
1704    fn test_iso8601_generation() {
1705        let now = iso8601_now();
1706        assert!(now.contains("T"));
1707        assert!(now.ends_with("Z"));
1708        assert_eq!(now.len(), 20);
1709    }
1710
1711    #[test]
1712    fn test_iso8601_parsing() {
1713        let timestamp = parse_iso8601("2024-01-15T10:30:00Z");
1714        assert!(timestamp.is_ok());
1715
1716        let invalid = parse_iso8601("invalid");
1717        assert!(invalid.is_err());
1718    }
1719
1720    #[test]
1721    fn test_xml_attribute_extraction() {
1722        let xml = r#"<Response ID="resp123" Version="2.0">"#;
1723        assert_eq!(
1724            extract_attribute(xml, "Response", "ID"),
1725            Some("resp123".to_string())
1726        );
1727    }
1728
1729    #[test]
1730    fn test_attribute_parsing() {
1731        let xml = r#"
1732        <AttributeStatement>
1733            <Attribute Name="email">
1734                <AttributeValue>user@example.com</AttributeValue>
1735            </Attribute>
1736            <Attribute Name="roles">
1737                <AttributeValue>admin</AttributeValue>
1738                <AttributeValue>user</AttributeValue>
1739            </Attribute>
1740        </AttributeStatement>
1741        "#;
1742
1743        let attrs = parse_attributes(xml);
1744        assert_eq!(
1745            attrs.get("email"),
1746            Some(&vec!["user@example.com".to_string()])
1747        );
1748        assert_eq!(
1749            attrs.get("roles"),
1750            Some(&vec!["admin".to_string(), "user".to_string()])
1751        );
1752    }
1753
1754    #[test]
1755    fn test_saml_error_display() {
1756        let err = SamlError::Configuration("test error".to_string());
1757        let display = format!("{}", err);
1758        assert!(display.contains("test error"));
1759    }
1760
1761    #[test]
1762    fn test_logout_request() {
1763        let config = SamlConfig::new()
1764            .entity_id("https://sp.example.com")
1765            .idp_sso_url("https://idp.example.com/sso")
1766            .idp_slo_url("https://idp.example.com/slo")
1767            .acs_url("https://sp.example.com/acs")
1768            .want_assertions_signed(false);
1769
1770        let authenticator = SamlAuthenticator::new(config).unwrap();
1771        let logout_url =
1772            authenticator.create_logout_request("user@example.com", Some("_session123"));
1773
1774        assert!(logout_url.is_some());
1775        let url = logout_url.unwrap();
1776        assert!(url.contains("SAMLRequest="));
1777    }
1778}