Skip to main content

aetheris_encoder_serde/
serde_encoder.rs

1//! Implementation of the `SerdeEncoder` using `rmp-serde`.
2
3use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12/// Internal header for serialized replication events.
13///
14/// Ensures a stable binary format across different `rmp-serde` configurations.
15#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17    network_id: NetworkId,
18    component_kind: ComponentKind,
19    tick: u64,
20}
21
22/// A `serde`-based encoder that uses `rmp-serde` (`MessagePack`) for binary serialization.
23///
24/// This implementation targets Phase 1 MVP requirements for rapid iteration.
25/// It uses a fixed-size header followed by the raw component payload.
26#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30    /// Creates a new `SerdeEncoder`.
31    #[must_use]
32    pub fn new() -> Self {
33        Self
34    }
35
36    /// Encodes a `NetworkEvent` into raw bytes for transmission.
37    ///
38    /// # Errors
39    /// Returns `EncodeError` if the event fails to serialize or is a local-only variant.
40    pub fn encode_event(
41        &self,
42        event: &aetheris_protocol::events::NetworkEvent,
43    ) -> Result<Vec<u8>, EncodeError> {
44        let wire_event = match event {
45            NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
46            NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
47            NetworkEvent::Auth { session_token } => WireEvent::Auth {
48                session_token: session_token.clone(),
49            },
50            NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
51            NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
52                count: *count,
53                rotate: *rotate,
54            },
55            NetworkEvent::Spawn {
56                entity_type,
57                x,
58                y,
59                rot,
60                ..
61            } => WireEvent::Spawn {
62                entity_type: *entity_type,
63                x: *x,
64                y: *y,
65                rot: *rot,
66            },
67            NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
68            NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
69            NetworkEvent::ClientConnected(_)
70            | NetworkEvent::ClientDisconnected(_)
71            | NetworkEvent::UnreliableMessage { .. }
72            | NetworkEvent::ReliableMessage { .. }
73            | NetworkEvent::Ping { .. }
74            | NetworkEvent::SessionClosed(_)
75            | NetworkEvent::StreamReset(_)
76            | NetworkEvent::Disconnected(_) => {
77                return Err(EncodeError::Io(std::io::Error::other(format!(
78                    "Cannot encode local-only variant as wire event: {event:?}"
79                ))));
80            }
81        };
82        rmp_serde::to_vec(&wire_event)
83            .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
84    }
85
86    /// Decodes raw bytes into a `NetworkEvent`.
87    ///
88    /// # Errors
89    /// Returns `EncodeError` if the bytes are not a valid `WireEvent`.
90    pub fn decode_event(
91        &self,
92        data: &[u8],
93    ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
94        let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
95            EncodeError::MalformedPayload {
96                offset: 0, // In Phase 1 we don't track exact rmp-serde offset easily
97                message: e.to_string(),
98            }
99        })?;
100
101        Ok(match wire_event {
102            WireEvent::Ping { tick } => NetworkEvent::Ping {
103                client_id: ClientId(0), // Populated by transport/server
104                tick,
105            },
106            WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
107            WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
108            WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
109                client_id: ClientId(0),
110                fragment,
111            },
112            WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
113                client_id: ClientId(0), // Populated by server
114                count,
115                rotate,
116            },
117            WireEvent::Spawn {
118                entity_type,
119                x,
120                y,
121                rot,
122            } => NetworkEvent::Spawn {
123                client_id: ClientId(0),
124                entity_type,
125                x,
126                y,
127                rot,
128            },
129            WireEvent::ClearWorld => NetworkEvent::ClearWorld {
130                client_id: ClientId(0),
131            },
132            WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
133                client_id: ClientId(0),
134                event,
135            },
136        })
137    }
138}
139
140impl Encoder for SerdeEncoder {
141    fn codec_id(&self) -> u32 {
142        1
143    }
144
145    fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
146        self.encode_event(event)
147    }
148
149    fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
150        self.decode_event(data)
151    }
152
153    fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
154        #[cfg(not(target_arch = "wasm32"))]
155        let start = std::time::Instant::now();
156        let header = PacketHeader {
157            network_id: event.network_id,
158            component_kind: event.component_kind,
159            tick: event.tick,
160        };
161
162        let mut cursor = Cursor::new(buffer);
163        let mut serializer = rmp_serde::Serializer::new(&mut cursor);
164
165        header.serialize(&mut serializer).map_err(|_e| {
166            metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
167                .increment(1);
168            // If it fails to serialize, it's likely a buffer overflow.
169            EncodeError::BufferOverflow {
170                needed: 32, // PacketHeader is small (~20 bytes)
171                available: cursor.get_ref().len(),
172            }
173        })?;
174
175        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
176        let payload_len = event.payload.len();
177        let total_needed = header_len + payload_len;
178
179        if total_needed > cursor.get_ref().len() {
180            metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
181                .increment(1);
182            return Err(EncodeError::BufferOverflow {
183                needed: total_needed,
184                available: cursor.get_ref().len(),
185            });
186        }
187
188        // Copy payload manually after the header
189        let slice = cursor.into_inner();
190        slice[header_len..total_needed].copy_from_slice(&event.payload);
191
192        #[allow(clippy::cast_precision_loss)]
193        metrics::histogram!(
194            "aetheris_encoder_payload_size_bytes",
195            "operation" => "encode"
196        )
197        .record(total_needed as f64);
198
199        #[cfg(not(target_arch = "wasm32"))]
200        metrics::histogram!(
201            "aetheris_encoder_encode_duration_seconds",
202            "kind" => event.component_kind.0.to_string()
203        )
204        .record(start.elapsed().as_secs_f64());
205
206        Ok(total_needed)
207    }
208
209    fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
210        #[cfg(not(target_arch = "wasm32"))]
211        let start = std::time::Instant::now();
212        let mut cursor = Cursor::new(buffer);
213        let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
214
215        let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
216            metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
217                .increment(1);
218            EncodeError::MalformedPayload {
219                offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
220                message: e.to_string(),
221            }
222        })?;
223
224        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
225        let payload = buffer
226            .get(header_len..)
227            .ok_or(EncodeError::MalformedPayload {
228                offset: header_len,
229                message: "Payload slice out of bounds".to_string(),
230            })?
231            .to_vec();
232
233        #[allow(clippy::cast_precision_loss)]
234        metrics::histogram!(
235            "aetheris_encoder_payload_size_bytes",
236            "operation" => "decode"
237        )
238        .record(buffer.len() as f64);
239
240        #[cfg(not(target_arch = "wasm32"))]
241        metrics::histogram!(
242            "aetheris_encoder_decode_duration_seconds",
243            "kind" => header.component_kind.0.to_string()
244        )
245        .record(start.elapsed().as_secs_f64());
246
247        Ok(ComponentUpdate {
248            network_id: header.network_id,
249            component_kind: header.component_kind,
250            payload,
251            tick: header.tick,
252        })
253    }
254
255    fn max_encoded_size(&self) -> usize {
256        aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use proptest::prelude::*;
264
265    #[test]
266    fn test_roundtrip() {
267        let encoder = SerdeEncoder::new();
268        let event = ReplicationEvent {
269            network_id: NetworkId(42),
270            component_kind: ComponentKind(1),
271            payload: vec![1, 2, 3, 4],
272            tick: 100,
273        };
274
275        let mut buffer = [0u8; 1200];
276        let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
277        assert!(bytes_written > 0);
278
279        let update = encoder.decode(&buffer[..bytes_written]).unwrap();
280        assert_eq!(update.network_id, event.network_id);
281        assert_eq!(update.component_kind, event.component_kind);
282        assert_eq!(update.tick, event.tick);
283        assert_eq!(update.payload, event.payload);
284    }
285    #[test]
286    fn test_fragment_roundtrip() {
287        let encoder = SerdeEncoder::new();
288        let fragment = aetheris_protocol::events::FragmentedEvent {
289            message_id: 123,
290            fragment_index: 1,
291            total_fragments: 5,
292            payload: vec![1, 2, 3],
293        };
294
295        let event = NetworkEvent::Fragment {
296            client_id: aetheris_protocol::types::ClientId(0),
297            fragment: fragment.clone(),
298        };
299
300        let output = encoder.encode_event(&event).unwrap();
301        let decoded = encoder.decode_event(&output).unwrap();
302
303        if let NetworkEvent::Fragment {
304            client_id: _,
305            fragment: decoded_fragment,
306        } = decoded
307        {
308            assert_eq!(decoded_fragment.message_id, fragment.message_id);
309            assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
310            assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
311            assert_eq!(decoded_fragment.payload, fragment.payload);
312        } else {
313            panic!("Decoded event is not a Fragment: {decoded:?}");
314        }
315    }
316
317    #[test]
318    fn test_buffer_overflow() {
319        let encoder = SerdeEncoder::new();
320        let event = ReplicationEvent {
321            network_id: NetworkId(42),
322            component_kind: ComponentKind(1),
323            payload: vec![1, 2, 3, 4],
324            tick: 100,
325        };
326
327        let mut small_buffer = [0u8; 1];
328        let result = encoder.encode(&event, &mut small_buffer);
329        assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
330    }
331
332    #[test]
333    fn test_malformed_payload() {
334        let encoder = SerdeEncoder::new();
335        let garbage = [0xff, 0xff, 0xff, 0xff];
336        let result = encoder.decode(&garbage);
337        if let Err(EncodeError::MalformedPayload { message, .. }) = result {
338            assert!(!message.is_empty());
339        } else {
340            panic!("Expected MalformedPayload error, got {result:?}");
341        }
342    }
343
344    proptest! {
345        #[test]
346        fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
347            let encoder = SerdeEncoder::new();
348            // Should never panic
349            let _ = encoder.decode(bytes);
350        }
351
352        #[test]
353        fn test_fuzz_roundtrip(
354            nid in any::<u64>(),
355            kind in any::<u16>(),
356            tick in any::<u64>(),
357            ref payload in any::<Vec<u8>>()
358        ) {
359            let encoder = SerdeEncoder::new();
360            let event = ReplicationEvent {
361                network_id: NetworkId(nid),
362                component_kind: ComponentKind(kind),
363                payload: payload.clone(),
364                tick,
365            };
366
367            let mut buffer = vec![0u8; 2048 + payload.len()];
368            if let Ok(written) = encoder.encode(&event, &mut buffer) {
369                let update = encoder.decode(&buffer[..written])
370                    .expect("Round-trip decode failed during fuzzed test");
371                assert_eq!(update.network_id, event.network_id);
372                assert_eq!(update.component_kind, event.component_kind);
373                assert_eq!(update.tick, event.tick);
374                assert_eq!(update.payload, event.payload);
375            }
376        }
377    }
378
379    #[test]
380    fn test_disconnected_not_serializable() {
381        let encoder = SerdeEncoder::new();
382        let event = NetworkEvent::Disconnected(ClientId(42));
383
384        // Attempting to encode a local-only event should return an error
385        let result = encoder.encode_event(&event);
386        assert!(result.is_err());
387        if let Err(EncodeError::Io(e)) = result {
388            assert!(e.to_string().contains("Cannot encode local-only variant"));
389        } else {
390            panic!("Expected EncodeError::Io with local-only message, got {result:?}");
391        }
392    }
393
394    #[test]
395    fn test_game_event_roundtrip() {
396        use aetheris_protocol::events::GameEvent;
397        use aetheris_protocol::types::NetworkId;
398
399        let encoder = SerdeEncoder::new();
400        let game_event = GameEvent::AsteroidDepleted {
401            network_id: NetworkId(123),
402        };
403        let event = NetworkEvent::GameEvent {
404            client_id: ClientId(1), // Should be masked to 0 on wire and restored on server poll
405            event: game_event.clone(),
406        };
407
408        let output = encoder.encode_event(&event).unwrap();
409        let decoded = encoder.decode_event(&output).unwrap();
410
411        if let NetworkEvent::GameEvent {
412            client_id,
413            event: decoded_event,
414        } = decoded
415        {
416            assert_eq!(
417                client_id,
418                ClientId(0),
419                "Wire decoding should default client_id to 0"
420            );
421            match decoded_event {
422                GameEvent::AsteroidDepleted { network_id } => {
423                    assert_eq!(network_id, NetworkId(123));
424                }
425            }
426        } else {
427            panic!("Decoded event is not a GameEvent: {decoded:?}");
428        }
429    }
430}