Skip to main content

pim_crypto/
session.rs

1//! Session-level symmetric encryption for transport payloads.
2
3use aes_gcm::aead::{Aead, KeyInit};
4use aes_gcm::{Aes256Gcm, Nonce};
5use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
6
7/// Encrypted frame produced by SessionCipher.
8#[derive(Clone, Debug)]
9pub struct EncryptedFrame {
10    /// 12-byte nonce used for this frame.
11    pub nonce: [u8; 12],
12    /// Ciphertext + AES-GCM authentication tag.
13    pub ciphertext: Vec<u8>,
14}
15
16/// Symmetric cipher for encrypting/decrypting frames within a session.
17///
18/// Uses AES-256-GCM with an incrementing nonce counter to prevent reuse.
19/// The nonce is constructed as: 8-byte random session prefix || 4-byte counter.
20///
21/// Replay protection: `decrypt` tracks the highest accepted counter and rejects
22/// any frame whose counter is ≤ the last accepted value.
23pub struct SessionCipher {
24    cipher: Aes256Gcm,
25    nonce_prefix: [u8; 8],
26    counter: AtomicU32,
27    /// Highest counter value accepted during decryption.
28    /// Initialised to `u64::MAX` (sentinel meaning "no frame received yet").
29    last_recv_counter: AtomicU64,
30}
31
32/// Maximum number of frames before the nonce counter wraps.
33const MAX_NONCE_COUNTER: u32 = u32::MAX - 1;
34
35impl SessionCipher {
36    /// Create a new SessionCipher from a 32-byte key and 8-byte nonce prefix.
37    pub fn new(key: &[u8; 32], nonce_prefix: [u8; 8]) -> Self {
38        let cipher = Aes256Gcm::new_from_slice(key).expect("32-byte key is valid for AES-256");
39        Self {
40            cipher,
41            nonce_prefix,
42            counter: AtomicU32::new(0),
43            last_recv_counter: AtomicU64::new(u64::MAX),
44        }
45    }
46
47    /// Encrypt plaintext, returning the nonce and ciphertext.
48    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedFrame, SessionError> {
49        let count = self.counter.fetch_add(1, Ordering::SeqCst);
50        if count >= MAX_NONCE_COUNTER {
51            return Err(SessionError::NonceExhausted);
52        }
53
54        let nonce_bytes = self.build_nonce(count);
55        let nonce = Nonce::from_slice(&nonce_bytes);
56
57        let ciphertext = self
58            .cipher
59            .encrypt(nonce, plaintext)
60            .map_err(|_| SessionError::EncryptionFailed)?;
61
62        Ok(EncryptedFrame {
63            nonce: nonce_bytes,
64            ciphertext,
65        })
66    }
67
68    /// Decrypt an encrypted frame.
69    ///
70    /// Rejects replayed frames: the counter embedded in `frame.nonce[8..12]` must
71    /// be strictly greater than the last accepted counter.
72    pub fn decrypt(&self, frame: &EncryptedFrame) -> Result<Vec<u8>, SessionError> {
73        let counter = u32::from_be_bytes(
74            frame
75                .nonce
76                .get(8..12)
77                .ok_or(SessionError::InvalidNonce)?
78                .try_into()
79                .map_err(|_| SessionError::InvalidNonce)?,
80        ) as u64;
81        let last = self.last_recv_counter.load(Ordering::SeqCst);
82        if last != u64::MAX && counter <= last {
83            return Err(SessionError::ReplayedNonce);
84        }
85
86        let nonce = Nonce::from_slice(&frame.nonce);
87        let plaintext = self
88            .cipher
89            .decrypt(nonce, frame.ciphertext.as_ref())
90            .map_err(|_| SessionError::DecryptionFailed)?;
91
92        self.last_recv_counter.store(counter, Ordering::SeqCst);
93        Ok(plaintext)
94    }
95
96    /// Encrypts plaintext in-place, returning the generated nonce and tag.
97    /// Returns an error if the nonce counter is exhausted or encryption fails.
98    pub fn encrypt_in_place_detached(
99        &self,
100        payload: &mut [u8],
101    ) -> Result<([u8; 12], [u8; 16]), SessionError> {
102        let count = self.counter.fetch_add(1, Ordering::SeqCst);
103        if count >= MAX_NONCE_COUNTER {
104            return Err(SessionError::NonceExhausted);
105        }
106
107        let nonce_bytes = self.build_nonce(count);
108        let nonce = Nonce::from_slice(&nonce_bytes);
109
110        let tag = aes_gcm::aead::AeadInPlace::encrypt_in_place_detached(
111            &self.cipher,
112            nonce,
113            b"",
114            payload,
115        )
116        .map_err(|_| SessionError::EncryptionFailed)?;
117
118        let mut tag_bytes = [0u8; 16];
119        tag_bytes.copy_from_slice(&tag);
120
121        Ok((nonce_bytes, tag_bytes))
122    }
123
124    /// Decrypts a frame in-place, avoiding an allocation.
125    /// Returns an error if the nonce is replayed, or if decryption fails.
126    pub fn decrypt_in_place_detached(
127        &self,
128        nonce_bytes: &[u8; 12],
129        payload: &mut [u8],
130        tag_bytes: &[u8; 16],
131    ) -> Result<(), SessionError> {
132        let counter = u32::from_be_bytes(
133            nonce_bytes
134                .get(8..12)
135                .ok_or(SessionError::InvalidNonce)?
136                .try_into()
137                .map_err(|_| SessionError::InvalidNonce)?,
138        ) as u64;
139        let last = self.last_recv_counter.load(Ordering::SeqCst);
140        if last != u64::MAX && counter <= last {
141            return Err(SessionError::ReplayedNonce);
142        }
143
144        let nonce = Nonce::from_slice(nonce_bytes);
145        let tag = aes_gcm::aead::Tag::<aes_gcm::Aes256Gcm>::from_slice(tag_bytes);
146        aes_gcm::aead::AeadInPlace::decrypt_in_place_detached(
147            &self.cipher,
148            nonce,
149            b"",
150            payload,
151            tag,
152        )
153        .map_err(|_| SessionError::DecryptionFailed)?;
154
155        self.last_recv_counter.store(counter, Ordering::SeqCst);
156        Ok(())
157    }
158
159    /// Build a 12-byte nonce from the prefix and counter.
160    fn build_nonce(&self, counter: u32) -> [u8; 12] {
161        let mut nonce = [0u8; 12];
162        nonce[..8].copy_from_slice(&self.nonce_prefix);
163        nonce[8..12].copy_from_slice(&counter.to_be_bytes());
164        nonce
165    }
166}
167
168#[derive(Debug, thiserror::Error)]
169/// Errors returned by [`SessionCipher`].
170pub enum SessionError {
171    /// The nonce length or format is invalid.
172    #[error("invalid nonce format")]
173    InvalidNonce,
174    /// The nonce counter reached its maximum and the session must be replaced.
175    #[error("nonce counter exhausted — session must be rekeyed")]
176    NonceExhausted,
177    /// Encrypting the payload failed.
178    #[error("encryption failed")]
179    EncryptionFailed,
180    /// Decrypting the payload failed.
181    #[error("decryption failed (invalid ciphertext or wrong key)")]
182    DecryptionFailed,
183    /// A frame reused or regressed the receive nonce counter.
184    #[error("replayed nonce: frame counter has already been accepted")]
185    ReplayedNonce,
186}
187
188#[cfg(test)]
189mod tests;