nucypher_core/
versioning.rs1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::String;
4use alloc::vec::Vec;
5use core::fmt;
6
7use serde::{Deserialize, Serialize};
8
9pub(crate) fn messagepack_serialize<T>(obj: &T) -> Box<[u8]>
10where
11 T: Serialize,
12{
13 rmp_serde::to_vec(obj)
22 .map(|vec| vec.into_boxed_slice())
23 .expect("Error serializing into MessagePack")
24}
25
26pub(crate) fn messagepack_deserialize<'a, T>(bytes: &'a [u8]) -> Result<T, String>
27where
28 T: Deserialize<'a>,
29{
30 rmp_serde::from_slice(bytes).map_err(|err| format!("{err}"))
31}
32
33struct ProtocolObjectHeader {
34 brand: [u8; 4],
35 major_version: u16,
36 minor_version: u16,
37}
38
39impl ProtocolObjectHeader {
40 fn to_bytes(&self) -> [u8; 8] {
41 let mut header = [0u8; 8];
42 header[..4].copy_from_slice(&self.brand);
43 header[4..6].copy_from_slice(&self.major_version.to_be_bytes());
44 header[6..].copy_from_slice(&self.minor_version.to_be_bytes());
45 header
46 }
47
48 fn from_bytes(bytes: &[u8; 8]) -> Self {
49 Self {
50 brand: [bytes[0], bytes[1], bytes[2], bytes[3]],
51 major_version: u16::from_be_bytes([bytes[4], bytes[5]]),
52 minor_version: u16::from_be_bytes([bytes[6], bytes[7]]),
53 }
54 }
55
56 fn from_type<'a, T>() -> Self
57 where
58 T: ProtocolObjectInner<'a>,
59 {
60 let (major, minor) = T::version();
61 Self {
62 brand: T::brand(),
63 major_version: major,
64 minor_version: minor,
65 }
66 }
67}
68
69#[derive(Debug)]
70pub enum DeserializationError {
71 TooShort {
72 expected: usize,
73 received: usize,
74 },
75 IncorrectHeader {
76 expected: [u8; 4],
77 received: [u8; 4],
78 },
79 MajorVersionMismatch {
80 expected: u16,
81 received: u16,
82 },
83 UnsupportedMinorVersion {
84 expected: u16,
85 received: u16,
86 },
87 BadPayload {
88 error_msg: String,
89 },
90}
91
92impl fmt::Display for DeserializationError {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 match self {
95 Self::TooShort { expected, received } => write!(
96 f,
97 "bytestring too short: expected {expected} bytes, got {received}"
98 ),
99 Self::IncorrectHeader { expected, received } => write!(
100 f,
101 "incorrect header: expected {expected:?}, got {received:?}"
102 ),
103 Self::MajorVersionMismatch { expected, received } => write!(
104 f,
105 "differing major version: expected {expected}, got {received}"
106 ),
107 Self::UnsupportedMinorVersion { expected, received } => write!(
108 f,
109 "unsupported minor version: expected <={expected}, got {received}"
110 ),
111 Self::BadPayload { error_msg } => {
112 write!(f, "payload deserialization failed: {error_msg}")
113 }
114 }
115 }
116}
117
118pub trait ProtocolObjectInner<'a>: Serialize + Deserialize<'a> {
123 fn version() -> (u16, u16);
128
129 fn brand() -> [u8; 4];
131
132 fn unversioned_to_bytes(&self) -> Box<[u8]>;
133
134 fn unversioned_from_bytes(minor_version: u16, bytes: &'a [u8]) -> Option<Result<Self, String>>;
135}
136
137pub trait ProtocolObject<'a>: ProtocolObjectInner<'a> {
139 fn version() -> (u16, u16) {
142 <Self as ProtocolObjectInner>::version()
146 }
147
148 fn to_bytes(&self) -> Box<[u8]> {
150 let header_bytes = ProtocolObjectHeader::from_type::<Self>().to_bytes();
151 let unversioned_bytes = Self::unversioned_to_bytes(self);
152
153 let mut result = Vec::with_capacity(header_bytes.len() + unversioned_bytes.len());
154 result.extend(header_bytes);
155 result.extend(unversioned_bytes.iter());
156 result.into_boxed_slice()
157 }
158
159 fn from_bytes(bytes: &'a [u8]) -> Result<Self, DeserializationError> {
161 if bytes.len() < 8 {
162 return Err(DeserializationError::TooShort {
163 expected: 8,
164 received: bytes.len(),
165 });
166 }
167 let mut header_bytes = [0u8; 8];
168 header_bytes.copy_from_slice(&bytes[..8]);
169 let header = ProtocolObjectHeader::from_bytes(&header_bytes);
170
171 let reference_header = ProtocolObjectHeader::from_type::<Self>();
172
173 if header.brand != reference_header.brand {
174 return Err(DeserializationError::IncorrectHeader {
175 expected: reference_header.brand,
176 received: header.brand,
177 });
178 }
179
180 if header.major_version != reference_header.major_version {
181 return Err(DeserializationError::MajorVersionMismatch {
182 expected: reference_header.major_version,
183 received: header.major_version,
184 });
185 }
186
187 if header.minor_version > reference_header.minor_version {
188 return Err(DeserializationError::UnsupportedMinorVersion {
189 expected: reference_header.minor_version,
190 received: header.minor_version,
191 });
192 }
193
194 let result = match Self::unversioned_from_bytes(header.minor_version, &bytes[8..]) {
195 Some(result) => result,
196 None => panic!("minor version {} is not supported", header.minor_version),
200 };
201
202 result.map_err(|msg| DeserializationError::BadPayload { error_msg: msg })
203 }
204}