Skip to main content

aetheris_encoder_serde/
serde_encoder.rs

1//! Implementation of the `SerdeEncoder` using `rmp-serde`.
2
3use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12/// Internal header for serialized replication events.
13///
14/// Ensures a stable binary format across different `rmp-serde` configurations.
15#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17    network_id: NetworkId,
18    component_kind: ComponentKind,
19    tick: u64,
20}
21
22/// A `serde`-based encoder that uses `rmp-serde` (`MessagePack`) for binary serialization.
23///
24/// This implementation targets Phase 1 MVP requirements for rapid iteration.
25/// It uses a fixed-size header followed by the raw component payload.
26#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30    /// Creates a new `SerdeEncoder`.
31    #[must_use]
32    pub fn new() -> Self {
33        Self
34    }
35
36    /// Encodes a `NetworkEvent` into raw bytes for transmission.
37    ///
38    /// # Errors
39    /// Returns `EncodeError` if the event fails to serialize or is a local-only variant.
40    pub fn encode_event(
41        &self,
42        event: &aetheris_protocol::events::NetworkEvent,
43    ) -> Result<Vec<u8>, EncodeError> {
44        let wire_event = match event {
45            NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
46            NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
47            NetworkEvent::Auth { session_token } => WireEvent::Auth {
48                session_token: session_token.clone(),
49            },
50            NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
51            NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
52                count: *count,
53                rotate: *rotate,
54            },
55            NetworkEvent::Spawn {
56                entity_type,
57                x,
58                y,
59                rot,
60                ..
61            } => WireEvent::Spawn {
62                entity_type: *entity_type,
63                x: *x,
64                y: *y,
65                rot: *rot,
66            },
67            NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
68            NetworkEvent::StartSession { .. } => WireEvent::StartSession,
69            NetworkEvent::RequestSystemManifest { .. } => WireEvent::RequestSystemManifest,
70            NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
71            NetworkEvent::ClientConnected(_)
72            | NetworkEvent::ClientDisconnected(_)
73            | NetworkEvent::UnreliableMessage { .. }
74            | NetworkEvent::ReliableMessage { .. }
75            | NetworkEvent::Ping { .. }
76            | NetworkEvent::SessionClosed(_)
77            | NetworkEvent::StreamReset(_)
78            | NetworkEvent::Disconnected(_) => {
79                return Err(EncodeError::Io(std::io::Error::other(format!(
80                    "Cannot encode local-only variant as wire event: {event:?}"
81                ))));
82            }
83        };
84        rmp_serde::to_vec(&wire_event)
85            .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
86    }
87
88    /// Decodes raw bytes into a `NetworkEvent`.
89    ///
90    /// # Errors
91    /// Returns `EncodeError` if the bytes are not a valid `WireEvent`.
92    pub fn decode_event(
93        &self,
94        data: &[u8],
95    ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
96        let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
97            EncodeError::MalformedPayload {
98                offset: 0, // In Phase 1 we don't track exact rmp-serde offset easily
99                message: e.to_string(),
100            }
101        })?;
102
103        Ok(match wire_event {
104            WireEvent::Ping { tick } => NetworkEvent::Ping {
105                client_id: ClientId(0), // Populated by transport/server
106                tick,
107            },
108            WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
109            WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
110            WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
111                client_id: ClientId(0),
112                fragment,
113            },
114            WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
115                client_id: ClientId(0), // Populated by server
116                count,
117                rotate,
118            },
119            WireEvent::Spawn {
120                entity_type,
121                x,
122                y,
123                rot,
124            } => NetworkEvent::Spawn {
125                client_id: ClientId(0),
126                entity_type,
127                x,
128                y,
129                rot,
130            },
131            WireEvent::ClearWorld => NetworkEvent::ClearWorld {
132                client_id: ClientId(0),
133            },
134            WireEvent::StartSession => NetworkEvent::StartSession {
135                client_id: ClientId(0),
136            },
137            WireEvent::RequestSystemManifest => NetworkEvent::RequestSystemManifest {
138                client_id: ClientId(0),
139            },
140            WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
141                client_id: ClientId(0),
142                event,
143            },
144        })
145    }
146}
147
148impl Encoder for SerdeEncoder {
149    fn codec_id(&self) -> u32 {
150        1
151    }
152
153    fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
154        self.encode_event(event)
155    }
156
157    fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
158        self.decode_event(data)
159    }
160
161    fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
162        #[cfg(not(target_arch = "wasm32"))]
163        let start = std::time::Instant::now();
164        let header = PacketHeader {
165            network_id: event.network_id,
166            component_kind: event.component_kind,
167            tick: event.tick,
168        };
169
170        let mut cursor = Cursor::new(buffer);
171        let mut serializer = rmp_serde::Serializer::new(&mut cursor);
172
173        header.serialize(&mut serializer).map_err(|_e| {
174            metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
175                .increment(1);
176            // If it fails to serialize, it's likely a buffer overflow.
177            EncodeError::BufferOverflow {
178                needed: 32, // PacketHeader is small (~20 bytes)
179                available: cursor.get_ref().len(),
180            }
181        })?;
182
183        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
184        let payload_len = event.payload.len();
185        let total_needed = header_len + payload_len;
186
187        if total_needed > cursor.get_ref().len() {
188            metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
189                .increment(1);
190            return Err(EncodeError::BufferOverflow {
191                needed: total_needed,
192                available: cursor.get_ref().len(),
193            });
194        }
195
196        // Copy payload manually after the header
197        let slice = cursor.into_inner();
198        slice[header_len..total_needed].copy_from_slice(&event.payload);
199
200        #[allow(clippy::cast_precision_loss)]
201        metrics::histogram!(
202            "aetheris_encoder_payload_size_bytes",
203            "operation" => "encode"
204        )
205        .record(total_needed as f64);
206
207        #[cfg(not(target_arch = "wasm32"))]
208        metrics::histogram!(
209            "aetheris_encoder_encode_duration_seconds",
210            "kind" => event.component_kind.0.to_string()
211        )
212        .record(start.elapsed().as_secs_f64());
213
214        Ok(total_needed)
215    }
216
217    fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
218        #[cfg(not(target_arch = "wasm32"))]
219        let start = std::time::Instant::now();
220        let mut cursor = Cursor::new(buffer);
221        let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
222
223        let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
224            metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
225                .increment(1);
226            EncodeError::MalformedPayload {
227                offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
228                message: e.to_string(),
229            }
230        })?;
231
232        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
233        let payload = buffer
234            .get(header_len..)
235            .ok_or(EncodeError::MalformedPayload {
236                offset: header_len,
237                message: "Payload slice out of bounds".to_string(),
238            })?
239            .to_vec();
240
241        #[allow(clippy::cast_precision_loss)]
242        metrics::histogram!(
243            "aetheris_encoder_payload_size_bytes",
244            "operation" => "decode"
245        )
246        .record(buffer.len() as f64);
247
248        #[cfg(not(target_arch = "wasm32"))]
249        metrics::histogram!(
250            "aetheris_encoder_decode_duration_seconds",
251            "kind" => header.component_kind.0.to_string()
252        )
253        .record(start.elapsed().as_secs_f64());
254
255        Ok(ComponentUpdate {
256            network_id: header.network_id,
257            component_kind: header.component_kind,
258            payload,
259            tick: header.tick,
260        })
261    }
262
263    fn max_encoded_size(&self) -> usize {
264        aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use proptest::prelude::*;
272
273    #[test]
274    fn test_roundtrip() {
275        let encoder = SerdeEncoder::new();
276        let event = ReplicationEvent {
277            network_id: NetworkId(42),
278            component_kind: ComponentKind(1),
279            payload: vec![1, 2, 3, 4],
280            tick: 100,
281        };
282
283        let mut buffer = [0u8; 1200];
284        let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
285        assert!(bytes_written > 0);
286
287        let update = encoder.decode(&buffer[..bytes_written]).unwrap();
288        assert_eq!(update.network_id, event.network_id);
289        assert_eq!(update.component_kind, event.component_kind);
290        assert_eq!(update.tick, event.tick);
291        assert_eq!(update.payload, event.payload);
292    }
293    #[test]
294    fn test_fragment_roundtrip() {
295        let encoder = SerdeEncoder::new();
296        let fragment = aetheris_protocol::events::FragmentedEvent {
297            message_id: 123,
298            fragment_index: 1,
299            total_fragments: 5,
300            payload: vec![1, 2, 3],
301        };
302
303        let event = NetworkEvent::Fragment {
304            client_id: aetheris_protocol::types::ClientId(0),
305            fragment: fragment.clone(),
306        };
307
308        let output = encoder.encode_event(&event).unwrap();
309        let decoded = encoder.decode_event(&output).unwrap();
310
311        if let NetworkEvent::Fragment {
312            client_id: _,
313            fragment: decoded_fragment,
314        } = decoded
315        {
316            assert_eq!(decoded_fragment.message_id, fragment.message_id);
317            assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
318            assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
319            assert_eq!(decoded_fragment.payload, fragment.payload);
320        } else {
321            panic!("Decoded event is not a Fragment: {decoded:?}");
322        }
323    }
324
325    #[test]
326    fn test_buffer_overflow() {
327        let encoder = SerdeEncoder::new();
328        let event = ReplicationEvent {
329            network_id: NetworkId(42),
330            component_kind: ComponentKind(1),
331            payload: vec![1, 2, 3, 4],
332            tick: 100,
333        };
334
335        let mut small_buffer = [0u8; 1];
336        let result = encoder.encode(&event, &mut small_buffer);
337        assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
338    }
339
340    #[test]
341    fn test_malformed_payload() {
342        let encoder = SerdeEncoder::new();
343        let garbage = [0xff, 0xff, 0xff, 0xff];
344        let result = encoder.decode(&garbage);
345        if let Err(EncodeError::MalformedPayload { message, .. }) = result {
346            assert!(!message.is_empty());
347        } else {
348            panic!("Expected MalformedPayload error, got {result:?}");
349        }
350    }
351
352    proptest! {
353        #[test]
354        fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
355            let encoder = SerdeEncoder::new();
356            // Should never panic
357            let _ = encoder.decode(bytes);
358        }
359
360        #[test]
361        fn test_fuzz_roundtrip(
362            nid in any::<u64>(),
363            kind in any::<u16>(),
364            tick in any::<u64>(),
365            ref payload in any::<Vec<u8>>()
366        ) {
367            let encoder = SerdeEncoder::new();
368            let event = ReplicationEvent {
369                network_id: NetworkId(nid),
370                component_kind: ComponentKind(kind),
371                payload: payload.clone(),
372                tick,
373            };
374
375            let mut buffer = vec![0u8; 2048 + payload.len()];
376            if let Ok(written) = encoder.encode(&event, &mut buffer) {
377                let update = encoder.decode(&buffer[..written])
378                    .expect("Round-trip decode failed during fuzzed test");
379                assert_eq!(update.network_id, event.network_id);
380                assert_eq!(update.component_kind, event.component_kind);
381                assert_eq!(update.tick, event.tick);
382                assert_eq!(update.payload, event.payload);
383            }
384        }
385    }
386
387    #[test]
388    fn test_disconnected_not_serializable() {
389        let encoder = SerdeEncoder::new();
390        let event = NetworkEvent::Disconnected(ClientId(42));
391
392        // Attempting to encode a local-only event should return an error
393        let result = encoder.encode_event(&event);
394        assert!(result.is_err());
395        if let Err(EncodeError::Io(e)) = result {
396            assert!(e.to_string().contains("Cannot encode local-only variant"));
397        } else {
398            panic!("Expected EncodeError::Io with local-only message, got {result:?}");
399        }
400    }
401
402    #[test]
403    fn test_game_event_roundtrip() {
404        use aetheris_protocol::events::GameEvent;
405        use aetheris_protocol::types::NetworkId;
406
407        let encoder = SerdeEncoder::new();
408        let game_event = GameEvent::AsteroidDepleted {
409            network_id: NetworkId(123),
410        };
411        let event = NetworkEvent::GameEvent {
412            client_id: ClientId(1), // Should be masked to 0 on wire and restored on server poll
413            event: game_event.clone(),
414        };
415
416        let output = encoder.encode_event(&event).unwrap();
417        let decoded = encoder.decode_event(&output).unwrap();
418
419        if let NetworkEvent::GameEvent {
420            client_id,
421            event: decoded_event,
422        } = decoded
423        {
424            assert_eq!(
425                client_id,
426                ClientId(0),
427                "Wire decoding should default client_id to 0"
428            );
429            match decoded_event {
430                GameEvent::AsteroidDepleted { network_id } => {
431                    assert_eq!(network_id, NetworkId(123));
432                }
433                GameEvent::Possession { .. } | GameEvent::SystemManifest { .. } => {
434                    panic!("Unexpected event type in roundtrip test");
435                }
436            }
437        } else {
438            panic!("Decoded event is not a GameEvent: {decoded:?}");
439        }
440    }
441}