1use std::{sync::Arc, time::Duration};
24
25use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
26use dashmap::DashMap;
27use rand::{RngCore, rngs::OsRng};
28use sha2::{Digest, Sha256};
29use thiserror::Error;
30
31use crate::state_encryption::StateEncryptionService;
32
33#[derive(Debug, Error)]
39#[non_exhaustive]
40pub enum PkceError {
41 #[error(
47 "state not found — the authorization flow may have already been completed or the state is invalid"
48 )]
49 StateNotFound,
50
51 #[error("state expired — please restart the authorization flow")]
57 StateExpired,
58}
59
60#[derive(Debug)]
66pub struct ConsumedPkceState {
67 pub verifier: String,
70 pub redirect_uri: String,
72}
73
74struct PkceEntry {
79 verifier: String,
80 redirect_uri: String,
81 created_at: tokio::time::Instant,
84 ttl: Duration,
85}
86
87pub struct InMemoryPkceStateStore {
93 state_ttl_secs: u64,
94 entries: DashMap<String, PkceEntry>,
95 encryptor: Option<Arc<StateEncryptionService>>,
96}
97
98impl InMemoryPkceStateStore {
99 fn new(state_ttl_secs: u64, encryptor: Option<Arc<StateEncryptionService>>) -> Self {
100 Self {
101 state_ttl_secs,
102 entries: DashMap::new(),
103 encryptor,
104 }
105 }
106
107 fn create_state_sync(&self, redirect_uri: &str) -> Result<(String, String), anyhow::Error> {
108 let mut verifier_bytes = [0u8; 32];
110 OsRng.fill_bytes(&mut verifier_bytes);
111 let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
112
113 let mut key_bytes = [0u8; 32];
115 OsRng.fill_bytes(&mut key_bytes);
116 let internal_key = URL_SAFE_NO_PAD.encode(key_bytes);
117
118 self.entries.insert(
119 internal_key.clone(),
120 PkceEntry {
121 verifier: verifier.clone(),
122 redirect_uri: redirect_uri.to_owned(),
123 created_at: tokio::time::Instant::now(),
124 ttl: Duration::from_secs(self.state_ttl_secs),
125 },
126 );
127
128 let outbound_token = match &self.encryptor {
129 Some(enc) => enc.encrypt(internal_key.as_bytes())?,
130 None => internal_key,
131 };
132
133 Ok((outbound_token, verifier))
134 }
135
136 fn consume_state_sync(&self, outbound_token: &str) -> Result<ConsumedPkceState, PkceError> {
137 let internal_key = match &self.encryptor {
138 Some(enc) => {
139 let bytes = enc.decrypt(outbound_token).map_err(|_| PkceError::StateNotFound)?;
140 String::from_utf8(bytes).map_err(|_| PkceError::StateNotFound)?
141 },
142 None => outbound_token.to_owned(),
143 };
144
145 let (_, entry) = self.entries.remove(&internal_key).ok_or(PkceError::StateNotFound)?;
146
147 if entry.created_at.elapsed() > entry.ttl {
148 return Err(PkceError::StateExpired);
149 }
150
151 Ok(ConsumedPkceState {
152 verifier: entry.verifier,
153 redirect_uri: entry.redirect_uri,
154 })
155 }
156
157 fn cleanup_expired_sync(&self) {
158 self.entries.retain(|_, e| e.created_at.elapsed() <= e.ttl);
159 }
160
161 fn len_sync(&self) -> usize {
162 self.entries.len()
163 }
164}
165
166#[cfg(feature = "redis-pkce")]
174pub static REDIS_PKCE_ERRORS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
175
176#[cfg(feature = "redis-pkce")]
178pub fn redis_pkce_error_count_total() -> u64 {
179 REDIS_PKCE_ERRORS.load(std::sync::atomic::Ordering::Relaxed)
180}
181
182#[cfg(feature = "redis-pkce")]
191pub struct RedisPkceStateStore {
192 pool: redis::aio::ConnectionManager,
193 state_ttl_secs: u64,
194 encryptor: Option<Arc<StateEncryptionService>>,
195}
196
197#[cfg(feature = "redis-pkce")]
198impl RedisPkceStateStore {
199 pub async fn new(
205 url: &str,
206 state_ttl_secs: u64,
207 encryptor: Option<Arc<StateEncryptionService>>,
208 ) -> Result<Self, redis::RedisError> {
209 let client = redis::Client::open(url)?;
210 let pool = redis::aio::ConnectionManager::new(client).await?;
211 Ok(Self {
212 pool,
213 state_ttl_secs,
214 encryptor,
215 })
216 }
217
218 async fn create_state_impl(
219 &self,
220 redirect_uri: &str,
221 ) -> Result<(String, String), anyhow::Error> {
222 let mut verifier_bytes = [0u8; 32];
224 OsRng.fill_bytes(&mut verifier_bytes);
225 let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
226
227 let mut key_bytes = [0u8; 32];
229 OsRng.fill_bytes(&mut key_bytes);
230 let internal_key = URL_SAFE_NO_PAD.encode(key_bytes);
231
232 let redis_key = format!("fraiseql:pkce:{internal_key}");
234 let value = serde_json::json!({
235 "verifier": verifier,
236 "redirect_uri": redirect_uri,
237 })
238 .to_string();
239
240 let mut conn = self.pool.clone();
241 redis::cmd("SET")
242 .arg(&redis_key)
243 .arg(&value)
244 .arg("EX")
245 .arg(self.state_ttl_secs)
246 .query_async::<()>(&mut conn)
247 .await?;
248
249 let outbound_token = match &self.encryptor {
250 Some(enc) => enc.encrypt(internal_key.as_bytes())?,
251 None => internal_key,
252 };
253
254 Ok((outbound_token, verifier))
255 }
256
257 async fn consume_state_impl(
258 &self,
259 outbound_token: &str,
260 ) -> Result<ConsumedPkceState, PkceError> {
261 #[derive(serde::Deserialize)]
262 struct StoredEntry {
263 verifier: String,
264 redirect_uri: String,
265 }
266
267 let internal_key = match &self.encryptor {
269 Some(enc) => {
270 let bytes = enc.decrypt(outbound_token).map_err(|_| PkceError::StateNotFound)?;
271 String::from_utf8(bytes).map_err(|_| PkceError::StateNotFound)?
272 },
273 None => outbound_token.to_owned(),
274 };
275
276 let redis_key = format!("fraiseql:pkce:{internal_key}");
277 let mut conn = self.pool.clone();
278
279 let raw: Option<String> = redis::cmd("GETDEL")
283 .arg(&redis_key)
284 .query_async(&mut conn)
285 .await
286 .map_err(|_| PkceError::StateNotFound)?;
287
288 let json = raw.ok_or(PkceError::StateNotFound)?;
289
290 let entry: StoredEntry =
291 serde_json::from_str(&json).map_err(|_| PkceError::StateNotFound)?;
292
293 Ok(ConsumedPkceState {
297 verifier: entry.verifier,
298 redirect_uri: entry.redirect_uri,
299 })
300 }
301}
302
303#[non_exhaustive]
324pub enum PkceStateStore {
325 InMemory(InMemoryPkceStateStore),
327 #[cfg(feature = "redis-pkce")]
329 Redis(RedisPkceStateStore),
330}
331
332impl PkceStateStore {
333 pub fn new(state_ttl_secs: u64, encryptor: Option<Arc<StateEncryptionService>>) -> Self {
335 Self::InMemory(InMemoryPkceStateStore::new(state_ttl_secs, encryptor))
336 }
337
338 #[cfg(feature = "redis-pkce")]
344 pub async fn new_redis(
345 url: &str,
346 state_ttl_secs: u64,
347 encryptor: Option<Arc<StateEncryptionService>>,
348 ) -> Result<Self, redis::RedisError> {
349 let inner = RedisPkceStateStore::new(url, state_ttl_secs, encryptor).await?;
350 Ok(Self::Redis(inner))
351 }
352
353 pub const fn is_in_memory(&self) -> bool {
357 matches!(self, Self::InMemory(_))
358 }
359
360 pub async fn create_state(
372 &self,
373 redirect_uri: &str,
374 ) -> Result<(String, String), anyhow::Error> {
375 match self {
376 Self::InMemory(s) => s.create_state_sync(redirect_uri),
377 #[cfg(feature = "redis-pkce")]
378 Self::Redis(s) => s.create_state_impl(redirect_uri).await,
379 }
380 }
381
382 pub async fn consume_state(
398 &self,
399 outbound_token: &str,
400 ) -> Result<ConsumedPkceState, PkceError> {
401 match self {
402 Self::InMemory(s) => s.consume_state_sync(outbound_token),
403 #[cfg(feature = "redis-pkce")]
404 Self::Redis(s) => s.consume_state_impl(outbound_token).await,
405 }
406 }
407
408 pub fn s256_challenge(verifier: &str) -> String {
414 URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()))
415 }
416
417 pub async fn cleanup_expired(&self) {
422 match self {
423 Self::InMemory(s) => s.cleanup_expired_sync(),
424 #[cfg(feature = "redis-pkce")]
425 Self::Redis(_) => {}, }
427 }
428
429 pub fn len(&self) -> usize {
433 match self {
434 Self::InMemory(s) => s.len_sync(),
435 #[cfg(feature = "redis-pkce")]
436 Self::Redis(_) => 0,
437 }
438 }
439
440 pub fn is_empty(&self) -> bool {
444 self.len() == 0
445 }
446}
447
448#[allow(clippy::unwrap_used)] #[cfg(test)]
454mod tests {
455 use std::time::Duration;
456
457 #[allow(clippy::wildcard_imports)]
458 use super::*;
460 use crate::state_encryption::{EncryptionAlgorithm, StateEncryptionService};
461
462 fn store_no_enc(ttl_secs: u64) -> PkceStateStore {
463 PkceStateStore::new(ttl_secs, None)
464 }
465
466 fn enc_service() -> Arc<StateEncryptionService> {
467 Arc::new(StateEncryptionService::from_raw_key(
468 &[0u8; 32],
469 EncryptionAlgorithm::Chacha20Poly1305,
470 ))
471 }
472
473 #[tokio::test]
476 async fn test_create_and_consume_roundtrip() {
477 let store = store_no_enc(600);
478 let (token, verifier) = store.create_state("https://app.example.com/cb").await.unwrap();
479 let result = store.consume_state(&token).await.unwrap();
480 assert_eq!(result.verifier, verifier);
481 assert_eq!(result.redirect_uri, "https://app.example.com/cb");
482 }
483
484 #[tokio::test]
485 async fn test_consume_removes_entry_cannot_reuse() {
486 let store = store_no_enc(600);
487 let (token, _) = store.create_state("https://app.example.com/cb").await.unwrap();
488 store.consume_state(&token).await.unwrap();
489 assert!(
490 matches!(store.consume_state(&token).await, Err(PkceError::StateNotFound)),
491 "second consume must return StateNotFound"
492 );
493 }
494
495 #[tokio::test(start_paused = true)]
496 async fn test_expired_state_returns_state_expired_not_not_found() {
497 let store = store_no_enc(1);
498 let (token, _) = store.create_state("https://example.com").await.unwrap();
499 tokio::time::advance(Duration::from_millis(1100)).await;
500 assert!(
501 matches!(store.consume_state(&token).await, Err(PkceError::StateExpired)),
502 "expired state must be StateExpired, not StateNotFound"
503 );
504 }
505
506 #[tokio::test]
507 async fn test_unknown_token_returns_not_found() {
508 let store = store_no_enc(600);
509 assert!(matches!(
510 store.consume_state("completely-unknown-token").await,
511 Err(PkceError::StateNotFound)
512 ));
513 }
514
515 #[tokio::test]
516 async fn test_two_distinct_states_dont_interfere() {
517 let store = store_no_enc(600);
518 let (t1, v1) = store.create_state("https://a.example.com/cb").await.unwrap();
519 let (t2, v2) = store.create_state("https://b.example.com/cb").await.unwrap();
520 let r2 = store.consume_state(&t2).await.unwrap();
521 let r1 = store.consume_state(&t1).await.unwrap();
522 assert_eq!(r1.verifier, v1);
523 assert_eq!(r2.verifier, v2);
524 }
525
526 #[test]
529 fn test_s256_challenge_matches_rfc7636_appendix_a() {
530 let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
531 let expected = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
532 assert_eq!(PkceStateStore::s256_challenge(verifier), expected);
533 }
534
535 #[tokio::test]
536 async fn test_verifier_length_and_charset_are_rfc7636_compliant() {
537 let store = store_no_enc(600);
538 let (_, verifier) = store.create_state("https://example.com").await.unwrap();
539 assert!(
540 verifier.len() >= 43 && verifier.len() <= 128,
541 "verifier length {} is outside the 43–128 char range",
542 verifier.len()
543 );
544 assert!(!verifier.contains('='), "verifier must not contain padding characters");
545 }
546
547 #[tokio::test]
550 async fn test_encrypted_token_is_longer_than_raw_internal_key() {
551 let store = PkceStateStore::new(600, Some(enc_service()));
552 let (token, _) = store.create_state("https://app.example.com/cb").await.unwrap();
553 assert!(
554 token.len() > 43,
555 "encrypted token (len={}) must be longer than a raw 32-byte key (43 chars)",
556 token.len()
557 );
558 }
559
560 #[tokio::test]
561 async fn test_encrypted_roundtrip_works_end_to_end() {
562 let store = PkceStateStore::new(600, Some(enc_service()));
563 let (token, verifier) = store.create_state("https://app.example.com/cb").await.unwrap();
564 let result = store.consume_state(&token).await.unwrap();
565 assert_eq!(result.verifier, verifier);
566 }
567
568 #[tokio::test]
569 async fn test_tampered_encrypted_token_returns_not_found() {
570 let store = PkceStateStore::new(600, Some(enc_service()));
571 store.create_state("https://app.example.com/cb").await.unwrap();
572 let result = store.consume_state("aGVsbG8gd29ybGQ").await;
573 assert!(
574 matches!(result, Err(PkceError::StateNotFound)),
575 "tampered token must yield StateNotFound, not an internal error"
576 );
577 }
578
579 #[tokio::test(start_paused = true)]
582 async fn test_consume_at_exact_ttl_boundary_succeeds() {
583 let store = store_no_enc(2);
584 let (token, verifier) = store.create_state("https://example.com").await.unwrap();
585 tokio::time::advance(Duration::from_secs(2)).await;
587 let result = store.consume_state(&token).await.unwrap();
588 assert_eq!(result.verifier, verifier, "state at exact TTL boundary must still be valid");
589 }
590
591 #[test]
594 fn test_is_in_memory_returns_true_for_in_memory_store() {
595 let store = PkceStateStore::new(600, None);
596 assert!(store.is_in_memory());
597 }
598
599 #[tokio::test]
602 async fn test_is_empty_true_for_fresh_store() {
603 let store = store_no_enc(600);
604 assert!(store.is_empty(), "fresh store must be empty");
605 assert_eq!(store.len(), 0);
606 }
607
608 #[tokio::test]
609 async fn test_is_empty_false_after_create() {
610 let store = store_no_enc(600);
611 store.create_state("https://example.com").await.unwrap();
612 assert!(!store.is_empty(), "store with one entry must not be empty");
613 assert_eq!(store.len(), 1);
614 }
615
616 #[tokio::test]
617 async fn test_is_empty_true_after_consume() {
618 let store = store_no_enc(600);
619 let (token, _) = store.create_state("https://example.com").await.unwrap();
620 store.consume_state(&token).await.unwrap();
621 assert!(store.is_empty(), "store must be empty after consuming the only entry");
622 }
623
624 #[tokio::test]
627 async fn test_cleanup_removes_expired_leaves_valid() {
628 let store = store_no_enc(1);
629 store.create_state("https://a.example.com").await.unwrap();
630 tokio::time::sleep(Duration::from_millis(1100)).await;
631 store.cleanup_expired().await;
632 assert_eq!(store.len(), 0, "expired entry must be removed by cleanup");
633
634 let store2 = store_no_enc(600);
635 store2.create_state("https://b.example.com").await.unwrap();
636 store2.cleanup_expired().await;
637 assert_eq!(store2.len(), 1, "unexpired entry must survive cleanup");
638 }
639
640 #[cfg(feature = "redis-pkce")]
646 #[tokio::test]
647 #[ignore = "requires Redis — set REDIS_URL=redis://localhost:6379"]
648 async fn test_redis_pkce_create_and_consume_roundtrip() {
649 let url =
650 std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://localhost:6379".to_string());
651 let store = PkceStateStore::new_redis(&url, 300, None)
652 .await
653 .expect("Redis connection failed");
654
655 let (token, verifier) = store.create_state("https://example.com/cb").await.unwrap();
656 let consumed = store.consume_state(&token).await.unwrap();
657 assert_eq!(consumed.verifier, verifier);
658 assert_eq!(consumed.redirect_uri, "https://example.com/cb");
659 }
660
661 #[cfg(feature = "redis-pkce")]
662 #[tokio::test]
663 #[ignore = "requires Redis — set REDIS_URL=redis://localhost:6379"]
664 async fn test_redis_pkce_one_shot_consumption() {
665 let url =
666 std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://localhost:6379".to_string());
667 let store = PkceStateStore::new_redis(&url, 300, None)
668 .await
669 .expect("Redis connection failed");
670
671 let (token, _) = store.create_state("https://example.com/cb").await.unwrap();
672 store.consume_state(&token).await.unwrap();
673
674 let second = store.consume_state(&token).await;
675 assert!(
676 matches!(second, Err(PkceError::StateNotFound)),
677 "second consume must return StateNotFound — GETDEL guarantees one-shot"
678 );
679 }
680
681 #[cfg(feature = "redis-pkce")]
682 #[tokio::test]
683 #[ignore = "requires Redis — set REDIS_URL=redis://localhost:6379"]
684 async fn test_redis_pkce_two_instances_share_state() {
685 let url =
686 std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://localhost:6379".to_string());
687
688 let store_a = PkceStateStore::new_redis(&url, 300, None)
690 .await
691 .expect("Redis connection failed");
692 let store_b = PkceStateStore::new_redis(&url, 300, None)
693 .await
694 .expect("Redis connection failed");
695
696 let (token, verifier) = store_a.create_state("https://example.com/cb").await.unwrap();
698
699 let consumed = store_b.consume_state(&token).await.unwrap();
701 assert_eq!(
702 consumed.verifier, verifier,
703 "cross-replica consumption must succeed with shared Redis"
704 );
705 }
706
707 #[cfg(feature = "redis-pkce")]
708 #[tokio::test]
709 #[ignore = "requires Redis — set REDIS_URL=redis://localhost:6379"]
710 async fn test_redis_pkce_tampered_token_rejected() {
711 let url =
712 std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://localhost:6379".to_string());
713 let enc = Some(Arc::new(StateEncryptionService::from_raw_key(
714 &[0u8; 32],
715 EncryptionAlgorithm::Chacha20Poly1305,
716 )));
717 let store = PkceStateStore::new_redis(&url, 300, enc)
718 .await
719 .expect("Redis connection failed");
720
721 store.create_state("https://example.com/cb").await.unwrap();
722
723 let result = store.consume_state("completely-fabricated-token").await;
724 assert!(
725 matches!(result, Err(PkceError::StateNotFound)),
726 "tampered token must be rejected"
727 );
728 }
729}