use super::Controlled;
use crate::private::ControlledPrivate;
use crate::OpaqueDebug;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Zeroize, OpaqueDebug)]
pub struct Protected<T>(pub(crate) T);
impl<T> Protected<T> {
pub const fn new(x: T) -> Self
where
T: Zeroize,
{
Self(x)
}
}
impl<T> Protected<Protected<T>> {
#[inline]
pub fn flatten(self) -> Protected<T> {
self.0
}
}
impl<T> Protected<Option<T>> {
#[inline]
pub fn transpose(self) -> Option<Protected<T>> {
self.0.map(Protected)
}
}
impl<T: Zeroize> ZeroizeOnDrop for Protected<T> {}
impl<T> ControlledPrivate for Protected<T> {}
impl<T> Controlled for Protected<T>
where
T: Zeroize,
{
fn risky_unwrap(self) -> Self::Inner {
self.0
}
type Inner = T;
fn init_from_inner(x: Self::Inner) -> Self {
Self(x)
}
fn risky_ref(&self) -> &T {
&self.0
}
fn inner_mut(&mut self) -> &mut Self::Inner {
&mut self.0
}
}
impl<T> Copy for Protected<T> where T: Copy {}
impl<T> Clone for Protected<T>
where
T: Clone,
{
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T, A> Extend<A> for Protected<T>
where
T: Extend<A>,
{
fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = A>,
{
self.0.extend(iter);
}
}
#[cfg(feature = "arbitrary")]
impl<T> quickcheck::Arbitrary for Protected<T>
where
T: quickcheck::Arbitrary + Zeroize,
{
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
let inner = T::arbitrary(g);
Self::new(inner)
}
}
pub fn flatten_array<const N: usize, T>(array: [Protected<T>; N]) -> Protected<[T; N]>
where
T: Zeroize + Default + Copy, {
let mut out: [T; N] = [Default::default(); N];
array.iter().enumerate().for_each(|(i, x)| {
out[i] = x.risky_unwrap();
});
Protected::new(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_array() {
let x = Protected::new([0u8; 32]);
assert_eq!(x.0, [0u8; 32]);
}
#[test]
fn test_opaque_debug() {
let x = Protected::new([0u8; 32]);
assert_eq!(
format!("{x:?}"),
"vitaminc_protected::protected::Protected<[u8; 32]>(\"***\")"
);
}
#[test]
fn test_flatten() {
let x = Protected::new(Protected::new([0u8; 32]));
let y = x.flatten();
assert_eq!(y.risky_unwrap(), [0u8; 32]);
}
#[test]
fn test_flatten_array() {
let x = Protected::new(1);
let y = Protected::new(2);
let z = Protected::new(3);
let array: [Protected<u8>; 3] = [x, y, z];
let flattened = flatten_array(array);
assert!(matches!(flattened, Protected(_)));
assert_eq!(flattened.risky_unwrap(), [1, 2, 3]);
}
}