Skip to main content

gbp_sframe/
cipher.rs

1use std::collections::HashMap;
2
3use aes_gcm::aead::{Aead, KeyInit, Payload};
4use aes_gcm::{Aes128Gcm, Aes256Gcm};
5
6use crate::error::SFrameError;
7use crate::header::SFrameHeader;
8use crate::kdf::{CipherSuite, ParticipantKeys, derive_participant, make_nonce};
9use crate::replay::ReplayWindow;
10
11// ─── Internal AEAD helper ────────────────────────────────────────────────────
12
13enum AeadCipher {
14    Aes128(Aes128Gcm),
15    Aes256(Aes256Gcm),
16}
17
18impl AeadCipher {
19    fn new(key: &[u8], suite: CipherSuite) -> Self {
20        match suite {
21            CipherSuite::Aes128Gcm => Self::Aes128(
22                Aes128Gcm::new_from_slice(key).expect("key length matches suite"),
23            ),
24            CipherSuite::Aes256Gcm => Self::Aes256(
25                Aes256Gcm::new_from_slice(key).expect("key length matches suite"),
26            ),
27        }
28    }
29
30    fn encrypt(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>, SFrameError> {
31        let n = aes_gcm::Nonce::from_slice(nonce);
32        let payload = Payload { msg: plaintext, aad };
33        match self {
34            Self::Aes128(c) => c.encrypt(n, payload).map_err(|_| SFrameError::Encrypt),
35            Self::Aes256(c) => c.encrypt(n, payload).map_err(|_| SFrameError::Encrypt),
36        }
37    }
38
39    fn decrypt(&self, nonce: &[u8; 12], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>, SFrameError> {
40        let n = aes_gcm::Nonce::from_slice(nonce);
41        let payload = Payload { msg: ciphertext, aad };
42        match self {
43            Self::Aes128(c) => c.decrypt(n, payload).map_err(|_| SFrameError::Decrypt),
44            Self::Aes256(c) => c.decrypt(n, payload).map_err(|_| SFrameError::Decrypt),
45        }
46    }
47}
48
49// ─── SFrameEncryptor ─────────────────────────────────────────────────────────
50
51/// Stateful per-sender SFrame encryptor.
52///
53/// Holds the derived key+salt for one `(epoch, leaf_index)` pair and an
54/// internal counter that increments on every call to [`encrypt`].
55///
56/// Obtain via [`crate::SFrameSession::encryptor`].
57pub struct SFrameEncryptor {
58    cipher: AeadCipher,
59    salt: [u8; 12],
60    kid: u64,
61    ctr: u64,
62}
63
64impl SFrameEncryptor {
65    pub(crate) fn new(keys: ParticipantKeys, kid: u64, suite: CipherSuite) -> Self {
66        Self {
67            cipher: AeadCipher::new(&keys.key, suite),
68            salt: keys.salt,
69            kid,
70            ctr: 0,
71        }
72    }
73
74    /// Encrypts `plaintext` and returns the complete SFrame payload:
75    /// `header ‖ ciphertext ‖ GCM-tag`.
76    ///
77    /// `extra_aad` is appended to the SFrame header to form the full AAD
78    /// (e.g. pass an RTP header or an empty slice).
79    pub fn encrypt(&mut self, plaintext: &[u8], extra_aad: &[u8]) -> Result<Vec<u8>, SFrameError> {
80        let header = SFrameHeader { kid: self.kid, ctr: self.ctr };
81        let header_bytes = header.encode();
82
83        let mut aad = Vec::with_capacity(header_bytes.len() + extra_aad.len());
84        aad.extend_from_slice(&header_bytes);
85        aad.extend_from_slice(extra_aad);
86
87        let nonce = make_nonce(&self.salt, self.ctr);
88        let ciphertext = self.cipher.encrypt(&nonce, plaintext, &aad)?;
89
90        self.ctr = self.ctr.wrapping_add(1);
91
92        let mut out = Vec::with_capacity(header_bytes.len() + ciphertext.len());
93        out.extend_from_slice(&header_bytes);
94        out.extend_from_slice(&ciphertext);
95        Ok(out)
96    }
97
98    /// Current counter value (number of frames encrypted so far).
99    pub fn counter(&self) -> u64 {
100        self.ctr
101    }
102
103    /// KID this encryptor was created for.
104    pub fn kid(&self) -> u64 {
105        self.kid
106    }
107}
108
109// ─── SFrameDecryptor ─────────────────────────────────────────────────────────
110
111/// Per-sender decryption state maintained inside [`SFrameDecryptor`].
112struct SenderState {
113    cipher: AeadCipher,
114    salt: [u8; 12],
115    window: ReplayWindow,
116}
117
118/// Multi-sender SFrame decryptor for one epoch.
119///
120/// Lazily derives per-sender key material from the epoch's base key as new
121/// `KID`s are encountered.  Maintains an independent replay window per sender.
122///
123/// Obtain via [`crate::SFrameSession::decryptor`].
124pub struct SFrameDecryptor {
125    base_key: [u8; 32],
126    epoch: u64,
127    suite: CipherSuite,
128    /// Keyed by `leaf_index`.
129    senders: HashMap<u32, SenderState>,
130}
131
132impl SFrameDecryptor {
133    pub(crate) fn new(base_key: [u8; 32], epoch: u64, suite: CipherSuite) -> Self {
134        Self {
135            base_key,
136            epoch,
137            suite,
138            senders: HashMap::new(),
139        }
140    }
141
142    /// Decrypts an SFrame `payload` and returns `(plaintext, sender_leaf)`.
143    ///
144    /// `extra_aad` must be the same slice passed on the encrypting side.
145    pub fn decrypt(
146        &mut self,
147        payload: &[u8],
148        extra_aad: &[u8],
149    ) -> Result<(Vec<u8>, u32), SFrameError> {
150        let (header, header_len) = SFrameHeader::decode(payload)?;
151
152        let frame_epoch = SFrameHeader::epoch_from_kid(header.kid);
153        if frame_epoch != self.epoch {
154            return Err(SFrameError::UnknownKid(header.kid));
155        }
156        let leaf = SFrameHeader::leaf_from_kid(header.kid);
157
158        // Lazily derive key material for this sender.
159        let state = self.senders.entry(leaf).or_insert_with(|| {
160            let keys = derive_participant(&self.base_key, leaf, self.suite);
161            SenderState {
162                cipher: AeadCipher::new(&keys.key, self.suite),
163                salt: keys.salt,
164                window: ReplayWindow::new(),
165            }
166        });
167
168        // Replay check before decryption (fast path for replays).
169        state
170            .window
171            .check_and_mark(header.ctr)
172            .map_err(|_| SFrameError::Replay { kid: header.kid, ctr: header.ctr })?;
173
174        let header_bytes = &payload[..header_len];
175        let ciphertext = &payload[header_len..];
176
177        let mut aad = Vec::with_capacity(header_bytes.len() + extra_aad.len());
178        aad.extend_from_slice(header_bytes);
179        aad.extend_from_slice(extra_aad);
180
181        let nonce = make_nonce(&state.salt, header.ctr);
182        let plaintext = state.cipher.decrypt(&nonce, ciphertext, &aad)?;
183
184        Ok((plaintext, leaf))
185    }
186
187    /// Resets all per-sender replay windows (call on epoch change).
188    pub fn reset(&mut self) {
189        self.senders.values_mut().for_each(|s| s.window.reset());
190    }
191}