use alloc::vec::Vec;
use core::{fmt::Debug, marker::PhantomData};
use crate::codec::{Decode, Encode};
pub trait MacPrimitive {
type Key;
type Tag: Clone + Eq;
type Error: Debug;
fn mac(key: &Self::Key, message: &[u8]) -> Self::Tag;
fn verify(key: &Self::Key, message: &[u8], tag: &Self::Tag) -> Result<(), Self::Error>;
}
pub struct Mac<T, M: MacPrimitive, C> {
tag: M::Tag,
encoded_payload: Vec<u8>,
_marker: PhantomData<fn() -> (T, C)>,
}
impl<T, M: MacPrimitive, C> Clone for Mac<T, M, C>
where
M::Tag: Clone,
{
fn clone(&self) -> Self {
Self {
tag: self.tag.clone(),
encoded_payload: self.encoded_payload.clone(),
_marker: PhantomData,
}
}
}
impl<T, M: MacPrimitive, C> PartialEq for Mac<T, M, C>
where
M::Tag: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.tag == other.tag && self.encoded_payload == other.encoded_payload
}
}
impl<T, M: MacPrimitive, C> Eq for Mac<T, M, C> where M::Tag: Eq {}
impl<T, M: MacPrimitive, C> core::fmt::Debug for Mac<T, M, C>
where
M::Tag: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Mac")
.field("tag", &self.tag)
.field("encoded_payload", &self.encoded_payload)
.finish()
}
}
impl<T, M: MacPrimitive, C> Mac<T, M, C> {
#[must_use]
pub(crate) fn new(tag: M::Tag, encoded_payload: Vec<u8>) -> Self {
Self {
tag,
encoded_payload,
_marker: PhantomData,
}
}
#[must_use]
#[allow(clippy::expect_used)] pub fn tag(key: &M::Key, payload: &T) -> Self
where
C: Encode<T>,
{
let encoded = C::encode(payload).expect("encoding failed");
let tag = M::mac(key, &encoded);
Self::new(tag, encoded)
}
pub fn try_verify(&self, key: &M::Key) -> Result<T, MacError>
where
C: Decode<T>,
{
M::verify(key, &self.encoded_payload, &self.tag).map_err(|_| MacError::InvalidMac)?;
C::decode(&self.encoded_payload).map_err(|_| MacError::DecodeError)
}
#[must_use]
pub const fn mac_tag(&self) -> &M::Tag {
&self.tag
}
#[must_use]
pub fn encoded_payload(&self) -> &[u8] {
&self.encoded_payload
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MacError {
InvalidMac,
DecodeError,
}
impl core::fmt::Display for MacError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::InvalidMac => write!(f, "invalid MAC"),
Self::DecodeError => write!(f, "payload decode error"),
}
}
}
pub trait MacUnchecked<T, M: MacPrimitive, C> {
fn from_unchecked_parts(tag: M::Tag, encoded_payload: Vec<u8>) -> Self;
}
impl<T, M: MacPrimitive, C> MacUnchecked<T, M, C> for Mac<T, M, C> {
fn from_unchecked_parts(tag: M::Tag, encoded_payload: Vec<u8>) -> Self {
Self::new(tag, encoded_payload)
}
}
#[cfg(feature = "serde")]
impl<T, M: MacPrimitive, C> serde::Serialize for Mac<T, M, C>
where
M::Tag: serde::Serialize,
{
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("Mac", 2)?;
state.serialize_field("tag", &self.tag)?;
state.serialize_field("encoded_payload", &self.encoded_payload)?;
state.end()
}
}
#[cfg(feature = "serde")]
impl<'de, T, M: MacPrimitive, C> serde::Deserialize<'de> for Mac<T, M, C>
where
M::Tag: serde::Deserialize<'de>,
{
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use serde::de::{MapAccess, Visitor};
struct MacVisitor<T, M: MacPrimitive, C>(PhantomData<(T, M, C)>);
impl<'de, T, M: MacPrimitive, C> Visitor<'de> for MacVisitor<T, M, C>
where
M::Tag: serde::Deserialize<'de>,
{
type Value = Mac<T, M, C>;
fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
formatter.write_str("struct Mac")
}
fn visit_map<V: MapAccess<'de>>(self, mut map: V) -> Result<Mac<T, M, C>, V::Error> {
let mut tag = None;
let mut encoded_payload = None;
while let Some(key) = map.next_key::<&str>()? {
match key {
"tag" => tag = Some(map.next_value()?),
"encoded_payload" => encoded_payload = Some(map.next_value()?),
_ => {
let _: serde::de::IgnoredAny = map.next_value()?;
}
}
}
let tag = tag.ok_or_else(|| serde::de::Error::missing_field("tag"))?;
let encoded_payload = encoded_payload
.ok_or_else(|| serde::de::Error::missing_field("encoded_payload"))?;
Ok(Mac::new(tag, encoded_payload))
}
}
const FIELDS: &[&str] = &["tag", "encoded_payload"];
deserializer.deserialize_struct("Mac", FIELDS, MacVisitor(PhantomData))
}
}
#[cfg(feature = "arbitrary")]
impl<'a, T, M: MacPrimitive, C> arbitrary::Arbitrary<'a> for Mac<T, M, C>
where
M::Tag: arbitrary::Arbitrary<'a>,
{
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let tag = M::Tag::arbitrary(u)?;
let encoded_payload = Vec::arbitrary(u)?;
Ok(Self::new(tag, encoded_payload))
}
}
#[cfg(feature = "bolero")]
impl<T: 'static, M: MacPrimitive + 'static, C: 'static> bolero_generator::TypeGenerator
for Mac<T, M, C>
where
M::Tag: bolero_generator::TypeGenerator,
{
fn generate<D: bolero_generator::Driver>(driver: &mut D) -> Option<Self> {
let tag = M::Tag::generate(driver)?;
let encoded_payload = Vec::generate(driver)?;
Some(Self::new(tag, encoded_payload))
}
}
#[cfg(feature = "proptest")]
impl<T: 'static, M: MacPrimitive + 'static, C: 'static> proptest::arbitrary::Arbitrary
for Mac<T, M, C>
where
M::Tag: proptest::arbitrary::Arbitrary + 'static,
{
type Parameters = ();
type Strategy = proptest::strategy::BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
use proptest::prelude::*;
(
any::<M::Tag>(),
proptest::collection::vec(any::<u8>(), 0..256),
)
.prop_map(|(tag, encoded_payload)| Self::new(tag, encoded_payload))
.boxed()
}
}
#[cfg(feature = "rkyv")]
pub mod archive {
use super::{Mac, MacPrimitive};
use alloc::vec::Vec;
use rkyv::{Archive, Archived, Deserialize, Serialize, rancor::Fallible};
impl<T, M: MacPrimitive, C> Archive for Mac<T, M, C>
where
M::Tag: Archive,
{
type Archived = ArchivedMac<M::Tag>;
type Resolver = MacResolver<M::Tag>;
fn resolve(&self, resolver: Self::Resolver, out: rkyv::Place<Self::Archived>) {
let helper = MacHelper {
tag: self.tag.clone(),
encoded_payload: self.encoded_payload.clone(),
};
helper.resolve(resolver, out);
}
}
impl<T, M: MacPrimitive, C, S> Serialize<S> for Mac<T, M, C>
where
M::Tag: Serialize<S>,
S: Fallible + rkyv::ser::Allocator + rkyv::ser::Writer + ?Sized,
{
fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
let helper = MacHelper {
tag: self.tag.clone(),
encoded_payload: self.encoded_payload.clone(),
};
helper.serialize(serializer)
}
}
impl<T, M: MacPrimitive, C, D> Deserialize<Mac<T, M, C>, D> for ArchivedMac<M::Tag>
where
M::Tag: Archive,
Archived<M::Tag>: Deserialize<M::Tag, D>,
D: Fallible + ?Sized,
D::Error: rkyv::rancor::Source,
{
fn deserialize(&self, deserializer: &mut D) -> Result<Mac<T, M, C>, D::Error> {
let helper: MacHelper<M::Tag> = <ArchivedMac<M::Tag> as Deserialize<
MacHelper<M::Tag>,
D,
>>::deserialize(self, deserializer)?;
Ok(Mac::new(helper.tag, helper.encoded_payload))
}
}
#[derive(Debug, Archive, Serialize, Deserialize)]
pub struct MacHelper<Tag> {
tag: Tag,
encoded_payload: Vec<u8>,
}
pub type ArchivedMac<Tag> = ArchivedMacHelper<Tag>;
pub type MacResolver<Tag> = MacHelperResolver<Tag>;
}
#[cfg(feature = "hmac")]
pub mod hmac;