1use std::{fmt, ops::Deref, time::Duration};
5
6use sha2::{Digest, Sha256};
7use zeroize::Zeroize;
8
9pub const MESSAGE_ID_SIZE: usize = 32;
10pub const MESSAGE_HEADER_SIZE: usize = MESSAGE_ID_SIZE + 2 + 2;
11
12#[derive(Debug, Copy, Clone, PartialEq, Zeroize)]
13pub struct InstanceId([u8; 32]);
14
15impl InstanceId {
16 pub fn new(bytes: [u8; 32]) -> Self {
17 Self(bytes)
18 }
19}
20
21impl From<[u8; 32]> for InstanceId {
22 fn from(bytes: [u8; 32]) -> Self {
23 Self::new(bytes)
24 }
25}
26
27#[derive(Debug, Copy, Clone, PartialEq)]
28pub struct MessageTag([u8; 8]);
29
30impl MessageTag {
31 pub const fn tag(tag: u64) -> Self {
32 Self(tag.to_le_bytes())
33 }
34
35 pub const fn tag1(tag: u32, param: u32) -> Self {
37 Self::tag(tag as u64 | ((param as u64) << 32))
38 }
39
40 pub const fn tag2(tag: u32, param1: u16, param2: u16) -> Self {
42 Self::tag(tag as u64 | (param1 as u64) << 32 | (param2 as u64) << 48)
43 }
44
45 pub const fn to_bytes(&self) -> [u8; 8] {
47 self.0
48 }
49}
50
51#[derive(
52 PartialEq,
53 Clone,
54 Copy,
55 Hash,
56 PartialOrd,
57 Eq,
58 bytemuck::AnyBitPattern,
59 bytemuck::NoUninit,
60)]
61#[repr(C)]
62pub struct MsgId([u8; MESSAGE_ID_SIZE]);
63
64impl Deref for MsgId {
65 type Target = [u8];
66
67 fn deref(&self) -> &Self::Target {
68 &self.0
69 }
70}
71
72impl fmt::Debug for MsgId {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(f, "MsgId({self:X})")
75 }
76}
77
78impl fmt::UpperHex for MsgId {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 for b in &self.0 {
81 write!(f, "{:02X}", b)?;
82 }
83 Ok(())
84 }
85}
86
87impl fmt::LowerHex for MsgId {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 for b in &self.0 {
90 write!(f, "{:02x}", b)?;
91 }
92 Ok(())
93 }
94}
95
96impl MsgId {
97 pub const ZERO_ID: MsgId = MsgId([0; MESSAGE_ID_SIZE]);
98
99 pub fn new(
102 instance: &InstanceId,
103 sender: &[u8],
104 receiver: Option<&[u8]>,
105 tag: MessageTag,
106 ) -> Self {
107 Self(
108 Sha256::default()
109 .chain_update(tag.to_bytes())
110 .chain_update(sender)
111 .chain_update(receiver.unwrap_or(&[]))
112 .chain_update(instance.0)
113 .finalize()
114 .into(),
115 )
116 }
117
118 pub fn broadcast(
120 instance: &InstanceId,
121 sender: &[u8],
122 tag: MessageTag,
123 ) -> Self {
124 Self::new(instance, sender, None, tag)
125 }
126
127 pub fn as_slice(&self) -> &[u8] {
129 &self.0
130 }
131}
132
133impl From<[u8; MESSAGE_ID_SIZE]> for MsgId {
134 fn from(id: [u8; MESSAGE_ID_SIZE]) -> Self {
135 Self(id)
136 }
137}
138
139impl<'a> TryFrom<&'a [u8]> for &'a MsgId {
142 type Error = ();
143
144 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
145 value
146 .first_chunk::<MESSAGE_ID_SIZE>()
147 .and_then(|id| bytemuck::try_cast_ref(id).ok())
148 .ok_or(())
149 }
150}
151
152impl<'a> TryFrom<&'a [u8]> for MsgId {
154 type Error = ();
155
156 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
157 let msg_id: &MsgId = value.try_into()?;
158 Ok(*msg_id)
159 }
160}
161
162impl<'a> From<&'a MsgHdr> for MsgId {
164 fn from(value: &MsgHdr) -> Self {
165 *value.id()
166 }
167}
168
169#[derive(Debug, Eq, Copy, Clone, PartialEq)]
170pub enum Kind {
171 Ask,
172 Pub,
173}
174
175#[derive(Clone, Copy, bytemuck::AnyBitPattern, bytemuck::NoUninit)]
176#[repr(C)]
177pub struct MsgHdr {
178 data: [u8; MESSAGE_HEADER_SIZE],
179}
180
181impl fmt::Debug for MsgHdr {
182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183 write!(
184 f,
185 "MsgHdr(id: {:X}, flags: {:04X}, ttl: {})",
186 self.id(),
187 self.flags(),
188 self.ttl().as_secs(),
189 )
190 }
191}
192
193impl<'a> TryFrom<&'a [u8]> for &'a MsgHdr {
196 type Error = ();
197
198 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
199 value
200 .first_chunk::<MESSAGE_HEADER_SIZE>()
201 .and_then(|hdr| bytemuck::try_cast_ref(hdr).ok())
202 .ok_or(())
203 }
204}
205
206impl<'a> TryFrom<&'a [u8]> for MsgHdr {
208 type Error = ();
209
210 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
211 let hdr: &MsgHdr = value.try_into()?;
212 Ok(*hdr)
213 }
214}
215
216impl MsgHdr {
217 pub fn id(&self) -> &MsgId {
219 self.data[..MESSAGE_ID_SIZE].try_into().unwrap()
220 }
221
222 pub fn flags(&self) -> u16 {
224 u16::from_le_bytes(
225 self.data[MESSAGE_ID_SIZE..][2..].try_into().unwrap(),
226 )
227 }
228
229 pub fn ttl(&self) -> Duration {
231 let secs: u16 = u16::from_le_bytes(
232 self.data[MESSAGE_ID_SIZE..][..2].try_into().unwrap(),
233 );
234
235 Duration::from_secs(secs as u64)
236 }
237
238 pub fn encode(
240 hdr: &mut [u8; MESSAGE_HEADER_SIZE],
241 id: &MsgId,
242 ttl: u32,
243 flags: u16,
244 ) {
245 let data: u32 = (ttl & 0xffff) | (flags as u32) << 16;
246
247 hdr[..MESSAGE_ID_SIZE].copy_from_slice(&id.0);
248 hdr[MESSAGE_ID_SIZE..].copy_from_slice(&data.to_le_bytes());
249 }
250}
251
252pub fn allocate_message(
256 id: &MsgId,
257 ttl: u32,
258 flags: u16,
259 payload: &[u8],
260) -> Vec<u8> {
261 let mut buffer = Vec::with_capacity(MESSAGE_HEADER_SIZE + payload.len());
262
263 buffer.resize(MESSAGE_HEADER_SIZE, 0);
264
265 MsgHdr::encode(buffer.as_mut_slice().try_into().unwrap(), id, ttl, flags);
266
267 buffer.extend_from_slice(payload);
268
269 buffer
270}
271
272pub struct AskMsg;
273
274impl AskMsg {
275 pub fn allocate(id: &MsgId, ttl: u32) -> Vec<u8> {
276 allocate_message(id, ttl, 0, &[])
277 }
278}
279
280#[cfg(test)]
281mod test {
282 use super::*;
283
284 #[test]
285 fn msg_hdr() {
286 let data = [0u8; MESSAGE_HEADER_SIZE + 1];
287
288 assert!(<&MsgHdr>::try_from(&data[..]).is_ok());
289
290 assert!(<&MsgHdr>::try_from(&data[..MESSAGE_HEADER_SIZE]).is_ok());
291
292 assert!(
293 <&MsgHdr>::try_from(&data[..MESSAGE_HEADER_SIZE - 1]).is_err()
294 );
295 }
296
297 #[test]
298 fn msg_tags() {
299 let t1 = MessageTag::tag(0x1020304050607080);
300
301 assert_eq!(
302 t1.to_bytes(),
303 [0x80, 0x70, 0x60, 0x50, 0x40, 0x30, 0x20, 0x10]
304 );
305
306 let t2 = MessageTag::tag1(0x10203040, 0xAABBCCDD);
307
308 assert_eq!(
309 t2.to_bytes(),
310 [0x40, 0x30, 0x20, 0x10, 0xDD, 0xCC, 0xBB, 0xAA]
311 );
312
313 let t3 = MessageTag::tag2(0x10203040, 0xEEFF, 0xDEAD);
314
315 assert_eq!(
316 t3.to_bytes(),
317 [0x40, 0x30, 0x20, 0x10, 0xFF, 0xEE, 0xAD, 0xDE]
318 );
319 }
320}