Skip to main content

commonware_cryptography/handshake/
cipher.rs

1use super::error::Error;
2use crate::Secret;
3use rand_core::CryptoRngCore;
4use std::vec::Vec;
5use zeroize::Zeroizing;
6
7/// Size of the ChaCha20-Poly1305 authentication tag.
8///
9/// This tag is the overhead added to each ciphertext and must be transmitted
10/// alongside it for the receiver to verify integrity and authenticity.
11pub const TAG_SIZE: usize = 16;
12
13/// How many bytes are in a nonce.
14/// ChaCha20-Poly1305 uses a 96-bit (12 byte) nonce.
15const NONCE_SIZE_BYTES: usize = 12;
16
17/// How many bytes are in a key.
18/// ChaCha20-Poly1305 uses a 256-bit (32 byte) key.
19const KEY_SIZE_BYTES: usize = 32;
20
21struct CounterNonce {
22    inner: u128,
23}
24
25impl CounterNonce {
26    /// Creates a new counter nonce starting at zero.
27    pub const fn new() -> Self {
28        Self { inner: 0 }
29    }
30
31    /// Increments the counter and returns the current value as bytes.
32    /// Returns an error if the counter would overflow.
33    pub fn inc(&mut self) -> Result<[u8; NONCE_SIZE_BYTES], Error> {
34        if self.inner >= 1 << (8 * NONCE_SIZE_BYTES) {
35            return Err(Error::MessageLimitReached);
36        }
37        let out = self.inner.to_le_bytes();
38        self.inner += 1;
39
40        // Extract only the lower 96 bits (12 bytes) for the nonce
41        let mut nonce = [0u8; NONCE_SIZE_BYTES];
42        nonce.copy_from_slice(&out[..NONCE_SIZE_BYTES]);
43        Ok(nonce)
44    }
45}
46
47cfg_if::cfg_if! {
48    if #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] {
49        use aws_lc_rs::aead::{self, LessSafeKey, UnboundKey, CHACHA20_POLY1305};
50
51        struct Cipher(LessSafeKey);
52
53        impl Cipher {
54            fn from_key(key: &[u8; KEY_SIZE_BYTES]) -> Self {
55                let unbound_key = UnboundKey::new(&CHACHA20_POLY1305, key)
56                    .expect("key size should match algorithm");
57                Self(LessSafeKey::new(unbound_key))
58            }
59
60            fn encrypt_in_place(
61                &self,
62                nonce: &[u8; NONCE_SIZE_BYTES],
63                data: &mut [u8],
64            ) -> Result<[u8; TAG_SIZE], Error> {
65                let nonce = aead::Nonce::assume_unique_for_key(*nonce);
66                let tag = self
67                    .0
68                    .seal_in_place_separate_tag(nonce, aead::Aad::empty(), data)
69                    .map_err(|_| Error::EncryptionFailed)?;
70                Ok(tag.as_ref().try_into().expect("tag size mismatch"))
71            }
72
73            fn decrypt_in_place(
74                &self,
75                nonce: &[u8; NONCE_SIZE_BYTES],
76                data: &mut [u8],
77            ) -> Result<usize, Error> {
78                let nonce = aead::Nonce::assume_unique_for_key(*nonce);
79                self.0
80                    .open_in_place(nonce, aead::Aad::empty(), data)
81                    .map_err(|_| Error::DecryptionFailed)?;
82                Ok(data.len() - TAG_SIZE)
83            }
84        }
85    } else {
86        use chacha20poly1305::{aead::AeadInPlace, ChaCha20Poly1305, KeyInit as _};
87
88        struct Cipher(ChaCha20Poly1305);
89
90        impl Cipher {
91            fn from_key(key: &[u8; KEY_SIZE_BYTES]) -> Self {
92                Self(ChaCha20Poly1305::new(key.into()))
93            }
94
95            fn encrypt_in_place(
96                &self,
97                nonce: &[u8; NONCE_SIZE_BYTES],
98                data: &mut [u8],
99            ) -> Result<[u8; TAG_SIZE], Error> {
100                let tag = self
101                    .0
102                    .encrypt_in_place_detached(nonce.into(), &[], data)
103                    .map_err(|_| Error::EncryptionFailed)?;
104                Ok(tag.into())
105            }
106
107            fn decrypt_in_place(
108                &self,
109                nonce: &[u8; NONCE_SIZE_BYTES],
110                data: &mut [u8],
111            ) -> Result<usize, Error> {
112                let plaintext_len = data.len() - TAG_SIZE;
113                let tag: [u8; TAG_SIZE] = data[plaintext_len..]
114                    .try_into()
115                    .map_err(|_| Error::DecryptionFailed)?;
116                self.0
117                    .decrypt_in_place_detached(
118                        nonce.into(),
119                        &[],
120                        &mut data[..plaintext_len],
121                        &tag.into(),
122                    )
123                    .map_err(|_| Error::DecryptionFailed)?;
124                Ok(plaintext_len)
125            }
126        }
127    }
128}
129
130/// Encrypts outgoing messages with an auto-incrementing nonce.
131pub struct SendCipher {
132    nonce: CounterNonce,
133    inner: Secret<Cipher>,
134}
135
136impl SendCipher {
137    /// Creates a new sending cipher with a random key.
138    pub fn new(mut rng: impl CryptoRngCore) -> Self {
139        let mut key_bytes = Zeroizing::new([0u8; KEY_SIZE_BYTES]);
140        rng.fill_bytes(key_bytes.as_mut());
141        Self {
142            nonce: CounterNonce::new(),
143            inner: Secret::new(Cipher::from_key(&key_bytes)),
144        }
145    }
146
147    /// Encrypts `data` in-place and returns the authentication tag.
148    ///
149    /// The caller is responsible for appending the returned tag to the buffer.
150    #[inline]
151    pub fn send_in_place(&mut self, data: &mut [u8]) -> Result<[u8; TAG_SIZE], Error> {
152        let nonce = self.nonce.inc()?;
153        self.inner
154            .expose(|cipher| cipher.encrypt_in_place(&nonce, data))
155    }
156
157    /// Encrypts data and returns the ciphertext.
158    pub fn send(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
159        let mut buf = vec![0u8; data.len() + TAG_SIZE];
160        buf[..data.len()].copy_from_slice(data);
161        let tag = self.send_in_place(&mut buf[..data.len()])?;
162        buf[data.len()..].copy_from_slice(&tag);
163        Ok(buf)
164    }
165}
166
167/// Decrypts incoming messages with an auto-incrementing nonce.
168pub struct RecvCipher {
169    nonce: CounterNonce,
170    inner: Secret<Cipher>,
171}
172
173impl RecvCipher {
174    /// Creates a new receiving cipher with a random key.
175    pub fn new(mut rng: impl CryptoRngCore) -> Self {
176        let mut key_bytes = Zeroizing::new([0u8; KEY_SIZE_BYTES]);
177        rng.fill_bytes(key_bytes.as_mut());
178        Self {
179            nonce: CounterNonce::new(),
180            inner: Secret::new(Cipher::from_key(&key_bytes)),
181        }
182    }
183
184    /// Decrypts `encrypted_data` in-place and returns the plaintext length.
185    ///
186    /// The buffer must contain ciphertext with the authentication tag appended
187    /// (last `TAG_SIZE` bytes). After decryption, the plaintext is in
188    /// `encrypted_data[..returned_len]`.
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if:
193    /// - `encrypted_data.len() < TAG_SIZE`
194    /// - Too many messages have been received with this cipher
195    /// - The ciphertext was corrupted or tampered with
196    ///
197    /// In the last two cases, the `RecvCipher` will no longer be able to return
198    /// valid ciphertexts, and will always return an error on subsequent calls
199    /// to [`Self::recv`]. Terminating (and optionally reestablishing) the connection
200    /// is a simple (and safe) way to handle this scenario.
201    #[inline]
202    pub fn recv_in_place(&mut self, encrypted_data: &mut [u8]) -> Result<usize, Error> {
203        if encrypted_data.len() < TAG_SIZE {
204            return Err(Error::DecryptionFailed);
205        }
206        let nonce = self.nonce.inc()?;
207        self.inner
208            .expose(|cipher| cipher.decrypt_in_place(&nonce, encrypted_data))
209    }
210
211    /// Decrypts ciphertext and returns the original data.
212    ///
213    /// # Errors
214    ///
215    /// This function will return an error in the following situations:
216    ///
217    /// - Too many messages have been received with this cipher.
218    /// - The ciphertext was corrupted in some way.
219    ///
220    /// In *both* cases, the `RecvCipher` will no longer be able to return
221    /// valid ciphertexts, and will always return an error on subsequent calls
222    /// to [`Self::recv`]. Terminating (and optionally reestablishing) the connection
223    /// is a simple (and safe) way to handle this scenario.
224    pub fn recv(&mut self, encrypted_data: &[u8]) -> Result<Vec<u8>, Error> {
225        let mut buf = encrypted_data.to_vec();
226        let plaintext_len = self.recv_in_place(&mut buf)?;
227        buf.truncate(plaintext_len);
228        Ok(buf)
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use commonware_utils::{test_rng, test_rng_seeded};
236
237    #[test]
238    fn test_send_recv_roundtrip() {
239        let mut send = SendCipher::new(&mut test_rng());
240        let mut recv = RecvCipher::new(&mut test_rng());
241
242        let plaintext = b"hello world";
243        let ciphertext = send.send(plaintext).unwrap();
244        assert_eq!(ciphertext.len(), plaintext.len() + TAG_SIZE);
245
246        let decrypted = recv.recv(&ciphertext).unwrap();
247        assert_eq!(decrypted, plaintext);
248    }
249
250    #[test]
251    fn test_recv_wrong_key_fails() {
252        let mut send = SendCipher::new(&mut test_rng_seeded(0));
253        let mut recv = RecvCipher::new(&mut test_rng_seeded(1));
254
255        let ciphertext = send.send(b"hello").unwrap();
256        assert!(matches!(
257            recv.recv(&ciphertext),
258            Err(Error::DecryptionFailed)
259        ));
260    }
261
262    #[test]
263    fn test_recv_ciphertext_too_short() {
264        let mut rng = test_rng();
265        let mut recv = RecvCipher::new(&mut rng);
266        let short_data = vec![0u8; TAG_SIZE - 1];
267        assert!(matches!(
268            recv.recv(&short_data),
269            Err(Error::DecryptionFailed)
270        ));
271    }
272
273    #[test]
274    fn test_recv_ciphertext_exactly_overhead() {
275        let mut rng = test_rng();
276        let mut recv = RecvCipher::new(&mut rng);
277        let tag_only = vec![0u8; TAG_SIZE];
278        assert!(matches!(recv.recv(&tag_only), Err(Error::DecryptionFailed)));
279    }
280
281    #[test]
282    fn test_send_recv_in_place_roundtrip() {
283        let mut send = SendCipher::new(&mut test_rng());
284        let mut recv = RecvCipher::new(&mut test_rng());
285
286        let plaintext = b"hello world";
287        let mut buf = vec![0u8; plaintext.len() + TAG_SIZE];
288        buf[..plaintext.len()].copy_from_slice(plaintext);
289
290        // Encrypt plaintext in place, get tag back
291        let tag = send.send_in_place(&mut buf[..plaintext.len()]).unwrap();
292        // Append tag to buffer
293        buf[plaintext.len()..].copy_from_slice(&tag);
294
295        // Decrypt ciphertext+tag in place, get plaintext length back
296        let plaintext_len = recv.recv_in_place(&mut buf).unwrap();
297
298        assert_eq!(plaintext_len, plaintext.len());
299        assert_eq!(&buf[..plaintext_len], plaintext);
300    }
301
302    #[test]
303    fn test_recv_in_place_ciphertext_too_short() {
304        let mut recv = RecvCipher::new(&mut test_rng());
305
306        // Buffer smaller than tag size
307        let mut buf = vec![0u8; TAG_SIZE - 1];
308        assert!(matches!(
309            recv.recv_in_place(&mut buf),
310            Err(Error::DecryptionFailed)
311        ));
312    }
313
314    #[test]
315    fn test_send_in_place_recv_compatibility() {
316        let mut send = SendCipher::new(&mut test_rng());
317        let mut recv = RecvCipher::new(&mut test_rng());
318
319        let plaintext = b"cross-api test";
320        let mut buf = vec![0u8; plaintext.len() + TAG_SIZE];
321        buf[..plaintext.len()].copy_from_slice(plaintext);
322
323        let tag = send.send_in_place(&mut buf[..plaintext.len()]).unwrap();
324        buf[plaintext.len()..].copy_from_slice(&tag);
325
326        // Use allocating recv on in-place encrypted data
327        let decrypted = recv.recv(&buf).unwrap();
328        assert_eq!(decrypted, plaintext);
329    }
330
331    #[test]
332    fn test_send_recv_in_place_compatibility() {
333        let mut send = SendCipher::new(&mut test_rng());
334        let mut recv = RecvCipher::new(&mut test_rng());
335
336        let plaintext = b"cross-api test";
337        let mut ciphertext = send.send(plaintext).unwrap();
338
339        // Use in-place recv on allocating send data
340        let plaintext_len = recv.recv_in_place(&mut ciphertext).unwrap();
341        assert_eq!(&ciphertext[..plaintext_len], plaintext);
342    }
343}