datex_core/global/protocol_structures/
routing_header.rs

1use super::serializable::Serializable;
2use crate::global::protocol_structures::instructions::RawFullPointerAddress;
3use crate::stdlib::vec::Vec;
4use crate::values::core_values::endpoint::Endpoint;
5use binrw::{BinRead, BinWrite};
6use core::fmt::Display;
7use core::prelude::rust_2024::*;
8use modular_bitfield::prelude::*;
9
10// 2 bit
11#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
12#[derive(Debug, PartialEq, Clone, Default, Specifier)]
13#[bits = 2]
14pub enum SignatureType {
15    #[default]
16    None = 0b00,
17    Unencrypted = 0b10,
18    Encrypted = 0b11,
19}
20
21// 1 bit
22#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
23#[derive(Debug, PartialEq, Clone, Default, Specifier)]
24pub enum EncryptionType {
25    #[default]
26    None = 0b0,
27    Encrypted = 0b1,
28}
29
30// 2 bit + 1 bit + 2 bit + 1 bit + 1 bit + 1 bit = 1 byte
31#[bitfield]
32#[derive(BinWrite, BinRead, Clone, Default, Copy, Debug, PartialEq)]
33#[bw(map = |&x| Self::into_bytes(x))]
34#[br(map = Self::from_bytes)]
35pub struct Flags {
36    pub signature_type: SignatureType,   // 2 bit
37    pub encryption_type: EncryptionType, // 1 bit
38    pub receiver_type: ReceiverType,     // 2 bit
39    pub is_bounce_back: bool,            // 1 bit
40    pub has_checksum: bool,              // 1 bit
41
42    #[allow(unused)]
43    unused_2: bool,
44}
45
46#[cfg(feature = "debug")]
47mod flags_serde {
48    use super::*;
49    use crate::global::protocol_structures::routing_header::Flags;
50    use serde::{Deserialize, Deserializer, Serialize, Serializer};
51    #[derive(Serialize, Deserialize)]
52    struct FlagsHelper {
53        signature_type: SignatureType,
54        encryption_type: EncryptionType,
55        receiver_type: ReceiverType,
56        is_bounce_back: bool,
57        has_checksum: bool,
58    }
59
60    impl Serialize for Flags {
61        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
62        where
63            S: Serializer,
64        {
65            let helper = FlagsHelper {
66                signature_type: self.signature_type(),
67                encryption_type: self.encryption_type(),
68                receiver_type: self.receiver_type(),
69                is_bounce_back: self.is_bounce_back(),
70                has_checksum: self.has_checksum(),
71            };
72            helper.serialize(serializer)
73        }
74    }
75
76    impl<'de> Deserialize<'de> for Flags {
77        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
78        where
79            D: Deserializer<'de>,
80        {
81            let helper = FlagsHelper::deserialize(deserializer)?;
82            Ok(Flags::new()
83                .with_signature_type(helper.signature_type)
84                .with_encryption_type(helper.encryption_type)
85                .with_receiver_type(helper.receiver_type)
86                .with_is_bounce_back(helper.is_bounce_back)
87                .with_has_checksum(helper.has_checksum))
88        }
89    }
90}
91
92// 2 bit
93#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
94#[derive(Debug, PartialEq, Clone, Default, Specifier)]
95#[bits = 2]
96pub enum ReceiverType {
97    #[default]
98    None = 0b00,
99    Pointer = 0b01,
100    Receivers = 0b10,
101    ReceiversWithKeys = 0b11,
102}
103
104// <count>: 1 byte + (21 byte * count)
105// min: 2 bytes
106#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
107#[derive(Debug, Clone, Default, BinWrite, BinRead, PartialEq)]
108pub struct ReceiverEndpoints {
109    #[cfg_attr(feature = "debug", serde(rename = "number_of_receivers"))]
110    pub count: u8,
111    #[br(count = count)]
112    #[cfg_attr(feature = "debug", serde(rename = "receivers"))]
113    pub endpoints: Vec<Endpoint>,
114}
115
116impl ReceiverEndpoints {
117    pub fn new(endpoints: Vec<Endpoint>) -> Self {
118        let count = endpoints.len() as u8;
119        ReceiverEndpoints { count, endpoints }
120    }
121}
122
123// <count>: 1 byte + (21 byte * count) + (512 byte * count)
124// min: 2 bytes
125#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
126#[derive(Debug, Clone, Default, BinWrite, BinRead, PartialEq)]
127pub struct ReceiverEndpointsWithKeys {
128    #[cfg_attr(feature = "debug", serde(rename = "number_of_receivers"))]
129    count: u8,
130    #[br(count = count)]
131    #[cfg_attr(feature = "debug", serde(rename = "receivers_with_keys"))]
132    pub endpoints_with_keys: Vec<(Endpoint, Key512)>,
133}
134
135#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
136#[derive(Debug, Clone, BinWrite, BinRead, PartialEq)]
137pub struct Key512(
138    #[cfg_attr(feature = "debug", serde(with = "serde_big_array::BigArray"))]
139    [u8; 512],
140);
141impl Default for Key512 {
142    fn default() -> Self {
143        Key512([0u8; 512])
144    }
145}
146impl From<[u8; 512]> for Key512 {
147    fn from(arr: [u8; 512]) -> Self {
148        Key512(arr)
149    }
150}
151
152impl ReceiverEndpointsWithKeys {
153    pub fn new<T>(endpoints_with_keys: Vec<(Endpoint, T)>) -> Self
154    where
155        T: Into<Key512>,
156    {
157        let count = endpoints_with_keys.len() as u8;
158        ReceiverEndpointsWithKeys {
159            count,
160            endpoints_with_keys: endpoints_with_keys
161                .into_iter()
162                .map(|(ep, key)| (ep, key.into()))
163                .collect(),
164        }
165    }
166}
167
168// min: 11 byte + 2 byte + 21 byte + 1 byte = 35 bytes
169#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
170#[derive(Debug, Clone, BinWrite, BinRead, PartialEq)]
171#[brw(little, magic = b"\x01\x64")]
172pub struct RoutingHeader {
173    pub version: u8,
174    pub block_size: u16,
175    pub flags: Flags,
176
177    #[brw(if(flags.has_checksum()))]
178    checksum: Option<u32>,
179
180    pub distance: i8,
181    pub ttl: u8,
182
183    pub sender: Endpoint,
184
185    // TODO #115: add custom match receiver queries
186    #[brw(if(flags.receiver_type() == ReceiverType::Pointer))]
187    receivers_pointer_id: Option<RawFullPointerAddress>,
188    #[brw(if(flags.receiver_type() == ReceiverType::Receivers))]
189    #[cfg_attr(feature = "debug", serde(flatten))]
190    receivers_endpoints: Option<ReceiverEndpoints>,
191    #[brw(if(flags.receiver_type() == ReceiverType::ReceiversWithKeys))]
192    #[cfg_attr(feature = "debug", serde(flatten))]
193    receivers_endpoints_with_keys: Option<ReceiverEndpointsWithKeys>,
194}
195
196impl Serializable for RoutingHeader {}
197
198impl Default for RoutingHeader {
199    fn default() -> Self {
200        RoutingHeader {
201            version: 1,
202            distance: 0,
203            ttl: 42,
204            flags: Flags::new(),
205            checksum: None,
206            block_size: 0,
207            sender: Endpoint::default(),
208            receivers_pointer_id: None,
209            receivers_endpoints: None,
210            receivers_endpoints_with_keys: None,
211        }
212    }
213}
214
215#[derive(Debug, Clone, PartialEq)]
216pub enum Receivers {
217    None,
218    // TODO #431 rename to PointerAddress
219    PointerId(RawFullPointerAddress),
220    Endpoints(Vec<Endpoint>),
221    EndpointsWithKeys(Vec<(Endpoint, Key512)>),
222}
223impl Display for Receivers {
224    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
225        match self {
226            Receivers::None => core::write!(f, "No receivers"),
227            Receivers::PointerId(pid) => {
228                core::write!(f, "Pointer ID: {:?}", pid)
229            }
230            Receivers::Endpoints(endpoints) => {
231                core::write!(f, "Endpoints: {:?}", endpoints)
232            }
233            Receivers::EndpointsWithKeys(endpoints_with_keys) => {
234                core::write!(
235                    f,
236                    "Endpoints with keys: {:?}",
237                    endpoints_with_keys
238                )
239            }
240        }
241    }
242}
243
244impl<T> From<T> for Receivers
245where
246    T: Into<RawFullPointerAddress>,
247{
248    fn from(pid: T) -> Self {
249        Receivers::PointerId(pid.into())
250    }
251}
252impl From<Vec<Endpoint>> for Receivers {
253    fn from(endpoints: Vec<Endpoint>) -> Self {
254        Receivers::from(endpoints.as_slice())
255    }
256}
257impl From<&Vec<Endpoint>> for Receivers {
258    fn from(endpoints: &Vec<Endpoint>) -> Self {
259        Receivers::from(endpoints.as_slice())
260    }
261}
262impl From<&[Endpoint]> for Receivers {
263    fn from(endpoints: &[Endpoint]) -> Self {
264        if endpoints.is_empty() {
265            Receivers::None
266        } else {
267            Receivers::Endpoints(endpoints.to_vec())
268        }
269    }
270}
271impl<T> From<Vec<(Endpoint, T)>> for Receivers
272where
273    T: Into<Key512>,
274{
275    fn from(endpoints_with_keys: Vec<(Endpoint, T)>) -> Self {
276        if endpoints_with_keys.is_empty() {
277            Receivers::None
278        } else {
279            Receivers::EndpointsWithKeys(
280                endpoints_with_keys
281                    .into_iter()
282                    .map(|(ep, key)| (ep, key.into()))
283                    .collect(),
284            )
285        }
286    }
287}
288
289impl RoutingHeader {
290    pub fn with_sender(&mut self, sender: Endpoint) -> &mut Self {
291        self.sender = sender;
292        self
293    }
294    pub fn with_receivers(&mut self, receivers: Receivers) -> &mut Self {
295        self.set_receivers(receivers);
296        self
297    }
298    pub fn with_ttl(&mut self, ttl: u8) -> &mut Self {
299        self.ttl = ttl;
300        self
301    }
302}
303
304impl RoutingHeader {
305    pub fn new(
306        ttl: u8,
307        flags: Flags,
308        sender: Endpoint,
309        receivers: Receivers,
310    ) -> Self {
311        let mut routing_header = RoutingHeader {
312            sender,
313            ttl,
314            flags,
315            ..RoutingHeader::default()
316        };
317        routing_header.set_receivers(receivers);
318        routing_header
319    }
320
321    pub fn set_size(&mut self, size: u16) {
322        self.block_size = size;
323    }
324
325    pub fn set_receivers(&mut self, receivers: Receivers) {
326        self.receivers_endpoints = None;
327        self.receivers_pointer_id = None;
328        self.receivers_endpoints_with_keys = None;
329        self.flags.set_receiver_type(ReceiverType::None);
330
331        match receivers {
332            Receivers::PointerId(pid) => {
333                self.receivers_pointer_id = Some(pid);
334                self.flags.set_receiver_type(ReceiverType::Pointer);
335            }
336            Receivers::Endpoints(endpoints) => {
337                if !endpoints.is_empty() {
338                    self.receivers_endpoints =
339                        Some(ReceiverEndpoints::new(endpoints));
340                    self.flags.set_receiver_type(ReceiverType::Receivers);
341                }
342            }
343            Receivers::EndpointsWithKeys(endpoints_with_keys) => {
344                if !endpoints_with_keys.is_empty() {
345                    self.receivers_endpoints_with_keys = Some(
346                        ReceiverEndpointsWithKeys::new(endpoints_with_keys),
347                    );
348                    self.flags
349                        .set_receiver_type(ReceiverType::ReceiversWithKeys);
350                }
351            }
352            Receivers::None => {}
353        }
354    }
355
356    /// Get the receivers from the routing header
357    pub fn receivers(&self) -> Receivers {
358        if let Some(pid) = &self.receivers_pointer_id {
359            Receivers::PointerId(pid.clone())
360        } else if let Some(endpoints) = &self.receivers_endpoints
361            && endpoints.count > 0
362        {
363            Receivers::Endpoints(endpoints.endpoints.clone())
364        } else if let Some(endpoints_with_keys) =
365            &self.receivers_endpoints_with_keys
366            && endpoints_with_keys.count > 0
367        {
368            Receivers::EndpointsWithKeys(
369                endpoints_with_keys.endpoints_with_keys.clone(),
370            )
371        } else {
372            Receivers::None
373        }
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use core::str::FromStr;
380
381    use super::*;
382    #[test]
383    fn single_receiver() {
384        let routing_header = RoutingHeader::default()
385            .with_sender(Endpoint::from_str("@jonas").unwrap())
386            .with_ttl(64)
387            .with_receivers(Receivers::Endpoints(vec![
388                Endpoint::from_str("@alice").unwrap(),
389            ]))
390            .to_owned();
391        assert_eq!(
392            routing_header.sender,
393            Endpoint::from_str("@jonas").unwrap()
394        );
395        assert_eq!(routing_header.ttl, 64);
396        assert_eq!(
397            routing_header.receivers(),
398            Receivers::Endpoints(vec![Endpoint::from_str("@alice").unwrap()])
399        );
400        assert_eq!(
401            routing_header.flags.receiver_type(),
402            ReceiverType::Receivers
403        );
404    }
405
406    #[test]
407    fn multiple_receivers() {
408        let routing_header = RoutingHeader::default()
409            .with_receivers(Receivers::Endpoints(vec![
410                Endpoint::from_str("@alice").unwrap(),
411                Endpoint::from_str("@bob").unwrap(),
412            ]))
413            .to_owned();
414        assert_eq!(
415            routing_header.receivers(),
416            Receivers::Endpoints(vec![
417                Endpoint::from_str("@alice").unwrap(),
418                Endpoint::from_str("@bob").unwrap(),
419            ])
420        );
421        assert_eq!(
422            routing_header.flags.receiver_type(),
423            ReceiverType::Receivers
424        );
425    }
426
427    #[test]
428    fn no_receivers() {
429        let routing_header = RoutingHeader::default()
430            .with_receivers(Receivers::None)
431            .to_owned();
432        assert_eq!(routing_header.receivers(), Receivers::None);
433
434        let routing_header = RoutingHeader::default()
435            .with_receivers(Receivers::Endpoints(vec![]))
436            .to_owned();
437        assert_eq!(routing_header.receivers(), Receivers::None);
438        assert_eq!(routing_header.flags.receiver_type(), ReceiverType::None);
439    }
440}