acn_protocol/
pdu.rs

1use crate::{error::AcnError, flags::Flags, length::Length, vector::Vector};
2use core::error::Error;
3
4pub trait PduCodec {
5    type Error: Error + From<AcnError>;
6
7    fn flags(&self) -> Flags {
8        let mut flags = Flags::default();
9
10        if self.vector_length() + self.header_length() + self.data_length()
11            >= Length::EXTENDED_LENGTH_THRESHOLD
12        {
13            flags |= Flags::EXTENDED_LENGTH;
14        }
15
16        if self.vector_length() > 0 {
17            flags |= Flags::INCLUDES_VECTOR;
18        }
19
20        if self.header_length() > 0 {
21            flags |= Flags::INCLUDES_HEADER;
22        }
23
24        if self.data_length() > 0 {
25            flags |= Flags::INCLUDES_DATA;
26        }
27
28        flags
29    }
30
31    fn length(&self) -> Length {
32        Length::new(
33            self.vector_length() + self.header_length() + self.data_length(),
34            self.flags(),
35        )
36    }
37
38    fn vector(&self) -> Option<Vector> {
39        None
40    }
41
42    fn vector_length(&self) -> usize {
43        self.vector().map_or(0, |v| v.size())
44    }
45
46    fn header_length(&self) -> usize {
47        0
48    }
49
50    fn encode_header(&self, _buf: &mut [u8]) -> Result<usize, Self::Error> {
51        Ok(0)
52    }
53
54    fn data_length(&self) -> usize {
55        0
56    }
57
58    fn encode_data(&self, _buf: &mut [u8]) -> Result<usize, Self::Error> {
59        Ok(0)
60    }
61
62    fn encode(&self, buf: &mut [u8]) -> Result<usize, Self::Error> {
63        // It's possible the buffer is not zeroed out and it is being overwritten
64        // flags and length OR the first byte so we must zero it first
65        buf[0] = 0;
66
67        self.flags().encode(buf)?;
68
69        let mut offset = self.length().encode(buf)?;
70
71        if let Some(vector) = self.vector() {
72            offset += vector.encode(&mut buf[offset..])?;
73        }
74
75        offset += self.encode_header(&mut buf[offset..])?;
76
77        offset += self.encode_data(&mut buf[offset..])?;
78
79        Ok(offset)
80    }
81
82    fn decode(buf: &[u8]) -> Result<Self, Self::Error>
83    where
84        Self: Sized;
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    struct TestPdu {
92        target_id: u8,
93        source_id: u8,
94        data: u32,
95    }
96
97    impl PduCodec for TestPdu {
98        type Error = AcnError;
99
100        fn vector(&self) -> Option<Vector> {
101            Some(Vector::U8(0x01))
102        }
103
104        fn header_length(&self) -> usize {
105            2
106        }
107
108        fn encode_header(&self, buf: &mut [u8]) -> Result<usize, AcnError> {
109            buf[0] = self.target_id;
110            buf[1] = self.source_id;
111            Ok(2)
112        }
113
114        fn data_length(&self) -> usize {
115            4
116        }
117
118        fn encode_data(&self, buf: &mut [u8]) -> Result<usize, AcnError> {
119            buf[0..4].copy_from_slice(&self.data.to_be_bytes());
120            Ok(4)
121        }
122
123        fn decode(buf: &[u8]) -> Result<Self, Self::Error> {
124            let length = Length::decode(buf)?;
125
126            let buffer_len = buf.len();
127            let expected_len = length.as_usize();
128
129            if buffer_len < expected_len {
130                return Err(AcnError::InvalidBufferLength {
131                    actual: buffer_len,
132                    expected: expected_len,
133                });
134            }
135
136            let vector = buf[3];
137
138            if vector != 0x01 {
139                return Err(AcnError::InvalidVector(vector.into()));
140            }
141
142            let target_id = buf[4];
143            let source_id = buf[5];
144            let data = u32::from_be_bytes(buf[6..10].try_into()?);
145
146            Ok(TestPdu {
147                target_id,
148                source_id,
149                data,
150            })
151        }
152    }
153
154    #[test]
155    fn should_encode_pdu() {
156        let buf = &mut [0u8; 9];
157
158        let pdu = TestPdu {
159            target_id: 0x01,
160            source_id: 0x02,
161            data: 0x12345678,
162        };
163
164        let flags = pdu.flags();
165
166        let expected_flags = Flags::INCLUDES_VECTOR | Flags::INCLUDES_HEADER | Flags::INCLUDES_DATA;
167
168        assert_eq!(flags, expected_flags);
169
170        let length = pdu.length();
171        assert_eq!(length.size(), 2);
172        assert_eq!(length.as_u32(), 9);
173
174        pdu.encode(buf).unwrap();
175
176        assert_eq!(
177            buf,
178            &[
179                expected_flags.bits(), // Flags |= Length MSB
180                0x09,                  // Length LSB
181                0x01,                  // Vector
182                0x01,                  // Header Byte 1 = Target ID
183                0x02,                  // Header Byte 2 = Source ID
184                0x12,                  // Data Byte 1
185                0x34,                  // Data Byte 2
186                0x56,                  // Data Byte 3
187                0x78,                  // Data Byte 4
188            ]
189        );
190    }
191}