use thiserror::Error;
use crate::Sealed;
#[derive(Error, Clone, Debug)]
pub enum Error {
#[error("padding error")]
Padding,
#[error("ECDH error")]
Ecdh,
}
pub type EncryptionError = Error;
pub type DecryptionError = Error;
pub type EncryptionKey = [u8; 16];
pub const PUBLIC_KEY_LEN: usize = 33;
#[derive(Debug, Clone, Copy)]
pub(crate) enum EncryptOp<'a> {
Input(&'a [u8]),
Flush,
}
pub(crate) trait Sink = crate::Sink<Error>;
pub(crate) trait Encryptor: Sealed {
fn encrypt<S>(&mut self, operation: EncryptOp, sink: &mut S) -> Result<(), S::Error>
where
S: Sink;
}
pub(crate) trait Decryptor: Sealed {
fn decrypt<S>(
&mut self,
input: &[u8],
reached_to_end: bool,
sink: &mut S,
) -> Result<(), S::Error>
where
S: Sink;
}
pub use ecdh::{gen_echd_key_pair, PublicKey, SecretKey};
pub(crate) mod ecdh {
use std::mem;
use p256::{ecdh::diffie_hellman, elliptic_curve};
use rand_core::OsRng;
use crate::encrypt::{EncryptionKey, Error, PUBLIC_KEY_LEN};
pub type SecretKey = [u8; 32];
pub type PublicKey = [u8; 33];
pub(crate) const EMPTY_PUBLIC_KEY: PublicKey = [0; PUBLIC_KEY_LEN];
impl From<elliptic_curve::Error> for Error {
#[inline]
fn from(_: elliptic_curve::Error) -> Self {
Self::Ecdh
}
}
#[inline]
pub fn gen_echd_key_pair() -> (SecretKey, PublicKey) {
let secret_key = p256::SecretKey::random(&mut OsRng);
let public_key = p256::EncodedPoint::from(secret_key.public_key()).compress();
(secret_key.to_bytes().into(), public_key.as_bytes().try_into().unwrap())
}
pub(crate) struct Keys {
pub(crate) public_key: PublicKey,
pub(crate) encryption_key: EncryptionKey,
}
impl Keys {
pub(crate) fn new(public_key: &PublicKey) -> Result<Self, Error> {
let public_key = p256::PublicKey::from_sec1_bytes(public_key.as_ref())?;
let secret_key = p256::SecretKey::random(&mut OsRng);
let encryption_key =
diffie_hellman(secret_key.to_nonzero_scalar(), public_key.as_affine());
let encryption_key = encryption_key.raw_secret_bytes().as_slice()
[..mem::size_of::<EncryptionKey>()]
.try_into()
.map_err(|_| Error::Ecdh)?;
let public_key = p256::EncodedPoint::from(secret_key.public_key()).compress();
let public_key = public_key.as_bytes().try_into().map_err(|_| Error::Ecdh)?;
Ok(Self { public_key, encryption_key })
}
}
#[inline]
pub(crate) fn ecdh_encryption_key(
secret_key: &SecretKey,
public_key: &PublicKey,
) -> Result<EncryptionKey, Error> {
let secret_key = p256::SecretKey::from_slice(secret_key.as_ref())?;
let public_key = p256::PublicKey::from_sec1_bytes(public_key.as_ref())?;
let encryption_key = diffie_hellman(secret_key.to_nonzero_scalar(), public_key.as_affine());
encryption_key.raw_secret_bytes().as_slice()[..mem::size_of::<EncryptionKey>()]
.try_into()
.map_err(|_| Error::Ecdh)
}
}
pub(crate) use aes::{Decryptor as AesDecryptor, Encryptor as AesEncryptor};
pub(crate) mod aes {
use aes::{Aes128Dec, Aes128Enc};
use cipher::{
block_padding::{NoPadding, Pkcs7, UnpadError},
inout::PadError,
BlockDecrypt, BlockEncrypt, KeyInit,
};
use crate::{
common::BytesBuf,
encrypt::{
Decryptor as DecryptorTrait, EncryptOp, EncryptionKey, Encryptor as EncryptorTrait,
Error, Sink,
},
Sealed,
};
const BLOCK_SIZE: usize = 16;
impl From<PadError> for Error {
#[inline]
fn from(_: PadError) -> Self {
Self::Padding
}
}
impl From<UnpadError> for Error {
#[inline]
fn from(_: UnpadError) -> Self {
Self::Padding
}
}
pub(crate) struct Encryptor {
inner: Aes128Enc,
buffer: BytesBuf,
}
impl Encryptor {
const BUFFER_LEN: usize = 16 * BLOCK_SIZE;
#[inline]
pub(crate) fn new(key: &EncryptionKey) -> Self {
let inner = Aes128Enc::new(key.into());
let buffer = BytesBuf::with_capacity(Self::BUFFER_LEN);
Self { inner, buffer }
}
}
impl EncryptorTrait for Encryptor {
fn encrypt<S>(&mut self, operation: EncryptOp, sink: &mut S) -> Result<(), S::Error>
where
S: Sink,
{
match operation {
EncryptOp::Input(mut input) => {
while !input.is_empty() {
let buffered = self.buffer.buffer(input);
debug_assert_ne!(
buffered, 0,
"the size of buffer needs to be greater than or equal to `BLOCK_SIZE`"
);
self.buffer.sink(sink, false, |buf, len| {
self.inner.encrypt_padded::<NoPadding>(buf, len)
})?;
input = &input[buffered..];
}
Ok(())
}
EncryptOp::Flush => self
.buffer
.sink(sink, true, |buf, len| self.inner.encrypt_padded::<Pkcs7>(buf, len)),
}
}
}
impl Sealed for Encryptor {}
pub(crate) struct Decryptor {
inner: Aes128Dec,
buffer: BytesBuf,
}
impl Decryptor {
const BUFFER_LEN: usize = 64 * BLOCK_SIZE;
#[inline]
pub(crate) fn new(key: &EncryptionKey) -> Self {
let inner = Aes128Dec::new(key.into());
let buffer = BytesBuf::with_capacity(Self::BUFFER_LEN);
Self { inner, buffer }
}
}
impl DecryptorTrait for Decryptor {
fn decrypt<S>(
&mut self,
mut input: &[u8],
reached_to_end: bool,
sink: &mut S,
) -> Result<(), S::Error>
where
S: Sink,
{
while !input.is_empty() {
let buffered = self.buffer.buffer(input);
debug_assert_ne!(
buffered, 0,
"the size of buffer needs to be greater than or equal to `BLOCK_SIZE`"
);
let reached_to_end = reached_to_end && buffered == input.len();
self.buffer.sink(sink, reached_to_end, |buf, len| {
let buf = &mut buf[..len];
if reached_to_end {
self.inner.decrypt_padded::<Pkcs7>(buf)
} else {
self.inner.decrypt_padded::<NoPadding>(buf)
}
})?;
input = &input[buffered..];
}
Ok(())
}
}
impl Sealed for Decryptor {}
impl BytesBuf {
fn sink<S, E>(
&mut self,
sink: &mut S,
pad: bool,
handle: impl FnOnce(&mut [u8], usize) -> Result<&[u8], E>,
) -> Result<(), S::Error>
where
S: Sink,
E: Into<Error>,
{
let len = if pad { self.len() } else { self.len() / BLOCK_SIZE * BLOCK_SIZE };
let buffer = self.as_buffer_mut_slice();
let bytes = handle(buffer, len).map_err(Into::into)?;
if !bytes.is_empty() {
sink.sink(bytes)?;
}
self.drain(len);
Ok(())
}
}
}
impl<T> Encryptor for Option<T>
where
T: Encryptor,
{
#[inline]
fn encrypt<S>(&mut self, operation: EncryptOp, sink: &mut S) -> Result<(), S::Error>
where
S: Sink,
{
match self {
Some(encryptor) => encryptor.encrypt(operation, sink),
None => match operation {
EncryptOp::Input(bytes) => sink.sink(bytes),
_ => Ok(()),
},
}
}
}
impl<T> Decryptor for Option<T>
where
T: Decryptor,
{
#[inline]
fn decrypt<S>(
&mut self,
input: &[u8],
reached_to_end: bool,
sink: &mut S,
) -> Result<(), S::Error>
where
S: Sink,
{
match self {
Some(decryptor) => decryptor.decrypt(input, reached_to_end, sink),
None => sink.sink(input),
}
}
}
#[cfg(test)]
mod tests {
use std::slice;
use crate::encrypt::{
AesDecryptor, AesEncryptor, Decryptor, EncryptOp, EncryptionKey, Encryptor,
};
const KEY: EncryptionKey = [0x23; 16];
fn aes_encrypt(input: &[u8]) -> Vec<u8> {
let mut encryptor = AesEncryptor::new(&KEY);
let mut sink = Vec::new();
let mut sink_mul = Vec::new();
encryptor.encrypt(EncryptOp::Input(input), &mut sink).unwrap();
encryptor.encrypt(EncryptOp::Flush, &mut sink).unwrap();
for byte in input {
encryptor.encrypt(EncryptOp::Input(slice::from_ref(byte)), &mut sink_mul).unwrap();
}
encryptor.encrypt(EncryptOp::Flush, &mut sink_mul).unwrap();
assert_eq!(sink, sink_mul);
sink
}
fn aes_decrypt(input: &[u8]) -> Vec<u8> {
let mut decryptor = AesDecryptor::new(&KEY);
let mut sink = Vec::new();
let mut sink_mul = Vec::new();
decryptor.decrypt(input, true, &mut sink).unwrap();
for (idx, byte) in input.iter().enumerate() {
decryptor
.decrypt(slice::from_ref(byte), idx == input.len() - 1, &mut sink_mul)
.unwrap();
}
assert_eq!(sink, sink_mul);
sink
}
#[test]
fn test_aes() {
let data = b"Hello World";
assert_eq!(aes_decrypt(&aes_encrypt(data)), data);
let data = b"123456789ABCDEFG";
assert_eq!(aes_decrypt(&aes_encrypt(data)), data);
let data = b"Hello, I'm Tangent, nice to meet you.";
assert_eq!(aes_decrypt(&aes_encrypt(data)), data);
}
}