1use std::collections::HashMap;
2
3use aes_gcm::aead::{Aead, KeyInit, Payload};
4use aes_gcm::{Aes128Gcm, Aes256Gcm};
5
6use crate::error::SFrameError;
7use crate::header::SFrameHeader;
8use crate::kdf::{CipherSuite, ParticipantKeys, derive_participant, make_nonce};
9use crate::replay::ReplayWindow;
10
11enum AeadCipher {
14 Aes128(Aes128Gcm),
15 Aes256(Aes256Gcm),
16}
17
18impl AeadCipher {
19 fn new(key: &[u8], suite: CipherSuite) -> Self {
20 match suite {
21 CipherSuite::Aes128Gcm => {
22 Self::Aes128(Aes128Gcm::new_from_slice(key).expect("key length matches suite"))
23 }
24 CipherSuite::Aes256Gcm => {
25 Self::Aes256(Aes256Gcm::new_from_slice(key).expect("key length matches suite"))
26 }
27 }
28 }
29
30 fn encrypt(
31 &self,
32 nonce: &[u8; 12],
33 plaintext: &[u8],
34 aad: &[u8],
35 ) -> Result<Vec<u8>, SFrameError> {
36 let n = aes_gcm::Nonce::from_slice(nonce);
37 let payload = Payload {
38 msg: plaintext,
39 aad,
40 };
41 match self {
42 Self::Aes128(c) => c.encrypt(n, payload).map_err(|_| SFrameError::Encrypt),
43 Self::Aes256(c) => c.encrypt(n, payload).map_err(|_| SFrameError::Encrypt),
44 }
45 }
46
47 fn decrypt(
48 &self,
49 nonce: &[u8; 12],
50 ciphertext: &[u8],
51 aad: &[u8],
52 ) -> Result<Vec<u8>, SFrameError> {
53 let n = aes_gcm::Nonce::from_slice(nonce);
54 let payload = Payload {
55 msg: ciphertext,
56 aad,
57 };
58 match self {
59 Self::Aes128(c) => c.decrypt(n, payload).map_err(|_| SFrameError::Decrypt),
60 Self::Aes256(c) => c.decrypt(n, payload).map_err(|_| SFrameError::Decrypt),
61 }
62 }
63}
64
65pub struct SFrameEncryptor {
74 cipher: AeadCipher,
75 base_nonce: [u8; 12],
76 kid: u64,
77 ctr: u64,
78}
79
80impl SFrameEncryptor {
81 pub(crate) fn new(keys: ParticipantKeys, kid: u64, suite: CipherSuite) -> Self {
82 Self {
83 cipher: AeadCipher::new(&keys.key, suite),
84 base_nonce: keys.base_nonce,
85 kid,
86 ctr: 0,
87 }
88 }
89
90 pub fn encrypt(&mut self, plaintext: &[u8], extra_aad: &[u8]) -> Result<Vec<u8>, SFrameError> {
96 let header = SFrameHeader {
97 kid: self.kid,
98 ctr: self.ctr,
99 };
100 let header_bytes = header.encode();
101
102 let mut aad = Vec::with_capacity(header_bytes.len() + extra_aad.len());
103 aad.extend_from_slice(&header_bytes);
104 aad.extend_from_slice(extra_aad);
105
106 let nonce = make_nonce(&self.base_nonce, self.ctr);
107 let ciphertext = self.cipher.encrypt(&nonce, plaintext, &aad)?;
108
109 self.ctr = self.ctr.wrapping_add(1);
110
111 let mut out = Vec::with_capacity(header_bytes.len() + ciphertext.len());
112 out.extend_from_slice(&header_bytes);
113 out.extend_from_slice(&ciphertext);
114 Ok(out)
115 }
116
117 pub fn counter(&self) -> u64 {
119 self.ctr
120 }
121
122 pub fn kid(&self) -> u64 {
124 self.kid
125 }
126}
127
128struct SenderState {
132 cipher: AeadCipher,
133 base_nonce: [u8; 12],
134 window: ReplayWindow,
135}
136
137pub struct SFrameDecryptor {
144 base_key: [u8; 32],
145 epoch: u64,
146 suite: CipherSuite,
147 senders: HashMap<u32, SenderState>,
149}
150
151impl SFrameDecryptor {
152 pub(crate) fn new(base_key: [u8; 32], epoch: u64, suite: CipherSuite) -> Self {
153 Self {
154 base_key,
155 epoch,
156 suite,
157 senders: HashMap::new(),
158 }
159 }
160
161 pub fn decrypt(
165 &mut self,
166 payload: &[u8],
167 extra_aad: &[u8],
168 ) -> Result<(Vec<u8>, u32), SFrameError> {
169 let (header, header_len) = SFrameHeader::decode(payload)?;
170
171 let frame_epoch = SFrameHeader::epoch_from_kid(header.kid);
172 if frame_epoch != self.epoch {
173 return Err(SFrameError::UnknownKid(header.kid));
174 }
175 let leaf = SFrameHeader::leaf_from_kid(header.kid);
176
177 let state = self.senders.entry(leaf).or_insert_with(|| {
179 let keys = derive_participant(&self.base_key, leaf, self.suite);
180 SenderState {
181 cipher: AeadCipher::new(&keys.key, self.suite),
182 base_nonce: keys.base_nonce,
183 window: ReplayWindow::new(),
184 }
185 });
186
187 state
189 .window
190 .check_and_mark(header.ctr)
191 .map_err(|_| SFrameError::Replay {
192 kid: header.kid,
193 ctr: header.ctr,
194 })?;
195
196 let header_bytes = &payload[..header_len];
197 let ciphertext = &payload[header_len..];
198
199 let mut aad = Vec::with_capacity(header_bytes.len() + extra_aad.len());
200 aad.extend_from_slice(header_bytes);
201 aad.extend_from_slice(extra_aad);
202
203 let nonce = make_nonce(&state.base_nonce, header.ctr);
204 let plaintext = state.cipher.decrypt(&nonce, ciphertext, &aad)?;
205
206 Ok((plaintext, leaf))
207 }
208
209 pub fn reset(&mut self) {
211 self.senders.values_mut().for_each(|s| s.window.reset());
212 }
213}