1use crate::errors::{AuthError, Result};
7use crate::protocols::saml_assertions::SamlAssertion;
8use base64::{Engine as _, engine::general_purpose::STANDARD};
9use chrono::{DateTime, Duration, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13fn xml_escape(s: &str) -> String {
15 let mut out = String::with_capacity(s.len());
16 for c in s.chars() {
17 match c {
18 '&' => out.push_str("&"),
19 '<' => out.push_str("<"),
20 '>' => out.push_str(">"),
21 '"' => out.push_str("""),
22 '\'' => out.push_str("'"),
23 _ => out.push(c),
24 }
25 }
26 out
27}
28
29#[derive(Debug, Clone, Default)]
31pub struct WsSecurityHeader {
32 pub username_token: Option<UsernameToken>,
34
35 pub timestamp: Option<Timestamp>,
37
38 pub binary_security_token: Option<BinarySecurityToken>,
40
41 pub saml_assertions: Vec<SamlAssertionRef>,
43
44 pub signature: Option<WsSecuritySignature>,
46
47 pub custom_elements: Vec<String>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct UsernameToken {
54 pub username: String,
56
57 pub password: Option<UsernamePassword>,
59
60 pub nonce: Option<String>,
62
63 pub created: Option<DateTime<Utc>>,
65
66 pub wsu_id: Option<String>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct UsernamePassword {
73 pub value: String,
75
76 pub password_type: PasswordType,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82pub enum PasswordType {
83 PasswordText,
85
86 PasswordDigest,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct Timestamp {
93 pub created: DateTime<Utc>,
95
96 pub expires: DateTime<Utc>,
98
99 pub wsu_id: Option<String>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct BinarySecurityToken {
106 pub value: String,
108
109 pub value_type: String,
111
112 pub encoding_type: String,
114
115 pub wsu_id: Option<String>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct SamlAssertionRef {
122 pub assertion: SamlAssertion,
124
125 pub wsu_id: Option<String>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct WsSecuritySignature {
132 pub signature_method: String,
134
135 pub canonicalization_method: String,
137
138 pub digest_method: String,
140
141 pub references: Vec<SignatureReference>,
143
144 pub key_info: Option<KeyInfo>,
146
147 pub signature_value: Option<String>,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct SignatureReference {
154 pub uri: String,
156
157 pub digest_value: String,
159
160 pub transforms: Vec<String>,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct KeyInfo {
167 pub security_token_reference: Option<String>,
169
170 pub key_value: Option<String>,
172
173 pub x509_data: Option<String>,
175}
176
177#[derive(Debug, Clone)]
179pub struct WsSecurityConfig {
180 pub include_timestamp: bool,
182
183 pub timestamp_ttl: Duration,
185
186 pub sign_message: bool,
188
189 pub elements_to_sign: Vec<String>,
191
192 pub signing_certificate: Option<Vec<u8>>,
194
195 pub signing_private_key: Option<Vec<u8>>,
197
198 pub include_certificate: bool,
200
201 pub saml_token_endpoint: Option<String>,
203
204 pub actor: Option<String>,
206}
207
208pub struct WsSecurityClient {
210 config: WsSecurityConfig,
212
213 namespaces: HashMap<String, String>,
215}
216
217impl WsSecurityClient {
218 pub fn new(config: WsSecurityConfig) -> Self {
220 let mut namespaces = HashMap::new();
221 namespaces.insert(
222 "wsse".to_string(),
223 "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd"
224 .to_string(),
225 );
226 namespaces.insert(
227 "wsu".to_string(),
228 "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd"
229 .to_string(),
230 );
231 namespaces.insert(
232 "ds".to_string(),
233 "http://www.w3.org/2000/09/xmldsig#".to_string(),
234 );
235 namespaces.insert(
236 "saml".to_string(),
237 "urn:oasis:names:tc:SAML:2.0:assertion".to_string(),
238 );
239
240 Self { config, namespaces }
241 }
242
243 pub fn create_username_token_header(
245 &self,
246 username: &str,
247 password: Option<&str>,
248 password_type: PasswordType,
249 ) -> Result<WsSecurityHeader> {
250 let mut header = WsSecurityHeader::default();
251
252 let (nonce, created) = if password_type == PasswordType::PasswordDigest {
253 (Some(self.generate_nonce()), Some(Utc::now()))
254 } else {
255 (None, None)
256 };
257
258 let password_element = if let Some(pwd) = password {
259 let pwd_value = match password_type {
260 PasswordType::PasswordText => pwd.to_string(),
261 PasswordType::PasswordDigest => self.compute_password_digest(
262 pwd,
263 nonce
264 .as_ref()
265 .expect("nonce is Some for PasswordDigest variant"),
266 &created.expect("created is Some for PasswordDigest variant"),
267 )?,
268 };
269
270 Some(UsernamePassword {
271 value: pwd_value,
272 password_type,
273 })
274 } else {
275 None
276 };
277
278 header.username_token = Some(UsernameToken {
279 username: username.to_string(),
280 password: password_element,
281 nonce,
282 created,
283 wsu_id: Some(format!("UsernameToken-{}", uuid::Uuid::new_v4())),
284 });
285
286 if self.config.include_timestamp {
287 header.timestamp = Some(self.create_timestamp());
288 }
289
290 Ok(header)
291 }
292
293 pub fn create_certificate_header(&self, certificate: &[u8]) -> Result<WsSecurityHeader> {
295 let mut header = WsSecurityHeader::default();
296
297 let cert_b64 = STANDARD.encode(certificate);
299
300 header.binary_security_token = Some(BinarySecurityToken {
301 value: cert_b64,
302 value_type: "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-x509-token-profile-1.0#X509v3".to_string(),
303 encoding_type: "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-soap-message-security-1.0#Base64Binary".to_string(),
304 wsu_id: Some(format!("X509Token-{}", uuid::Uuid::new_v4())),
305 });
306
307 if self.config.include_timestamp {
308 header.timestamp = Some(self.create_timestamp());
309 }
310
311 if self.config.sign_message {
312 header.signature = Some(self.create_signature_template()?);
313 }
314
315 Ok(header)
316 }
317
318 pub fn create_saml_header(&self, assertion: SamlAssertion) -> Result<WsSecurityHeader> {
320 let mut header = WsSecurityHeader::default();
321
322 let assertion_ref = SamlAssertionRef {
323 assertion,
324 wsu_id: Some(format!("SamlAssertion-{}", uuid::Uuid::new_v4())),
325 };
326
327 header.saml_assertions.push(assertion_ref);
328
329 if self.config.include_timestamp {
330 header.timestamp = Some(self.create_timestamp());
331 }
332
333 Ok(header)
334 }
335 pub fn header_to_xml(&self, header: &WsSecurityHeader) -> Result<String> {
337 let mut xml = String::new();
338
339 xml.push_str(&format!(
341 r#"<wsse:Security xmlns:wsse="{}" xmlns:wsu="{}">"#,
342 self.namespaces["wsse"], self.namespaces["wsu"]
343 ));
344
345 if let Some(ref timestamp) = header.timestamp {
347 xml.push_str(&self.timestamp_to_xml(timestamp));
348 }
349
350 if let Some(ref username_token) = header.username_token {
352 xml.push_str(&self.username_token_to_xml(username_token));
353 }
354
355 if let Some(ref bst) = header.binary_security_token {
357 xml.push_str(&self.binary_security_token_to_xml(bst));
358 }
359
360 for assertion_ref in &header.saml_assertions {
362 let assertion_xml = assertion_ref.assertion.to_xml()?;
363 xml.push_str(&assertion_xml);
364 }
365
366 if let Some(ref signature) = header.signature {
368 xml.push_str(&self.signature_to_xml(signature));
369 }
370
371 xml.push_str("</wsse:Security>");
373
374 Ok(xml)
375 }
376
377 fn generate_nonce(&self) -> String {
379 use rand::Rng;
380 let mut rng = rand::rng();
381 let mut nonce = [0u8; 16];
382 rng.fill_bytes(&mut nonce);
383 STANDARD.encode(nonce)
384 }
385
386 fn compute_password_digest(
388 &self,
389 password: &str,
390 nonce: &str,
391 created: &DateTime<Utc>,
392 ) -> Result<String> {
393 use sha1::{Digest, Sha1};
394
395 let nonce_bytes = STANDARD
396 .decode(nonce)
397 .map_err(|_| AuthError::auth_method("ws_security", "Invalid nonce encoding"))?;
398 let created_str = created.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
399
400 let mut hasher = Sha1::new();
401 hasher.update(&nonce_bytes);
402 hasher.update(created_str.as_bytes());
403 hasher.update(password.as_bytes());
404
405 let digest = hasher.finalize();
406 Ok(STANDARD.encode(digest))
407 }
408
409 fn create_timestamp(&self) -> Timestamp {
411 let now = Utc::now();
412 let expires = now + self.config.timestamp_ttl;
413
414 Timestamp {
415 created: now,
416 expires,
417 wsu_id: Some(format!("Timestamp-{}", uuid::Uuid::new_v4())),
418 }
419 }
420
421 fn create_signature_template(&self) -> Result<WsSecuritySignature> {
423 Ok(WsSecuritySignature {
424 signature_method: "http://www.w3.org/2001/04/xmldsig-more#hmac-sha256".to_string(),
429 canonicalization_method: "http://www.w3.org/2001/10/xml-exc-c14n#".to_string(),
430 digest_method: "http://www.w3.org/2001/04/xmlenc#sha256".to_string(),
431 references: self
432 .config
433 .elements_to_sign
434 .iter()
435 .map(|element| {
436 SignatureReference {
437 uri: format!("#{}", element),
438 digest_value: String::new(), transforms: vec!["http://www.w3.org/2001/10/xml-exc-c14n#".to_string()],
440 }
441 })
442 .collect(),
443 key_info: None, signature_value: None, })
446 }
447
448 fn timestamp_to_xml(&self, timestamp: &Timestamp) -> String {
450 let mut xml = String::new();
451
452 if let Some(ref id) = timestamp.wsu_id {
453 xml.push_str(&format!(r#"<wsu:Timestamp wsu:Id="{}">"#, id));
454 } else {
455 xml.push_str("<wsu:Timestamp>");
456 }
457
458 xml.push_str(&format!(
459 "<wsu:Created>{}</wsu:Created>",
460 timestamp.created.format("%Y-%m-%dT%H:%M:%S%.3fZ")
461 ));
462
463 xml.push_str(&format!(
464 "<wsu:Expires>{}</wsu:Expires>",
465 timestamp.expires.format("%Y-%m-%dT%H:%M:%S%.3fZ")
466 ));
467
468 xml.push_str("</wsu:Timestamp>");
469 xml
470 }
471
472 fn username_token_to_xml(&self, token: &UsernameToken) -> String {
474 let mut xml = String::new();
475
476 if let Some(ref id) = token.wsu_id {
477 xml.push_str(&format!(
478 r#"<wsse:UsernameToken wsu:Id="{}">"#,
479 xml_escape(id)
480 ));
481 } else {
482 xml.push_str("<wsse:UsernameToken>");
483 }
484
485 xml.push_str(&format!(
486 "<wsse:Username>{}</wsse:Username>",
487 xml_escape(&token.username)
488 ));
489
490 if let Some(ref password) = token.password {
491 let type_attr = match password.password_type {
492 PasswordType::PasswordText => {
493 "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-username-token-profile-1.0#PasswordText"
494 }
495 PasswordType::PasswordDigest => {
496 "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-username-token-profile-1.0#PasswordDigest"
497 }
498 };
499
500 xml.push_str(&format!(
501 r#"<wsse:Password Type="{}">{}</wsse:Password>"#,
502 type_attr,
503 xml_escape(&password.value)
504 ));
505 }
506
507 if let Some(ref nonce) = token.nonce {
508 xml.push_str(&format!(
509 r#"<wsse:Nonce EncodingType="http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-soap-message-security-1.0#Base64Binary">{}</wsse:Nonce>"#,
510 nonce
511 ));
512 }
513
514 if let Some(ref created) = token.created {
515 xml.push_str(&format!(
516 "<wsu:Created>{}</wsu:Created>",
517 created.format("%Y-%m-%dT%H:%M:%S%.3fZ")
518 ));
519 }
520
521 xml.push_str("</wsse:UsernameToken>");
522 xml
523 }
524
525 fn binary_security_token_to_xml(&self, token: &BinarySecurityToken) -> String {
527 let mut xml = String::new();
528
529 xml.push_str(&format!(
530 r#"<wsse:BinarySecurityToken ValueType="{}" EncodingType="{}""#,
531 token.value_type, token.encoding_type
532 ));
533
534 if let Some(ref id) = token.wsu_id {
535 xml.push_str(&format!(r#" wsu:Id="{}""#, id));
536 }
537
538 xml.push('>');
539 xml.push_str(&token.value);
540 xml.push_str("</wsse:BinarySecurityToken>");
541
542 xml
543 }
544
545 fn signature_to_xml(&self, signature: &WsSecuritySignature) -> String {
551 let references_xml: String = signature
553 .references
554 .iter()
555 .map(|r| {
556 let digest_value = if !r.digest_value.is_empty() {
557 r.digest_value.clone()
558 } else if let Some(ref key_bytes) = self.config.signing_private_key {
559 use ring::hmac;
562 let key = hmac::Key::new(hmac::HMAC_SHA256, key_bytes);
563 let tag = hmac::sign(&key, r.uri.as_bytes());
564 STANDARD.encode(tag.as_ref())
565 } else {
566 String::new()
567 };
568
569 format!(
570 r#"<ds:Reference URI="{}">
571 <ds:Transforms>
572 {}
573 </ds:Transforms>
574 <ds:DigestMethod Algorithm="{}"/>
575 <ds:DigestValue>{}</ds:DigestValue>
576 </ds:Reference>"#,
577 r.uri,
578 r.transforms
579 .iter()
580 .map(|t| format!(r#"<ds:Transform Algorithm="{}"/>"#, t))
581 .collect::<Vec<_>>()
582 .join(""),
583 signature.digest_method,
584 digest_value,
585 )
586 })
587 .collect::<Vec<_>>()
588 .join("");
589
590 let signed_info_xml = format!(
592 r#"<ds:SignedInfo>
593 <ds:CanonicalizationMethod Algorithm="{}"/>
594 <ds:SignatureMethod Algorithm="{}"/>
595 {}
596 </ds:SignedInfo>"#,
597 signature.canonicalization_method, signature.signature_method, references_xml,
598 );
599
600 let signature_value = if let Some(ref sv) = signature.signature_value {
602 sv.clone()
603 } else if let Some(ref key_bytes) = self.config.signing_private_key {
604 use ring::hmac;
605 let key = hmac::Key::new(hmac::HMAC_SHA256, key_bytes);
606 let tag = hmac::sign(&key, signed_info_xml.as_bytes());
607 STANDARD.encode(tag.as_ref())
608 } else {
609 String::new()
610 };
611
612 let key_info_xml = if let Some(ref ki) = signature.key_info {
614 let mut ki_xml = String::new();
615 if let Some(ref x509) = ki.x509_data {
616 ki_xml.push_str(&format!(
617 "<ds:X509Data><ds:X509Certificate>{}</ds:X509Certificate></ds:X509Data>",
618 x509
619 ));
620 }
621 if let Some(ref str_ref) = ki.security_token_reference {
622 ki_xml.push_str(&format!(
623 "<wsse:SecurityTokenReference>{}</wsse:SecurityTokenReference>",
624 str_ref
625 ));
626 }
627 if let Some(ref kv) = ki.key_value {
628 ki_xml.push_str(&format!("<ds:KeyValue>{}</ds:KeyValue>", kv));
629 }
630 ki_xml
631 } else if let Some(ref cert) = self.config.signing_certificate {
632 let cert_b64 = STANDARD.encode(cert);
633 format!(
634 "<ds:X509Data><ds:X509Certificate>{}</ds:X509Certificate></ds:X509Data>",
635 cert_b64
636 )
637 } else {
638 "<ds:KeyName>WS-Security-Signing-Key</ds:KeyName>".to_string()
639 };
640
641 format!(
642 r#"<ds:Signature xmlns:ds="{}">
643 {}
644 <ds:SignatureValue>{}</ds:SignatureValue>
645 <ds:KeyInfo>{}</ds:KeyInfo>
646 </ds:Signature>"#,
647 self.namespaces["ds"], signed_info_xml, signature_value, key_info_xml,
648 )
649 }
650}
651
652impl Default for WsSecurityConfig {
653 fn default() -> Self {
654 Self {
655 include_timestamp: true,
656 timestamp_ttl: Duration::minutes(5),
657 sign_message: false,
658 elements_to_sign: vec!["Body".to_string(), "Timestamp".to_string()],
659 signing_certificate: None,
660 signing_private_key: None,
661 include_certificate: true,
662 saml_token_endpoint: None,
663 actor: None,
664 }
665 }
666}
667
668impl WsSecurityConfig {
669 pub fn builder() -> WsSecurityConfigBuilder {
685 WsSecurityConfigBuilder::default()
686 }
687}
688
689#[derive(Debug, Clone)]
693pub struct WsSecurityConfigBuilder {
694 config: WsSecurityConfig,
695}
696
697impl Default for WsSecurityConfigBuilder {
698 fn default() -> Self {
699 Self {
700 config: WsSecurityConfig::default(),
701 }
702 }
703}
704
705impl WsSecurityConfigBuilder {
706 pub fn include_timestamp(mut self, include: bool) -> Self {
708 self.config.include_timestamp = include;
709 self
710 }
711
712 pub fn timestamp_ttl(mut self, ttl: Duration) -> Self {
714 self.config.timestamp_ttl = ttl;
715 self
716 }
717
718 pub fn sign_message(mut self, sign: bool) -> Self {
720 self.config.sign_message = sign;
721 self
722 }
723
724 pub fn elements_to_sign(mut self, elements: Vec<String>) -> Self {
726 self.config.elements_to_sign = elements;
727 self
728 }
729
730 pub fn signing_certificate(mut self, cert: Vec<u8>) -> Self {
732 self.config.signing_certificate = Some(cert);
733 self
734 }
735
736 pub fn signing_private_key(mut self, key: Vec<u8>) -> Self {
738 self.config.signing_private_key = Some(key);
739 self
740 }
741
742 pub fn include_certificate(mut self, include: bool) -> Self {
744 self.config.include_certificate = include;
745 self
746 }
747
748 pub fn saml_token_endpoint(mut self, endpoint: impl Into<String>) -> Self {
750 self.config.saml_token_endpoint = Some(endpoint.into());
751 self
752 }
753
754 pub fn actor(mut self, actor: impl Into<String>) -> Self {
756 self.config.actor = Some(actor.into());
757 self
758 }
759
760 pub fn build(self) -> WsSecurityConfig {
762 self.config
763 }
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769
770 #[test]
771 fn test_username_token_creation() {
772 let config = WsSecurityConfig::default();
773 let client = WsSecurityClient::new(config);
774
775 let header = client
776 .create_username_token_header("testuser", Some("testpass"), PasswordType::PasswordText)
777 .unwrap();
778
779 assert!(header.username_token.is_some());
780 let token = header.username_token.unwrap();
781 assert_eq!(token.username, "testuser");
782 assert!(token.password.is_some());
783 }
784
785 #[test]
786 fn test_password_digest() {
787 let config = WsSecurityConfig::default();
788 let client = WsSecurityClient::new(config);
789
790 let nonce = "MTIzNDU2Nzg5MDEyMzQ1Ng=="; let created = DateTime::parse_from_rfc3339("2023-01-01T12:00:00Z")
792 .unwrap()
793 .with_timezone(&Utc);
794 let password = "secret";
795
796 let digest = client
797 .compute_password_digest(password, nonce, &created)
798 .unwrap();
799 assert!(!digest.is_empty());
800 }
801
802 #[test]
803 fn test_timestamp_creation() {
804 let config = WsSecurityConfig::default();
805 let client = WsSecurityClient::new(config);
806
807 let timestamp = client.create_timestamp();
808 assert!(timestamp.expires > timestamp.created);
809 assert!(timestamp.wsu_id.is_some());
810 }
811
812 #[test]
813 fn test_xml_generation() {
814 let config = WsSecurityConfig::default();
815 let client = WsSecurityClient::new(config);
816
817 let header = client
818 .create_username_token_header("testuser", Some("testpass"), PasswordType::PasswordText)
819 .unwrap();
820
821 let xml = client.header_to_xml(&header).unwrap();
822 assert!(xml.contains("<wsse:Security"));
823 assert!(xml.contains("<wsse:UsernameToken"));
824 assert!(xml.contains("testuser"));
825 assert!(xml.contains("</wsse:Security>"));
826 }
827
828 #[test]
829 fn test_certificate_header() {
830 let config = WsSecurityConfig::default();
831 let client = WsSecurityClient::new(config);
832
833 let dummy_cert = b"dummy certificate data";
834 let header = client.create_certificate_header(dummy_cert).unwrap();
835
836 assert!(header.binary_security_token.is_some());
837 let bst = header.binary_security_token.unwrap();
838 assert_eq!(bst.value, STANDARD.encode(dummy_cert));
839 }
840
841 #[test]
842 fn test_xml_escape_special_characters() {
843 assert_eq!(xml_escape("hello"), "hello");
844 assert_eq!(xml_escape("<script>"), "<script>");
845 assert_eq!(xml_escape("a&b"), "a&b");
846 assert_eq!(xml_escape("\"quoted\""), ""quoted"");
847 assert_eq!(xml_escape("it's"), "it's");
848 assert_eq!(
849 xml_escape("<user>&\"name'"),
850 "<user>&"name'"
851 );
852 }
853
854 #[test]
855 fn test_xml_escape_empty_and_normal() {
856 assert_eq!(xml_escape(""), "");
857 assert_eq!(xml_escape("normal_user123"), "normal_user123");
858 }
859
860 #[test]
861 fn test_username_token_xml_escapes_injection() {
862 let config = WsSecurityConfig::default();
863 let client = WsSecurityClient::new(config);
864
865 let header = client
867 .create_username_token_header(
868 "<script>alert(1)</script>",
869 Some("pass&word\""),
870 PasswordType::PasswordText,
871 )
872 .unwrap();
873
874 let xml = client.header_to_xml(&header).unwrap();
875 assert!(!xml.contains("<script>"));
877 assert!(xml.contains("<script>"));
878 assert!(xml.contains("pass&word""));
879 }
880}