Skip to main content

gbp_sframe/
cipher.rs

1use std::collections::HashMap;
2
3use aes_gcm::aead::{Aead, KeyInit, Payload};
4use aes_gcm::{Aes128Gcm, Aes256Gcm};
5
6use crate::error::SFrameError;
7use crate::header::SFrameHeader;
8use crate::kdf::{CipherSuite, ParticipantKeys, derive_participant, make_nonce};
9use crate::replay::ReplayWindow;
10
11// ─── Internal AEAD helper ────────────────────────────────────────────────────
12
13enum AeadCipher {
14    Aes128(Aes128Gcm),
15    Aes256(Aes256Gcm),
16}
17
18impl AeadCipher {
19    fn new(key: &[u8], suite: CipherSuite) -> Self {
20        match suite {
21            CipherSuite::Aes128Gcm => {
22                Self::Aes128(Aes128Gcm::new_from_slice(key).expect("key length matches suite"))
23            }
24            CipherSuite::Aes256Gcm => {
25                Self::Aes256(Aes256Gcm::new_from_slice(key).expect("key length matches suite"))
26            }
27        }
28    }
29
30    fn encrypt(
31        &self,
32        nonce: &[u8; 12],
33        plaintext: &[u8],
34        aad: &[u8],
35    ) -> Result<Vec<u8>, SFrameError> {
36        let n = aes_gcm::Nonce::from_slice(nonce);
37        let payload = Payload {
38            msg: plaintext,
39            aad,
40        };
41        match self {
42            Self::Aes128(c) => c.encrypt(n, payload).map_err(|_| SFrameError::Encrypt),
43            Self::Aes256(c) => c.encrypt(n, payload).map_err(|_| SFrameError::Encrypt),
44        }
45    }
46
47    fn decrypt(
48        &self,
49        nonce: &[u8; 12],
50        ciphertext: &[u8],
51        aad: &[u8],
52    ) -> Result<Vec<u8>, SFrameError> {
53        let n = aes_gcm::Nonce::from_slice(nonce);
54        let payload = Payload {
55            msg: ciphertext,
56            aad,
57        };
58        match self {
59            Self::Aes128(c) => c.decrypt(n, payload).map_err(|_| SFrameError::Decrypt),
60            Self::Aes256(c) => c.decrypt(n, payload).map_err(|_| SFrameError::Decrypt),
61        }
62    }
63}
64
65// ─── SFrameEncryptor ─────────────────────────────────────────────────────────
66
67/// Stateful per-sender SFrame encryptor.
68///
69/// Holds the derived key+salt for one `(epoch, leaf_index)` pair and an
70/// internal counter that increments on every call to [`encrypt`].
71///
72/// Obtain via [`crate::SFrameSession::encryptor`].
73pub struct SFrameEncryptor {
74    cipher: AeadCipher,
75    base_nonce: [u8; 12],
76    kid: u64,
77    ctr: u64,
78}
79
80impl SFrameEncryptor {
81    pub(crate) fn new(keys: ParticipantKeys, kid: u64, suite: CipherSuite) -> Self {
82        Self {
83            cipher: AeadCipher::new(&keys.key, suite),
84            base_nonce: keys.base_nonce,
85            kid,
86            ctr: 0,
87        }
88    }
89
90    /// Encrypts `plaintext` and returns the complete SFrame payload:
91    /// `header ‖ ciphertext ‖ GCM-tag`.
92    ///
93    /// `extra_aad` is appended to the SFrame header to form the full AAD
94    /// (e.g. pass an RTP header or an empty slice).
95    pub fn encrypt(&mut self, plaintext: &[u8], extra_aad: &[u8]) -> Result<Vec<u8>, SFrameError> {
96        let header = SFrameHeader {
97            kid: self.kid,
98            ctr: self.ctr,
99        };
100        let header_bytes = header.encode();
101
102        let mut aad = Vec::with_capacity(header_bytes.len() + extra_aad.len());
103        aad.extend_from_slice(&header_bytes);
104        aad.extend_from_slice(extra_aad);
105
106        let nonce = make_nonce(&self.base_nonce, self.ctr);
107        let ciphertext = self.cipher.encrypt(&nonce, plaintext, &aad)?;
108
109        self.ctr = self.ctr.wrapping_add(1);
110
111        let mut out = Vec::with_capacity(header_bytes.len() + ciphertext.len());
112        out.extend_from_slice(&header_bytes);
113        out.extend_from_slice(&ciphertext);
114        Ok(out)
115    }
116
117    /// Current counter value (number of frames encrypted so far).
118    pub fn counter(&self) -> u64 {
119        self.ctr
120    }
121
122    /// KID this encryptor was created for.
123    pub fn kid(&self) -> u64 {
124        self.kid
125    }
126}
127
128// ─── SFrameDecryptor ─────────────────────────────────────────────────────────
129
130/// Per-sender decryption state maintained inside [`SFrameDecryptor`].
131struct SenderState {
132    cipher: AeadCipher,
133    base_nonce: [u8; 12],
134    window: ReplayWindow,
135}
136
137/// Multi-sender SFrame decryptor for one epoch.
138///
139/// Lazily derives per-sender key material from the epoch's base key as new
140/// `KID`s are encountered.  Maintains an independent replay window per sender.
141///
142/// Obtain via [`crate::SFrameSession::decryptor`].
143pub struct SFrameDecryptor {
144    base_key: [u8; 32],
145    epoch: u64,
146    suite: CipherSuite,
147    /// Keyed by `leaf_index`.
148    senders: HashMap<u32, SenderState>,
149}
150
151impl SFrameDecryptor {
152    pub(crate) fn new(base_key: [u8; 32], epoch: u64, suite: CipherSuite) -> Self {
153        Self {
154            base_key,
155            epoch,
156            suite,
157            senders: HashMap::new(),
158        }
159    }
160
161    /// Decrypts an SFrame `payload` and returns `(plaintext, sender_leaf)`.
162    ///
163    /// `extra_aad` must be the same slice passed on the encrypting side.
164    pub fn decrypt(
165        &mut self,
166        payload: &[u8],
167        extra_aad: &[u8],
168    ) -> Result<(Vec<u8>, u32), SFrameError> {
169        let (header, header_len) = SFrameHeader::decode(payload)?;
170
171        let frame_epoch = SFrameHeader::epoch_from_kid(header.kid);
172        if frame_epoch != self.epoch {
173            return Err(SFrameError::UnknownKid(header.kid));
174        }
175        let leaf = SFrameHeader::leaf_from_kid(header.kid);
176
177        // Lazily derive key material for this sender.
178        let state = self.senders.entry(leaf).or_insert_with(|| {
179            let keys = derive_participant(&self.base_key, leaf, self.suite);
180            SenderState {
181                cipher: AeadCipher::new(&keys.key, self.suite),
182                base_nonce: keys.base_nonce,
183                window: ReplayWindow::new(),
184            }
185        });
186
187        // Replay check before decryption (fast path for replays).
188        state
189            .window
190            .check_and_mark(header.ctr)
191            .map_err(|_| SFrameError::Replay {
192                kid: header.kid,
193                ctr: header.ctr,
194            })?;
195
196        let header_bytes = &payload[..header_len];
197        let ciphertext = &payload[header_len..];
198
199        let mut aad = Vec::with_capacity(header_bytes.len() + extra_aad.len());
200        aad.extend_from_slice(header_bytes);
201        aad.extend_from_slice(extra_aad);
202
203        let nonce = make_nonce(&state.base_nonce, header.ctr);
204        let plaintext = state.cipher.decrypt(&nonce, ciphertext, &aad)?;
205
206        Ok((plaintext, leaf))
207    }
208
209    /// Resets all per-sender replay windows (call on epoch change).
210    pub fn reset(&mut self) {
211        self.senders.values_mut().for_each(|s| s.window.reset());
212    }
213}