1use crate::api::{ApiResponse, ApiState};
2use axum::{
3 extract::{Query, State},
4 http::StatusCode,
5 response::{Html, IntoResponse, Json},
6};
7use base64::Engine;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11fn xml_escape(input: &str) -> String {
14 let mut output = String::with_capacity(input.len());
15 for ch in input.chars() {
16 match ch {
17 '&' => output.push_str("&"),
18 '<' => output.push_str("<"),
19 '>' => output.push_str(">"),
20 '"' => output.push_str("""),
21 '\'' => output.push_str("'"),
22 _ => output.push(ch),
23 }
24 }
25 output
26}
27
28#[cfg(feature = "saml")]
29use bergshamra::{DsigContext, Key, KeyData, KeysManager, VerifyResult, verify};
30#[cfg(feature = "saml")]
31use quick_xml::Reader;
32#[cfg(feature = "saml")]
33use quick_xml::events::Event;
34#[cfg(feature = "saml")]
35use quick_xml::name::QName;
36
37#[cfg(feature = "saml")]
40fn xml_local<'a>(name: QName<'a>) -> &'a [u8] {
41 let full = name.0;
42 match full.iter().position(|&b| b == b':') {
43 Some(pos) => &full[pos + 1..],
44 None => full,
45 }
46}
47
48#[derive(Debug, Serialize, Deserialize)]
50pub struct SamlSsoRequest {
51 pub idp_entity_id: String,
52 pub relay_state: Option<String>,
53 pub force_authn: Option<bool>,
54 pub is_passive: Option<bool>,
55}
56
57#[derive(Debug, Serialize, Deserialize)]
59pub struct SamlSsoResponse {
60 pub redirect_url: String,
61 pub saml_request: String,
62 pub relay_state: Option<String>,
63}
64
65#[derive(Debug, Serialize, Deserialize)]
67pub struct SamlAcsRequest {
68 #[serde(rename = "SAMLResponse")]
69 pub saml_response: String,
70 #[serde(rename = "RelayState")]
71 pub relay_state: Option<String>,
72 #[serde(rename = "SigAlg")]
73 pub sig_alg: Option<String>,
74 #[serde(rename = "Signature")]
75 pub signature: Option<String>,
76}
77
78#[derive(Debug, Serialize, Deserialize)]
80pub struct SamlMetadataResponse {
81 pub entity_id: String,
82 pub acs_url: String,
83 pub sls_url: Option<String>,
84 pub certificate: Option<String>,
85 pub name_id_format: String,
86}
87
88#[derive(Debug, Deserialize)]
89struct SamlSpConfig {
90 entity_id: String,
91 acs_url: String,
92 #[serde(default)]
93 slo_url: Option<String>,
94}
95
96impl SamlSpConfig {
97 fn validate(self) -> Result<Self, String> {
98 if self.entity_id.trim().is_empty() {
99 return Err("missing entity_id".to_string());
100 }
101 if self.acs_url.trim().is_empty() {
102 return Err("missing acs_url".to_string());
103 }
104 Ok(self)
105 }
106
107 fn slo_url(&self) -> Result<&str, String> {
108 self.slo_url
109 .as_deref()
110 .filter(|value| !value.trim().is_empty())
111 .ok_or_else(|| "missing slo_url".to_string())
112 }
113}
114
115async fn load_saml_sp_config(state: &ApiState) -> Result<SamlSpConfig, String> {
116 let data = state
117 .auth_framework
118 .storage()
119 .get_kv("saml_sp:config")
120 .await
121 .map_err(|_| "failed to load saml_sp:config".to_string())?
122 .ok_or_else(|| "missing saml_sp:config".to_string())?;
123
124 serde_json::from_slice::<SamlSpConfig>(&data)
125 .map_err(|_| "invalid saml_sp:config JSON".to_string())?
126 .validate()
127}
128
129#[derive(Debug, Serialize, Deserialize)]
131pub struct SamlLogoutRequest {
132 pub name_id: String,
133 pub session_index: Option<String>,
134 pub idp_entity_id: String,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
139pub struct SamlLogoutResponse {
140 pub redirect_url: String,
141 pub status: String,
142}
143
144pub async fn get_saml_metadata(State(state): State<ApiState>) -> impl IntoResponse {
148 let sp_config = match load_saml_sp_config(&state).await {
149 Ok(config) => config,
150 Err(error) => {
151 tracing::error!(error = %error, "SAML metadata requested without valid SP configuration");
152 return (
153 StatusCode::INTERNAL_SERVER_ERROR,
154 Html("SAML service provider configuration is missing or incomplete".to_string()),
155 )
156 .into_response();
157 }
158 };
159 let slo_url = match sp_config.slo_url() {
160 Ok(url) => url,
161 Err(error) => {
162 tracing::error!(error = %error, "SAML metadata requested without SLO URL configured");
163 return (
164 StatusCode::INTERNAL_SERVER_ERROR,
165 Html("SAML service provider configuration is missing or incomplete".to_string()),
166 )
167 .into_response();
168 }
169 };
170
171 let metadata_xml = format!(
172 r#"<?xml version="1.0" encoding="UTF-8"?>
173<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"
174 entityID="{entity_id}">
175 <md:SPSSODescriptor AuthnRequestsSigned="true" WantAssertionsSigned="true"
176 protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
177 <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
178 <md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
179 Location="{acs_url}"
180 index="0" />
181 <md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
182 Location="{slo_url}" />
183 </md:SPSSODescriptor>
184</md:EntityDescriptor>"#,
185 entity_id = xml_escape(&sp_config.entity_id),
186 acs_url = xml_escape(&sp_config.acs_url),
187 slo_url = xml_escape(slo_url),
188 );
189
190 (
192 StatusCode::OK,
193 [(
194 axum::http::header::CONTENT_TYPE,
195 "application/samlmetadata+xml",
196 )],
197 metadata_xml,
198 )
199 .into_response()
200}
201
202pub async fn initiate_saml_sso(
207 State(state): State<ApiState>,
208 Json(request): Json<SamlSsoRequest>,
209) -> Json<ApiResponse<SamlSsoResponse>> {
210 let idp_key = format!("saml_idp:{}", request.idp_entity_id);
213 let idp_sso_url = match state.auth_framework.storage().get_kv(&idp_key).await {
214 Ok(Some(data)) => {
215 let cfg: serde_json::Value = serde_json::from_slice(&data).unwrap_or_default();
216 match cfg["sso_url"].as_str() {
217 Some(url) => url.to_string(),
218 None => {
219 return Json(ApiResponse::error_typed(
220 "SAML_CONFIG_ERROR",
221 "IdP config is missing required sso_url field",
222 ));
223 }
224 }
225 }
226 Ok(None) => {
227 tracing::warn!(idp = %request.idp_entity_id, "SAML SSO: unknown IdP entity ID");
228 return Json(ApiResponse::error_typed(
229 "SAML_UNKNOWN_IDP",
230 format!("IdP not configured: {}", request.idp_entity_id),
231 ));
232 }
233 Err(e) => {
234 tracing::error!(error = %e, "SAML SSO: storage error looking up IdP");
235 return Json(ApiResponse::error_typed(
236 "server_error",
237 "Failed to look up IdP configuration",
238 ));
239 }
240 };
241
242 let sp_config = match load_saml_sp_config(&state).await {
243 Ok(config) => config,
244 Err(error) => {
245 tracing::error!(error = %error, "SAML SSO requested without valid SP configuration");
246 return Json(ApiResponse::error_typed(
247 "SAML_CONFIG_ERROR",
248 "Service Provider configuration is missing required entity_id and acs_url values",
249 ));
250 }
251 };
252
253 let request_id = format!("saml_{}", uuid::Uuid::new_v4());
255 let issue_instant = chrono::Utc::now().to_rfc3339();
256
257 let saml_request = format!(
258 r#"<samlp:AuthnRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
259 xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
260 ID="{request_id}"
261 Version="2.0"
262 IssueInstant="{issue_instant}"
263 Destination="{idp_sso_url}"
264 {force_authn}
265 {is_passive}
266 AssertionConsumerServiceURL="{sp_acs_url}">
267 <saml:Issuer>{sp_entity_id}</saml:Issuer>
268 <samlp:NameIDPolicy Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
269 AllowCreate="true" />
270</samlp:AuthnRequest>"#,
271 force_authn = request
272 .force_authn
273 .map_or(String::new(), |fa| format!(r#"ForceAuthn="{}""#, fa)),
274 is_passive = request
275 .is_passive
276 .map_or(String::new(), |ip| format!(r#"IsPassive="{}""#, ip)),
277 sp_entity_id = xml_escape(&sp_config.entity_id),
278 sp_acs_url = xml_escape(&sp_config.acs_url),
279 idp_sso_url = xml_escape(&idp_sso_url),
280 );
281
282 let encoded_request = base64::engine::general_purpose::STANDARD.encode(&saml_request);
284
285 let mut redirect_url = format!(
287 "{}?SAMLRequest={}",
288 idp_sso_url,
289 urlencoding::encode(&encoded_request)
290 );
291
292 if let Some(relay_state) = &request.relay_state {
293 redirect_url.push_str(&format!("&RelayState={}", urlencoding::encode(relay_state)));
294 }
295
296 let request_key = format!("saml_request:{}", request_id);
298 let request_data = serde_json::json!({
299 "request_id": request_id,
300 "idp_entity_id": request.idp_entity_id,
301 "relay_state": request.relay_state,
302 "issued_at": chrono::Utc::now().to_rfc3339(),
303 })
304 .to_string();
305 if let Err(e) = state
306 .auth_framework
307 .storage()
308 .store_kv(
309 &request_key,
310 request_data.as_bytes(),
311 Some(std::time::Duration::from_secs(600)),
312 )
313 .await
314 {
315 tracing::warn!(error = %e, "SAML SSO: failed to persist AuthnRequest — InResponseTo validation will be skipped");
316 }
317
318 Json(ApiResponse::success(SamlSsoResponse {
319 redirect_url,
320 saml_request: encoded_request,
321 relay_state: request.relay_state,
322 }))
323}
324
325#[allow(unreachable_code, unused_variables)]
327pub async fn handle_saml_acs(
328 State(state): State<ApiState>,
329 axum::Form(form_data): axum::Form<SamlAcsRequest>,
330) -> Json<ApiResponse<serde_json::Value>> {
331 let saml_response_xml =
333 match base64::engine::general_purpose::STANDARD.decode(&form_data.saml_response) {
334 Ok(decoded) => match String::from_utf8(decoded) {
335 Ok(xml) => xml,
336 Err(e) => {
337 tracing::warn!(error = %e, "SAML ACS: invalid UTF-8 in decoded response");
338 return Json(ApiResponse::validation_error_typed(
339 "Invalid SAML response encoding",
340 ));
341 }
342 },
343 Err(e) => {
344 tracing::warn!(error = %e, "SAML ACS: base64 decode failed");
345 return Json(ApiResponse::validation_error_typed(
346 "Invalid SAML response encoding",
347 ));
348 }
349 };
350
351 #[cfg(feature = "saml")]
353 {
354 match validate_saml_signature(&state, &saml_response_xml).await {
355 Ok(()) => {
356 tracing::info!("SAML ACS: XML signature validated successfully");
357 }
358 Err(e) => {
359 tracing::error!(error = %e, "SAML ACS: XML signature validation failed");
360 return Json(ApiResponse::error_typed(
361 "SAML_SIGNATURE_INVALID",
362 format!("SAML response signature validation failed: {}", e),
363 ));
364 }
365 }
366 }
367 #[cfg(not(feature = "saml"))]
368 {
369 tracing::error!(
370 "SAML ACS: XML signature validation is not available — \
371 the 'saml' feature is required for secure SAML processing"
372 );
373 return Json(ApiResponse::error_typed(
374 "SAML_SIGNATURE_UNAVAILABLE",
375 "SAML signature validation is not available; the server must be compiled with the 'saml' feature",
376 ));
377 }
378
379 if !saml_response_xml.contains("<saml:Assertion")
380 && !saml_response_xml.contains("<saml2:Assertion")
381 && !saml_response_xml.contains("<Assertion")
382 {
383 return Json(ApiResponse::validation_error_typed(
384 "No SAML assertion found",
385 ));
386 }
387
388 if let Some(irt) = extract_in_response_to(&saml_response_xml) {
391 let request_key = format!("saml_request:{}", irt);
392 match state.auth_framework.storage().get_kv(&request_key).await {
393 Ok(Some(_)) => {
394 let _ = state.auth_framework.storage().delete_kv(&request_key).await;
396 }
397 _ => {
398 tracing::warn!(in_response_to = %irt, "SAML ACS: InResponseTo references unknown or expired request");
399 return Json(ApiResponse::error_typed(
400 "SAML_INVALID_RESPONSE",
401 "SAML response references an unknown or expired authentication request",
402 ));
403 }
404 }
405 } else {
406 tracing::warn!(
408 "SAML ACS: response has no InResponseTo attribute — rejecting unsolicited response"
409 );
410 return Json(ApiResponse::error_typed(
411 "SAML_UNSOLICITED_RESPONSE",
412 "Unsolicited SAML responses are not accepted; initiate SSO via /api/v1/saml/sso first",
413 ));
414 }
415
416 #[cfg(feature = "saml")]
419 {
420 let sp_entity_id = match load_saml_sp_config(&state).await {
421 Ok(config) => config.entity_id,
422 Err(error) => {
423 tracing::error!(error = %error, "SAML ACS requested without valid SP configuration");
424 return Json(ApiResponse::error_typed(
425 "SAML_CONFIG_ERROR",
426 "Service Provider configuration is missing required entity_id and acs_url values",
427 ));
428 }
429 };
430
431 if let Err(e) = validate_saml_conditions(&saml_response_xml, &sp_entity_id) {
432 tracing::warn!(error = %e, "SAML ACS: assertion conditions validation failed");
433 return Json(ApiResponse::error_typed("SAML_CONDITIONS_INVALID", e));
434 }
435 }
436
437 let username = match extract_username_from_saml(&saml_response_xml) {
439 Ok(user) => user,
440 Err(e) => return Json(ApiResponse::error_typed("SAML_PARSE_ERROR", e)),
441 };
442
443 let attributes = match extract_attributes_from_saml(&saml_response_xml) {
444 Ok(attrs) => attrs,
445 Err(e) => return Json(ApiResponse::error_typed("SAML_PARSE_ERROR", e)),
446 };
447
448 let scopes = vec![
450 "openid".to_string(),
451 "profile".to_string(),
452 "email".to_string(),
453 ];
454 let token = match state
455 .auth_framework
456 .token_manager()
457 .create_auth_token(&username, scopes, "saml", None)
458 {
459 Ok(t) => t,
460 Err(e) => {
461 tracing::error!(user = %username, error = %e, "SAML ACS: failed to create auth token");
462 return Json(ApiResponse::error_typed(
463 "server_error",
464 "Failed to create authentication token",
465 ));
466 }
467 };
468
469 let token_data = serde_json::json!({
470 "access_token": token.access_token,
471 "token_type": "Bearer",
472 "expires_in": (token.expires_at - token.issued_at).num_seconds().max(0) as u64,
473 "refresh_token": token.refresh_token,
474 "user_id": username,
475 "authentication_method": "saml",
476 "attributes": attributes,
477 "relay_state": form_data.relay_state
478 });
479
480 tracing::info!(user = %username, "SAML authentication successful");
481 Json(ApiResponse::success_with_message(
482 token_data,
483 "SAML authentication successful",
484 ))
485}
486
487pub async fn initiate_saml_slo(
490 State(state): State<ApiState>,
491 Json(request): Json<SamlLogoutRequest>,
492) -> Json<ApiResponse<SamlLogoutResponse>> {
493 let idp_key = format!("saml_idp:{}", request.idp_entity_id);
495 let idp_slo_url = match state.auth_framework.storage().get_kv(&idp_key).await {
496 Ok(Some(data)) => {
497 let cfg: serde_json::Value = serde_json::from_slice(&data).unwrap_or_default();
498 match cfg["slo_url"].as_str() {
499 Some(url) => url.to_string(),
500 None => {
501 return Json(ApiResponse::error_typed(
502 "SAML_CONFIG_ERROR",
503 "IdP config is missing required slo_url field",
504 ));
505 }
506 }
507 }
508 Ok(None) => {
509 tracing::warn!(idp = %request.idp_entity_id, "SAML SLO: unknown IdP entity ID");
510 return Json(ApiResponse::error_typed(
511 "SAML_UNKNOWN_IDP",
512 format!("IdP not configured: {}", request.idp_entity_id),
513 ));
514 }
515 Err(e) => {
516 tracing::error!(error = %e, "SAML SLO: storage error looking up IdP");
517 return Json(ApiResponse::error_typed(
518 "server_error",
519 "Failed to look up IdP configuration",
520 ));
521 }
522 };
523
524 let sp_config = match load_saml_sp_config(&state).await {
525 Ok(config) => config,
526 Err(error) => {
527 tracing::error!(error = %error, "SAML SLO requested without valid SP configuration");
528 return Json(ApiResponse::error_typed(
529 "SAML_CONFIG_ERROR",
530 "Service Provider configuration is missing required entity_id and acs_url values",
531 ));
532 }
533 };
534
535 let logout_id = format!("logout_{}", uuid::Uuid::new_v4());
536 let issue_instant = chrono::Utc::now().to_rfc3339();
537
538 let saml_logout_request = format!(
540 r#"<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
541 xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
542 ID="{logout_id}"
543 Version="2.0"
544 IssueInstant="{issue_instant}"
545 Destination="{idp_slo_url}">
546 <saml:Issuer>{sp_entity_id}</saml:Issuer>
547 <saml:NameID Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress">{name_id}</saml:NameID>
548 {session_index}
549</samlp:LogoutRequest>"#,
550 name_id = xml_escape(&request.name_id),
551 session_index = request.session_index.map_or(String::new(), |si| format!(
552 r#"<samlp:SessionIndex>{}</samlp:SessionIndex>"#,
553 xml_escape(&si)
554 )),
555 sp_entity_id = xml_escape(&sp_config.entity_id),
556 idp_slo_url = xml_escape(&idp_slo_url),
557 );
558
559 let encoded_request = base64::engine::general_purpose::STANDARD.encode(&saml_logout_request);
560 let redirect_url = format!(
561 "{}?SAMLRequest={}",
562 idp_slo_url,
563 urlencoding::encode(&encoded_request)
564 );
565
566 Json(ApiResponse::success_with_message(
567 SamlLogoutResponse {
568 redirect_url,
569 status: "logout_initiated".to_string(),
570 },
571 "SAML logout initiated",
572 ))
573}
574
575pub async fn handle_saml_slo_response(
577 State(_state): State<ApiState>,
578 Query(params): Query<HashMap<String, String>>,
579) -> Json<ApiResponse<()>> {
580 let saml_response = match params.get("SAMLResponse") {
581 Some(response) => response,
582 None => {
583 return Json(ApiResponse::validation_error(
584 "Missing SAMLResponse parameter",
585 ));
586 }
587 };
588
589 let response_xml = match base64::engine::general_purpose::STANDARD.decode(saml_response) {
591 Ok(decoded) => match String::from_utf8(decoded) {
592 Ok(xml) => xml,
593 Err(e) => {
594 return Json(ApiResponse::validation_error(format!(
595 "Invalid SLO response UTF-8: {}",
596 e
597 )));
598 }
599 },
600 Err(e) => {
601 return Json(ApiResponse::validation_error(format!(
602 "Invalid SLO response encoding: {}",
603 e
604 )));
605 }
606 };
607
608 #[cfg(feature = "saml")]
610 let slo_success = xml_extract_status_code(&response_xml)
611 .map(|code| code == "urn:oasis:names:tc:SAML:2.0:status:Success")
612 .unwrap_or(false);
613 #[cfg(not(feature = "saml"))]
614 let slo_success = false;
615
616 if slo_success {
617 #[cfg(feature = "saml")]
620 {
621 if let Some(name_id) = xml_extract_name_id(&response_xml) {
622 if let Ok(Some(uid_bytes)) = _state
624 .auth_framework
625 .storage()
626 .get_kv(&format!("user:email:{}", name_id))
627 .await
628 {
629 let user_id = String::from_utf8_lossy(&uid_bytes).to_string();
630 let session_key = format!("sessions:user:{}", user_id);
631 let _ = _state
632 .auth_framework
633 .storage()
634 .delete_kv(&session_key)
635 .await;
636 tracing::info!(user_id = %user_id, "SAML SLO: invalidated sessions");
637 }
638 }
639 }
640
641 if let Some(relay_state) = params.get("RelayState") {
643 if !relay_state.is_empty() {
644 tracing::debug!(relay_state = %relay_state, "SAML SLO: RelayState provided");
645 }
646 }
647
648 Json(ApiResponse::<()>::ok_with_message(
649 "SAML logout completed successfully",
650 ))
651 } else {
652 Json(ApiResponse::error(
653 "SAML_LOGOUT_FAILED",
654 "SAML logout failed",
655 ))
656 }
657}
658
659pub async fn create_saml_assertion(
661 State(state): State<ApiState>,
662 Json(request): Json<serde_json::Value>,
663) -> Json<ApiResponse<String>> {
664 let username = match request["username"].as_str() {
665 Some(user) => user,
666 None => return Json(ApiResponse::validation_error_typed("Username required")),
667 };
668
669 let audience = match request["audience"].as_str() {
670 Some(aud) => aud,
671 None => return Json(ApiResponse::validation_error_typed("Audience required")),
672 };
673
674 let sp_config = match load_saml_sp_config(&state).await {
675 Ok(config) => config,
676 Err(error) => {
677 tracing::error!(error = %error, "SAML assertion requested without valid SP configuration");
678 return Json(ApiResponse::error_typed(
679 "SAML_CONFIG_ERROR",
680 "Service Provider configuration is missing required entity_id and acs_url values",
681 ));
682 }
683 };
684
685 let name_id = match request["email"].as_str().map(str::trim) {
686 Some(email) if !email.is_empty() => email.to_string(),
687 _ if username.contains('@') => username.to_string(),
688 _ => {
689 return Json(ApiResponse::validation_error_typed(
690 "Email required when username is not an email address",
691 ));
692 }
693 };
694
695 let assertion_id = uuid::Uuid::new_v4();
700 let response_id = uuid::Uuid::new_v4();
701 let now = chrono::Utc::now();
702 let not_before = (now - chrono::Duration::minutes(1)).to_rfc3339();
703 let not_after = (now + chrono::Duration::hours(1)).to_rfc3339();
704 let now_str = now.to_rfc3339();
705 let issuer = xml_escape(&sp_config.entity_id);
707 let audience_escaped = xml_escape(audience);
708 let name_id_escaped = xml_escape(&name_id);
709 let username_escaped = xml_escape(username);
710
711 let assertion_xml = format!(
712 r#"<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
713 xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
714 ID="response_{response_id}"
715 IssueInstant="{now_str}"
716 Destination="{audience_escaped}"
717 Version="2.0">
718 <saml:Issuer>{issuer}</saml:Issuer>
719 <samlp:Status>
720 <samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
721 </samlp:Status>
722 <saml:Assertion ID="assertion_{assertion_id}"
723 IssueInstant="{now_str}"
724 Version="2.0">
725 <saml:Issuer>{issuer}</saml:Issuer>
726 <saml:Subject>
727 <saml:NameID Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress">{name_id_escaped}</saml:NameID>
728 <saml:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
729 <saml:SubjectConfirmationData NotOnOrAfter="{not_after}" Recipient="{audience_escaped}"/>
730 </saml:SubjectConfirmation>
731 </saml:Subject>
732 <saml:Conditions NotBefore="{not_before}" NotOnOrAfter="{not_after}">
733 <saml:AudienceRestriction>
734 <saml:Audience>{audience_escaped}</saml:Audience>
735 </saml:AudienceRestriction>
736 </saml:Conditions>
737 <saml:AuthnStatement AuthnInstant="{now_str}" SessionIndex="session_{assertion_id}">
738 <saml:AuthnContext>
739 <saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml:AuthnContextClassRef>
740 </saml:AuthnContext>
741 </saml:AuthnStatement>
742 <saml:AttributeStatement>
743 <saml:Attribute Name="username">
744 <saml:AttributeValue>{username_escaped}</saml:AttributeValue>
745 </saml:Attribute>
746 <saml:Attribute Name="email">
747 <saml:AttributeValue>{name_id_escaped}</saml:AttributeValue>
748 </saml:Attribute>
749 </saml:AttributeStatement>
750 </saml:Assertion>
751</samlp:Response>"#,
752 );
753
754 Json(ApiResponse::success_with_message(
755 assertion_xml,
756 "SAML assertion created",
757 ))
758}
759
760pub async fn list_saml_idps(
764 State(state): State<ApiState>,
765) -> Json<ApiResponse<Vec<serde_json::Value>>> {
766 let entity_ids: Vec<String> = match state
768 .auth_framework
769 .storage()
770 .get_kv("saml_idps:index")
771 .await
772 {
773 Ok(Some(data)) => serde_json::from_slice(&data).unwrap_or_default(),
774 Ok(None) => vec![],
775 Err(e) => {
776 tracing::error!(error = %e, "Failed to load SAML IdP index");
777 return Json(ApiResponse::error_typed(
778 "server_error",
779 "Failed to load IdP list",
780 ));
781 }
782 };
783
784 let mut idps = Vec::with_capacity(entity_ids.len());
786 for entity_id in &entity_ids {
787 let key = format!("saml_idp:{}", entity_id);
788 if let Ok(Some(data)) = state.auth_framework.storage().get_kv(&key).await
789 && let Ok(cfg) = serde_json::from_slice::<serde_json::Value>(&data)
790 {
791 idps.push(cfg);
792 }
793 }
794
795 Json(ApiResponse::success_with_message(
796 idps,
797 "SAML IdPs retrieved",
798 ))
799}
800
801#[cfg(feature = "saml")]
812async fn validate_saml_signature(state: &ApiState, saml_xml: &str) -> Result<(), String> {
813 let issuer = extract_issuer(saml_xml)
816 .ok_or_else(|| "SAML response does not contain an Issuer element".to_string())?;
817
818 let idp_key = format!("saml_idp:{}", issuer);
820 let idp_cfg_data = state
821 .auth_framework
822 .storage()
823 .get_kv(&idp_key)
824 .await
825 .map_err(|e| format!("Storage error loading IdP config: {}", e))?
826 .ok_or_else(|| format!("IdP not configured: {}", issuer))?;
827
828 let idp_cfg: serde_json::Value = serde_json::from_slice(&idp_cfg_data)
829 .map_err(|e| format!("Invalid IdP config JSON: {}", e))?;
830
831 let signing_cert_pem = idp_cfg["signing_cert"]
832 .as_str()
833 .ok_or_else(|| format!("IdP '{}' has no signing_cert configured", issuer))?;
834
835 let der_bytes = pem_to_der(signing_cert_pem)?;
837
838 let mut keys_manager = KeysManager::new();
840
841 let key = key_from_x509_der(&der_bytes)?;
844 keys_manager.add_key(key);
845
846 keys_manager.add_trusted_cert(der_bytes);
848
849 let ctx = DsigContext::new(keys_manager)
854 .with_trusted_keys_only(true)
855 .with_strict_verification(true)
856 .with_verify_keys(true);
857
858 let result =
860 verify(&ctx, saml_xml).map_err(|e| format!("XML-DSig verification error: {}", e))?;
861
862 match result {
863 VerifyResult::Valid { references, .. } => {
864 if references.is_empty() {
866 return Err("Signature is valid but covers no references".to_string());
867 }
868 Ok(())
869 }
870 VerifyResult::Invalid { reason } => Err(format!("Signature invalid: {}", reason)),
871 }
872}
873
874#[cfg(feature = "saml")]
876fn pem_to_der(pem: &str) -> Result<Vec<u8>, String> {
877 let b64: String = pem
879 .lines()
880 .filter(|line| {
881 !line.starts_with("-----BEGIN") && !line.starts_with("-----END") && !line.is_empty()
882 })
883 .collect::<Vec<&str>>()
884 .join("");
885
886 base64::engine::general_purpose::STANDARD
887 .decode(&b64)
888 .map_err(|e| format!("Failed to base64-decode PEM certificate: {}", e))
889}
890
891#[cfg(feature = "saml")]
896fn key_from_x509_der(der: &[u8]) -> Result<Key, String> {
897 use rsa::pkcs8::DecodePublicKey;
898 use x509_parser::prelude::*;
899
900 let (_, cert) = X509Certificate::from_der(der)
901 .map_err(|e| format!("Failed to parse X.509 certificate: {}", e))?;
902
903 let spki = cert.public_key();
904 let spki_der = spki.raw;
905
906 if let Ok(rsa_pub) = rsa::RsaPublicKey::from_public_key_der(spki_der) {
908 return Ok(Key::new(
909 KeyData::Rsa {
910 public: rsa_pub,
911 private: None,
912 },
913 bergshamra::KeyUsage::Verify,
914 ));
915 }
916
917 if let Ok(ec_key) = p256::ecdsa::VerifyingKey::from_public_key_der(spki_der) {
919 return Ok(Key::new(
920 KeyData::EcP256 {
921 public: ec_key,
922 private: None,
923 },
924 bergshamra::KeyUsage::Verify,
925 ));
926 }
927
928 if let Ok(ec_key) = p384::ecdsa::VerifyingKey::from_public_key_der(spki_der) {
930 return Ok(Key::new(
931 KeyData::EcP384 {
932 public: ec_key,
933 private: None,
934 },
935 bergshamra::KeyUsage::Verify,
936 ));
937 }
938
939 Err(format!(
940 "Unsupported IdP signing key algorithm (OID: {}). RSA, P-256, and P-384 are supported.",
941 cert.public_key().algorithm.oid()
942 ))
943}
944
945#[cfg(feature = "saml")]
948fn extract_issuer(saml_xml: &str) -> Option<String> {
949 let mut reader = Reader::from_str(saml_xml);
950 let mut in_response = false;
951 let mut in_issuer = false;
952 let mut depth: u32 = 0;
953
954 loop {
955 match reader.read_event() {
956 Ok(Event::Start(e)) => {
957 let local = xml_local(e.name());
958 if local == b"Response" && !in_response {
959 in_response = true;
960 depth = 1;
961 } else if in_response {
962 depth += 1;
963 if local == b"Issuer" && depth == 2 {
965 in_issuer = true;
966 }
967 }
968 }
969 Ok(Event::End(e)) => {
970 let local = xml_local(e.name());
971 if in_issuer && local == b"Issuer" {
972 in_issuer = false;
973 }
974 if in_response {
975 depth -= 1;
976 if depth == 0 {
977 break; }
979 }
980 }
981 Ok(Event::Text(t)) if in_issuer => {
982 if let Ok(text) = t.decode() {
983 let s = text.trim();
984 if !s.is_empty() {
985 return Some(s.to_string());
986 }
987 }
988 }
989 Ok(Event::Eof) => break,
990 Err(_) => break,
991 _ => {}
992 }
993 }
994 None
995}
996
997#[cfg(feature = "saml")]
1000fn extract_in_response_to(saml_xml: &str) -> Option<String> {
1001 let mut reader = Reader::from_str(saml_xml);
1002 loop {
1003 match reader.read_event() {
1004 Ok(Event::Start(e)) | Ok(Event::Empty(e)) => {
1005 if xml_local(e.name()) == b"Response" {
1006 for attr in e.attributes().flatten() {
1007 if attr.key.as_ref() == b"InResponseTo" {
1008 return String::from_utf8(attr.value.to_vec()).ok();
1009 }
1010 }
1011 return None; }
1013 }
1014 Ok(Event::Eof) => break,
1015 Err(_) => break,
1016 _ => {}
1017 }
1018 }
1019 None
1020}
1021
1022#[cfg(not(feature = "saml"))]
1024fn extract_in_response_to(saml_xml: &str) -> Option<String> {
1025 let response_tag_start = saml_xml.find("<samlp:Response")?;
1028 let tag_end = saml_xml[response_tag_start..].find('>')?;
1029 let tag = &saml_xml[response_tag_start..response_tag_start + tag_end];
1030 let attr_start = tag.find("InResponseTo=\"")?;
1031 let value_start = attr_start + "InResponseTo=\"".len();
1032 let value_end = tag[value_start..].find('"')?;
1033 Some(tag[value_start..value_start + value_end].to_string())
1034}
1035
1036#[cfg(feature = "saml")]
1039fn extract_username_from_saml(saml_xml: &str) -> Result<String, String> {
1040 let mut reader = Reader::from_str(saml_xml);
1041 let mut in_assertion = false;
1042 let mut in_name_id = false;
1043
1044 loop {
1045 match reader.read_event() {
1046 Ok(Event::Start(e)) => {
1047 let local = xml_local(e.name());
1048 if local == b"Assertion" {
1049 in_assertion = true;
1050 } else if in_assertion && local == b"NameID" {
1051 in_name_id = true;
1052 }
1053 }
1054 Ok(Event::End(e)) => {
1055 let local = xml_local(e.name());
1056 if in_name_id && local == b"NameID" {
1057 in_name_id = false;
1058 }
1059 if local == b"Assertion" {
1060 break; }
1062 }
1063 Ok(Event::Text(t)) if in_name_id => {
1064 if let Ok(text) = t.decode() {
1065 let s = text.trim();
1066 if !s.is_empty() {
1067 return Ok(s.to_string());
1068 }
1069 }
1070 }
1071 Ok(Event::Eof) => break,
1072 Err(e) => return Err(format!("XML parse error extracting NameID: {}", e)),
1073 _ => {}
1074 }
1075 }
1076
1077 Err("Could not extract username from SAML assertion".to_string())
1078}
1079
1080#[cfg(not(feature = "saml"))]
1082fn extract_username_from_saml(_saml_xml: &str) -> Result<String, String> {
1083 Err("SAML parsing requires the 'saml' feature".to_string())
1085}
1086
1087#[cfg(feature = "saml")]
1089fn extract_attributes_from_saml(saml_xml: &str) -> Result<HashMap<String, Vec<String>>, String> {
1090 let mut attributes = HashMap::new();
1091 attributes.insert("source".to_string(), vec!["saml".to_string()]);
1092 attributes.insert("auth_method".to_string(), vec!["saml_sso".to_string()]);
1093
1094 let mut reader = Reader::from_str(saml_xml);
1095 let mut in_attr_statement = false;
1096 let mut in_attribute = false;
1097 let mut in_attr_value = false;
1098 let mut current_attr_name: Option<String> = None;
1099 let mut current_values: Vec<String> = Vec::new();
1100
1101 loop {
1102 match reader.read_event() {
1103 Ok(Event::Start(e)) => {
1104 let local = xml_local(e.name());
1105 if local == b"AttributeStatement" {
1106 in_attr_statement = true;
1107 } else if in_attr_statement && local == b"Attribute" {
1108 in_attribute = true;
1109 current_values.clear();
1110 current_attr_name = None;
1111 for attr in e.attributes().flatten() {
1112 if xml_local(attr.key) == b"Name" {
1113 current_attr_name = String::from_utf8(attr.value.to_vec()).ok();
1114 }
1115 }
1116 } else if in_attribute && local == b"AttributeValue" {
1117 in_attr_value = true;
1118 }
1119 }
1120 Ok(Event::End(e)) => {
1121 let local = xml_local(e.name());
1122 if local == b"AttributeValue" {
1123 in_attr_value = false;
1124 } else if local == b"Attribute" && in_attribute {
1125 if let Some(name) = current_attr_name.take()
1126 && !current_values.is_empty()
1127 {
1128 attributes.insert(name, std::mem::take(&mut current_values));
1129 }
1130 in_attribute = false;
1131 } else if local == b"AttributeStatement" {
1132 in_attr_statement = false;
1133 }
1134 }
1135 Ok(Event::Text(t)) if in_attr_value => {
1136 if let Ok(text) = t.decode() {
1137 let s = text.trim();
1138 if !s.is_empty() {
1139 current_values.push(s.to_string());
1140 }
1141 }
1142 }
1143 Ok(Event::Eof) => break,
1144 Err(_) => break,
1145 _ => {}
1146 }
1147 }
1148
1149 Ok(attributes)
1150}
1151
1152#[cfg(not(feature = "saml"))]
1154fn extract_attributes_from_saml(_saml_xml: &str) -> Result<HashMap<String, Vec<String>>, String> {
1155 Err("SAML parsing requires the 'saml' feature".to_string())
1156}
1157
1158#[cfg(feature = "saml")]
1164fn validate_saml_conditions(saml_xml: &str, sp_entity_id: &str) -> Result<(), String> {
1165 let mut reader = Reader::from_str(saml_xml);
1166 let mut in_assertion = false;
1167 let mut in_conditions = false;
1168 let mut in_audience_restriction = false;
1169 let mut in_audience = false;
1170 let mut found_conditions = false;
1171 let mut not_before: Option<String> = None;
1172 let mut not_on_or_after: Option<String> = None;
1173 let mut audiences: Vec<String> = Vec::new();
1174
1175 loop {
1176 match reader.read_event() {
1177 Ok(Event::Start(e)) => {
1178 let local = xml_local(e.name());
1179 if local == b"Assertion" {
1180 in_assertion = true;
1181 } else if in_assertion && local == b"Conditions" {
1182 in_conditions = true;
1183 found_conditions = true;
1184 for attr in e.attributes().flatten() {
1185 let key = attr.key.as_ref();
1186 if key == b"NotBefore" {
1187 not_before = String::from_utf8(attr.value.to_vec()).ok();
1188 } else if key == b"NotOnOrAfter" {
1189 not_on_or_after = String::from_utf8(attr.value.to_vec()).ok();
1190 }
1191 }
1192 } else if in_conditions && local == b"AudienceRestriction" {
1193 in_audience_restriction = true;
1194 } else if in_audience_restriction && local == b"Audience" {
1195 in_audience = true;
1196 }
1197 }
1198 Ok(Event::End(e)) => {
1199 let local = xml_local(e.name());
1200 if local == b"Audience" {
1201 in_audience = false;
1202 } else if local == b"AudienceRestriction" {
1203 in_audience_restriction = false;
1204 } else if local == b"Conditions" {
1205 break; } else if local == b"Assertion" {
1207 break;
1208 }
1209 }
1210 Ok(Event::Text(t)) if in_audience => {
1211 if let Ok(text) = t.decode() {
1212 let s = text.trim();
1213 if !s.is_empty() {
1214 audiences.push(s.to_string());
1215 }
1216 }
1217 }
1218 Ok(Event::Eof) => break,
1219 Err(e) => return Err(format!("XML parse error in Conditions: {}", e)),
1220 _ => {}
1221 }
1222 }
1223
1224 if !found_conditions {
1225 return Err("Assertion does not contain a Conditions element".to_string());
1226 }
1227
1228 let skew = chrono::Duration::seconds(60);
1230 let now = chrono::Utc::now();
1231
1232 if let Some(nb) = not_before {
1233 let ts = chrono::DateTime::parse_from_rfc3339(&nb)
1234 .or_else(|_| chrono::DateTime::parse_from_str(&nb, "%Y-%m-%dT%H:%M:%S%.fZ"))
1235 .map_err(|e| format!("Invalid NotBefore timestamp '{}': {}", nb, e))?;
1236 if now < ts.with_timezone(&chrono::Utc) - skew {
1237 return Err(format!("Assertion is not yet valid (NotBefore: {})", nb));
1238 }
1239 }
1240
1241 if let Some(noa) = not_on_or_after {
1242 let ts = chrono::DateTime::parse_from_rfc3339(&noa)
1243 .or_else(|_| chrono::DateTime::parse_from_str(&noa, "%Y-%m-%dT%H:%M:%S%.fZ"))
1244 .map_err(|e| format!("Invalid NotOnOrAfter timestamp '{}': {}", noa, e))?;
1245 if now >= ts.with_timezone(&chrono::Utc) + skew {
1246 return Err(format!("Assertion has expired (NotOnOrAfter: {})", noa));
1247 }
1248 }
1249
1250 if !audiences.is_empty() && !audiences.iter().any(|a| a == sp_entity_id) {
1252 return Err(format!(
1253 "Assertion audience restriction does not include this SP (expected '{}', got {:?})",
1254 sp_entity_id, audiences
1255 ));
1256 }
1257
1258 Ok(())
1259}
1260
1261#[cfg(feature = "saml")]
1264fn xml_extract_status_code(saml_xml: &str) -> Option<String> {
1265 let mut reader = Reader::from_str(saml_xml);
1266 let mut in_status = false;
1267
1268 loop {
1269 match reader.read_event() {
1270 Ok(Event::Start(e)) | Ok(Event::Empty(e)) => {
1271 let local = xml_local(e.name());
1272 if local == b"Status" {
1273 in_status = true;
1274 } else if in_status && local == b"StatusCode" {
1275 for attr in e.attributes().flatten() {
1276 if attr.key.as_ref() == b"Value" {
1277 return String::from_utf8(attr.value.to_vec()).ok();
1278 }
1279 }
1280 return None; }
1282 }
1283 Ok(Event::End(e)) => {
1284 if xml_local(e.name()) == b"Status" {
1285 return None; }
1287 }
1288 Ok(Event::Eof) => break,
1289 Err(_) => break,
1290 _ => {}
1291 }
1292 }
1293 None
1294}
1295
1296#[cfg(feature = "saml")]
1298fn xml_extract_name_id(saml_xml: &str) -> Option<String> {
1299 let mut reader = Reader::from_str(saml_xml);
1300 let mut in_name_id = false;
1301
1302 loop {
1303 match reader.read_event() {
1304 Ok(Event::Start(e)) => {
1305 if xml_local(e.name()) == b"NameID" {
1306 in_name_id = true;
1307 }
1308 }
1309 Ok(Event::Text(e)) if in_name_id => {
1310 if let Ok(text) = e.decode() {
1311 let s = text.trim();
1312 if !s.is_empty() {
1313 return Some(s.to_string());
1314 }
1315 }
1316 }
1317 Ok(Event::End(e)) if in_name_id && xml_local(e.name()) == b"NameID" => {
1318 return None; }
1320 Ok(Event::Eof) => break,
1321 Err(_) => break,
1322 _ => {}
1323 }
1324 }
1325 None
1326}
1327
1328#[cfg(test)]
1329#[cfg(feature = "saml")]
1330mod tests {
1331 use super::*;
1332 use chrono::{Duration, Utc};
1333
1334 #[test]
1335 fn test_extract_issuer() {
1336 let xml = r#"<samlp:Response><saml:Issuer>https://idp.example.com</saml:Issuer></samlp:Response>"#;
1338 assert_eq!(extract_issuer(xml).unwrap(), "https://idp.example.com");
1339 }
1340
1341 #[test]
1342 fn test_extract_username() {
1343 let xml = r#"<saml:Assertion><saml:Subject><saml:NameID>user@example.com</saml:NameID></saml:Subject></saml:Assertion>"#;
1345 assert_eq!(extract_username_from_saml(xml).unwrap(), "user@example.com");
1346 }
1347
1348 #[test]
1349 fn test_validate_conditions_time() {
1350 let now = Utc::now();
1352 let past = now - Duration::minutes(10);
1353 let future = now + Duration::minutes(10);
1354 let xml = format!(
1355 r#"<saml:Assertion><saml:Conditions NotBefore="{}" NotOnOrAfter="{}"><saml:AudienceRestriction><saml:Audience>test-aud</saml:Audience></saml:AudienceRestriction></saml:Conditions></saml:Assertion>"#,
1356 past.to_rfc3339(),
1357 future.to_rfc3339()
1358 );
1359 assert!(validate_saml_conditions(&xml, "test-aud").is_ok());
1360
1361 let wrong_aud = format!(
1362 r#"<saml:Assertion><saml:Conditions NotBefore="{}" NotOnOrAfter="{}"><saml:AudienceRestriction><saml:Audience>wrong-aud</saml:Audience></saml:AudienceRestriction></saml:Conditions></saml:Assertion>"#,
1363 past.to_rfc3339(),
1364 future.to_rfc3339()
1365 );
1366 assert!(validate_saml_conditions(&wrong_aud, "test-aud").is_err());
1367 }
1368
1369 #[test]
1370 fn test_extract_status() {
1371 let xml = r#"<samlp:Status><samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/></samlp:Status>"#;
1372 assert_eq!(
1373 xml_extract_status_code(xml).unwrap(),
1374 "urn:oasis:names:tc:SAML:2.0:status:Success"
1375 );
1376 }
1377}