Skip to main content

fraiseql_auth/
state_encryption.rs

1//! State encryption for PKCE and OAuth state parameter protection.
2//!
3//! Encrypts OAuth `state` (and PKCE) blobs with AEAD ciphers so that the
4//! outbound token sent to the identity provider cannot be deciphered or
5//! tampered with by an attacker who intercepts the redirect.
6//!
7//! Supports two algorithms selectable at runtime:
8//! - [`EncryptionAlgorithm::Chacha20Poly1305`] (default, constant-time in software)
9//! - [`EncryptionAlgorithm::Aes256Gcm`] (hardware-accelerated on modern CPUs)
10
11use std::{fmt, sync::Arc};
12
13// aes_gcm and chacha20poly1305 both re-export the same underlying `aead` traits.
14// We import them once from chacha20poly1305 and reuse for both cipher types.
15use aes_gcm::Aes256Gcm;
16use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
17use chacha20poly1305::{
18    ChaCha20Poly1305, Nonce,
19    aead::{Aead, AeadCore, KeyInit, OsRng, Payload},
20};
21use rand::RngCore;
22use serde::{Deserialize, Serialize};
23use zeroize::Zeroizing;
24
25use crate::{AuthError, error::Result};
26
27/// Encrypted state container with nonce
28#[derive(Debug, Clone)]
29pub struct EncryptedState {
30    /// Ciphertext with authentication tag appended
31    pub ciphertext: Vec<u8>,
32    /// 96-bit nonce used for encryption
33    pub nonce:      [u8; 12],
34}
35
36impl EncryptedState {
37    /// Create new encrypted state
38    pub const fn new(ciphertext: Vec<u8>, nonce: [u8; 12]) -> Self {
39        Self { ciphertext, nonce }
40    }
41
42    /// Serialize to bytes for storage
43    /// Format: [12-byte nonce][ciphertext with auth tag]
44    pub fn to_bytes(&self) -> Vec<u8> {
45        let mut bytes = Vec::with_capacity(12 + self.ciphertext.len());
46        bytes.extend_from_slice(&self.nonce);
47        bytes.extend_from_slice(&self.ciphertext);
48        bytes
49    }
50
51    /// Deserialize from bytes.
52    ///
53    /// # Errors
54    ///
55    /// Returns [`AuthError::InvalidState`] if `bytes` is shorter than 12 bytes
56    /// (minimum nonce size).
57    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
58        if bytes.len() < 12 {
59            return Err(AuthError::InvalidState);
60        }
61
62        let mut nonce = [0u8; 12];
63        nonce.copy_from_slice(&bytes[0..12]);
64        let ciphertext = bytes[12..].to_vec();
65
66        Ok(Self::new(ciphertext, nonce))
67    }
68}
69
70/// State encryption using ChaCha20-Poly1305 AEAD
71///
72/// Provides authenticated encryption for OAuth state parameters.
73/// Uses a fixed encryption key for the deployment lifetime.
74/// Each encryption uses a random nonce for security.
75///
76/// # Security Properties
77/// - **Confidentiality**: State values are encrypted with ChaCha20
78/// - **Authenticity**: Authentication tag prevents tampering detection
79/// - **Replay Prevention**: Random nonce in each encryption
80/// - **Key Isolation**: Separate from signing keys, used only for state
81pub struct StateEncryption {
82    cipher: ChaCha20Poly1305,
83}
84
85impl StateEncryption {
86    /// Create a new state encryption instance
87    ///
88    /// # Arguments
89    /// * `key` - 32-byte encryption key (must be cryptographically random)
90    ///
91    /// # Errors
92    /// Returns error if key is invalid
93    pub fn new(key_bytes: &[u8; 32]) -> Result<Self> {
94        let cipher =
95            ChaCha20Poly1305::new_from_slice(key_bytes).map_err(|_| AuthError::ConfigError {
96                message: "Invalid state encryption key".to_string(),
97            })?;
98
99        Ok(Self { cipher })
100    }
101
102    /// Encrypt a state value
103    ///
104    /// Generates a random 96-bit nonce and encrypts the state using ChaCha20-Poly1305.
105    /// The authentication tag is appended to the ciphertext.
106    ///
107    /// # Arguments
108    /// * `state` - The plaintext state value to encrypt
109    ///
110    /// # Returns
111    /// EncryptedState containing ciphertext and nonce
112    ///
113    /// # Errors
114    /// Returns error if encryption fails (should be rare)
115    pub fn encrypt(&self, state: &str) -> Result<EncryptedState> {
116        // Generate random 96-bit nonce
117        let mut nonce_bytes = [0u8; 12];
118        rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
119        let nonce = Nonce::from(nonce_bytes);
120
121        // Encrypt with AEAD (includes authentication tag)
122        let ciphertext =
123            self.cipher.encrypt(&nonce, Payload::from(state.as_bytes())).map_err(|_| {
124                AuthError::Internal {
125                    message: "State encryption failed".to_string(),
126                }
127            })?;
128
129        Ok(EncryptedState::new(ciphertext, nonce_bytes))
130    }
131
132    /// Decrypt and verify a state value
133    ///
134    /// Uses the nonce from EncryptedState to decrypt the ciphertext.
135    /// Authentication tag verification is automatic - tampering is detected.
136    ///
137    /// # Arguments
138    /// * `encrypted` - The encrypted state to decrypt
139    ///
140    /// # Returns
141    /// The decrypted plaintext state value
142    ///
143    /// # Errors
144    /// Returns error if:
145    /// - Authentication tag verification fails (tampering detected)
146    /// - Decryption fails
147    /// - Result is not valid UTF-8
148    pub fn decrypt(&self, encrypted: &EncryptedState) -> Result<String> {
149        let nonce = Nonce::from(encrypted.nonce);
150
151        // Decrypt and verify authentication tag
152        let plaintext = self
153            .cipher
154            .decrypt(&nonce, Payload::from(encrypted.ciphertext.as_slice()))
155            .map_err(|_| AuthError::InvalidState)?;
156
157        // Convert bytes to UTF-8 string
158        String::from_utf8(plaintext).map_err(|_| AuthError::InvalidState)
159    }
160
161    /// Encrypt state and serialize to bytes.
162    ///
163    /// # Errors
164    ///
165    /// Returns [`AuthError::Internal`] if AEAD encryption fails (essentially never).
166    pub fn encrypt_to_bytes(&self, state: &str) -> Result<Vec<u8>> {
167        let encrypted = self.encrypt(state)?;
168        Ok(encrypted.to_bytes())
169    }
170
171    /// Decrypt state from serialized bytes.
172    ///
173    /// # Errors
174    ///
175    /// Returns [`AuthError::InvalidState`] if `bytes` is too short, if AEAD
176    /// authentication fails (tampered or wrong key), or if decrypted bytes are
177    /// not valid UTF-8.
178    pub fn decrypt_from_bytes(&self, bytes: &[u8]) -> Result<String> {
179        let encrypted = EncryptedState::from_bytes(bytes)?;
180        self.decrypt(&encrypted)
181    }
182}
183
184/// Generate a cryptographically random encryption key
185pub fn generate_state_encryption_key() -> Zeroizing<[u8; 32]> {
186    let mut key = [0u8; 32];
187    rand::rngs::OsRng.fill_bytes(&mut key);
188    Zeroizing::new(key)
189}
190
191// ── StateEncryptionService ────────────────────────────────────────────────────
192//
193// A higher-level service that wraps the low-level `StateEncryption` struct.
194// Differences from `StateEncryption`:
195//   - Supports both ChaCha20-Poly1305 AND AES-256-GCM (runtime-selectable)
196//   - Wire format: URL-safe base64 of `[12-byte nonce || ciphertext || tag]`
197//   - Accepts keys as 64-char hex strings or env-var names
198//   - Can be constructed from the compiled schema JSON
199//   - Key never appears in `Debug` output
200//
201// This is the PKCE state encryption service wired into `Server`.
202
203/// Errors that can occur during decryption by `StateEncryptionService`.
204#[derive(Debug, thiserror::Error)]
205#[non_exhaustive]
206pub enum DecryptionError {
207    /// Ciphertext was tampered with or encrypted with a different key.
208    #[error("authentication failed — ciphertext may be tampered or key is wrong")]
209    AuthenticationFailed,
210    /// Input is malformed (empty, too short, bad base64, etc.).
211    #[error("invalid input: {0}")]
212    InvalidInput(String),
213}
214
215/// Errors that can occur when constructing a `StateEncryptionService` key.
216#[derive(Debug, thiserror::Error)]
217#[non_exhaustive]
218pub enum KeyError {
219    /// Hex string was not 64 characters (32 bytes).
220    #[error("hex key must be 64 chars (32 bytes); got {0} chars")]
221    WrongLength(usize),
222    /// Hex string contained a non-hex character.
223    #[error("invalid hex character in key")]
224    InvalidHex,
225}
226
227/// AEAD algorithm selection for `StateEncryptionService`.
228#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq, Eq)]
229#[non_exhaustive]
230pub enum EncryptionAlgorithm {
231    /// ChaCha20-Poly1305 (recommended — constant-time, software-friendly).
232    #[default]
233    #[serde(rename = "chacha20-poly1305")]
234    Chacha20Poly1305,
235    /// AES-256-GCM (hardware-accelerated on modern CPUs).
236    #[serde(rename = "aes-256-gcm")]
237    Aes256Gcm,
238}
239
240impl fmt::Display for EncryptionAlgorithm {
241    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242        match self {
243            Self::Chacha20Poly1305 => f.write_str("chacha20-poly1305"),
244            Self::Aes256Gcm => f.write_str("aes-256-gcm"),
245        }
246    }
247}
248
249/// Deserialized from `compiled.security.state_encryption`.
250#[derive(Debug, Clone, Deserialize, Serialize)]
251#[serde(default)]
252pub struct StateEncryptionConfig {
253    /// Enable the service; when `false`, `from_compiled_schema` returns `None`.
254    pub enabled:   bool,
255    /// AEAD algorithm to use.
256    pub algorithm: EncryptionAlgorithm,
257    /// Name of the environment variable holding the 64-char hex key.
258    pub key_env:   Option<String>,
259}
260
261impl Default for StateEncryptionConfig {
262    fn default() -> Self {
263        Self {
264            enabled:   false,
265            algorithm: EncryptionAlgorithm::default(),
266            key_env:   Some("STATE_ENCRYPTION_KEY".to_string()),
267        }
268    }
269}
270
271/// AEAD encryption service for OAuth state and PKCE blobs.
272///
273/// Wire format: URL-safe base64 of `[12-byte nonce || ciphertext || 16-byte tag]`.
274///
275/// The 32-byte key is never printed in [`fmt::Debug`] output.
276pub struct StateEncryptionService {
277    algorithm: EncryptionAlgorithm,
278    key:       [u8; 32],
279}
280
281impl fmt::Debug for StateEncryptionService {
282    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283        f.debug_struct("StateEncryptionService")
284            .field("algorithm", &self.algorithm)
285            .field("key", &"[REDACTED]")
286            .finish()
287    }
288}
289
290impl StateEncryptionService {
291    /// Construct from a raw 32-byte key slice.
292    pub const fn from_raw_key(key: &[u8; 32], algorithm: EncryptionAlgorithm) -> Self {
293        Self {
294            algorithm,
295            key: *key,
296        }
297    }
298
299    /// Construct from a 64-character hex string (= 32 bytes).
300    ///
301    /// # Errors
302    ///
303    /// Returns [`KeyError::WrongLength`] if `hex` is not 64 chars.
304    /// Returns [`KeyError::InvalidHex`] if `hex` contains non-hex chars.
305    pub fn from_hex_key(
306        hex: &str,
307        algorithm: EncryptionAlgorithm,
308    ) -> std::result::Result<Self, KeyError> {
309        if hex.len() != 64 {
310            return Err(KeyError::WrongLength(hex.len()));
311        }
312        let bytes = hex::decode(hex).map_err(|_| KeyError::InvalidHex)?;
313        let mut key = [0u8; 32];
314        key.copy_from_slice(&bytes);
315        Ok(Self { algorithm, key })
316    }
317
318    /// Load the key from an environment variable containing a 64-char hex string.
319    ///
320    /// # Errors
321    ///
322    /// Returns an error if the env var is absent or the value is not valid hex/length.
323    pub fn new_from_env(
324        var: &str,
325        algorithm: EncryptionAlgorithm,
326    ) -> std::result::Result<Self, anyhow::Error> {
327        let hex = std::env::var(var).map_err(|_| anyhow::anyhow!("env var {var} not set"))?;
328        Ok(Self::from_hex_key(&hex, algorithm)?)
329    }
330
331    /// Build from the `security` blob of a compiled schema, if enabled.
332    ///
333    /// Returns `Ok(None)` when the `state_encryption` key is absent or `enabled = false`.
334    ///
335    /// # Errors
336    ///
337    /// Returns `Err` when `enabled = true` but the key environment variable is absent
338    /// or contains an invalid value.  The server must refuse to start in this case.
339    pub fn from_compiled_schema(
340        security_json: &serde_json::Value,
341    ) -> std::result::Result<Option<Arc<Self>>, anyhow::Error> {
342        let cfg: StateEncryptionConfig = match security_json.get("state_encryption") {
343            None => return Ok(None),
344            Some(v) => serde_json::from_value(v.clone())
345                .map_err(|e| anyhow::anyhow!("invalid state_encryption config: {e}"))?,
346        };
347
348        if !cfg.enabled {
349            return Ok(None);
350        }
351
352        let key_env = cfg.key_env.as_deref().unwrap_or("STATE_ENCRYPTION_KEY");
353        Self::new_from_env(key_env, cfg.algorithm)
354            .map(|svc| Some(Arc::new(svc)))
355            .map_err(|e| {
356                anyhow::anyhow!(
357                    "state_encryption enabled but key env var '{}' failed: {e}",
358                    key_env
359                )
360            })
361    }
362
363    /// Encrypt `plaintext` to a URL-safe base64 string.
364    ///
365    /// A fresh random nonce is generated on every call.
366    ///
367    /// # Errors
368    ///
369    /// Returns an error only on internal cipher failure (essentially never).
370    pub fn encrypt(&self, plaintext: &[u8]) -> std::result::Result<String, anyhow::Error> {
371        let combined = match self.algorithm {
372            EncryptionAlgorithm::Chacha20Poly1305 => {
373                let cipher = ChaCha20Poly1305::new_from_slice(&self.key)
374                    .map_err(|_| anyhow::anyhow!("invalid key for ChaCha20-Poly1305"))?;
375                let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
376                let ct = cipher
377                    .encrypt(&nonce, plaintext)
378                    .map_err(|_| anyhow::anyhow!("ChaCha20-Poly1305 encryption failed"))?;
379                let mut out = nonce.to_vec();
380                out.extend_from_slice(&ct);
381                out
382            },
383            EncryptionAlgorithm::Aes256Gcm => {
384                let cipher = Aes256Gcm::new_from_slice(&self.key)
385                    .map_err(|_| anyhow::anyhow!("invalid key for AES-256-GCM"))?;
386                let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
387                let ct = cipher
388                    .encrypt(&nonce, plaintext)
389                    .map_err(|_| anyhow::anyhow!("AES-256-GCM encryption failed"))?;
390                let mut out = nonce.to_vec();
391                out.extend_from_slice(&ct);
392                out
393            },
394        };
395        Ok(URL_SAFE_NO_PAD.encode(&combined))
396    }
397
398    /// Decrypt a URL-safe base64 string produced by [`Self::encrypt`].
399    ///
400    /// # Errors
401    ///
402    /// - [`DecryptionError::InvalidInput`] — empty / too-short / bad base64
403    /// - [`DecryptionError::AuthenticationFailed`] — tampered or wrong-key
404    pub fn decrypt(&self, encoded: &str) -> std::result::Result<Vec<u8>, DecryptionError> {
405        const NONCE_SIZE: usize = 12;
406        if encoded.is_empty() {
407            return Err(DecryptionError::InvalidInput("empty input".into()));
408        }
409        let combined = URL_SAFE_NO_PAD
410            .decode(encoded)
411            .map_err(|_| DecryptionError::InvalidInput("invalid base64".into()))?;
412
413        if combined.len() < NONCE_SIZE {
414            return Err(DecryptionError::InvalidInput(format!(
415                "too short: {} bytes (minimum {NONCE_SIZE})",
416                combined.len()
417            )));
418        }
419        let (nonce_bytes, ct) = combined.split_at(NONCE_SIZE);
420
421        match self.algorithm {
422            EncryptionAlgorithm::Chacha20Poly1305 => {
423                let cipher = ChaCha20Poly1305::new_from_slice(&self.key)
424                    .map_err(|_| DecryptionError::InvalidInput("invalid key".into()))?;
425                let nonce = chacha20poly1305::Nonce::from_slice(nonce_bytes);
426                cipher.decrypt(nonce, ct).map_err(|_| DecryptionError::AuthenticationFailed)
427            },
428            EncryptionAlgorithm::Aes256Gcm => {
429                let cipher = Aes256Gcm::new_from_slice(&self.key)
430                    .map_err(|_| DecryptionError::InvalidInput("invalid key".into()))?;
431                let nonce = aes_gcm::Nonce::from_slice(nonce_bytes);
432                cipher.decrypt(nonce, ct).map_err(|_| DecryptionError::AuthenticationFailed)
433            },
434        }
435    }
436}
437
438#[allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
439#[cfg(test)]
440mod service_tests {
441    #[allow(clippy::wildcard_imports)]
442    // Reason: test module — wildcard keeps test boilerplate minimal
443    use super::*;
444
445    fn chacha_svc() -> StateEncryptionService {
446        StateEncryptionService::from_raw_key(&[0u8; 32], EncryptionAlgorithm::Chacha20Poly1305)
447    }
448    fn aes_svc() -> StateEncryptionService {
449        StateEncryptionService::from_raw_key(&[0u8; 32], EncryptionAlgorithm::Aes256Gcm)
450    }
451
452    #[test]
453    fn test_chacha_encrypt_decrypt_roundtrip() {
454        let svc = chacha_svc();
455        let pt = b"oauth_state_nonce_12345";
456        assert_eq!(svc.decrypt(&svc.encrypt(pt).unwrap()).unwrap(), pt);
457    }
458
459    #[test]
460    fn test_chacha_two_encryptions_differ() {
461        let svc = chacha_svc();
462        assert_ne!(svc.encrypt(b"hello").unwrap(), svc.encrypt(b"hello").unwrap());
463    }
464
465    #[test]
466    fn test_chacha_tampered_fails() {
467        let svc = chacha_svc();
468        let ct = svc.encrypt(b"secret").unwrap();
469        let mut bytes = URL_SAFE_NO_PAD.decode(&ct).unwrap();
470        bytes[15] ^= 0xFF;
471        let tampered = URL_SAFE_NO_PAD.encode(&bytes);
472        assert!(matches!(svc.decrypt(&tampered), Err(DecryptionError::AuthenticationFailed)));
473    }
474
475    #[test]
476    fn test_chacha_wrong_key_fails() {
477        let a =
478            StateEncryptionService::from_raw_key(&[0u8; 32], EncryptionAlgorithm::Chacha20Poly1305);
479        let b =
480            StateEncryptionService::from_raw_key(&[1u8; 32], EncryptionAlgorithm::Chacha20Poly1305);
481        let ct = a.encrypt(b"secret").unwrap();
482        assert!(matches!(b.decrypt(&ct), Err(DecryptionError::AuthenticationFailed)));
483    }
484
485    #[test]
486    fn test_aes_encrypt_decrypt_roundtrip() {
487        let svc = aes_svc();
488        let pt = b"pkce_code_challenge";
489        assert_eq!(svc.decrypt(&svc.encrypt(pt).unwrap()).unwrap(), pt);
490    }
491
492    #[test]
493    fn test_aes_two_encryptions_differ() {
494        let svc = aes_svc();
495        assert_ne!(svc.encrypt(b"hello").unwrap(), svc.encrypt(b"hello").unwrap());
496    }
497
498    #[test]
499    fn test_aes_tampered_fails() {
500        let svc = aes_svc();
501        let ct = svc.encrypt(b"secret").unwrap();
502        let mut bytes = URL_SAFE_NO_PAD.decode(&ct).unwrap();
503        bytes[15] ^= 0xFF;
504        let tampered = URL_SAFE_NO_PAD.encode(&bytes);
505        assert!(matches!(svc.decrypt(&tampered), Err(DecryptionError::AuthenticationFailed)));
506    }
507
508    #[test]
509    fn test_aes_wrong_key_fails() {
510        let a = StateEncryptionService::from_raw_key(&[0u8; 32], EncryptionAlgorithm::Aes256Gcm);
511        let b = StateEncryptionService::from_raw_key(&[1u8; 32], EncryptionAlgorithm::Aes256Gcm);
512        let ct = a.encrypt(b"secret").unwrap();
513        assert!(matches!(b.decrypt(&ct), Err(DecryptionError::AuthenticationFailed)));
514    }
515
516    #[test]
517    fn test_empty_ciphertext_invalid_input() {
518        assert!(matches!(chacha_svc().decrypt(""), Err(DecryptionError::InvalidInput(_))));
519    }
520
521    #[test]
522    fn test_too_short_invalid_input() {
523        let short = URL_SAFE_NO_PAD.encode([0u8; 11]);
524        assert!(matches!(chacha_svc().decrypt(&short), Err(DecryptionError::InvalidInput(_))));
525    }
526
527    #[test]
528    fn test_bad_base64_invalid_input() {
529        assert!(matches!(
530            chacha_svc().decrypt("not!valid@base64#"),
531            Err(DecryptionError::InvalidInput(_))
532        ));
533    }
534
535    #[test]
536    fn test_from_hex_key_valid() {
537        let hex = "00".repeat(32);
538        StateEncryptionService::from_hex_key(&hex, EncryptionAlgorithm::Chacha20Poly1305)
539            .unwrap_or_else(|e| panic!("expected Ok for valid 64-char hex key: {e}"));
540    }
541
542    #[test]
543    fn test_from_hex_key_wrong_length() {
544        assert!(matches!(
545            StateEncryptionService::from_hex_key("deadbeef", EncryptionAlgorithm::Chacha20Poly1305),
546            Err(KeyError::WrongLength(_))
547        ));
548    }
549
550    #[test]
551    fn test_from_hex_key_invalid_hex() {
552        let bad = "zz".repeat(32);
553        assert!(matches!(
554            StateEncryptionService::from_hex_key(&bad, EncryptionAlgorithm::Chacha20Poly1305),
555            Err(KeyError::InvalidHex)
556        ));
557    }
558
559    #[test]
560    fn test_debug_redacts_key() {
561        let svc = chacha_svc();
562        let s = format!("{svc:?}");
563        assert!(!s.contains("00000000"), "key bytes must not appear in debug output");
564        assert!(s.contains("REDACTED"));
565    }
566
567    #[test]
568    fn test_from_compiled_schema_enabled_missing_key_returns_error() {
569        // Use a unique env var name that is guaranteed absent
570        std::env::remove_var("FRAISEQL_TEST_MISSING_ENC_KEY_B1");
571        let json = serde_json::json!({
572            "state_encryption": {
573                "enabled": true,
574                "algorithm": "chacha20-poly1305",
575                "key_env": "FRAISEQL_TEST_MISSING_ENC_KEY_B1"
576            }
577        });
578        let result = StateEncryptionService::from_compiled_schema(&json);
579        assert!(result.is_err(), "should error when enabled=true but env var absent");
580        let msg = result.unwrap_err().to_string();
581        assert!(msg.contains("FRAISEQL_TEST_MISSING_ENC_KEY_B1"));
582    }
583
584    #[test]
585    fn test_from_compiled_schema_enabled() {
586        let key_hex = "aa".repeat(32);
587        std::env::set_var("TEST_SVC_ENC_KEY_P3", &key_hex);
588        let json = serde_json::json!({
589            "state_encryption": {
590                "enabled": true,
591                "algorithm": "chacha20-poly1305",
592                "key_env": "TEST_SVC_ENC_KEY_P3"
593            }
594        });
595        let svc = StateEncryptionService::from_compiled_schema(&json)
596            .expect("should succeed when env var is set");
597        assert!(svc.is_some());
598        std::env::remove_var("TEST_SVC_ENC_KEY_P3");
599    }
600
601    #[test]
602    fn test_from_compiled_schema_disabled() {
603        let json = serde_json::json!({"state_encryption": {"enabled": false}});
604        assert!(
605            StateEncryptionService::from_compiled_schema(&json)
606                .expect("disabled should be ok")
607                .is_none()
608        );
609    }
610
611    #[test]
612    fn test_from_compiled_schema_missing() {
613        assert!(
614            StateEncryptionService::from_compiled_schema(&serde_json::json!({}))
615                .expect("missing should be ok")
616                .is_none()
617        );
618    }
619
620    #[test]
621    fn test_cross_algorithm_fails() {
622        let chacha =
623            StateEncryptionService::from_raw_key(&[0u8; 32], EncryptionAlgorithm::Chacha20Poly1305);
624        let aes = StateEncryptionService::from_raw_key(&[0u8; 32], EncryptionAlgorithm::Aes256Gcm);
625        let ct = chacha.encrypt(b"cross").unwrap();
626        assert!(matches!(aes.decrypt(&ct), Err(DecryptionError::AuthenticationFailed)));
627    }
628}
629
630#[allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
631#[cfg(test)]
632mod tests {
633    #[allow(clippy::wildcard_imports)]
634    // Reason: test module — wildcard keeps test boilerplate minimal
635    use super::*;
636
637    fn test_key() -> [u8; 32] {
638        // Use deterministic test key
639        [42u8; 32]
640    }
641
642    #[test]
643    fn test_encrypt_decrypt() {
644        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
645        let state = "oauth_state_test_value";
646
647        let encrypted = encryption.encrypt(state).expect("Encryption failed");
648        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
649
650        assert_eq!(decrypted, state);
651    }
652
653    #[test]
654    fn test_encrypt_produces_ciphertext() {
655        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
656        let state = "test_state";
657
658        let encrypted = encryption.encrypt(state).expect("Encryption failed");
659
660        // Ciphertext should be different from plaintext (due to ChaCha20 encryption)
661        // Ciphertext should include auth tag, so typically longer than plaintext
662        assert_ne!(encrypted.ciphertext, state.as_bytes());
663    }
664
665    #[test]
666    fn test_empty_state() {
667        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
668        let state = "";
669
670        let encrypted = encryption.encrypt(state).expect("Encryption failed");
671        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
672
673        assert_eq!(decrypted, state);
674    }
675
676    #[test]
677    fn test_different_keys_fail_decryption() {
678        let key1 = [42u8; 32];
679        let key2 = [99u8; 32];
680        let state = "secret_state";
681
682        let encryption1 = StateEncryption::new(&key1).expect("Init 1 failed");
683        let encrypted = encryption1.encrypt(state).expect("Encryption failed");
684
685        let encryption2 = StateEncryption::new(&key2).expect("Init 2 failed");
686        let result = encryption2.decrypt(&encrypted);
687
688        // Different key should fail due to authentication tag mismatch
689        assert!(
690            matches!(result, Err(AuthError::InvalidState)),
691            "expected InvalidState for wrong-key decryption, got: {result:?}"
692        );
693    }
694
695    #[test]
696    fn test_tampered_ciphertext_fails() {
697        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
698        let state = "tamper_test";
699
700        let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
701
702        // Tamper with ciphertext
703        if !encrypted.ciphertext.is_empty() {
704            encrypted.ciphertext[0] ^= 0xFF;
705        }
706
707        // Should fail due to authentication tag verification
708        let result = encryption.decrypt(&encrypted);
709        assert!(
710            matches!(result, Err(AuthError::InvalidState)),
711            "expected InvalidState for tampered ciphertext, got: {result:?}"
712        );
713    }
714
715    #[test]
716    fn test_tampered_nonce_fails() {
717        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
718        let state = "nonce_tamper";
719
720        let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
721
722        // Tamper with nonce
723        encrypted.nonce[0] ^= 0xFF;
724
725        // Should fail due to authentication tag verification
726        let result = encryption.decrypt(&encrypted);
727        assert!(
728            matches!(result, Err(AuthError::InvalidState)),
729            "expected InvalidState for tampered nonce, got: {result:?}"
730        );
731    }
732
733    #[test]
734    fn test_truncated_ciphertext_fails() {
735        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
736        let state = "truncation_test";
737
738        let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
739
740        // Truncate (removes auth tag)
741        if encrypted.ciphertext.len() > 1 {
742            encrypted.ciphertext.truncate(encrypted.ciphertext.len() - 1);
743        }
744
745        // Should fail
746        let result = encryption.decrypt(&encrypted);
747        assert!(
748            matches!(result, Err(AuthError::InvalidState)),
749            "expected InvalidState for truncated ciphertext, got: {result:?}"
750        );
751    }
752
753    #[test]
754    fn test_serialization() {
755        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
756        let state = "serialization_test";
757
758        // Encrypt and serialize
759        let bytes = encryption.encrypt_to_bytes(state).expect("Encryption failed");
760
761        // Deserialize and decrypt
762        let decrypted = encryption.decrypt_from_bytes(&bytes).expect("Decryption failed");
763
764        assert_eq!(decrypted, state);
765    }
766
767    #[test]
768    fn test_random_nonces() {
769        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
770        let state = "random_nonce_test";
771
772        let encrypted1 = encryption.encrypt(state).expect("Encryption 1 failed");
773        let encrypted2 = encryption.encrypt(state).expect("Encryption 2 failed");
774
775        // Nonces should be different (extremely unlikely to collide)
776        assert_ne!(encrypted1.nonce, encrypted2.nonce);
777
778        // Both should decrypt correctly
779        let decrypted1 = encryption.decrypt(&encrypted1).expect("Decryption 1 failed");
780        let decrypted2 = encryption.decrypt(&encrypted2).expect("Decryption 2 failed");
781
782        assert_eq!(decrypted1, state);
783        assert_eq!(decrypted2, state);
784    }
785
786    #[test]
787    fn test_long_state() {
788        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
789        let state = "a".repeat(10_000);
790
791        let encrypted = encryption.encrypt(&state).expect("Encryption failed");
792        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
793
794        assert_eq!(decrypted, state);
795    }
796
797    #[test]
798    fn test_special_characters() {
799        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
800        let state = "state:with-special_chars.and/symbols!@#$%^&*()";
801
802        let encrypted = encryption.encrypt(state).expect("Encryption failed");
803        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
804
805        assert_eq!(decrypted, state);
806    }
807
808    #[test]
809    fn test_unicode_state() {
810        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
811        let state = "state_with_emoji_🔐_🔒_🔓_and_emoji";
812
813        let encrypted = encryption.encrypt(state).expect("Encryption failed");
814        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
815
816        assert_eq!(decrypted, state);
817    }
818
819    #[test]
820    fn test_null_bytes_in_state() {
821        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
822        let state = "state_with\x00null\x00bytes\x00";
823
824        let encrypted = encryption.encrypt(state).expect("Encryption failed");
825        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
826
827        assert_eq!(decrypted, state);
828    }
829
830    #[test]
831    fn test_key_generation() {
832        let key1 = generate_state_encryption_key();
833        let key2 = generate_state_encryption_key();
834
835        // Keys should be different
836        assert_ne!(key1, key2);
837
838        // Both should be valid 32-byte keys
839        assert_eq!(key1.len(), 32);
840        assert_eq!(key2.len(), 32);
841
842        // Both should work
843        let enc1 = StateEncryption::new(&key1).expect("Init 1 failed");
844        let enc2 = StateEncryption::new(&key2).expect("Init 2 failed");
845
846        let state = "test";
847        let encrypted1 = enc1.encrypt(state).expect("Encryption 1 failed");
848        let encrypted2 = enc2.encrypt(state).expect("Encryption 2 failed");
849
850        assert_eq!(enc1.decrypt(&encrypted1).expect("Decryption 1 failed"), state);
851        assert_eq!(enc2.decrypt(&encrypted2).expect("Decryption 2 failed"), state);
852    }
853
854    #[test]
855    fn test_large_ciphertext() {
856        let encryption = StateEncryption::new(&test_key()).expect("Init failed");
857        let state = "x".repeat(100_000);
858
859        let encrypted = encryption.encrypt(&state).expect("Encryption failed");
860        let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
861
862        assert_eq!(decrypted, state);
863    }
864}