Skip to main content

aetheris_encoder_serde/
serde_encoder.rs

1//! Implementation of the `SerdeEncoder` using `rmp-serde`.
2
3use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12/// Internal header for serialized replication events.
13///
14/// Ensures a stable binary format across different `rmp-serde` configurations.
15#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17    network_id: NetworkId,
18    component_kind: ComponentKind,
19    tick: u64,
20}
21
22/// A `serde`-based encoder that uses `rmp-serde` (`MessagePack`) for binary serialization.
23///
24/// This implementation targets Phase 1 MVP requirements for rapid iteration.
25/// It uses a fixed-size header followed by the raw component payload.
26#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30    /// Creates a new `SerdeEncoder`.
31    #[must_use]
32    pub fn new() -> Self {
33        Self
34    }
35
36    /// Encodes a `NetworkEvent` into raw bytes for transmission.
37    ///
38    /// # Errors
39    /// Returns `EncodeError` if the event fails to serialize or is a local-only variant.
40    pub fn encode_event(
41        &self,
42        event: &aetheris_protocol::events::NetworkEvent,
43    ) -> Result<Vec<u8>, EncodeError> {
44        let wire_event = match event {
45            NetworkEvent::Ping { tick, .. } => WireEvent::Ping { tick: *tick },
46            NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
47            NetworkEvent::Auth { session_token } => WireEvent::Auth {
48                session_token: session_token.clone(),
49            },
50            NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
51            NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
52                count: *count,
53                rotate: *rotate,
54            },
55            NetworkEvent::Spawn {
56                entity_type,
57                x,
58                y,
59                rot,
60                ..
61            } => WireEvent::Spawn {
62                entity_type: *entity_type,
63                x: *x,
64                y: *y,
65                rot: *rot,
66            },
67            NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
68            NetworkEvent::ClientConnected(_)
69            | NetworkEvent::ClientDisconnected(_)
70            | NetworkEvent::UnreliableMessage { .. }
71            | NetworkEvent::ReliableMessage { .. }
72            | NetworkEvent::SessionClosed(_)
73            | NetworkEvent::StreamReset(_) => {
74                return Err(EncodeError::Io(std::io::Error::other(format!(
75                    "Cannot encode local-only variant as wire event: {event:?}"
76                ))));
77            }
78        };
79        rmp_serde::to_vec(&wire_event)
80            .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
81    }
82
83    /// Decodes raw bytes into a `NetworkEvent`.
84    ///
85    /// # Errors
86    /// Returns `EncodeError` if the bytes are not a valid `WireEvent`.
87    pub fn decode_event(
88        &self,
89        data: &[u8],
90    ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
91        let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
92            EncodeError::MalformedPayload {
93                offset: 0, // In Phase 1 we don't track exact rmp-serde offset easily
94                message: e.to_string(),
95            }
96        })?;
97
98        Ok(match wire_event {
99            WireEvent::Ping { tick } => NetworkEvent::Ping {
100                client_id: ClientId(0), // Populated by transport/server
101                tick,
102            },
103            WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
104            WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
105            WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
106                client_id: ClientId(0),
107                fragment,
108            },
109            WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
110                client_id: ClientId(0), // Populated by server
111                count,
112                rotate,
113            },
114            WireEvent::Spawn {
115                entity_type,
116                x,
117                y,
118                rot,
119            } => NetworkEvent::Spawn {
120                client_id: ClientId(0),
121                entity_type,
122                x,
123                y,
124                rot,
125            },
126            WireEvent::ClearWorld => NetworkEvent::ClearWorld {
127                client_id: ClientId(0),
128            },
129        })
130    }
131}
132
133impl Encoder for SerdeEncoder {
134    fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
135        self.encode_event(event)
136    }
137
138    fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
139        self.decode_event(data)
140    }
141
142    fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
143        #[cfg(not(target_arch = "wasm32"))]
144        let start = std::time::Instant::now();
145        let header = PacketHeader {
146            network_id: event.network_id,
147            component_kind: event.component_kind,
148            tick: event.tick,
149        };
150
151        let mut cursor = Cursor::new(buffer);
152        let mut serializer = rmp_serde::Serializer::new(&mut cursor);
153
154        header.serialize(&mut serializer).map_err(|_e| {
155            metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
156                .increment(1);
157            // If it fails to serialize, it's likely a buffer overflow.
158            EncodeError::BufferOverflow {
159                needed: 32, // PacketHeader is small (~20 bytes)
160                available: cursor.get_ref().len(),
161            }
162        })?;
163
164        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
165        let payload_len = event.payload.len();
166        let total_needed = header_len + payload_len;
167
168        if total_needed > self.max_encoded_size() {
169            return Err(EncodeError::BufferOverflow {
170                needed: total_needed,
171                available: self.max_encoded_size(),
172            });
173        }
174
175        if total_needed > cursor.get_ref().len() {
176            metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
177                .increment(1);
178            return Err(EncodeError::BufferOverflow {
179                needed: total_needed,
180                available: cursor.get_ref().len(),
181            });
182        }
183
184        // Copy payload manually after the header
185        let slice = cursor.into_inner();
186        slice[header_len..total_needed].copy_from_slice(&event.payload);
187
188        #[allow(clippy::cast_precision_loss)]
189        metrics::histogram!(
190            "aetheris_encoder_payload_size_bytes",
191            "operation" => "encode"
192        )
193        .record(total_needed as f64);
194
195        #[cfg(not(target_arch = "wasm32"))]
196        metrics::histogram!(
197            "aetheris_encoder_encode_duration_seconds",
198            "kind" => event.component_kind.0.to_string()
199        )
200        .record(start.elapsed().as_secs_f64());
201
202        Ok(total_needed)
203    }
204
205    fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
206        #[cfg(not(target_arch = "wasm32"))]
207        let start = std::time::Instant::now();
208        let mut cursor = Cursor::new(buffer);
209        let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
210
211        let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
212            metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
213                .increment(1);
214            EncodeError::MalformedPayload {
215                offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
216                message: e.to_string(),
217            }
218        })?;
219
220        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
221        let payload = buffer
222            .get(header_len..)
223            .ok_or(EncodeError::MalformedPayload {
224                offset: header_len,
225                message: "Payload slice out of bounds".to_string(),
226            })?
227            .to_vec();
228
229        #[allow(clippy::cast_precision_loss)]
230        metrics::histogram!(
231            "aetheris_encoder_payload_size_bytes",
232            "operation" => "decode"
233        )
234        .record(buffer.len() as f64);
235
236        #[cfg(not(target_arch = "wasm32"))]
237        metrics::histogram!(
238            "aetheris_encoder_decode_duration_seconds",
239            "kind" => header.component_kind.0.to_string()
240        )
241        .record(start.elapsed().as_secs_f64());
242
243        Ok(ComponentUpdate {
244            network_id: header.network_id,
245            component_kind: header.component_kind,
246            payload,
247            tick: header.tick,
248        })
249    }
250
251    fn max_encoded_size(&self) -> usize {
252        aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use proptest::prelude::*;
260
261    #[test]
262    fn test_roundtrip() {
263        let encoder = SerdeEncoder::new();
264        let event = ReplicationEvent {
265            network_id: NetworkId(42),
266            component_kind: ComponentKind(1),
267            payload: vec![1, 2, 3, 4],
268            tick: 100,
269        };
270
271        let mut buffer = [0u8; 1200];
272        let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
273        assert!(bytes_written > 0);
274
275        let update = encoder.decode(&buffer[..bytes_written]).unwrap();
276        assert_eq!(update.network_id, event.network_id);
277        assert_eq!(update.component_kind, event.component_kind);
278        assert_eq!(update.tick, event.tick);
279        assert_eq!(update.payload, event.payload);
280    }
281    #[test]
282    fn test_fragment_roundtrip() {
283        let encoder = SerdeEncoder::new();
284        let fragment = aetheris_protocol::events::FragmentedEvent {
285            message_id: 123,
286            fragment_index: 1,
287            total_fragments: 5,
288            payload: vec![1, 2, 3],
289        };
290
291        let event = NetworkEvent::Fragment {
292            client_id: aetheris_protocol::types::ClientId(0),
293            fragment: fragment.clone(),
294        };
295
296        let output = encoder.encode_event(&event).unwrap();
297        let decoded = encoder.decode_event(&output).unwrap();
298
299        if let NetworkEvent::Fragment {
300            client_id: _,
301            fragment: decoded_fragment,
302        } = decoded
303        {
304            assert_eq!(decoded_fragment.message_id, fragment.message_id);
305            assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
306            assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
307            assert_eq!(decoded_fragment.payload, fragment.payload);
308        } else {
309            panic!("Decoded event is not a Fragment: {decoded:?}");
310        }
311    }
312
313    #[test]
314    fn test_buffer_overflow() {
315        let encoder = SerdeEncoder::new();
316        let event = ReplicationEvent {
317            network_id: NetworkId(42),
318            component_kind: ComponentKind(1),
319            payload: vec![1, 2, 3, 4],
320            tick: 100,
321        };
322
323        let mut small_buffer = [0u8; 1];
324        let result = encoder.encode(&event, &mut small_buffer);
325        assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
326    }
327
328    #[test]
329    fn test_malformed_payload() {
330        let encoder = SerdeEncoder::new();
331        let garbage = [0xff, 0xff, 0xff, 0xff];
332        let result = encoder.decode(&garbage);
333        if let Err(EncodeError::MalformedPayload { message, .. }) = result {
334            assert!(!message.is_empty());
335        } else {
336            panic!("Expected MalformedPayload error, got {result:?}");
337        }
338    }
339
340    proptest! {
341        #[test]
342        fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
343            let encoder = SerdeEncoder::new();
344            // Should never panic
345            let _ = encoder.decode(bytes);
346        }
347
348        #[test]
349        fn test_fuzz_roundtrip(
350            nid in any::<u64>(),
351            kind in any::<u16>(),
352            tick in any::<u64>(),
353            ref payload in any::<Vec<u8>>()
354        ) {
355            let encoder = SerdeEncoder::new();
356            let event = ReplicationEvent {
357                network_id: NetworkId(nid),
358                component_kind: ComponentKind(kind),
359                payload: payload.clone(),
360                tick,
361            };
362
363            let mut buffer = vec![0u8; 2048 + payload.len()];
364            if let Ok(written) = encoder.encode(&event, &mut buffer) {
365                let update = encoder.decode(&buffer[..written])
366                    .expect("Round-trip decode failed during fuzzed test");
367                assert_eq!(update.network_id, event.network_id);
368                assert_eq!(update.component_kind, event.component_kind);
369                assert_eq!(update.tick, event.tick);
370                assert_eq!(update.payload, event.payload);
371            }
372        }
373    }
374}