sl_mpc_mate/
message.rs

1// Copyright (c) Silence Laboratories Pte. Ltd. All Rights Reserved.
2// This software is licensed under the Silence Laboratories License Agreement.
3
4use std::{fmt, ops::Deref, time::Duration};
5
6use sha2::{Digest, Sha256};
7use zeroize::Zeroize;
8
9pub const MESSAGE_ID_SIZE: usize = 32;
10pub const MESSAGE_HEADER_SIZE: usize = MESSAGE_ID_SIZE + 2 + 2;
11
12#[derive(Debug, Copy, Clone, PartialEq, Zeroize)]
13pub struct InstanceId([u8; 32]);
14
15impl InstanceId {
16    pub fn new(bytes: [u8; 32]) -> Self {
17        Self(bytes)
18    }
19}
20
21impl From<[u8; 32]> for InstanceId {
22    fn from(bytes: [u8; 32]) -> Self {
23        Self::new(bytes)
24    }
25}
26
27#[derive(Debug, Copy, Clone, PartialEq)]
28pub struct MessageTag([u8; 8]);
29
30impl MessageTag {
31    pub const fn tag(tag: u64) -> Self {
32        Self(tag.to_le_bytes())
33    }
34
35    /// Define a family of tags indexed by some parameter.
36    pub const fn tag1(tag: u32, param: u32) -> Self {
37        Self::tag(tag as u64 | ((param as u64) << 32))
38    }
39
40    /// Define a familty of tags indexed by pair of parameters.
41    pub const fn tag2(tag: u32, param1: u16, param2: u16) -> Self {
42        Self::tag(tag as u64 | (param1 as u64) << 32 | (param2 as u64) << 48)
43    }
44
45    /// Convert the tag to an array of bytes.
46    pub const fn to_bytes(&self) -> [u8; 8] {
47        self.0
48    }
49}
50
51#[derive(
52    PartialEq,
53    Clone,
54    Copy,
55    Hash,
56    PartialOrd,
57    Eq,
58    bytemuck::AnyBitPattern,
59    bytemuck::NoUninit,
60)]
61#[repr(C)]
62pub struct MsgId([u8; MESSAGE_ID_SIZE]);
63
64impl Deref for MsgId {
65    type Target = [u8];
66
67    fn deref(&self) -> &Self::Target {
68        &self.0
69    }
70}
71
72impl fmt::Debug for MsgId {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        write!(f, "MsgId({self:X})")
75    }
76}
77
78impl fmt::UpperHex for MsgId {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        for b in &self.0 {
81            write!(f, "{:02X}", b)?;
82        }
83        Ok(())
84    }
85}
86
87impl fmt::LowerHex for MsgId {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        for b in &self.0 {
90            write!(f, "{:02x}", b)?;
91        }
92        Ok(())
93    }
94}
95
96impl MsgId {
97    pub const ZERO_ID: MsgId = MsgId([0; MESSAGE_ID_SIZE]);
98
99    /// Create message ID for given instance id, sender, receiver and
100    /// message tag.
101    pub fn new(
102        instance: &InstanceId,
103        sender: &[u8],
104        receiver: Option<&[u8]>,
105        tag: MessageTag,
106    ) -> Self {
107        Self(
108            Sha256::default()
109                .chain_update(tag.to_bytes())
110                .chain_update(sender)
111                .chain_update(receiver.unwrap_or(&[]))
112                .chain_update(instance.0)
113                .finalize()
114                .into(),
115        )
116    }
117
118    /// Create message ID for a broadcast message, without a designated receiver.
119    pub fn broadcast(
120        instance: &InstanceId,
121        sender: &[u8],
122        tag: MessageTag,
123    ) -> Self {
124        Self::new(instance, sender, None, tag)
125    }
126
127    /// Return as slice of bytes
128    pub fn as_slice(&self) -> &[u8] {
129        &self.0
130    }
131}
132
133impl From<[u8; MESSAGE_ID_SIZE]> for MsgId {
134    fn from(id: [u8; MESSAGE_ID_SIZE]) -> Self {
135        Self(id)
136    }
137}
138
139// Try to convert a byte slice into a reference to MsgId. It will
140// succeed if passed slice is at least MESSAGE_ID_SIZE bytes.
141impl<'a> TryFrom<&'a [u8]> for &'a MsgId {
142    type Error = ();
143
144    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
145        value
146            .first_chunk::<MESSAGE_ID_SIZE>()
147            .and_then(|id| bytemuck::try_cast_ref(id).ok())
148            .ok_or(())
149    }
150}
151
152// The same as above but return MsgId value.
153impl<'a> TryFrom<&'a [u8]> for MsgId {
154    type Error = ();
155
156    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
157        let msg_id: &MsgId = value.try_into()?;
158        Ok(*msg_id)
159    }
160}
161
162// It is always possible to get MsgId from &MsgHdr
163impl<'a> From<&'a MsgHdr> for MsgId {
164    fn from(value: &MsgHdr) -> Self {
165        *value.id()
166    }
167}
168
169#[derive(Debug, Eq, Copy, Clone, PartialEq)]
170pub enum Kind {
171    Ask,
172    Pub,
173}
174
175#[derive(Clone, Copy, bytemuck::AnyBitPattern, bytemuck::NoUninit)]
176#[repr(C)]
177pub struct MsgHdr {
178    data: [u8; MESSAGE_HEADER_SIZE],
179}
180
181impl fmt::Debug for MsgHdr {
182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183        write!(
184            f,
185            "MsgHdr(id: {:X}, flags: {:04X}, ttl: {})",
186            self.id(),
187            self.flags(),
188            self.ttl().as_secs(),
189        )
190    }
191}
192
193// Try convert a byte slice into a reference to MsgHdr. It will
194// succeed is given slice is at least MESSAGE_HEADER_SIZE bytes.
195impl<'a> TryFrom<&'a [u8]> for &'a MsgHdr {
196    type Error = ();
197
198    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
199        value
200            .first_chunk::<MESSAGE_HEADER_SIZE>()
201            .and_then(|hdr| bytemuck::try_cast_ref(hdr).ok())
202            .ok_or(())
203    }
204}
205
206// The same above but tries to convert into MsgHdr value.
207impl<'a> TryFrom<&'a [u8]> for MsgHdr {
208    type Error = ();
209
210    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
211        let hdr: &MsgHdr = value.try_into()?;
212        Ok(*hdr)
213    }
214}
215
216impl MsgHdr {
217    /// Decode message id field.
218    pub fn id(&self) -> &MsgId {
219        self.data[..MESSAGE_ID_SIZE].try_into().unwrap()
220    }
221
222    /// Decode flags field.
223    pub fn flags(&self) -> u16 {
224        u16::from_le_bytes(
225            self.data[MESSAGE_ID_SIZE..][2..].try_into().unwrap(),
226        )
227    }
228
229    /// Decode TTL field.
230    pub fn ttl(&self) -> Duration {
231        let secs: u16 = u16::from_le_bytes(
232            self.data[MESSAGE_ID_SIZE..][..2].try_into().unwrap(),
233        );
234
235        Duration::from_secs(secs as u64)
236    }
237
238    /// Encode header parts into given buffer.
239    pub fn encode(
240        hdr: &mut [u8; MESSAGE_HEADER_SIZE],
241        id: &MsgId,
242        ttl: u32,
243        flags: u16,
244    ) {
245        let data: u32 = (ttl & 0xffff) | (flags as u32) << 16;
246
247        hdr[..MESSAGE_ID_SIZE].copy_from_slice(&id.0);
248        hdr[MESSAGE_ID_SIZE..].copy_from_slice(&data.to_le_bytes());
249    }
250}
251
252/// Allocate message and initalize it from given parts.
253///
254/// This is mostly debug/test support function.
255pub fn allocate_message(
256    id: &MsgId,
257    ttl: u32,
258    flags: u16,
259    payload: &[u8],
260) -> Vec<u8> {
261    let mut buffer = Vec::with_capacity(MESSAGE_HEADER_SIZE + payload.len());
262
263    buffer.resize(MESSAGE_HEADER_SIZE, 0);
264
265    MsgHdr::encode(buffer.as_mut_slice().try_into().unwrap(), id, ttl, flags);
266
267    buffer.extend_from_slice(payload);
268
269    buffer
270}
271
272pub struct AskMsg;
273
274impl AskMsg {
275    pub fn allocate(id: &MsgId, ttl: u32) -> Vec<u8> {
276        allocate_message(id, ttl, 0, &[])
277    }
278}
279
280#[cfg(test)]
281mod test {
282    use super::*;
283
284    #[test]
285    fn msg_hdr() {
286        let data = [0u8; MESSAGE_HEADER_SIZE + 1];
287
288        assert!(<&MsgHdr>::try_from(&data[..]).is_ok());
289
290        assert!(<&MsgHdr>::try_from(&data[..MESSAGE_HEADER_SIZE]).is_ok());
291
292        assert!(
293            <&MsgHdr>::try_from(&data[..MESSAGE_HEADER_SIZE - 1]).is_err()
294        );
295    }
296
297    #[test]
298    fn msg_tags() {
299        let t1 = MessageTag::tag(0x1020304050607080);
300
301        assert_eq!(
302            t1.to_bytes(),
303            [0x80, 0x70, 0x60, 0x50, 0x40, 0x30, 0x20, 0x10]
304        );
305
306        let t2 = MessageTag::tag1(0x10203040, 0xAABBCCDD);
307
308        assert_eq!(
309            t2.to_bytes(),
310            [0x40, 0x30, 0x20, 0x10, 0xDD, 0xCC, 0xBB, 0xAA]
311        );
312
313        let t3 = MessageTag::tag2(0x10203040, 0xEEFF, 0xDEAD);
314
315        assert_eq!(
316            t3.to_bytes(),
317            [0x40, 0x30, 0x20, 0x10, 0xFF, 0xEE, 0xAD, 0xDE]
318        );
319    }
320}