multicast_discovery_socket/
protocol.rs

1use std::borrow::Cow;
2use log::debug;
3use sha2::Digest;
4use crate::AdvertisementData;
5
6/// Message kind for `DiscoveryMessage` type
7#[derive(Copy, Clone)]
8pub enum DiscoveryMessageKind {
9    Discovery,
10    Announce,
11    ExtendAnnouncements,
12}
13
14impl DiscoveryMessageKind {
15    fn header(&self) -> &'static [u8] {
16        match self {
17            DiscoveryMessageKind::Discovery => b"discovery",
18            DiscoveryMessageKind::Announce => b"announce",
19            DiscoveryMessageKind::ExtendAnnouncements => b"extend-announcements",
20        }
21    }
22}
23
24#[derive(Clone)]
25pub enum DiscoveryMessage<'a, D: AdvertisementData> {
26    /// Ping packet used to trigger other endpoints to send Announce packet back
27    Discovery,
28    /// Tell other endpoints that we are running and available for making connection
29    Announce {
30        service_port: u16,
31        discover_id: u32,
32        disconnected: bool,
33        adv_data: Cow<'a, D>,
34    },
35    /// Request for endpoints on Primary and Backup ports to extend their announcements scope to Backup ports as well
36    ExtendAnnouncements,
37}
38
39impl<D: AdvertisementData> DiscoveryMessage<'_, D> {
40    fn msg_type(&self) -> DiscoveryMessageKind {
41        match self {
42            DiscoveryMessage::Discovery => DiscoveryMessageKind::Discovery,
43            DiscoveryMessage::Announce { .. } => DiscoveryMessageKind::Announce,
44            DiscoveryMessage::ExtendAnnouncements => DiscoveryMessageKind::ExtendAnnouncements,
45        }
46    }
47    pub(crate) fn try_parse(msg: &[u8]) -> Option<Self> {
48        if msg.len() < 32 + 1 {
49            debug!("Packet is too small, ignoring...");
50            return None;
51        }
52        let msg_len = msg.len() - 32;
53        let sha = sha2::Sha256::digest(&msg[..msg_len]);
54        if !msg.ends_with(&sha[..32]) {
55            debug!("Incorrect sha2, ignoring message...");
56            return None;
57        }
58        if msg.starts_with(DiscoveryMessageKind::Discovery.header())
59            && msg_len == DiscoveryMessageKind::Discovery.header().len() {
60            Some(DiscoveryMessage::Discovery)
61        } else if msg.starts_with(DiscoveryMessageKind::Announce.header())
62            && msg_len >= DiscoveryMessageKind::Announce.header().len() + 2 + 4 + 1 {
63            let msg_body = &msg[DiscoveryMessageKind::Announce.header().len()..msg_len];
64            let service_port = u16::from_be_bytes(msg_body[0..2].try_into().unwrap());
65            let discover_id = u32::from_be_bytes(msg_body[2..6].try_into().unwrap());
66            let disconnected = msg_body[6] != 0;
67            let adv_data_body = &msg_body[7..];
68            let adv_data = D::try_decode(adv_data_body)?;
69
70            Some(DiscoveryMessage::Announce {
71                adv_data: Cow::Owned(adv_data),
72                service_port,
73                disconnected,
74                discover_id
75            })
76        } else if msg.starts_with(DiscoveryMessageKind::ExtendAnnouncements.header())
77            && msg_len == DiscoveryMessageKind::ExtendAnnouncements.header().len() {
78            Some(DiscoveryMessage::ExtendAnnouncements)
79        } else {
80            None
81        }
82    }
83    pub(crate) fn gen_message(&self) -> Vec<u8> {
84        let header = self.msg_type().header();
85        let mut message = match self {
86            DiscoveryMessage::Discovery => header.to_vec(),
87            DiscoveryMessage::Announce { service_port, discover_id, disconnected, adv_data } => {
88                let mut hello_msg = header.to_vec();
89                hello_msg.extend_from_slice(service_port.to_be_bytes().as_ref());
90                hello_msg.extend_from_slice(discover_id.to_be_bytes().as_ref());
91                hello_msg.push(*disconnected as u8);
92                hello_msg.extend_from_slice(&adv_data.encode_to_bytes());
93                hello_msg
94            }
95            DiscoveryMessage::ExtendAnnouncements => header.to_vec(),
96        };
97
98        let sha = sha2::Sha256::digest(&message);
99        message.extend_from_slice(&sha[..32]);
100        message
101    }
102}