acn_protocol/
pdu.rs

1use crate::{error::AcnError, flags::Flags, length::Length, vector::Vector};
2use core::fmt;
3
4pub trait PduCodec {
5    type Error: fmt::Debug + core::error::Error;
6
7    fn flags(&self) -> Flags {
8        let mut flags = Flags::default();
9
10        if self.vector_length() + self.header_length() + self.data_length()
11            >= Length::EXTENDED_LENGTH_THRESHOLD
12        {
13            flags |= Flags::EXTENDED_LENGTH;
14        }
15
16        if self.vector_length() > 0 {
17            flags |= Flags::INCLUDES_VECTOR;
18        }
19
20        if self.header_length() > 0 {
21            flags |= Flags::INCLUDES_HEADER;
22        }
23
24        if self.data_length() > 0 {
25            flags |= Flags::INCLUDES_DATA;
26        }
27
28        flags
29    }
30
31    fn length(&self) -> Length {
32        Length::new(
33            self.vector_length() + self.header_length() + self.data_length(),
34            self.flags(),
35        )
36    }
37
38    fn vector(&self) -> Option<Vector> {
39        None
40    }
41
42    fn vector_length(&self) -> usize {
43        self.vector().map_or(0, |v| v.size())
44    }
45
46    fn header_length(&self) -> usize {
47        0
48    }
49
50    fn encode_header(&self, _buf: &mut [u8]) -> Result<usize, AcnError> {
51        Ok(0)
52    }
53
54    fn data_length(&self) -> usize {
55        0
56    }
57
58    fn encode_data(&self, _buf: &mut [u8]) -> Result<usize, AcnError> {
59        Ok(0)
60    }
61
62    fn encode(&self, buf: &mut [u8]) -> Result<usize, AcnError> {
63        // It's possible the buffer is not zeroed out and it is being overwritten
64        // flags and length OR the first byte so we must zero it first
65        buf[0] = 0;
66        
67        self.flags().encode(buf)?;
68
69        let mut offset = self.length().encode(buf)?;
70
71        if let Some(vector) = self.vector() {
72            offset += vector.encode(&mut buf[offset..])?;
73        }
74
75        offset += self.encode_header(&mut buf[offset..])?;
76
77        offset += self.encode_data(&mut buf[offset..])?;
78
79        Ok(offset)
80    }
81
82    fn decode(buf: &[u8]) -> Result<Self, Self::Error>
83    where
84        Self: Sized;
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    struct TestPdu {
92        target_id: u8,
93        source_id: u8,
94        data: u32,
95    }
96
97    impl PduCodec for TestPdu {
98        type Error = AcnError;
99
100        fn vector(&self) -> Option<Vector> {
101            Some(Vector::U8(0x01))
102        }
103
104        fn header_length(&self) -> usize {
105            2
106        }
107
108        fn encode_header(&self, buf: &mut [u8]) -> Result<usize, AcnError> {
109            buf[0] = self.target_id;
110            buf[1] = self.source_id;
111            Ok(2)
112        }
113
114        fn data_length(&self) -> usize {
115            4
116        }
117
118        fn encode_data(&self, buf: &mut [u8]) -> Result<usize, AcnError> {
119            buf[0..4].copy_from_slice(&self.data.to_be_bytes());
120            Ok(4)
121        }
122
123        fn decode(buf: &[u8]) -> Result<Self, Self::Error> {
124            let length = Length::decode(buf)?;
125
126            if (buf.len() as u32) < length.as_u32() {
127                return Err(AcnError::InvalidBufferLength(buf.len()));
128            }
129
130            let vector = buf[3];
131
132            if vector != 0x01 {
133                return Err(AcnError::InvalidVector(vector.into()));
134            }
135
136            let target_id = buf[4];
137            let source_id = buf[5];
138            let data = u32::from_be_bytes(buf[6..10].try_into()?);
139
140            Ok(TestPdu {
141                target_id,
142                source_id,
143                data,
144            })
145        }
146    }
147
148    #[test]
149    fn should_encode_pdu() {
150        let buf = &mut [0u8; 9];
151
152        let pdu = TestPdu {
153            target_id: 0x01,
154            source_id: 0x02,
155            data: 0x12345678,
156        };
157
158        let flags = pdu.flags();
159
160        let expected_flags = Flags::INCLUDES_VECTOR | Flags::INCLUDES_HEADER | Flags::INCLUDES_DATA;
161
162        assert_eq!(flags, expected_flags);
163
164        let length = pdu.length();
165        assert_eq!(length.size(), 2);
166        assert_eq!(length.as_u32(), 9);
167
168        pdu.encode(buf).unwrap();
169
170        assert_eq!(
171            buf,
172            &[
173                expected_flags.bits(), // Flags |= Length MSB
174                0x09,                  // Length LSB
175                0x01,                  // Vector
176                0x01,                  // Header Byte 1 = Target ID
177                0x02,                  // Header Byte 2 = Source ID
178                0x12,                  // Data Byte 1
179                0x34,                  // Data Byte 2
180                0x56,                  // Data Byte 3
181                0x78,                  // Data Byte 4
182            ]
183        );
184    }
185}