1use crate::errors::{AuthError, Result};
23use ring::rand::SecureRandom;
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::time::{SystemTime, UNIX_EPOCH};
27
28pub mod action {
32 pub const SIGN_IN: &str = "wsignin1.0";
34 pub const SIGN_OUT: &str = "wsignout1.0";
36 pub const SIGN_OUT_CLEANUP: &str = "wsignoutcleanup1.0";
38 pub const ATTRIBUTE: &str = "wattr1.0";
40}
41
42pub mod ns {
44 pub const WS_FED: &str = "http://docs.oasis-open.org/wsfed/federation/200706";
45 pub const WS_TRUST: &str = "http://docs.oasis-open.org/ws-sx/ws-trust/200512";
46 pub const WS_ADDRESSING: &str = "http://www.w3.org/2005/08/addressing";
47 pub const SAML_11_ASSERTION: &str = "urn:oasis:names:tc:SAML:1.0:assertion";
48 pub const SAML_20_ASSERTION: &str = "urn:oasis:names:tc:SAML:2.0:assertion";
49}
50
51#[derive(Debug, Clone)]
55pub struct WsFederationConfig {
56 pub sts_url: String,
58
59 pub realm: String,
61
62 pub reply_url: String,
64
65 pub metadata_url: Option<String>,
67
68 pub trusted_issuers: Vec<String>,
70
71 pub trusted_cert_thumbprints: Vec<String>,
73
74 pub max_clock_skew_secs: u64,
76
77 pub require_encrypted_tokens: bool,
79
80 pub timeout_secs: u64,
82
83 pub home_realm: Option<String>,
85}
86
87impl Default for WsFederationConfig {
88 fn default() -> Self {
89 Self {
90 sts_url: String::new(),
91 realm: String::new(),
92 reply_url: String::new(),
93 metadata_url: None,
94 trusted_issuers: Vec::new(),
95 trusted_cert_thumbprints: Vec::new(),
96 max_clock_skew_secs: 300,
97 require_encrypted_tokens: false,
98 timeout_secs: 10,
99 home_realm: None,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct WsFedSignInResponse {
109 pub action: String,
111
112 pub result_xml: String,
114
115 pub context: Option<String>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct WsFedSecurityToken {
122 pub token_type: WsFedTokenType,
124
125 pub subject: String,
127
128 pub issuer: String,
130
131 pub audience: String,
133
134 pub issued_at: u64,
136
137 pub expires_at: u64,
139
140 pub claims: HashMap<String, Vec<String>>,
142
143 pub raw_assertion: String,
145}
146
147#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
149pub enum WsFedTokenType {
150 Saml11,
152 Saml20,
154 Jwt,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct FederationMetadata {
161 pub entity_id: String,
163
164 pub passive_endpoint: Option<String>,
166
167 pub signing_certificates: Vec<String>,
169
170 pub token_types_offered: Vec<String>,
172
173 pub claim_types_offered: Vec<ClaimTypeOffered>,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ClaimTypeOffered {
180 pub uri: String,
181 pub display_name: Option<String>,
182 pub description: Option<String>,
183}
184
185#[derive(Debug)]
189pub struct WsFederationClient {
190 config: WsFederationConfig,
191 http: reqwest::Client,
192}
193
194impl WsFederationClient {
195 pub fn new(config: WsFederationConfig) -> Result<Self> {
197 if config.sts_url.is_empty() {
198 return Err(AuthError::config("WS-Federation STS URL must be set"));
199 }
200 if !config.sts_url.starts_with("https://") {
201 return Err(AuthError::config("WS-Federation STS URL must use HTTPS"));
202 }
203 if config.realm.is_empty() {
204 return Err(AuthError::config("WS-Federation realm must be set"));
205 }
206
207 let http = reqwest::Client::builder()
208 .timeout(std::time::Duration::from_secs(config.timeout_secs))
209 .build()
210 .map_err(|e| AuthError::internal(format!("Failed to build HTTP client: {e}")))?;
211
212 Ok(Self { config, http })
213 }
214
215 pub fn sign_in_url(&self) -> Result<(String, String)> {
220 let rng = ring::rand::SystemRandom::new();
221 let mut ctx_bytes = [0u8; 16];
222 rng.fill(&mut ctx_bytes)
223 .map_err(|_| AuthError::crypto("Failed to generate wctx nonce"))?;
224 let wctx = hex::encode(ctx_bytes);
225
226 let mut url = format!(
227 "{}?wa={}&wtrealm={}&wreply={}&wctx={}",
228 self.config.sts_url,
229 urlencoding::encode(action::SIGN_IN),
230 urlencoding::encode(&self.config.realm),
231 urlencoding::encode(&self.config.reply_url),
232 urlencoding::encode(&wctx),
233 );
234
235 if let Some(ref whr) = self.config.home_realm {
236 url.push_str(&format!("&whr={}", urlencoding::encode(whr)));
237 }
238
239 Ok((url, wctx))
240 }
241
242 pub fn sign_out_url(&self) -> String {
244 format!(
245 "{}?wa={}&wtrealm={}",
246 self.config.sts_url,
247 urlencoding::encode(action::SIGN_OUT),
248 urlencoding::encode(&self.config.realm),
249 )
250 }
251
252 pub fn process_sign_in_response(
259 &self,
260 wa: &str,
261 wresult: &str,
262 wctx: Option<&str>,
263 expected_wctx: &str,
264 ) -> Result<WsFedSecurityToken> {
265 if wa != action::SIGN_IN {
267 return Err(AuthError::validation(format!(
268 "Unexpected WS-Fed action: {wa}"
269 )));
270 }
271
272 if let Some(ctx) = wctx {
274 if !constant_time_eq(ctx.as_bytes(), expected_wctx.as_bytes()) {
275 return Err(AuthError::validation("WS-Federation wctx mismatch (CSRF)"));
276 }
277 } else {
278 return Err(AuthError::validation("Missing wctx parameter"));
279 }
280
281 let token = self.parse_rstr(wresult)?;
283
284 self.validate_token(&token)?;
286
287 Ok(token)
288 }
289
290 pub async fn fetch_metadata(&self) -> Result<FederationMetadata> {
292 let url = self
293 .config
294 .metadata_url
295 .as_deref()
296 .ok_or_else(|| AuthError::config("Federation metadata URL not configured"))?;
297
298 let resp =
299 self.http.get(url).send().await.map_err(|e| {
300 AuthError::internal(format!("Federation metadata fetch failed: {e}"))
301 })?;
302
303 if !resp.status().is_success() {
304 let status = resp.status();
305 return Err(AuthError::internal(format!(
306 "Federation metadata HTTP error: {status}"
307 )));
308 }
309
310 let body = resp
311 .text()
312 .await
313 .map_err(|e| AuthError::internal(format!("Federation metadata read failed: {e}")))?;
314
315 parse_federation_metadata(&body)
316 }
317
318 fn parse_rstr(&self, rstr_xml: &str) -> Result<WsFedSecurityToken> {
320 let token_type = if rstr_xml.contains(ns::SAML_20_ASSERTION) {
322 WsFedTokenType::Saml20
323 } else if rstr_xml.contains(ns::SAML_11_ASSERTION) {
324 WsFedTokenType::Saml11
325 } else if rstr_xml.contains("\"JWT\"") || rstr_xml.contains("jwt") {
326 WsFedTokenType::Jwt
327 } else {
328 WsFedTokenType::Saml20 };
330
331 let raw_assertion = extract_assertion(rstr_xml)?;
333
334 let (subject, issuer, audience, issued_at, expires_at, claims) = match token_type {
336 WsFedTokenType::Saml20 => parse_saml20_assertion(&raw_assertion)?,
337 WsFedTokenType::Saml11 => parse_saml11_assertion(&raw_assertion)?,
338 WsFedTokenType::Jwt => parse_jwt_token(&raw_assertion)?,
339 };
340
341 Ok(WsFedSecurityToken {
342 token_type,
343 subject,
344 issuer,
345 audience,
346 issued_at,
347 expires_at,
348 claims,
349 raw_assertion,
350 })
351 }
352
353 fn validate_token(&self, token: &WsFedSecurityToken) -> Result<()> {
355 if !self.config.trusted_issuers.is_empty()
357 && !self.config.trusted_issuers.contains(&token.issuer)
358 {
359 return Err(AuthError::validation(format!(
360 "Token issuer '{}' is not trusted",
361 token.issuer
362 )));
363 }
364
365 if !token.audience.is_empty() && token.audience != self.config.realm {
367 return Err(AuthError::validation(format!(
368 "Token audience '{}' does not match realm '{}'",
369 token.audience, self.config.realm
370 )));
371 }
372
373 let now = SystemTime::now()
375 .duration_since(UNIX_EPOCH)
376 .map_err(|e| AuthError::internal(format!("Clock error: {e}")))?
377 .as_secs();
378
379 let skew = self.config.max_clock_skew_secs;
380 if token.expires_at + skew < now {
381 return Err(AuthError::validation("Security token has expired"));
382 }
383
384 if token.issued_at > now + skew {
385 return Err(AuthError::validation("Security token issued in the future"));
386 }
387
388 Ok(())
389 }
390}
391
392fn extract_assertion(rstr: &str) -> Result<String> {
396 let assertion_tags = [
398 ("saml:Assertion", "</saml:Assertion>"),
399 ("saml2:Assertion", "</saml2:Assertion>"),
400 ("Assertion", "</Assertion>"),
401 ];
402
403 for (open_tag, close_tag) in &assertion_tags {
404 let open = format!("<{open_tag}");
405 if let Some(start) = rstr.find(&open) {
406 if let Some(end) = rstr[start..].find(close_tag) {
407 return Ok(rstr[start..start + end + close_tag.len()].to_string());
408 }
409 }
410 }
411
412 Err(AuthError::validation(
413 "No SAML assertion found in WS-Federation response",
414 ))
415}
416
417fn parse_saml20_assertion(
419 xml: &str,
420) -> Result<(
421 String,
422 String,
423 String,
424 u64,
425 u64,
426 HashMap<String, Vec<String>>,
427)> {
428 let subject = extract_xml_text(xml, "NameID")
429 .or_else(|| extract_xml_text(xml, "saml:NameID"))
430 .unwrap_or_default();
431
432 let issuer = extract_xml_text(xml, "Issuer")
433 .or_else(|| extract_xml_text(xml, "saml:Issuer"))
434 .unwrap_or_default();
435
436 let audience = extract_xml_text(xml, "Audience")
437 .or_else(|| extract_xml_text(xml, "saml:Audience"))
438 .unwrap_or_default();
439
440 let not_before = extract_xml_attr_val(xml, "Conditions", "NotBefore")
441 .or_else(|| extract_xml_attr_val(xml, "saml:Conditions", "NotBefore"))
442 .and_then(|s| parse_iso_timestamp(&s))
443 .unwrap_or(0);
444
445 let not_on_or_after = extract_xml_attr_val(xml, "Conditions", "NotOnOrAfter")
446 .or_else(|| extract_xml_attr_val(xml, "saml:Conditions", "NotOnOrAfter"))
447 .and_then(|s| parse_iso_timestamp(&s))
448 .unwrap_or(u64::MAX);
449
450 let claims = extract_saml_attributes(xml);
451
452 Ok((
453 subject,
454 issuer,
455 audience,
456 not_before,
457 not_on_or_after,
458 claims,
459 ))
460}
461
462fn parse_saml11_assertion(
464 xml: &str,
465) -> Result<(
466 String,
467 String,
468 String,
469 u64,
470 u64,
471 HashMap<String, Vec<String>>,
472)> {
473 let subject = extract_xml_text(xml, "NameIdentifier")
474 .or_else(|| extract_xml_text(xml, "saml:NameIdentifier"))
475 .unwrap_or_default();
476
477 let issuer = extract_xml_attr_val(xml, "Assertion", "Issuer")
478 .or_else(|| extract_xml_attr_val(xml, "saml:Assertion", "Issuer"))
479 .unwrap_or_default();
480
481 let audience = extract_xml_text(xml, "Audience")
482 .or_else(|| extract_xml_text(xml, "saml:Audience"))
483 .unwrap_or_default();
484
485 let not_before = extract_xml_attr_val(xml, "Conditions", "NotBefore")
486 .and_then(|s| parse_iso_timestamp(&s))
487 .unwrap_or(0);
488
489 let not_on_or_after = extract_xml_attr_val(xml, "Conditions", "NotOnOrAfter")
490 .and_then(|s| parse_iso_timestamp(&s))
491 .unwrap_or(u64::MAX);
492
493 let claims = extract_saml_attributes(xml);
494
495 Ok((
496 subject,
497 issuer,
498 audience,
499 not_before,
500 not_on_or_after,
501 claims,
502 ))
503}
504
505fn parse_jwt_token(
511 jwt_str: &str,
512) -> Result<(
513 String,
514 String,
515 String,
516 u64,
517 u64,
518 HashMap<String, Vec<String>>,
519)> {
520 use base64::Engine as _;
521 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
522
523 let parts: Vec<&str> = jwt_str.trim().split('.').collect();
525 if parts.len() != 3 {
526 return Err(AuthError::validation(
527 "Invalid JWT format: expected 3 parts",
528 ));
529 }
530
531 let payload_bytes = URL_SAFE_NO_PAD
533 .decode(parts[1])
534 .map_err(|e| AuthError::validation(format!("Invalid JWT payload encoding: {e}")))?;
535
536 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
537 .map_err(|e| AuthError::validation(format!("Invalid JWT payload JSON: {e}")))?;
538
539 let subject = payload["sub"]
540 .as_str()
541 .or_else(|| payload["upn"].as_str())
542 .or_else(|| payload["email"].as_str())
543 .unwrap_or_default()
544 .to_string();
545
546 let issuer = payload["iss"].as_str().unwrap_or_default().to_string();
547
548 let audience = payload["aud"].as_str().unwrap_or_default().to_string();
549
550 let issued_at = payload["iat"].as_u64().unwrap_or(0);
551 let expires_at = payload["exp"].as_u64().unwrap_or(u64::MAX);
552
553 let mut claims = HashMap::new();
555 if let Some(obj) = payload.as_object() {
556 for (key, value) in obj {
557 match value {
558 serde_json::Value::String(s) => {
559 claims
560 .entry(key.clone())
561 .or_insert_with(Vec::new)
562 .push(s.clone());
563 }
564 serde_json::Value::Array(arr) => {
565 for item in arr {
566 if let Some(s) = item.as_str() {
567 claims
568 .entry(key.clone())
569 .or_insert_with(Vec::new)
570 .push(s.to_string());
571 }
572 }
573 }
574 _ => {}
575 }
576 }
577 }
578
579 Ok((subject, issuer, audience, issued_at, expires_at, claims))
580}
581
582fn extract_xml_text(xml: &str, tag: &str) -> Option<String> {
584 let open = format!("<{tag}");
585 let close = format!("</{tag}>");
586 let mut search_from = 0;
587 loop {
588 let start = xml[search_from..].find(&open).map(|i| search_from + i)?;
589 let after_name = start + open.len();
591 if after_name < xml.len() {
592 let next_char = xml.as_bytes()[after_name];
593 if next_char != b'>' && next_char != b' ' && next_char != b'/' {
594 search_from = after_name;
595 continue;
596 }
597 }
598 let after_tag = xml[after_name..].find('>')?;
599 let content_start = after_name + after_tag + 1;
600 let content_end = xml[content_start..].find(&close)?;
601 return Some(
602 xml[content_start..content_start + content_end]
603 .trim()
604 .to_string(),
605 );
606 }
607}
608
609fn extract_xml_attr_val(xml: &str, tag: &str, attr: &str) -> Option<String> {
611 let open = format!("<{tag}");
612 let start = xml.find(&open)?;
613 let tag_end = xml[start..].find('>')?;
614 let tag_content = &xml[start..start + tag_end];
615 let attr_search = format!("{attr}=\"");
616 let attr_start = tag_content.find(&attr_search)?;
617 let val_start = attr_start + attr_search.len();
618 let val_end = tag_content[val_start..].find('"')?;
619 Some(tag_content[val_start..val_start + val_end].to_string())
620}
621
622fn extract_saml_attributes(xml: &str) -> HashMap<String, Vec<String>> {
624 let mut attrs = HashMap::new();
625
626 let mut search_pos = 0;
628 while let Some(attr_pos) = xml[search_pos..].find("<Attribute ") {
629 let abs_pos = search_pos + attr_pos;
630 let tag_end = match xml[abs_pos..].find('>') {
631 Some(p) => abs_pos + p,
632 None => break,
633 };
634
635 let tag_content = &xml[abs_pos..tag_end];
637 let name = if let Some(n) = extract_inline_attr(tag_content, "Name") {
638 n
639 } else {
640 search_pos = tag_end + 1;
641 continue;
642 };
643
644 let close_tag = "</Attribute>";
646 let alt_close = format!("</saml:Attribute>");
647 let end_pos = xml[tag_end..]
648 .find(close_tag)
649 .or_else(|| xml[tag_end..].find(&alt_close))
650 .map(|p| tag_end + p)
651 .unwrap_or(xml.len());
652
653 let attr_block = &xml[tag_end + 1..end_pos];
655 let values = extract_attribute_values(attr_block);
656
657 if !values.is_empty() {
658 attrs.entry(name).or_insert_with(Vec::new).extend(values);
659 }
660
661 search_pos = end_pos;
662 }
663
664 attrs
665}
666
667fn extract_inline_attr(tag: &str, attr: &str) -> Option<String> {
669 let search = format!("{attr}=\"");
670 let start = tag.find(&search)?;
671 let val_start = start + search.len();
672 let val_end = tag[val_start..].find('"')?;
673 Some(tag[val_start..val_start + val_end].to_string())
674}
675
676fn extract_attribute_values(block: &str) -> Vec<String> {
678 let mut values = Vec::new();
679 let mut pos = 0;
680
681 let open_tags = ["<AttributeValue", "<saml:AttributeValue"];
682 let close_tags = ["</AttributeValue>", "</saml:AttributeValue>"];
683
684 while pos < block.len() {
685 let mut found = false;
686 for (open, close) in open_tags.iter().zip(close_tags.iter()) {
687 if let Some(start) = block[pos..].find(open) {
688 let abs_start = pos + start;
689 if let Some(tag_end) = block[abs_start..].find('>') {
690 let content_start = abs_start + tag_end + 1;
691 if let Some(close_pos) = block[content_start..].find(close) {
692 let val = block[content_start..content_start + close_pos].trim();
693 if !val.is_empty() {
694 values.push(val.to_string());
695 }
696 pos = content_start + close_pos + close.len();
697 found = true;
698 break;
699 }
700 }
701 }
702 }
703 if !found {
704 break;
705 }
706 }
707
708 values
709}
710
711fn parse_iso_timestamp(s: &str) -> Option<u64> {
713 chrono::DateTime::parse_from_rfc3339(s)
714 .ok()
715 .map(|dt| dt.timestamp() as u64)
716}
717
718fn parse_federation_metadata(xml: &str) -> Result<FederationMetadata> {
720 let entity_id = extract_xml_attr_val(xml, "EntityDescriptor", "entityID").unwrap_or_default();
721
722 let passive_endpoint = extract_xml_attr_val(xml, "PassiveRequestorEndpoint", "Location")
724 .or_else(|| {
725 extract_xml_text(xml, "Address").or_else(|| extract_xml_text(xml, "wsa:Address"))
727 });
728
729 let signing_certificates = extract_all_x509_certs(xml);
731
732 let token_types = extract_xml_list_by_tag(xml, "TokenType");
734
735 let claim_types = extract_claim_types(xml);
737
738 Ok(FederationMetadata {
739 entity_id,
740 passive_endpoint,
741 signing_certificates,
742 token_types_offered: token_types,
743 claim_types_offered: claim_types,
744 })
745}
746
747fn extract_all_x509_certs(xml: &str) -> Vec<String> {
749 let mut certs = Vec::new();
750 let tag = "X509Certificate";
751 let open = format!("<{tag}>");
752 let close = format!("</{tag}>");
753 let mut pos = 0;
754
755 while let Some(start) = xml[pos..].find(&open) {
756 let content_start = pos + start + open.len();
757 if let Some(end) = xml[content_start..].find(&close) {
758 let cert = xml[content_start..content_start + end]
759 .trim()
760 .replace(['\n', '\r', ' '], "");
761 if !cert.is_empty() {
762 certs.push(cert);
763 }
764 pos = content_start + end + close.len();
765 } else {
766 break;
767 }
768 }
769
770 certs
771}
772
773fn extract_xml_list_by_tag(xml: &str, tag: &str) -> Vec<String> {
775 let mut values = Vec::new();
776 let open = format!("<{tag}");
777 let close = format!("</{tag}>");
778 let mut pos = 0;
779
780 while let Some(start) = xml[pos..].find(&open) {
781 let abs = pos + start;
782 if let Some(tag_end) = xml[abs..].find('>') {
783 let content_start = abs + tag_end + 1;
784 if let Some(close_pos) = xml[content_start..].find(&close) {
785 let val = xml[content_start..content_start + close_pos].trim();
786 if !val.is_empty() {
787 values.push(val.to_string());
788 }
789 pos = content_start + close_pos + close.len();
790 } else {
791 break;
792 }
793 } else {
794 break;
795 }
796 }
797
798 values
799}
800
801fn extract_claim_types(xml: &str) -> Vec<ClaimTypeOffered> {
803 let mut claims = Vec::new();
804 let mut pos = 0;
805
806 while let Some(start) = xml[pos..].find("<ClaimType ") {
807 let abs = pos + start;
808 let tag_end = match xml[abs..].find('>') {
809 Some(p) => abs + p,
810 None => break,
811 };
812 let tag = &xml[abs..tag_end];
813
814 let uri = extract_inline_attr(tag, "Uri").unwrap_or_default();
815
816 let close = "</ClaimType>";
817 let block_end = xml[tag_end..]
818 .find(close)
819 .map(|p| tag_end + p + close.len())
820 .unwrap_or(xml.len());
821
822 let block = &xml[tag_end + 1..block_end.saturating_sub(close.len())];
823
824 let display_name = extract_xml_text(block, "DisplayName");
825 let description = extract_xml_text(block, "Description");
826
827 claims.push(ClaimTypeOffered {
828 uri,
829 display_name,
830 description,
831 });
832
833 pos = block_end;
834 }
835
836 claims
837}
838
839fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
841 use subtle::ConstantTimeEq;
842 if a.len() != b.len() {
843 return false;
844 }
845 a.ct_eq(b).into()
846}
847
848#[cfg(test)]
849mod tests {
850 use super::*;
851
852 #[test]
853 fn test_config_defaults() {
854 let config = WsFederationConfig::default();
855 assert_eq!(config.max_clock_skew_secs, 300);
856 assert!(!config.require_encrypted_tokens);
857 }
858
859 #[test]
860 fn test_client_requires_https() {
861 let config = WsFederationConfig {
862 sts_url: "http://adfs.example.com/adfs/ls".into(),
863 realm: "https://app.example.com".into(),
864 ..Default::default()
865 };
866 let err = WsFederationClient::new(config).unwrap_err();
867 assert!(err.to_string().contains("HTTPS"));
868 }
869
870 #[test]
871 fn test_client_requires_realm() {
872 let config = WsFederationConfig {
873 sts_url: "https://adfs.example.com/adfs/ls".into(),
874 ..Default::default()
875 };
876 let err = WsFederationClient::new(config).unwrap_err();
877 assert!(err.to_string().contains("realm"));
878 }
879
880 #[test]
881 fn test_sign_in_url() {
882 let config = WsFederationConfig {
883 sts_url: "https://adfs.example.com/adfs/ls".into(),
884 realm: "https://app.example.com".into(),
885 reply_url: "https://app.example.com/auth/wsfed".into(),
886 ..Default::default()
887 };
888 let client = WsFederationClient::new(config).unwrap();
889 let (url, wctx) = client.sign_in_url().unwrap();
890 assert!(url.contains("wa=wsignin1.0"));
891 assert!(url.contains("wtrealm="));
892 assert!(url.contains("wreply="));
893 assert!(url.contains(&wctx));
894 }
895
896 #[test]
897 fn test_sign_out_url() {
898 let config = WsFederationConfig {
899 sts_url: "https://adfs.example.com/adfs/ls".into(),
900 realm: "https://app.example.com".into(),
901 reply_url: "https://app.example.com/auth/wsfed".into(),
902 ..Default::default()
903 };
904 let client = WsFederationClient::new(config).unwrap();
905 let url = client.sign_out_url();
906 assert!(url.contains("wa=wsignout1.0"));
907 }
908
909 #[test]
910 fn test_parse_saml20_assertion() {
911 let xml = r#"
912 <saml:Assertion xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" Version="2.0">
913 <saml:Issuer>https://idp.example.com</saml:Issuer>
914 <saml:Subject>
915 <saml:NameID>jdoe@example.com</saml:NameID>
916 </saml:Subject>
917 <saml:Conditions NotBefore="2024-01-01T00:00:00Z" NotOnOrAfter="2034-01-01T00:00:00Z">
918 <saml:AudienceRestriction>
919 <saml:Audience>https://app.example.com</saml:Audience>
920 </saml:AudienceRestriction>
921 </saml:Conditions>
922 <saml:AttributeStatement>
923 <Attribute Name="email">
924 <AttributeValue>jdoe@example.com</AttributeValue>
925 </Attribute>
926 </saml:AttributeStatement>
927 </saml:Assertion>
928 "#;
929
930 let (subject, issuer, audience, _issued, _expires, claims) =
931 parse_saml20_assertion(xml).unwrap();
932
933 assert_eq!(subject, "jdoe@example.com");
934 assert_eq!(issuer, "https://idp.example.com");
935 assert_eq!(audience, "https://app.example.com");
936 assert!(claims.contains_key("email"));
937 }
938
939 #[test]
940 fn test_extract_assertion() {
941 let rstr = r#"
942 <RequestSecurityTokenResponse>
943 <saml:Assertion ID="_abc123">
944 <saml:Issuer>test</saml:Issuer>
945 </saml:Assertion>
946 </RequestSecurityTokenResponse>
947 "#;
948 let assertion = extract_assertion(rstr).unwrap();
949 assert!(assertion.contains("saml:Assertion"));
950 assert!(assertion.contains("saml:Issuer"));
951 }
952
953 #[test]
954 fn test_action_constants() {
955 assert_eq!(action::SIGN_IN, "wsignin1.0");
956 assert_eq!(action::SIGN_OUT, "wsignout1.0");
957 }
958}