use core::{
array::TryFromSliceError,
fmt::{self, Debug, Formatter},
hash,
marker::{PhantomData, PhantomPinned},
ops::Deref,
};
use crate::generic_array::{ArrayLength, GenericArray};
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use subtle::{Choice, ConstantTimeEq};
use zeroize::Zeroize;
use super::HexRepr;
use crate::{
error::Error,
kdf::{FromKeyDerivation, KeyDerivation},
random::KeyMaterial,
};
#[derive(Clone)]
#[repr(transparent)]
pub struct ArrayKey<L: ArrayLength<u8>>(
GenericArray<u8, L>,
PhantomPinned,
);
impl<L: ArrayLength<u8>> ArrayKey<L> {
pub const SIZE: usize = L::USIZE;
#[inline]
pub fn generate(mut rng: impl KeyMaterial) -> Self {
Self::new_with(|buf| rng.read_okm(buf))
}
pub fn new_with(f: impl FnOnce(&mut [u8])) -> Self {
let mut slf = Self::default();
f(slf.0.as_mut());
slf
}
pub fn try_new_with<E>(f: impl FnOnce(&mut [u8]) -> Result<(), E>) -> Result<Self, E> {
let mut slf = Self::default();
f(slf.0.as_mut())?;
Ok(slf)
}
pub fn temp<R>(f: impl FnOnce(&mut GenericArray<u8, L>) -> R) -> R {
let mut slf = Self::default();
f(&mut slf.0)
}
#[inline]
pub fn extract(self) -> GenericArray<u8, L> {
self.0.clone()
}
#[inline]
pub fn from_slice(data: &[u8]) -> Self {
Self::from(GenericArray::from_slice(data))
}
#[inline]
pub fn len() -> usize {
Self::SIZE
}
#[cfg(feature = "getrandom")]
#[inline]
pub fn random() -> Self {
Self::generate(crate::random::default_rng())
}
pub fn as_hex(&self) -> HexRepr<&[u8]> {
HexRepr(self.0.as_ref())
}
}
impl<L: ArrayLength<u8>> AsRef<GenericArray<u8, L>> for ArrayKey<L> {
#[inline(always)]
fn as_ref(&self) -> &GenericArray<u8, L> {
&self.0
}
}
impl<L: ArrayLength<u8>> Deref for ArrayKey<L> {
type Target = [u8];
#[inline(always)]
fn deref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl<L: ArrayLength<u8>> Default for ArrayKey<L> {
#[inline(always)]
fn default() -> Self {
Self(GenericArray::default(), PhantomPinned)
}
}
impl<L: ArrayLength<u8>> From<&GenericArray<u8, L>> for ArrayKey<L> {
#[inline(always)]
fn from(key: &GenericArray<u8, L>) -> Self {
Self(key.clone(), PhantomPinned)
}
}
impl<L: ArrayLength<u8>> From<GenericArray<u8, L>> for ArrayKey<L> {
#[inline(always)]
fn from(key: GenericArray<u8, L>) -> Self {
Self(key, PhantomPinned)
}
}
impl<'a, L: ArrayLength<u8>, const N: usize> TryFrom<&'a ArrayKey<L>> for &'a [u8; N] {
type Error = TryFromSliceError;
#[inline(always)]
fn try_from(key: &ArrayKey<L>) -> Result<&[u8; N], TryFromSliceError> {
key.0.as_slice().try_into()
}
}
impl<L: ArrayLength<u8>> Debug for ArrayKey<L> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if cfg!(test) {
f.debug_tuple("ArrayKey").field(&self.0).finish()
} else {
f.debug_tuple("ArrayKey").field(&"<secret>").finish()
}
}
}
impl<L: ArrayLength<u8>> ConstantTimeEq for ArrayKey<L> {
fn ct_eq(&self, other: &Self) -> Choice {
ConstantTimeEq::ct_eq(self.0.as_ref(), other.0.as_ref())
}
}
impl<L: ArrayLength<u8>> PartialEq for ArrayKey<L> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl<L: ArrayLength<u8>> Eq for ArrayKey<L> {}
impl<L: ArrayLength<u8>> hash::Hash for ArrayKey<L> {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl<L: ArrayLength<u8>> Serialize for ArrayKey<L> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(self.as_ref())
}
}
impl<'de, L: ArrayLength<u8>> Deserialize<'de> for ArrayKey<L> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_bytes(KeyVisitor { _pd: PhantomData })
}
}
impl<L: ArrayLength<u8>> Zeroize for ArrayKey<L> {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl<L: ArrayLength<u8>> Drop for ArrayKey<L> {
fn drop(&mut self) {
self.zeroize();
}
}
struct KeyVisitor<L: ArrayLength<u8>> {
_pd: PhantomData<L>,
}
impl<'de, L: ArrayLength<u8>> de::Visitor<'de> for KeyVisitor<L> {
type Value = ArrayKey<L>;
fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter.write_str("ArrayKey")
}
fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
if value.len() != L::USIZE {
return Err(E::invalid_length(value.len(), &self));
}
Ok(ArrayKey::from_slice(value))
}
}
impl<L: ArrayLength<u8>> FromKeyDerivation for ArrayKey<L> {
fn from_key_derivation<D: KeyDerivation>(mut derive: D) -> Result<Self, Error>
where
Self: Sized,
{
Self::try_new_with(|buf| derive.derive_key_bytes(buf))
}
}