1use crate::errors::{AuthError, Result};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::{Duration, SystemTime, UNIX_EPOCH};
24use tokio::sync::RwLock;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum CibaMode {
32 Poll,
33 Ping,
34 Push,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct CibaConfig {
40 pub auth_endpoint: String,
42 pub token_endpoint: String,
44 pub modes_supported: Vec<CibaMode>,
46 #[serde(default = "default_interval")]
48 pub default_interval: u64,
49 #[serde(default = "default_expires_in")]
51 pub expires_in: u64,
52 #[serde(default)]
54 pub user_code_supported: bool,
55}
56
57fn default_interval() -> u64 {
58 5
59}
60fn default_expires_in() -> u64 {
61 300
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
68#[serde(rename_all = "snake_case")]
69pub enum LoginHint {
70 LoginHintToken(String),
72 IdTokenHint(String),
74 LoginHint(String),
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CibaAuthRequest {
81 pub scope: String,
83 pub hint: LoginHint,
85 #[serde(skip_serializing_if = "Option::is_none")]
87 pub binding_message: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub user_code: Option<String>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub requested_expiry: Option<u64>,
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub acr_values: Option<String>,
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub client_notification_token: Option<String>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CibaAuthResponse {
105 pub auth_req_id: String,
107 pub expires_in: u64,
109 #[serde(skip_serializing_if = "Option::is_none")]
111 pub interval: Option<u64>,
112}
113
114#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
118#[serde(rename_all = "snake_case")]
119pub enum CibaRequestStatus {
120 Pending,
121 Approved,
122 Denied,
123 Expired,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct CibaTokenResponse {
129 pub access_token: String,
130 pub token_type: String,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 pub refresh_token: Option<String>,
133 pub expires_in: u64,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 pub id_token: Option<String>,
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140#[serde(rename_all = "snake_case")]
141pub enum CibaError {
142 AuthorizationPending,
143 SlowDown,
144 ExpiredToken,
145 AccessDenied,
146 InvalidRequest,
147 UnauthorizedClient,
148 InvalidScope,
149 InvalidBindingMessage,
150}
151
152#[allow(dead_code)]
155#[derive(Debug, Clone)]
156struct PendingAuth {
157 request: CibaAuthRequest,
158 status: CibaRequestStatus,
159 created_at: u64,
160 expires_at: u64,
161 last_polled: Option<u64>,
162 mode: CibaMode,
163 subject: Option<String>,
164 token_response: Option<CibaTokenResponse>,
165}
166
167pub struct CibaProvider {
171 config: CibaConfig,
172 pending: Arc<RwLock<HashMap<String, PendingAuth>>>,
174 token_generator: Arc<dyn Fn(&str, &str, &str) -> CibaTokenResponse + Send + Sync>,
176}
177
178impl CibaProvider {
179 pub fn new(
181 config: CibaConfig,
182 token_generator: impl Fn(&str, &str, &str) -> CibaTokenResponse + Send + Sync + 'static,
183 ) -> Self {
184 Self {
185 config,
186 pending: Arc::new(RwLock::new(HashMap::new())),
187 token_generator: Arc::new(token_generator),
188 }
189 }
190
191 fn now_secs() -> u64 {
192 SystemTime::now()
193 .duration_since(UNIX_EPOCH)
194 .unwrap_or(Duration::ZERO)
195 .as_secs()
196 }
197
198 fn generate_auth_req_id() -> String {
199 uuid::Uuid::new_v4().to_string()
200 }
201
202 pub async fn authenticate(
206 &self,
207 request: CibaAuthRequest,
208 mode: CibaMode,
209 ) -> Result<CibaAuthResponse> {
210 if !self.config.modes_supported.contains(&mode) {
212 return Err(AuthError::validation(&format!(
213 "CIBA mode {:?} not supported",
214 mode
215 )));
216 }
217
218 if let Some(ref msg) = request.binding_message {
220 if msg.is_empty() || msg.len() > 256 {
221 return Err(AuthError::validation(
222 "Binding message must be 1-256 characters",
223 ));
224 }
225 }
226
227 if matches!(mode, CibaMode::Ping | CibaMode::Push)
229 && request.client_notification_token.is_none()
230 {
231 return Err(AuthError::validation(
232 "client_notification_token required for ping/push mode",
233 ));
234 }
235
236 if request.scope.is_empty() {
238 return Err(AuthError::validation("scope is required"));
239 }
240
241 let now = Self::now_secs();
242 let expires_in = request
243 .requested_expiry
244 .unwrap_or(self.config.expires_in)
245 .min(self.config.expires_in);
246
247 let auth_req_id = Self::generate_auth_req_id();
248
249 let pending = PendingAuth {
250 request,
251 status: CibaRequestStatus::Pending,
252 created_at: now,
253 expires_at: now + expires_in,
254 last_polled: None,
255 mode,
256 subject: None,
257 token_response: None,
258 };
259
260 self.pending
261 .write()
262 .await
263 .insert(auth_req_id.clone(), pending);
264
265 Ok(CibaAuthResponse {
266 auth_req_id,
267 expires_in,
268 interval: if matches!(mode, CibaMode::Poll | CibaMode::Ping) {
269 Some(self.config.default_interval)
270 } else {
271 None
272 },
273 })
274 }
275
276 pub async fn approve(&self, auth_req_id: &str, subject: &str) -> Result<()> {
280 let mut pending = self.pending.write().await;
281 let entry = pending
282 .get_mut(auth_req_id)
283 .ok_or_else(|| AuthError::validation("Unknown auth_req_id"))?;
284
285 if entry.status != CibaRequestStatus::Pending {
286 return Err(AuthError::validation(&format!(
287 "Request already {:?}",
288 entry.status
289 )));
290 }
291
292 let now = Self::now_secs();
293 if now > entry.expires_at {
294 entry.status = CibaRequestStatus::Expired;
295 return Err(AuthError::validation("Request has expired"));
296 }
297
298 let token_response = (self.token_generator)(
300 auth_req_id,
301 subject,
302 &entry.request.scope,
303 );
304
305 entry.status = CibaRequestStatus::Approved;
306 entry.subject = Some(subject.to_string());
307 entry.token_response = Some(token_response);
308 Ok(())
309 }
310
311 pub async fn deny(&self, auth_req_id: &str) -> Result<()> {
313 let mut pending = self.pending.write().await;
314 let entry = pending
315 .get_mut(auth_req_id)
316 .ok_or_else(|| AuthError::validation("Unknown auth_req_id"))?;
317
318 if entry.status != CibaRequestStatus::Pending {
319 return Err(AuthError::validation(&format!(
320 "Request already {:?}",
321 entry.status
322 )));
323 }
324
325 entry.status = CibaRequestStatus::Denied;
326 Ok(())
327 }
328
329 pub async fn poll_token(
336 &self,
337 auth_req_id: &str,
338 ) -> std::result::Result<CibaTokenResponse, CibaError> {
339 let mut pending = self.pending.write().await;
340 let entry = pending
341 .get_mut(auth_req_id)
342 .ok_or(CibaError::InvalidRequest)?;
343
344 let now = Self::now_secs();
345
346 if now > entry.expires_at {
348 entry.status = CibaRequestStatus::Expired;
349 return Err(CibaError::ExpiredToken);
350 }
351
352 if let Some(last) = entry.last_polled {
354 if now - last < self.config.default_interval {
355 return Err(CibaError::SlowDown);
356 }
357 }
358 entry.last_polled = Some(now);
359
360 match entry.status {
361 CibaRequestStatus::Pending => Err(CibaError::AuthorizationPending),
362 CibaRequestStatus::Denied => Err(CibaError::AccessDenied),
363 CibaRequestStatus::Expired => Err(CibaError::ExpiredToken),
364 CibaRequestStatus::Approved => entry
365 .token_response
366 .clone()
367 .ok_or(CibaError::InvalidRequest),
368 }
369 }
370
371 pub async fn get_notification(
373 &self,
374 auth_req_id: &str,
375 ) -> Result<(CibaMode, Option<String>, Option<CibaTokenResponse>)> {
376 let pending = self.pending.read().await;
377 let entry = pending
378 .get(auth_req_id)
379 .ok_or_else(|| AuthError::validation("Unknown auth_req_id"))?;
380
381 let client_notification_token = entry.request.client_notification_token.clone();
382 let token_response = entry.token_response.clone();
383 Ok((entry.mode, client_notification_token, token_response))
384 }
385
386 pub async fn cleanup_expired(&self) {
388 let now = Self::now_secs();
389 self.pending.write().await.retain(|_, entry| {
390 now <= entry.expires_at
391 });
392 }
393
394 pub async fn get_status(&self, auth_req_id: &str) -> Option<CibaRequestStatus> {
396 let pending = self.pending.read().await;
397 pending.get(auth_req_id).map(|e| e.status.clone())
398 }
399
400 pub async fn pending_count(&self) -> usize {
402 self.pending.read().await.len()
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 fn test_config() -> CibaConfig {
411 CibaConfig {
412 auth_endpoint: "https://op.example.com/ciba".to_string(),
413 token_endpoint: "https://op.example.com/token".to_string(),
414 modes_supported: vec![CibaMode::Poll, CibaMode::Ping, CibaMode::Push],
415 default_interval: 1,
416 expires_in: 120,
417 user_code_supported: false,
418 }
419 }
420
421 fn test_token_gen() -> impl Fn(&str, &str, &str) -> CibaTokenResponse {
422 |_req_id, subject, scope| CibaTokenResponse {
423 access_token: format!("at_{subject}_{scope}"),
424 token_type: "Bearer".to_string(),
425 refresh_token: Some(format!("rt_{subject}")),
426 expires_in: 3600,
427 id_token: Some(format!("idt_{subject}")),
428 }
429 }
430
431 fn poll_request() -> CibaAuthRequest {
432 CibaAuthRequest {
433 scope: "openid email".to_string(),
434 hint: LoginHint::LoginHint("alice@example.com".to_string()),
435 binding_message: Some("Confirm login on terminal 42".to_string()),
436 user_code: None,
437 requested_expiry: None,
438 acr_values: None,
439 client_notification_token: None,
440 }
441 }
442
443 #[test]
446 fn test_ciba_mode_serde() {
447 let json = serde_json::to_string(&CibaMode::Poll).unwrap();
448 assert_eq!(json, "\"poll\"");
449 let parsed: CibaMode = serde_json::from_str(&json).unwrap();
450 assert_eq!(parsed, CibaMode::Poll);
451 }
452
453 #[test]
454 fn test_config_serde() {
455 let config = test_config();
456 let json = serde_json::to_string(&config).unwrap();
457 let parsed: CibaConfig = serde_json::from_str(&json).unwrap();
458 assert_eq!(parsed.auth_endpoint, config.auth_endpoint);
459 assert_eq!(parsed.modes_supported.len(), 3);
460 }
461
462 #[tokio::test]
465 async fn test_auth_request_poll_mode() {
466 let provider = CibaProvider::new(test_config(), test_token_gen());
467 let resp = provider
468 .authenticate(poll_request(), CibaMode::Poll)
469 .await
470 .unwrap();
471 assert!(!resp.auth_req_id.is_empty());
472 assert!(resp.expires_in > 0);
473 assert!(resp.interval.is_some());
474 }
475
476 #[tokio::test]
477 async fn test_auth_request_push_mode_requires_notification_token() {
478 let provider = CibaProvider::new(test_config(), test_token_gen());
479 let result = provider
480 .authenticate(poll_request(), CibaMode::Push)
481 .await;
482 assert!(result.is_err());
483 }
484
485 #[tokio::test]
486 async fn test_auth_request_push_mode_with_token() {
487 let provider = CibaProvider::new(test_config(), test_token_gen());
488 let mut req = poll_request();
489 req.client_notification_token = Some("cnt_abc123".to_string());
490 let resp = provider
491 .authenticate(req, CibaMode::Push)
492 .await
493 .unwrap();
494 assert!(!resp.auth_req_id.is_empty());
495 assert!(resp.interval.is_none()); }
497
498 #[tokio::test]
499 async fn test_auth_request_empty_scope_rejected() {
500 let provider = CibaProvider::new(test_config(), test_token_gen());
501 let mut req = poll_request();
502 req.scope = String::new();
503 assert!(provider.authenticate(req, CibaMode::Poll).await.is_err());
504 }
505
506 #[tokio::test]
507 async fn test_auth_request_invalid_binding_message() {
508 let provider = CibaProvider::new(test_config(), test_token_gen());
509 let mut req = poll_request();
510 req.binding_message = Some(String::new());
511 assert!(provider.authenticate(req, CibaMode::Poll).await.is_err());
512 }
513
514 #[tokio::test]
515 async fn test_unsupported_mode_rejected() {
516 let config = CibaConfig {
517 modes_supported: vec![CibaMode::Poll],
518 ..test_config()
519 };
520 let provider = CibaProvider::new(config, test_token_gen());
521 let mut req = poll_request();
522 req.client_notification_token = Some("token".to_string());
523 assert!(provider.authenticate(req, CibaMode::Push).await.is_err());
524 }
525
526 #[tokio::test]
529 async fn test_approve_and_poll() {
530 let config = CibaConfig {
532 default_interval: 0,
533 ..test_config()
534 };
535 let provider = CibaProvider::new(config, test_token_gen());
536 let resp = provider
537 .authenticate(poll_request(), CibaMode::Poll)
538 .await
539 .unwrap();
540
541 assert_eq!(
543 provider.get_status(&resp.auth_req_id).await.unwrap(),
544 CibaRequestStatus::Pending
545 );
546
547 let poll_result = provider.poll_token(&resp.auth_req_id).await;
549 assert_eq!(poll_result.unwrap_err(), CibaError::AuthorizationPending);
550
551 provider
553 .approve(&resp.auth_req_id, "user:alice")
554 .await
555 .unwrap();
556 assert_eq!(
557 provider.get_status(&resp.auth_req_id).await.unwrap(),
558 CibaRequestStatus::Approved
559 );
560
561 let token = provider.poll_token(&resp.auth_req_id).await.unwrap();
563 assert!(token.access_token.contains("alice"));
564 assert_eq!(token.token_type, "Bearer");
565 assert!(token.id_token.is_some());
566 }
567
568 #[tokio::test]
569 async fn test_deny_and_poll() {
570 let provider = CibaProvider::new(test_config(), test_token_gen());
571 let resp = provider
572 .authenticate(poll_request(), CibaMode::Poll)
573 .await
574 .unwrap();
575
576 provider.deny(&resp.auth_req_id).await.unwrap();
577
578 let poll_result = provider.poll_token(&resp.auth_req_id).await;
579 assert_eq!(poll_result.unwrap_err(), CibaError::AccessDenied);
580 }
581
582 #[tokio::test]
583 async fn test_double_approve_rejected() {
584 let provider = CibaProvider::new(test_config(), test_token_gen());
585 let resp = provider
586 .authenticate(poll_request(), CibaMode::Poll)
587 .await
588 .unwrap();
589 provider
590 .approve(&resp.auth_req_id, "user:alice")
591 .await
592 .unwrap();
593 assert!(provider.approve(&resp.auth_req_id, "user:bob").await.is_err());
594 }
595
596 #[tokio::test]
597 async fn test_approve_unknown_id() {
598 let provider = CibaProvider::new(test_config(), test_token_gen());
599 assert!(provider.approve("nonexistent", "user:alice").await.is_err());
600 }
601
602 #[tokio::test]
605 async fn test_cleanup_expired() {
606 let mut config = test_config();
607 config.expires_in = 1; let provider = CibaProvider::new(config, test_token_gen());
609 let resp = provider
610 .authenticate(poll_request(), CibaMode::Poll)
611 .await
612 .unwrap();
613 assert_eq!(provider.pending_count().await, 1);
614
615 {
618 let mut pending = provider.pending.write().await;
619 let entry = pending.get_mut(&resp.auth_req_id).unwrap();
620 entry.expires_at = 0; }
622
623 provider.cleanup_expired().await;
624 assert_eq!(provider.pending_count().await, 0);
625 }
626
627 #[tokio::test]
630 async fn test_get_notification_push() {
631 let provider = CibaProvider::new(test_config(), test_token_gen());
632 let mut req = poll_request();
633 req.client_notification_token = Some("cnt_xyz".to_string());
634 let resp = provider
635 .authenticate(req, CibaMode::Push)
636 .await
637 .unwrap();
638
639 provider
640 .approve(&resp.auth_req_id, "user:alice")
641 .await
642 .unwrap();
643
644 let (mode, cnt, token) = provider
645 .get_notification(&resp.auth_req_id)
646 .await
647 .unwrap();
648 assert_eq!(mode, CibaMode::Push);
649 assert_eq!(cnt.unwrap(), "cnt_xyz");
650 assert!(token.is_some());
651 }
652
653 #[test]
656 fn test_login_hint_serde() {
657 let hint = LoginHint::IdTokenHint("eyJ...".to_string());
658 let json = serde_json::to_string(&hint).unwrap();
659 let parsed: LoginHint = serde_json::from_str(&json).unwrap();
660 match parsed {
661 LoginHint::IdTokenHint(v) => assert_eq!(v, "eyJ..."),
662 _ => panic!("Wrong hint variant"),
663 }
664 }
665
666 #[test]
669 fn test_ciba_error_serde() {
670 let err = CibaError::SlowDown;
671 let json = serde_json::to_string(&err).unwrap();
672 assert_eq!(json, "\"slow_down\"");
673 }
674}