use core::convert::Infallible;
use core::fmt::Debug;
#[cfg(any(feature = "std", test))]
use std::error::Error;
use displaydoc::Display;
#[derive(Clone, Copy, Display, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum InternalError<T = Infallible> {
Custom(T),
InvalidByteSequence,
SizeError {
name: &'static str,
len: usize,
actual_len: usize,
},
PointError,
HashToScalar,
HkdfError,
HmacError,
KsfError,
SealOpenHmacError,
IncompatibleEnvelopeModeError,
OprfError(voprf::Error),
OprfInternalError(voprf::InternalError),
}
impl<T: Debug> Debug for InternalError<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Custom(custom) => f.debug_tuple("InvalidByteSequence").field(custom).finish(),
Self::InvalidByteSequence => f.debug_tuple("InvalidByteSequence").finish(),
Self::SizeError {
name,
len,
actual_len,
} => f
.debug_struct("SizeError")
.field("name", name)
.field("len", len)
.field("actual_len", actual_len)
.finish(),
Self::PointError => f.debug_tuple("PointError").finish(),
Self::HashToScalar => f.debug_tuple("HashToScalar").finish(),
Self::HkdfError => f.debug_tuple("HkdfError").finish(),
Self::HmacError => f.debug_tuple("HmacError").finish(),
Self::KsfError => f.debug_tuple("KsfError").finish(),
Self::SealOpenHmacError => f.debug_tuple("SealOpenHmacError").finish(),
Self::IncompatibleEnvelopeModeError => {
f.debug_tuple("IncompatibleEnvelopeModeError").finish()
}
Self::OprfError(error) => f.debug_tuple("OprfError").field(error).finish(),
Self::OprfInternalError(error) => {
f.debug_tuple("OprfInternalError").field(error).finish()
}
}
}
}
#[cfg(any(feature = "std", test))]
impl<T: Error> Error for InternalError<T> {}
impl InternalError {
pub fn into_custom<T>(self) -> InternalError<T> {
match self {
Self::Custom(_) => unreachable!(),
Self::InvalidByteSequence => InternalError::InvalidByteSequence,
Self::SizeError {
name,
len,
actual_len,
} => InternalError::SizeError {
name,
len,
actual_len,
},
Self::PointError => InternalError::PointError,
Self::HashToScalar => InternalError::HashToScalar,
Self::HkdfError => InternalError::HkdfError,
Self::HmacError => InternalError::HmacError,
Self::KsfError => InternalError::KsfError,
Self::SealOpenHmacError => InternalError::SealOpenHmacError,
Self::IncompatibleEnvelopeModeError => InternalError::IncompatibleEnvelopeModeError,
Self::OprfError(error) => InternalError::OprfError(error),
Self::OprfInternalError(error) => InternalError::OprfInternalError(error),
}
}
}
impl From<voprf::Error> for InternalError {
fn from(voprf_error: voprf::Error) -> Self {
Self::OprfError(voprf_error)
}
}
impl From<voprf::Error> for ProtocolError {
fn from(voprf_error: voprf::Error) -> Self {
Self::LibraryError(InternalError::OprfError(voprf_error))
}
}
impl From<voprf::InternalError> for ProtocolError {
fn from(voprf_error: voprf::InternalError) -> Self {
Self::LibraryError(InternalError::OprfInternalError(voprf_error))
}
}
#[derive(Clone, Copy, Display, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum ProtocolError<T = Infallible> {
LibraryError(InternalError<T>),
InvalidLoginError,
SerializationError,
ReflectedValueError,
IdentityGroupElementError,
}
impl<T: Debug> Debug for ProtocolError<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::LibraryError(pake_error) => {
f.debug_tuple("LibraryError").field(pake_error).finish()
}
Self::InvalidLoginError => f.debug_tuple("InvalidLoginError").finish(),
Self::SerializationError => f.debug_tuple("SerializationError").finish(),
Self::ReflectedValueError => f.debug_tuple("ReflectedValueError").finish(),
Self::IdentityGroupElementError => f.debug_tuple("IdentityGroupElementError").finish(),
}
}
}
#[cfg(any(feature = "std", test))]
impl<T: Error> Error for ProtocolError<T> {}
impl<T> From<InternalError<T>> for ProtocolError<T> {
fn from(e: InternalError<T>) -> ProtocolError<T> {
Self::LibraryError(e)
}
}
impl<T> From<::core::convert::Infallible> for ProtocolError<T> {
fn from(_: ::core::convert::Infallible) -> Self {
unreachable!()
}
}
impl ProtocolError {
pub fn into_custom<T>(self) -> ProtocolError<T> {
match self {
Self::LibraryError(internal_error) => {
ProtocolError::LibraryError(internal_error.into_custom())
}
Self::InvalidLoginError => ProtocolError::InvalidLoginError,
Self::SerializationError => ProtocolError::SerializationError,
Self::ReflectedValueError => ProtocolError::ReflectedValueError,
Self::IdentityGroupElementError => ProtocolError::IdentityGroupElementError,
}
}
}
pub(crate) mod utils {
use super::*;
pub fn check_slice_size<'a, T>(
slice: &'a [u8],
expected_len: usize,
arg_name: &'static str,
) -> Result<&'a [u8], InternalError<T>> {
if slice.len() != expected_len {
return Err(InternalError::SizeError {
name: arg_name,
len: expected_len,
actual_len: slice.len(),
});
}
Ok(slice)
}
pub fn check_slice_size_atleast<'a>(
slice: &'a [u8],
expected_len: usize,
arg_name: &'static str,
) -> Result<&'a [u8], InternalError> {
if slice.len() < expected_len {
return Err(InternalError::SizeError {
name: arg_name,
len: expected_len,
actual_len: slice.len(),
});
}
Ok(slice)
}
}