Skip to main content

aetheris_encoder_serde/
serde_encoder.rs

1//! Implementation of the `SerdeEncoder` using `rmp-serde`.
2
3use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12/// Internal header for serialized replication events.
13///
14/// Ensures a stable binary format across different `rmp-serde` configurations.
15#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17    network_id: NetworkId,
18    component_kind: ComponentKind,
19    tick: u64,
20}
21
22/// A `serde`-based encoder that uses `rmp-serde` (`MessagePack`) for binary serialization.
23///
24/// This implementation targets Phase 1 MVP requirements for rapid iteration.
25/// It uses a fixed-size header followed by the raw component payload.
26#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30    /// Creates a new `SerdeEncoder`.
31    #[must_use]
32    pub fn new() -> Self {
33        Self
34    }
35
36    /// Encodes a `NetworkEvent` into raw bytes for transmission.
37    ///
38    /// # Errors
39    /// Returns `EncodeError` if the event fails to serialize or is a local-only variant.
40    pub fn encode_event(
41        &self,
42        event: &aetheris_protocol::events::NetworkEvent,
43    ) -> Result<Vec<u8>, EncodeError> {
44        let wire_event = match event {
45            NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
46            NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
47            NetworkEvent::Auth { session_token } => WireEvent::Auth {
48                session_token: session_token.clone(),
49            },
50            NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
51            NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
52                count: *count,
53                rotate: *rotate,
54            },
55            NetworkEvent::Spawn {
56                entity_type,
57                x,
58                y,
59                rot,
60                ..
61            } => WireEvent::Spawn {
62                entity_type: *entity_type,
63                x: *x,
64                y: *y,
65                rot: *rot,
66            },
67            NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
68            NetworkEvent::ClientConnected(_)
69            | NetworkEvent::ClientDisconnected(_)
70            | NetworkEvent::UnreliableMessage { .. }
71            | NetworkEvent::ReliableMessage { .. }
72            | NetworkEvent::Ping { .. }
73            | NetworkEvent::SessionClosed(_)
74            | NetworkEvent::StreamReset(_)
75            | NetworkEvent::Disconnected(_) => {
76                return Err(EncodeError::Io(std::io::Error::other(format!(
77                    "Cannot encode local-only variant as wire event: {event:?}"
78                ))));
79            }
80        };
81        rmp_serde::to_vec(&wire_event)
82            .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
83    }
84
85    /// Decodes raw bytes into a `NetworkEvent`.
86    ///
87    /// # Errors
88    /// Returns `EncodeError` if the bytes are not a valid `WireEvent`.
89    pub fn decode_event(
90        &self,
91        data: &[u8],
92    ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
93        let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
94            EncodeError::MalformedPayload {
95                offset: 0, // In Phase 1 we don't track exact rmp-serde offset easily
96                message: e.to_string(),
97            }
98        })?;
99
100        Ok(match wire_event {
101            WireEvent::Ping { tick } => NetworkEvent::Ping {
102                client_id: ClientId(0), // Populated by transport/server
103                tick,
104            },
105            WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
106            WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
107            WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
108                client_id: ClientId(0),
109                fragment,
110            },
111            WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
112                client_id: ClientId(0), // Populated by server
113                count,
114                rotate,
115            },
116            WireEvent::Spawn {
117                entity_type,
118                x,
119                y,
120                rot,
121            } => NetworkEvent::Spawn {
122                client_id: ClientId(0),
123                entity_type,
124                x,
125                y,
126                rot,
127            },
128            WireEvent::ClearWorld => NetworkEvent::ClearWorld {
129                client_id: ClientId(0),
130            },
131        })
132    }
133}
134
135impl Encoder for SerdeEncoder {
136    fn codec_id(&self) -> u32 {
137        1
138    }
139
140    fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
141        self.encode_event(event)
142    }
143
144    fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
145        self.decode_event(data)
146    }
147
148    fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
149        #[cfg(not(target_arch = "wasm32"))]
150        let start = std::time::Instant::now();
151        let header = PacketHeader {
152            network_id: event.network_id,
153            component_kind: event.component_kind,
154            tick: event.tick,
155        };
156
157        let mut cursor = Cursor::new(buffer);
158        let mut serializer = rmp_serde::Serializer::new(&mut cursor);
159
160        header.serialize(&mut serializer).map_err(|_e| {
161            metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
162                .increment(1);
163            // If it fails to serialize, it's likely a buffer overflow.
164            EncodeError::BufferOverflow {
165                needed: 32, // PacketHeader is small (~20 bytes)
166                available: cursor.get_ref().len(),
167            }
168        })?;
169
170        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
171        let payload_len = event.payload.len();
172        let total_needed = header_len + payload_len;
173
174        if total_needed > cursor.get_ref().len() {
175            metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
176                .increment(1);
177            return Err(EncodeError::BufferOverflow {
178                needed: total_needed,
179                available: cursor.get_ref().len(),
180            });
181        }
182
183        // Copy payload manually after the header
184        let slice = cursor.into_inner();
185        slice[header_len..total_needed].copy_from_slice(&event.payload);
186
187        #[allow(clippy::cast_precision_loss)]
188        metrics::histogram!(
189            "aetheris_encoder_payload_size_bytes",
190            "operation" => "encode"
191        )
192        .record(total_needed as f64);
193
194        #[cfg(not(target_arch = "wasm32"))]
195        metrics::histogram!(
196            "aetheris_encoder_encode_duration_seconds",
197            "kind" => event.component_kind.0.to_string()
198        )
199        .record(start.elapsed().as_secs_f64());
200
201        Ok(total_needed)
202    }
203
204    fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
205        #[cfg(not(target_arch = "wasm32"))]
206        let start = std::time::Instant::now();
207        let mut cursor = Cursor::new(buffer);
208        let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
209
210        let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
211            metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
212                .increment(1);
213            EncodeError::MalformedPayload {
214                offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
215                message: e.to_string(),
216            }
217        })?;
218
219        let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
220        let payload = buffer
221            .get(header_len..)
222            .ok_or(EncodeError::MalformedPayload {
223                offset: header_len,
224                message: "Payload slice out of bounds".to_string(),
225            })?
226            .to_vec();
227
228        #[allow(clippy::cast_precision_loss)]
229        metrics::histogram!(
230            "aetheris_encoder_payload_size_bytes",
231            "operation" => "decode"
232        )
233        .record(buffer.len() as f64);
234
235        #[cfg(not(target_arch = "wasm32"))]
236        metrics::histogram!(
237            "aetheris_encoder_decode_duration_seconds",
238            "kind" => header.component_kind.0.to_string()
239        )
240        .record(start.elapsed().as_secs_f64());
241
242        Ok(ComponentUpdate {
243            network_id: header.network_id,
244            component_kind: header.component_kind,
245            payload,
246            tick: header.tick,
247        })
248    }
249
250    fn max_encoded_size(&self) -> usize {
251        aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use proptest::prelude::*;
259
260    #[test]
261    fn test_roundtrip() {
262        let encoder = SerdeEncoder::new();
263        let event = ReplicationEvent {
264            network_id: NetworkId(42),
265            component_kind: ComponentKind(1),
266            payload: vec![1, 2, 3, 4],
267            tick: 100,
268        };
269
270        let mut buffer = [0u8; 1200];
271        let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
272        assert!(bytes_written > 0);
273
274        let update = encoder.decode(&buffer[..bytes_written]).unwrap();
275        assert_eq!(update.network_id, event.network_id);
276        assert_eq!(update.component_kind, event.component_kind);
277        assert_eq!(update.tick, event.tick);
278        assert_eq!(update.payload, event.payload);
279    }
280    #[test]
281    fn test_fragment_roundtrip() {
282        let encoder = SerdeEncoder::new();
283        let fragment = aetheris_protocol::events::FragmentedEvent {
284            message_id: 123,
285            fragment_index: 1,
286            total_fragments: 5,
287            payload: vec![1, 2, 3],
288        };
289
290        let event = NetworkEvent::Fragment {
291            client_id: aetheris_protocol::types::ClientId(0),
292            fragment: fragment.clone(),
293        };
294
295        let output = encoder.encode_event(&event).unwrap();
296        let decoded = encoder.decode_event(&output).unwrap();
297
298        if let NetworkEvent::Fragment {
299            client_id: _,
300            fragment: decoded_fragment,
301        } = decoded
302        {
303            assert_eq!(decoded_fragment.message_id, fragment.message_id);
304            assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
305            assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
306            assert_eq!(decoded_fragment.payload, fragment.payload);
307        } else {
308            panic!("Decoded event is not a Fragment: {decoded:?}");
309        }
310    }
311
312    #[test]
313    fn test_buffer_overflow() {
314        let encoder = SerdeEncoder::new();
315        let event = ReplicationEvent {
316            network_id: NetworkId(42),
317            component_kind: ComponentKind(1),
318            payload: vec![1, 2, 3, 4],
319            tick: 100,
320        };
321
322        let mut small_buffer = [0u8; 1];
323        let result = encoder.encode(&event, &mut small_buffer);
324        assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
325    }
326
327    #[test]
328    fn test_malformed_payload() {
329        let encoder = SerdeEncoder::new();
330        let garbage = [0xff, 0xff, 0xff, 0xff];
331        let result = encoder.decode(&garbage);
332        if let Err(EncodeError::MalformedPayload { message, .. }) = result {
333            assert!(!message.is_empty());
334        } else {
335            panic!("Expected MalformedPayload error, got {result:?}");
336        }
337    }
338
339    proptest! {
340        #[test]
341        fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
342            let encoder = SerdeEncoder::new();
343            // Should never panic
344            let _ = encoder.decode(bytes);
345        }
346
347        #[test]
348        fn test_fuzz_roundtrip(
349            nid in any::<u64>(),
350            kind in any::<u16>(),
351            tick in any::<u64>(),
352            ref payload in any::<Vec<u8>>()
353        ) {
354            let encoder = SerdeEncoder::new();
355            let event = ReplicationEvent {
356                network_id: NetworkId(nid),
357                component_kind: ComponentKind(kind),
358                payload: payload.clone(),
359                tick,
360            };
361
362            let mut buffer = vec![0u8; 2048 + payload.len()];
363            if let Ok(written) = encoder.encode(&event, &mut buffer) {
364                let update = encoder.decode(&buffer[..written])
365                    .expect("Round-trip decode failed during fuzzed test");
366                assert_eq!(update.network_id, event.network_id);
367                assert_eq!(update.component_kind, event.component_kind);
368                assert_eq!(update.tick, event.tick);
369                assert_eq!(update.payload, event.payload);
370            }
371        }
372    }
373
374    #[test]
375    fn test_disconnected_not_serializable() {
376        let encoder = SerdeEncoder::new();
377        let event = NetworkEvent::Disconnected(ClientId(42));
378
379        // Attempting to encode a local-only event should return an error
380        let result = encoder.encode_event(&event);
381        assert!(result.is_err());
382        if let Err(EncodeError::Io(e)) = result {
383            assert!(e.to_string().contains("Cannot encode local-only variant"));
384        } else {
385            panic!("Expected EncodeError::Io with local-only message, got {result:?}");
386        }
387    }
388
389    #[test]
390    fn test_wire_event_exclusivity() {
391        let encoder = SerdeEncoder::new();
392        // ClientConnected is local-only
393        let event = NetworkEvent::ClientConnected(ClientId(1));
394        assert!(encoder.encode_event(&event).is_err());
395
396        // Ping is wire-capable (but we should have a way to distinguish wire pings if needed)
397        // For now, our implementation handles Ping in the match arm correctly by excluding it
398        // from certain paths or variants.
399        // Let's verify that a standard Auth event (which is wire) works.
400        let auth = NetworkEvent::Auth {
401            session_token: "token".to_string(),
402        };
403        assert!(encoder.encode_event(&auth).is_ok());
404    }
405}