Skip to main content

datex_core/global/protocol_structures/
routing_header.rs

1use super::serializable::Serializable;
2use crate::{
3    global::protocol_structures::instructions::RawFullPointerAddress,
4    values::core_values::endpoint::Endpoint,
5};
6
7use crate::prelude::*;
8use binrw::{BinRead, BinWrite};
9use core::{fmt::Display, prelude::rust_2024::*};
10use modular_bitfield::prelude::*;
11
12// 2 bit
13#[derive(
14    serde::Serialize, serde::Deserialize, Debug, PartialEq, Clone, Specifier,
15)]
16#[bits = 2]
17pub enum SignatureType {
18    None = 0b00,
19    Unencrypted = 0b10,
20    Encrypted = 0b11,
21}
22
23#[allow(clippy::derivable_impls)]
24impl Default for SignatureType {
25    /// Sets the default signature type based on whether the "crypto_enabled" feature is enabled.
26    /// When "crypto_enabled" is enabled, the default signature type is set to Unencrypted
27    /// Otherwise, it is set to None.
28    fn default() -> Self {
29        #[cfg(not(feature = "crypto_enabled"))]
30        {
31            SignatureType::None
32        }
33        #[cfg(feature = "crypto_enabled")]
34        {
35            SignatureType::Unencrypted
36        }
37    }
38}
39
40// 1 bit
41#[derive(
42    serde::Serialize,
43    serde::Deserialize,
44    Debug,
45    PartialEq,
46    Clone,
47    Default,
48    Specifier,
49)]
50pub enum EncryptionType {
51    #[default]
52    None = 0b0,
53    Encrypted = 0b1,
54}
55
56// 2 bit + 1 bit + 2 bit + 1 bit + 1 bit + 1 bit = 1 byte
57#[bitfield]
58#[derive(BinWrite, BinRead, Clone, Copy, Debug, PartialEq)]
59#[bw(map = |&x| Self::into_bytes(x))]
60#[br(map = Self::from_bytes)]
61pub struct Flags {
62    pub signature_type: SignatureType,   // 2 bit
63    pub encryption_type: EncryptionType, // 1 bit
64    pub receiver_type: ReceiverType,     // 2 bit
65    pub is_bounce_back: bool,            // 1 bit
66    pub has_checksum: bool,              // 1 bit
67
68    #[allow(unused)]
69    unused_2: bool,
70}
71
72impl Default for Flags {
73    fn default() -> Self {
74        Flags::new()
75            .with_signature_type(SignatureType::default())
76            .with_encryption_type(EncryptionType::default())
77            .with_receiver_type(ReceiverType::default())
78            .with_is_bounce_back(false)
79            .with_has_checksum(false)
80    }
81}
82
83mod flags_serde {
84    use super::*;
85    use crate::global::protocol_structures::routing_header::Flags;
86    use serde::{Deserialize, Deserializer, Serialize, Serializer};
87    #[derive(Serialize, Deserialize)]
88    struct FlagsHelper {
89        signature_type: SignatureType,
90        encryption_type: EncryptionType,
91        receiver_type: ReceiverType,
92        is_bounce_back: bool,
93        has_checksum: bool,
94    }
95
96    impl Serialize for Flags {
97        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
98        where
99            S: Serializer,
100        {
101            let helper = FlagsHelper {
102                signature_type: self.signature_type(),
103                encryption_type: self.encryption_type(),
104                receiver_type: self.receiver_type(),
105                is_bounce_back: self.is_bounce_back(),
106                has_checksum: self.has_checksum(),
107            };
108            helper.serialize(serializer)
109        }
110    }
111
112    impl<'de> Deserialize<'de> for Flags {
113        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
114        where
115            D: Deserializer<'de>,
116        {
117            let helper = FlagsHelper::deserialize(deserializer)?;
118            Ok(Flags::new()
119                .with_signature_type(helper.signature_type)
120                .with_encryption_type(helper.encryption_type)
121                .with_receiver_type(helper.receiver_type)
122                .with_is_bounce_back(helper.is_bounce_back)
123                .with_has_checksum(helper.has_checksum))
124        }
125    }
126}
127
128// 2 bit
129#[derive(
130    serde::Serialize,
131    serde::Deserialize,
132    Debug,
133    PartialEq,
134    Clone,
135    Default,
136    Specifier,
137)]
138#[bits = 2]
139pub enum ReceiverType {
140    #[default]
141    None = 0b00,
142    Pointer = 0b01,
143    Receivers = 0b10,
144    ReceiversWithKeys = 0b11,
145}
146
147// <count>: 1 byte + (21 byte * count)
148// min: 2 bytes
149#[derive(
150    serde::Serialize,
151    serde::Deserialize,
152    Debug,
153    Clone,
154    Default,
155    BinWrite,
156    BinRead,
157    PartialEq,
158)]
159pub struct ReceiverEndpoints {
160    #[serde(rename = "number_of_receivers")]
161    pub count: u8,
162    #[br(count = count)]
163    #[serde(rename = "receivers")]
164    pub endpoints: Vec<Endpoint>,
165}
166
167impl ReceiverEndpoints {
168    pub fn new(endpoints: Vec<Endpoint>) -> Self {
169        let count = endpoints.len() as u8;
170        ReceiverEndpoints { count, endpoints }
171    }
172}
173
174// <count>: 1 byte + (21 byte * count) + (512 byte * count)
175// min: 2 bytes
176#[derive(
177    serde::Serialize,
178    serde::Deserialize,
179    Debug,
180    Clone,
181    Default,
182    BinWrite,
183    BinRead,
184    PartialEq,
185)]
186pub struct ReceiverEndpointsWithKeys {
187    #[serde(rename = "number_of_receivers")]
188    count: u8,
189    #[br(count = count)]
190    #[serde(rename = "receivers_with_keys")]
191    pub endpoints_with_keys: Vec<(Endpoint, Key512)>,
192}
193
194#[derive(
195    serde::Serialize,
196    serde::Deserialize,
197    Debug,
198    Clone,
199    BinWrite,
200    BinRead,
201    PartialEq,
202)]
203pub struct Key512(#[serde(with = "serde_big_array::BigArray")] [u8; 512]);
204impl Default for Key512 {
205    fn default() -> Self {
206        Key512([0u8; 512])
207    }
208}
209impl From<[u8; 512]> for Key512 {
210    fn from(arr: [u8; 512]) -> Self {
211        Key512(arr)
212    }
213}
214
215impl ReceiverEndpointsWithKeys {
216    pub fn new<T>(endpoints_with_keys: Vec<(Endpoint, T)>) -> Self
217    where
218        T: Into<Key512>,
219    {
220        let count = endpoints_with_keys.len() as u8;
221        ReceiverEndpointsWithKeys {
222            count,
223            endpoints_with_keys: endpoints_with_keys
224                .into_iter()
225                .map(|(ep, key)| (ep, key.into()))
226                .collect(),
227        }
228    }
229}
230
231// min: 11 byte + 2 byte + 21 byte + 1 byte = 35 bytes
232#[derive(
233    serde::Serialize,
234    serde::Deserialize,
235    Debug,
236    Clone,
237    BinWrite,
238    BinRead,
239    PartialEq,
240)]
241#[brw(little, magic = b"\x01\x64")]
242pub struct RoutingHeader {
243    pub version: u8,
244    pub block_size: u16,
245    pub flags: Flags,
246
247    #[brw(if(flags.has_checksum()))]
248    checksum: Option<u32>,
249
250    pub distance: i8,
251    pub ttl: u8,
252
253    pub sender: Endpoint,
254
255    // TODO #115: add custom match receiver queries
256    #[brw(if(flags.receiver_type() == ReceiverType::Pointer))]
257    receivers_pointer_id: Option<RawFullPointerAddress>,
258    #[brw(if(flags.receiver_type() == ReceiverType::Receivers))]
259    #[serde(flatten)]
260    receivers_endpoints: Option<ReceiverEndpoints>,
261    #[brw(if(flags.receiver_type() == ReceiverType::ReceiversWithKeys))]
262    #[serde(flatten)]
263    receivers_endpoints_with_keys: Option<ReceiverEndpointsWithKeys>,
264}
265
266impl Serializable for RoutingHeader {}
267
268impl RoutingHeader {
269    pub fn unsigned() -> Self {
270        RoutingHeader {
271            flags: Flags::new().with_signature_type(SignatureType::None),
272            ..RoutingHeader::default()
273        }
274    }
275}
276
277impl Default for RoutingHeader {
278    fn default() -> Self {
279        RoutingHeader {
280            version: 1,
281            distance: 0,
282            ttl: 42,
283            flags: Flags::default(),
284            checksum: None,
285            block_size: 0,
286            sender: Endpoint::default(),
287            receivers_pointer_id: None,
288            receivers_endpoints: None,
289            receivers_endpoints_with_keys: None,
290        }
291    }
292}
293
294#[derive(Debug, Clone, PartialEq)]
295pub enum Receivers {
296    None,
297    // TODO #431 rename to PointerAddress
298    PointerId(RawFullPointerAddress),
299    Endpoints(Vec<Endpoint>),
300    EndpointsWithKeys(Vec<(Endpoint, Key512)>),
301}
302impl Display for Receivers {
303    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
304        match self {
305            Receivers::None => core::write!(f, "No receivers"),
306            Receivers::PointerId(pid) => {
307                core::write!(f, "Pointer ID: {:?}", pid)
308            }
309            Receivers::Endpoints(endpoints) => {
310                core::write!(f, "Endpoints: {:?}", endpoints)
311            }
312            Receivers::EndpointsWithKeys(endpoints_with_keys) => {
313                core::write!(
314                    f,
315                    "Endpoints with keys: {:?}",
316                    endpoints_with_keys
317                )
318            }
319        }
320    }
321}
322
323impl<T> From<T> for Receivers
324where
325    T: Into<RawFullPointerAddress>,
326{
327    fn from(pid: T) -> Self {
328        Receivers::PointerId(pid.into())
329    }
330}
331impl From<Vec<Endpoint>> for Receivers {
332    fn from(endpoints: Vec<Endpoint>) -> Self {
333        Receivers::from(endpoints.as_slice())
334    }
335}
336impl From<&Vec<Endpoint>> for Receivers {
337    fn from(endpoints: &Vec<Endpoint>) -> Self {
338        Receivers::from(endpoints.as_slice())
339    }
340}
341impl From<&[Endpoint]> for Receivers {
342    fn from(endpoints: &[Endpoint]) -> Self {
343        if endpoints.is_empty() {
344            Receivers::None
345        } else {
346            Receivers::Endpoints(endpoints.to_vec())
347        }
348    }
349}
350impl<T> From<Vec<(Endpoint, T)>> for Receivers
351where
352    T: Into<Key512>,
353{
354    fn from(endpoints_with_keys: Vec<(Endpoint, T)>) -> Self {
355        if endpoints_with_keys.is_empty() {
356            Receivers::None
357        } else {
358            Receivers::EndpointsWithKeys(
359                endpoints_with_keys
360                    .into_iter()
361                    .map(|(ep, key)| (ep, key.into()))
362                    .collect(),
363            )
364        }
365    }
366}
367
368impl RoutingHeader {
369    pub fn with_sender(&mut self, sender: Endpoint) -> &mut Self {
370        self.sender = sender;
371        self
372    }
373    pub fn with_receivers(&mut self, receivers: Receivers) -> &mut Self {
374        self.set_receivers(receivers);
375        self
376    }
377    pub fn with_ttl(&mut self, ttl: u8) -> &mut Self {
378        self.ttl = ttl;
379        self
380    }
381}
382
383impl RoutingHeader {
384    pub fn new(
385        ttl: u8,
386        flags: Flags,
387        sender: Endpoint,
388        receivers: Receivers,
389    ) -> Self {
390        let mut routing_header = RoutingHeader {
391            sender,
392            ttl,
393            flags,
394            ..RoutingHeader::default()
395        };
396        routing_header.set_receivers(receivers);
397        routing_header
398    }
399
400    pub fn set_size(&mut self, size: u16) {
401        self.block_size = size;
402    }
403
404    pub fn set_receivers(&mut self, receivers: Receivers) {
405        self.receivers_endpoints = None;
406        self.receivers_pointer_id = None;
407        self.receivers_endpoints_with_keys = None;
408        self.flags.set_receiver_type(ReceiverType::None);
409
410        match receivers {
411            Receivers::PointerId(pid) => {
412                self.receivers_pointer_id = Some(pid);
413                self.flags.set_receiver_type(ReceiverType::Pointer);
414            }
415            Receivers::Endpoints(endpoints) => {
416                if !endpoints.is_empty() {
417                    self.receivers_endpoints =
418                        Some(ReceiverEndpoints::new(endpoints));
419                    self.flags.set_receiver_type(ReceiverType::Receivers);
420                }
421            }
422            Receivers::EndpointsWithKeys(endpoints_with_keys) => {
423                if !endpoints_with_keys.is_empty() {
424                    self.receivers_endpoints_with_keys = Some(
425                        ReceiverEndpointsWithKeys::new(endpoints_with_keys),
426                    );
427                    self.flags
428                        .set_receiver_type(ReceiverType::ReceiversWithKeys);
429                }
430            }
431            Receivers::None => {}
432        }
433    }
434
435    /// Get the receivers from the routing header
436    pub fn receivers(&self) -> Receivers {
437        if let Some(pid) = &self.receivers_pointer_id {
438            Receivers::PointerId(pid.clone())
439        } else if let Some(endpoints) = &self.receivers_endpoints
440            && endpoints.count > 0
441        {
442            Receivers::Endpoints(endpoints.endpoints.clone())
443        } else if let Some(endpoints_with_keys) =
444            &self.receivers_endpoints_with_keys
445            && endpoints_with_keys.count > 0
446        {
447            Receivers::EndpointsWithKeys(
448                endpoints_with_keys.endpoints_with_keys.clone(),
449            )
450        } else {
451            Receivers::None
452        }
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use core::str::FromStr;
459
460    use super::*;
461    #[test]
462    fn single_receiver() {
463        let routing_header = RoutingHeader::default()
464            .with_sender(Endpoint::from_str("@jonas").unwrap())
465            .with_ttl(64)
466            .with_receivers(Receivers::Endpoints(vec![
467                Endpoint::from_str("@alice").unwrap(),
468            ]))
469            .to_owned();
470        assert_eq!(
471            routing_header.sender,
472            Endpoint::from_str("@jonas").unwrap()
473        );
474        assert_eq!(routing_header.ttl, 64);
475        assert_eq!(
476            routing_header.receivers(),
477            Receivers::Endpoints(vec![Endpoint::from_str("@alice").unwrap()])
478        );
479        assert_eq!(
480            routing_header.flags.receiver_type(),
481            ReceiverType::Receivers
482        );
483    }
484
485    #[test]
486    fn multiple_receivers() {
487        let routing_header = RoutingHeader::default()
488            .with_receivers(Receivers::Endpoints(vec![
489                Endpoint::from_str("@alice").unwrap(),
490                Endpoint::from_str("@bob").unwrap(),
491            ]))
492            .to_owned();
493        assert_eq!(
494            routing_header.receivers(),
495            Receivers::Endpoints(vec![
496                Endpoint::from_str("@alice").unwrap(),
497                Endpoint::from_str("@bob").unwrap(),
498            ])
499        );
500        assert_eq!(
501            routing_header.flags.receiver_type(),
502            ReceiverType::Receivers
503        );
504    }
505
506    #[test]
507    fn no_receivers() {
508        let routing_header = RoutingHeader::default()
509            .with_receivers(Receivers::None)
510            .to_owned();
511        assert_eq!(routing_header.receivers(), Receivers::None);
512
513        let routing_header = RoutingHeader::default()
514            .with_receivers(Receivers::Endpoints(vec![]))
515            .to_owned();
516        assert_eq!(routing_header.receivers(), Receivers::None);
517        assert_eq!(routing_header.flags.receiver_type(), ReceiverType::None);
518    }
519}