Skip to main content

aetheris_encoder_serde/
serde_encoder.rs

1//! Implementation of the `SerdeEncoder` using `rmp-serde`.
2
3use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE;
8use aetheris_protocol::error::EncodeError;
9use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
10use aetheris_protocol::traits::Encoder;
11use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
12
13/// Internal header for serialized replication events.
14///
15/// Ensures a stable binary format across different `rmp-serde` configurations.
16#[derive(Debug, Serialize, Deserialize)]
17struct PacketHeader {
18    network_id: NetworkId,
19    component_kind: ComponentKind,
20    tick: u64,
21}
22
23/// A `serde`-based encoder that uses `rmp-serde` (`MessagePack`) for binary serialization.
24///
25/// This implementation targets Phase 1 MVP requirements for rapid iteration.
26/// It uses a fixed-size header followed by the raw component payload.
27#[derive(Debug, Default)]
28pub struct SerdeEncoder;
29
30impl SerdeEncoder {
31    /// Creates a new `SerdeEncoder`.
32    #[must_use]
33    pub fn new() -> Self {
34        Self
35    }
36
37    /// Encodes a `NetworkEvent` into raw bytes for transmission.
38    ///
39    /// # Errors
40    /// Returns `EncodeError` if the event fails to serialize or is a local-only variant.
41    pub fn encode_event(
42        &self,
43        event: &aetheris_protocol::events::NetworkEvent,
44    ) -> Result<Vec<u8>, EncodeError> {
45        let wire_event = match event {
46            NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
47            NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
48            NetworkEvent::Auth { session_token } => WireEvent::Auth {
49                session_token: session_token.clone(),
50            },
51            NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
52            NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
53                count: *count,
54                rotate: *rotate,
55            },
56            NetworkEvent::Spawn {
57                entity_type,
58                x,
59                y,
60                rot,
61                ..
62            } => WireEvent::Spawn {
63                entity_type: *entity_type,
64                x: *x,
65                y: *y,
66                rot: *rot,
67            },
68            NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
69            NetworkEvent::StartSession { .. } => WireEvent::StartSession,
70            NetworkEvent::RequestSystemManifest { .. } => WireEvent::RequestSystemManifest,
71            NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
72            NetworkEvent::ReplicationBatch { events, .. } => {
73                let wire_event = WireEvent::ReplicationBatch(events.clone());
74                let encoded = rmp_serde::to_vec(&wire_event)
75                    .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))?;
76
77                if encoded.len() > MAX_SAFE_PAYLOAD_SIZE {
78                    return Err(EncodeError::BufferOverflow {
79                        needed: encoded.len(),
80                        available: MAX_SAFE_PAYLOAD_SIZE,
81                    });
82                }
83                return Ok(encoded);
84            }
85            NetworkEvent::ClientConnected(_)
86            | NetworkEvent::ClientDisconnected(_)
87            | NetworkEvent::UnreliableMessage { .. }
88            | NetworkEvent::ReliableMessage { .. }
89            | NetworkEvent::Ping { .. }
90            | NetworkEvent::SessionClosed(_)
91            | NetworkEvent::StreamReset(_)
92            | NetworkEvent::Disconnected(_) => {
93                return Err(EncodeError::Io(std::io::Error::other(format!(
94                    "Cannot encode local-only variant as wire event: {event:?}"
95                ))));
96            }
97        };
98        rmp_serde::to_vec(&wire_event)
99            .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
100    }
101
102    /// Decodes raw bytes into a `NetworkEvent`.
103    ///
104    /// # Errors
105    /// Returns `EncodeError` if the bytes are not a valid `WireEvent`.
106    pub fn decode_event(
107        &self,
108        data: &[u8],
109    ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
110        let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
111            EncodeError::MalformedPayload {
112                offset: 0, // In Phase 1 we don't track exact rmp-serde offset easily
113                message: e.to_string(),
114            }
115        })?;
116
117        Ok(match wire_event {
118            WireEvent::Ping { tick } => NetworkEvent::Ping {
119                client_id: ClientId(0), // Populated by transport/server
120                tick,
121            },
122            WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
123            WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
124            WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
125                client_id: ClientId(0),
126                fragment,
127            },
128            WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
129                client_id: ClientId(0), // Populated by server
130                count,
131                rotate,
132            },
133            WireEvent::Spawn {
134                entity_type,
135                x,
136                y,
137                rot,
138            } => NetworkEvent::Spawn {
139                client_id: ClientId(0),
140                entity_type,
141                x,
142                y,
143                rot,
144            },
145            WireEvent::ClearWorld => NetworkEvent::ClearWorld {
146                client_id: ClientId(0),
147            },
148            WireEvent::StartSession => NetworkEvent::StartSession {
149                client_id: ClientId(0),
150            },
151            WireEvent::RequestSystemManifest => NetworkEvent::RequestSystemManifest {
152                client_id: ClientId(0),
153            },
154            WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
155                client_id: ClientId(0),
156                event,
157            },
158            WireEvent::ReplicationBatch(events) => NetworkEvent::ReplicationBatch {
159                client_id: ClientId(0),
160                events,
161            },
162        })
163    }
164}
165
166impl Encoder for SerdeEncoder {
167    fn codec_id(&self) -> u32 {
168        1
169    }
170
171    fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
172        self.encode_event(event)
173    }
174
175    fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
176        self.decode_event(data)
177    }
178
179    fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
180        #[cfg(not(target_arch = "wasm32"))]
181        let start = std::time::Instant::now();
182        let header = PacketHeader {
183            network_id: event.network_id,
184            component_kind: event.component_kind,
185            tick: event.tick,
186        };
187
188        let mut cursor = Cursor::new(buffer);
189        let mut serializer = rmp_serde::Serializer::new(&mut cursor);
190
191        header.serialize(&mut serializer).map_err(|_e| {
192            metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
193                .increment(1);
194            // If it fails to serialize, it's likely a buffer overflow.
195            EncodeError::BufferOverflow {
196                needed: 32, // PacketHeader is small (~20 bytes)
197                available: cursor.get_ref().len(),
198            }
199        })?;
200
201        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
202        let payload_len = event.payload.len();
203        let total_needed = header_len + payload_len;
204
205        if total_needed > cursor.get_ref().len() {
206            metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
207                .increment(1);
208            return Err(EncodeError::BufferOverflow {
209                needed: total_needed,
210                available: cursor.get_ref().len(),
211            });
212        }
213
214        // Copy payload manually after the header
215        let slice = cursor.into_inner();
216        slice[header_len..total_needed].copy_from_slice(&event.payload);
217
218        #[allow(clippy::cast_precision_loss)]
219        metrics::histogram!(
220            "aetheris_encoder_payload_size_bytes",
221            "operation" => "encode"
222        )
223        .record(total_needed as f64);
224
225        #[cfg(not(target_arch = "wasm32"))]
226        metrics::histogram!(
227            "aetheris_encoder_encode_duration_seconds",
228            "kind" => event.component_kind.0.to_string()
229        )
230        .record(start.elapsed().as_secs_f64());
231
232        Ok(total_needed)
233    }
234
235    fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
236        #[cfg(not(target_arch = "wasm32"))]
237        let start = std::time::Instant::now();
238        let mut cursor = Cursor::new(buffer);
239        let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
240
241        let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
242            metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
243                .increment(1);
244            EncodeError::MalformedPayload {
245                offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
246                message: e.to_string(),
247            }
248        })?;
249
250        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
251        let payload = buffer
252            .get(header_len..)
253            .ok_or(EncodeError::MalformedPayload {
254                offset: header_len,
255                message: "Payload slice out of bounds".to_string(),
256            })?
257            .to_vec();
258
259        #[allow(clippy::cast_precision_loss)]
260        metrics::histogram!(
261            "aetheris_encoder_payload_size_bytes",
262            "operation" => "decode"
263        )
264        .record(buffer.len() as f64);
265
266        #[cfg(not(target_arch = "wasm32"))]
267        metrics::histogram!(
268            "aetheris_encoder_decode_duration_seconds",
269            "kind" => header.component_kind.0.to_string()
270        )
271        .record(start.elapsed().as_secs_f64());
272
273        Ok(ComponentUpdate {
274            network_id: header.network_id,
275            component_kind: header.component_kind,
276            payload,
277            tick: header.tick,
278        })
279    }
280
281    fn max_encoded_size(&self) -> usize {
282        aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use proptest::prelude::*;
290
291    #[test]
292    fn test_roundtrip() {
293        let encoder = SerdeEncoder::new();
294        let event = ReplicationEvent {
295            network_id: NetworkId(42),
296            component_kind: ComponentKind(1),
297            payload: vec![1, 2, 3, 4],
298            tick: 100,
299        };
300
301        let mut buffer = [0u8; 1200];
302        let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
303        assert!(bytes_written > 0);
304
305        let update = encoder.decode(&buffer[..bytes_written]).unwrap();
306        assert_eq!(update.network_id, event.network_id);
307        assert_eq!(update.component_kind, event.component_kind);
308        assert_eq!(update.tick, event.tick);
309        assert_eq!(update.payload, event.payload);
310    }
311    #[test]
312    fn test_fragment_roundtrip() {
313        let encoder = SerdeEncoder::new();
314        let fragment = aetheris_protocol::events::FragmentedEvent {
315            message_id: 123,
316            fragment_index: 1,
317            total_fragments: 5,
318            payload: vec![1, 2, 3],
319        };
320
321        let event = NetworkEvent::Fragment {
322            client_id: aetheris_protocol::types::ClientId(0),
323            fragment: fragment.clone(),
324        };
325
326        let output = encoder.encode_event(&event).unwrap();
327        let decoded = encoder.decode_event(&output).unwrap();
328
329        if let NetworkEvent::Fragment {
330            client_id: _,
331            fragment: decoded_fragment,
332        } = decoded
333        {
334            assert_eq!(decoded_fragment.message_id, fragment.message_id);
335            assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
336            assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
337            assert_eq!(decoded_fragment.payload, fragment.payload);
338        } else {
339            panic!("Decoded event is not a Fragment: {decoded:?}");
340        }
341    }
342
343    #[test]
344    fn test_buffer_overflow() {
345        let encoder = SerdeEncoder::new();
346        let event = ReplicationEvent {
347            network_id: NetworkId(42),
348            component_kind: ComponentKind(1),
349            payload: vec![1, 2, 3, 4],
350            tick: 100,
351        };
352
353        let mut small_buffer = [0u8; 1];
354        let result = encoder.encode(&event, &mut small_buffer);
355        assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
356    }
357
358    #[test]
359    fn test_malformed_payload() {
360        let encoder = SerdeEncoder::new();
361        let garbage = [0xff, 0xff, 0xff, 0xff];
362        let result = encoder.decode(&garbage);
363        if let Err(EncodeError::MalformedPayload { message, .. }) = result {
364            assert!(!message.is_empty());
365        } else {
366            panic!("Expected MalformedPayload error, got {result:?}");
367        }
368    }
369
370    proptest! {
371        #[test]
372        fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
373            let encoder = SerdeEncoder::new();
374            // Should never panic
375            let _ = encoder.decode(bytes);
376        }
377
378        #[test]
379        fn test_fuzz_roundtrip(
380            nid in any::<u64>(),
381            kind in any::<u16>(),
382            tick in any::<u64>(),
383            ref payload in any::<Vec<u8>>()
384        ) {
385            let encoder = SerdeEncoder::new();
386            let event = ReplicationEvent {
387                network_id: NetworkId(nid),
388                component_kind: ComponentKind(kind),
389                payload: payload.clone(),
390                tick,
391            };
392
393            let mut buffer = vec![0u8; 2048 + payload.len()];
394            if let Ok(written) = encoder.encode(&event, &mut buffer) {
395                let update = encoder.decode(&buffer[..written])
396                    .expect("Round-trip decode failed during fuzzed test");
397                assert_eq!(update.network_id, event.network_id);
398                assert_eq!(update.component_kind, event.component_kind);
399                assert_eq!(update.tick, event.tick);
400                assert_eq!(update.payload, event.payload);
401            }
402        }
403    }
404
405    #[test]
406    fn test_disconnected_not_serializable() {
407        let encoder = SerdeEncoder::new();
408        let event = NetworkEvent::Disconnected(ClientId(42));
409
410        // Attempting to encode a local-only event should return an error
411        let result = encoder.encode_event(&event);
412        assert!(result.is_err());
413        if let Err(EncodeError::Io(e)) = result {
414            assert!(e.to_string().contains("Cannot encode local-only variant"));
415        } else {
416            panic!("Expected EncodeError::Io with local-only message, got {result:?}");
417        }
418    }
419
420    #[test]
421    fn test_game_event_roundtrip() {
422        use aetheris_protocol::events::GameEvent;
423        use aetheris_protocol::types::NetworkId;
424
425        let encoder = SerdeEncoder::new();
426        let game_event = GameEvent::AsteroidDepleted {
427            network_id: NetworkId(123),
428        };
429        let event = NetworkEvent::GameEvent {
430            client_id: ClientId(1), // Should be masked to 0 on wire and restored on server poll
431            event: game_event.clone(),
432        };
433
434        let output = encoder.encode_event(&event).unwrap();
435        let decoded = encoder.decode_event(&output).unwrap();
436
437        if let NetworkEvent::GameEvent {
438            client_id,
439            event: decoded_event,
440        } = decoded
441        {
442            assert_eq!(
443                client_id,
444                ClientId(0),
445                "Wire decoding should default client_id to 0"
446            );
447            match decoded_event {
448                GameEvent::AsteroidDepleted { network_id } => {
449                    assert_eq!(network_id, NetworkId(123));
450                }
451                GameEvent::Possession { .. } | GameEvent::SystemManifest { .. } => {
452                    panic!("Unexpected event type in roundtrip test");
453                }
454            }
455        } else {
456            panic!("Decoded event is not a GameEvent: {decoded:?}");
457        }
458    }
459}