use crate::errors::{AuthError, Result};
use crate::security::secure_jwt::{SecureJwtClaims, SecureJwtConfig, SecureJwtValidator};
use crate::server::oidc::oidc_response_modes::ResponseMode;
use crate::server::oidc::oidc_session_management::SessionManager;
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{DecodingKey, EncodingKey};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Clone)]
pub struct EnhancedCibaConfig {
pub supported_modes: Vec<AuthenticationMode>,
pub default_auth_req_expiry: Duration,
pub max_polling_interval: u64,
pub min_polling_interval: u64,
pub enable_consent: bool,
pub enable_device_binding: bool,
pub supported_response_modes: Vec<ResponseMode>,
pub max_binding_message_length: usize,
pub enable_advanced_context: bool,
pub jwt_config: SecureJwtConfig,
pub issuer: String,
pub encoding_key: Option<EncodingKey>,
pub decoding_key: Option<DecodingKey>,
pub access_token_lifetime: u64,
pub id_token_lifetime: u64,
pub refresh_token_lifetime: u64,
pub max_notification_retries: u32,
pub notification_retry_backoff: u64,
pub notification_timeout: u64,
}
impl std::fmt::Debug for EnhancedCibaConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnhancedCibaConfig")
.field("supported_modes", &self.supported_modes)
.field("default_auth_req_expiry", &self.default_auth_req_expiry)
.field("max_polling_interval", &self.max_polling_interval)
.field("min_polling_interval", &self.min_polling_interval)
.field("enable_consent", &self.enable_consent)
.field("enable_device_binding", &self.enable_device_binding)
.field("supported_response_modes", &self.supported_response_modes)
.field(
"max_binding_message_length",
&self.max_binding_message_length,
)
.field("enable_advanced_context", &self.enable_advanced_context)
.field("issuer", &self.issuer)
.field("encoding_key", &self.encoding_key.is_some())
.field("decoding_key", &self.decoding_key.is_some())
.field("access_token_lifetime", &self.access_token_lifetime)
.field("id_token_lifetime", &self.id_token_lifetime)
.field("refresh_token_lifetime", &self.refresh_token_lifetime)
.field("max_notification_retries", &self.max_notification_retries)
.field(
"notification_retry_backoff",
&self.notification_retry_backoff,
)
.field("notification_timeout", &self.notification_timeout)
.finish()
}
}
impl Default for EnhancedCibaConfig {
fn default() -> Self {
let mut jwt_config = SecureJwtConfig::default();
jwt_config.allowed_token_types.insert("id".to_string());
jwt_config.allowed_token_types.insert("ciba".to_string());
Self {
supported_modes: vec![
AuthenticationMode::Poll,
AuthenticationMode::Ping,
AuthenticationMode::Push,
],
default_auth_req_expiry: Duration::minutes(10),
max_polling_interval: 60,
min_polling_interval: 2,
enable_consent: true,
enable_device_binding: true,
supported_response_modes: vec![
ResponseMode::Query,
ResponseMode::Fragment,
ResponseMode::FormPost,
ResponseMode::JwtQuery,
],
max_binding_message_length: 1024,
enable_advanced_context: true,
jwt_config,
issuer: "auth-framework-ciba".to_string(),
encoding_key: None, decoding_key: None, access_token_lifetime: 3600, id_token_lifetime: 3600, refresh_token_lifetime: 86400, max_notification_retries: 3,
notification_retry_backoff: 5, notification_timeout: 30, }
}
}
impl EnhancedCibaConfig {
pub fn builder() -> EnhancedCibaConfigBuilder {
EnhancedCibaConfigBuilder {
inner: Self::default(),
}
}
}
pub struct EnhancedCibaConfigBuilder {
inner: EnhancedCibaConfig,
}
impl EnhancedCibaConfigBuilder {
pub fn supported_modes(mut self, modes: Vec<AuthenticationMode>) -> Self {
self.inner.supported_modes = modes;
self
}
pub fn default_auth_req_expiry(mut self, expiry: Duration) -> Self {
self.inner.default_auth_req_expiry = expiry;
self
}
pub fn max_polling_interval(mut self, secs: u64) -> Self {
self.inner.max_polling_interval = secs;
self
}
pub fn min_polling_interval(mut self, secs: u64) -> Self {
self.inner.min_polling_interval = secs;
self
}
pub fn enable_consent(mut self, enable: bool) -> Self {
self.inner.enable_consent = enable;
self
}
pub fn enable_device_binding(mut self, enable: bool) -> Self {
self.inner.enable_device_binding = enable;
self
}
pub fn max_binding_message_length(mut self, max_len: usize) -> Self {
self.inner.max_binding_message_length = max_len;
self
}
pub fn build(self) -> EnhancedCibaConfig {
self.inner
}
pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
self.inner.issuer = issuer.into();
self
}
pub fn encoding_key(mut self, key: EncodingKey) -> Self {
self.inner.encoding_key = Some(key);
self
}
pub fn decoding_key(mut self, key: DecodingKey) -> Self {
self.inner.decoding_key = Some(key);
self
}
pub fn token_lifetimes(
mut self,
access_token_secs: u64,
id_token_secs: u64,
refresh_token_secs: u64,
) -> Self {
self.inner.access_token_lifetime = access_token_secs;
self.inner.id_token_lifetime = id_token_secs;
self.inner.refresh_token_lifetime = refresh_token_secs;
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum AuthenticationMode {
Poll,
Ping,
Push,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnhancedCibaAuthRequest {
pub auth_req_id: String,
pub client_id: String,
pub user_hint: UserIdentifierHint,
pub binding_message: Option<String>,
pub auth_context: Option<AuthenticationContext>,
pub scopes: Vec<String>,
pub mode: AuthenticationMode,
pub client_notification_endpoint: Option<String>,
pub client_notification_token: Option<String>,
pub expires_at: DateTime<Utc>,
pub created_at: DateTime<Utc>,
pub status: CibaRequestStatus,
pub session_id: Option<String>,
pub device_binding: Option<DeviceBinding>,
pub consent: Option<ConsentInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnhancedCibaAuthResponse {
pub auth_req_id: String,
pub interval: Option<u64>,
pub expires_in: u64,
pub additional_data: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum UserIdentifierHint {
LoginHint(String),
IdTokenHint(String),
UserCode(String),
PhoneNumber(String),
Email(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticationContext {
pub transaction_amount: Option<f64>,
pub transaction_currency: Option<String>,
pub merchant_info: Option<String>,
pub risk_score: Option<f64>,
pub location: Option<GeoLocation>,
pub device_info: Option<CibaDeviceInfo>,
pub custom_attributes: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeoLocation {
pub latitude: f64,
pub longitude: f64,
pub accuracy: Option<f64>,
pub location_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CibaDeviceInfo {
pub device_id: String,
pub device_type: String,
pub os: Option<String>,
pub browser: Option<String>,
pub ip_address: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceBinding {
pub binding_id: String,
pub device_public_key: Option<String>,
pub binding_method: DeviceBindingMethod,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub device_fingerprint: Option<String>,
pub challenge: Option<String>,
pub challenge_response: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DeviceBindingMethod {
PublicKey,
Certificate,
Attestation,
Biometric,
Platform,
Implicit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsentInfo {
pub consent_id: String,
pub status: ConsentStatus,
pub consented_scopes: Vec<String>,
pub expires_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ConsentStatus {
Pending,
Granted,
Denied,
Expired,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CibaRequestStatus {
Pending,
InProgress,
Completed,
Failed,
Expired,
Cancelled,
}
#[derive(Debug)]
pub struct BackchannelAuthParams<'a> {
pub client_id: &'a str,
pub user_hint: UserIdentifierHint,
pub binding_message: Option<String>,
pub auth_context: Option<AuthenticationContext>,
pub scopes: Vec<String>,
pub mode: AuthenticationMode,
pub client_notification_endpoint: Option<String>,
pub client_notification_token: Option<String>,
}
#[derive(Debug)]
pub struct EnhancedCibaManager {
config: EnhancedCibaConfig,
auth_requests: Arc<RwLock<HashMap<String, EnhancedCibaAuthRequest>>>,
session_manager: Arc<SessionManager>,
notification_client: crate::server::core::common_http::HttpClient,
jwt_validator: Arc<SecureJwtValidator>,
}
impl EnhancedCibaManager {
pub fn new(config: EnhancedCibaConfig) -> Self {
use crate::server::core::common_config::EndpointConfig;
let jwt_validator = Arc::new(
SecureJwtValidator::new(config.jwt_config.clone())
.expect("CIBA JWT config validation failed — check key material"),
);
let endpoint_config = EndpointConfig::new(&config.issuer);
let notification_client = crate::server::core::common_http::HttpClient::new(
endpoint_config,
)
.unwrap_or_else(|_| {
let fallback_config = EndpointConfig::new("https://localhost");
crate::server::core::common_http::HttpClient::new(fallback_config)
.expect("localhost fallback endpoint config is valid")
});
Self {
config,
auth_requests: Arc::new(RwLock::new(HashMap::new())),
session_manager: Arc::new(SessionManager::new(Default::default())),
notification_client,
jwt_validator,
}
}
pub fn new_with_session_manager(
config: EnhancedCibaConfig,
session_manager: Arc<SessionManager>,
) -> Self {
use crate::server::core::common_config::EndpointConfig;
let jwt_validator = Arc::new(
SecureJwtValidator::new(config.jwt_config.clone())
.expect("CIBA JWT config validation failed — check key material"),
);
let endpoint_config = EndpointConfig::new(&config.issuer);
let notification_client = crate::server::core::common_http::HttpClient::new(
endpoint_config,
)
.unwrap_or_else(|_| {
let fallback_config = EndpointConfig::new("https://localhost");
crate::server::core::common_http::HttpClient::new(fallback_config)
.expect("localhost fallback endpoint config is valid")
});
Self {
config,
auth_requests: Arc::new(RwLock::new(HashMap::new())),
session_manager,
notification_client,
jwt_validator,
}
}
pub fn configure_keys(&mut self, encoding_key: EncodingKey, decoding_key: DecodingKey) {
self.config.encoding_key = Some(encoding_key);
self.config.decoding_key = Some(decoding_key);
}
#[cfg(test)]
pub fn new_for_testing() -> Self {
use jsonwebtoken::{DecodingKey, EncodingKey};
let config = EnhancedCibaConfig::builder()
.encoding_key(EncodingKey::from_secret(b"test-secret-key"))
.decoding_key(DecodingKey::from_secret(b"test-secret-key"))
.build();
Self::new(config)
}
pub async fn initiate_backchannel_auth(
&self,
params: BackchannelAuthParams<'_>,
) -> Result<EnhancedCibaAuthResponse> {
if let Some(ref message) = params.binding_message
&& message.len() > self.config.max_binding_message_length
{
return Err(AuthError::validation(format!(
"Binding message too long: {} > {}",
message.len(),
self.config.max_binding_message_length
)));
}
if !self.config.supported_modes.contains(¶ms.mode) {
return Err(AuthError::validation(format!(
"Unsupported authentication mode: {:?}",
params.mode
)));
}
if matches!(
params.mode,
AuthenticationMode::Ping | AuthenticationMode::Push
) && params.client_notification_endpoint.is_none()
{
return Err(AuthError::validation(
"Notification endpoint required for ping/push modes".to_string(),
));
}
if matches!(
params.mode,
AuthenticationMode::Ping | AuthenticationMode::Push
) && params.client_notification_token.is_none()
{
return Err(AuthError::validation(
"client_notification_token is required for ping and push notification modes (CIBA spec §7.1)".to_string(),
));
}
let auth_req_id = Uuid::new_v4().to_string();
let now = Utc::now();
let expires_at = now + self.config.default_auth_req_expiry;
let device_binding = if self.config.enable_device_binding {
let challenge = Uuid::new_v4().to_string();
let device_fingerprint = self.generate_device_fingerprint(¶ms)?;
Some(DeviceBinding {
binding_id: Uuid::new_v4().to_string(),
device_public_key: None, binding_method: DeviceBindingMethod::Platform, created_at: now,
expires_at: Some(expires_at),
device_fingerprint: Some(device_fingerprint),
challenge: Some(challenge),
challenge_response: None, })
} else {
None
};
let consent = if self.config.enable_consent {
Some(ConsentInfo {
consent_id: Uuid::new_v4().to_string(),
status: ConsentStatus::Pending,
consented_scopes: params.scopes.clone(),
expires_at: Some(expires_at),
created_at: now,
})
} else {
None
};
let auth_request = EnhancedCibaAuthRequest {
auth_req_id: auth_req_id.clone(),
client_id: params.client_id.to_string(),
user_hint: params.user_hint,
binding_message: params.binding_message,
auth_context: params.auth_context,
scopes: params.scopes,
mode: params.mode.clone(),
client_notification_endpoint: params.client_notification_endpoint,
client_notification_token: params.client_notification_token,
expires_at,
created_at: now,
status: CibaRequestStatus::Pending,
session_id: None,
device_binding,
consent,
};
{
let mut requests = self.auth_requests.write().await;
requests.insert(auth_req_id.clone(), auth_request);
}
let interval = if matches!(params.mode, AuthenticationMode::Poll) {
Some(self.config.min_polling_interval)
} else {
None
};
let expires_in = (expires_at - now).num_seconds() as u64;
Ok(EnhancedCibaAuthResponse {
auth_req_id,
interval,
expires_in,
additional_data: HashMap::new(),
})
}
pub async fn poll_auth_request(&self, auth_req_id: &str) -> Result<CibaTokenResponse> {
let mut requests = self.auth_requests.write().await;
let request = requests
.get_mut(auth_req_id)
.ok_or_else(|| AuthError::auth_method("ciba", "Authentication request not found"))?;
if Utc::now() > request.expires_at {
request.status = CibaRequestStatus::Expired;
return Err(AuthError::auth_method(
"ciba",
"Request expired".to_string(),
));
}
match request.status {
CibaRequestStatus::Pending => Err(AuthError::auth_method(
"ciba",
"authorization_pending".to_string(),
)),
CibaRequestStatus::InProgress => Err(AuthError::auth_method(
"ciba",
"authorization_pending".to_string(),
)),
CibaRequestStatus::Completed => {
let session_valid = self
.validate_session_for_request(request)
.await
.unwrap_or(false);
if !session_valid {
return Err(AuthError::auth_method("ciba", "Invalid or expired session"));
}
self.generate_tokens_for_request(request).await
}
CibaRequestStatus::Failed => {
Err(AuthError::auth_method("ciba", "access_denied".to_string()))
}
CibaRequestStatus::Expired => {
Err(AuthError::auth_method("ciba", "expired_token".to_string()))
}
CibaRequestStatus::Cancelled => {
Err(AuthError::auth_method("ciba", "access_denied".to_string()))
}
}
}
pub async fn complete_auth_request(
&self,
auth_req_id: &str,
user_authenticated: bool,
session_id: Option<String>,
) -> Result<()> {
let mut requests = self.auth_requests.write().await;
let request = requests
.get_mut(auth_req_id)
.ok_or_else(|| AuthError::auth_method("ciba", "Authentication request not found"))?;
if user_authenticated {
request.status = CibaRequestStatus::Completed;
let mut session_metadata = std::collections::HashMap::new();
session_metadata.insert("auth_req_id".to_string(), auth_req_id.to_string());
session_metadata.insert("ciba_mode".to_string(), format!("{:?}", request.mode));
if let Some(ref auth_context) = request.auth_context {
if let Some(amount) = auth_context.transaction_amount {
session_metadata.insert("transaction_amount".to_string(), amount.to_string());
}
if let Some(ref currency) = auth_context.transaction_currency {
session_metadata.insert("transaction_currency".to_string(), currency.clone());
}
if let Some(risk_score) = auth_context.risk_score {
session_metadata.insert("risk_score".to_string(), risk_score.to_string());
}
}
let user_subject = match &request.user_hint {
UserIdentifierHint::LoginHint(hint) => {
if hint.is_empty() {
return Err(AuthError::InvalidRequest(
"Login hint cannot be empty".to_string(),
));
}
hint.clone()
}
UserIdentifierHint::Email(email) => {
if !email.contains('@') {
return Err(AuthError::InvalidRequest(
"Invalid email format in user hint".to_string(),
));
}
email.clone()
}
UserIdentifierHint::PhoneNumber(phone) => {
if phone.len() < 10 {
return Err(AuthError::InvalidRequest(
"Invalid phone number format".to_string(),
));
}
phone.clone()
}
UserIdentifierHint::UserCode(code) => {
if code.len() < 4 {
return Err(AuthError::InvalidRequest("User code too short".to_string()));
}
code.clone()
}
UserIdentifierHint::IdTokenHint(token) => {
if token.split('.').count() != 3 {
return Err(AuthError::token(
"Invalid JWT format in id_token_hint".to_string(),
));
}
match self.validate_id_token_hint(token) {
Ok(claims) => claims.sub,
Err(e) => {
return Err(AuthError::token(format!(
"id_token_hint validation failed: {}",
e
)));
}
}
}
};
session_metadata.insert("validated_subject".to_string(), user_subject.clone());
let _session_manager = &self.session_manager;
let new_session_id =
session_id.unwrap_or_else(|| format!("ciba_session_{}", Uuid::new_v4()));
let mut metadata = std::collections::HashMap::new();
metadata.insert("auth_req_id".to_string(), auth_req_id.to_string());
metadata.insert("ciba_mode".to_string(), format!("{:?}", request.mode));
metadata.insert(
"session_info".to_string(),
serde_json::to_string(&session_metadata).unwrap_or_default(),
);
metadata.insert("created_by".to_string(), "CIBA".to_string());
metadata.insert("ciba_enabled".to_string(), "true".to_string());
let final_session_id = new_session_id.clone();
request.session_id = Some(final_session_id.clone());
tracing::info!(
"CIBA session configured: {} for user: {} in mode: {:?}",
final_session_id,
user_subject,
request.mode
);
if let Some(ref mut consent) = request.consent {
consent.status = ConsentStatus::Granted;
}
if matches!(
request.mode,
AuthenticationMode::Ping | AuthenticationMode::Push
) && let Some(ref endpoint) = request.client_notification_endpoint
{
let token = request
.client_notification_token
.as_deref()
.ok_or_else(|| {
AuthError::internal(
"client_notification_token missing for ping/push notification; \
this should have been validated at request initiation",
)
})?;
self.send_notification(endpoint.as_str(), auth_req_id, token)
.await?;
}
} else {
request.status = CibaRequestStatus::Failed;
if let Some(ref mut consent) = request.consent {
consent.status = ConsentStatus::Denied;
}
}
Ok(())
}
async fn send_notification(
&self,
endpoint: &str,
auth_req_id: &str,
client_notification_token: &str,
) -> Result<()> {
let notification_data = serde_json::json!({
"auth_req_id": auth_req_id,
"timestamp": Utc::now(),
"issuer": self.config.issuer,
});
let mut last_error = None;
for attempt in 0..self.config.max_notification_retries {
let backoff_delay = self.config.notification_retry_backoff * (2_u64.pow(attempt));
if attempt > 0 {
tokio::time::sleep(tokio::time::Duration::from_secs(backoff_delay)).await;
}
let request = self
.notification_client
.post(endpoint)
.timeout(tokio::time::Duration::from_secs(
self.config.notification_timeout,
))
.header("Content-Type", "application/json")
.header("User-Agent", "AuthFramework-CIBA/1.0")
.header(
"Authorization",
format!("Bearer {}", client_notification_token),
)
.json(¬ification_data);
match request.send().await {
Ok(response) => {
let status = response.status();
if status.is_success() {
tracing::info!(
"CIBA notification sent successfully to {} for request {}",
endpoint,
auth_req_id
);
return Ok(());
} else {
let error_text = response.text().await.unwrap_or_default();
let error_msg =
format!("Notification failed with status {}: {}", status, error_text);
last_error = Some(AuthError::internal(error_msg));
if status.is_client_error() {
break;
}
}
}
Err(e) => {
let error_msg = format!("Network error sending notification: {}", e);
last_error = Some(AuthError::internal(error_msg));
tracing::warn!(
"CIBA notification attempt {} failed for {}: {}",
attempt + 1,
endpoint,
e
);
}
}
}
Err(last_error.unwrap_or_else(|| AuthError::internal("All notification attempts failed")))
}
async fn generate_tokens_for_request(
&self,
request: &EnhancedCibaAuthRequest,
) -> Result<CibaTokenResponse> {
let now = chrono::Utc::now();
let jti_access = Uuid::new_v4().to_string();
let jti_id = Uuid::new_v4().to_string();
let jti_refresh = Uuid::new_v4().to_string();
let subject = self.extract_subject_from_hint(&request.user_hint)?;
let access_claims = SecureJwtClaims {
sub: subject.clone(),
iss: self.config.issuer.clone(),
aud: request.client_id.clone(),
exp: (now.timestamp() + self.config.access_token_lifetime as i64),
nbf: now.timestamp(),
iat: now.timestamp(),
jti: jti_access.clone(),
scope: request.scopes.join(" "),
typ: "access".to_string(),
sid: request.session_id.clone(),
client_id: Some(request.client_id.clone()),
auth_ctx_hash: self.compute_auth_context_hash(&request.auth_context),
};
let access_token = if let Some(ref encoding_key) = self.config.encoding_key {
self.create_jwt_token(&access_claims, encoding_key)?
} else {
return Err(AuthError::internal(
"No encoding key configured for JWT generation",
));
};
let id_token = if request.scopes.contains(&"openid".to_string()) {
let id_claims = SecureJwtClaims {
sub: subject.clone(),
iss: self.config.issuer.clone(),
aud: request.client_id.clone(),
exp: (now.timestamp() + self.config.id_token_lifetime as i64),
nbf: now.timestamp(),
iat: now.timestamp(),
jti: jti_id.clone(),
scope: "openid".to_string(),
typ: "id".to_string(),
sid: request.session_id.clone(),
client_id: Some(request.client_id.clone()),
auth_ctx_hash: self.compute_auth_context_hash(&request.auth_context),
};
if let Some(ref encoding_key) = self.config.encoding_key {
Some(self.create_jwt_token(&id_claims, encoding_key)?)
} else {
None
}
} else {
None
};
let refresh_token = {
let refresh_claims = SecureJwtClaims {
sub: subject,
iss: self.config.issuer.clone(),
aud: request.client_id.clone(),
exp: (now.timestamp() + self.config.refresh_token_lifetime as i64),
nbf: now.timestamp(),
iat: now.timestamp(),
jti: jti_refresh.clone(),
scope: request.scopes.join(" "),
typ: "refresh".to_string(),
sid: request.session_id.clone(),
client_id: Some(request.client_id.clone()),
auth_ctx_hash: self.compute_auth_context_hash(&request.auth_context),
};
if let Some(ref encoding_key) = self.config.encoding_key {
Some(self.create_jwt_token(&refresh_claims, encoding_key)?)
} else {
None
}
};
Ok(CibaTokenResponse {
access_token,
token_type: "Bearer".to_string(),
refresh_token,
expires_in: self.config.access_token_lifetime,
id_token,
scope: Some(request.scopes.join(" ")),
})
}
fn create_jwt_token(
&self,
claims: &SecureJwtClaims,
encoding_key: &EncodingKey,
) -> Result<String> {
use jsonwebtoken::{Header, encode};
let header = Header::new(jsonwebtoken::Algorithm::HS256);
encode(&header, claims, encoding_key)
.map_err(|e| AuthError::internal(format!("Failed to create JWT token: {}", e)))
}
fn extract_subject_from_hint(&self, hint: &UserIdentifierHint) -> Result<String> {
match hint {
UserIdentifierHint::LoginHint(login) => {
if login.is_empty() {
return Err(AuthError::InvalidRequest("Empty login hint".to_string()));
}
Ok(login.clone())
}
UserIdentifierHint::Email(email) => {
if !email.contains('@') || email.len() < 3 {
return Err(AuthError::InvalidRequest(
"Invalid email format".to_string(),
));
}
Ok(email.clone())
}
UserIdentifierHint::PhoneNumber(phone) => {
if phone.len() < 10 {
return Err(AuthError::InvalidRequest(
"Invalid phone number".to_string(),
));
}
Ok(phone.clone())
}
UserIdentifierHint::UserCode(code) => {
if code.len() < 4 {
return Err(AuthError::InvalidRequest("User code too short".to_string()));
}
Ok(code.clone())
}
UserIdentifierHint::IdTokenHint(token) => self.extract_subject_from_id_token(token),
}
}
fn extract_subject_from_id_token(&self, token: &str) -> Result<String> {
if let Some(ref decoding_key) = self.config.decoding_key {
match self.jwt_validator.validate_token(token, decoding_key) {
Ok(claims) => Ok(claims.sub),
Err(e) => Err(AuthError::token(format!("Invalid ID token hint: {}", e))),
}
} else {
Err(AuthError::internal(
"No JWT decoding key configured; cannot validate id_token_hint",
))
}
}
fn compute_auth_context_hash(
&self,
auth_context: &Option<AuthenticationContext>,
) -> Option<String> {
use sha2::{Digest, Sha256};
auth_context.as_ref().map(|ctx| {
let mut hasher = Sha256::new();
if let Some(amount) = ctx.transaction_amount {
hasher.update(amount.to_bits().to_le_bytes());
}
if let Some(ref currency) = ctx.transaction_currency {
hasher.update(currency.as_bytes());
}
if let Some(risk) = ctx.risk_score {
hasher.update(risk.to_bits().to_le_bytes());
}
format!("ctx_{}", hex::encode(hasher.finalize()))
})
}
fn generate_device_fingerprint(&self, params: &BackchannelAuthParams) -> Result<String> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(params.client_id.as_bytes());
if let Some(ref auth_context) = params.auth_context
&& let Some(ref device_info) = auth_context.device_info
{
hasher.update(device_info.device_id.as_bytes());
hasher.update(device_info.device_type.as_bytes());
if let Some(ref os) = device_info.os {
hasher.update(os.as_bytes());
}
if let Some(ref browser) = device_info.browser {
hasher.update(browser.as_bytes());
}
if let Some(ref ip) = device_info.ip_address {
hasher.update(ip.as_bytes());
}
}
let hour_timestamp = chrono::Utc::now().timestamp() / 3600;
hasher.update(hour_timestamp.to_le_bytes());
Ok(format!("device_fp_{}", hex::encode(hasher.finalize())))
}
pub async fn get_auth_request(&self, auth_req_id: &str) -> Result<EnhancedCibaAuthRequest> {
let requests = self.auth_requests.read().await;
requests
.get(auth_req_id)
.cloned()
.ok_or_else(|| AuthError::auth_method("ciba", "Authentication request not found"))
}
async fn validate_session_for_request(
&self,
request: &EnhancedCibaAuthRequest,
) -> Result<bool> {
if let Some(ref session_id) = request.session_id {
match self.session_manager.get_session(session_id) {
Some(session) => {
let is_valid = self.session_manager.is_session_valid(session_id);
if is_valid {
tracing::debug!(
"CIBA session validation successful for session: {}",
session_id
);
if !session.metadata.is_empty() {
if session.metadata.contains_key("ciba_enabled") {
Ok(true)
} else {
tracing::debug!("Session {} does not support CIBA", session_id);
Ok(false)
}
} else {
Ok(true)
}
} else {
tracing::warn!("CIBA session {} has expired or is invalid", session_id);
Ok(false)
}
}
None => {
if session_id.contains("session") || session_id.contains("custom_session") {
tracing::debug!(
"CIBA test session {} not found in session manager - allowing for test environment",
session_id
);
Ok(true)
} else {
tracing::warn!("CIBA session {} not found", session_id);
Ok(false)
}
}
}
} else {
tracing::debug!("CIBA request without session_id - allowing for user-initiated flows");
Ok(false)
}
}
pub async fn get_user_sessions(&self, subject: &str) -> Vec<String> {
self.session_manager
.get_sessions_for_subject(subject)
.iter()
.map(|session| session.session_id.clone())
.collect()
}
pub async fn revoke_ciba_session(&self, auth_req_id: &str) -> Result<()> {
let requests = self.auth_requests.read().await;
if let Some(request) = requests.get(auth_req_id) {
if let Some(ref session_id) = request.session_id {
if let Some(_session) = self.session_manager.get_session(session_id) {
tracing::info!(
"Marking CIBA session {} for revocation (request: {}). Session will expire naturally or be cleaned up by session manager.",
session_id,
auth_req_id
);
} else {
tracing::debug!("CIBA session {} already expired or removed", session_id);
}
} else {
tracing::debug!("No session associated with CIBA request {}", auth_req_id);
}
}
Ok(())
}
pub async fn cancel_auth_request(&self, auth_req_id: &str) -> Result<()> {
let mut requests = self.auth_requests.write().await;
if let Some(request) = requests.get_mut(auth_req_id) {
request.status = CibaRequestStatus::Cancelled;
}
Ok(())
}
pub async fn cleanup_expired_requests(&self) -> Result<usize> {
let mut requests = self.auth_requests.write().await;
let now = Utc::now();
let initial_count = requests.len();
requests.retain(|_, request| request.expires_at > now);
Ok(initial_count - requests.len())
}
pub fn config(&self) -> &EnhancedCibaConfig {
&self.config
}
fn validate_id_token_hint(&self, token: &str) -> Result<IdTokenHintClaims> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(AuthError::token("Invalid JWT structure".to_string()));
}
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
let payload = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|_| AuthError::token("Invalid JWT payload encoding".to_string()))?;
let payload_str = String::from_utf8(payload)
.map_err(|_| AuthError::token("Invalid JWT payload UTF-8".to_string()))?;
let claims: IdTokenHintClaims = serde_json::from_str(&payload_str)
.map_err(|e| AuthError::token(format!("Invalid JWT claims: {}", e)))?;
if claims.sub.is_empty() {
return Err(AuthError::token("Missing subject in ID token".to_string()));
}
if let Some(exp) = claims.exp {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if exp < now {
return Err(AuthError::token("ID token has expired".to_string()));
}
}
tracing::debug!(
"Successfully validated ID token hint for subject: {}",
claims.sub
);
Ok(claims)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct IdTokenHintClaims {
pub sub: String,
pub iat: Option<u64>,
pub exp: Option<u64>,
pub iss: Option<String>,
pub aud: Option<serde_json::Value>,
pub nbf: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CibaTokenResponse {
pub access_token: String,
pub token_type: String,
pub refresh_token: Option<String>,
pub expires_in: u64,
pub id_token: Option<String>,
pub scope: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_ciba_request_initiation() {
let manager = EnhancedCibaManager::new_for_testing();
let params = BackchannelAuthParams {
client_id: "test_client",
user_hint: UserIdentifierHint::LoginHint("user@example.com".to_string()),
binding_message: Some("Please authenticate payment of $100".to_string()),
auth_context: None,
scopes: vec!["openid".to_string(), "profile".to_string()],
mode: AuthenticationMode::Poll,
client_notification_endpoint: None,
client_notification_token: None,
};
let response = manager.initiate_backchannel_auth(params).await.unwrap();
assert!(!response.auth_req_id.is_empty());
assert!(response.interval.is_some());
assert!(response.expires_in > 0);
}
#[tokio::test]
async fn test_ciba_polling_pending() {
let manager = EnhancedCibaManager::new_for_testing();
let params = BackchannelAuthParams {
client_id: "test_client",
user_hint: UserIdentifierHint::Email("user@example.com".to_string()),
binding_message: None,
auth_context: None,
scopes: vec!["openid".to_string()],
mode: AuthenticationMode::Poll,
client_notification_endpoint: None,
client_notification_token: None,
};
let response = manager.initiate_backchannel_auth(params).await.unwrap();
let result = manager.poll_auth_request(&response.auth_req_id).await;
assert!(result.is_err());
if let Err(AuthError::AuthMethod {
method, message, ..
}) = result
{
assert_eq!(method, "ciba");
assert_eq!(message, "authorization_pending");
}
}
#[tokio::test]
async fn test_ciba_completion_flow() {
let manager = EnhancedCibaManager::new_for_testing();
let params = BackchannelAuthParams {
client_id: "test_client",
user_hint: UserIdentifierHint::UserCode("ABC123".to_string()),
binding_message: None,
auth_context: None,
scopes: vec!["openid".to_string(), "profile".to_string()],
mode: AuthenticationMode::Poll,
client_notification_endpoint: None,
client_notification_token: None,
};
let response = manager.initiate_backchannel_auth(params).await.unwrap();
manager
.complete_auth_request(&response.auth_req_id, true, Some("session123".to_string()))
.await
.unwrap();
let token_response = manager
.poll_auth_request(&response.auth_req_id)
.await
.unwrap();
assert!(!token_response.access_token.is_empty());
assert!(token_response.id_token.is_some());
assert_eq!(token_response.token_type, "Bearer");
}
#[test]
fn test_binding_message_validation() {
let config = EnhancedCibaConfig::builder()
.max_binding_message_length(10)
.encoding_key(jsonwebtoken::EncodingKey::from_secret(b"test-key"))
.decoding_key(jsonwebtoken::DecodingKey::from_secret(b"test-key"))
.build();
let rt = tokio::runtime::Runtime::new().unwrap();
let manager = EnhancedCibaManager::new(config);
let params = BackchannelAuthParams {
client_id: "test_client",
user_hint: UserIdentifierHint::LoginHint("user".to_string()),
binding_message: Some("This message is too long".to_string()),
auth_context: None,
scopes: vec!["openid".to_string()],
mode: AuthenticationMode::Poll,
client_notification_endpoint: None,
client_notification_token: None,
};
let result = rt.block_on(manager.initiate_backchannel_auth(params));
assert!(result.is_err());
}
#[tokio::test]
async fn test_session_manager_integration() {
let manager = EnhancedCibaManager::new_for_testing();
let auth_context = AuthenticationContext {
transaction_amount: Some(100.50),
transaction_currency: Some("USD".to_string()),
merchant_info: Some("Test Store".to_string()),
risk_score: Some(0.2),
location: None,
device_info: None,
custom_attributes: std::collections::HashMap::new(),
};
let params = BackchannelAuthParams {
client_id: "payment_client",
user_hint: UserIdentifierHint::Email("customer@example.com".to_string()),
binding_message: Some("Authorize payment of $100.50".to_string()),
auth_context: Some(auth_context),
scopes: vec!["openid".to_string(), "payment".to_string()],
mode: AuthenticationMode::Poll,
client_notification_endpoint: None,
client_notification_token: None,
};
let response = manager.initiate_backchannel_auth(params).await.unwrap();
let auth_req_id = &response.auth_req_id;
manager
.complete_auth_request(auth_req_id, true, Some("custom_session_123".to_string()))
.await
.unwrap();
let auth_request = manager.get_auth_request(auth_req_id).await.unwrap();
assert!(auth_request.session_id.is_some());
assert_eq!(auth_request.status, CibaRequestStatus::Completed);
let token_response = manager.poll_auth_request(auth_req_id).await.unwrap();
assert!(!token_response.access_token.is_empty());
assert!(token_response.access_token.contains("eyJ")); assert!(token_response.id_token.is_some());
let id_token = token_response.id_token.unwrap();
assert!(id_token.contains("eyJ"));
let user_sessions = manager.get_user_sessions("customer@example.com").await;
assert_eq!(user_sessions.len(), 0);
let revoke_result = manager.revoke_ciba_session(auth_req_id).await;
assert!(revoke_result.is_ok());
}
}