1use core::fmt;
4
5use alloc::vec::Vec;
6
7use ring::hmac;
8
9use crate::{Deserializer, Serializer};
10
11pub mod algorithms {
15 pub use super::hmac::{
16 HMAC_SHA1_FOR_LEGACY_USE_ONLY, HMAC_SHA256, HMAC_SHA384,
17 HMAC_SHA512,
18 };
19}
20
21pub use self::hmac::Key;
22
23#[derive(Debug, Clone)]
25pub struct Hmac<U> {
26 next: U,
27 key: Key,
28}
29
30impl<U> Hmac<U> {
31 pub fn new(next: U, key: Key) -> Self {
33 Self { next, key }
34 }
35
36 pub fn next_ref(&self) -> &U {
38 &self.next
39 }
40
41 pub fn next_mut(&mut self) -> &mut U {
43 &mut self.next
44 }
45
46 pub fn into_next(self) -> U {
48 self.next
49 }
50}
51
52impl<T, U> Serializer<T> for Hmac<U>
53where
54 U: Serializer<T>,
55{
56 type Error = U::Error;
57
58 fn serialize(&mut self, t: &T) -> Result<Vec<u8>, Self::Error> {
59 let data = self.next.serialize(t)?;
60 let tag = hmac::sign(&self.key, &data);
61
62 let out = [data.as_slice(), tag.as_ref()].concat();
63 Ok(out)
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
69pub enum DeserializeError<T> {
70 VerifyFail,
72 NoTag,
74 Next(T),
76}
77
78impl<T> fmt::Display for DeserializeError<T>
79where
80 T: fmt::Display,
81{
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 Self::Next(e) => e.fmt(f),
85 Self::NoTag => f.write_str("no tag"),
86 Self::VerifyFail => f.write_str("verification failure"),
87 }
88 }
89}
90
91#[cfg(feature = "std")]
92impl<T: std::error::Error> std::error::Error for DeserializeError<T> {}
93
94impl<T, U> Deserializer<T> for Hmac<U>
95where
96 U: Deserializer<T>,
97{
98 type Error = DeserializeError<U::Error>;
99
100 fn deserialize(
101 &mut self,
102 buf: &mut [u8],
103 ) -> Result<T, Self::Error> {
104 let tag_len =
105 self.key.algorithm().digest_algorithm().output_len();
106
107 let tag_start = buf
108 .len()
109 .checked_sub(tag_len)
110 .ok_or(DeserializeError::NoTag)?;
111 let (data, tag) = buf.split_at_mut(tag_start);
112
113 hmac::verify(&self.key, data, tag)
114 .map_err(|_| DeserializeError::VerifyFail)?;
115
116 self.next.deserialize(data).map_err(DeserializeError::Next)
117 }
118}