Skip to main content

auth_framework/protocols/
ciba.rs

1//! OpenID Connect Client-Initiated Backchannel Authentication (CIBA).
2//!
3//! Implements the CIBA flow where a *consumption device* (e.g. a POS terminal
4//! or call-center application) authenticates the user on a separate
5//! *authentication device* (e.g. the user's phone) without a browser redirect.
6//!
7//! # Modes
8//!
9//! - **Poll** — the client repeatedly polls the token endpoint.
10//! - **Ping** — the OP sends a notification to the client's callback URI, then the client
11//!   fetches the token.
12//! - **Push** — the OP pushes the token directly to the client's callback URI.
13//!
14//! # References
15//!
16//! - [OpenID Connect CIBA Core 1.0](https://openid.net/specs/openid-client-initiated-backchannel-authentication-core-1_0.html)
17//! - [RFC 9449 — DPoP](https://www.rfc-editor.org/rfc/rfc9449) (optional token binding)
18
19use crate::errors::{AuthError, Result};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::{Duration, SystemTime, UNIX_EPOCH};
24use tokio::sync::RwLock;
25
26// ── Configuration ───────────────────────────────────────────────────
27
28/// CIBA token delivery mode.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum CibaMode {
32    Poll,
33    Ping,
34    Push,
35}
36
37/// Configuration for a CIBA provider.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct CibaConfig {
40    /// Backchannel authentication endpoint URL.
41    pub auth_endpoint: String,
42    /// Token endpoint URL.
43    pub token_endpoint: String,
44    /// Supported delivery modes.
45    pub modes_supported: Vec<CibaMode>,
46    /// Default polling interval in seconds.
47    #[serde(default = "default_interval")]
48    pub default_interval: u64,
49    /// Maximum auth request lifetime in seconds.
50    #[serde(default = "default_expires_in")]
51    pub expires_in: u64,
52    /// Optional user code support.
53    #[serde(default)]
54    pub user_code_supported: bool,
55}
56
57fn default_interval() -> u64 {
58    5
59}
60fn default_expires_in() -> u64 {
61    300
62}
63
64// ── Authentication Request ──────────────────────────────────────────
65
66/// Hint identifying the end-user to authenticate.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68#[serde(rename_all = "snake_case")]
69pub enum LoginHint {
70    /// Subject identifier.
71    LoginHintToken(String),
72    /// An id_token_hint.
73    IdTokenHint(String),
74    /// Login hint (e.g. email or phone).
75    LoginHint(String),
76}
77
78/// A backchannel authentication request sent by the consumption device.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CibaAuthRequest {
81    /// The scopes requested.
82    pub scope: String,
83    /// Hint identifying the user.
84    pub hint: LoginHint,
85    /// Human-readable binding message shown on the authentication device.
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub binding_message: Option<String>,
88    /// User code entered on the consumption device (for user-code mode).
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub user_code: Option<String>,
91    /// Requested expiry (seconds).
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub requested_expiry: Option<u64>,
94    /// ACR values.
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub acr_values: Option<String>,
97    /// Client notification token (required for ping/push modes).
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub client_notification_token: Option<String>,
100}
101
102/// Successful response to a backchannel authentication request.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CibaAuthResponse {
105    /// Unique identifier for the authentication request.
106    pub auth_req_id: String,
107    /// Expires-in (seconds).
108    pub expires_in: u64,
109    /// Minimum polling interval (seconds) — included for poll / ping modes.
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub interval: Option<u64>,
112}
113
114// ── Token Request / Response ────────────────────────────────────────
115
116/// Pending request state.
117#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
118#[serde(rename_all = "snake_case")]
119pub enum CibaRequestStatus {
120    Pending,
121    Approved,
122    Denied,
123    Expired,
124}
125
126/// Token response after successful authentication.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct CibaTokenResponse {
129    pub access_token: String,
130    pub token_type: String,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub refresh_token: Option<String>,
133    pub expires_in: u64,
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub id_token: Option<String>,
136}
137
138/// Error response per CIBA spec.
139#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "snake_case")]
141pub enum CibaError {
142    AuthorizationPending,
143    SlowDown,
144    ExpiredToken,
145    AccessDenied,
146    InvalidRequest,
147    UnauthorizedClient,
148    InvalidScope,
149    InvalidBindingMessage,
150}
151
152// ── Internal state ──────────────────────────────────────────────────
153
154#[allow(dead_code)]
155#[derive(Debug, Clone)]
156struct PendingAuth {
157    request: CibaAuthRequest,
158    status: CibaRequestStatus,
159    created_at: u64,
160    expires_at: u64,
161    last_polled: Option<u64>,
162    mode: CibaMode,
163    subject: Option<String>,
164    token_response: Option<CibaTokenResponse>,
165}
166
167// ── CIBA Provider ───────────────────────────────────────────────────
168
169/// In-memory CIBA provider implementing Auth Request → Token lifecycle.
170pub struct CibaProvider {
171    config: CibaConfig,
172    /// `auth_req_id → PendingAuth`
173    pending: Arc<RwLock<HashMap<String, PendingAuth>>>,
174    /// Token generator function (auth_req_id, subject, scope) → CibaTokenResponse.
175    token_generator: Arc<dyn Fn(&str, &str, &str) -> CibaTokenResponse + Send + Sync>,
176}
177
178impl CibaProvider {
179    /// Create a provider with the given config and token generator.
180    pub fn new(
181        config: CibaConfig,
182        token_generator: impl Fn(&str, &str, &str) -> CibaTokenResponse + Send + Sync + 'static,
183    ) -> Self {
184        Self {
185            config,
186            pending: Arc::new(RwLock::new(HashMap::new())),
187            token_generator: Arc::new(token_generator),
188        }
189    }
190
191    fn now_secs() -> u64 {
192        SystemTime::now()
193            .duration_since(UNIX_EPOCH)
194            .unwrap_or(Duration::ZERO)
195            .as_secs()
196    }
197
198    fn generate_auth_req_id() -> String {
199        uuid::Uuid::new_v4().to_string()
200    }
201
202    // ── Phase 1: Authentication Request ─────────────────────────
203
204    /// Process a backchannel authentication request.
205    pub async fn authenticate(
206        &self,
207        request: CibaAuthRequest,
208        mode: CibaMode,
209    ) -> Result<CibaAuthResponse> {
210        // Validate mode is supported
211        if !self.config.modes_supported.contains(&mode) {
212            return Err(AuthError::validation(&format!(
213                "CIBA mode {:?} not supported",
214                mode
215            )));
216        }
217
218        // Validate binding message length
219        if let Some(ref msg) = request.binding_message {
220            if msg.is_empty() || msg.len() > 256 {
221                return Err(AuthError::validation(
222                    "Binding message must be 1-256 characters",
223                ));
224            }
225        }
226
227        // Ping/push requires client_notification_token
228        if matches!(mode, CibaMode::Ping | CibaMode::Push)
229            && request.client_notification_token.is_none()
230        {
231            return Err(AuthError::validation(
232                "client_notification_token required for ping/push mode",
233            ));
234        }
235
236        // Validate scope is non-empty
237        if request.scope.is_empty() {
238            return Err(AuthError::validation("scope is required"));
239        }
240
241        let now = Self::now_secs();
242        let expires_in = request
243            .requested_expiry
244            .unwrap_or(self.config.expires_in)
245            .min(self.config.expires_in);
246
247        let auth_req_id = Self::generate_auth_req_id();
248
249        let pending = PendingAuth {
250            request,
251            status: CibaRequestStatus::Pending,
252            created_at: now,
253            expires_at: now + expires_in,
254            last_polled: None,
255            mode,
256            subject: None,
257            token_response: None,
258        };
259
260        self.pending
261            .write()
262            .await
263            .insert(auth_req_id.clone(), pending);
264
265        Ok(CibaAuthResponse {
266            auth_req_id,
267            expires_in,
268            interval: if matches!(mode, CibaMode::Poll | CibaMode::Ping) {
269                Some(self.config.default_interval)
270            } else {
271                None
272            },
273        })
274    }
275
276    // ── Phase 2: User consent (called by authentication device) ─
277
278    /// Approve an authentication request (called when user consents).
279    pub async fn approve(&self, auth_req_id: &str, subject: &str) -> Result<()> {
280        let mut pending = self.pending.write().await;
281        let entry = pending
282            .get_mut(auth_req_id)
283            .ok_or_else(|| AuthError::validation("Unknown auth_req_id"))?;
284
285        if entry.status != CibaRequestStatus::Pending {
286            return Err(AuthError::validation(&format!(
287                "Request already {:?}",
288                entry.status
289            )));
290        }
291
292        let now = Self::now_secs();
293        if now > entry.expires_at {
294            entry.status = CibaRequestStatus::Expired;
295            return Err(AuthError::validation("Request has expired"));
296        }
297
298        // Generate tokens
299        let token_response = (self.token_generator)(
300            auth_req_id,
301            subject,
302            &entry.request.scope,
303        );
304
305        entry.status = CibaRequestStatus::Approved;
306        entry.subject = Some(subject.to_string());
307        entry.token_response = Some(token_response);
308        Ok(())
309    }
310
311    /// Deny an authentication request.
312    pub async fn deny(&self, auth_req_id: &str) -> Result<()> {
313        let mut pending = self.pending.write().await;
314        let entry = pending
315            .get_mut(auth_req_id)
316            .ok_or_else(|| AuthError::validation("Unknown auth_req_id"))?;
317
318        if entry.status != CibaRequestStatus::Pending {
319            return Err(AuthError::validation(&format!(
320                "Request already {:?}",
321                entry.status
322            )));
323        }
324
325        entry.status = CibaRequestStatus::Denied;
326        Ok(())
327    }
328
329    // ── Phase 3: Token retrieval (poll mode) ────────────────────
330
331    /// Poll for the token (used in poll mode).
332    ///
333    /// Returns `Ok(CibaTokenResponse)` on success,
334    /// `Err` with appropriate CIBA error on pending/denied/expired/slow-down.
335    pub async fn poll_token(
336        &self,
337        auth_req_id: &str,
338    ) -> std::result::Result<CibaTokenResponse, CibaError> {
339        let mut pending = self.pending.write().await;
340        let entry = pending
341            .get_mut(auth_req_id)
342            .ok_or(CibaError::InvalidRequest)?;
343
344        let now = Self::now_secs();
345
346        // Check expiry
347        if now > entry.expires_at {
348            entry.status = CibaRequestStatus::Expired;
349            return Err(CibaError::ExpiredToken);
350        }
351
352        // Slow-down check
353        if let Some(last) = entry.last_polled {
354            if now - last < self.config.default_interval {
355                return Err(CibaError::SlowDown);
356            }
357        }
358        entry.last_polled = Some(now);
359
360        match entry.status {
361            CibaRequestStatus::Pending => Err(CibaError::AuthorizationPending),
362            CibaRequestStatus::Denied => Err(CibaError::AccessDenied),
363            CibaRequestStatus::Expired => Err(CibaError::ExpiredToken),
364            CibaRequestStatus::Approved => entry
365                .token_response
366                .clone()
367                .ok_or(CibaError::InvalidRequest),
368        }
369    }
370
371    /// Get the notification payload for ping/push modes.
372    pub async fn get_notification(
373        &self,
374        auth_req_id: &str,
375    ) -> Result<(CibaMode, Option<String>, Option<CibaTokenResponse>)> {
376        let pending = self.pending.read().await;
377        let entry = pending
378            .get(auth_req_id)
379            .ok_or_else(|| AuthError::validation("Unknown auth_req_id"))?;
380
381        let client_notification_token = entry.request.client_notification_token.clone();
382        let token_response = entry.token_response.clone();
383        Ok((entry.mode, client_notification_token, token_response))
384    }
385
386    /// Clean up expired requests.
387    pub async fn cleanup_expired(&self) {
388        let now = Self::now_secs();
389        self.pending.write().await.retain(|_, entry| {
390            now <= entry.expires_at
391        });
392    }
393
394    /// Get the status of an auth request.
395    pub async fn get_status(&self, auth_req_id: &str) -> Option<CibaRequestStatus> {
396        let pending = self.pending.read().await;
397        pending.get(auth_req_id).map(|e| e.status.clone())
398    }
399
400    /// Get the total number of pending requests.
401    pub async fn pending_count(&self) -> usize {
402        self.pending.read().await.len()
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    fn test_config() -> CibaConfig {
411        CibaConfig {
412            auth_endpoint: "https://op.example.com/ciba".to_string(),
413            token_endpoint: "https://op.example.com/token".to_string(),
414            modes_supported: vec![CibaMode::Poll, CibaMode::Ping, CibaMode::Push],
415            default_interval: 1,
416            expires_in: 120,
417            user_code_supported: false,
418        }
419    }
420
421    fn test_token_gen() -> impl Fn(&str, &str, &str) -> CibaTokenResponse {
422        |_req_id, subject, scope| CibaTokenResponse {
423            access_token: format!("at_{subject}_{scope}"),
424            token_type: "Bearer".to_string(),
425            refresh_token: Some(format!("rt_{subject}")),
426            expires_in: 3600,
427            id_token: Some(format!("idt_{subject}")),
428        }
429    }
430
431    fn poll_request() -> CibaAuthRequest {
432        CibaAuthRequest {
433            scope: "openid email".to_string(),
434            hint: LoginHint::LoginHint("alice@example.com".to_string()),
435            binding_message: Some("Confirm login on terminal 42".to_string()),
436            user_code: None,
437            requested_expiry: None,
438            acr_values: None,
439            client_notification_token: None,
440        }
441    }
442
443    // ── Config serialization ────────────────────────────────────
444
445    #[test]
446    fn test_ciba_mode_serde() {
447        let json = serde_json::to_string(&CibaMode::Poll).unwrap();
448        assert_eq!(json, "\"poll\"");
449        let parsed: CibaMode = serde_json::from_str(&json).unwrap();
450        assert_eq!(parsed, CibaMode::Poll);
451    }
452
453    #[test]
454    fn test_config_serde() {
455        let config = test_config();
456        let json = serde_json::to_string(&config).unwrap();
457        let parsed: CibaConfig = serde_json::from_str(&json).unwrap();
458        assert_eq!(parsed.auth_endpoint, config.auth_endpoint);
459        assert_eq!(parsed.modes_supported.len(), 3);
460    }
461
462    // ── Authentication request ──────────────────────────────────
463
464    #[tokio::test]
465    async fn test_auth_request_poll_mode() {
466        let provider = CibaProvider::new(test_config(), test_token_gen());
467        let resp = provider
468            .authenticate(poll_request(), CibaMode::Poll)
469            .await
470            .unwrap();
471        assert!(!resp.auth_req_id.is_empty());
472        assert!(resp.expires_in > 0);
473        assert!(resp.interval.is_some());
474    }
475
476    #[tokio::test]
477    async fn test_auth_request_push_mode_requires_notification_token() {
478        let provider = CibaProvider::new(test_config(), test_token_gen());
479        let result = provider
480            .authenticate(poll_request(), CibaMode::Push)
481            .await;
482        assert!(result.is_err());
483    }
484
485    #[tokio::test]
486    async fn test_auth_request_push_mode_with_token() {
487        let provider = CibaProvider::new(test_config(), test_token_gen());
488        let mut req = poll_request();
489        req.client_notification_token = Some("cnt_abc123".to_string());
490        let resp = provider
491            .authenticate(req, CibaMode::Push)
492            .await
493            .unwrap();
494        assert!(!resp.auth_req_id.is_empty());
495        assert!(resp.interval.is_none()); // Push mode has no polling interval
496    }
497
498    #[tokio::test]
499    async fn test_auth_request_empty_scope_rejected() {
500        let provider = CibaProvider::new(test_config(), test_token_gen());
501        let mut req = poll_request();
502        req.scope = String::new();
503        assert!(provider.authenticate(req, CibaMode::Poll).await.is_err());
504    }
505
506    #[tokio::test]
507    async fn test_auth_request_invalid_binding_message() {
508        let provider = CibaProvider::new(test_config(), test_token_gen());
509        let mut req = poll_request();
510        req.binding_message = Some(String::new());
511        assert!(provider.authenticate(req, CibaMode::Poll).await.is_err());
512    }
513
514    #[tokio::test]
515    async fn test_unsupported_mode_rejected() {
516        let config = CibaConfig {
517            modes_supported: vec![CibaMode::Poll],
518            ..test_config()
519        };
520        let provider = CibaProvider::new(config, test_token_gen());
521        let mut req = poll_request();
522        req.client_notification_token = Some("token".to_string());
523        assert!(provider.authenticate(req, CibaMode::Push).await.is_err());
524    }
525
526    // ── Approve / Deny ──────────────────────────────────────────
527
528    #[tokio::test]
529    async fn test_approve_and_poll() {
530        // Use interval=0 so rapid successive polls don't trigger SlowDown
531        let config = CibaConfig {
532            default_interval: 0,
533            ..test_config()
534        };
535        let provider = CibaProvider::new(config, test_token_gen());
536        let resp = provider
537            .authenticate(poll_request(), CibaMode::Poll)
538            .await
539            .unwrap();
540
541        // Initially pending
542        assert_eq!(
543            provider.get_status(&resp.auth_req_id).await.unwrap(),
544            CibaRequestStatus::Pending
545        );
546
547        // Polling before approval → authorization_pending
548        let poll_result = provider.poll_token(&resp.auth_req_id).await;
549        assert_eq!(poll_result.unwrap_err(), CibaError::AuthorizationPending);
550
551        // Approve
552        provider
553            .approve(&resp.auth_req_id, "user:alice")
554            .await
555            .unwrap();
556        assert_eq!(
557            provider.get_status(&resp.auth_req_id).await.unwrap(),
558            CibaRequestStatus::Approved
559        );
560
561        // Poll should now succeed (after interval)
562        let token = provider.poll_token(&resp.auth_req_id).await.unwrap();
563        assert!(token.access_token.contains("alice"));
564        assert_eq!(token.token_type, "Bearer");
565        assert!(token.id_token.is_some());
566    }
567
568    #[tokio::test]
569    async fn test_deny_and_poll() {
570        let provider = CibaProvider::new(test_config(), test_token_gen());
571        let resp = provider
572            .authenticate(poll_request(), CibaMode::Poll)
573            .await
574            .unwrap();
575
576        provider.deny(&resp.auth_req_id).await.unwrap();
577
578        let poll_result = provider.poll_token(&resp.auth_req_id).await;
579        assert_eq!(poll_result.unwrap_err(), CibaError::AccessDenied);
580    }
581
582    #[tokio::test]
583    async fn test_double_approve_rejected() {
584        let provider = CibaProvider::new(test_config(), test_token_gen());
585        let resp = provider
586            .authenticate(poll_request(), CibaMode::Poll)
587            .await
588            .unwrap();
589        provider
590            .approve(&resp.auth_req_id, "user:alice")
591            .await
592            .unwrap();
593        assert!(provider.approve(&resp.auth_req_id, "user:bob").await.is_err());
594    }
595
596    #[tokio::test]
597    async fn test_approve_unknown_id() {
598        let provider = CibaProvider::new(test_config(), test_token_gen());
599        assert!(provider.approve("nonexistent", "user:alice").await.is_err());
600    }
601
602    // ── Expiry ──────────────────────────────────────────────────
603
604    #[tokio::test]
605    async fn test_cleanup_expired() {
606        let mut config = test_config();
607        config.expires_in = 1; // 1 second
608        let provider = CibaProvider::new(config, test_token_gen());
609        let resp = provider
610            .authenticate(poll_request(), CibaMode::Poll)
611            .await
612            .unwrap();
613        assert_eq!(provider.pending_count().await, 1);
614
615        // The request expires_in=1s, but we can't sleep. Instead, manually
616        // set the entry to expired by over-riding expires_at.
617        {
618            let mut pending = provider.pending.write().await;
619            let entry = pending.get_mut(&resp.auth_req_id).unwrap();
620            entry.expires_at = 0; // Force expired
621        }
622
623        provider.cleanup_expired().await;
624        assert_eq!(provider.pending_count().await, 0);
625    }
626
627    // ── Notification info ───────────────────────────────────────
628
629    #[tokio::test]
630    async fn test_get_notification_push() {
631        let provider = CibaProvider::new(test_config(), test_token_gen());
632        let mut req = poll_request();
633        req.client_notification_token = Some("cnt_xyz".to_string());
634        let resp = provider
635            .authenticate(req, CibaMode::Push)
636            .await
637            .unwrap();
638
639        provider
640            .approve(&resp.auth_req_id, "user:alice")
641            .await
642            .unwrap();
643
644        let (mode, cnt, token) = provider
645            .get_notification(&resp.auth_req_id)
646            .await
647            .unwrap();
648        assert_eq!(mode, CibaMode::Push);
649        assert_eq!(cnt.unwrap(), "cnt_xyz");
650        assert!(token.is_some());
651    }
652
653    // ── Login hint variants ─────────────────────────────────────
654
655    #[test]
656    fn test_login_hint_serde() {
657        let hint = LoginHint::IdTokenHint("eyJ...".to_string());
658        let json = serde_json::to_string(&hint).unwrap();
659        let parsed: LoginHint = serde_json::from_str(&json).unwrap();
660        match parsed {
661            LoginHint::IdTokenHint(v) => assert_eq!(v, "eyJ..."),
662            _ => panic!("Wrong hint variant"),
663        }
664    }
665
666    // ── CibaError equality ──────────────────────────────────────
667
668    #[test]
669    fn test_ciba_error_serde() {
670        let err = CibaError::SlowDown;
671        let json = serde_json::to_string(&err).unwrap();
672        assert_eq!(json, "\"slow_down\"");
673    }
674}