datex_core/global/protocol_structures/
routing_header.rs

1use std::fmt::Display;
2
3use super::serializable::Serializable;
4use crate::values::core_values::endpoint::Endpoint;
5use binrw::{BinRead, BinWrite};
6use modular_bitfield::prelude::*;
7
8// 2 bit
9#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, PartialEq, Clone, Default, Specifier)]
11#[bits = 2]
12pub enum SignatureType {
13    #[default]
14    None = 0b00,
15    Unencrypted = 0b10,
16    Encrypted = 0b11,
17}
18
19// 1 bit
20#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
21#[derive(Debug, PartialEq, Clone, Default, Specifier)]
22pub enum EncryptionType {
23    #[default]
24    None = 0b0,
25    Encrypted = 0b1,
26}
27
28// 2 bit + 1 bit + 2 bit + 1 bit + 1 bit + 1 bit = 1 byte
29#[bitfield]
30#[derive(BinWrite, BinRead, Clone, Default, Copy, Debug, PartialEq)]
31#[bw(map = |&x| Self::into_bytes(x))]
32#[br(map = Self::from_bytes)]
33pub struct Flags {
34    pub signature_type: SignatureType,   // 2 bit
35    pub encryption_type: EncryptionType, // 1 bit
36    pub receiver_type: ReceiverType,     // 2 bit
37    pub is_bounce_back: bool,            // 1 bit
38    pub has_checksum: bool,              // 1 bit
39
40    #[allow(unused)]
41    unused_2: bool,
42}
43
44#[cfg(feature = "debug")]
45mod flags_serde {
46    use super::*;
47    use crate::global::protocol_structures::routing_header::Flags;
48    use serde::{Deserialize, Deserializer, Serialize, Serializer};
49    #[derive(Serialize, Deserialize)]
50    struct FlagsHelper {
51        signature_type: SignatureType,
52        encryption_type: EncryptionType,
53        receiver_type: ReceiverType,
54        is_bounce_back: bool,
55        has_checksum: bool,
56    }
57
58    impl Serialize for Flags {
59        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
60        where
61            S: Serializer,
62        {
63            let helper = FlagsHelper {
64                signature_type: self.signature_type(),
65                encryption_type: self.encryption_type(),
66                receiver_type: self.receiver_type(),
67                is_bounce_back: self.is_bounce_back(),
68                has_checksum: self.has_checksum(),
69            };
70            helper.serialize(serializer)
71        }
72    }
73
74    impl<'de> Deserialize<'de> for Flags {
75        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
76        where
77            D: Deserializer<'de>,
78        {
79            let helper = FlagsHelper::deserialize(deserializer)?;
80            Ok(Flags::new()
81                .with_signature_type(helper.signature_type)
82                .with_encryption_type(helper.encryption_type)
83                .with_receiver_type(helper.receiver_type)
84                .with_is_bounce_back(helper.is_bounce_back)
85                .with_has_checksum(helper.has_checksum))
86        }
87    }
88}
89
90// 2 bit
91#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
92#[derive(Debug, PartialEq, Clone, Default, Specifier)]
93#[bits = 2]
94pub enum ReceiverType {
95    #[default]
96    None = 0b00,
97    Pointer = 0b01,
98    Receivers = 0b10,
99    ReceiversWithKeys = 0b11,
100}
101
102// TODO #430 directly use bytes / pointer address instead of whole struct here
103// 1 byte + 18 byte + 2 byte + 4 byte + 1 byte = 26 bytes
104#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
105#[derive(Debug, Clone, Default, BinWrite, BinRead, PartialEq)]
106pub struct PointerAddress {
107    pub pointer_type: u8,
108    pub identifier: [u8; 18],
109    pub instance: u16,
110    pub timestamp: u32,
111    pub counter: u8,
112}
113
114// <count>: 1 byte + (21 byte * count)
115// min: 2 bytes
116#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
117#[derive(Debug, Clone, Default, BinWrite, BinRead, PartialEq)]
118pub struct ReceiverEndpoints {
119    #[cfg_attr(feature = "debug", serde(rename = "number_of_receivers"))]
120    pub count: u8,
121    #[br(count = count)]
122    #[cfg_attr(feature = "debug", serde(rename = "receivers"))]
123    pub endpoints: Vec<Endpoint>,
124}
125
126impl ReceiverEndpoints {
127    pub fn new(endpoints: Vec<Endpoint>) -> Self {
128        let count = endpoints.len() as u8;
129        ReceiverEndpoints { count, endpoints }
130    }
131}
132
133// <count>: 1 byte + (21 byte * count) + (512 byte * count)
134// min: 2 bytes
135#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
136#[derive(Debug, Clone, Default, BinWrite, BinRead, PartialEq)]
137pub struct ReceiverEndpointsWithKeys {
138    #[cfg_attr(feature = "debug", serde(rename = "number_of_receivers"))]
139    count: u8,
140    #[br(count = count)]
141    #[cfg_attr(feature = "debug", serde(rename = "receivers_with_keys"))]
142    pub endpoints_with_keys: Vec<(Endpoint, Key512)>,
143}
144
145#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
146#[derive(Debug, Clone, BinWrite, BinRead, PartialEq)]
147pub struct Key512(
148    #[cfg_attr(feature = "debug", serde(with = "serde_big_array::BigArray"))]
149    [u8; 512],
150);
151impl Default for Key512 {
152    fn default() -> Self {
153        Key512([0u8; 512])
154    }
155}
156impl From<[u8; 512]> for Key512 {
157    fn from(arr: [u8; 512]) -> Self {
158        Key512(arr)
159    }
160}
161
162impl ReceiverEndpointsWithKeys {
163    pub fn new<T>(endpoints_with_keys: Vec<(Endpoint, T)>) -> Self
164    where
165        T: Into<Key512>,
166    {
167        let count = endpoints_with_keys.len() as u8;
168        ReceiverEndpointsWithKeys {
169            count,
170            endpoints_with_keys: endpoints_with_keys
171                .into_iter()
172                .map(|(ep, key)| (ep, key.into()))
173                .collect(),
174        }
175    }
176}
177
178// min: 11 byte + 2 byte + 21 byte + 1 byte = 35 bytes
179#[cfg_attr(feature = "debug", derive(serde::Serialize, serde::Deserialize))]
180#[derive(Debug, Clone, BinWrite, BinRead, PartialEq)]
181#[brw(little, magic = b"\x01\x64")]
182pub struct RoutingHeader {
183    pub version: u8,
184    pub block_size: u16,
185    pub flags: Flags,
186
187    #[brw(if(flags.has_checksum()))]
188    checksum: Option<u32>,
189
190    pub distance: i8,
191    pub ttl: u8,
192
193    pub sender: Endpoint,
194
195    // TODO #115: add custom match receiver queries
196    #[brw(if(flags.receiver_type() == ReceiverType::Pointer))]
197    receivers_pointer_id: Option<PointerAddress>,
198    #[brw(if(flags.receiver_type() == ReceiverType::Receivers))]
199    #[cfg_attr(feature = "debug", serde(flatten))]
200    receivers_endpoints: Option<ReceiverEndpoints>,
201    #[brw(if(flags.receiver_type() == ReceiverType::ReceiversWithKeys))]
202    #[cfg_attr(feature = "debug", serde(flatten))]
203    receivers_endpoints_with_keys: Option<ReceiverEndpointsWithKeys>,
204}
205
206impl Serializable for RoutingHeader {}
207
208impl Default for RoutingHeader {
209    fn default() -> Self {
210        RoutingHeader {
211            version: 1,
212            distance: 0,
213            ttl: 42,
214            flags: Flags::new(),
215            checksum: None,
216            block_size: 0,
217            sender: Endpoint::default(),
218            receivers_pointer_id: None,
219            receivers_endpoints: None,
220            receivers_endpoints_with_keys: None,
221        }
222    }
223}
224
225#[derive(Debug, Clone, PartialEq)]
226pub enum Receivers {
227    None,
228    // TODO #431 rename to PointerAddress
229    PointerId(PointerAddress),
230    Endpoints(Vec<Endpoint>),
231    EndpointsWithKeys(Vec<(Endpoint, Key512)>),
232}
233impl Display for Receivers {
234    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
235        match self {
236            Receivers::None => write!(f, "No receivers"),
237            Receivers::PointerId(pid) => write!(f, "Pointer ID: {:?}", pid),
238            Receivers::Endpoints(endpoints) => {
239                write!(f, "Endpoints: {:?}", endpoints)
240            }
241            Receivers::EndpointsWithKeys(endpoints_with_keys) => {
242                write!(f, "Endpoints with keys: {:?}", endpoints_with_keys)
243            }
244        }
245    }
246}
247
248impl From<PointerAddress> for Receivers {
249    fn from(pid: PointerAddress) -> Self {
250        Receivers::PointerId(pid)
251    }
252}
253impl From<Vec<Endpoint>> for Receivers {
254    fn from(endpoints: Vec<Endpoint>) -> Self {
255        Receivers::from(endpoints.as_slice())
256    }
257}
258impl From<&Vec<Endpoint>> for Receivers {
259    fn from(endpoints: &Vec<Endpoint>) -> Self {
260        Receivers::from(endpoints.as_slice())
261    }
262}
263impl From<&[Endpoint]> for Receivers {
264    fn from(endpoints: &[Endpoint]) -> Self {
265        if endpoints.is_empty() {
266            Receivers::None
267        } else {
268            Receivers::Endpoints(endpoints.to_vec())
269        }
270    }
271}
272impl<T> From<Vec<(Endpoint, T)>> for Receivers
273where
274    T: Into<Key512>,
275{
276    fn from(endpoints_with_keys: Vec<(Endpoint, T)>) -> Self {
277        if endpoints_with_keys.is_empty() {
278            Receivers::None
279        } else {
280            Receivers::EndpointsWithKeys(
281                endpoints_with_keys
282                    .into_iter()
283                    .map(|(ep, key)| (ep, key.into()))
284                    .collect(),
285            )
286        }
287    }
288}
289
290impl RoutingHeader {
291    pub fn with_sender(&mut self, sender: Endpoint) -> &mut Self {
292        self.sender = sender;
293        self
294    }
295    pub fn with_receivers(&mut self, receivers: Receivers) -> &mut Self {
296        self.set_receivers(receivers);
297        self
298    }
299    pub fn with_ttl(&mut self, ttl: u8) -> &mut Self {
300        self.ttl = ttl;
301        self
302    }
303}
304
305impl RoutingHeader {
306    pub fn new(
307        ttl: u8,
308        flags: Flags,
309        sender: Endpoint,
310        receivers: Receivers,
311    ) -> Self {
312        let mut routing_header = RoutingHeader {
313            sender,
314            ttl,
315            flags,
316            ..RoutingHeader::default()
317        };
318        routing_header.set_receivers(receivers);
319        routing_header
320    }
321
322    pub fn set_size(&mut self, size: u16) {
323        self.block_size = size;
324    }
325
326    pub fn set_receivers(&mut self, receivers: Receivers) {
327        self.receivers_endpoints = None;
328        self.receivers_pointer_id = None;
329        self.receivers_endpoints_with_keys = None;
330        self.flags.set_receiver_type(ReceiverType::None);
331
332        match receivers {
333            Receivers::PointerId(pid) => self.receivers_pointer_id = Some(pid),
334            Receivers::Endpoints(endpoints) => {
335                if !endpoints.is_empty() {
336                    self.receivers_endpoints =
337                        Some(ReceiverEndpoints::new(endpoints));
338                    self.flags.set_receiver_type(ReceiverType::Receivers);
339                }
340            }
341            Receivers::EndpointsWithKeys(endpoints_with_keys) => {
342                if !endpoints_with_keys.is_empty() {
343                    self.receivers_endpoints_with_keys = Some(
344                        ReceiverEndpointsWithKeys::new(endpoints_with_keys),
345                    );
346                    self.flags
347                        .set_receiver_type(ReceiverType::ReceiversWithKeys);
348                }
349            }
350            Receivers::None => {}
351        }
352    }
353
354    /// Get the receivers from the routing header
355    pub fn receivers(&self) -> Receivers {
356        if let Some(pid) = &self.receivers_pointer_id {
357            Receivers::PointerId(pid.clone())
358        } else if let Some(endpoints) = &self.receivers_endpoints
359            && endpoints.count > 0
360        {
361            Receivers::Endpoints(endpoints.endpoints.clone())
362        } else if let Some(endpoints_with_keys) =
363            &self.receivers_endpoints_with_keys
364            && endpoints_with_keys.count > 0
365        {
366            Receivers::EndpointsWithKeys(
367                endpoints_with_keys.endpoints_with_keys.clone(),
368            )
369        } else {
370            Receivers::None
371        }
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use std::str::FromStr;
378
379    use super::*;
380    #[test]
381    fn single_receiver() {
382        let routing_header = RoutingHeader::default()
383            .with_sender(Endpoint::from_str("@jonas").unwrap())
384            .with_ttl(64)
385            .with_receivers(Receivers::Endpoints(vec![
386                Endpoint::from_str("@alice").unwrap(),
387            ]))
388            .to_owned();
389        assert_eq!(
390            routing_header.sender,
391            Endpoint::from_str("@jonas").unwrap()
392        );
393        assert_eq!(routing_header.ttl, 64);
394        assert_eq!(
395            routing_header.receivers(),
396            Receivers::Endpoints(vec![Endpoint::from_str("@alice").unwrap()])
397        );
398        assert_eq!(
399            routing_header.flags.receiver_type(),
400            ReceiverType::Receivers
401        );
402    }
403
404    #[test]
405    fn multiple_receivers() {
406        let routing_header = RoutingHeader::default()
407            .with_receivers(Receivers::Endpoints(vec![
408                Endpoint::from_str("@alice").unwrap(),
409                Endpoint::from_str("@bob").unwrap(),
410            ]))
411            .to_owned();
412        assert_eq!(
413            routing_header.receivers(),
414            Receivers::Endpoints(vec![
415                Endpoint::from_str("@alice").unwrap(),
416                Endpoint::from_str("@bob").unwrap(),
417            ])
418        );
419        assert_eq!(
420            routing_header.flags.receiver_type(),
421            ReceiverType::Receivers
422        );
423    }
424
425    #[test]
426    fn no_receivers() {
427        let routing_header = RoutingHeader::default()
428            .with_receivers(Receivers::None)
429            .to_owned();
430        assert_eq!(routing_header.receivers(), Receivers::None);
431
432        let routing_header = RoutingHeader::default()
433            .with_receivers(Receivers::Endpoints(vec![]))
434            .to_owned();
435        assert_eq!(routing_header.receivers(), Receivers::None);
436        assert_eq!(routing_header.flags.receiver_type(), ReceiverType::None);
437    }
438}