Skip to main content

aetheris_encoder_serde/
serde_encoder.rs

1//! Implementation of the `SerdeEncoder` using `rmp-serde`.
2
3use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12/// Internal header for serialized replication events.
13///
14/// Ensures a stable binary format across different `rmp-serde` configurations.
15#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17    network_id: NetworkId,
18    component_kind: ComponentKind,
19    tick: u64,
20}
21
22/// A `serde`-based encoder that uses `rmp-serde` (`MessagePack`) for binary serialization.
23///
24/// This implementation targets Phase 1 MVP requirements for rapid iteration.
25/// It uses a fixed-size header followed by the raw component payload.
26#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30    /// Creates a new `SerdeEncoder`.
31    #[must_use]
32    pub fn new() -> Self {
33        Self
34    }
35
36    /// Encodes a `NetworkEvent` into raw bytes for transmission.
37    ///
38    /// # Errors
39    /// Returns `EncodeError` if the event fails to serialize or is a local-only variant.
40    pub fn encode_event(
41        &self,
42        event: &aetheris_protocol::events::NetworkEvent,
43    ) -> Result<Vec<u8>, EncodeError> {
44        let wire_event = Self::to_wire_event(event)?;
45        rmp_serde::to_vec(&wire_event)
46            .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
47    }
48
49    /// Encodes a `NetworkEvent` into the provided buffer.
50    ///
51    /// # Errors
52    /// Returns `EncodeError::BufferOverflow` if the buffer is too small.
53    pub fn encode_event_into(
54        &self,
55        event: &NetworkEvent,
56        buffer: &mut [u8],
57    ) -> Result<usize, EncodeError> {
58        let wire_event = Self::to_wire_event(event)?;
59        let mut cursor = Cursor::new(&mut *buffer);
60        rmp_serde::encode::write(&mut cursor, &wire_event).map_err(|e| {
61            let err_msg = e.to_string();
62            if err_msg.contains("unexpected end of file") || err_msg.contains("invalid value write")
63            {
64                EncodeError::BufferOverflow {
65                    needed: 256, // Estimate
66                    available: cursor.get_ref().len(),
67                }
68            } else {
69                EncodeError::Io(std::io::Error::other(err_msg))
70            }
71        })?;
72        usize::try_from(cursor.position()).map_err(|_| EncodeError::BufferOverflow {
73            needed: usize::MAX,
74            available: buffer.len(),
75        })
76    }
77
78    fn to_wire_event(event: &NetworkEvent) -> Result<WireEvent, EncodeError> {
79        Ok(match event {
80            NetworkEvent::Ping { tick, .. } => WireEvent::Ping { tick: *tick },
81            NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
82            NetworkEvent::Auth { session_token } => WireEvent::Auth {
83                session_token: session_token.clone(),
84            },
85            NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
86            NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
87                count: *count,
88                rotate: *rotate,
89            },
90            NetworkEvent::Spawn {
91                entity_type,
92                x,
93                y,
94                rot,
95                ..
96            } => WireEvent::Spawn {
97                entity_type: *entity_type,
98                x: *x,
99                y: *y,
100                rot: *rot,
101            },
102            NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
103            NetworkEvent::StartSession { .. } => WireEvent::StartSession,
104            NetworkEvent::RequestWorkspaceManifest { .. } => WireEvent::RequestWorkspaceManifest,
105            NetworkEvent::PlatformEvent { event, .. } => WireEvent::PlatformEvent(event.clone()),
106            NetworkEvent::EntitySpawned {
107                network_id, kind, ..
108            } => WireEvent::EntitySpawned {
109                network_id: *network_id,
110                kind: *kind,
111            },
112            NetworkEvent::EntityDespawned { network_id, .. } => WireEvent::EntityDespawned {
113                network_id: *network_id,
114            },
115            NetworkEvent::ReplicationBatch { events, .. } => {
116                WireEvent::ReplicationBatch(events.clone())
117            }
118            // NOTE: When adding new NetworkEvent variants, ensure they are handled here.
119            // Local-only variants should be added to the error arm below.
120            NetworkEvent::ClientConnected(_)
121            | NetworkEvent::ClientDisconnected(_)
122            | NetworkEvent::UnreliableMessage { .. }
123            | NetworkEvent::ReliableMessage { .. }
124            | NetworkEvent::SessionClosed(_)
125            | NetworkEvent::StreamReset(_)
126            | NetworkEvent::Disconnected(_) => {
127                return Err(EncodeError::Io(std::io::Error::other(format!(
128                    "Cannot encode local-only variant as wire event: {event:?}"
129                ))));
130            }
131        })
132    }
133
134    /// Decodes raw bytes into a `NetworkEvent`.
135    ///
136    /// # Errors
137    /// Returns `EncodeError` if the bytes are not a valid `WireEvent`.
138    pub fn decode_event(
139        &self,
140        data: &[u8],
141    ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
142        let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
143            EncodeError::MalformedPayload {
144                offset: 0, // In Phase 1 we don't track exact rmp-serde offset easily
145                message: e.to_string(),
146            }
147        })?;
148
149        Ok(match wire_event {
150            WireEvent::Ping { tick } => NetworkEvent::Ping {
151                client_id: ClientId(0), // Populated by transport/server
152                tick,
153            },
154            WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
155            WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
156            WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
157                client_id: ClientId(0),
158                fragment,
159            },
160            WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
161                client_id: ClientId(0), // Populated by server
162                count,
163                rotate,
164            },
165            WireEvent::Spawn {
166                entity_type,
167                x,
168                y,
169                rot,
170            } => NetworkEvent::Spawn {
171                client_id: ClientId(0),
172                entity_type,
173                x,
174                y,
175                rot,
176            },
177            WireEvent::ClearWorld => NetworkEvent::ClearWorld {
178                client_id: ClientId(0),
179            },
180            WireEvent::StartSession => NetworkEvent::StartSession {
181                client_id: ClientId(0),
182            },
183            WireEvent::RequestWorkspaceManifest => NetworkEvent::RequestWorkspaceManifest {
184                client_id: ClientId(0),
185            },
186            WireEvent::PlatformEvent(event) => NetworkEvent::PlatformEvent {
187                client_id: ClientId(0),
188                event,
189            },
190            WireEvent::EntitySpawned { network_id, kind } => NetworkEvent::EntitySpawned {
191                client_id: ClientId(0),
192                network_id,
193                kind,
194            },
195            WireEvent::EntityDespawned { network_id } => NetworkEvent::EntityDespawned {
196                client_id: ClientId(0),
197                network_id,
198            },
199            WireEvent::ReplicationBatch(events) => NetworkEvent::ReplicationBatch {
200                client_id: ClientId(0),
201                events,
202            },
203        })
204    }
205}
206
207impl Encoder for SerdeEncoder {
208    fn codec_id(&self) -> u32 {
209        1
210    }
211
212    fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
213        self.encode_event(event)
214    }
215
216    fn encode_event_into(
217        &self,
218        event: &NetworkEvent,
219        buffer: &mut [u8],
220    ) -> Result<usize, EncodeError> {
221        self.encode_event_into(event, buffer)
222    }
223
224    fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
225        self.decode_event(data)
226    }
227
228    fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
229        #[cfg(not(target_arch = "wasm32"))]
230        let start = std::time::Instant::now();
231        let header = PacketHeader {
232            network_id: event.network_id,
233            component_kind: event.component_kind,
234            tick: event.tick,
235        };
236
237        let mut cursor = Cursor::new(buffer);
238        let mut serializer = rmp_serde::Serializer::new(&mut cursor);
239
240        header.serialize(&mut serializer).map_err(|_e| {
241            metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
242                .increment(1);
243            // If it fails to serialize, it's likely a buffer overflow.
244            EncodeError::BufferOverflow {
245                needed: 32, // PacketHeader is small (~20 bytes)
246                available: cursor.get_ref().len(),
247            }
248        })?;
249
250        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
251        let payload_len = event.payload.len();
252        let total_needed = header_len + payload_len;
253
254        if total_needed > cursor.get_ref().len() {
255            metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
256                .increment(1);
257            return Err(EncodeError::BufferOverflow {
258                needed: total_needed,
259                available: cursor.get_ref().len(),
260            });
261        }
262
263        // Copy payload manually after the header
264        let slice = cursor.into_inner();
265        slice[header_len..total_needed].copy_from_slice(&event.payload);
266
267        #[allow(clippy::cast_precision_loss)]
268        metrics::histogram!(
269            "aetheris_encoder_payload_size_bytes",
270            "operation" => "encode"
271        )
272        .record(total_needed as f64);
273
274        #[cfg(not(target_arch = "wasm32"))]
275        metrics::histogram!(
276            "aetheris_encoder_encode_duration_seconds",
277            "kind" => event.component_kind.0.to_string()
278        )
279        .record(start.elapsed().as_secs_f64());
280
281        Ok(total_needed)
282    }
283
284    fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
285        #[cfg(not(target_arch = "wasm32"))]
286        let start = std::time::Instant::now();
287        let mut cursor = Cursor::new(buffer);
288        let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
289
290        let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
291            metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
292                .increment(1);
293            EncodeError::MalformedPayload {
294                offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
295                message: e.to_string(),
296            }
297        })?;
298
299        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
300        let payload = buffer
301            .get(header_len..)
302            .ok_or(EncodeError::MalformedPayload {
303                offset: header_len,
304                message: "Payload slice out of bounds".to_string(),
305            })?
306            .to_vec();
307
308        #[allow(clippy::cast_precision_loss)]
309        metrics::histogram!(
310            "aetheris_encoder_payload_size_bytes",
311            "operation" => "decode"
312        )
313        .record(buffer.len() as f64);
314
315        #[cfg(not(target_arch = "wasm32"))]
316        metrics::histogram!(
317            "aetheris_encoder_decode_duration_seconds",
318            "kind" => header.component_kind.0.to_string()
319        )
320        .record(start.elapsed().as_secs_f64());
321
322        Ok(ComponentUpdate {
323            network_id: header.network_id,
324            component_kind: header.component_kind,
325            payload,
326            tick: header.tick,
327        })
328    }
329
330    fn max_encoded_size(&self) -> usize {
331        aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use proptest::prelude::*;
339
340    #[test]
341    fn test_roundtrip() {
342        let encoder = SerdeEncoder::new();
343        let event = ReplicationEvent {
344            network_id: NetworkId(42),
345            component_kind: ComponentKind(1),
346            payload: vec![1, 2, 3, 4],
347            tick: 100,
348        };
349
350        let mut buffer = [0u8; 1200];
351        let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
352        assert!(bytes_written > 0);
353
354        let update = encoder.decode(&buffer[..bytes_written]).unwrap();
355        assert_eq!(update.network_id, event.network_id);
356        assert_eq!(update.component_kind, event.component_kind);
357        assert_eq!(update.tick, event.tick);
358        assert_eq!(update.payload, event.payload);
359    }
360    #[test]
361    fn test_fragment_roundtrip() {
362        let encoder = SerdeEncoder::new();
363        let fragment = aetheris_protocol::events::FragmentedEvent {
364            message_id: 123,
365            fragment_index: 1,
366            total_fragments: 5,
367            payload: vec![1, 2, 3],
368        };
369
370        let event = NetworkEvent::Fragment {
371            client_id: aetheris_protocol::types::ClientId(0),
372            fragment: fragment.clone(),
373        };
374
375        let output = encoder.encode_event(&event).unwrap();
376        let decoded = encoder.decode_event(&output).unwrap();
377
378        if let NetworkEvent::Fragment {
379            client_id: _,
380            fragment: decoded_fragment,
381        } = decoded
382        {
383            assert_eq!(decoded_fragment.message_id, fragment.message_id);
384            assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
385            assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
386            assert_eq!(decoded_fragment.payload, fragment.payload);
387        } else {
388            panic!("Decoded event is not a Fragment: {decoded:?}");
389        }
390    }
391
392    #[test]
393    fn test_buffer_overflow() {
394        let encoder = SerdeEncoder::new();
395        let event = ReplicationEvent {
396            network_id: NetworkId(42),
397            component_kind: ComponentKind(1),
398            payload: vec![1, 2, 3, 4],
399            tick: 100,
400        };
401
402        let mut small_buffer = [0u8; 1];
403        let result = encoder.encode(&event, &mut small_buffer);
404        assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
405    }
406
407    #[test]
408    fn test_malformed_payload() {
409        let encoder = SerdeEncoder::new();
410        let garbage = [0xff, 0xff, 0xff, 0xff];
411        let result = encoder.decode(&garbage);
412        if let Err(EncodeError::MalformedPayload { message, .. }) = result {
413            assert!(!message.is_empty());
414        } else {
415            panic!("Expected MalformedPayload error, got {result:?}");
416        }
417    }
418
419    proptest! {
420        #[test]
421        fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
422            let encoder = SerdeEncoder::new();
423            // Should never panic
424            let _ = encoder.decode(bytes);
425        }
426
427        #[test]
428        fn test_fuzz_roundtrip(
429            nid in any::<u64>(),
430            kind in any::<u16>(),
431            tick in any::<u64>(),
432            ref payload in any::<Vec<u8>>()
433        ) {
434            let encoder = SerdeEncoder::new();
435            let event = ReplicationEvent {
436                network_id: NetworkId(nid),
437                component_kind: ComponentKind(kind),
438                payload: payload.clone(),
439                tick,
440            };
441
442            let mut buffer = vec![0u8; 2048 + payload.len()];
443            if let Ok(written) = encoder.encode(&event, &mut buffer) {
444                let update = encoder.decode(&buffer[..written])
445                    .expect("Round-trip decode failed during fuzzed test");
446                assert_eq!(update.network_id, event.network_id);
447                assert_eq!(update.component_kind, event.component_kind);
448                assert_eq!(update.tick, event.tick);
449                assert_eq!(update.payload, event.payload);
450            }
451        }
452    }
453
454    #[test]
455    fn test_disconnected_not_serializable() {
456        let encoder = SerdeEncoder::new();
457        let event = NetworkEvent::Disconnected(ClientId(42));
458
459        // Attempting to encode a local-only event should return an error
460        let result = encoder.encode_event(&event);
461        assert!(result.is_err());
462        if let Err(EncodeError::Io(e)) = result {
463            assert!(e.to_string().contains("Cannot encode local-only variant"));
464        } else {
465            panic!("Expected EncodeError::Io with local-only message, got {result:?}");
466        }
467    }
468
469    #[test]
470    fn test_platform_event_roundtrip() {
471        use aetheris_protocol::events::PlatformEvent;
472        use aetheris_protocol::types::NetworkId;
473        use std::collections::BTreeMap;
474
475        let encoder = SerdeEncoder::new();
476
477        let test_events = vec![
478            PlatformEvent::ResourceExhausted {
479                network_id: NetworkId(123),
480            },
481            PlatformEvent::Possession {
482                network_id: NetworkId(456),
483            },
484            PlatformEvent::WorkspaceManifest {
485                manifest: BTreeMap::from([("test_key".to_string(), "test_value".to_string())]),
486            },
487            PlatformEvent::Interaction {
488                source: NetworkId(1),
489                target: NetworkId(2),
490                amount: 10,
491            },
492            PlatformEvent::Termination {
493                target: NetworkId(3),
494            },
495            PlatformEvent::Reinitialization {
496                target: NetworkId(4),
497                x: 10.5,
498                y: 20.2,
499            },
500            PlatformEvent::PayloadCollected {
501                network_id: NetworkId(5),
502                amount: 50,
503            },
504        ];
505
506        for platform_event in test_events {
507            let event = NetworkEvent::PlatformEvent {
508                client_id: ClientId(1),
509                event: platform_event.clone(),
510            };
511
512            let output = encoder.encode_event(&event).unwrap();
513            let decoded = encoder.decode_event(&output).unwrap();
514
515            if let NetworkEvent::PlatformEvent {
516                client_id,
517                event: decoded_event,
518            } = decoded
519            {
520                assert_eq!(
521                    client_id,
522                    ClientId(0),
523                    "Wire decoding should default client_id to 0"
524                );
525                assert_eq!(
526                    decoded_event, platform_event,
527                    "Roundtrip failed for {platform_event:?}"
528                );
529            } else {
530                panic!("Decoded event is not a PlatformEvent: {decoded:?}");
531            }
532        }
533    }
534}