Skip to main content

aetheris_encoder_serde/
serde_encoder.rs

1//! Implementation of the `SerdeEncoder` using `rmp-serde`.
2
3use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12/// Internal header for serialized replication events.
13///
14/// Ensures a stable binary format across different `rmp-serde` configurations.
15#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17    network_id: NetworkId,
18    component_kind: ComponentKind,
19    tick: u64,
20}
21
22/// A `serde`-based encoder that uses `rmp-serde` (`MessagePack`) for binary serialization.
23///
24/// This implementation targets Phase 1 MVP requirements for rapid iteration.
25/// It uses a fixed-size header followed by the raw component payload.
26#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30    /// Creates a new `SerdeEncoder`.
31    #[must_use]
32    pub fn new() -> Self {
33        Self
34    }
35
36    /// Encodes a `NetworkEvent` into raw bytes for transmission.
37    ///
38    /// # Errors
39    /// Returns `EncodeError` if the event fails to serialize or is a local-only variant.
40    pub fn encode_event(
41        &self,
42        event: &aetheris_protocol::events::NetworkEvent,
43    ) -> Result<Vec<u8>, EncodeError> {
44        let wire_event = Self::to_wire_event(event)?;
45        rmp_serde::to_vec(&wire_event)
46            .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
47    }
48
49    /// Encodes a `NetworkEvent` into the provided buffer.
50    ///
51    /// # Errors
52    /// Returns `EncodeError::BufferOverflow` if the buffer is too small.
53    pub fn encode_event_into(
54        &self,
55        event: &NetworkEvent,
56        buffer: &mut [u8],
57    ) -> Result<usize, EncodeError> {
58        let wire_event = Self::to_wire_event(event)?;
59        let mut cursor = Cursor::new(&mut *buffer);
60        rmp_serde::encode::write(&mut cursor, &wire_event).map_err(|e| {
61            if e.to_string().contains("unexpected end of file") {
62                EncodeError::BufferOverflow {
63                    needed: 256, // Estimate
64                    available: cursor.get_ref().len(),
65                }
66            } else {
67                EncodeError::Io(std::io::Error::other(e.to_string()))
68            }
69        })?;
70        usize::try_from(cursor.position()).map_err(|_| EncodeError::BufferOverflow {
71            needed: usize::MAX,
72            available: buffer.len(),
73        })
74    }
75
76    fn to_wire_event(event: &NetworkEvent) -> Result<WireEvent, EncodeError> {
77        Ok(match event {
78            NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
79            NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
80            NetworkEvent::Auth { session_token } => WireEvent::Auth {
81                session_token: session_token.clone(),
82            },
83            NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
84            NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
85                count: *count,
86                rotate: *rotate,
87            },
88            NetworkEvent::Spawn {
89                entity_type,
90                x,
91                y,
92                rot,
93                ..
94            } => WireEvent::Spawn {
95                entity_type: *entity_type,
96                x: *x,
97                y: *y,
98                rot: *rot,
99            },
100            NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
101            NetworkEvent::StartSession { .. } => WireEvent::StartSession,
102            NetworkEvent::RequestSystemManifest { .. } => WireEvent::RequestSystemManifest,
103            NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
104            NetworkEvent::ReplicationBatch { events, .. } => {
105                WireEvent::ReplicationBatch(events.clone())
106            }
107            _ => {
108                return Err(EncodeError::Io(std::io::Error::other(format!(
109                    "Cannot encode local-only variant as wire event: {event:?}"
110                ))));
111            }
112        })
113    }
114
115    /// Decodes raw bytes into a `NetworkEvent`.
116    ///
117    /// # Errors
118    /// Returns `EncodeError` if the bytes are not a valid `WireEvent`.
119    pub fn decode_event(
120        &self,
121        data: &[u8],
122    ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
123        let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
124            EncodeError::MalformedPayload {
125                offset: 0, // In Phase 1 we don't track exact rmp-serde offset easily
126                message: e.to_string(),
127            }
128        })?;
129
130        Ok(match wire_event {
131            WireEvent::Ping { tick } => NetworkEvent::Ping {
132                client_id: ClientId(0), // Populated by transport/server
133                tick,
134            },
135            WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
136            WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
137            WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
138                client_id: ClientId(0),
139                fragment,
140            },
141            WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
142                client_id: ClientId(0), // Populated by server
143                count,
144                rotate,
145            },
146            WireEvent::Spawn {
147                entity_type,
148                x,
149                y,
150                rot,
151            } => NetworkEvent::Spawn {
152                client_id: ClientId(0),
153                entity_type,
154                x,
155                y,
156                rot,
157            },
158            WireEvent::ClearWorld => NetworkEvent::ClearWorld {
159                client_id: ClientId(0),
160            },
161            WireEvent::StartSession => NetworkEvent::StartSession {
162                client_id: ClientId(0),
163            },
164            WireEvent::RequestSystemManifest => NetworkEvent::RequestSystemManifest {
165                client_id: ClientId(0),
166            },
167            WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
168                client_id: ClientId(0),
169                event,
170            },
171            WireEvent::ReplicationBatch(events) => NetworkEvent::ReplicationBatch {
172                client_id: ClientId(0),
173                events,
174            },
175        })
176    }
177}
178
179impl Encoder for SerdeEncoder {
180    fn codec_id(&self) -> u32 {
181        1
182    }
183
184    fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
185        self.encode_event(event)
186    }
187
188    fn encode_event_into(
189        &self,
190        event: &NetworkEvent,
191        buffer: &mut [u8],
192    ) -> Result<usize, EncodeError> {
193        self.encode_event_into(event, buffer)
194    }
195
196    fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
197        self.decode_event(data)
198    }
199
200    fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
201        #[cfg(not(target_arch = "wasm32"))]
202        let start = std::time::Instant::now();
203        let header = PacketHeader {
204            network_id: event.network_id,
205            component_kind: event.component_kind,
206            tick: event.tick,
207        };
208
209        let mut cursor = Cursor::new(buffer);
210        let mut serializer = rmp_serde::Serializer::new(&mut cursor);
211
212        header.serialize(&mut serializer).map_err(|_e| {
213            metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
214                .increment(1);
215            // If it fails to serialize, it's likely a buffer overflow.
216            EncodeError::BufferOverflow {
217                needed: 32, // PacketHeader is small (~20 bytes)
218                available: cursor.get_ref().len(),
219            }
220        })?;
221
222        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
223        let payload_len = event.payload.len();
224        let total_needed = header_len + payload_len;
225
226        if total_needed > cursor.get_ref().len() {
227            metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
228                .increment(1);
229            return Err(EncodeError::BufferOverflow {
230                needed: total_needed,
231                available: cursor.get_ref().len(),
232            });
233        }
234
235        // Copy payload manually after the header
236        let slice = cursor.into_inner();
237        slice[header_len..total_needed].copy_from_slice(&event.payload);
238
239        #[allow(clippy::cast_precision_loss)]
240        metrics::histogram!(
241            "aetheris_encoder_payload_size_bytes",
242            "operation" => "encode"
243        )
244        .record(total_needed as f64);
245
246        #[cfg(not(target_arch = "wasm32"))]
247        metrics::histogram!(
248            "aetheris_encoder_encode_duration_seconds",
249            "kind" => event.component_kind.0.to_string()
250        )
251        .record(start.elapsed().as_secs_f64());
252
253        Ok(total_needed)
254    }
255
256    fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
257        #[cfg(not(target_arch = "wasm32"))]
258        let start = std::time::Instant::now();
259        let mut cursor = Cursor::new(buffer);
260        let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
261
262        let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
263            metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
264                .increment(1);
265            EncodeError::MalformedPayload {
266                offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
267                message: e.to_string(),
268            }
269        })?;
270
271        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
272        let payload = buffer
273            .get(header_len..)
274            .ok_or(EncodeError::MalformedPayload {
275                offset: header_len,
276                message: "Payload slice out of bounds".to_string(),
277            })?
278            .to_vec();
279
280        #[allow(clippy::cast_precision_loss)]
281        metrics::histogram!(
282            "aetheris_encoder_payload_size_bytes",
283            "operation" => "decode"
284        )
285        .record(buffer.len() as f64);
286
287        #[cfg(not(target_arch = "wasm32"))]
288        metrics::histogram!(
289            "aetheris_encoder_decode_duration_seconds",
290            "kind" => header.component_kind.0.to_string()
291        )
292        .record(start.elapsed().as_secs_f64());
293
294        Ok(ComponentUpdate {
295            network_id: header.network_id,
296            component_kind: header.component_kind,
297            payload,
298            tick: header.tick,
299        })
300    }
301
302    fn max_encoded_size(&self) -> usize {
303        aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use proptest::prelude::*;
311
312    #[test]
313    fn test_roundtrip() {
314        let encoder = SerdeEncoder::new();
315        let event = ReplicationEvent {
316            network_id: NetworkId(42),
317            component_kind: ComponentKind(1),
318            payload: vec![1, 2, 3, 4],
319            tick: 100,
320        };
321
322        let mut buffer = [0u8; 1200];
323        let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
324        assert!(bytes_written > 0);
325
326        let update = encoder.decode(&buffer[..bytes_written]).unwrap();
327        assert_eq!(update.network_id, event.network_id);
328        assert_eq!(update.component_kind, event.component_kind);
329        assert_eq!(update.tick, event.tick);
330        assert_eq!(update.payload, event.payload);
331    }
332    #[test]
333    fn test_fragment_roundtrip() {
334        let encoder = SerdeEncoder::new();
335        let fragment = aetheris_protocol::events::FragmentedEvent {
336            message_id: 123,
337            fragment_index: 1,
338            total_fragments: 5,
339            payload: vec![1, 2, 3],
340        };
341
342        let event = NetworkEvent::Fragment {
343            client_id: aetheris_protocol::types::ClientId(0),
344            fragment: fragment.clone(),
345        };
346
347        let output = encoder.encode_event(&event).unwrap();
348        let decoded = encoder.decode_event(&output).unwrap();
349
350        if let NetworkEvent::Fragment {
351            client_id: _,
352            fragment: decoded_fragment,
353        } = decoded
354        {
355            assert_eq!(decoded_fragment.message_id, fragment.message_id);
356            assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
357            assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
358            assert_eq!(decoded_fragment.payload, fragment.payload);
359        } else {
360            panic!("Decoded event is not a Fragment: {decoded:?}");
361        }
362    }
363
364    #[test]
365    fn test_buffer_overflow() {
366        let encoder = SerdeEncoder::new();
367        let event = ReplicationEvent {
368            network_id: NetworkId(42),
369            component_kind: ComponentKind(1),
370            payload: vec![1, 2, 3, 4],
371            tick: 100,
372        };
373
374        let mut small_buffer = [0u8; 1];
375        let result = encoder.encode(&event, &mut small_buffer);
376        assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
377    }
378
379    #[test]
380    fn test_malformed_payload() {
381        let encoder = SerdeEncoder::new();
382        let garbage = [0xff, 0xff, 0xff, 0xff];
383        let result = encoder.decode(&garbage);
384        if let Err(EncodeError::MalformedPayload { message, .. }) = result {
385            assert!(!message.is_empty());
386        } else {
387            panic!("Expected MalformedPayload error, got {result:?}");
388        }
389    }
390
391    proptest! {
392        #[test]
393        fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
394            let encoder = SerdeEncoder::new();
395            // Should never panic
396            let _ = encoder.decode(bytes);
397        }
398
399        #[test]
400        fn test_fuzz_roundtrip(
401            nid in any::<u64>(),
402            kind in any::<u16>(),
403            tick in any::<u64>(),
404            ref payload in any::<Vec<u8>>()
405        ) {
406            let encoder = SerdeEncoder::new();
407            let event = ReplicationEvent {
408                network_id: NetworkId(nid),
409                component_kind: ComponentKind(kind),
410                payload: payload.clone(),
411                tick,
412            };
413
414            let mut buffer = vec![0u8; 2048 + payload.len()];
415            if let Ok(written) = encoder.encode(&event, &mut buffer) {
416                let update = encoder.decode(&buffer[..written])
417                    .expect("Round-trip decode failed during fuzzed test");
418                assert_eq!(update.network_id, event.network_id);
419                assert_eq!(update.component_kind, event.component_kind);
420                assert_eq!(update.tick, event.tick);
421                assert_eq!(update.payload, event.payload);
422            }
423        }
424    }
425
426    #[test]
427    fn test_disconnected_not_serializable() {
428        let encoder = SerdeEncoder::new();
429        let event = NetworkEvent::Disconnected(ClientId(42));
430
431        // Attempting to encode a local-only event should return an error
432        let result = encoder.encode_event(&event);
433        assert!(result.is_err());
434        if let Err(EncodeError::Io(e)) = result {
435            assert!(e.to_string().contains("Cannot encode local-only variant"));
436        } else {
437            panic!("Expected EncodeError::Io with local-only message, got {result:?}");
438        }
439    }
440
441    #[test]
442    fn test_game_event_roundtrip() {
443        use aetheris_protocol::events::GameEvent;
444        use aetheris_protocol::types::NetworkId;
445
446        let encoder = SerdeEncoder::new();
447        let game_event = GameEvent::AsteroidDepleted {
448            network_id: NetworkId(123),
449        };
450        let event = NetworkEvent::GameEvent {
451            client_id: ClientId(1), // Should be masked to 0 on wire and restored on server poll
452            event: game_event.clone(),
453        };
454
455        let output = encoder.encode_event(&event).unwrap();
456        let decoded = encoder.decode_event(&output).unwrap();
457
458        if let NetworkEvent::GameEvent {
459            client_id,
460            event: decoded_event,
461        } = decoded
462        {
463            assert_eq!(
464                client_id,
465                ClientId(0),
466                "Wire decoding should default client_id to 0"
467            );
468            match decoded_event {
469                GameEvent::AsteroidDepleted { network_id } => {
470                    assert_eq!(network_id, NetworkId(123));
471                }
472                GameEvent::Possession { .. } | GameEvent::SystemManifest { .. } => {
473                    panic!("Unexpected event type in roundtrip test");
474                }
475            }
476        } else {
477            panic!("Decoded event is not a GameEvent: {decoded:?}");
478        }
479    }
480}