use crate::{exportable::SafeSerialize, private::ControlledPrivate, Controlled, Protected};
use core::num::NonZeroU16;
use serde::{Serialize, Serializer};
use subtle::ConstantTimeEq as SubtleCtEq;
use zeroize::Zeroize;
#[derive(Debug, Zeroize)]
pub struct Equatable<T>(pub(crate) T);
impl<T> Equatable<T> {
pub fn new(x: <Equatable<T> as Controlled>::Inner) -> Self
where
Self: Controlled,
{
Self::init_from_inner(x)
}
}
impl<T> From<T> for Equatable<T>
where
T: ControlledPrivate,
{
fn from(x: T) -> Self {
Self(x)
}
}
impl<T: Controlled> Equatable<T>
where
T::Inner: ConstantTimeEq,
{
pub fn constant_time_eq(&self, other: &Self) -> bool {
self.risky_ref().constant_time_eq(other.risky_ref())
}
}
impl<T: ControlledPrivate> ControlledPrivate for Equatable<T> {}
impl<T> Controlled for Equatable<T>
where
T: Controlled,
{
type Inner = T::Inner;
fn init_from_inner(x: Self::Inner) -> Self {
Self(T::init_from_inner(x))
}
fn risky_ref(&self) -> &Self::Inner {
self.0.risky_ref()
}
fn inner_mut(&mut self) -> &mut Self::Inner {
self.0.inner_mut()
}
fn risky_unwrap(self) -> Self::Inner {
self.0.risky_unwrap()
}
}
impl<T, A> Extend<A> for Equatable<T>
where
T: Extend<A>,
{
fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = A>,
{
self.0.extend(iter);
}
}
impl<T> From<T> for Equatable<Protected<T>>
where
T: Into<Protected<T>> + Zeroize,
{
fn from(x: T) -> Self {
Self(Protected::init_from_inner(x))
}
}
impl<T, O> PartialEq<O> for Equatable<T>
where
T: Controlled,
O: Controlled,
<T as Controlled>::Inner: ConstantTimeEq<O::Inner>,
{
fn eq(&self, other: &O) -> bool {
self.risky_ref().constant_time_eq(other.risky_ref())
}
}
impl<T, O> ConstantTimeEq<O> for Equatable<T>
where
T: Controlled,
O: Controlled,
<T as Controlled>::Inner: ConstantTimeEq<O::Inner>,
{
fn constant_time_eq(&self, other: &O) -> bool {
self.risky_ref().constant_time_eq(other.risky_ref())
}
}
pub trait ConstantTimeEq<Rhs: ?Sized = Self>: private::SupportsConstantTimeEq {
fn constant_time_eq(&self, other: &Rhs) -> bool;
}
impl<const N: usize, T> ConstantTimeEq<Self> for [T; N]
where
T: ConstantTimeEq,
{
fn constant_time_eq(&self, other: &Self) -> bool {
let mut x = true;
for (ai, bi) in self.iter().zip(other.iter()) {
x &= ai.constant_time_eq(bi);
}
x
}
}
macro_rules! impl_constany_time_eq {
($($type:ty),+) => {
$(
impl ConstantTimeEq for $type {
fn constant_time_eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
)+
};
}
impl_constany_time_eq!(u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128);
impl ConstantTimeEq for NonZeroU16 {
#[inline]
fn constant_time_eq(&self, other: &Self) -> bool {
let mut a_inner = self.get();
let mut b_inner = other.get();
let result = a_inner.constant_time_eq(&b_inner);
a_inner.zeroize();
b_inner.zeroize();
result
}
}
impl ConstantTimeEq for [u8] {
fn constant_time_eq(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
let mut x = true;
for (ai, bi) in self.iter().zip(other.iter()) {
x &= ai.constant_time_eq(bi);
}
x
}
}
impl ConstantTimeEq for str {
#[inline]
fn constant_time_eq(&self, other: &Self) -> bool {
self.as_bytes().constant_time_eq(other.as_bytes())
}
}
impl ConstantTimeEq for String {
fn constant_time_eq(&self, other: &Self) -> bool {
self.as_bytes().constant_time_eq(other.as_bytes())
}
}
impl<T> Serialize for Equatable<T>
where
T: Controlled,
T::Inner: SafeSerialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.risky_ref().safe_serialize(serializer)
}
}
mod private {
use std::num::NonZeroU16;
use super::Equatable;
pub trait SupportsConstantTimeEq {}
impl<T> SupportsConstantTimeEq for Equatable<T> {}
impl<const N: usize, T> SupportsConstantTimeEq for [T; N] {}
impl SupportsConstantTimeEq for u8 {}
impl SupportsConstantTimeEq for u16 {}
impl SupportsConstantTimeEq for u32 {}
impl SupportsConstantTimeEq for u64 {}
impl SupportsConstantTimeEq for u128 {}
impl SupportsConstantTimeEq for usize {}
impl SupportsConstantTimeEq for i8 {}
impl SupportsConstantTimeEq for i16 {}
impl SupportsConstantTimeEq for i32 {}
impl SupportsConstantTimeEq for i64 {}
impl SupportsConstantTimeEq for i128 {}
impl SupportsConstantTimeEq for isize {}
impl SupportsConstantTimeEq for NonZeroU16 {}
impl SupportsConstantTimeEq for [u8] {}
impl SupportsConstantTimeEq for String {}
impl SupportsConstantTimeEq for str {}
}
#[cfg(test)]
mod tests {
use crate::{Equatable, Protected};
#[test]
fn test_opaque_debug() {
let x: Equatable<Protected<[u8; 32]>> = Equatable::new([0u8; 32]);
assert_eq!(
format!("{x:?}"),
"Equatable(vitaminc_protected::protected::Protected<[u8; 32]>(\"***\"))"
);
}
#[test]
fn test_safe_eq_arr() {
let x: Equatable<Protected<[u8; 16]>> = Equatable::from([0u8; 16]);
let y: Equatable<Protected<[u8; 16]>> = Equatable::new([0u8; 16]);
assert_eq!(x, y);
assert!(x.constant_time_eq(&y));
}
#[test]
fn test_equality_u8() {
let x: Equatable<Protected<u8>> = Equatable::new(27);
let y: Equatable<Protected<u8>> = Equatable::new(27);
assert_eq!(x, y);
assert!(x.constant_time_eq(&y));
}
#[test]
fn test_inequality_u8() {
let x: Equatable<Protected<u8>> = Equatable::new(27);
let y: Equatable<Protected<u8>> = Equatable::new(0);
assert_ne!(x, y);
assert!(!x.constant_time_eq(&y));
}
}