commonware_cryptography/handshake/
cipher.rs1use super::error::Error;
2use crate::Secret;
3use rand_core::CryptoRngCore;
4use std::vec::Vec;
5use zeroize::Zeroizing;
6
7pub const TAG_SIZE: usize = 16;
12
13const NONCE_SIZE_BYTES: usize = 12;
16
17const KEY_SIZE_BYTES: usize = 32;
20
21struct CounterNonce {
22 inner: u128,
23}
24
25impl CounterNonce {
26 pub const fn new() -> Self {
28 Self { inner: 0 }
29 }
30
31 pub fn inc(&mut self) -> Result<[u8; NONCE_SIZE_BYTES], Error> {
34 if self.inner >= 1 << (8 * NONCE_SIZE_BYTES) {
35 return Err(Error::MessageLimitReached);
36 }
37 let out = self.inner.to_le_bytes();
38 self.inner += 1;
39
40 let mut nonce = [0u8; NONCE_SIZE_BYTES];
42 nonce.copy_from_slice(&out[..NONCE_SIZE_BYTES]);
43 Ok(nonce)
44 }
45}
46
47cfg_if::cfg_if! {
48 if #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] {
49 use aws_lc_rs::aead::{self, LessSafeKey, UnboundKey, CHACHA20_POLY1305};
50
51 struct Cipher(LessSafeKey);
52
53 impl Cipher {
54 fn from_key(key: &[u8; KEY_SIZE_BYTES]) -> Self {
55 let unbound_key = UnboundKey::new(&CHACHA20_POLY1305, key)
56 .expect("key size should match algorithm");
57 Self(LessSafeKey::new(unbound_key))
58 }
59
60 fn encrypt_in_place(
61 &self,
62 nonce: &[u8; NONCE_SIZE_BYTES],
63 data: &mut [u8],
64 ) -> Result<[u8; TAG_SIZE], Error> {
65 let nonce = aead::Nonce::assume_unique_for_key(*nonce);
66 let tag = self
67 .0
68 .seal_in_place_separate_tag(nonce, aead::Aad::empty(), data)
69 .map_err(|_| Error::EncryptionFailed)?;
70 Ok(tag.as_ref().try_into().expect("tag size mismatch"))
71 }
72
73 fn decrypt_in_place(
74 &self,
75 nonce: &[u8; NONCE_SIZE_BYTES],
76 data: &mut [u8],
77 ) -> Result<usize, Error> {
78 let nonce = aead::Nonce::assume_unique_for_key(*nonce);
79 self.0
80 .open_in_place(nonce, aead::Aad::empty(), data)
81 .map_err(|_| Error::DecryptionFailed)?;
82 Ok(data.len() - TAG_SIZE)
83 }
84 }
85 } else {
86 use chacha20poly1305::{aead::AeadInPlace, ChaCha20Poly1305, KeyInit as _};
87
88 struct Cipher(ChaCha20Poly1305);
89
90 impl Cipher {
91 fn from_key(key: &[u8; KEY_SIZE_BYTES]) -> Self {
92 Self(ChaCha20Poly1305::new(key.into()))
93 }
94
95 fn encrypt_in_place(
96 &self,
97 nonce: &[u8; NONCE_SIZE_BYTES],
98 data: &mut [u8],
99 ) -> Result<[u8; TAG_SIZE], Error> {
100 let tag = self
101 .0
102 .encrypt_in_place_detached(nonce.into(), &[], data)
103 .map_err(|_| Error::EncryptionFailed)?;
104 Ok(tag.into())
105 }
106
107 fn decrypt_in_place(
108 &self,
109 nonce: &[u8; NONCE_SIZE_BYTES],
110 data: &mut [u8],
111 ) -> Result<usize, Error> {
112 let plaintext_len = data.len() - TAG_SIZE;
113 let tag: [u8; TAG_SIZE] = data[plaintext_len..]
114 .try_into()
115 .map_err(|_| Error::DecryptionFailed)?;
116 self.0
117 .decrypt_in_place_detached(
118 nonce.into(),
119 &[],
120 &mut data[..plaintext_len],
121 &tag.into(),
122 )
123 .map_err(|_| Error::DecryptionFailed)?;
124 Ok(plaintext_len)
125 }
126 }
127 }
128}
129
130pub struct SendCipher {
132 nonce: CounterNonce,
133 inner: Secret<Cipher>,
134}
135
136impl SendCipher {
137 pub fn new(mut rng: impl CryptoRngCore) -> Self {
139 let mut key_bytes = Zeroizing::new([0u8; KEY_SIZE_BYTES]);
140 rng.fill_bytes(key_bytes.as_mut());
141 Self {
142 nonce: CounterNonce::new(),
143 inner: Secret::new(Cipher::from_key(&key_bytes)),
144 }
145 }
146
147 #[inline]
151 pub fn send_in_place(&mut self, data: &mut [u8]) -> Result<[u8; TAG_SIZE], Error> {
152 let nonce = self.nonce.inc()?;
153 self.inner
154 .expose(|cipher| cipher.encrypt_in_place(&nonce, data))
155 }
156
157 pub fn send(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
159 let mut buf = vec![0u8; data.len() + TAG_SIZE];
160 buf[..data.len()].copy_from_slice(data);
161 let tag = self.send_in_place(&mut buf[..data.len()])?;
162 buf[data.len()..].copy_from_slice(&tag);
163 Ok(buf)
164 }
165}
166
167pub struct RecvCipher {
169 nonce: CounterNonce,
170 inner: Secret<Cipher>,
171}
172
173impl RecvCipher {
174 pub fn new(mut rng: impl CryptoRngCore) -> Self {
176 let mut key_bytes = Zeroizing::new([0u8; KEY_SIZE_BYTES]);
177 rng.fill_bytes(key_bytes.as_mut());
178 Self {
179 nonce: CounterNonce::new(),
180 inner: Secret::new(Cipher::from_key(&key_bytes)),
181 }
182 }
183
184 #[inline]
202 pub fn recv_in_place(&mut self, encrypted_data: &mut [u8]) -> Result<usize, Error> {
203 if encrypted_data.len() < TAG_SIZE {
204 return Err(Error::DecryptionFailed);
205 }
206 let nonce = self.nonce.inc()?;
207 self.inner
208 .expose(|cipher| cipher.decrypt_in_place(&nonce, encrypted_data))
209 }
210
211 pub fn recv(&mut self, encrypted_data: &[u8]) -> Result<Vec<u8>, Error> {
225 let mut buf = encrypted_data.to_vec();
226 let plaintext_len = self.recv_in_place(&mut buf)?;
227 buf.truncate(plaintext_len);
228 Ok(buf)
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use commonware_utils::{test_rng, test_rng_seeded};
236
237 #[test]
238 fn test_send_recv_roundtrip() {
239 let mut send = SendCipher::new(&mut test_rng());
240 let mut recv = RecvCipher::new(&mut test_rng());
241
242 let plaintext = b"hello world";
243 let ciphertext = send.send(plaintext).unwrap();
244 assert_eq!(ciphertext.len(), plaintext.len() + TAG_SIZE);
245
246 let decrypted = recv.recv(&ciphertext).unwrap();
247 assert_eq!(decrypted, plaintext);
248 }
249
250 #[test]
251 fn test_recv_wrong_key_fails() {
252 let mut send = SendCipher::new(&mut test_rng_seeded(0));
253 let mut recv = RecvCipher::new(&mut test_rng_seeded(1));
254
255 let ciphertext = send.send(b"hello").unwrap();
256 assert!(matches!(
257 recv.recv(&ciphertext),
258 Err(Error::DecryptionFailed)
259 ));
260 }
261
262 #[test]
263 fn test_recv_ciphertext_too_short() {
264 let mut rng = test_rng();
265 let mut recv = RecvCipher::new(&mut rng);
266 let short_data = vec![0u8; TAG_SIZE - 1];
267 assert!(matches!(
268 recv.recv(&short_data),
269 Err(Error::DecryptionFailed)
270 ));
271 }
272
273 #[test]
274 fn test_recv_ciphertext_exactly_overhead() {
275 let mut rng = test_rng();
276 let mut recv = RecvCipher::new(&mut rng);
277 let tag_only = vec![0u8; TAG_SIZE];
278 assert!(matches!(recv.recv(&tag_only), Err(Error::DecryptionFailed)));
279 }
280
281 #[test]
282 fn test_send_recv_in_place_roundtrip() {
283 let mut send = SendCipher::new(&mut test_rng());
284 let mut recv = RecvCipher::new(&mut test_rng());
285
286 let plaintext = b"hello world";
287 let mut buf = vec![0u8; plaintext.len() + TAG_SIZE];
288 buf[..plaintext.len()].copy_from_slice(plaintext);
289
290 let tag = send.send_in_place(&mut buf[..plaintext.len()]).unwrap();
292 buf[plaintext.len()..].copy_from_slice(&tag);
294
295 let plaintext_len = recv.recv_in_place(&mut buf).unwrap();
297
298 assert_eq!(plaintext_len, plaintext.len());
299 assert_eq!(&buf[..plaintext_len], plaintext);
300 }
301
302 #[test]
303 fn test_recv_in_place_ciphertext_too_short() {
304 let mut recv = RecvCipher::new(&mut test_rng());
305
306 let mut buf = vec![0u8; TAG_SIZE - 1];
308 assert!(matches!(
309 recv.recv_in_place(&mut buf),
310 Err(Error::DecryptionFailed)
311 ));
312 }
313
314 #[test]
315 fn test_send_in_place_recv_compatibility() {
316 let mut send = SendCipher::new(&mut test_rng());
317 let mut recv = RecvCipher::new(&mut test_rng());
318
319 let plaintext = b"cross-api test";
320 let mut buf = vec![0u8; plaintext.len() + TAG_SIZE];
321 buf[..plaintext.len()].copy_from_slice(plaintext);
322
323 let tag = send.send_in_place(&mut buf[..plaintext.len()]).unwrap();
324 buf[plaintext.len()..].copy_from_slice(&tag);
325
326 let decrypted = recv.recv(&buf).unwrap();
328 assert_eq!(decrypted, plaintext);
329 }
330
331 #[test]
332 fn test_send_recv_in_place_compatibility() {
333 let mut send = SendCipher::new(&mut test_rng());
334 let mut recv = RecvCipher::new(&mut test_rng());
335
336 let plaintext = b"cross-api test";
337 let mut ciphertext = send.send(plaintext).unwrap();
338
339 let plaintext_len = recv.recv_in_place(&mut ciphertext).unwrap();
341 assert_eq!(&ciphertext[..plaintext_len], plaintext);
342 }
343}