#[cfg(feature = "saml")]
use async_trait::async_trait;
#[cfg(feature = "saml")]
use base64::{engine::general_purpose, Engine as _};
#[cfg(feature = "saml")]
use chrono::{DateTime, Utc};
#[cfg(feature = "saml")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "saml")]
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
#[cfg(feature = "saml")]
use tokio::sync::RwLock;
#[cfg(feature = "saml")]
use url::Url;
#[cfg(feature = "saml")]
use uuid::Uuid;
use crate::{
auth::{AuthResult, User},
error::{FusekiError, FusekiResult},
};
#[derive(Debug, Clone)]
pub struct SamlConfig {
pub sp: ServiceProviderConfig,
pub idp: IdentityProviderConfig,
pub attribute_mapping: AttributeMapping,
pub session: SessionConfig,
}
pub type SamlSpConfig = ServiceProviderConfig;
pub type SamlAttributeMappings = AttributeMapping;
#[derive(Debug, Clone)]
pub struct ServiceProviderConfig {
pub entity_id: String,
pub acs_url: Url,
pub sls_url: Option<Url>,
pub certificate: Option<String>,
pub private_key: Option<String>,
}
#[derive(Debug, Clone)]
pub struct IdentityProviderConfig {
pub entity_id: String,
pub sso_url: Url,
pub slo_url: Option<Url>,
pub certificate: String,
pub metadata_url: Option<Url>,
}
#[derive(Debug, Clone)]
pub struct AttributeMapping {
pub username: String,
pub email: Option<String>,
pub display_name: Option<String>,
pub groups: Option<String>,
pub custom: HashMap<String, String>,
}
impl Default for AttributeMapping {
fn default() -> Self {
Self {
username: "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name".to_string(),
email: Some(
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress".to_string(),
),
display_name: Some(
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname".to_string(),
),
groups: Some("http://schemas.xmlsoap.org/claims/Group".to_string()),
custom: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub timeout: Duration,
pub allow_idp_initiated: bool,
pub force_authn: bool,
pub track_session_index: bool,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(3600), allow_idp_initiated: false,
force_authn: false,
track_session_index: true,
}
}
}
pub struct SamlProvider {
pub config: SamlConfig,
sessions: Arc<RwLock<HashMap<String, SamlSession>>>,
pending_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
}
#[derive(Debug, Clone)]
struct SamlSession {
user: User,
session_index: Option<String>,
created_at: SystemTime,
expires_at: SystemTime,
attributes: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone)]
struct PendingRequest {
id: String,
relay_state: Option<String>,
timestamp: SystemTime,
}
#[derive(Debug, Serialize)]
pub struct AuthnRequest {
pub id: String,
pub issue_instant: DateTime<Utc>,
pub destination: Url,
pub issuer: String,
pub acs_url: Url,
pub protocol_binding: String,
pub force_authn: bool,
}
impl AuthnRequest {
pub fn new(config: &SamlConfig) -> Self {
Self {
id: format!("_{}", Uuid::new_v4()),
issue_instant: Utc::now(),
destination: config.idp.sso_url.clone(),
issuer: config.sp.entity_id.clone(),
acs_url: config.sp.acs_url.clone(),
protocol_binding: "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST".to_string(),
force_authn: config.session.force_authn,
}
}
pub fn to_xml(&self) -> String {
format!(
r#"<samlp:AuthnRequest
xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
ID="{}"
Version="2.0"
IssueInstant="{}"
Destination="{}"
ProtocolBinding="{}"
AssertionConsumerServiceURL="{}"
ForceAuthn="{}">
<saml:Issuer>{}</saml:Issuer>
</samlp:AuthnRequest>"#,
self.id,
self.issue_instant.to_rfc3339(),
self.destination,
self.protocol_binding,
self.acs_url,
self.force_authn,
self.issuer
)
}
}
#[derive(Debug, Deserialize)]
pub struct SamlResponse {
pub status: ResponseStatus,
pub assertions: Vec<Assertion>,
pub in_response_to: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ResponseStatus {
pub code: String,
pub message: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct Assertion {
pub subject: Subject,
pub attributes: Vec<Attribute>,
pub conditions: Option<Conditions>,
pub authn_statement: Option<AuthnStatement>,
}
#[derive(Debug, Deserialize)]
pub struct Subject {
pub name_id: String,
pub format: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct Attribute {
pub name: String,
pub values: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct Conditions {
pub not_before: Option<DateTime<Utc>>,
pub not_on_or_after: Option<DateTime<Utc>>,
}
#[derive(Debug, Deserialize)]
pub struct AuthnStatement {
pub session_index: Option<String>,
pub authn_instant: DateTime<Utc>,
}
impl SamlProvider {
pub fn new(config: SamlConfig) -> Self {
Self {
config,
sessions: Arc::new(RwLock::new(HashMap::new())),
pending_requests: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn generate_login_url(&self, relay_state: Option<String>) -> FusekiResult<Url> {
let request = AuthnRequest::new(&self.config);
let request_xml = request.to_xml();
let pending = PendingRequest {
id: request.id.clone(),
relay_state,
timestamp: SystemTime::now(),
};
let relay_state_clone = pending.relay_state.clone();
let mut pending_requests = self.pending_requests.write().await;
pending_requests.insert(request.id.clone(), pending);
let encoded = general_purpose::STANDARD.encode(request_xml.as_bytes());
let mut url = self.config.idp.sso_url.clone();
url.query_pairs_mut().append_pair("SAMLRequest", &encoded);
if let Some(relay) = &relay_state_clone {
url.query_pairs_mut().append_pair("RelayState", relay);
}
Ok(url)
}
pub async fn process_response(
&self,
saml_response: &str,
relay_state: Option<&str>,
) -> FusekiResult<User> {
let decoded = general_purpose::STANDARD
.decode(saml_response)
.map_err(|e| {
FusekiError::authentication(format!("Failed to decode SAML response: {}", e))
})?;
let response_xml = String::from_utf8(decoded).map_err(|e| {
FusekiError::authentication(format!("Invalid UTF-8 in SAML response: {}", e))
})?;
let response = self.parse_response(&response_xml)?;
self.validate_response(&response)?;
let user = self.extract_user_info(&response)?;
let session = SamlSession {
user: user.clone(),
session_index: response
.assertions
.first()
.and_then(|a| a.authn_statement.as_ref())
.and_then(|s| s.session_index.clone()),
created_at: SystemTime::now(),
expires_at: SystemTime::now() + self.config.session.timeout,
attributes: self.extract_attributes(&response),
};
let session_id = Uuid::new_v4().to_string();
let mut sessions = self.sessions.write().await;
sessions.insert(session_id, session);
if let Some(in_response_to) = &response.in_response_to {
let mut pending = self.pending_requests.write().await;
pending.remove(in_response_to);
}
Ok(user)
}
fn parse_response(&self, _xml: &str) -> FusekiResult<SamlResponse> {
Err(FusekiError::service_unavailable(
"SAML response parsing not yet implemented",
))
}
fn validate_response(&self, response: &SamlResponse) -> FusekiResult<()> {
if response.status.code != "urn:oasis:names:tc:SAML:2.0:status:Success" {
return Err(FusekiError::authentication(format!(
"SAML authentication failed: {}",
response
.status
.message
.as_ref()
.unwrap_or(&"Unknown error".to_string())
)));
}
if response.assertions.is_empty() {
return Err(FusekiError::authentication(
"No assertions in SAML response",
));
}
for assertion in &response.assertions {
if let Some(conditions) = &assertion.conditions {
let now = Utc::now();
if let Some(not_before) = &conditions.not_before {
if now < *not_before {
return Err(FusekiError::authentication("SAML assertion not yet valid"));
}
}
if let Some(not_after) = &conditions.not_on_or_after {
if now >= *not_after {
return Err(FusekiError::authentication("SAML assertion expired"));
}
}
}
}
Ok(())
}
fn extract_user_info(&self, response: &SamlResponse) -> FusekiResult<User> {
let assertion = response
.assertions
.first()
.ok_or_else(|| FusekiError::authentication("No assertion found"))?;
let mut user = User {
username: assertion.subject.name_id.clone(),
email: None,
full_name: None,
roles: vec!["user".to_string()],
last_login: Some(chrono::Utc::now()),
permissions: vec![],
};
for attr in &assertion.attributes {
if attr.name == self.config.attribute_mapping.username && !attr.values.is_empty() {
user.username = attr.values[0].clone();
}
if let Some(email_attr) = &self.config.attribute_mapping.email {
if attr.name == *email_attr && !attr.values.is_empty() {
user.email = Some(attr.values[0].clone());
}
}
if let Some(display_attr) = &self.config.attribute_mapping.display_name {
if attr.name == *display_attr && !attr.values.is_empty() {
user.full_name = Some(attr.values[0].clone());
}
}
if let Some(groups_attr) = &self.config.attribute_mapping.groups {
if attr.name == *groups_attr {
for group in &attr.values {
match group.as_str() {
"admin" | "administrators" => user.roles.push("admin".to_string()),
"editor" | "editors" => user.roles.push("editor".to_string()),
_ => {}
}
}
}
}
}
Ok(user)
}
fn extract_attributes(&self, response: &SamlResponse) -> HashMap<String, Vec<String>> {
let mut attributes = HashMap::new();
for assertion in &response.assertions {
for attr in &assertion.attributes {
attributes.insert(attr.name.clone(), attr.values.clone());
}
}
attributes
}
pub async fn generate_logout_url(&self, user: &User) -> FusekiResult<Option<Url>> {
if let Some(slo_url) = &self.config.idp.slo_url {
Ok(Some(slo_url.clone()))
} else {
Ok(None)
}
}
pub async fn cleanup_sessions(&self) {
let mut sessions = self.sessions.write().await;
let now = SystemTime::now();
sessions.retain(|_, session| session.expires_at > now);
let mut pending = self.pending_requests.write().await;
let timeout = Duration::from_secs(300);
pending.retain(|_, request| {
now.duration_since(request.timestamp)
.unwrap_or(Duration::MAX)
< timeout
});
}
pub async fn get_session_by_index(&self, session_index: &str) -> FusekiResult<Option<String>> {
let sessions = self.sessions.read().await;
for (session_id, session) in sessions.iter() {
if let Some(index) = &session.session_index {
if index == session_index {
return Ok(Some(session_id.clone()));
}
}
}
Ok(None)
}
pub async fn generate_logout_request(
&self,
session_index: &str,
name_id: &str,
) -> FusekiResult<String> {
let slo_url = self
.config
.idp
.slo_url
.as_ref()
.ok_or_else(|| FusekiError::configuration("SAML SLO not configured"))?;
let request_id = format!("_{}", Uuid::new_v4());
let logout_request = format!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<samlp:LogoutRequest
xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
ID="{}"
Version="2.0"
IssueInstant="{}"
Destination="{}">
<saml:Issuer>{}</saml:Issuer>
<saml:NameID>{}</saml:NameID>
<samlp:SessionIndex>{}</samlp:SessionIndex>
</samlp:LogoutRequest>"#,
request_id,
chrono::Utc::now().to_rfc3339(),
slo_url,
self.config.sp.entity_id,
name_id,
session_index
);
let encoded = general_purpose::STANDARD.encode(logout_request.as_bytes());
let mut logout_url = slo_url.clone();
logout_url
.query_pairs_mut()
.append_pair("SAMLRequest", &encoded);
Ok(logout_url.to_string())
}
pub fn get_metadata(&self) -> String {
format!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"
entityID="{}">
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="true"
protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
Location="{}" index="1" isDefault="true"/>
{}
</md:SPSSODescriptor>
</md:EntityDescriptor>"#,
self.config.sp.entity_id,
self.config.sp.acs_url,
self.config.sp.sls_url.as_ref()
.map(|url| format!(r#"<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="{}"/>"#, url))
.unwrap_or_default()
)
}
}
impl SamlProvider {
pub async fn authenticate(&self, _username: &str, _password: &str) -> FusekiResult<AuthResult> {
Ok(AuthResult::Invalid)
}
pub async fn validate_token(&self, token: &str) -> FusekiResult<AuthResult> {
let sessions = self.sessions.read().await;
if let Some(session) = sessions.get(token) {
if session.expires_at > SystemTime::now() {
Ok(AuthResult::Authenticated(session.user.clone()))
} else {
Ok(AuthResult::Expired)
}
} else {
Ok(AuthResult::Invalid)
}
}
pub async fn refresh_token(&self, _token: &str) -> FusekiResult<String> {
Err(FusekiError::bad_request(
"SAML does not support token refresh",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_authn_request_generation() {
let config = SamlConfig {
sp: ServiceProviderConfig {
entity_id: "http://sp.example.com".to_string(),
acs_url: Url::parse("http://sp.example.com/acs").unwrap(),
sls_url: None,
certificate: None,
private_key: None,
},
idp: IdentityProviderConfig {
entity_id: "http://idp.example.com".to_string(),
sso_url: Url::parse("http://idp.example.com/sso").unwrap(),
slo_url: None,
certificate: "dummy-cert".to_string(),
metadata_url: None,
},
attribute_mapping: AttributeMapping::default(),
session: SessionConfig::default(),
};
let request = AuthnRequest::new(&config);
let xml = request.to_xml();
assert!(xml.contains(&config.sp.entity_id));
assert!(xml.contains(&config.sp.acs_url.to_string()));
}
#[test]
fn test_attribute_mapping() {
let mapping = AttributeMapping::default();
assert_eq!(
mapping.username,
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"
);
assert!(mapping.email.is_some());
}
}