1use std::collections::HashMap;
2
3use aes_gcm::aead::{Aead, KeyInit, Payload};
4use aes_gcm::{Aes128Gcm, Aes256Gcm};
5
6use crate::error::SFrameError;
7use crate::header::SFrameHeader;
8use crate::kdf::{CipherSuite, ParticipantKeys, derive_participant, make_nonce};
9use crate::replay::ReplayWindow;
10
11enum AeadCipher {
14 Aes128(Aes128Gcm),
15 Aes256(Aes256Gcm),
16}
17
18impl AeadCipher {
19 fn new(key: &[u8], suite: CipherSuite) -> Self {
20 match suite {
21 CipherSuite::Aes128Gcm => Self::Aes128(
22 Aes128Gcm::new_from_slice(key).expect("key length matches suite"),
23 ),
24 CipherSuite::Aes256Gcm => Self::Aes256(
25 Aes256Gcm::new_from_slice(key).expect("key length matches suite"),
26 ),
27 }
28 }
29
30 fn encrypt(&self, nonce: &[u8; 12], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>, SFrameError> {
31 let n = aes_gcm::Nonce::from_slice(nonce);
32 let payload = Payload { msg: plaintext, aad };
33 match self {
34 Self::Aes128(c) => c.encrypt(n, payload).map_err(|_| SFrameError::Encrypt),
35 Self::Aes256(c) => c.encrypt(n, payload).map_err(|_| SFrameError::Encrypt),
36 }
37 }
38
39 fn decrypt(&self, nonce: &[u8; 12], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>, SFrameError> {
40 let n = aes_gcm::Nonce::from_slice(nonce);
41 let payload = Payload { msg: ciphertext, aad };
42 match self {
43 Self::Aes128(c) => c.decrypt(n, payload).map_err(|_| SFrameError::Decrypt),
44 Self::Aes256(c) => c.decrypt(n, payload).map_err(|_| SFrameError::Decrypt),
45 }
46 }
47}
48
49pub struct SFrameEncryptor {
58 cipher: AeadCipher,
59 salt: [u8; 12],
60 kid: u64,
61 ctr: u64,
62}
63
64impl SFrameEncryptor {
65 pub(crate) fn new(keys: ParticipantKeys, kid: u64, suite: CipherSuite) -> Self {
66 Self {
67 cipher: AeadCipher::new(&keys.key, suite),
68 salt: keys.salt,
69 kid,
70 ctr: 0,
71 }
72 }
73
74 pub fn encrypt(&mut self, plaintext: &[u8], extra_aad: &[u8]) -> Result<Vec<u8>, SFrameError> {
80 let header = SFrameHeader { kid: self.kid, ctr: self.ctr };
81 let header_bytes = header.encode();
82
83 let mut aad = Vec::with_capacity(header_bytes.len() + extra_aad.len());
84 aad.extend_from_slice(&header_bytes);
85 aad.extend_from_slice(extra_aad);
86
87 let nonce = make_nonce(&self.salt, self.ctr);
88 let ciphertext = self.cipher.encrypt(&nonce, plaintext, &aad)?;
89
90 self.ctr = self.ctr.wrapping_add(1);
91
92 let mut out = Vec::with_capacity(header_bytes.len() + ciphertext.len());
93 out.extend_from_slice(&header_bytes);
94 out.extend_from_slice(&ciphertext);
95 Ok(out)
96 }
97
98 pub fn counter(&self) -> u64 {
100 self.ctr
101 }
102
103 pub fn kid(&self) -> u64 {
105 self.kid
106 }
107}
108
109struct SenderState {
113 cipher: AeadCipher,
114 salt: [u8; 12],
115 window: ReplayWindow,
116}
117
118pub struct SFrameDecryptor {
125 base_key: [u8; 32],
126 epoch: u64,
127 suite: CipherSuite,
128 senders: HashMap<u32, SenderState>,
130}
131
132impl SFrameDecryptor {
133 pub(crate) fn new(base_key: [u8; 32], epoch: u64, suite: CipherSuite) -> Self {
134 Self {
135 base_key,
136 epoch,
137 suite,
138 senders: HashMap::new(),
139 }
140 }
141
142 pub fn decrypt(
146 &mut self,
147 payload: &[u8],
148 extra_aad: &[u8],
149 ) -> Result<(Vec<u8>, u32), SFrameError> {
150 let (header, header_len) = SFrameHeader::decode(payload)?;
151
152 let frame_epoch = SFrameHeader::epoch_from_kid(header.kid);
153 if frame_epoch != self.epoch {
154 return Err(SFrameError::UnknownKid(header.kid));
155 }
156 let leaf = SFrameHeader::leaf_from_kid(header.kid);
157
158 let state = self.senders.entry(leaf).or_insert_with(|| {
160 let keys = derive_participant(&self.base_key, leaf, self.suite);
161 SenderState {
162 cipher: AeadCipher::new(&keys.key, self.suite),
163 salt: keys.salt,
164 window: ReplayWindow::new(),
165 }
166 });
167
168 state
170 .window
171 .check_and_mark(header.ctr)
172 .map_err(|_| SFrameError::Replay { kid: header.kid, ctr: header.ctr })?;
173
174 let header_bytes = &payload[..header_len];
175 let ciphertext = &payload[header_len..];
176
177 let mut aad = Vec::with_capacity(header_bytes.len() + extra_aad.len());
178 aad.extend_from_slice(header_bytes);
179 aad.extend_from_slice(extra_aad);
180
181 let nonce = make_nonce(&state.salt, header.ctr);
182 let plaintext = state.cipher.decrypt(&nonce, ciphertext, &aad)?;
183
184 Ok((plaintext, leaf))
185 }
186
187 pub fn reset(&mut self) {
189 self.senders.values_mut().for_each(|s| s.window.reset());
190 }
191}