use alloc::vec::Vec;
use core::{
fmt::Debug,
ops::{Deref, DerefMut},
};
use subtle::ConstantTimeEq;
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
#[derive(Clone, Debug)]
pub struct SafeArray<T, const N: usize>(Vec<T>);
impl<T, const N: usize> SafeArray<T, N> {
pub const LEN: usize = N;
}
impl<T, const N: usize> AsRef<[T]> for SafeArray<T, N> {
fn as_ref(&self) -> &[T] {
&self.0
}
}
impl<T, const N: usize> AsMut<[T]> for SafeArray<T, N> {
fn as_mut(&mut self) -> &mut [T] {
&mut self.0
}
}
impl<T, const N: usize> Deref for SafeArray<T, N> {
type Target = [T];
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
impl<T, const N: usize> DerefMut for SafeArray<T, N> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut()
}
}
#[cfg(feature = "zeroize")]
impl<T, const N: usize> Zeroize for SafeArray<T, N>
where T: Zeroize
{
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl<T, const N: usize> Default for SafeArray<T, N>
where T: Clone + Default
{
fn default() -> Self {
Self(vec![T::default(); N])
}
}
impl<T, const N: usize> ConstantTimeEq for SafeArray<T, N>
where T: ConstantTimeEq
{
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.0.ct_eq(&other.0)
}
}
impl<T, const N: usize> Eq for SafeArray<T, N> where T: ConstantTimeEq {}
impl<T, const N: usize> PartialEq for SafeArray<T, N>
where T: ConstantTimeEq
{
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).unwrap_u8() == 1u8
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reference() {
use rand::{rngs::OsRng, RngCore};
use zeroize::Zeroize;
use crate::{hidden::Hidden, hidden_type};
hidden_type!(CipherKey, SafeArray<u8, 32>);
let key_a = CipherKey::from(SafeArray::<u8, 32>::default());
let key_b = CipherKey::from(SafeArray::<u8, 32>::default());
assert_eq!(key_a.reveal(), key_b.reveal());
assert_eq!(key_a.reveal().as_ref(), &[0u8; 32]);
let mut key_c = CipherKey::from(SafeArray::<u8, 32>::default());
let mut rng = OsRng;
rng.fill_bytes(key_c.reveal_mut());
assert_ne!(key_c.reveal().as_ref(), &[0u8; 32]);
}
#[test]
fn len() {
const N: usize = 64;
assert_eq!(SafeArray::<u8, N>::default().len(), N);
assert_eq!(SafeArray::<u8, 64>::LEN, N);
}
}