Skip to main content

wae_authentication/saml/
service.rs

1//! SAML 服务实现
2
3use crate::saml::{
4    AuthnContextComparison, SamlAuthnRequest, SamlConfig, SamlLogoutRequest, SamlLogoutResponse, SamlNameIdPolicy,
5    SamlRequestedAuthnContext, SamlResponse, SpMetadataBuilder,
6};
7use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
8use chrono::Utc;
9use flate2::{Compression, read::DeflateDecoder, write::DeflateEncoder};
10use std::io::{Read, Write};
11use uuid::Uuid;
12use wae_types::{WaeError, WaeErrorKind};
13
14/// SAML 结果类型
15pub type SamlResult<T> = Result<T, WaeError>;
16
17/// SAML 服务
18#[derive(Debug, Clone)]
19pub struct SamlService {
20    config: SamlConfig,
21}
22
23impl SamlService {
24    /// 创建新的 SAML 服务
25    pub fn new(config: SamlConfig) -> Self {
26        Self { config }
27    }
28
29    /// 创建认证请求 URL (HTTP Redirect 绑定)
30    pub fn create_authn_request_url(&self) -> SamlResult<String> {
31        let request = self.create_authn_request()?;
32        let xml = self.serialize_authn_request(&request)?;
33        let encoded = self.encode_redirect_request(&xml)?;
34
35        let mut url = self.config.idp.sso_url.clone();
36        url.push_str("?SAMLRequest=");
37        url.push_str(&encoded);
38
39        Ok(url)
40    }
41
42    /// 创建认证请求
43    pub fn create_authn_request(&self) -> SamlResult<SamlAuthnRequest> {
44        let id = format!("id{}", Uuid::new_v4().simple());
45        let name_id_policy = SamlNameIdPolicy::new().with_format(self.config.sp.name_id_format.clone()).with_allow_create(true);
46
47        let authn_context = SamlRequestedAuthnContext::new(vec![
48            "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport".to_string(),
49        ])
50        .with_comparison(AuthnContextComparison::Minimum);
51
52        Ok(SamlAuthnRequest::new(id, &self.config.sp.entity_id)
53            .with_destination(&self.config.idp.sso_url)
54            .with_protocol_binding(self.config.idp.sso_binding)
55            .with_acs_url(&self.config.sp.acs_url)
56            .with_name_id_policy(name_id_policy)
57            .with_authn_context(authn_context))
58    }
59
60    /// 处理认证响应
61    pub fn process_authn_response(&self, saml_response: &str) -> SamlResult<SamlResponse> {
62        let decoded = BASE64_STANDARD
63            .decode(saml_response)
64            .map_err(|e| WaeError::new(WaeErrorKind::Base64DecodeError { reason: e.to_string() }))?;
65
66        let xml =
67            String::from_utf8(decoded).map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))?;
68
69        let response: SamlResponse = quick_xml::de::from_str(&xml)
70            .map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))?;
71
72        self.validate_response(&response)?;
73
74        Ok(response)
75    }
76
77    /// 验证响应
78    fn validate_response(&self, response: &SamlResponse) -> SamlResult<()> {
79        if !response.is_success() {
80            return Err(WaeError::new(WaeErrorKind::InvalidSamlResponse {
81                reason: format!("SAML response status: {:?}", response.status.code()),
82            }));
83        }
84
85        if self.config.validate_issuer {
86            if response.issuer != self.config.idp.entity_id {
87                return Err(WaeError::new(WaeErrorKind::SamlIssuerValidationFailed {
88                    expected: self.config.idp.entity_id.clone(),
89                    actual: response.issuer.clone(),
90                }));
91            }
92        }
93
94        if let Some(ref assertion) = response.assertion {
95            self.validate_assertion(assertion)?;
96        }
97
98        Ok(())
99    }
100
101    /// 验证断言
102    fn validate_assertion(&self, assertion: &crate::saml::SamlAssertion) -> SamlResult<()> {
103        let now = Utc::now().timestamp();
104
105        if let Some(ref conditions) = assertion.conditions {
106            if let Some(not_before) = conditions.not_before {
107                if now < not_before.timestamp() - self.config.clock_skew_seconds {
108                    return Err(WaeError::new(WaeErrorKind::AssertionNotYetValid));
109                }
110            }
111
112            if let Some(not_on_or_after) = conditions.not_on_or_after {
113                if now >= not_on_or_after.timestamp() + self.config.clock_skew_seconds {
114                    return Err(WaeError::new(WaeErrorKind::AssertionExpired));
115                }
116            }
117
118            if self.config.validate_audience {
119                if let Some(ref restriction) = conditions.audience_restriction {
120                    if !restriction.audience.contains(&self.config.sp.entity_id) {
121                        return Err(WaeError::new(WaeErrorKind::SamlAudienceValidationFailed {
122                            expected: self.config.sp.entity_id.clone(),
123                            actual: restriction.audience.first().cloned().unwrap_or_default(),
124                        }));
125                    }
126                }
127            }
128        }
129
130        Ok(())
131    }
132
133    /// 创建登出请求 URL
134    pub fn create_logout_request_url(&self, name_id: &str, session_index: Option<&str>) -> SamlResult<String> {
135        let slo_url =
136            self.config.idp.slo_url.as_ref().ok_or_else(|| WaeError::config_invalid("slo_url", "SLO URL not configured"))?;
137
138        let request = self.create_logout_request(name_id, session_index)?;
139        let xml = self.serialize_logout_request(&request)?;
140        let encoded = self.encode_redirect_request(&xml)?;
141
142        let mut url = slo_url.clone();
143        url.push_str("?SAMLRequest=");
144        url.push_str(&encoded);
145
146        Ok(url)
147    }
148
149    /// 创建登出请求
150    pub fn create_logout_request(&self, name_id: &str, session_index: Option<&str>) -> SamlResult<SamlLogoutRequest> {
151        let id = format!("id{}", Uuid::new_v4().simple());
152        let slo_url =
153            self.config.idp.slo_url.as_ref().ok_or_else(|| WaeError::config_invalid("slo_url", "SLO URL not configured"))?;
154
155        let mut request = SamlLogoutRequest::new(id, &self.config.sp.entity_id)
156            .with_destination(slo_url)
157            .with_name_id(crate::saml::SamlNameId::new(name_id));
158
159        if let Some(idx) = session_index {
160            request = request.with_session_index(idx);
161        }
162
163        Ok(request)
164    }
165
166    /// 处理登出响应
167    pub fn process_logout_response(&self, saml_response: &str) -> SamlResult<SamlLogoutResponse> {
168        let decoded = BASE64_STANDARD
169            .decode(saml_response)
170            .map_err(|e| WaeError::new(WaeErrorKind::Base64DecodeError { reason: e.to_string() }))?;
171
172        let xml =
173            String::from_utf8(decoded).map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))?;
174
175        let response: SamlLogoutResponse = quick_xml::de::from_str(&xml)
176            .map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))?;
177
178        Ok(response)
179    }
180
181    /// 生成 SP 元数据
182    pub fn generate_sp_metadata(&self) -> String {
183        let builder = SpMetadataBuilder::new(&self.config.sp.entity_id, &self.config.sp.acs_url)
184            .with_want_assertions_signed(self.config.sp.want_assertions_signed)
185            .with_authn_requests_signed(self.config.sp.want_response_signed);
186
187        let metadata = if let Some(ref slo_url) = self.config.sp.slo_url { builder.with_slo_url(slo_url) } else { builder };
188
189        let entity = metadata.build();
190
191        quick_xml::se::to_string(&entity).unwrap_or_default()
192    }
193
194    /// 序列化认证请求
195    fn serialize_authn_request(&self, request: &SamlAuthnRequest) -> SamlResult<String> {
196        quick_xml::se::to_string(request).map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))
197    }
198
199    /// 序列化登出请求
200    fn serialize_logout_request(&self, request: &SamlLogoutRequest) -> SamlResult<String> {
201        quick_xml::se::to_string(request).map_err(|e| WaeError::new(WaeErrorKind::XmlParsingError { reason: e.to_string() }))
202    }
203
204    /// 编码重定向请求 (Deflate + Base64)
205    fn encode_redirect_request(&self, xml: &str) -> SamlResult<String> {
206        let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
207        encoder
208            .write_all(xml.as_bytes())
209            .map_err(|e| WaeError::new(WaeErrorKind::CompressionError { reason: e.to_string() }))?;
210        let compressed =
211            encoder.finish().map_err(|e| WaeError::new(WaeErrorKind::CompressionError { reason: e.to_string() }))?;
212
213        Ok(BASE64_STANDARD.encode(&compressed))
214    }
215
216    /// 解码重定向请求
217    pub fn decode_redirect_request(&self, encoded: &str) -> SamlResult<String> {
218        let decoded = BASE64_STANDARD
219            .decode(encoded)
220            .map_err(|e| WaeError::new(WaeErrorKind::Base64DecodeError { reason: e.to_string() }))?;
221
222        let mut decoder = DeflateDecoder::new(&decoded[..]);
223        let mut decompressed = String::new();
224        decoder
225            .read_to_string(&mut decompressed)
226            .map_err(|e| WaeError::new(WaeErrorKind::CompressionError { reason: e.to_string() }))?;
227
228        Ok(decompressed)
229    }
230
231    /// 获取配置
232    pub fn config(&self) -> &SamlConfig {
233        &self.config
234    }
235}
236
237/// 便捷函数:创建 SAML 服务
238pub fn create_saml_service(
239    sp_entity_id: impl Into<String>,
240    sp_acs_url: impl Into<String>,
241    idp_entity_id: impl Into<String>,
242    idp_sso_url: impl Into<String>,
243    idp_certificate: impl Into<String>,
244) -> SamlService {
245    use crate::saml::{IdentityProviderConfig, ServiceProviderConfig};
246
247    let sp = ServiceProviderConfig::new(sp_entity_id, sp_acs_url);
248    let idp = IdentityProviderConfig::new(idp_entity_id, idp_sso_url, idp_certificate);
249    let config = SamlConfig::new(sp, idp);
250
251    SamlService::new(config)
252}