#[cfg(any(feature = "aes", feature = "aes-std"))]
use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherError};
#[cfg(any(feature = "aes", feature = "aes-std"))]
use aes::{Aes128, Aes192, Aes256};
use kvstructs::bytes::Bytes;
pub const BLOCK_SIZE: usize = 16;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum EncryptionAlgorithm {
None = 0,
#[cfg(any(feature = "aes", feature = "aes-std"))]
Aes = 1,
}
impl EncryptionAlgorithm {
pub fn as_str_name(&self) -> &'static str {
match self {
EncryptionAlgorithm::None => "None",
#[cfg(any(feature = "aes", feature = "aes-std"))]
EncryptionAlgorithm::Aes => "Aes",
}
}
}
impl EncryptionAlgorithm {
#[inline]
pub const fn is_none(&self) -> bool {
match self {
#[cfg(any(feature = "aes", feature = "aes-std"))]
EncryptionAlgorithm::Aes => false,
_ => true,
}
}
#[inline]
pub const fn is_some(&self) -> bool {
match self {
#[cfg(any(feature = "aes", feature = "aes-std"))]
EncryptionAlgorithm::Aes => true,
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub struct Encryption {
algo: EncryptionAlgorithm,
secret: kvstructs::bytes::Bytes,
}
impl Encryption {
#[inline]
pub const fn new() -> Self {
Self {
algo: EncryptionAlgorithm::None,
secret: Bytes::new(),
}
}
#[inline]
pub fn set_secret(mut self, secret: Bytes) -> Self {
self.secret = secret;
self
}
#[inline]
pub const fn set_algorithm(mut self, algo: EncryptionAlgorithm) -> Self {
self.algo = algo;
self
}
#[cfg(any(feature = "aes", feature = "aes-std"))]
#[inline]
pub fn aes(secret: impl Into<Bytes>) -> Self {
Self {
algo: EncryptionAlgorithm::Aes,
secret: secret.into(),
}
}
#[inline]
pub const fn algorithm(&self) -> EncryptionAlgorithm {
self.algo
}
#[inline]
pub const fn is_none(&self) -> bool {
self.algo.is_none()
}
#[inline]
pub const fn is_some(&self) -> bool {
self.algo.is_some()
}
#[inline]
pub fn secret(&self) -> &[u8] {
self.secret.as_ref()
}
#[inline]
pub fn secret_bytes(&self) -> kvstructs::bytes::Bytes {
self.secret.clone()
}
#[inline]
pub const fn block_size(&self) -> usize {
block_size(self.algo)
}
}
impl prost::Message for Encryption {
#[allow(unused_variables)]
fn encode_raw<B>(&self, buf: &mut B)
where
B: prost::bytes::BufMut,
{
if self.algo != EncryptionAlgorithm::default() {
prost::encoding::int32::encode(1u32, &(self.algo as i32), buf);
}
if self.secret != b"" as &[u8] {
prost::encoding::bytes::encode(2u32, &self.secret, buf);
}
}
#[allow(unused_variables)]
fn merge_field<B>(
&mut self,
tag: u32,
wire_type: prost::encoding::WireType,
buf: &mut B,
ctx: prost::encoding::DecodeContext,
) -> ::core::result::Result<(), prost::DecodeError>
where
B: prost::bytes::Buf,
{
const STRUCT_NAME: &str = "Encryption";
match tag {
1u32 => {
let value = &mut (self.algo as i32);
prost::encoding::int32::merge(wire_type, value, buf, ctx).map_err(|mut error| {
error.push(STRUCT_NAME, "algo");
error
})
}
2u32 => {
let value = &mut self.secret;
prost::encoding::bytes::merge(wire_type, value, buf, ctx).map_err(|mut error| {
error.push(STRUCT_NAME, "secret");
error
})
}
_ => prost::encoding::skip_field(wire_type, tag, buf, ctx),
}
}
#[inline]
fn encoded_len(&self) -> usize {
(if self.algo != EncryptionAlgorithm::default() {
prost::encoding::int32::encoded_len(1u32, &(self.algo as i32))
} else {
0
}) + (if self.secret != b"" as &[u8] {
prost::encoding::bytes::encoded_len(2u32, &self.secret)
} else {
0
})
}
fn clear(&mut self) {
self.algo = EncryptionAlgorithm::default();
self.secret.clear();
}
}
#[cfg(any(feature = "aes", feature = "aes-std"))]
pub type Aes128Ctr = ctr::Ctr64BE<Aes128>;
#[cfg(any(feature = "aes", feature = "aes-std"))]
pub type Aes192Ctr = ctr::Ctr64BE<Aes192>;
#[cfg(any(feature = "aes", feature = "aes-std"))]
pub type Aes256Ctr = ctr::Ctr64BE<Aes256>;
#[cfg(any(feature = "aes", feature = "aes-std"))]
#[derive(Debug, Copy, Clone)]
pub enum AesError {
InvalidLength(aes::cipher::InvalidLength),
KeySizeError(usize),
StreamCipherError(StreamCipherError),
}
#[cfg(any(feature = "aes", feature = "aes-std"))]
impl core::fmt::Display for AesError {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
match self {
AesError::KeySizeError(size) => write!(
f,
"aes: invalid key size {}, only supports 16, 24, 32 length for key",
size
),
AesError::StreamCipherError(err) => write!(f, "aes: {}", err),
AesError::InvalidLength(e) => write!(f, "{}", e),
}
}
}
#[cfg(all(any(feature = "aes", feature = "aes-std"), feature = "std"))]
impl std::error::Error for AesError {}
#[derive(Debug, Copy, Clone)]
pub enum EncryptError {
#[cfg(any(feature = "aes", feature = "aes-std"))]
Aes(AesError),
LengthMismatch {
src: usize,
dst: usize,
},
}
impl core::fmt::Display for EncryptError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
EncryptError::LengthMismatch { src, dst } => write!(
f,
"aes length mismatch: the length of source is {} and the length of destination {}",
src, dst
),
#[cfg(any(feature = "aes", feature = "aes-std"))]
EncryptError::Aes(e) => write!(f, "{}", e), }
}
}
#[cfg(feature = "std")]
impl std::error::Error for EncryptError {}
pub trait Encryptor {
fn encrypt(
&mut self,
key: &[u8],
iv: &[u8],
algo: EncryptionAlgorithm,
) -> Result<(), EncryptError>
where
Self: AsMut<[u8]>,
{
let data = self.as_mut();
encrypt(data, key, iv, algo)
}
fn encrypt_to_vec(
&self,
key: &[u8],
iv: &[u8],
algo: EncryptionAlgorithm,
) -> Result<Vec<u8>, EncryptError>
where
Self: AsRef<[u8]>,
{
let src = self.as_ref();
encrypt_to_vec(src, key, iv, algo)
}
fn encrypt_to(
&self,
dst: &mut [u8],
key: &[u8],
iv: &[u8],
algo: EncryptionAlgorithm,
) -> Result<(), EncryptError>
where
Self: AsRef<[u8]>,
{
let src = self.as_ref();
encrypt_to(dst, src, key, iv, algo)
}
}
impl<T> Encryptor for T {}
#[inline(always)]
pub const fn block_size(algo: EncryptionAlgorithm) -> usize {
match algo {
EncryptionAlgorithm::None => 0,
#[cfg(any(feature = "aes", feature = "aes-std"))]
EncryptionAlgorithm::Aes => BLOCK_SIZE,
}
}
#[inline]
pub fn encrypt(
_data: &mut [u8],
_key: &[u8],
_iv: &[u8],
algo: EncryptionAlgorithm,
) -> Result<(), EncryptError> {
match algo {
#[cfg(any(feature = "aes", feature = "aes-std"))]
EncryptionAlgorithm::Aes => aes_encrypt_in(_data, _key, _iv),
_ => Ok(()),
}
}
#[inline]
pub fn encrypt_to_vec(
src: &[u8],
_key: &[u8],
_iv: &[u8],
algo: EncryptionAlgorithm,
) -> Result<Vec<u8>, EncryptError> {
let mut dst = src.to_vec();
match algo {
#[cfg(any(feature = "aes", feature = "aes-std"))]
EncryptionAlgorithm::Aes => aes_encrypt_in(dst.as_mut(), _key, _iv).map(|_| dst),
_ => Ok(dst),
}
}
#[inline]
pub fn encrypt_to(
dst: &mut [u8],
src: &[u8],
_key: &[u8],
_iv: &[u8],
algo: EncryptionAlgorithm,
) -> Result<(), EncryptError> {
if dst.len() != src.len() {
return Err(EncryptError::LengthMismatch {
src: src.len(),
dst: dst.len(),
});
}
dst.copy_from_slice(src);
match algo {
#[cfg(any(feature = "aes", feature = "aes-std"))]
EncryptionAlgorithm::Aes => aes_encrypt_in(dst, _key, _iv),
_ => Ok(()),
}
}
#[cfg(any(feature = "aes", feature = "aes-std"))]
#[inline(always)]
fn aes_encrypt_in(dst: &mut [u8], key: &[u8], iv: &[u8]) -> Result<(), EncryptError> {
let kl = key.len();
match kl {
16 => Aes128Ctr::new_from_slices(key, iv)
.map_err(|e| EncryptError::Aes(AesError::InvalidLength(e)))?
.try_apply_keystream(dst)
.map_err(|e| EncryptError::Aes(AesError::StreamCipherError(e))),
24 => Aes192Ctr::new_from_slices(key, iv)
.map_err(|e| EncryptError::Aes(AesError::InvalidLength(e)))?
.try_apply_keystream(dst)
.map_err(|e| EncryptError::Aes(AesError::StreamCipherError(e))),
32 => Aes256Ctr::new_from_slices(key, iv)
.map_err(|e| EncryptError::Aes(AesError::InvalidLength(e)))?
.try_apply_keystream(dst)
.map_err(|e| EncryptError::Aes(AesError::StreamCipherError(e))),
_ => Err(EncryptError::Aes(AesError::KeySizeError(kl))),
}
}
pub fn random_iv() -> [u8; BLOCK_SIZE] {
#[cfg(feature = "std")]
{
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
rng.gen::<[u8; BLOCK_SIZE]>()
}
#[cfg(not(feature = "std"))]
{
use rand::{rngs::OsRng, RngCore};
let mut key = [0u8; BLOCK_SIZE];
OsRng.fill_bytes(&mut key);
key
}
}
macro_rules! impl_encryption_algo_converter {
($($ty:ty),+ $(,)?) => {
$(
impl From<$ty> for EncryptionAlgorithm {
fn from(val: $ty) -> EncryptionAlgorithm {
match val {
#[cfg(any(feature = "aes", feature = "aes-std"))]
1 => EncryptionAlgorithm::Aes,
_ => EncryptionAlgorithm::None,
}
}
}
)*
};
}
impl_encryption_algo_converter!(i8, i16, i32, i64, isize, i128, u8, u16, u32, u64, usize, u128);
#[cfg(test)]
mod test {
use super::*;
use rand::{thread_rng, Rng};
#[test]
fn test_encrypt() {
let mut rng = thread_rng();
let key = rng.gen::<[u8; 32]>();
let iv = random_iv();
let mut src = [0u8; 1024];
rng.fill(&mut src);
let mut dst = vec![0u8; 1024];
encrypt_to(
dst.as_mut_slice(),
&src,
&key,
&iv,
EncryptionAlgorithm::Aes,
)
.unwrap();
let act = encrypt_to_vec(dst.as_slice(), &key, &iv, EncryptionAlgorithm::Aes).unwrap();
assert_eq!(src.clone().to_vec(), act);
let mut dst = vec![0u8; 1024];
encrypt_to(
dst.as_mut_slice(),
&src,
&key,
&iv,
EncryptionAlgorithm::None,
)
.unwrap();
assert_eq!(dst.as_slice(), src.as_ref());
let act = encrypt_to_vec(dst.as_slice(), &key, &iv, EncryptionAlgorithm::None).unwrap();
assert_eq!(src.clone().to_vec(), act);
}
}