Skip to main content

auth_framework/protocols/
ws_federation.rs

1//! WS-Federation Passive Requestor Profile
2//!
3//! Implements the WS-Federation 1.2 passive requestor profile for browser-based
4//! SSO. WS-Federation is commonly used with Active Directory Federation Services
5//! (ADFS) and Azure AD in legacy enterprise environments.
6//!
7//! # Protocol Flow (Passive Requestor)
8//!
9//! 1. Application redirects the user to the STS sign-in URL with `wa=wsignin1.0`
10//! 2. STS authenticates the user and posts a security token (SAML assertion)
11//!    back to the application's reply URL (`wtrealm`)
12//! 3. Application validates the security token and establishes a session
13//! 4. Logout: redirect to STS with `wa=wsignout1.0`
14//!
15//! # Security Considerations
16//!
17//! - Federation metadata should be fetched over HTTPS and cached
18//! - Security tokens (SAML assertions) must be validated against the STS
19//!   signing certificate
20//! - Replay protection via `wctx` state parameter
21
22use crate::errors::{AuthError, Result};
23use ring::rand::SecureRandom;
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::time::{SystemTime, UNIX_EPOCH};
27
28// ─── WS-Federation Constants ─────────────────────────────────────────────────
29
30/// WS-Federation action values.
31pub mod action {
32    /// Sign-in action.
33    pub const SIGN_IN: &str = "wsignin1.0";
34    /// Sign-out action.
35    pub const SIGN_OUT: &str = "wsignout1.0";
36    /// Sign-out cleanup action.
37    pub const SIGN_OUT_CLEANUP: &str = "wsignoutcleanup1.0";
38    /// Attribute request action.
39    pub const ATTRIBUTE: &str = "wattr1.0";
40}
41
42/// WS-Federation XML namespace URIs.
43pub mod ns {
44    pub const WS_FED: &str = "http://docs.oasis-open.org/wsfed/federation/200706";
45    pub const WS_TRUST: &str = "http://docs.oasis-open.org/ws-sx/ws-trust/200512";
46    pub const WS_ADDRESSING: &str = "http://www.w3.org/2005/08/addressing";
47    pub const SAML_11_ASSERTION: &str = "urn:oasis:names:tc:SAML:1.0:assertion";
48    pub const SAML_20_ASSERTION: &str = "urn:oasis:names:tc:SAML:2.0:assertion";
49}
50
51// ─── Configuration ───────────────────────────────────────────────────────────
52
53/// WS-Federation Relying Party (RP) configuration.
54#[derive(Debug, Clone)]
55pub struct WsFederationConfig {
56    /// STS (Security Token Service) sign-in URL.
57    pub sts_url: String,
58
59    /// This relying party's realm (wtrealm / appliesTo).
60    pub realm: String,
61
62    /// Reply URL where the STS posts the token back (wreply).
63    pub reply_url: String,
64
65    /// Federation metadata URL for dynamic trust configuration.
66    pub metadata_url: Option<String>,
67
68    /// Trusted issuer identifiers.
69    pub trusted_issuers: Vec<String>,
70
71    /// Trusted signing certificate fingerprints (SHA-256 hex).
72    pub trusted_cert_thumbprints: Vec<String>,
73
74    /// Maximum allowed clock skew for token validation (seconds).
75    pub max_clock_skew_secs: u64,
76
77    /// Whether to require encrypted tokens.
78    pub require_encrypted_tokens: bool,
79
80    /// HTTP request timeout.
81    pub timeout_secs: u64,
82
83    /// Custom home realm for IdP discovery (whr parameter).
84    pub home_realm: Option<String>,
85}
86
87impl Default for WsFederationConfig {
88    fn default() -> Self {
89        Self {
90            sts_url: String::new(),
91            realm: String::new(),
92            reply_url: String::new(),
93            metadata_url: None,
94            trusted_issuers: Vec::new(),
95            trusted_cert_thumbprints: Vec::new(),
96            max_clock_skew_secs: 300,
97            require_encrypted_tokens: false,
98            timeout_secs: 10,
99            home_realm: None,
100        }
101    }
102}
103
104// ─── Data Types ──────────────────────────────────────────────────────────────
105
106/// Parsed WS-Federation sign-in response.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct WsFedSignInResponse {
109    /// The action (`wa` parameter) — should be `wsignin1.0`.
110    pub action: String,
111
112    /// The `wresult` — contains the `<RequestSecurityTokenResponse>`.
113    pub result_xml: String,
114
115    /// The `wctx` context/state parameter echoed back from the STS.
116    pub context: Option<String>,
117}
118
119/// Validated security token extracted from a WS-Federation response.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct WsFedSecurityToken {
122    /// Token type (SAML 1.1 or SAML 2.0).
123    pub token_type: WsFedTokenType,
124
125    /// Authenticated subject / name identifier.
126    pub subject: String,
127
128    /// Token issuer.
129    pub issuer: String,
130
131    /// Token audience (appliesTo / realm).
132    pub audience: String,
133
134    /// When the token was issued.
135    pub issued_at: u64,
136
137    /// When the token expires.
138    pub expires_at: u64,
139
140    /// Claims / attributes from the token.
141    pub claims: HashMap<String, Vec<String>>,
142
143    /// Raw assertion XML (for downstream processing).
144    pub raw_assertion: String,
145}
146
147/// Security token type.
148#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
149pub enum WsFedTokenType {
150    /// SAML 1.1 assertion.
151    Saml11,
152    /// SAML 2.0 assertion.
153    Saml20,
154    /// JWT (non-standard but used by Azure AD).
155    Jwt,
156}
157
158/// Federation metadata for an STS.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct FederationMetadata {
161    /// STS entity ID.
162    pub entity_id: String,
163
164    /// Passive requestor endpoint URL.
165    pub passive_endpoint: Option<String>,
166
167    /// Signing certificate(s) as Base64-encoded DER.
168    pub signing_certificates: Vec<String>,
169
170    /// Token types offered.
171    pub token_types_offered: Vec<String>,
172
173    /// Claim types offered.
174    pub claim_types_offered: Vec<ClaimTypeOffered>,
175}
176
177/// Claim type advertised in federation metadata.
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ClaimTypeOffered {
180    pub uri: String,
181    pub display_name: Option<String>,
182    pub description: Option<String>,
183}
184
185// ─── Client ──────────────────────────────────────────────────────────────────
186
187/// WS-Federation Relying Party client.
188#[derive(Debug)]
189pub struct WsFederationClient {
190    config: WsFederationConfig,
191    http: reqwest::Client,
192}
193
194impl WsFederationClient {
195    /// Create a new WS-Federation client.
196    pub fn new(config: WsFederationConfig) -> Result<Self> {
197        if config.sts_url.is_empty() {
198            return Err(AuthError::config("WS-Federation STS URL must be set"));
199        }
200        if !config.sts_url.starts_with("https://") {
201            return Err(AuthError::config("WS-Federation STS URL must use HTTPS"));
202        }
203        if config.realm.is_empty() {
204            return Err(AuthError::config("WS-Federation realm must be set"));
205        }
206
207        let http = reqwest::Client::builder()
208            .timeout(std::time::Duration::from_secs(config.timeout_secs))
209            .build()
210            .map_err(|e| AuthError::internal(format!("Failed to build HTTP client: {e}")))?;
211
212        Ok(Self { config, http })
213    }
214
215    /// Generate the WS-Federation sign-in redirect URL.
216    ///
217    /// Returns `(url, wctx)` where `wctx` is a random state value for
218    /// CSRF protection.
219    pub fn sign_in_url(&self) -> Result<(String, String)> {
220        let rng = ring::rand::SystemRandom::new();
221        let mut ctx_bytes = [0u8; 16];
222        rng.fill(&mut ctx_bytes)
223            .map_err(|_| AuthError::crypto("Failed to generate wctx nonce"))?;
224        let wctx = hex::encode(ctx_bytes);
225
226        let mut url = format!(
227            "{}?wa={}&wtrealm={}&wreply={}&wctx={}",
228            self.config.sts_url,
229            urlencoding::encode(action::SIGN_IN),
230            urlencoding::encode(&self.config.realm),
231            urlencoding::encode(&self.config.reply_url),
232            urlencoding::encode(&wctx),
233        );
234
235        if let Some(ref whr) = self.config.home_realm {
236            url.push_str(&format!("&whr={}", urlencoding::encode(whr)));
237        }
238
239        Ok((url, wctx))
240    }
241
242    /// Generate the WS-Federation sign-out URL.
243    pub fn sign_out_url(&self) -> String {
244        format!(
245            "{}?wa={}&wtrealm={}",
246            self.config.sts_url,
247            urlencoding::encode(action::SIGN_OUT),
248            urlencoding::encode(&self.config.realm),
249        )
250    }
251
252    /// Process a WS-Federation sign-in response (POST from STS).
253    ///
254    /// The parameters are extracted from the form POST body:
255    /// - `wa`: action (should be `wsignin1.0`)
256    /// - `wresult`: the `<RequestSecurityTokenResponse>` XML
257    /// - `wctx`: echoed state parameter
258    pub fn process_sign_in_response(
259        &self,
260        wa: &str,
261        wresult: &str,
262        wctx: Option<&str>,
263        expected_wctx: &str,
264    ) -> Result<WsFedSecurityToken> {
265        // Validate action
266        if wa != action::SIGN_IN {
267            return Err(AuthError::validation(format!(
268                "Unexpected WS-Fed action: {wa}"
269            )));
270        }
271
272        // Validate wctx for CSRF protection
273        if let Some(ctx) = wctx {
274            if !constant_time_eq(ctx.as_bytes(), expected_wctx.as_bytes()) {
275                return Err(AuthError::validation("WS-Federation wctx mismatch (CSRF)"));
276            }
277        } else {
278            return Err(AuthError::validation("Missing wctx parameter"));
279        }
280
281        // Parse the RSTR envelope
282        let token = self.parse_rstr(wresult)?;
283
284        // Validate token
285        self.validate_token(&token)?;
286
287        Ok(token)
288    }
289
290    /// Fetch and parse federation metadata from the metadata URL.
291    pub async fn fetch_metadata(&self) -> Result<FederationMetadata> {
292        let url = self
293            .config
294            .metadata_url
295            .as_deref()
296            .ok_or_else(|| AuthError::config("Federation metadata URL not configured"))?;
297
298        let resp =
299            self.http.get(url).send().await.map_err(|e| {
300                AuthError::internal(format!("Federation metadata fetch failed: {e}"))
301            })?;
302
303        if !resp.status().is_success() {
304            let status = resp.status();
305            return Err(AuthError::internal(format!(
306                "Federation metadata HTTP error: {status}"
307            )));
308        }
309
310        let body = resp
311            .text()
312            .await
313            .map_err(|e| AuthError::internal(format!("Federation metadata read failed: {e}")))?;
314
315        parse_federation_metadata(&body)
316    }
317
318    /// Parse a RequestSecurityTokenResponse (RSTR) to extract the security token.
319    fn parse_rstr(&self, rstr_xml: &str) -> Result<WsFedSecurityToken> {
320        // Determine token type
321        let token_type = if rstr_xml.contains(ns::SAML_20_ASSERTION) {
322            WsFedTokenType::Saml20
323        } else if rstr_xml.contains(ns::SAML_11_ASSERTION) {
324            WsFedTokenType::Saml11
325        } else if rstr_xml.contains("\"JWT\"") || rstr_xml.contains("jwt") {
326            WsFedTokenType::Jwt
327        } else {
328            WsFedTokenType::Saml20 // Default assumption
329        };
330
331        // Extract the assertion
332        let raw_assertion = extract_assertion(rstr_xml)?;
333
334        // Parse claims from the assertion
335        let (subject, issuer, audience, issued_at, expires_at, claims) = match token_type {
336            WsFedTokenType::Saml20 => parse_saml20_assertion(&raw_assertion)?,
337            WsFedTokenType::Saml11 => parse_saml11_assertion(&raw_assertion)?,
338            WsFedTokenType::Jwt => parse_jwt_token(&raw_assertion)?,
339        };
340
341        Ok(WsFedSecurityToken {
342            token_type,
343            subject,
344            issuer,
345            audience,
346            issued_at,
347            expires_at,
348            claims,
349            raw_assertion,
350        })
351    }
352
353    /// Validate a security token against the configuration.
354    fn validate_token(&self, token: &WsFedSecurityToken) -> Result<()> {
355        // Check issuer trust
356        if !self.config.trusted_issuers.is_empty()
357            && !self.config.trusted_issuers.contains(&token.issuer)
358        {
359            return Err(AuthError::validation(format!(
360                "Token issuer '{}' is not trusted",
361                token.issuer
362            )));
363        }
364
365        // Check audience
366        if !token.audience.is_empty() && token.audience != self.config.realm {
367            return Err(AuthError::validation(format!(
368                "Token audience '{}' does not match realm '{}'",
369                token.audience, self.config.realm
370            )));
371        }
372
373        // Check expiration
374        let now = SystemTime::now()
375            .duration_since(UNIX_EPOCH)
376            .map_err(|e| AuthError::internal(format!("Clock error: {e}")))?
377            .as_secs();
378
379        let skew = self.config.max_clock_skew_secs;
380        if token.expires_at + skew < now {
381            return Err(AuthError::validation("Security token has expired"));
382        }
383
384        if token.issued_at > now + skew {
385            return Err(AuthError::validation("Security token issued in the future"));
386        }
387
388        Ok(())
389    }
390}
391
392// ─── XML Parsing Helpers ─────────────────────────────────────────────────────
393
394/// Extract the SAML assertion from a RequestSecurityTokenResponse.
395fn extract_assertion(rstr: &str) -> Result<String> {
396    // Look for <Assertion> or <saml:Assertion>
397    let assertion_tags = [
398        ("saml:Assertion", "</saml:Assertion>"),
399        ("saml2:Assertion", "</saml2:Assertion>"),
400        ("Assertion", "</Assertion>"),
401    ];
402
403    for (open_tag, close_tag) in &assertion_tags {
404        let open = format!("<{open_tag}");
405        if let Some(start) = rstr.find(&open) {
406            if let Some(end) = rstr[start..].find(close_tag) {
407                return Ok(rstr[start..start + end + close_tag.len()].to_string());
408            }
409        }
410    }
411
412    Err(AuthError::validation(
413        "No SAML assertion found in WS-Federation response",
414    ))
415}
416
417/// Parse a SAML 2.0 assertion.
418fn parse_saml20_assertion(
419    xml: &str,
420) -> Result<(
421    String,
422    String,
423    String,
424    u64,
425    u64,
426    HashMap<String, Vec<String>>,
427)> {
428    let subject = extract_xml_text(xml, "NameID")
429        .or_else(|| extract_xml_text(xml, "saml:NameID"))
430        .unwrap_or_default();
431
432    let issuer = extract_xml_text(xml, "Issuer")
433        .or_else(|| extract_xml_text(xml, "saml:Issuer"))
434        .unwrap_or_default();
435
436    let audience = extract_xml_text(xml, "Audience")
437        .or_else(|| extract_xml_text(xml, "saml:Audience"))
438        .unwrap_or_default();
439
440    let not_before = extract_xml_attr_val(xml, "Conditions", "NotBefore")
441        .or_else(|| extract_xml_attr_val(xml, "saml:Conditions", "NotBefore"))
442        .and_then(|s| parse_iso_timestamp(&s))
443        .unwrap_or(0);
444
445    let not_on_or_after = extract_xml_attr_val(xml, "Conditions", "NotOnOrAfter")
446        .or_else(|| extract_xml_attr_val(xml, "saml:Conditions", "NotOnOrAfter"))
447        .and_then(|s| parse_iso_timestamp(&s))
448        .unwrap_or(u64::MAX);
449
450    let claims = extract_saml_attributes(xml);
451
452    Ok((
453        subject,
454        issuer,
455        audience,
456        not_before,
457        not_on_or_after,
458        claims,
459    ))
460}
461
462/// Parse a SAML 1.1 assertion.
463fn parse_saml11_assertion(
464    xml: &str,
465) -> Result<(
466    String,
467    String,
468    String,
469    u64,
470    u64,
471    HashMap<String, Vec<String>>,
472)> {
473    let subject = extract_xml_text(xml, "NameIdentifier")
474        .or_else(|| extract_xml_text(xml, "saml:NameIdentifier"))
475        .unwrap_or_default();
476
477    let issuer = extract_xml_attr_val(xml, "Assertion", "Issuer")
478        .or_else(|| extract_xml_attr_val(xml, "saml:Assertion", "Issuer"))
479        .unwrap_or_default();
480
481    let audience = extract_xml_text(xml, "Audience")
482        .or_else(|| extract_xml_text(xml, "saml:Audience"))
483        .unwrap_or_default();
484
485    let not_before = extract_xml_attr_val(xml, "Conditions", "NotBefore")
486        .and_then(|s| parse_iso_timestamp(&s))
487        .unwrap_or(0);
488
489    let not_on_or_after = extract_xml_attr_val(xml, "Conditions", "NotOnOrAfter")
490        .and_then(|s| parse_iso_timestamp(&s))
491        .unwrap_or(u64::MAX);
492
493    let claims = extract_saml_attributes(xml);
494
495    Ok((
496        subject,
497        issuer,
498        audience,
499        not_before,
500        not_on_or_after,
501        claims,
502    ))
503}
504
505/// Parse a JWT token extracted from a WS-Fed response (used by Azure AD).
506///
507/// NOTE: This performs payload extraction and validation of standard claims
508/// without cryptographic signature verification. Signature verification
509/// requires the IdP's public keys (from federation metadata JWKS).
510fn parse_jwt_token(
511    jwt_str: &str,
512) -> Result<(
513    String,
514    String,
515    String,
516    u64,
517    u64,
518    HashMap<String, Vec<String>>,
519)> {
520    use base64::Engine as _;
521    use base64::engine::general_purpose::URL_SAFE_NO_PAD;
522
523    // JWT has three base64url-encoded parts separated by dots
524    let parts: Vec<&str> = jwt_str.trim().split('.').collect();
525    if parts.len() != 3 {
526        return Err(AuthError::validation(
527            "Invalid JWT format: expected 3 parts",
528        ));
529    }
530
531    // Decode and parse the payload (second part)
532    let payload_bytes = URL_SAFE_NO_PAD
533        .decode(parts[1])
534        .map_err(|e| AuthError::validation(format!("Invalid JWT payload encoding: {e}")))?;
535
536    let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
537        .map_err(|e| AuthError::validation(format!("Invalid JWT payload JSON: {e}")))?;
538
539    let subject = payload["sub"]
540        .as_str()
541        .or_else(|| payload["upn"].as_str())
542        .or_else(|| payload["email"].as_str())
543        .unwrap_or_default()
544        .to_string();
545
546    let issuer = payload["iss"].as_str().unwrap_or_default().to_string();
547
548    let audience = payload["aud"].as_str().unwrap_or_default().to_string();
549
550    let issued_at = payload["iat"].as_u64().unwrap_or(0);
551    let expires_at = payload["exp"].as_u64().unwrap_or(u64::MAX);
552
553    // Extract all string claims
554    let mut claims = HashMap::new();
555    if let Some(obj) = payload.as_object() {
556        for (key, value) in obj {
557            match value {
558                serde_json::Value::String(s) => {
559                    claims
560                        .entry(key.clone())
561                        .or_insert_with(Vec::new)
562                        .push(s.clone());
563                }
564                serde_json::Value::Array(arr) => {
565                    for item in arr {
566                        if let Some(s) = item.as_str() {
567                            claims
568                                .entry(key.clone())
569                                .or_insert_with(Vec::new)
570                                .push(s.to_string());
571                        }
572                    }
573                }
574                _ => {}
575            }
576        }
577    }
578
579    Ok((subject, issuer, audience, issued_at, expires_at, claims))
580}
581
582/// Extract text content from an XML element.
583fn extract_xml_text(xml: &str, tag: &str) -> Option<String> {
584    let open = format!("<{tag}");
585    let close = format!("</{tag}>");
586    let mut search_from = 0;
587    loop {
588        let start = xml[search_from..].find(&open).map(|i| search_from + i)?;
589        // Ensure we matched the exact tag, not a prefix (e.g. "Audience" vs "AudienceRestriction")
590        let after_name = start + open.len();
591        if after_name < xml.len() {
592            let next_char = xml.as_bytes()[after_name];
593            if next_char != b'>' && next_char != b' ' && next_char != b'/' {
594                search_from = after_name;
595                continue;
596            }
597        }
598        let after_tag = xml[after_name..].find('>')?;
599        let content_start = after_name + after_tag + 1;
600        let content_end = xml[content_start..].find(&close)?;
601        return Some(
602            xml[content_start..content_start + content_end]
603                .trim()
604                .to_string(),
605        );
606    }
607}
608
609/// Extract an attribute value from an XML element.
610fn extract_xml_attr_val(xml: &str, tag: &str, attr: &str) -> Option<String> {
611    let open = format!("<{tag}");
612    let start = xml.find(&open)?;
613    let tag_end = xml[start..].find('>')?;
614    let tag_content = &xml[start..start + tag_end];
615    let attr_search = format!("{attr}=\"");
616    let attr_start = tag_content.find(&attr_search)?;
617    let val_start = attr_start + attr_search.len();
618    let val_end = tag_content[val_start..].find('"')?;
619    Some(tag_content[val_start..val_start + val_end].to_string())
620}
621
622/// Extract SAML attribute statements.
623fn extract_saml_attributes(xml: &str) -> HashMap<String, Vec<String>> {
624    let mut attrs = HashMap::new();
625
626    // Find all <Attribute Name="..."> elements
627    let mut search_pos = 0;
628    while let Some(attr_pos) = xml[search_pos..].find("<Attribute ") {
629        let abs_pos = search_pos + attr_pos;
630        let tag_end = match xml[abs_pos..].find('>') {
631            Some(p) => abs_pos + p,
632            None => break,
633        };
634
635        // Extract attribute name
636        let tag_content = &xml[abs_pos..tag_end];
637        let name = if let Some(n) = extract_inline_attr(tag_content, "Name") {
638            n
639        } else {
640            search_pos = tag_end + 1;
641            continue;
642        };
643
644        // Find the closing </Attribute> or </saml:Attribute>
645        let close_tag = "</Attribute>";
646        let alt_close = format!("</saml:Attribute>");
647        let end_pos = xml[tag_end..]
648            .find(close_tag)
649            .or_else(|| xml[tag_end..].find(&alt_close))
650            .map(|p| tag_end + p)
651            .unwrap_or(xml.len());
652
653        // Extract <AttributeValue> elements within this attribute
654        let attr_block = &xml[tag_end + 1..end_pos];
655        let values = extract_attribute_values(attr_block);
656
657        if !values.is_empty() {
658            attrs.entry(name).or_insert_with(Vec::new).extend(values);
659        }
660
661        search_pos = end_pos;
662    }
663
664    attrs
665}
666
667/// Extract the value of an inline XML attribute.
668fn extract_inline_attr(tag: &str, attr: &str) -> Option<String> {
669    let search = format!("{attr}=\"");
670    let start = tag.find(&search)?;
671    let val_start = start + search.len();
672    let val_end = tag[val_start..].find('"')?;
673    Some(tag[val_start..val_start + val_end].to_string())
674}
675
676/// Extract <AttributeValue> elements from an attribute block.
677fn extract_attribute_values(block: &str) -> Vec<String> {
678    let mut values = Vec::new();
679    let mut pos = 0;
680
681    let open_tags = ["<AttributeValue", "<saml:AttributeValue"];
682    let close_tags = ["</AttributeValue>", "</saml:AttributeValue>"];
683
684    while pos < block.len() {
685        let mut found = false;
686        for (open, close) in open_tags.iter().zip(close_tags.iter()) {
687            if let Some(start) = block[pos..].find(open) {
688                let abs_start = pos + start;
689                if let Some(tag_end) = block[abs_start..].find('>') {
690                    let content_start = abs_start + tag_end + 1;
691                    if let Some(close_pos) = block[content_start..].find(close) {
692                        let val = block[content_start..content_start + close_pos].trim();
693                        if !val.is_empty() {
694                            values.push(val.to_string());
695                        }
696                        pos = content_start + close_pos + close.len();
697                        found = true;
698                        break;
699                    }
700                }
701            }
702        }
703        if !found {
704            break;
705        }
706    }
707
708    values
709}
710
711/// Parse an ISO 8601 timestamp to Unix epoch seconds.
712fn parse_iso_timestamp(s: &str) -> Option<u64> {
713    chrono::DateTime::parse_from_rfc3339(s)
714        .ok()
715        .map(|dt| dt.timestamp() as u64)
716}
717
718/// Parse federation metadata XML.
719fn parse_federation_metadata(xml: &str) -> Result<FederationMetadata> {
720    let entity_id = extract_xml_attr_val(xml, "EntityDescriptor", "entityID").unwrap_or_default();
721
722    // Extract passive endpoint
723    let passive_endpoint = extract_xml_attr_val(xml, "PassiveRequestorEndpoint", "Location")
724        .or_else(|| {
725            // Also look in Address element
726            extract_xml_text(xml, "Address").or_else(|| extract_xml_text(xml, "wsa:Address"))
727        });
728
729    // Extract signing certificates
730    let signing_certificates = extract_all_x509_certs(xml);
731
732    // Extract token types offered
733    let token_types = extract_xml_list_by_tag(xml, "TokenType");
734
735    // Extract claim types
736    let claim_types = extract_claim_types(xml);
737
738    Ok(FederationMetadata {
739        entity_id,
740        passive_endpoint,
741        signing_certificates,
742        token_types_offered: token_types,
743        claim_types_offered: claim_types,
744    })
745}
746
747/// Extract all X509Certificate values from metadata.
748fn extract_all_x509_certs(xml: &str) -> Vec<String> {
749    let mut certs = Vec::new();
750    let tag = "X509Certificate";
751    let open = format!("<{tag}>");
752    let close = format!("</{tag}>");
753    let mut pos = 0;
754
755    while let Some(start) = xml[pos..].find(&open) {
756        let content_start = pos + start + open.len();
757        if let Some(end) = xml[content_start..].find(&close) {
758            let cert = xml[content_start..content_start + end]
759                .trim()
760                .replace(['\n', '\r', ' '], "");
761            if !cert.is_empty() {
762                certs.push(cert);
763            }
764            pos = content_start + end + close.len();
765        } else {
766            break;
767        }
768    }
769
770    certs
771}
772
773/// Extract token type values.
774fn extract_xml_list_by_tag(xml: &str, tag: &str) -> Vec<String> {
775    let mut values = Vec::new();
776    let open = format!("<{tag}");
777    let close = format!("</{tag}>");
778    let mut pos = 0;
779
780    while let Some(start) = xml[pos..].find(&open) {
781        let abs = pos + start;
782        if let Some(tag_end) = xml[abs..].find('>') {
783            let content_start = abs + tag_end + 1;
784            if let Some(close_pos) = xml[content_start..].find(&close) {
785                let val = xml[content_start..content_start + close_pos].trim();
786                if !val.is_empty() {
787                    values.push(val.to_string());
788                }
789                pos = content_start + close_pos + close.len();
790            } else {
791                break;
792            }
793        } else {
794            break;
795        }
796    }
797
798    values
799}
800
801/// Extract claim types offered from metadata.
802fn extract_claim_types(xml: &str) -> Vec<ClaimTypeOffered> {
803    let mut claims = Vec::new();
804    let mut pos = 0;
805
806    while let Some(start) = xml[pos..].find("<ClaimType ") {
807        let abs = pos + start;
808        let tag_end = match xml[abs..].find('>') {
809            Some(p) => abs + p,
810            None => break,
811        };
812        let tag = &xml[abs..tag_end];
813
814        let uri = extract_inline_attr(tag, "Uri").unwrap_or_default();
815
816        let close = "</ClaimType>";
817        let block_end = xml[tag_end..]
818            .find(close)
819            .map(|p| tag_end + p + close.len())
820            .unwrap_or(xml.len());
821
822        let block = &xml[tag_end + 1..block_end.saturating_sub(close.len())];
823
824        let display_name = extract_xml_text(block, "DisplayName");
825        let description = extract_xml_text(block, "Description");
826
827        claims.push(ClaimTypeOffered {
828            uri,
829            display_name,
830            description,
831        });
832
833        pos = block_end;
834    }
835
836    claims
837}
838
839/// Constant-time byte comparison.
840fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
841    use subtle::ConstantTimeEq;
842    if a.len() != b.len() {
843        return false;
844    }
845    a.ct_eq(b).into()
846}
847
848#[cfg(test)]
849mod tests {
850    use super::*;
851
852    #[test]
853    fn test_config_defaults() {
854        let config = WsFederationConfig::default();
855        assert_eq!(config.max_clock_skew_secs, 300);
856        assert!(!config.require_encrypted_tokens);
857    }
858
859    #[test]
860    fn test_client_requires_https() {
861        let config = WsFederationConfig {
862            sts_url: "http://adfs.example.com/adfs/ls".into(),
863            realm: "https://app.example.com".into(),
864            ..Default::default()
865        };
866        let err = WsFederationClient::new(config).unwrap_err();
867        assert!(err.to_string().contains("HTTPS"));
868    }
869
870    #[test]
871    fn test_client_requires_realm() {
872        let config = WsFederationConfig {
873            sts_url: "https://adfs.example.com/adfs/ls".into(),
874            ..Default::default()
875        };
876        let err = WsFederationClient::new(config).unwrap_err();
877        assert!(err.to_string().contains("realm"));
878    }
879
880    #[test]
881    fn test_sign_in_url() {
882        let config = WsFederationConfig {
883            sts_url: "https://adfs.example.com/adfs/ls".into(),
884            realm: "https://app.example.com".into(),
885            reply_url: "https://app.example.com/auth/wsfed".into(),
886            ..Default::default()
887        };
888        let client = WsFederationClient::new(config).unwrap();
889        let (url, wctx) = client.sign_in_url().unwrap();
890        assert!(url.contains("wa=wsignin1.0"));
891        assert!(url.contains("wtrealm="));
892        assert!(url.contains("wreply="));
893        assert!(url.contains(&wctx));
894    }
895
896    #[test]
897    fn test_sign_out_url() {
898        let config = WsFederationConfig {
899            sts_url: "https://adfs.example.com/adfs/ls".into(),
900            realm: "https://app.example.com".into(),
901            reply_url: "https://app.example.com/auth/wsfed".into(),
902            ..Default::default()
903        };
904        let client = WsFederationClient::new(config).unwrap();
905        let url = client.sign_out_url();
906        assert!(url.contains("wa=wsignout1.0"));
907    }
908
909    #[test]
910    fn test_parse_saml20_assertion() {
911        let xml = r#"
912        <saml:Assertion xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" Version="2.0">
913            <saml:Issuer>https://idp.example.com</saml:Issuer>
914            <saml:Subject>
915                <saml:NameID>jdoe@example.com</saml:NameID>
916            </saml:Subject>
917            <saml:Conditions NotBefore="2024-01-01T00:00:00Z" NotOnOrAfter="2034-01-01T00:00:00Z">
918                <saml:AudienceRestriction>
919                    <saml:Audience>https://app.example.com</saml:Audience>
920                </saml:AudienceRestriction>
921            </saml:Conditions>
922            <saml:AttributeStatement>
923                <Attribute Name="email">
924                    <AttributeValue>jdoe@example.com</AttributeValue>
925                </Attribute>
926            </saml:AttributeStatement>
927        </saml:Assertion>
928        "#;
929
930        let (subject, issuer, audience, _issued, _expires, claims) =
931            parse_saml20_assertion(xml).unwrap();
932
933        assert_eq!(subject, "jdoe@example.com");
934        assert_eq!(issuer, "https://idp.example.com");
935        assert_eq!(audience, "https://app.example.com");
936        assert!(claims.contains_key("email"));
937    }
938
939    #[test]
940    fn test_extract_assertion() {
941        let rstr = r#"
942        <RequestSecurityTokenResponse>
943            <saml:Assertion ID="_abc123">
944                <saml:Issuer>test</saml:Issuer>
945            </saml:Assertion>
946        </RequestSecurityTokenResponse>
947        "#;
948        let assertion = extract_assertion(rstr).unwrap();
949        assert!(assertion.contains("saml:Assertion"));
950        assert!(assertion.contains("saml:Issuer"));
951    }
952
953    #[test]
954    fn test_action_constants() {
955        assert_eq!(action::SIGN_IN, "wsignin1.0");
956        assert_eq!(action::SIGN_OUT, "wsignout1.0");
957    }
958}