channels_serdes/
hmac.rs

1//! Middleware that verifies data with HMAC.
2
3use core::fmt;
4
5use alloc::vec::Vec;
6
7use ring::hmac;
8
9use crate::{Deserializer, Serializer};
10
11/// Algorithms usable with [`Key`].
12///
13/// This module reexports the algorithms from [`ring::hmac`].
14pub mod algorithms {
15	pub use super::hmac::{
16		HMAC_SHA1_FOR_LEGACY_USE_ONLY, HMAC_SHA256, HMAC_SHA384,
17		HMAC_SHA512,
18	};
19}
20
21pub use self::hmac::Key;
22
23/// Middleware that verifies data with HMAC.
24#[derive(Debug, Clone)]
25pub struct Hmac<U> {
26	next: U,
27	key: Key,
28}
29
30impl<U> Hmac<U> {
31	/// Create a new [`Hmac`] middleware that uses `key`.
32	pub fn new(next: U, key: Key) -> Self {
33		Self { next, key }
34	}
35
36	/// Get a reference to the next serializer/deserializer in the chain.
37	pub fn next_ref(&self) -> &U {
38		&self.next
39	}
40
41	/// Get a reference to the next serializer/deserializer in the chain.
42	pub fn next_mut(&mut self) -> &mut U {
43		&mut self.next
44	}
45
46	/// Consume `self` and return the next serializer/deserializer in the chain.
47	pub fn into_next(self) -> U {
48		self.next
49	}
50}
51
52impl<T, U> Serializer<T> for Hmac<U>
53where
54	U: Serializer<T>,
55{
56	type Error = U::Error;
57
58	fn serialize(&mut self, t: &T) -> Result<Vec<u8>, Self::Error> {
59		let data = self.next.serialize(t)?;
60		let tag = hmac::sign(&self.key, &data);
61
62		let out = [data.as_slice(), tag.as_ref()].concat();
63		Ok(out)
64	}
65}
66
67/// Possible errors that might occur during deserialization.
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub enum DeserializeError<T> {
70	/// The data could not be verified because the HMAC does not match.
71	VerifyFail,
72	/// The data does not have a tag.
73	NoTag,
74	/// An error from the next deserializer in the chain.
75	Next(T),
76}
77
78impl<T> fmt::Display for DeserializeError<T>
79where
80	T: fmt::Display,
81{
82	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83		match self {
84			Self::Next(e) => e.fmt(f),
85			Self::NoTag => f.write_str("no tag"),
86			Self::VerifyFail => f.write_str("verification failure"),
87		}
88	}
89}
90
91#[cfg(feature = "std")]
92impl<T: std::error::Error> std::error::Error for DeserializeError<T> {}
93
94impl<T, U> Deserializer<T> for Hmac<U>
95where
96	U: Deserializer<T>,
97{
98	type Error = DeserializeError<U::Error>;
99
100	fn deserialize(
101		&mut self,
102		buf: &mut [u8],
103	) -> Result<T, Self::Error> {
104		let tag_len =
105			self.key.algorithm().digest_algorithm().output_len();
106
107		let tag_start = buf
108			.len()
109			.checked_sub(tag_len)
110			.ok_or(DeserializeError::NoTag)?;
111		let (data, tag) = buf.split_at_mut(tag_start);
112
113		hmac::verify(&self.key, data, tag)
114			.map_err(|_| DeserializeError::VerifyFail)?;
115
116		self.next.deserialize(data).map_err(DeserializeError::Next)
117	}
118}