1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::io::Write;

pub const HEADER: [u8; 2] = ['B' as u8, 'R' as u8];

#[derive(Clone, Debug, Default, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ProtocolMessage {
    pub payload_length: u16,
    pub message_id: u16,
    pub src_device_id: u8,
    pub dst_device_id: u8,
    pub payload: Vec<u8>,
    pub checksum: u16,
}

impl ProtocolMessage {
    /**
     * Message Format
     *
     * Each message consists of a header, optional payload, and checksum. The binary format is specified as follows:
     *
     * | Byte        | Type | Name           | Description                                                                                               |
     * |-------------|------|----------------|-----------------------------------------------------------------------------------------------------------|
     * | 0           | u8   | start1         | Start frame identifier, ASCII 'B'                                                                         |
     * | 1           | u8   | start2         | Start frame identifier, ASCII 'R'                                                                         |
     * | 2-3         | u16  | payload_length | Number of bytes in payload.                                                                               |
     * | 4-5         | u16  | message_id     | The message id.                                                                                           |
     * | 6           | u8   | src_device_id  | The device ID of the device sending the message.                                                          |
     * | 7           | u8   | dst_device_id  | The device ID of the intended recipient of the message.                                                   |
     * | 8-n         | u8[] | payload        | The message payload.                                                                                      |
     * | (n+1)-(n+2) | u16  | checksum       | The message checksum. The checksum is calculated as the sum of all the non-checksum bytes in the message. |
     */

    pub fn new() -> Self {
        Default::default()
    }

    // Assuming PingMessage is a trait that your code defines
    pub fn set_message(&mut self, message: &impl PingMessage) {
        self.message_id = message.message_id();
        self.payload = message.serialize(); // Assuming serialize returns Vec<u8>
        self.payload_length = self.payload.len() as u16;
        self.update_checksum();
    }

    #[inline]
    pub fn set_src_device_id(&mut self, src_device_id: u8) {
        self.src_device_id = src_device_id;
        self.update_checksum();
    }

    #[inline]
    pub fn dst_device_id(&self) -> u8 {
        self.dst_device_id
    }

    #[inline]
    pub fn set_dst_device_id(&mut self, dst_device_id: u8) {
        self.dst_device_id = dst_device_id;
        self.update_checksum();
    }

    #[inline]
    pub fn payload(&self) -> &[u8] {
        &self.payload
    }

    #[inline]
    pub fn checksum(&self) -> u16 {
        self.checksum
    }

    #[inline]
    pub fn update_checksum(&mut self) {
        self.checksum = self.calculate_crc();
    }

    pub fn calculate_crc(&self) -> u16 {
        let mut checksum: u16 = 0;
        checksum += HEADER[0] as u16;
        checksum += HEADER[1] as u16;
        self.payload_length
            .to_le_bytes()
            .iter()
            .for_each(|byte| checksum += *byte as u16);
        self.message_id
            .to_le_bytes()
            .iter()
            .for_each(|byte| checksum += *byte as u16);
        checksum += self.src_device_id as u16;
        checksum += self.dst_device_id as u16;
        for &byte in &self.payload {
            checksum += byte as u16;
        }
        checksum
    }

    pub fn has_valid_crc(&self) -> bool {
        self.checksum == self.calculate_crc()
    }

    pub fn length(&self) -> usize {
        HEADER.len() + 2 + 2 + 1 + 1 + self.payload_length as usize + 2
    }

    pub fn write(&self, writer: &mut dyn Write) -> std::io::Result<usize> {
        let data = self.serialized();
        writer.write_all(&data)?;
        Ok(data.len())
    }

    pub fn serialized(&self) -> Vec<u8> {
        let mut serialized_data = Vec::with_capacity(self.length());
        serialized_data.extend_from_slice(&HEADER);
        serialized_data.extend_from_slice(&self.payload_length.to_le_bytes());
        serialized_data.extend_from_slice(&self.message_id.to_le_bytes());
        serialized_data.push(self.src_device_id);
        serialized_data.push(self.dst_device_id);
        serialized_data.extend_from_slice(&self.payload);
        serialized_data.extend_from_slice(&self.checksum.to_le_bytes());
        serialized_data
    }
}

// This information is only related to the message itself,
// not the entire package with header, dst, src and etc.
pub trait PingMessage
where
    Self: Sized + SerializePayload + SerializePayload,
{
    fn message_id(&self) -> u16;
    fn message_name(&self) -> &'static str;

    fn message_id_from_name(name: &str) -> Result<u16, String>;
}

pub trait SerializePayload {
    fn serialize(&self) -> Vec<u8>;
}

pub trait DeserializePayload {
    fn deserialize(payload: &[u8]) -> Self;
}

pub trait MessageInfo {
    fn id() -> u16;
}

pub trait DeserializeGenericMessage
where
    Self: Sized,
{
    fn deserialize(message_id: u16, payload: &[u8]) -> Result<Self, &'static str>;
}