open_dis_rust/
pdu_macro.rs

1//     open-dis-rust - Rust implementation of the IEEE 1278.1-2012 Distributed Interactive
2//                     Simulation (DIS) application protocol
3//     Copyright (C) 2025 Cameron Howell
4//
5//     Licensed under the BSD 2-Clause License
6
7//! A macro system for generating PDUs with trait-based serialization/length/deserialize.
8//! Place this file at crate root and `pub mod pdu_macro;` in lib.rs.
9//! Use `use crate::define_pdu;` in PDU modules.
10
11use bytes::{Buf, BufMut, BytesMut};
12
13/// Serialize a single field into the buffer.
14pub trait FieldSerialize {
15    fn serialize_field(&self, buf: &mut BytesMut);
16}
17
18/// Deserialize a single field from the buffer.
19pub trait FieldDeserialize: Sized {
20    fn deserialize_field<B: Buf>(buf: &mut B) -> Self;
21}
22
23/// Return the serialized length of this field in bytes.
24pub trait FieldLen {
25    fn field_len(&self) -> usize;
26}
27
28/// Trait for types that can be deserialized given an externally-provided length.
29/// Used by the macro when a field is annotated with `#[len = length_field_name]`.
30pub trait FieldDeserializeWithLen: Sized {
31    fn deserialize_with_len<B: Buf>(buf: &mut B, len: usize) -> Self;
32}
33
34// Blanket impl so `Option<T>` can be deserialized with an externally-provided length
35impl<T> FieldDeserializeWithLen for Option<T>
36where
37    T: FieldDeserializeWithLen,
38{
39    fn deserialize_with_len<B: Buf>(buf: &mut B, len: usize) -> Self {
40        if len == 0 {
41            None
42        } else {
43            Some(<T as FieldDeserializeWithLen>::deserialize_with_len(
44                buf, len,
45            ))
46        }
47    }
48}
49
50// Helper macros for generated code. These are kept private to the macro expansion
51// but exported so they can be used from the `define_pdu!` expansion.
52#[macro_export]
53macro_rules! __pdu_prep_serialize_field {
54    // When the field has a length attribute, set the length field before length calculation.
55    ( len = $len_field:ident ; $self:ident, $field:ident, Option<$inner:ty> ) => {
56        $self.$len_field = $self.$field.as_ref().map_or(0u8, |v| {
57            <$inner as $crate::pdu_macro::FieldLen>::field_len(v) as u8
58        });
59    };
60
61    ( len = $len_field:ident ; $self:ident, $field:ident, $t:ty ) => {
62        // For non-option fields with length attribute, set the length from the inner value.
63        $self.$len_field = u8::try_from(<$t as $crate::pdu_macro::FieldLen>::field_len(
64            &$self.$field,
65        ))
66        .unwrap_or_default();
67    };
68
69    // Default: no-op
70    ( ; $self:ident, $field:ident, $t:ty ) => {
71        // nothing to do
72    };
73}
74
75#[macro_export]
76macro_rules! __pdu_deserialize_field {
77    // Option<T> with a length attribute -> read using FieldDeserializeWithLen
78    ( len = $len_field:ident ; $field:ident, Option<$inner:ty>, $buf:ident ) => {
79        let $field: Option<$inner> = {
80            let len_val = $len_field as usize;
81            if len_val == 0 {
82                None
83            } else {
84                Some(
85                    <$inner as $crate::pdu_macro::FieldDeserializeWithLen>::deserialize_with_len(
86                        $buf, len_val,
87                    ),
88                )
89            }
90        };
91    };
92
93    // T with length attribute (non-Option)
94    ( len = $len_field:ident ; $field:ident, $t:ty, $buf:ident ) => {
95        let $field: $t = <$t as $crate::pdu_macro::FieldDeserializeWithLen>::deserialize_with_len(
96            $buf,
97            $len_field as usize,
98        );
99    };
100
101    // Default: plain FieldDeserialize
102    ( ; $field:ident, $t:ty, $buf:ident ) => {
103        let $field: $t = <$t as $crate::pdu_macro::FieldDeserialize>::deserialize_field($buf);
104    };
105}
106
107// ------ Implementations for primitive types ------
108
109macro_rules! impl_primitive {
110    ($ty:ty, $put:ident, $get:ident, $len:expr) => {
111        impl FieldSerialize for $ty {
112            fn serialize_field(&self, buf: &mut BytesMut) {
113                buf.$put(*self);
114            }
115        }
116
117        impl FieldDeserialize for $ty {
118            fn deserialize_field<B: Buf>(buf: &mut B) -> Self {
119                buf.$get()
120            }
121        }
122
123        impl FieldLen for $ty {
124            fn field_len(&self) -> usize {
125                $len
126            }
127        }
128    };
129}
130
131impl_primitive!(u8, put_u8, get_u8, 1usize);
132impl_primitive!(i8, put_i8, get_i8, 1usize);
133impl_primitive!(u16, put_u16, get_u16, 2usize);
134impl_primitive!(i16, put_i16, get_i16, 2usize);
135impl_primitive!(u32, put_u32, get_u32, 4usize);
136impl_primitive!(i32, put_i32, get_i32, 4usize);
137impl_primitive!(u64, put_u64, get_u64, 8usize);
138impl_primitive!(i64, put_i64, get_i64, 8usize);
139impl_primitive!(f32, put_f32, get_f32, 4usize);
140impl_primitive!(f64, put_f64, get_f64, 8usize);
141
142// String: serialized_len = bytes in UTF-8 (no extra length prefix; adapt if your PDU requires a length prefix)
143impl FieldSerialize for String {
144    fn serialize_field(&self, buf: &mut BytesMut) {
145        buf.put_slice(self.as_bytes());
146    }
147}
148impl FieldDeserialize for String {
149    fn deserialize_field<B: Buf>(_buf: &mut B) -> Self {
150        Self::new()
151    }
152}
153impl FieldLen for String {
154    fn field_len(&self) -> usize {
155        self.len()
156    }
157}
158
159// Vec<T>
160impl<T> FieldSerialize for Vec<T>
161where
162    T: FieldSerialize,
163{
164    fn serialize_field(&self, buf: &mut BytesMut) {
165        for item in self {
166            item.serialize_field(buf);
167        }
168    }
169}
170impl<T> FieldDeserialize for Vec<T>
171where
172    T: FieldDeserialize,
173{
174    fn deserialize_field<B: Buf>(_buf: &mut B) -> Self {
175        Self::new()
176    }
177}
178impl<T> FieldLen for Vec<T>
179where
180    T: FieldLen,
181{
182    fn field_len(&self) -> usize {
183        self.iter().map(FieldLen::field_len).sum()
184    }
185}
186
187// Option<T>
188impl<T> FieldSerialize for Option<T>
189where
190    T: FieldSerialize,
191{
192    fn serialize_field(&self, buf: &mut BytesMut) {
193        if let Some(v) = self.as_ref() {
194            v.serialize_field(buf);
195        }
196    }
197}
198impl<T> FieldDeserialize for Option<T>
199where
200    T: FieldDeserialize,
201{
202    fn deserialize_field<B: Buf>(_buf: &mut B) -> Self {
203        None
204    }
205}
206impl<T> FieldLen for Option<T>
207where
208    T: FieldLen,
209{
210    fn field_len(&self) -> usize {
211        self.as_ref().map_or(0, FieldLen::field_len)
212    }
213}
214
215#[macro_export]
216macro_rules! define_pdu {
217    (
218        $(#[$meta:meta])*
219        $vis:vis struct $name:ident {
220            header: $header:ty,
221            pdu_type: $pdu_type:expr,
222            protocol_family: $protocol_family:expr,
223            fields: {
224                $(
225                    $(#[len = $len_field:ident])? $fvis:vis $field:ident : $ftype:ty,
226                )*
227            }
228
229        }
230    ) => {
231        $(#[$meta])*
232        $vis struct $name {
233            header: $header,
234            $(
235                $fvis $field : $ftype,
236            )*
237        }
238
239        // Default impl using Default::default for header/fields (requires those impls)
240        impl Default for $name {
241            fn default() -> Self {
242                Self {
243                    header: <$header>::default(),
244                    $(
245                        $field: <$ftype>::default(),
246                    )*
247                }
248            }
249        }
250
251        // Body deserializer generated in terms of FieldDeserialize
252        impl $name {
253            /// Deserialize only the body (fields), leaving header defaulted.
254            /// Note: for variable-length arrays/strings the generated code will call
255            /// `FieldDeserialize::deserialize_field()`, but for real variable-length fields
256            /// you should write custom code in the manual body impl below or adapt the macro.
257            fn deserialize_body<B: bytes::Buf>(buf: &mut B) -> Self {
258                $(
259                    // Each field can optionally be annotated with `#[len = name]`.
260                    // The helper macro below will either call the plain `FieldDeserialize`
261                    // or the length-aware `FieldDeserializeWithLen` depending on the
262                    // annotation.
263                    $crate::__pdu_deserialize_field!( $( len = $len_field )? ; $field, $ftype, buf );
264                )*
265
266                Self {
267                    header: <$header>::default(),
268                    $(
269                        $field,
270                    )*
271                }
272            }
273        }
274
275        // Implement the Pdu trait (your crate's Pdu trait path may differ; adapt the path)
276        impl $crate::common::pdu::Pdu for $name {
277            type Header = $header;
278
279            fn calculate_length(&self) -> Result<u16, $crate::common::dis_error::DISError> {
280                // Start with header length const; requires header::LENGTH const
281                let mut len: usize = <$header>::LENGTH;
282
283                $(
284                    len += <$ftype as $crate::pdu_macro::FieldLen>::field_len(&self.$field);
285                )*
286
287                u16::try_from(len).map_err(|_| $crate::common::dis_error::DISError::PduSizeExceeded {
288                    size: len,
289                    max_size: $crate::common::constants::MAX_PDU_SIZE_OCTETS,
290                })
291            }
292
293            fn header(&self) -> &Self::Header {
294                &self.header
295            }
296
297            fn header_mut(&mut self) -> &mut Self::Header {
298                &mut self.header
299            }
300
301            fn serialize(&mut self, buf: &mut bytes::BytesMut) -> Result<(), $crate::common::dis_error::DISError> {
302                // set header fields
303                self.header.set_pdu_type($pdu_type);
304                self.header.set_protocol_family($protocol_family);
305
306                // Allow annotated fields to update their associated "length" fields
307                // before we compute the overall PDU length. If a field is annotated
308                // `#[len = foo]` the prep macro will set `self.foo` appropriately.
309                $( $crate::__pdu_prep_serialize_field!( $( len = $len_field )? ; self, $field, $ftype ); )*
310
311                // compute length the correct way and set it
312                let len = self.calculate_length()?;
313                self.header.set_length(len);
314
315                // write header
316                self.header.serialize(buf);
317
318                // serialize each field
319                $(
320                    <$ftype as $crate::pdu_macro::FieldSerialize>::serialize_field(&self.$field, buf);
321                )*
322
323                Ok(())
324            }
325
326            fn deserialize<B: bytes::Buf>(buf: &mut B) -> Result<Self, $crate::common::dis_error::DISError>
327            where Self: Sized {
328                // deserialize header using its associated function
329                let header: Self::Header = <Self::Header as $crate::pdu_macro::FieldDeserialize>::deserialize_field(buf);
330
331                // check PDU type (assumes header exposes pdu_type() method; adapt if different)
332                if header.pdu_type() != $pdu_type {
333                    return Err($crate::common::dis_error::DISError::invalid_header(
334                        format!("Expected PDU type {:?}, got {:?}", $pdu_type, header.pdu_type()),
335                        None,
336                    ));
337                }
338
339                // read body fields with the generated deserializer
340                let mut body = Self::deserialize_body(buf);
341                body.header = header;
342                Ok(body)
343            }
344
345            fn deserialize_without_header<B: bytes::Buf>(buf: &mut B, header: Self::Header) -> Result<Self, $crate::common::dis_error::DISError>
346            where Self: Sized {
347                let mut body = Self::deserialize_body(buf);
348                body.header = header;
349                Ok(body)
350            }
351
352            fn as_any(&self) -> &dyn std::any::Any {
353                self
354            }
355        }
356
357        // Provide a convenience constructor (matching your existing API)
358        impl $name {
359            #[must_use]
360            pub fn new() -> Self {
361                let mut pdu = Self::default();
362                pdu.header.set_pdu_type($pdu_type);
363                pdu.header.set_protocol_family($protocol_family);
364                pdu.finalize();
365                pdu
366            }
367        }
368    };
369}