1use thiserror::Error;
4
5use crate::Sealed;
6
7#[derive(Error, Clone, Debug)]
9pub enum Error {
10 #[error("padding error")]
12 Padding,
13 #[error("ECDH error")]
15 Ecdh,
16}
17
18pub type EncryptionError = Error;
20
21pub type DecryptionError = Error;
23
24pub type EncryptionKey = [u8; 16];
28
29pub const PUBLIC_KEY_LEN: usize = 33;
34
35#[derive(Debug, Clone, Copy)]
38pub(crate) enum EncryptOp<'a> {
39 Input(&'a [u8]),
40 Flush,
41}
42
43pub(crate) trait Sink = crate::Sink<Error>;
45
46pub(crate) trait Encryptor: Sealed {
48 fn encrypt<S>(&mut self, operation: EncryptOp, sink: &mut S) -> Result<(), S::Error>
49 where
50 S: Sink;
51}
52
53pub(crate) trait Decryptor: Sealed {
55 fn decrypt<S>(
56 &mut self,
57 input: &[u8],
58 reached_to_end: bool,
59 sink: &mut S,
60 ) -> Result<(), S::Error>
61 where
62 S: Sink;
63}
64
65pub use ecdh::{gen_echd_key_pair, PublicKey, SecretKey};
66
67pub(crate) mod ecdh {
71 use std::mem;
72
73 use p256::{ecdh::diffie_hellman, elliptic_curve};
74 use rand_core::OsRng;
75
76 use crate::encrypt::{EncryptionKey, Error, PUBLIC_KEY_LEN};
77
78 pub type SecretKey = [u8; 32];
82
83 pub type PublicKey = [u8; 33];
88
89 pub(crate) const EMPTY_PUBLIC_KEY: PublicKey = [0; PUBLIC_KEY_LEN];
91
92 impl From<elliptic_curve::Error> for Error {
93 #[inline]
94 fn from(_: elliptic_curve::Error) -> Self {
95 Self::Ecdh
96 }
97 }
98
99 #[inline]
101 pub fn gen_echd_key_pair() -> (SecretKey, PublicKey) {
102 let secret_key = p256::SecretKey::random(&mut OsRng);
103 let public_key = p256::EncodedPoint::from(secret_key.public_key()).compress();
104 (secret_key.to_bytes().into(), public_key.as_bytes().try_into().unwrap())
105 }
106
107 pub(crate) struct Keys {
110 pub(crate) public_key: PublicKey,
112 pub(crate) encryption_key: EncryptionKey,
114 }
115
116 impl Keys {
117 pub(crate) fn new(public_key: &PublicKey) -> Result<Self, Error> {
119 let public_key = p256::PublicKey::from_sec1_bytes(public_key.as_ref())?;
120 let secret_key = p256::SecretKey::random(&mut OsRng);
121
122 let encryption_key =
123 diffie_hellman(secret_key.to_nonzero_scalar(), public_key.as_affine());
124 let encryption_key = encryption_key.raw_secret_bytes().as_slice()
125 [..mem::size_of::<EncryptionKey>()]
126 .try_into()
127 .map_err(|_| Error::Ecdh)?;
128
129 let public_key = p256::EncodedPoint::from(secret_key.public_key()).compress();
130 let public_key = public_key.as_bytes().try_into().map_err(|_| Error::Ecdh)?;
131
132 Ok(Self { public_key, encryption_key })
133 }
134 }
135
136 #[inline]
139 pub(crate) fn ecdh_encryption_key(
140 secret_key: &SecretKey,
141 public_key: &PublicKey,
142 ) -> Result<EncryptionKey, Error> {
143 let secret_key = p256::SecretKey::from_slice(secret_key.as_ref())?;
144 let public_key = p256::PublicKey::from_sec1_bytes(public_key.as_ref())?;
145
146 let encryption_key = diffie_hellman(secret_key.to_nonzero_scalar(), public_key.as_affine());
147 encryption_key.raw_secret_bytes().as_slice()[..mem::size_of::<EncryptionKey>()]
148 .try_into()
149 .map_err(|_| Error::Ecdh)
150 }
151}
152
153pub(crate) use aes::{Decryptor as AesDecryptor, Encryptor as AesEncryptor};
154
155pub(crate) mod aes {
161 use aes::{Aes128Dec, Aes128Enc};
162 use cipher::{
163 block_padding::{NoPadding, Pkcs7, UnpadError},
164 inout::PadError,
165 BlockDecrypt, BlockEncrypt, KeyInit,
166 };
167
168 use crate::{
169 common::BytesBuf,
170 encrypt::{
171 Decryptor as DecryptorTrait, EncryptOp, EncryptionKey, Encryptor as EncryptorTrait,
172 Error, Sink,
173 },
174 Sealed,
175 };
176
177 const BLOCK_SIZE: usize = 16;
179
180 impl From<PadError> for Error {
181 #[inline]
182 fn from(_: PadError) -> Self {
183 Self::Padding
184 }
185 }
186
187 impl From<UnpadError> for Error {
188 #[inline]
189 fn from(_: UnpadError) -> Self {
190 Self::Padding
191 }
192 }
193
194 pub(crate) struct Encryptor {
196 inner: Aes128Enc,
197 buffer: BytesBuf,
198 }
199
200 impl Encryptor {
201 const BUFFER_LEN: usize = 16 * BLOCK_SIZE;
205
206 #[inline]
208 pub(crate) fn new(key: &EncryptionKey) -> Self {
209 let inner = Aes128Enc::new(key.into());
210 let buffer = BytesBuf::with_capacity(Self::BUFFER_LEN);
211 Self { inner, buffer }
212 }
213 }
214
215 impl EncryptorTrait for Encryptor {
216 fn encrypt<S>(&mut self, operation: EncryptOp, sink: &mut S) -> Result<(), S::Error>
217 where
218 S: Sink,
219 {
220 match operation {
221 EncryptOp::Input(mut input) => {
222 while !input.is_empty() {
223 let buffered = self.buffer.buffer(input);
224 debug_assert_ne!(
225 buffered, 0,
226 "the size of buffer needs to be greater than or equal to `BLOCK_SIZE`"
227 );
228
229 self.buffer.sink(sink, false, |buf, len| {
230 self.inner.encrypt_padded::<NoPadding>(buf, len)
231 })?;
232
233 input = &input[buffered..];
235 }
236 Ok(())
237 }
238 EncryptOp::Flush => self
239 .buffer
240 .sink(sink, true, |buf, len| self.inner.encrypt_padded::<Pkcs7>(buf, len)),
241 }
242 }
243 }
244
245 impl Sealed for Encryptor {}
246
247 pub(crate) struct Decryptor {
249 inner: Aes128Dec,
250 buffer: BytesBuf,
251 }
252
253 impl Decryptor {
254 const BUFFER_LEN: usize = 64 * BLOCK_SIZE;
258
259 #[inline]
261 pub(crate) fn new(key: &EncryptionKey) -> Self {
262 let inner = Aes128Dec::new(key.into());
263 let buffer = BytesBuf::with_capacity(Self::BUFFER_LEN);
264 Self { inner, buffer }
265 }
266 }
267
268 impl DecryptorTrait for Decryptor {
269 fn decrypt<S>(
270 &mut self,
271 mut input: &[u8],
272 reached_to_end: bool,
273 sink: &mut S,
274 ) -> Result<(), S::Error>
275 where
276 S: Sink,
277 {
278 while !input.is_empty() {
279 let buffered = self.buffer.buffer(input);
280 debug_assert_ne!(
281 buffered, 0,
282 "the size of buffer needs to be greater than or equal to `BLOCK_SIZE`"
283 );
284
285 let reached_to_end = reached_to_end && buffered == input.len();
286 self.buffer.sink(sink, reached_to_end, |buf, len| {
287 let buf = &mut buf[..len];
288 if reached_to_end {
289 self.inner.decrypt_padded::<Pkcs7>(buf)
290 } else {
291 self.inner.decrypt_padded::<NoPadding>(buf)
292 }
293 })?;
294
295 input = &input[buffered..];
297 }
298 Ok(())
299 }
300 }
301
302 impl Sealed for Decryptor {}
303
304 impl BytesBuf {
305 fn sink<S, E>(
308 &mut self,
309 sink: &mut S,
310 pad: bool,
311 handle: impl FnOnce(&mut [u8], usize) -> Result<&[u8], E>,
312 ) -> Result<(), S::Error>
313 where
314 S: Sink,
315 E: Into<Error>,
316 {
317 let len = if pad { self.len() } else { self.len() / BLOCK_SIZE * BLOCK_SIZE };
318 let buffer = self.as_buffer_mut_slice();
319
320 let bytes = handle(buffer, len).map_err(Into::into)?;
321 if !bytes.is_empty() {
322 sink.sink(bytes)?;
323 }
324 self.drain(len);
325
326 Ok(())
327 }
328 }
329}
330
331impl<T> Encryptor for Option<T>
332where
333 T: Encryptor,
334{
335 #[inline]
336 fn encrypt<S>(&mut self, operation: EncryptOp, sink: &mut S) -> Result<(), S::Error>
337 where
338 S: Sink,
339 {
340 match self {
341 Some(encryptor) => encryptor.encrypt(operation, sink),
342 None => match operation {
344 EncryptOp::Input(bytes) => sink.sink(bytes),
345 _ => Ok(()),
346 },
347 }
348 }
349}
350
351impl<T> Decryptor for Option<T>
352where
353 T: Decryptor,
354{
355 #[inline]
356 fn decrypt<S>(
357 &mut self,
358 input: &[u8],
359 reached_to_end: bool,
360 sink: &mut S,
361 ) -> Result<(), S::Error>
362 where
363 S: Sink,
364 {
365 match self {
366 Some(decryptor) => decryptor.decrypt(input, reached_to_end, sink),
367 None => sink.sink(input),
369 }
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use std::slice;
376
377 use crate::encrypt::{
378 AesDecryptor, AesEncryptor, Decryptor, EncryptOp, EncryptionKey, Encryptor,
379 };
380
381 const KEY: EncryptionKey = [0x23; 16];
382
383 fn aes_encrypt(input: &[u8]) -> Vec<u8> {
384 let mut encryptor = AesEncryptor::new(&KEY);
385 let mut sink = Vec::new();
386 let mut sink_mul = Vec::new();
387
388 encryptor.encrypt(EncryptOp::Input(input), &mut sink).unwrap();
390 encryptor.encrypt(EncryptOp::Flush, &mut sink).unwrap();
391
392 for byte in input {
394 encryptor.encrypt(EncryptOp::Input(slice::from_ref(byte)), &mut sink_mul).unwrap();
395 }
396 encryptor.encrypt(EncryptOp::Flush, &mut sink_mul).unwrap();
397
398 assert_eq!(sink, sink_mul);
399 sink
400 }
401
402 fn aes_decrypt(input: &[u8]) -> Vec<u8> {
403 let mut decryptor = AesDecryptor::new(&KEY);
404 let mut sink = Vec::new();
405 let mut sink_mul = Vec::new();
406
407 decryptor.decrypt(input, true, &mut sink).unwrap();
409
410 for (idx, byte) in input.iter().enumerate() {
412 decryptor
413 .decrypt(slice::from_ref(byte), idx == input.len() - 1, &mut sink_mul)
414 .unwrap();
415 }
416
417 assert_eq!(sink, sink_mul);
418 sink
419 }
420
421 #[test]
422 fn test_aes() {
423 let data = b"Hello World";
425 assert_eq!(aes_decrypt(&aes_encrypt(data)), data);
426
427 let data = b"123456789ABCDEFG";
429 assert_eq!(aes_decrypt(&aes_encrypt(data)), data);
430
431 let data = b"Hello, I'm Tangent, nice to meet you.";
433 assert_eq!(aes_decrypt(&aes_encrypt(data)), data);
434 }
435}