distant_net/common/transport/framed/codec/
encryption.rs

1use std::{fmt, io};
2
3use derive_more::Display;
4
5use super::{Codec, Frame};
6use crate::common::{SecretKey, SecretKey32};
7
8/// Represents the type of encryption for a [`EncryptionCodec`]
9#[derive(
10    Copy, Clone, Debug, Display, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize,
11)]
12pub enum EncryptionType {
13    /// ChaCha20Poly1305 variant with an extended 192-bit (24-byte) nonce
14    #[display(fmt = "xchacha20poly1305")]
15    XChaCha20Poly1305,
16
17    /// Indicates an unknown encryption type for use in handshakes
18    #[display(fmt = "unknown")]
19    #[serde(other)]
20    Unknown,
21}
22
23impl EncryptionType {
24    /// Generates bytes for a secret key based on the encryption type
25    pub fn generate_secret_key_bytes(&self) -> io::Result<Vec<u8>> {
26        match self {
27            Self::XChaCha20Poly1305 => Ok(SecretKey::<32>::generate()
28                .unwrap()
29                .into_heap_secret_key()
30                .unprotected_into_bytes()),
31            Self::Unknown => Err(io::Error::new(
32                io::ErrorKind::InvalidInput,
33                "Unknown encryption type",
34            )),
35        }
36    }
37
38    /// Returns a list of all variants of the type *except* unknown.
39    pub const fn known_variants() -> &'static [EncryptionType] {
40        &[EncryptionType::XChaCha20Poly1305]
41    }
42
43    /// Returns true if type is unknown
44    pub fn is_unknown(&self) -> bool {
45        matches!(self, Self::Unknown)
46    }
47
48    /// Creates a new [`EncryptionCodec`] for this type, failing if this type is unknown or the key
49    /// is an invalid length
50    pub fn new_codec(&self, key: &[u8]) -> io::Result<EncryptionCodec> {
51        EncryptionCodec::from_type_and_key(*self, key)
52    }
53}
54
55/// Represents the codec that encodes & decodes frames by encrypting/decrypting them
56#[derive(Clone)]
57pub enum EncryptionCodec {
58    /// ChaCha20Poly1305 variant with an extended 192-bit (24-byte) nonce, using
59    /// [`XChaCha20Poly1305`] underneath
60    XChaCha20Poly1305 {
61        cipher: chacha20poly1305::XChaCha20Poly1305,
62    },
63}
64
65impl EncryptionCodec {
66    /// Makes a new [`EncryptionCodec`] based on the [`EncryptionType`] and `key`, returning an
67    /// error if the key is invalid for the encryption type or the type is unknown
68    pub fn from_type_and_key(ty: EncryptionType, key: &[u8]) -> io::Result<EncryptionCodec> {
69        match ty {
70            EncryptionType::XChaCha20Poly1305 => {
71                use chacha20poly1305::{KeyInit, XChaCha20Poly1305};
72                let cipher = XChaCha20Poly1305::new_from_slice(key)
73                    .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?;
74                Ok(Self::XChaCha20Poly1305 { cipher })
75            }
76            EncryptionType::Unknown => Err(io::Error::new(
77                io::ErrorKind::InvalidInput,
78                "Encryption type is unknown",
79            )),
80        }
81    }
82
83    pub fn new_xchacha20poly1305(secret_key: SecretKey32) -> EncryptionCodec {
84        // NOTE: This should never fail as we are enforcing the key size at compile time
85        Self::from_type_and_key(
86            EncryptionType::XChaCha20Poly1305,
87            secret_key.unprotected_as_bytes(),
88        )
89        .unwrap()
90    }
91
92    /// Returns the encryption type associa ted with the codec
93    pub fn ty(&self) -> EncryptionType {
94        match self {
95            Self::XChaCha20Poly1305 { .. } => EncryptionType::XChaCha20Poly1305,
96        }
97    }
98
99    /// Size of nonce (in bytes) associated with the encryption algorithm
100    pub const fn nonce_size(&self) -> usize {
101        match self {
102            // XChaCha20Poly1305 uses a 192-bit (24-byte) key
103            Self::XChaCha20Poly1305 { .. } => 24,
104        }
105    }
106
107    /// Generates a new nonce for the encryption algorithm
108    fn generate_nonce_bytes(&self) -> Vec<u8> {
109        // NOTE: As seen in orion, with a 24-bit nonce, it's safe to generate instead of
110        //       maintaining a stateful counter due to its size (24-byte secret key generation
111        //       will never panic)
112        match self {
113            Self::XChaCha20Poly1305 { .. } => SecretKey::<24>::generate()
114                .unwrap()
115                .into_heap_secret_key()
116                .unprotected_into_bytes(),
117        }
118    }
119}
120
121impl fmt::Debug for EncryptionCodec {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        f.debug_struct("EncryptionCodec")
124            .field("cipher", &"**OMITTED**".to_string())
125            .field("nonce_size", &self.nonce_size())
126            .field("ty", &self.ty().to_string())
127            .finish()
128    }
129}
130
131impl Codec for EncryptionCodec {
132    fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
133        let nonce_bytes = self.generate_nonce_bytes();
134
135        Ok(match self {
136            Self::XChaCha20Poly1305 { cipher } => {
137                use chacha20poly1305::aead::Aead;
138                use chacha20poly1305::XNonce;
139                let item = frame.into_item();
140                let nonce = XNonce::from_slice(&nonce_bytes);
141
142                // Encrypt the frame's item as our ciphertext
143                let ciphertext = cipher
144                    .encrypt(nonce, item.as_ref())
145                    .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
146
147                // Start our frame with the nonce at the beginning
148                let mut frame = Frame::from(nonce_bytes);
149                frame.extend(ciphertext);
150
151                frame
152            }
153        })
154    }
155
156    fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
157        let nonce_size = self.nonce_size();
158        if frame.len() <= nonce_size {
159            return Err(io::Error::new(
160                io::ErrorKind::InvalidData,
161                format!("Frame cannot have length less than {}", nonce_size + 1),
162            ));
163        }
164
165        // Grab the nonce from the front of the frame, and then use it with the remainder
166        // of the frame to tease out the decrypted frame item
167        let item = match self {
168            Self::XChaCha20Poly1305 { cipher } => {
169                use chacha20poly1305::aead::Aead;
170                use chacha20poly1305::XNonce;
171                let nonce = XNonce::from_slice(&frame.as_item()[..nonce_size]);
172                cipher
173                    .decrypt(nonce, &frame.as_item()[nonce_size..])
174                    .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?
175            }
176        };
177
178        Ok(Frame::from(item))
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use test_log::test;
185
186    use super::*;
187
188    #[test]
189    fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
190        let ty = EncryptionType::XChaCha20Poly1305;
191        let key = ty.generate_secret_key_bytes().unwrap();
192        let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
193
194        let frame = codec
195            .encode(Frame::new(b"hello world"))
196            .expect("Failed to encode");
197
198        let nonce = &frame.as_item()[..codec.nonce_size()];
199        let ciphertext = &frame.as_item()[codec.nonce_size()..];
200
201        // Manually build our key & cipher so we can decrypt the frame manually to ensure it is
202        // correct
203        let item = {
204            use chacha20poly1305::aead::Aead;
205            use chacha20poly1305::{KeyInit, XChaCha20Poly1305, XNonce};
206            let cipher = XChaCha20Poly1305::new_from_slice(&key).unwrap();
207            cipher
208                .decrypt(XNonce::from_slice(nonce), ciphertext)
209                .expect("Failed to decrypt")
210        };
211        assert_eq!(item, b"hello world");
212    }
213
214    #[test]
215    fn decode_should_fail_if_frame_length_is_smaller_than_nonce_plus_data() {
216        let ty = EncryptionType::XChaCha20Poly1305;
217        let key = ty.generate_secret_key_bytes().unwrap();
218        let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
219
220        // NONCE_SIZE + 1 is minimum for frame length
221        let frame = Frame::from(b"a".repeat(codec.nonce_size()));
222
223        let result = codec.decode(frame);
224        match result {
225            Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
226            x => panic!("Unexpected result: {:?}", x),
227        }
228    }
229
230    #[test]
231    fn decode_should_fail_if_unable_to_decrypt_frame_item() {
232        let ty = EncryptionType::XChaCha20Poly1305;
233        let key = ty.generate_secret_key_bytes().unwrap();
234        let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
235
236        // NONCE_SIZE + 1 is minimum for frame length
237        let frame = Frame::from(b"a".repeat(codec.nonce_size() + 1));
238
239        let result = codec.decode(frame);
240        match result {
241            Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
242            x => panic!("Unexpected result: {:?}", x),
243        }
244    }
245
246    #[test]
247    fn decode_should_return_decrypted_frame_when_successful() {
248        let ty = EncryptionType::XChaCha20Poly1305;
249        let key = ty.generate_secret_key_bytes().unwrap();
250        let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
251
252        let frame = codec
253            .encode(Frame::new(b"hello, world"))
254            .expect("Failed to encode");
255
256        let frame = codec.decode(frame).expect("Failed to decode");
257        assert_eq!(frame, b"hello, world");
258    }
259}