Skip to main content

aetheris_encoder_serde/
serde_encoder.rs

1//! Implementation of the `SerdeEncoder` using `rmp-serde`.
2
3use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12/// Internal header for serialized replication events.
13///
14/// Ensures a stable binary format across different `rmp-serde` configurations.
15#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17    network_id: NetworkId,
18    component_kind: ComponentKind,
19    tick: u64,
20}
21
22/// A `serde`-based encoder that uses `rmp-serde` (`MessagePack`) for binary serialization.
23///
24/// This implementation targets Phase 1 MVP requirements for rapid iteration.
25/// It uses a fixed-size header followed by the raw component payload.
26#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30    /// Creates a new `SerdeEncoder`.
31    #[must_use]
32    pub fn new() -> Self {
33        Self
34    }
35
36    /// Encodes a `NetworkEvent` into raw bytes for transmission.
37    ///
38    /// # Errors
39    /// Returns `EncodeError` if the event fails to serialize or is a local-only variant.
40    pub fn encode_event(
41        &self,
42        event: &aetheris_protocol::events::NetworkEvent,
43    ) -> Result<Vec<u8>, EncodeError> {
44        let wire_event = Self::to_wire_event(event)?;
45        rmp_serde::to_vec(&wire_event)
46            .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
47    }
48
49    /// Encodes a `NetworkEvent` into the provided buffer.
50    ///
51    /// # Errors
52    /// Returns `EncodeError::BufferOverflow` if the buffer is too small.
53    pub fn encode_event_into(
54        &self,
55        event: &NetworkEvent,
56        buffer: &mut [u8],
57    ) -> Result<usize, EncodeError> {
58        let wire_event = Self::to_wire_event(event)?;
59        let mut cursor = Cursor::new(&mut *buffer);
60        rmp_serde::encode::write(&mut cursor, &wire_event).map_err(|e| {
61            let err_msg = e.to_string();
62            if err_msg.contains("unexpected end of file") || err_msg.contains("invalid value write")
63            {
64                EncodeError::BufferOverflow {
65                    needed: 256, // Estimate
66                    available: cursor.get_ref().len(),
67                }
68            } else {
69                EncodeError::Io(std::io::Error::other(err_msg))
70            }
71        })?;
72        usize::try_from(cursor.position()).map_err(|_| EncodeError::BufferOverflow {
73            needed: usize::MAX,
74            available: buffer.len(),
75        })
76    }
77
78    fn to_wire_event(event: &NetworkEvent) -> Result<WireEvent, EncodeError> {
79        Ok(match event {
80            NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
81            NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
82            NetworkEvent::Auth { session_token } => WireEvent::Auth {
83                session_token: session_token.clone(),
84            },
85            NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
86            NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
87                count: *count,
88                rotate: *rotate,
89            },
90            NetworkEvent::Spawn {
91                entity_type,
92                x,
93                y,
94                rot,
95                ..
96            } => WireEvent::Spawn {
97                entity_type: *entity_type,
98                x: *x,
99                y: *y,
100                rot: *rot,
101            },
102            NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
103            NetworkEvent::StartSession { .. } => WireEvent::StartSession,
104            NetworkEvent::RequestSystemManifest { .. } => WireEvent::RequestSystemManifest,
105            NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
106            NetworkEvent::ReplicationBatch { events, .. } => {
107                WireEvent::ReplicationBatch(events.clone())
108            }
109            _ => {
110                return Err(EncodeError::Io(std::io::Error::other(format!(
111                    "Cannot encode local-only variant as wire event: {event:?}"
112                ))));
113            }
114        })
115    }
116
117    /// Decodes raw bytes into a `NetworkEvent`.
118    ///
119    /// # Errors
120    /// Returns `EncodeError` if the bytes are not a valid `WireEvent`.
121    pub fn decode_event(
122        &self,
123        data: &[u8],
124    ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
125        let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
126            EncodeError::MalformedPayload {
127                offset: 0, // In Phase 1 we don't track exact rmp-serde offset easily
128                message: e.to_string(),
129            }
130        })?;
131
132        Ok(match wire_event {
133            WireEvent::Ping { tick } => NetworkEvent::Ping {
134                client_id: ClientId(0), // Populated by transport/server
135                tick,
136            },
137            WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
138            WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
139            WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
140                client_id: ClientId(0),
141                fragment,
142            },
143            WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
144                client_id: ClientId(0), // Populated by server
145                count,
146                rotate,
147            },
148            WireEvent::Spawn {
149                entity_type,
150                x,
151                y,
152                rot,
153            } => NetworkEvent::Spawn {
154                client_id: ClientId(0),
155                entity_type,
156                x,
157                y,
158                rot,
159            },
160            WireEvent::ClearWorld => NetworkEvent::ClearWorld {
161                client_id: ClientId(0),
162            },
163            WireEvent::StartSession => NetworkEvent::StartSession {
164                client_id: ClientId(0),
165            },
166            WireEvent::RequestSystemManifest => NetworkEvent::RequestSystemManifest {
167                client_id: ClientId(0),
168            },
169            WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
170                client_id: ClientId(0),
171                event,
172            },
173            WireEvent::ReplicationBatch(events) => NetworkEvent::ReplicationBatch {
174                client_id: ClientId(0),
175                events,
176            },
177        })
178    }
179}
180
181impl Encoder for SerdeEncoder {
182    fn codec_id(&self) -> u32 {
183        1
184    }
185
186    fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
187        self.encode_event(event)
188    }
189
190    fn encode_event_into(
191        &self,
192        event: &NetworkEvent,
193        buffer: &mut [u8],
194    ) -> Result<usize, EncodeError> {
195        self.encode_event_into(event, buffer)
196    }
197
198    fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
199        self.decode_event(data)
200    }
201
202    fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
203        #[cfg(not(target_arch = "wasm32"))]
204        let start = std::time::Instant::now();
205        let header = PacketHeader {
206            network_id: event.network_id,
207            component_kind: event.component_kind,
208            tick: event.tick,
209        };
210
211        let mut cursor = Cursor::new(buffer);
212        let mut serializer = rmp_serde::Serializer::new(&mut cursor);
213
214        header.serialize(&mut serializer).map_err(|_e| {
215            metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
216                .increment(1);
217            // If it fails to serialize, it's likely a buffer overflow.
218            EncodeError::BufferOverflow {
219                needed: 32, // PacketHeader is small (~20 bytes)
220                available: cursor.get_ref().len(),
221            }
222        })?;
223
224        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
225        let payload_len = event.payload.len();
226        let total_needed = header_len + payload_len;
227
228        if total_needed > cursor.get_ref().len() {
229            metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
230                .increment(1);
231            return Err(EncodeError::BufferOverflow {
232                needed: total_needed,
233                available: cursor.get_ref().len(),
234            });
235        }
236
237        // Copy payload manually after the header
238        let slice = cursor.into_inner();
239        slice[header_len..total_needed].copy_from_slice(&event.payload);
240
241        #[allow(clippy::cast_precision_loss)]
242        metrics::histogram!(
243            "aetheris_encoder_payload_size_bytes",
244            "operation" => "encode"
245        )
246        .record(total_needed as f64);
247
248        #[cfg(not(target_arch = "wasm32"))]
249        metrics::histogram!(
250            "aetheris_encoder_encode_duration_seconds",
251            "kind" => event.component_kind.0.to_string()
252        )
253        .record(start.elapsed().as_secs_f64());
254
255        Ok(total_needed)
256    }
257
258    fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
259        #[cfg(not(target_arch = "wasm32"))]
260        let start = std::time::Instant::now();
261        let mut cursor = Cursor::new(buffer);
262        let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
263
264        let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
265            metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
266                .increment(1);
267            EncodeError::MalformedPayload {
268                offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
269                message: e.to_string(),
270            }
271        })?;
272
273        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
274        let payload = buffer
275            .get(header_len..)
276            .ok_or(EncodeError::MalformedPayload {
277                offset: header_len,
278                message: "Payload slice out of bounds".to_string(),
279            })?
280            .to_vec();
281
282        #[allow(clippy::cast_precision_loss)]
283        metrics::histogram!(
284            "aetheris_encoder_payload_size_bytes",
285            "operation" => "decode"
286        )
287        .record(buffer.len() as f64);
288
289        #[cfg(not(target_arch = "wasm32"))]
290        metrics::histogram!(
291            "aetheris_encoder_decode_duration_seconds",
292            "kind" => header.component_kind.0.to_string()
293        )
294        .record(start.elapsed().as_secs_f64());
295
296        Ok(ComponentUpdate {
297            network_id: header.network_id,
298            component_kind: header.component_kind,
299            payload,
300            tick: header.tick,
301        })
302    }
303
304    fn max_encoded_size(&self) -> usize {
305        aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use proptest::prelude::*;
313
314    #[test]
315    fn test_roundtrip() {
316        let encoder = SerdeEncoder::new();
317        let event = ReplicationEvent {
318            network_id: NetworkId(42),
319            component_kind: ComponentKind(1),
320            payload: vec![1, 2, 3, 4],
321            tick: 100,
322        };
323
324        let mut buffer = [0u8; 1200];
325        let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
326        assert!(bytes_written > 0);
327
328        let update = encoder.decode(&buffer[..bytes_written]).unwrap();
329        assert_eq!(update.network_id, event.network_id);
330        assert_eq!(update.component_kind, event.component_kind);
331        assert_eq!(update.tick, event.tick);
332        assert_eq!(update.payload, event.payload);
333    }
334    #[test]
335    fn test_fragment_roundtrip() {
336        let encoder = SerdeEncoder::new();
337        let fragment = aetheris_protocol::events::FragmentedEvent {
338            message_id: 123,
339            fragment_index: 1,
340            total_fragments: 5,
341            payload: vec![1, 2, 3],
342        };
343
344        let event = NetworkEvent::Fragment {
345            client_id: aetheris_protocol::types::ClientId(0),
346            fragment: fragment.clone(),
347        };
348
349        let output = encoder.encode_event(&event).unwrap();
350        let decoded = encoder.decode_event(&output).unwrap();
351
352        if let NetworkEvent::Fragment {
353            client_id: _,
354            fragment: decoded_fragment,
355        } = decoded
356        {
357            assert_eq!(decoded_fragment.message_id, fragment.message_id);
358            assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
359            assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
360            assert_eq!(decoded_fragment.payload, fragment.payload);
361        } else {
362            panic!("Decoded event is not a Fragment: {decoded:?}");
363        }
364    }
365
366    #[test]
367    fn test_buffer_overflow() {
368        let encoder = SerdeEncoder::new();
369        let event = ReplicationEvent {
370            network_id: NetworkId(42),
371            component_kind: ComponentKind(1),
372            payload: vec![1, 2, 3, 4],
373            tick: 100,
374        };
375
376        let mut small_buffer = [0u8; 1];
377        let result = encoder.encode(&event, &mut small_buffer);
378        assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
379    }
380
381    #[test]
382    fn test_malformed_payload() {
383        let encoder = SerdeEncoder::new();
384        let garbage = [0xff, 0xff, 0xff, 0xff];
385        let result = encoder.decode(&garbage);
386        if let Err(EncodeError::MalformedPayload { message, .. }) = result {
387            assert!(!message.is_empty());
388        } else {
389            panic!("Expected MalformedPayload error, got {result:?}");
390        }
391    }
392
393    proptest! {
394        #[test]
395        fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
396            let encoder = SerdeEncoder::new();
397            // Should never panic
398            let _ = encoder.decode(bytes);
399        }
400
401        #[test]
402        fn test_fuzz_roundtrip(
403            nid in any::<u64>(),
404            kind in any::<u16>(),
405            tick in any::<u64>(),
406            ref payload in any::<Vec<u8>>()
407        ) {
408            let encoder = SerdeEncoder::new();
409            let event = ReplicationEvent {
410                network_id: NetworkId(nid),
411                component_kind: ComponentKind(kind),
412                payload: payload.clone(),
413                tick,
414            };
415
416            let mut buffer = vec![0u8; 2048 + payload.len()];
417            if let Ok(written) = encoder.encode(&event, &mut buffer) {
418                let update = encoder.decode(&buffer[..written])
419                    .expect("Round-trip decode failed during fuzzed test");
420                assert_eq!(update.network_id, event.network_id);
421                assert_eq!(update.component_kind, event.component_kind);
422                assert_eq!(update.tick, event.tick);
423                assert_eq!(update.payload, event.payload);
424            }
425        }
426    }
427
428    #[test]
429    fn test_disconnected_not_serializable() {
430        let encoder = SerdeEncoder::new();
431        let event = NetworkEvent::Disconnected(ClientId(42));
432
433        // Attempting to encode a local-only event should return an error
434        let result = encoder.encode_event(&event);
435        assert!(result.is_err());
436        if let Err(EncodeError::Io(e)) = result {
437            assert!(e.to_string().contains("Cannot encode local-only variant"));
438        } else {
439            panic!("Expected EncodeError::Io with local-only message, got {result:?}");
440        }
441    }
442
443    #[test]
444    fn test_game_event_roundtrip() {
445        use aetheris_protocol::events::GameEvent;
446        use aetheris_protocol::types::NetworkId;
447        use std::collections::BTreeMap;
448
449        let encoder = SerdeEncoder::new();
450
451        let test_events = vec![
452            GameEvent::AsteroidDepleted {
453                network_id: NetworkId(123),
454            },
455            GameEvent::Possession {
456                network_id: NetworkId(456),
457            },
458            GameEvent::SystemManifest {
459                manifest: BTreeMap::new(),
460            },
461            GameEvent::DamageEvent {
462                source: NetworkId(1),
463                target: NetworkId(2),
464                amount: 10,
465            },
466            GameEvent::DeathEvent {
467                target: NetworkId(3),
468            },
469            GameEvent::RespawnEvent {
470                target: NetworkId(4),
471                x: 10.5,
472                y: 20.2,
473            },
474            GameEvent::CargoCollected {
475                network_id: NetworkId(5),
476                amount: 50,
477            },
478        ];
479
480        for game_event in test_events {
481            let event = NetworkEvent::GameEvent {
482                client_id: ClientId(1),
483                event: game_event.clone(),
484            };
485
486            let output = encoder.encode_event(&event).unwrap();
487            let decoded = encoder.decode_event(&output).unwrap();
488
489            if let NetworkEvent::GameEvent {
490                client_id,
491                event: decoded_event,
492            } = decoded
493            {
494                assert_eq!(
495                    client_id,
496                    ClientId(0),
497                    "Wire decoding should default client_id to 0"
498                );
499                assert_eq!(
500                    decoded_event, game_event,
501                    "Roundtrip failed for {game_event:?}"
502                );
503            } else {
504                panic!("Decoded event is not a GameEvent: {decoded:?}");
505            }
506        }
507    }
508}