channels_serdes/
crc.rs

1//! Middleware that verifies data with a CRC checksum.
2
3use core::fmt;
4
5use alloc::boxed::Box;
6use alloc::vec::Vec;
7
8use crate::{Deserializer, Serializer};
9
10/// Middleware that verifies data with a CRC checksum.
11///
12/// When working as a [`Serializer`], it simply computes an 8 byte CRC checksum
13/// of the data it was given and returns the original data with the checksum
14/// appended to the end in big-endian format. When working as a [`Deserializer`],
15/// it reads the checksum of the data (the last 8 bytes), computes the checksum
16/// again from the read data and then compares the 2 checksums. If don't match,
17/// the [`Deserializer::deserialize()`] returns with [`Err(DeserializeError::InvalidChecksum)`].
18/// If the 2 checksums match, the data is then given to the next deserialize in
19/// the chain. Any errors from the next deserializer are returned via [`Err(DeserializeError::Next)`].
20///
21/// [`Err(DeserializeError::InvalidChecksum)`]: DeserializeError::InvalidChecksum
22/// [`Err(DeserializeError::Next)`]: DeserializeError::Next
23pub struct Crc<U> {
24	next: U,
25	crc: Box<crc::Crc<u64>>,
26}
27
28impl<U> fmt::Debug for Crc<U>
29where
30	U: fmt::Debug,
31{
32	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33		f.debug_struct("Crc")
34			.field("next", &self.next)
35			.finish_non_exhaustive()
36	}
37}
38
39impl<U> Clone for Crc<U>
40where
41	U: Clone,
42{
43	fn clone(&self) -> Self {
44		Self::new(self.next.clone(), self.crc.algorithm)
45	}
46}
47
48impl<U> Default for Crc<U>
49where
50	U: Default,
51{
52	fn default() -> Self {
53		Self::new(Default::default(), Self::DEFAULT_ALGORITHM)
54	}
55}
56
57impl<U> Crc<U> {
58	const DEFAULT_ALGORITHM: &'static crc::Algorithm<u64> =
59		&crc::CRC_64_XZ;
60
61	/// Create a new [`Crc`] middleware.
62	pub fn new(
63		next: U,
64		algorithm: &'static crc::Algorithm<u64>,
65	) -> Self {
66		Self { next, crc: Box::new(crc::Crc::<u64>::new(algorithm)) }
67	}
68
69	/// Get a reference to the next serializer in the chain.
70	pub fn next_ref(&self) -> &U {
71		&self.next
72	}
73
74	/// Get a reference to the next serializer in the chain.
75	pub fn next_mut(&mut self) -> &mut U {
76		&mut self.next
77	}
78
79	/// Consume `self` and return the next serializer in the chain.
80	pub fn into_next(self) -> U {
81		self.next
82	}
83}
84
85impl<T, U> Serializer<T> for Crc<U>
86where
87	U: Serializer<T>,
88{
89	type Error = U::Error;
90
91	fn serialize(&mut self, t: &T) -> Result<Vec<u8>, Self::Error> {
92		let data = self.next.serialize(t)?;
93		let checksum = self.crc.checksum(&data);
94
95		let checksum_bytes = checksum.to_be_bytes();
96		let out =
97			[data.as_slice(), checksum_bytes.as_slice()].concat();
98		Ok(out)
99	}
100}
101
102/// Possible errors that might occur during deserialization.
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum DeserializeError<T> {
105	/// The data could not be verified because the checksum is not correct.
106	InvalidChecksum,
107	/// No checksum exists in the data.
108	NoChecksum,
109	/// An error from the next deserializer in the chain.
110	Next(T),
111}
112
113impl<T> fmt::Display for DeserializeError<T>
114where
115	T: fmt::Display,
116{
117	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118		match self {
119			Self::Next(e) => e.fmt(f),
120			Self::InvalidChecksum => f.write_str("invalid checksum"),
121			Self::NoChecksum => f.write_str("no checksum"),
122		}
123	}
124}
125
126#[cfg(feature = "std")]
127impl<T: std::error::Error> std::error::Error for DeserializeError<T> {}
128
129impl<T, U> Deserializer<T> for Crc<U>
130where
131	U: Deserializer<T>,
132{
133	type Error = DeserializeError<U::Error>;
134
135	fn deserialize(
136		&mut self,
137		buf: &mut [u8],
138	) -> Result<T, Self::Error> {
139		let inner_len = buf
140			.len()
141			.checked_sub(8)
142			.ok_or(DeserializeError::NoChecksum)?;
143
144		let (inner, checksum) = buf.split_at_mut(inner_len);
145
146		let unverified = u64::from_be_bytes(checksum.try_into().expect(
147			"remaining part of payload should have been at least 8 bytes",
148		));
149
150		let calculated = self.crc.checksum(inner);
151
152		if unverified != calculated {
153			return Err(DeserializeError::InvalidChecksum);
154		}
155
156		self.next.deserialize(inner).map_err(DeserializeError::Next)
157	}
158}