use crate::errors::{AuthError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CibaMode {
Poll,
Ping,
Push,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CibaConfig {
pub auth_endpoint: String,
pub token_endpoint: String,
pub modes_supported: Vec<CibaMode>,
#[serde(default = "default_interval")]
pub default_interval: u64,
#[serde(default = "default_expires_in")]
pub expires_in: u64,
#[serde(default)]
pub user_code_supported: bool,
}
fn default_interval() -> u64 {
5
}
fn default_expires_in() -> u64 {
300
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LoginHint {
LoginHintToken(String),
IdTokenHint(String),
LoginHint(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CibaAuthRequest {
pub scope: String,
pub hint: LoginHint,
#[serde(skip_serializing_if = "Option::is_none")]
pub binding_message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub requested_expiry: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub acr_values: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_notification_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CibaAuthResponse {
pub auth_req_id: String,
pub expires_in: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub interval: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CibaRequestStatus {
Pending,
Approved,
Denied,
Expired,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CibaTokenResponse {
pub access_token: String,
pub token_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
pub expires_in: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CibaError {
AuthorizationPending,
SlowDown,
ExpiredToken,
AccessDenied,
InvalidRequest,
UnauthorizedClient,
InvalidScope,
InvalidBindingMessage,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct PendingAuth {
request: CibaAuthRequest,
status: CibaRequestStatus,
created_at: u64,
expires_at: u64,
last_polled: Option<u64>,
mode: CibaMode,
subject: Option<String>,
token_response: Option<CibaTokenResponse>,
}
pub struct CibaProvider {
config: CibaConfig,
pending: Arc<RwLock<HashMap<String, PendingAuth>>>,
token_generator: Arc<dyn Fn(&str, &str, &str) -> CibaTokenResponse + Send + Sync>,
}
impl CibaProvider {
pub fn new(
config: CibaConfig,
token_generator: impl Fn(&str, &str, &str) -> CibaTokenResponse + Send + Sync + 'static,
) -> Self {
Self {
config,
pending: Arc::new(RwLock::new(HashMap::new())),
token_generator: Arc::new(token_generator),
}
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_secs()
}
fn generate_auth_req_id() -> String {
uuid::Uuid::new_v4().to_string()
}
pub async fn authenticate(
&self,
request: CibaAuthRequest,
mode: CibaMode,
) -> Result<CibaAuthResponse> {
if !self.config.modes_supported.contains(&mode) {
return Err(AuthError::validation(&format!(
"CIBA mode {:?} not supported",
mode
)));
}
if let Some(ref msg) = request.binding_message {
if msg.is_empty() || msg.len() > 256 {
return Err(AuthError::validation(
"Binding message must be 1-256 characters",
));
}
}
if matches!(mode, CibaMode::Ping | CibaMode::Push)
&& request.client_notification_token.is_none()
{
return Err(AuthError::validation(
"client_notification_token required for ping/push mode",
));
}
if request.scope.is_empty() {
return Err(AuthError::validation("scope is required"));
}
let now = Self::now_secs();
let expires_in = request
.requested_expiry
.unwrap_or(self.config.expires_in)
.min(self.config.expires_in);
let auth_req_id = Self::generate_auth_req_id();
let pending = PendingAuth {
request,
status: CibaRequestStatus::Pending,
created_at: now,
expires_at: now + expires_in,
last_polled: None,
mode,
subject: None,
token_response: None,
};
self.pending
.write()
.await
.insert(auth_req_id.clone(), pending);
Ok(CibaAuthResponse {
auth_req_id,
expires_in,
interval: if matches!(mode, CibaMode::Poll | CibaMode::Ping) {
Some(self.config.default_interval)
} else {
None
},
})
}
pub async fn approve(&self, auth_req_id: &str, subject: &str) -> Result<()> {
let mut pending = self.pending.write().await;
let entry = pending
.get_mut(auth_req_id)
.ok_or_else(|| AuthError::validation("Unknown auth_req_id"))?;
if entry.status != CibaRequestStatus::Pending {
return Err(AuthError::validation(&format!(
"Request already {:?}",
entry.status
)));
}
let now = Self::now_secs();
if now > entry.expires_at {
entry.status = CibaRequestStatus::Expired;
return Err(AuthError::validation("Request has expired"));
}
let token_response = (self.token_generator)(
auth_req_id,
subject,
&entry.request.scope,
);
entry.status = CibaRequestStatus::Approved;
entry.subject = Some(subject.to_string());
entry.token_response = Some(token_response);
Ok(())
}
pub async fn deny(&self, auth_req_id: &str) -> Result<()> {
let mut pending = self.pending.write().await;
let entry = pending
.get_mut(auth_req_id)
.ok_or_else(|| AuthError::validation("Unknown auth_req_id"))?;
if entry.status != CibaRequestStatus::Pending {
return Err(AuthError::validation(&format!(
"Request already {:?}",
entry.status
)));
}
entry.status = CibaRequestStatus::Denied;
Ok(())
}
pub async fn poll_token(
&self,
auth_req_id: &str,
) -> std::result::Result<CibaTokenResponse, CibaError> {
let mut pending = self.pending.write().await;
let entry = pending
.get_mut(auth_req_id)
.ok_or(CibaError::InvalidRequest)?;
let now = Self::now_secs();
if now > entry.expires_at {
entry.status = CibaRequestStatus::Expired;
return Err(CibaError::ExpiredToken);
}
if let Some(last) = entry.last_polled {
if now - last < self.config.default_interval {
return Err(CibaError::SlowDown);
}
}
entry.last_polled = Some(now);
match entry.status {
CibaRequestStatus::Pending => Err(CibaError::AuthorizationPending),
CibaRequestStatus::Denied => Err(CibaError::AccessDenied),
CibaRequestStatus::Expired => Err(CibaError::ExpiredToken),
CibaRequestStatus::Approved => entry
.token_response
.clone()
.ok_or(CibaError::InvalidRequest),
}
}
pub async fn get_notification(
&self,
auth_req_id: &str,
) -> Result<(CibaMode, Option<String>, Option<CibaTokenResponse>)> {
let pending = self.pending.read().await;
let entry = pending
.get(auth_req_id)
.ok_or_else(|| AuthError::validation("Unknown auth_req_id"))?;
let client_notification_token = entry.request.client_notification_token.clone();
let token_response = entry.token_response.clone();
Ok((entry.mode, client_notification_token, token_response))
}
pub async fn cleanup_expired(&self) {
let now = Self::now_secs();
self.pending.write().await.retain(|_, entry| {
now <= entry.expires_at
});
}
pub async fn get_status(&self, auth_req_id: &str) -> Option<CibaRequestStatus> {
let pending = self.pending.read().await;
pending.get(auth_req_id).map(|e| e.status.clone())
}
pub async fn pending_count(&self) -> usize {
self.pending.read().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> CibaConfig {
CibaConfig {
auth_endpoint: "https://op.example.com/ciba".to_string(),
token_endpoint: "https://op.example.com/token".to_string(),
modes_supported: vec![CibaMode::Poll, CibaMode::Ping, CibaMode::Push],
default_interval: 1,
expires_in: 120,
user_code_supported: false,
}
}
fn test_token_gen() -> impl Fn(&str, &str, &str) -> CibaTokenResponse {
|_req_id, subject, scope| CibaTokenResponse {
access_token: format!("at_{subject}_{scope}"),
token_type: "Bearer".to_string(),
refresh_token: Some(format!("rt_{subject}")),
expires_in: 3600,
id_token: Some(format!("idt_{subject}")),
}
}
fn poll_request() -> CibaAuthRequest {
CibaAuthRequest {
scope: "openid email".to_string(),
hint: LoginHint::LoginHint("alice@example.com".to_string()),
binding_message: Some("Confirm login on terminal 42".to_string()),
user_code: None,
requested_expiry: None,
acr_values: None,
client_notification_token: None,
}
}
#[test]
fn test_ciba_mode_serde() {
let json = serde_json::to_string(&CibaMode::Poll).unwrap();
assert_eq!(json, "\"poll\"");
let parsed: CibaMode = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, CibaMode::Poll);
}
#[test]
fn test_config_serde() {
let config = test_config();
let json = serde_json::to_string(&config).unwrap();
let parsed: CibaConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.auth_endpoint, config.auth_endpoint);
assert_eq!(parsed.modes_supported.len(), 3);
}
#[tokio::test]
async fn test_auth_request_poll_mode() {
let provider = CibaProvider::new(test_config(), test_token_gen());
let resp = provider
.authenticate(poll_request(), CibaMode::Poll)
.await
.unwrap();
assert!(!resp.auth_req_id.is_empty());
assert!(resp.expires_in > 0);
assert!(resp.interval.is_some());
}
#[tokio::test]
async fn test_auth_request_push_mode_requires_notification_token() {
let provider = CibaProvider::new(test_config(), test_token_gen());
let result = provider
.authenticate(poll_request(), CibaMode::Push)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_auth_request_push_mode_with_token() {
let provider = CibaProvider::new(test_config(), test_token_gen());
let mut req = poll_request();
req.client_notification_token = Some("cnt_abc123".to_string());
let resp = provider
.authenticate(req, CibaMode::Push)
.await
.unwrap();
assert!(!resp.auth_req_id.is_empty());
assert!(resp.interval.is_none()); }
#[tokio::test]
async fn test_auth_request_empty_scope_rejected() {
let provider = CibaProvider::new(test_config(), test_token_gen());
let mut req = poll_request();
req.scope = String::new();
assert!(provider.authenticate(req, CibaMode::Poll).await.is_err());
}
#[tokio::test]
async fn test_auth_request_invalid_binding_message() {
let provider = CibaProvider::new(test_config(), test_token_gen());
let mut req = poll_request();
req.binding_message = Some(String::new());
assert!(provider.authenticate(req, CibaMode::Poll).await.is_err());
}
#[tokio::test]
async fn test_unsupported_mode_rejected() {
let config = CibaConfig {
modes_supported: vec![CibaMode::Poll],
..test_config()
};
let provider = CibaProvider::new(config, test_token_gen());
let mut req = poll_request();
req.client_notification_token = Some("token".to_string());
assert!(provider.authenticate(req, CibaMode::Push).await.is_err());
}
#[tokio::test]
async fn test_approve_and_poll() {
let config = CibaConfig {
default_interval: 0,
..test_config()
};
let provider = CibaProvider::new(config, test_token_gen());
let resp = provider
.authenticate(poll_request(), CibaMode::Poll)
.await
.unwrap();
assert_eq!(
provider.get_status(&resp.auth_req_id).await.unwrap(),
CibaRequestStatus::Pending
);
let poll_result = provider.poll_token(&resp.auth_req_id).await;
assert_eq!(poll_result.unwrap_err(), CibaError::AuthorizationPending);
provider
.approve(&resp.auth_req_id, "user:alice")
.await
.unwrap();
assert_eq!(
provider.get_status(&resp.auth_req_id).await.unwrap(),
CibaRequestStatus::Approved
);
let token = provider.poll_token(&resp.auth_req_id).await.unwrap();
assert!(token.access_token.contains("alice"));
assert_eq!(token.token_type, "Bearer");
assert!(token.id_token.is_some());
}
#[tokio::test]
async fn test_deny_and_poll() {
let provider = CibaProvider::new(test_config(), test_token_gen());
let resp = provider
.authenticate(poll_request(), CibaMode::Poll)
.await
.unwrap();
provider.deny(&resp.auth_req_id).await.unwrap();
let poll_result = provider.poll_token(&resp.auth_req_id).await;
assert_eq!(poll_result.unwrap_err(), CibaError::AccessDenied);
}
#[tokio::test]
async fn test_double_approve_rejected() {
let provider = CibaProvider::new(test_config(), test_token_gen());
let resp = provider
.authenticate(poll_request(), CibaMode::Poll)
.await
.unwrap();
provider
.approve(&resp.auth_req_id, "user:alice")
.await
.unwrap();
assert!(provider.approve(&resp.auth_req_id, "user:bob").await.is_err());
}
#[tokio::test]
async fn test_approve_unknown_id() {
let provider = CibaProvider::new(test_config(), test_token_gen());
assert!(provider.approve("nonexistent", "user:alice").await.is_err());
}
#[tokio::test]
async fn test_cleanup_expired() {
let mut config = test_config();
config.expires_in = 1; let provider = CibaProvider::new(config, test_token_gen());
let resp = provider
.authenticate(poll_request(), CibaMode::Poll)
.await
.unwrap();
assert_eq!(provider.pending_count().await, 1);
{
let mut pending = provider.pending.write().await;
let entry = pending.get_mut(&resp.auth_req_id).unwrap();
entry.expires_at = 0; }
provider.cleanup_expired().await;
assert_eq!(provider.pending_count().await, 0);
}
#[tokio::test]
async fn test_get_notification_push() {
let provider = CibaProvider::new(test_config(), test_token_gen());
let mut req = poll_request();
req.client_notification_token = Some("cnt_xyz".to_string());
let resp = provider
.authenticate(req, CibaMode::Push)
.await
.unwrap();
provider
.approve(&resp.auth_req_id, "user:alice")
.await
.unwrap();
let (mode, cnt, token) = provider
.get_notification(&resp.auth_req_id)
.await
.unwrap();
assert_eq!(mode, CibaMode::Push);
assert_eq!(cnt.unwrap(), "cnt_xyz");
assert!(token.is_some());
}
#[test]
fn test_login_hint_serde() {
let hint = LoginHint::IdTokenHint("eyJ...".to_string());
let json = serde_json::to_string(&hint).unwrap();
let parsed: LoginHint = serde_json::from_str(&json).unwrap();
match parsed {
LoginHint::IdTokenHint(v) => assert_eq!(v, "eyJ..."),
_ => panic!("Wrong hint variant"),
}
}
#[test]
fn test_ciba_error_serde() {
let err = CibaError::SlowDown;
let json = serde_json::to_string(&err).unwrap();
assert_eq!(json, "\"slow_down\"");
}
}