noise_protocol/
cipherstate.rs

1use crate::traits::{Cipher, U8Array};
2
3#[cfg(feature = "use_alloc")]
4use alloc::vec::Vec;
5
6/// A `CipherState` can encrypt and decrypt data.
7///
8/// Mostly like `CipherState` in the spec, but must be created with a key.
9///
10/// # Panics
11///
12/// Encryption and decryption methods will panic if nonce reaches maximum u64, i.e., 2 ^ 64 - 1.
13pub struct CipherState<C: Cipher> {
14    key: C::Key,
15    n: u64,
16}
17
18impl<C> Clone for CipherState<C>
19where
20    C: Cipher,
21{
22    fn clone(&self) -> Self {
23        Self {
24            key: self.key.clone(),
25            n: self.n,
26        }
27    }
28}
29
30impl<C> CipherState<C>
31where
32    C: Cipher,
33{
34    /// Name of cipher, e.g. “ChaChaPoly”.
35    pub fn name() -> &'static str {
36        C::name()
37    }
38
39    /// Create a new `CipherState` with a `key` and a nonce `n`.
40    pub fn new(key: &[u8], n: u64) -> Self {
41        CipherState {
42            key: C::Key::from_slice(key),
43            n,
44        }
45    }
46
47    /// Rekey. Set our key to `REKEY(old key)`.
48    pub fn rekey(&mut self) {
49        self.key = C::rekey(&self.key);
50    }
51
52    /// AEAD encryption.
53    pub fn encrypt_ad(&mut self, authtext: &[u8], plaintext: &[u8], out: &mut [u8]) {
54        C::encrypt(&self.key, self.n, authtext, plaintext, out);
55        #[cfg(feature = "use_std")]
56        if option_env!("NOISE_RUST_TEST_IN_PLACE").is_some() {
57            let mut inout = plaintext.to_vec();
58            inout.extend_from_slice(&[0; 16]);
59            let l = C::encrypt_in_place(&self.key, self.n, authtext, &mut inout, plaintext.len());
60            assert_eq!(inout, out);
61            assert_eq!(l, out.len());
62        }
63        // This will fail when n == 2 ^ 64 - 1, complying to the spec.
64        self.n = self.n.checked_add(1).unwrap();
65    }
66
67    /// AEAD encryption in place.
68    pub fn encrypt_ad_in_place(
69        &mut self,
70        authtext: &[u8],
71        in_out: &mut [u8],
72        plaintext_len: usize,
73    ) -> usize {
74        let size = C::encrypt_in_place(&self.key, self.n, authtext, in_out, plaintext_len);
75        // This will fail when n == 2 ^ 64 - 1, complying to the spec.
76        self.n = self.n.checked_add(1).unwrap();
77        size
78    }
79
80    /// AEAD decryption.
81    pub fn decrypt_ad(
82        &mut self,
83        authtext: &[u8],
84        ciphertext: &[u8],
85        out: &mut [u8],
86    ) -> Result<(), ()> {
87        let r = C::decrypt(&self.key, self.n, authtext, ciphertext, out);
88        #[cfg(feature = "use_std")]
89        if option_env!("NOISE_RUST_TEST_IN_PLACE").is_some() {
90            let mut inout = ciphertext.to_vec();
91            let r2 = C::decrypt_in_place(&self.key, self.n, authtext, &mut inout, ciphertext.len());
92            assert_eq!(r.map(|_| out.len()), r2);
93            if r.is_ok() {
94                assert_eq!(&inout[..out.len()], out);
95            }
96        }
97        r?;
98        self.n = self.n.checked_add(1).unwrap();
99        Ok(())
100    }
101
102    /// AEAD decryption in place.
103    pub fn decrypt_ad_in_place(
104        &mut self,
105        authtext: &[u8],
106        in_out: &mut [u8],
107        ciphertext_len: usize,
108    ) -> Result<usize, ()> {
109        let size = C::decrypt_in_place(&self.key, self.n, authtext, in_out, ciphertext_len)?;
110        self.n = self.n.checked_add(1).unwrap();
111        Ok(size)
112    }
113
114    /// Encryption.
115    pub fn encrypt(&mut self, plaintext: &[u8], out: &mut [u8]) {
116        self.encrypt_ad(&[0u8; 0], plaintext, out)
117    }
118
119    /// Encryption in place.
120    pub fn encrypt_in_place(&mut self, in_out: &mut [u8], plaintext_len: usize) -> usize {
121        self.encrypt_ad_in_place(&[0u8; 0], in_out, plaintext_len)
122    }
123
124    /// Encryption, returns ciphertext as `Vec<u8>`.
125    #[cfg(any(feature = "use_std", feature = "use_alloc"))]
126    pub fn encrypt_vec(&mut self, plaintext: &[u8]) -> Vec<u8> {
127        let mut out = vec![0u8; plaintext.len() + 16];
128        self.encrypt(plaintext, &mut out);
129        out
130    }
131
132    /// Decryption.
133    pub fn decrypt(&mut self, ciphertext: &[u8], out: &mut [u8]) -> Result<(), ()> {
134        self.decrypt_ad(&[0u8; 0], ciphertext, out)
135    }
136
137    /// Decryption in place.
138    pub fn decrypt_in_place(
139        &mut self,
140        in_out: &mut [u8],
141        ciphertext_len: usize,
142    ) -> Result<usize, ()> {
143        self.decrypt_ad_in_place(&[0u8; 0], in_out, ciphertext_len)
144    }
145
146    /// Decryption, returns plaintext as `Vec<u8>`.
147    #[cfg(any(feature = "use_std", feature = "use_alloc"))]
148    pub fn decrypt_vec(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, ()> {
149        if ciphertext.len() < 16 {
150            return Err(());
151        }
152        let mut out = vec![0u8; ciphertext.len() - 16];
153        self.decrypt(ciphertext, &mut out)?;
154        Ok(out)
155    }
156
157    /// Get the next value of `n`. Could be used to decide on whether to re-key, etc.
158    pub fn get_next_n(&self) -> u64 {
159        self.n
160    }
161
162    /// Get underlying cipher and nonce.
163    ///
164    /// This is useful for e.g. WireGuard. Because packets may be lost or arrive out of order,
165    /// they would likely want to deal with nonces themselves.
166    pub fn extract(self) -> (C::Key, u64) {
167        (self.key, self.n)
168    }
169}