use alloc::vec::Vec;
use core::marker::PhantomData;
use hybrid_array::Array;
use crate::{
codec::{Decode, Encode},
encryption::EncryptionPrimitive,
};
pub struct Encrypted<T, E: EncryptionPrimitive, C> {
ciphertext: Vec<u8>,
nonce: Array<u8, E::NonceSize>,
_marker: PhantomData<fn() -> (T, C)>,
}
impl<T, E: EncryptionPrimitive, C> Clone for Encrypted<T, E, C>
where
Array<u8, E::NonceSize>: Clone,
{
fn clone(&self) -> Self {
Self {
ciphertext: self.ciphertext.clone(),
nonce: self.nonce.clone(),
_marker: PhantomData,
}
}
}
impl<T, E: EncryptionPrimitive, C> PartialEq for Encrypted<T, E, C>
where
Array<u8, E::NonceSize>: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.ciphertext == other.ciphertext && self.nonce == other.nonce
}
}
impl<T, E: EncryptionPrimitive, C> Eq for Encrypted<T, E, C> where Array<u8, E::NonceSize>: Eq {}
impl<T, E: EncryptionPrimitive, C> core::fmt::Debug for Encrypted<T, E, C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Encrypted")
.field("ciphertext", &self.ciphertext)
.field("nonce", &self.nonce.as_slice())
.finish()
}
}
impl<T, E: EncryptionPrimitive, C> Encrypted<T, E, C> {
#[must_use]
pub(crate) fn new(ciphertext: Vec<u8>, nonce: Array<u8, E::NonceSize>) -> Self {
Self {
ciphertext,
nonce,
_marker: PhantomData,
}
}
#[must_use]
#[allow(clippy::expect_used)] pub fn encrypt(key: &E::Key, plaintext: &T) -> Self
where
C: Encode<T>,
{
let encoded = C::encode(plaintext).expect("encoding failed");
let (ciphertext, nonce) = E::encrypt(key, &encoded);
Self::new(ciphertext, nonce)
}
#[must_use]
#[allow(clippy::expect_used)] pub fn encrypt_with_nonce(key: &E::Key, nonce: &Array<u8, E::NonceSize>, plaintext: &T) -> Self
where
C: Encode<T>,
{
let encoded = C::encode(plaintext).expect("encoding failed");
let ciphertext = E::encrypt_with_nonce(key, nonce, &encoded);
Self::new(ciphertext, nonce.clone())
}
pub fn try_decrypt(&self, key: &E::Key) -> Result<T, DecryptionError>
where
C: Decode<T>,
{
let plaintext = E::decrypt(key, &self.nonce, &self.ciphertext)
.map_err(|_| DecryptionError::AuthenticationFailed)?;
C::decode(&plaintext).map_err(|_| DecryptionError::DecodeError)
}
#[must_use]
pub fn ciphertext(&self) -> &[u8] {
&self.ciphertext
}
#[must_use]
pub const fn nonce(&self) -> &Array<u8, E::NonceSize> {
&self.nonce
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DecryptionError {
AuthenticationFailed,
DecodeError,
}
impl core::fmt::Display for DecryptionError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::AuthenticationFailed => write!(f, "authentication failed"),
Self::DecodeError => write!(f, "plaintext decode error"),
}
}
}
pub trait EncryptedUnchecked<T, E: EncryptionPrimitive, C> {
fn from_unchecked_parts(ciphertext: Vec<u8>, nonce: Array<u8, E::NonceSize>) -> Self;
}
impl<T, E: EncryptionPrimitive, C> EncryptedUnchecked<T, E, C> for Encrypted<T, E, C> {
fn from_unchecked_parts(ciphertext: Vec<u8>, nonce: Array<u8, E::NonceSize>) -> Self {
Self::new(ciphertext, nonce)
}
}
#[cfg(feature = "serde")]
impl<T, E: EncryptionPrimitive, C> serde::Serialize for Encrypted<T, E, C>
where
Array<u8, E::NonceSize>: serde::Serialize,
{
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("Encrypted", 2)?;
state.serialize_field("ciphertext", &self.ciphertext)?;
state.serialize_field("nonce", &self.nonce)?;
state.end()
}
}
#[cfg(feature = "serde")]
impl<'de, T, E: EncryptionPrimitive, C> serde::Deserialize<'de> for Encrypted<T, E, C>
where
Array<u8, E::NonceSize>: serde::Deserialize<'de>,
{
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use serde::de::{MapAccess, Visitor};
struct EncryptedVisitor<T, E: EncryptionPrimitive, C>(PhantomData<(T, E, C)>);
impl<'de, T, E: EncryptionPrimitive, C> Visitor<'de> for EncryptedVisitor<T, E, C>
where
Array<u8, E::NonceSize>: serde::Deserialize<'de>,
{
type Value = Encrypted<T, E, C>;
fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
formatter.write_str("struct Encrypted")
}
fn visit_map<V: MapAccess<'de>>(
self,
mut map: V,
) -> Result<Encrypted<T, E, C>, V::Error> {
let mut ciphertext = None;
let mut nonce = None;
while let Some(key) = map.next_key::<&str>()? {
match key {
"ciphertext" => ciphertext = Some(map.next_value()?),
"nonce" => nonce = Some(map.next_value()?),
_ => {
let _: serde::de::IgnoredAny = map.next_value()?;
}
}
}
let ciphertext =
ciphertext.ok_or_else(|| serde::de::Error::missing_field("ciphertext"))?;
let nonce = nonce.ok_or_else(|| serde::de::Error::missing_field("nonce"))?;
Ok(Encrypted::new(ciphertext, nonce))
}
}
const FIELDS: &[&str] = &["ciphertext", "nonce"];
deserializer.deserialize_struct("Encrypted", FIELDS, EncryptedVisitor(PhantomData))
}
}
#[cfg(feature = "arbitrary")]
impl<'a, T, E: EncryptionPrimitive, C> arbitrary::Arbitrary<'a> for Encrypted<T, E, C>
where
Array<u8, E::NonceSize>: arbitrary::Arbitrary<'a>,
{
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let ciphertext = Vec::arbitrary(u)?;
let nonce = Array::arbitrary(u)?;
Ok(Self::new(ciphertext, nonce))
}
}
#[cfg(feature = "bolero")]
impl<T: 'static, E: EncryptionPrimitive + 'static, C: 'static> bolero_generator::TypeGenerator
for Encrypted<T, E, C>
where
Array<u8, E::NonceSize>: bolero_generator::TypeGenerator,
{
fn generate<D: bolero_generator::Driver>(driver: &mut D) -> Option<Self> {
let ciphertext = Vec::generate(driver)?;
let nonce = Array::generate(driver)?;
Some(Self::new(ciphertext, nonce))
}
}
#[cfg(feature = "proptest")]
impl<T: 'static, E: EncryptionPrimitive + 'static, C: 'static> proptest::arbitrary::Arbitrary
for Encrypted<T, E, C>
where
Array<u8, E::NonceSize>: core::fmt::Debug,
{
type Parameters = ();
type Strategy = proptest::strategy::BoxedStrategy<Self>;
#[allow(clippy::expect_used)] fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
use hybrid_array::typenum::Unsigned;
use proptest::prelude::*;
(
proptest::collection::vec(any::<u8>(), 0..256),
proptest::collection::vec(any::<u8>(), E::NonceSize::USIZE),
)
.prop_map(|(ciphertext, nonce_vec)| {
let nonce = Array::try_from(nonce_vec.as_slice()).expect("correct length");
Self::new(ciphertext, nonce)
})
.boxed()
}
}
#[cfg(feature = "rkyv")]
pub mod archive {
use super::{Array, Encrypted, EncryptionPrimitive};
use alloc::vec::Vec;
use rkyv::{Archive, Deserialize, Serialize, rancor::Fallible};
#[derive(Debug, Archive, Serialize, Deserialize)]
#[rkyv(derive(Debug))]
pub struct EncryptedBytes {
ciphertext: Vec<u8>,
nonce: Vec<u8>,
}
impl ArchivedEncryptedBytes {
#[must_use]
pub fn ciphertext(&self) -> &[u8] {
&self.ciphertext
}
#[must_use]
pub fn nonce(&self) -> &[u8] {
&self.nonce
}
}
impl<T, E: EncryptionPrimitive, C> Archive for Encrypted<T, E, C> {
type Archived = ArchivedEncryptedBytes;
type Resolver = <EncryptedBytes as Archive>::Resolver;
fn resolve(&self, resolver: Self::Resolver, out: rkyv::Place<Self::Archived>) {
let helper = EncryptedBytes {
ciphertext: self.ciphertext.clone(),
nonce: self.nonce.as_slice().to_vec(),
};
helper.resolve(resolver, out);
}
}
impl<T, E: EncryptionPrimitive, C, S> Serialize<S> for Encrypted<T, E, C>
where
S: Fallible + rkyv::ser::Allocator + rkyv::ser::Writer + ?Sized,
{
fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
let helper = EncryptedBytes {
ciphertext: self.ciphertext.clone(),
nonce: self.nonce.as_slice().to_vec(),
};
helper.serialize(serializer)
}
}
impl<T, E: EncryptionPrimitive, C, D> Deserialize<Encrypted<T, E, C>, D> for ArchivedEncryptedBytes
where
D: Fallible + ?Sized,
D::Error: rkyv::rancor::Source,
{
#[allow(clippy::expect_used)] fn deserialize(&self, deserializer: &mut D) -> Result<Encrypted<T, E, C>, D::Error> {
let helper: EncryptedBytes = <ArchivedEncryptedBytes as Deserialize<
EncryptedBytes,
D,
>>::deserialize(self, deserializer)?;
let nonce = Array::try_from(helper.nonce.as_slice()).expect("invalid nonce length");
Ok(Encrypted::new(helper.ciphertext, nonce))
}
}
}