distant_net/common/transport/framed/codec/
encryption.rs1use std::{fmt, io};
2
3use derive_more::Display;
4
5use super::{Codec, Frame};
6use crate::common::{SecretKey, SecretKey32};
7
8#[derive(
10 Copy, Clone, Debug, Display, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize,
11)]
12pub enum EncryptionType {
13 #[display(fmt = "xchacha20poly1305")]
15 XChaCha20Poly1305,
16
17 #[display(fmt = "unknown")]
19 #[serde(other)]
20 Unknown,
21}
22
23impl EncryptionType {
24 pub fn generate_secret_key_bytes(&self) -> io::Result<Vec<u8>> {
26 match self {
27 Self::XChaCha20Poly1305 => Ok(SecretKey::<32>::generate()
28 .unwrap()
29 .into_heap_secret_key()
30 .unprotected_into_bytes()),
31 Self::Unknown => Err(io::Error::new(
32 io::ErrorKind::InvalidInput,
33 "Unknown encryption type",
34 )),
35 }
36 }
37
38 pub const fn known_variants() -> &'static [EncryptionType] {
40 &[EncryptionType::XChaCha20Poly1305]
41 }
42
43 pub fn is_unknown(&self) -> bool {
45 matches!(self, Self::Unknown)
46 }
47
48 pub fn new_codec(&self, key: &[u8]) -> io::Result<EncryptionCodec> {
51 EncryptionCodec::from_type_and_key(*self, key)
52 }
53}
54
55#[derive(Clone)]
57pub enum EncryptionCodec {
58 XChaCha20Poly1305 {
61 cipher: chacha20poly1305::XChaCha20Poly1305,
62 },
63}
64
65impl EncryptionCodec {
66 pub fn from_type_and_key(ty: EncryptionType, key: &[u8]) -> io::Result<EncryptionCodec> {
69 match ty {
70 EncryptionType::XChaCha20Poly1305 => {
71 use chacha20poly1305::{KeyInit, XChaCha20Poly1305};
72 let cipher = XChaCha20Poly1305::new_from_slice(key)
73 .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?;
74 Ok(Self::XChaCha20Poly1305 { cipher })
75 }
76 EncryptionType::Unknown => Err(io::Error::new(
77 io::ErrorKind::InvalidInput,
78 "Encryption type is unknown",
79 )),
80 }
81 }
82
83 pub fn new_xchacha20poly1305(secret_key: SecretKey32) -> EncryptionCodec {
84 Self::from_type_and_key(
86 EncryptionType::XChaCha20Poly1305,
87 secret_key.unprotected_as_bytes(),
88 )
89 .unwrap()
90 }
91
92 pub fn ty(&self) -> EncryptionType {
94 match self {
95 Self::XChaCha20Poly1305 { .. } => EncryptionType::XChaCha20Poly1305,
96 }
97 }
98
99 pub const fn nonce_size(&self) -> usize {
101 match self {
102 Self::XChaCha20Poly1305 { .. } => 24,
104 }
105 }
106
107 fn generate_nonce_bytes(&self) -> Vec<u8> {
109 match self {
113 Self::XChaCha20Poly1305 { .. } => SecretKey::<24>::generate()
114 .unwrap()
115 .into_heap_secret_key()
116 .unprotected_into_bytes(),
117 }
118 }
119}
120
121impl fmt::Debug for EncryptionCodec {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("EncryptionCodec")
124 .field("cipher", &"**OMITTED**".to_string())
125 .field("nonce_size", &self.nonce_size())
126 .field("ty", &self.ty().to_string())
127 .finish()
128 }
129}
130
131impl Codec for EncryptionCodec {
132 fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
133 let nonce_bytes = self.generate_nonce_bytes();
134
135 Ok(match self {
136 Self::XChaCha20Poly1305 { cipher } => {
137 use chacha20poly1305::aead::Aead;
138 use chacha20poly1305::XNonce;
139 let item = frame.into_item();
140 let nonce = XNonce::from_slice(&nonce_bytes);
141
142 let ciphertext = cipher
144 .encrypt(nonce, item.as_ref())
145 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Encryption failed"))?;
146
147 let mut frame = Frame::from(nonce_bytes);
149 frame.extend(ciphertext);
150
151 frame
152 }
153 })
154 }
155
156 fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
157 let nonce_size = self.nonce_size();
158 if frame.len() <= nonce_size {
159 return Err(io::Error::new(
160 io::ErrorKind::InvalidData,
161 format!("Frame cannot have length less than {}", nonce_size + 1),
162 ));
163 }
164
165 let item = match self {
168 Self::XChaCha20Poly1305 { cipher } => {
169 use chacha20poly1305::aead::Aead;
170 use chacha20poly1305::XNonce;
171 let nonce = XNonce::from_slice(&frame.as_item()[..nonce_size]);
172 cipher
173 .decrypt(nonce, &frame.as_item()[nonce_size..])
174 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Decryption failed"))?
175 }
176 };
177
178 Ok(Frame::from(item))
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use test_log::test;
185
186 use super::*;
187
188 #[test]
189 fn encode_should_build_a_frame_containing_a_length_nonce_and_ciphertext() {
190 let ty = EncryptionType::XChaCha20Poly1305;
191 let key = ty.generate_secret_key_bytes().unwrap();
192 let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
193
194 let frame = codec
195 .encode(Frame::new(b"hello world"))
196 .expect("Failed to encode");
197
198 let nonce = &frame.as_item()[..codec.nonce_size()];
199 let ciphertext = &frame.as_item()[codec.nonce_size()..];
200
201 let item = {
204 use chacha20poly1305::aead::Aead;
205 use chacha20poly1305::{KeyInit, XChaCha20Poly1305, XNonce};
206 let cipher = XChaCha20Poly1305::new_from_slice(&key).unwrap();
207 cipher
208 .decrypt(XNonce::from_slice(nonce), ciphertext)
209 .expect("Failed to decrypt")
210 };
211 assert_eq!(item, b"hello world");
212 }
213
214 #[test]
215 fn decode_should_fail_if_frame_length_is_smaller_than_nonce_plus_data() {
216 let ty = EncryptionType::XChaCha20Poly1305;
217 let key = ty.generate_secret_key_bytes().unwrap();
218 let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
219
220 let frame = Frame::from(b"a".repeat(codec.nonce_size()));
222
223 let result = codec.decode(frame);
224 match result {
225 Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
226 x => panic!("Unexpected result: {:?}", x),
227 }
228 }
229
230 #[test]
231 fn decode_should_fail_if_unable_to_decrypt_frame_item() {
232 let ty = EncryptionType::XChaCha20Poly1305;
233 let key = ty.generate_secret_key_bytes().unwrap();
234 let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
235
236 let frame = Frame::from(b"a".repeat(codec.nonce_size() + 1));
238
239 let result = codec.decode(frame);
240 match result {
241 Err(x) if x.kind() == io::ErrorKind::InvalidData => {}
242 x => panic!("Unexpected result: {:?}", x),
243 }
244 }
245
246 #[test]
247 fn decode_should_return_decrypted_frame_when_successful() {
248 let ty = EncryptionType::XChaCha20Poly1305;
249 let key = ty.generate_secret_key_bytes().unwrap();
250 let mut codec = EncryptionCodec::from_type_and_key(ty, &key).unwrap();
251
252 let frame = codec
253 .encode(Frame::new(b"hello, world"))
254 .expect("Failed to encode");
255
256 let frame = codec.decode(frame).expect("Failed to decode");
257 assert_eq!(frame, b"hello, world");
258 }
259}