use axum::{
extract::{Path, State},
http::HeaderMap,
response::{IntoResponse, Redirect, Response},
Form, Json,
};
use base64::{engine::general_purpose, Engine as _};
use chrono::{DateTime, Utc};
use ring::signature::{
UnparsedPublicKey, VerificationAlgorithm, RSA_PKCS1_2048_8192_SHA256,
RSA_PKCS1_2048_8192_SHA512,
};
use rustls_pemfile;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256, Sha512};
use x509_parser::prelude::*;
use crate::{
error::{ApiError, ApiResult},
middleware::{resolve_org_context, AuthUser},
models::{AuditEventType, Organization, Plan, SSOConfiguration, SSOSession, User},
AppState,
};
#[derive(Debug, Deserialize)]
pub struct CreateSSOConfigRequest {
pub provider: String, pub saml_entity_id: Option<String>,
pub saml_sso_url: Option<String>,
pub saml_slo_url: Option<String>,
pub saml_x509_cert: Option<String>,
pub saml_name_id_format: Option<String>,
pub attribute_mapping: Option<serde_json::Value>,
pub require_signed_assertions: Option<bool>,
pub require_signed_responses: Option<bool>,
pub allow_unsolicited_responses: Option<bool>,
}
#[derive(Debug, Serialize)]
pub struct SSOConfigResponse {
pub id: String,
pub org_id: String,
pub provider: String,
pub enabled: bool,
pub saml_entity_id: Option<String>,
pub saml_sso_url: Option<String>,
pub saml_slo_url: Option<String>,
pub saml_name_id_format: Option<String>,
pub attribute_mapping: serde_json::Value,
pub require_signed_assertions: bool,
pub require_signed_responses: bool,
pub allow_unsolicited_responses: bool,
pub created_at: String,
pub updated_at: String,
}
pub async fn create_sso_config(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
Json(request): Json<CreateSSOConfigRequest>,
) -> ApiResult<Json<SSOConfigResponse>> {
let org_ctx = resolve_org_context(&state, user_id, &headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization context required".to_string()))?;
use crate::models::OrgRole;
let is_admin = org_ctx.org.owner_id == user_id || {
if let Ok(Some(member)) = state.store.find_org_member(org_ctx.org_id, user_id).await {
let role = member.role();
matches!(role, OrgRole::Admin | OrgRole::Owner)
} else {
false
}
};
if !is_admin {
return Err(ApiError::PermissionDenied);
}
let org = state
.store
.find_organization_by_id(org_ctx.org_id)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Organization not found".to_string()))?;
if org.plan() != Plan::Team {
return Err(ApiError::InvalidRequest(
"SSO is only available for Team plans. Please upgrade to Team plan to enable SSO."
.to_string(),
));
}
use crate::models::sso::SSOProvider;
let provider = SSOProvider::from_str(&request.provider).ok_or_else(|| {
ApiError::InvalidRequest("Invalid SSO provider. Must be 'saml' or 'oidc'".to_string())
})?;
if provider == SSOProvider::Saml
&& (request.saml_entity_id.is_none()
|| request.saml_sso_url.is_none()
|| request.saml_x509_cert.is_none())
{
return Err(ApiError::InvalidRequest(
"SAML configuration requires entity_id, sso_url, and x509_cert".to_string(),
));
}
let config = state
.store
.upsert_sso_config(
org_ctx.org_id,
provider,
request.saml_entity_id.as_deref(),
request.saml_sso_url.as_deref(),
request.saml_slo_url.as_deref(),
request.saml_x509_cert.as_deref(),
request.saml_name_id_format.as_deref(),
request.attribute_mapping,
request.require_signed_assertions.unwrap_or(true),
request.require_signed_responses.unwrap_or(true),
request.allow_unsolicited_responses.unwrap_or(false),
)
.await?;
let ip_address = headers
.get("X-Forwarded-For")
.or_else(|| headers.get("X-Real-IP"))
.and_then(|h| h.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim().to_string());
let user_agent = headers.get("User-Agent").and_then(|h| h.to_str().ok()).map(|s| s.to_string());
state
.store
.record_audit_event(
org_ctx.org_id,
Some(user_id),
AuditEventType::SettingsUpdated,
"SSO configuration created/updated".to_string(),
Some(serde_json::json!({
"provider": provider.to_string(),
"enabled": config.enabled,
})),
ip_address.as_deref(),
user_agent.as_deref(),
)
.await;
Ok(Json(SSOConfigResponse {
id: config.id.to_string(),
org_id: config.org_id.to_string(),
provider: config.provider,
enabled: config.enabled,
saml_entity_id: config.saml_entity_id,
saml_sso_url: config.saml_sso_url,
saml_slo_url: config.saml_slo_url,
saml_name_id_format: config.saml_name_id_format,
attribute_mapping: config.attribute_mapping,
require_signed_assertions: config.require_signed_assertions,
require_signed_responses: config.require_signed_responses,
allow_unsolicited_responses: config.allow_unsolicited_responses,
created_at: config.created_at.to_rfc3339(),
updated_at: config.updated_at.to_rfc3339(),
}))
}
pub async fn get_sso_config(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
) -> ApiResult<Json<Option<SSOConfigResponse>>> {
let org_ctx = resolve_org_context(&state, user_id, &headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization context required".to_string()))?;
use crate::models::OrgRole;
let is_admin = org_ctx.org.owner_id == user_id || {
if let Ok(Some(member)) = state.store.find_org_member(org_ctx.org_id, user_id).await {
let role = member.role();
matches!(role, OrgRole::Admin | OrgRole::Owner)
} else {
false
}
};
if !is_admin {
return Err(ApiError::PermissionDenied);
}
let config = state.store.find_sso_config_by_org(org_ctx.org_id).await?;
if let Some(config) = config {
Ok(Json(Some(SSOConfigResponse {
id: config.id.to_string(),
org_id: config.org_id.to_string(),
provider: config.provider,
enabled: config.enabled,
saml_entity_id: config.saml_entity_id,
saml_sso_url: config.saml_sso_url,
saml_slo_url: config.saml_slo_url,
saml_name_id_format: config.saml_name_id_format,
attribute_mapping: config.attribute_mapping,
require_signed_assertions: config.require_signed_assertions,
require_signed_responses: config.require_signed_responses,
allow_unsolicited_responses: config.allow_unsolicited_responses,
created_at: config.created_at.to_rfc3339(),
updated_at: config.updated_at.to_rfc3339(),
})))
} else {
Ok(Json(None))
}
}
pub async fn enable_sso(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
) -> ApiResult<Json<serde_json::Value>> {
let org_ctx = resolve_org_context(&state, user_id, &headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization context required".to_string()))?;
use crate::models::OrgRole;
let is_admin = org_ctx.org.owner_id == user_id || {
if let Ok(Some(member)) = state.store.find_org_member(org_ctx.org_id, user_id).await {
let role = member.role();
matches!(role, OrgRole::Admin | OrgRole::Owner)
} else {
false
}
};
if !is_admin {
return Err(ApiError::PermissionDenied);
}
let org = state
.store
.find_organization_by_id(org_ctx.org_id)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Organization not found".to_string()))?;
if org.plan() != Plan::Team {
return Err(ApiError::InvalidRequest("SSO is only available for Team plans".to_string()));
}
let _config = state.store.find_sso_config_by_org(org_ctx.org_id).await?.ok_or_else(|| {
ApiError::InvalidRequest("SSO not configured. Please configure SSO first.".to_string())
})?;
state.store.enable_sso_config(org_ctx.org_id).await?;
let ip_address = headers
.get("X-Forwarded-For")
.or_else(|| headers.get("X-Real-IP"))
.and_then(|h| h.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim().to_string());
let user_agent = headers.get("User-Agent").and_then(|h| h.to_str().ok()).map(|s| s.to_string());
state
.store
.record_audit_event(
org_ctx.org_id,
Some(user_id),
AuditEventType::SettingsUpdated,
"SSO enabled".to_string(),
None,
ip_address.as_deref(),
user_agent.as_deref(),
)
.await;
Ok(Json(serde_json::json!({
"success": true,
"message": "SSO has been enabled successfully"
})))
}
pub async fn disable_sso(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
) -> ApiResult<Json<serde_json::Value>> {
let org_ctx = resolve_org_context(&state, user_id, &headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization context required".to_string()))?;
use crate::models::OrgRole;
let is_admin = org_ctx.org.owner_id == user_id || {
if let Ok(Some(member)) = state.store.find_org_member(org_ctx.org_id, user_id).await {
let role = member.role();
matches!(role, OrgRole::Admin | OrgRole::Owner)
} else {
false
}
};
if !is_admin {
return Err(ApiError::PermissionDenied);
}
state.store.disable_sso_config(org_ctx.org_id).await?;
let ip_address = headers
.get("X-Forwarded-For")
.or_else(|| headers.get("X-Real-IP"))
.and_then(|h| h.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim().to_string());
let user_agent = headers.get("User-Agent").and_then(|h| h.to_str().ok()).map(|s| s.to_string());
state
.store
.record_audit_event(
org_ctx.org_id,
Some(user_id),
AuditEventType::SettingsUpdated,
"SSO disabled".to_string(),
None,
ip_address.as_deref(),
user_agent.as_deref(),
)
.await;
Ok(Json(serde_json::json!({
"success": true,
"message": "SSO has been disabled successfully"
})))
}
pub async fn delete_sso_config(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
headers: HeaderMap,
) -> ApiResult<Json<serde_json::Value>> {
let org_ctx = resolve_org_context(&state, user_id, &headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization context required".to_string()))?;
use crate::models::OrgRole;
let is_admin = org_ctx.org.owner_id == user_id || {
if let Ok(Some(member)) = state.store.find_org_member(org_ctx.org_id, user_id).await {
let role = member.role();
matches!(role, OrgRole::Admin | OrgRole::Owner)
} else {
false
}
};
if !is_admin {
return Err(ApiError::PermissionDenied);
}
state.store.delete_sso_config(org_ctx.org_id).await?;
let ip_address = headers
.get("X-Forwarded-For")
.or_else(|| headers.get("X-Real-IP"))
.and_then(|h| h.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim().to_string());
let user_agent = headers.get("User-Agent").and_then(|h| h.to_str().ok()).map(|s| s.to_string());
state
.store
.record_audit_event(
org_ctx.org_id,
Some(user_id),
AuditEventType::SettingsUpdated,
"SSO configuration deleted".to_string(),
None,
ip_address.as_deref(),
user_agent.as_deref(),
)
.await;
Ok(Json(serde_json::json!({
"success": true,
"message": "SSO configuration has been deleted successfully"
})))
}
pub async fn get_saml_metadata(
State(state): State<AppState>,
Path(org_slug): Path<String>,
) -> ApiResult<axum::response::Response> {
let org = state
.store
.find_organization_by_slug(&org_slug)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Organization not found".to_string()))?;
let config = state.store.find_sso_config_by_org(org.id).await?.ok_or_else(|| {
ApiError::InvalidRequest("SSO not configured for this organization".to_string())
})?;
let app_base_url =
std::env::var("APP_BASE_URL").unwrap_or_else(|_| "https://app.mockforge.dev".to_string());
let entity_id = config
.saml_entity_id
.unwrap_or_else(|| format!("{}/saml/metadata/{}", app_base_url, org_slug));
let acs_url = format!("{}/api/v1/sso/saml/acs/{}", app_base_url, org_slug);
let slo_url = format!("{}/api/v1/sso/saml/slo/{}", app_base_url, org_slug);
let metadata = format!(
r#"<?xml version="1.0"?>
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="{}">
<SPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<NameIDFormat>{}</NameIDFormat>
<AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="{}" index="0"/>
<SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="{}"/>
</SPSSODescriptor>
</EntityDescriptor>"#,
entity_id,
config
.saml_name_id_format
.as_deref()
.unwrap_or("urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"),
acs_url,
slo_url
);
axum::response::Response::builder()
.status(axum::http::StatusCode::OK)
.header("Content-Type", "application/xml")
.body(metadata.into())
.map_err(|e| ApiError::Internal(anyhow::anyhow!("Failed to build response: {}", e)))
}
pub async fn initiate_saml_login(
State(state): State<AppState>,
Path(org_slug): Path<String>,
) -> Result<Response, ApiError> {
let org = state
.store
.find_organization_by_slug(&org_slug)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Organization not found".to_string()))?;
if org.plan() != Plan::Team {
return Err(ApiError::InvalidRequest("SSO is only available for Team plans".to_string()));
}
let config = state.store.find_sso_config_by_org(org.id).await?.ok_or_else(|| {
ApiError::InvalidRequest("SSO not configured for this organization".to_string())
})?;
if !config.enabled {
return Err(ApiError::InvalidRequest(
"SSO is not enabled for this organization".to_string(),
));
}
let sso_url = config
.saml_sso_url
.ok_or_else(|| ApiError::InvalidRequest("SAML SSO URL not configured".to_string()))?;
let app_base_url =
std::env::var("APP_BASE_URL").unwrap_or_else(|_| "https://app.mockforge.dev".to_string());
let acs_url = format!("{}/api/v1/sso/saml/acs/{}", app_base_url, org_slug);
let entity_id = config
.saml_entity_id
.unwrap_or_else(|| format!("{}/saml/metadata/{}", app_base_url, org_slug));
let saml_request = generate_saml_authn_request(&entity_id, &acs_url);
let encoded_request = general_purpose::STANDARD.encode(saml_request.as_bytes());
let redirect_url = format!("{}?SAMLRequest={}", sso_url, urlencoding::encode(&encoded_request));
Ok(Redirect::to(&redirect_url).into_response())
}
#[derive(Debug, Deserialize)]
#[allow(non_snake_case)]
pub struct SAMLResponseForm {
pub SAMLResponse: Option<String>,
pub RelayState: Option<String>,
}
pub async fn saml_acs(
State(state): State<AppState>,
Path(org_slug): Path<String>,
Form(form): Form<SAMLResponseForm>,
) -> Result<Response, ApiError> {
let pool = state.db.pool();
let org = state
.store
.find_organization_by_slug(&org_slug)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Organization not found".to_string()))?;
let config = state
.store
.find_sso_config_by_org(org.id)
.await?
.ok_or_else(|| ApiError::InvalidRequest("SSO not configured".to_string()))?;
if !config.enabled {
return Err(ApiError::InvalidRequest("SSO is not enabled".to_string()));
}
let saml_response = form
.SAMLResponse
.ok_or_else(|| ApiError::InvalidRequest("SAMLResponse parameter missing".to_string()))?;
let decoded_response = general_purpose::STANDARD.decode(&saml_response).map_err(|e| {
ApiError::Internal(anyhow::anyhow!("Failed to decode SAML response: {}", e))
})?;
if config.require_signed_responses {
verify_saml_signature(&decoded_response, &config)?;
}
let user_info = parse_saml_response(&decoded_response, &config, &org).await?;
validate_saml_timestamps(&user_info).map_err(|e| {
tracing::error!("SAML timestamp validation failed for org_id={}: {}", org.id, e);
e
})?;
if let Some(assertion_id) = &user_info.assertion_id {
let is_replay =
state.store.is_saml_assertion_used(assertion_id, org.id).await.map_err(|e| {
tracing::error!(
"Database error checking assertion ID for org_id={}: {:?}",
org.id,
e
);
e
})?;
if is_replay {
tracing::warn!(
"Replay attack detected: assertion_id={} already used for org_id={}",
assertion_id,
org.id
);
return Err(ApiError::InvalidRequest(
"This SAML assertion has already been used. Replay attacks are not allowed."
.to_string(),
));
}
}
let user = find_or_create_user_from_saml(&state, &user_info, &org).await?;
if let Some(assertion_id) = &user_info.assertion_id {
let expires_at = user_info
.not_on_or_after
.unwrap_or_else(|| chrono::Utc::now() + chrono::Duration::hours(1));
let issued_at = user_info.issued_at.unwrap_or_else(chrono::Utc::now);
state
.store
.record_saml_assertion_used(
assertion_id,
org.id,
Some(user.id),
user_info.name_id.as_deref(),
issued_at,
expires_at,
)
.await
.map_err(|e| {
tracing::error!("Failed to record assertion ID for org_id={}: {:?}", org.id, e);
e
})?;
tracing::debug!(
"Recorded assertion ID {} for org_id={}, user_id={}",
assertion_id,
org.id,
user.id
);
}
let session_expires = chrono::Utc::now() + chrono::Duration::hours(8); let _session = SSOSession::create(
pool,
org.id,
user.id,
user_info.session_index.as_deref(),
user_info.name_id.as_deref(),
session_expires,
)
.await
.map_err(ApiError::Database)?;
let token = crate::auth::create_token(&user.id.to_string(), &state.config.jwt_secret)
.map_err(ApiError::Internal)?;
let app_base_url =
std::env::var("APP_BASE_URL").unwrap_or_else(|_| "https://app.mockforge.dev".to_string());
let redirect_url =
format!("{}/auth/sso/callback?token={}&org_slug={}", app_base_url, token, org_slug);
Ok(Redirect::to(&redirect_url).into_response())
}
#[derive(Debug, Deserialize)]
#[allow(non_snake_case)]
pub struct SAMLLogoutForm {
pub SAMLRequest: Option<String>,
pub SAMLResponse: Option<String>,
pub RelayState: Option<String>,
}
pub async fn saml_slo(
State(state): State<AppState>,
Path(org_slug): Path<String>,
Form(form): Form<SAMLLogoutForm>,
) -> Result<Response, ApiError> {
let pool = state.db.pool();
let org = state
.store
.find_organization_by_slug(&org_slug)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Organization not found".to_string()))?;
let config = state
.store
.find_sso_config_by_org(org.id)
.await?
.ok_or_else(|| ApiError::InvalidRequest("SSO not configured".to_string()))?;
if let Some(saml_request) = form.SAMLRequest {
let decoded = general_purpose::STANDARD.decode(&saml_request).map_err(|e| {
ApiError::Internal(anyhow::anyhow!("Failed to decode SAML logout request: {}", e))
})?;
let session_index = parse_saml_logout_request(&decoded)?;
if let Some(session_index) = session_index {
sqlx::query("DELETE FROM sso_sessions WHERE org_id = $1 AND session_index = $2")
.bind(org.id)
.bind(session_index)
.execute(pool)
.await
.map_err(ApiError::Database)?;
}
let slo_url = config
.saml_slo_url
.ok_or_else(|| ApiError::InvalidRequest("SAML SLO URL not configured".to_string()))?;
let logout_response = generate_saml_logout_response(&slo_url);
let encoded_response = general_purpose::STANDARD.encode(logout_response.as_bytes());
let redirect_url =
format!("{}?SAMLResponse={}", slo_url, urlencoding::encode(&encoded_response));
Ok(Redirect::to(&redirect_url).into_response())
} else {
Ok(Redirect::to("/").into_response())
}
}
#[derive(Debug, Clone)]
struct SAMLUserInfo {
assertion_id: Option<String>,
name_id: Option<String>,
email: Option<String>,
username: Option<String>,
first_name: Option<String>,
last_name: Option<String>,
session_index: Option<String>,
attributes: serde_json::Value,
not_before: Option<DateTime<Utc>>,
not_on_or_after: Option<DateTime<Utc>>,
issued_at: Option<DateTime<Utc>>,
}
fn generate_saml_authn_request(entity_id: &str, acs_url: &str) -> String {
let request_id = uuid::Uuid::new_v4().to_string();
let issue_instant = chrono::Utc::now().to_rfc3339();
format!(
r#"<samlp:AuthnRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
ID="_{}"
Version="2.0"
IssueInstant="{}"
Destination="{}"
AssertionConsumerServiceURL="{}"
ProtocolBinding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST">
<saml:Issuer>{}</saml:Issuer>
<samlp:NameIDPolicy Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress" AllowCreate="true"/>
</samlp:AuthnRequest>"#,
request_id,
issue_instant,
acs_url, acs_url,
entity_id
)
}
async fn parse_saml_response(
response_xml: &[u8],
config: &SSOConfiguration,
_org: &Organization,
) -> Result<SAMLUserInfo, ApiError> {
let xml_str = std::str::from_utf8(response_xml).map_err(|e| {
ApiError::Internal(anyhow::anyhow!("Invalid UTF-8 in SAML response: {}", e))
})?;
let name_id = extract_xml_value(xml_str, "NameID")
.or_else(|| extract_xml_value(xml_str, "saml:NameID"))
.or_else(|| extract_xml_value(xml_str, "saml2:NameID"));
let email = name_id
.clone()
.filter(|v| v.contains('@'))
.or_else(|| extract_xml_value(xml_str, "AttributeValue").filter(|v| v.contains('@')));
let session_index = extract_xml_value(xml_str, "SessionIndex")
.or_else(|| extract_xml_value(xml_str, "samlp:SessionIndex"));
let assertion_id = extract_xml_value(xml_str, "Assertion")
.and_then(|a| {
regex::Regex::new(r#"ID="([^"]+)""#)
.ok()?
.captures(&a)
.and_then(|cap| cap.get(1))
.map(|m| m.as_str().to_string())
})
.or_else(|| {
regex::Regex::new(r#"<[^:]*:?Assertion[^>]*ID="([^"]+)""#)
.ok()?
.captures(xml_str)
.and_then(|cap| cap.get(1))
.map(|m| m.as_str().to_string())
});
let not_before = extract_xml_value(xml_str, "NotBefore")
.or_else(|| extract_xml_value(xml_str, "saml:NotBefore"))
.or_else(|| extract_xml_value(xml_str, "saml2:NotBefore"))
.and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
.map(|dt| dt.with_timezone(&chrono::Utc));
let not_on_or_after = extract_xml_value(xml_str, "NotOnOrAfter")
.or_else(|| extract_xml_value(xml_str, "saml:NotOnOrAfter"))
.or_else(|| extract_xml_value(xml_str, "saml2:NotOnOrAfter"))
.and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
.map(|dt| dt.with_timezone(&chrono::Utc));
let issued_at = extract_xml_value(xml_str, "IssueInstant")
.or_else(|| extract_xml_value(xml_str, "saml:IssueInstant"))
.or_else(|| extract_xml_value(xml_str, "saml2:IssueInstant"))
.and_then(|s| chrono::DateTime::parse_from_rfc3339(&s).ok())
.map(|dt| dt.with_timezone(&chrono::Utc));
let mut attributes = serde_json::json!({});
if let Some(mapping) = config.attribute_mapping.as_object() {
for (target_key, source_key) in mapping {
if let Some(source_key_str) = source_key.as_str() {
if let Some(source_value) = extract_xml_value(xml_str, source_key_str) {
attributes[target_key] = serde_json::Value::String(source_value);
}
}
}
}
let first_name =
extract_xml_value(xml_str, "FirstName").or_else(|| extract_xml_value(xml_str, "givenName"));
let last_name =
extract_xml_value(xml_str, "LastName").or_else(|| extract_xml_value(xml_str, "surname"));
let username = extract_xml_value(xml_str, "Username")
.or_else(|| email.as_ref().map(|e| e.split('@').next().unwrap_or("user").to_string()));
Ok(SAMLUserInfo {
assertion_id,
name_id,
email,
username,
first_name,
last_name,
session_index,
attributes,
not_before,
not_on_or_after,
issued_at,
})
}
fn verify_saml_signature(xml: &[u8], config: &SSOConfiguration) -> Result<(), ApiError> {
tracing::debug!("Verifying SAML signature for org_id={}", config.org_id);
let cert_pem = config.saml_x509_cert.as_ref().ok_or_else(|| {
tracing::error!("X.509 certificate not configured for org_id={}", config.org_id);
ApiError::InvalidRequest("SAML X.509 certificate not configured".to_string())
})?;
let cert_pem_bytes = cert_pem.as_bytes().to_vec();
let mut reader = std::io::Cursor::new(&cert_pem_bytes);
let certs: Vec<Vec<u8>> = rustls_pemfile::certs(&mut reader)
.map(|result| result.map(|cert| cert.to_vec()))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
tracing::error!("Failed to parse PEM certificate: {}", e);
ApiError::Internal(anyhow::anyhow!("Invalid PEM certificate format"))
})?;
if certs.is_empty() {
return Err(ApiError::InvalidRequest("No certificate found in PEM data".to_string()));
}
let first_cert = certs[0].clone();
let (_, cert) = X509Certificate::from_der(&first_cert).map_err(|e| {
tracing::error!("Failed to parse X.509 certificate DER: {:?}", e);
ApiError::Internal(anyhow::anyhow!("Invalid X.509 certificate format"))
})?;
cert.validity().time_to_expiration().ok_or_else(|| {
tracing::warn!("SAML certificate expired or invalid for org_id={}", config.org_id);
ApiError::InvalidRequest("SAML certificate has expired or is invalid".to_string())
})?;
let xml_str = std::str::from_utf8(xml).map_err(|e| {
ApiError::Internal(anyhow::anyhow!("Invalid UTF-8 in SAML response: {}", e))
})?;
let has_response_signature = xml_str.contains("<ds:Signature")
|| xml_str.contains("<Signature")
|| xml_str.contains("xmlns:ds=\"http://www.w3.org/2000/09/xmldsig#\"");
if !has_response_signature && config.require_signed_responses {
tracing::error!("SAML response missing signature for org_id={}", config.org_id);
return Err(ApiError::InvalidRequest(
"SAML response is not signed but signature is required".to_string(),
));
}
let public_key = cert.public_key();
if has_response_signature {
verify_xml_signature(xml_str, &first_cert, public_key).map_err(|e| {
tracing::error!(
"SAML response signature verification failed for org_id={}: {}",
config.org_id,
e
);
ApiError::InvalidRequest(format!("SAML response signature verification failed: {}", e))
})?;
}
if config.require_signed_assertions {
let has_assertion_signature = xml_str.contains("<Assertion")
&& (xml_str.contains("<ds:Signature") || xml_str.contains("<Signature"));
if !has_assertion_signature {
tracing::error!("SAML assertion missing signature for org_id={}", config.org_id);
return Err(ApiError::InvalidRequest(
"SAML assertion is not signed but signature is required".to_string(),
));
}
verify_xml_signature(xml_str, &first_cert, public_key).map_err(|e| {
tracing::error!(
"SAML assertion signature verification failed for org_id={}: {}",
config.org_id,
e
);
ApiError::InvalidRequest(format!("SAML assertion signature verification failed: {}", e))
})?;
}
tracing::info!("SAML signature validation passed for org_id={}", config.org_id);
Ok(())
}
fn verify_xml_signature(
xml: &str,
cert_der: &[u8],
_public_key: &SubjectPublicKeyInfo<'_>,
) -> Result<(), String> {
let signature_value = extract_signature_value(xml)
.ok_or_else(|| "Signature value not found in XML".to_string())?;
let signed_info =
extract_signed_info(xml).ok_or_else(|| "SignedInfo not found in XML".to_string())?;
let signature_bytes = general_purpose::STANDARD
.decode(&signature_value)
.map_err(|e| format!("Failed to decode signature: {}", e))?;
let algorithm_str =
extract_signature_algorithm(xml).unwrap_or_else(|| "rsa-sha256".to_string());
let signed_info_bytes = signed_info.as_bytes();
let hash = match algorithm_str.as_str() {
"rsa-sha256" | "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" => {
let mut hasher = Sha256::new();
hasher.update(signed_info_bytes);
hasher.finalize().to_vec()
}
"rsa-sha512" | "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512" => {
let mut hasher = Sha512::new();
hasher.update(signed_info_bytes);
hasher.finalize().to_vec()
}
_ => {
let mut hasher = Sha256::new();
hasher.update(signed_info_bytes);
hasher.finalize().to_vec()
}
};
let verification_alg: &dyn VerificationAlgorithm = match algorithm_str.as_str() {
"rsa-sha256"
| "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"
| "http://www.w3.org/2000/09/xmldsig#rsa-sha256" => &RSA_PKCS1_2048_8192_SHA256,
"rsa-sha512"
| "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512"
| "http://www.w3.org/2000/09/xmldsig#rsa-sha512" => &RSA_PKCS1_2048_8192_SHA512,
_ => &RSA_PKCS1_2048_8192_SHA256,
};
let public_key_unparsed = UnparsedPublicKey::new(verification_alg, cert_der);
public_key_unparsed
.verify(&hash, &signature_bytes)
.map_err(|e| format!("Signature verification failed: {:?}", e))?;
Ok(())
}
fn extract_signature_value(xml: &str) -> Option<String> {
let patterns = [
r#"<ds:SignatureValue[^>]*>(.*?)</ds:SignatureValue>"#,
r#"<SignatureValue[^>]*>(.*?)</SignatureValue>"#,
];
for pattern in &patterns {
if let Ok(re) = regex::Regex::new(pattern) {
if let Some(cap) = re.captures(xml) {
if let Some(value) = cap.get(1) {
return Some(value.as_str().trim().to_string());
}
}
}
}
None
}
fn extract_signed_info(xml: &str) -> Option<String> {
let patterns = [
r#"<ds:SignedInfo[^>]*>(.*?)</ds:SignedInfo>"#,
r#"<SignedInfo[^>]*>(.*?)</SignedInfo>"#,
];
for pattern in &patterns {
if let Ok(re) = regex::Regex::new(pattern) {
if let Some(cap) = re.captures(xml) {
if let Some(value) = cap.get(1) {
return Some(value.as_str().to_string());
}
}
}
}
None
}
fn extract_signature_algorithm(xml: &str) -> Option<String> {
let patterns = [
r#"<ds:SignatureMethod[^>]*Algorithm="([^"]+)""#,
r#"<SignatureMethod[^>]*Algorithm="([^"]+)""#,
];
for pattern in &patterns {
if let Ok(re) = regex::Regex::new(pattern) {
if let Some(cap) = re.captures(xml) {
if let Some(value) = cap.get(1) {
return Some(value.as_str().to_string());
}
}
}
}
None
}
fn extract_xml_value(xml: &str, tag: &str) -> Option<String> {
let pattern = format!(r#"<{}[^>]*>(.*?)</{}>"#, tag, tag);
if let Ok(re) = regex::Regex::new(&pattern) {
if let Some(cap) = re.captures(xml) {
return Some(cap.get(1)?.as_str().to_string());
}
}
for prefix in &["saml:", "saml2:", "samlp:", "ds:"] {
let pattern = format!(r#"<{}{}[^>]*>(.*?)</{}{}>"#, prefix, tag, prefix, tag);
if let Ok(re) = regex::Regex::new(&pattern) {
if let Some(cap) = re.captures(xml) {
return Some(cap.get(1)?.as_str().to_string());
}
}
}
None
}
fn parse_saml_logout_request(request_xml: &[u8]) -> Result<Option<String>, ApiError> {
let xml_str = String::from_utf8_lossy(request_xml);
let session_index = extract_xml_value(&xml_str, "SessionIndex")
.or_else(|| extract_xml_value(&xml_str, "samlp:SessionIndex"));
Ok(session_index)
}
fn generate_saml_logout_response(slo_url: &str) -> String {
let response_id = uuid::Uuid::new_v4().to_string();
let issue_instant = chrono::Utc::now().to_rfc3339();
format!(
r#"<samlp:LogoutResponse xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
ID="_{}"
Version="2.0"
IssueInstant="{}"
Destination="{}"
StatusCode="urn:oasis:names:tc:SAML:2.0:status:Success"/>
"#,
response_id, issue_instant, slo_url
)
}
async fn find_or_create_user_from_saml(
state: &AppState,
user_info: &SAMLUserInfo,
org: &Organization,
) -> Result<User, ApiError> {
let user = if let Some(email) = &user_info.email {
state.store.find_user_by_email(email).await?
} else {
None
};
let user = if let Some(user) = user {
use crate::models::organization::OrgRole;
if state.store.find_org_member(org.id, user.id).await?.is_none() {
state.store.create_org_member(org.id, user.id, OrgRole::Member).await?;
}
user
} else {
let email = user_info.email.as_ref().ok_or_else(|| {
ApiError::InvalidRequest("Email not found in SAML assertion".to_string())
})?;
let username = user_info.username.as_ref().cloned().unwrap_or_else(|| {
email.split('@').next().unwrap_or("user").to_string()
});
let password_hash = crate::auth::hash_password(&uuid::Uuid::new_v4().to_string())
.map_err(ApiError::Internal)?;
let user = state.store.create_user(&username, email, &password_hash).await?;
state.store.mark_user_verified(user.id).await?;
use crate::models::organization::OrgRole;
state.store.create_org_member(org.id, user.id, OrgRole::Member).await?;
user
};
Ok(user)
}
fn validate_saml_timestamps(user_info: &SAMLUserInfo) -> Result<(), ApiError> {
let now = chrono::Utc::now();
if let Some(not_before) = user_info.not_before {
let tolerance = chrono::Duration::minutes(5);
if now < not_before - tolerance {
tracing::warn!("SAML assertion not yet valid: not_before={}, now={}", not_before, now);
return Err(ApiError::InvalidRequest(format!(
"SAML assertion is not yet valid. Valid from: {}",
not_before
)));
}
}
if let Some(not_on_or_after) = user_info.not_on_or_after {
let tolerance = chrono::Duration::minutes(5);
if now > not_on_or_after + tolerance {
tracing::warn!(
"SAML assertion expired: not_on_or_after={}, now={}",
not_on_or_after,
now
);
return Err(ApiError::InvalidRequest(format!(
"SAML assertion has expired. Expired at: {}",
not_on_or_after
)));
}
} else {
if let Some(issued_at) = user_info.issued_at {
let max_validity = issued_at + chrono::Duration::minutes(5);
if now > max_validity {
tracing::warn!(
"SAML assertion exceeded default validity: issued_at={}, now={}",
issued_at,
now
);
return Err(ApiError::InvalidRequest(
"SAML assertion has exceeded maximum validity period (5 minutes)".to_string(),
));
}
}
}
tracing::debug!("SAML timestamp validation passed");
Ok(())
}