wae_authentication/saml/
service.rs1use crate::saml::{
4 AuthnContextComparison, SamlAuthnRequest, SamlConfig, SamlLogoutRequest, SamlLogoutResponse, SamlNameIdPolicy,
5 SamlRequestedAuthnContext, SamlResponse, SpMetadataBuilder,
6};
7use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
8use chrono::Utc;
9use flate2::{Compression, read::DeflateDecoder, write::DeflateEncoder};
10use std::io::{Read, Write};
11use uuid::Uuid;
12use wae_types::{WaeError, WaeErrorKind};
13
14pub type SamlResult<T> = Result<T, WaeError>;
16
17#[derive(Debug, Clone)]
19pub struct SamlService {
20 config: SamlConfig,
21}
22
23impl SamlService {
24 pub fn new(config: SamlConfig) -> Self {
26 Self { config }
27 }
28
29 pub fn create_authn_request_url(&self) -> SamlResult<String> {
31 let request = self.create_authn_request()?;
32 let xml = self.serialize_authn_request(&request)?;
33 let encoded = self.encode_redirect_request(&xml)?;
34
35 let mut url = self.config.idp.sso_url.clone();
36 url.push_str("?SAMLRequest=");
37 url.push_str(&encoded);
38
39 Ok(url)
40 }
41
42 pub fn create_authn_request(&self) -> SamlResult<SamlAuthnRequest> {
44 let id = format!("id{}", Uuid::new_v4().simple());
45 let name_id_policy = SamlNameIdPolicy::new().with_format(self.config.sp.name_id_format.clone()).with_allow_create(true);
46
47 let authn_context = SamlRequestedAuthnContext::new(vec![
48 "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport".to_string(),
49 ])
50 .with_comparison(AuthnContextComparison::Minimum);
51
52 Ok(SamlAuthnRequest::new(id, &self.config.sp.entity_id)
53 .with_destination(&self.config.idp.sso_url)
54 .with_protocol_binding(self.config.idp.sso_binding)
55 .with_acs_url(&self.config.sp.acs_url)
56 .with_name_id_policy(name_id_policy)
57 .with_authn_context(authn_context))
58 }
59
60 pub fn process_authn_response(&self, saml_response: &str) -> SamlResult<SamlResponse> {
62 let decoded = BASE64_STANDARD
63 .decode(saml_response)
64 .map_err(|e| WaeError::new(WaeErrorKind::Base64DecodeError { reason: e.to_string() }))?;
65
66 let xml =
67 String::from_utf8(decoded).map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))?;
68
69 let response: SamlResponse = quick_xml::de::from_str(&xml)
70 .map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))?;
71
72 self.validate_response(&response)?;
73
74 Ok(response)
75 }
76
77 fn validate_response(&self, response: &SamlResponse) -> SamlResult<()> {
79 if !response.is_success() {
80 return Err(WaeError::new(WaeErrorKind::InvalidSamlResponse {
81 reason: format!("SAML response status: {:?}", response.status.code()),
82 }));
83 }
84
85 if self.config.validate_issuer {
86 if response.issuer != self.config.idp.entity_id {
87 return Err(WaeError::new(WaeErrorKind::SamlIssuerValidationFailed {
88 expected: self.config.idp.entity_id.clone(),
89 actual: response.issuer.clone(),
90 }));
91 }
92 }
93
94 if let Some(ref assertion) = response.assertion {
95 self.validate_assertion(assertion)?;
96 }
97
98 Ok(())
99 }
100
101 fn validate_assertion(&self, assertion: &crate::saml::SamlAssertion) -> SamlResult<()> {
103 let now = Utc::now().timestamp();
104
105 if let Some(ref conditions) = assertion.conditions {
106 if let Some(not_before) = conditions.not_before {
107 if now < not_before.timestamp() - self.config.clock_skew_seconds {
108 return Err(WaeError::new(WaeErrorKind::AssertionNotYetValid));
109 }
110 }
111
112 if let Some(not_on_or_after) = conditions.not_on_or_after {
113 if now >= not_on_or_after.timestamp() + self.config.clock_skew_seconds {
114 return Err(WaeError::new(WaeErrorKind::AssertionExpired));
115 }
116 }
117
118 if self.config.validate_audience {
119 if let Some(ref restriction) = conditions.audience_restriction {
120 if !restriction.audience.contains(&self.config.sp.entity_id) {
121 return Err(WaeError::new(WaeErrorKind::SamlAudienceValidationFailed {
122 expected: self.config.sp.entity_id.clone(),
123 actual: restriction.audience.first().cloned().unwrap_or_default(),
124 }));
125 }
126 }
127 }
128 }
129
130 Ok(())
131 }
132
133 pub fn create_logout_request_url(&self, name_id: &str, session_index: Option<&str>) -> SamlResult<String> {
135 let slo_url =
136 self.config.idp.slo_url.as_ref().ok_or_else(|| WaeError::config_invalid("slo_url", "SLO URL not configured"))?;
137
138 let request = self.create_logout_request(name_id, session_index)?;
139 let xml = self.serialize_logout_request(&request)?;
140 let encoded = self.encode_redirect_request(&xml)?;
141
142 let mut url = slo_url.clone();
143 url.push_str("?SAMLRequest=");
144 url.push_str(&encoded);
145
146 Ok(url)
147 }
148
149 pub fn create_logout_request(&self, name_id: &str, session_index: Option<&str>) -> SamlResult<SamlLogoutRequest> {
151 let id = format!("id{}", Uuid::new_v4().simple());
152 let slo_url =
153 self.config.idp.slo_url.as_ref().ok_or_else(|| WaeError::config_invalid("slo_url", "SLO URL not configured"))?;
154
155 let mut request = SamlLogoutRequest::new(id, &self.config.sp.entity_id)
156 .with_destination(slo_url)
157 .with_name_id(crate::saml::SamlNameId::new(name_id));
158
159 if let Some(idx) = session_index {
160 request = request.with_session_index(idx);
161 }
162
163 Ok(request)
164 }
165
166 pub fn process_logout_response(&self, saml_response: &str) -> SamlResult<SamlLogoutResponse> {
168 let decoded = BASE64_STANDARD
169 .decode(saml_response)
170 .map_err(|e| WaeError::new(WaeErrorKind::Base64DecodeError { reason: e.to_string() }))?;
171
172 let xml =
173 String::from_utf8(decoded).map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))?;
174
175 let response: SamlLogoutResponse = quick_xml::de::from_str(&xml)
176 .map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))?;
177
178 Ok(response)
179 }
180
181 pub fn generate_sp_metadata(&self) -> String {
183 let builder = SpMetadataBuilder::new(&self.config.sp.entity_id, &self.config.sp.acs_url)
184 .with_want_assertions_signed(self.config.sp.want_assertions_signed)
185 .with_authn_requests_signed(self.config.sp.want_response_signed);
186
187 let metadata = if let Some(ref slo_url) = self.config.sp.slo_url { builder.with_slo_url(slo_url) } else { builder };
188
189 let entity = metadata.build();
190
191 quick_xml::se::to_string(&entity).unwrap_or_default()
192 }
193
194 fn serialize_authn_request(&self, request: &SamlAuthnRequest) -> SamlResult<String> {
196 quick_xml::se::to_string(request).map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))
197 }
198
199 fn serialize_logout_request(&self, request: &SamlLogoutRequest) -> SamlResult<String> {
201 quick_xml::se::to_string(request).map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))
202 }
203
204 fn encode_redirect_request(&self, xml: &str) -> SamlResult<String> {
206 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
207 encoder
208 .write_all(xml.as_bytes())
209 .map_err(|e| WaeError::new(WaeErrorKind::CompressionError { reason: e.to_string() }))?;
210 let compressed =
211 encoder.finish().map_err(|e| WaeError::new(WaeErrorKind::CompressionError { reason: e.to_string() }))?;
212
213 Ok(BASE64_STANDARD.encode(&compressed))
214 }
215
216 pub fn decode_redirect_request(&self, encoded: &str) -> SamlResult<String> {
218 let decoded = BASE64_STANDARD
219 .decode(encoded)
220 .map_err(|e| WaeError::new(WaeErrorKind::Base64DecodeError { reason: e.to_string() }))?;
221
222 let mut decoder = DeflateDecoder::new(&decoded[..]);
223 let mut decompressed = String::new();
224 decoder
225 .read_to_string(&mut decompressed)
226 .map_err(|e| WaeError::new(WaeErrorKind::CompressionError { reason: e.to_string() }))?;
227
228 Ok(decompressed)
229 }
230
231 pub fn config(&self) -> &SamlConfig {
233 &self.config
234 }
235}
236
237pub fn create_saml_service(
239 sp_entity_id: impl Into<String>,
240 sp_acs_url: impl Into<String>,
241 idp_entity_id: impl Into<String>,
242 idp_sso_url: impl Into<String>,
243 idp_certificate: impl Into<String>,
244) -> SamlService {
245 use crate::saml::{IdentityProviderConfig, ServiceProviderConfig};
246
247 let sp = ServiceProviderConfig::new(sp_entity_id, sp_acs_url);
248 let idp = IdentityProviderConfig::new(idp_entity_id, idp_sso_url, idp_certificate);
249 let config = SamlConfig::new(sp, idp);
250
251 SamlService::new(config)
252}