1use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17 network_id: NetworkId,
18 component_kind: ComponentKind,
19 tick: u64,
20}
21
22#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30 #[must_use]
32 pub fn new() -> Self {
33 Self
34 }
35
36 pub fn encode_event(
41 &self,
42 event: &aetheris_protocol::events::NetworkEvent,
43 ) -> Result<Vec<u8>, EncodeError> {
44 let wire_event = Self::to_wire_event(event)?;
45 rmp_serde::to_vec(&wire_event)
46 .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
47 }
48
49 pub fn encode_event_into(
54 &self,
55 event: &NetworkEvent,
56 buffer: &mut [u8],
57 ) -> Result<usize, EncodeError> {
58 let wire_event = Self::to_wire_event(event)?;
59 let mut cursor = Cursor::new(&mut *buffer);
60 rmp_serde::encode::write(&mut cursor, &wire_event).map_err(|e| {
61 if e.to_string().contains("unexpected end of file") {
62 EncodeError::BufferOverflow {
63 needed: 256, available: cursor.get_ref().len(),
65 }
66 } else {
67 EncodeError::Io(std::io::Error::other(e.to_string()))
68 }
69 })?;
70 usize::try_from(cursor.position()).map_err(|_| EncodeError::BufferOverflow {
71 needed: usize::MAX,
72 available: buffer.len(),
73 })
74 }
75
76 fn to_wire_event(event: &NetworkEvent) -> Result<WireEvent, EncodeError> {
77 Ok(match event {
78 NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
79 NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
80 NetworkEvent::Auth { session_token } => WireEvent::Auth {
81 session_token: session_token.clone(),
82 },
83 NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
84 NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
85 count: *count,
86 rotate: *rotate,
87 },
88 NetworkEvent::Spawn {
89 entity_type,
90 x,
91 y,
92 rot,
93 ..
94 } => WireEvent::Spawn {
95 entity_type: *entity_type,
96 x: *x,
97 y: *y,
98 rot: *rot,
99 },
100 NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
101 NetworkEvent::StartSession { .. } => WireEvent::StartSession,
102 NetworkEvent::RequestSystemManifest { .. } => WireEvent::RequestSystemManifest,
103 NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
104 NetworkEvent::ReplicationBatch { events, .. } => {
105 WireEvent::ReplicationBatch(events.clone())
106 }
107 _ => {
108 return Err(EncodeError::Io(std::io::Error::other(format!(
109 "Cannot encode local-only variant as wire event: {event:?}"
110 ))));
111 }
112 })
113 }
114
115 pub fn decode_event(
120 &self,
121 data: &[u8],
122 ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
123 let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
124 EncodeError::MalformedPayload {
125 offset: 0, message: e.to_string(),
127 }
128 })?;
129
130 Ok(match wire_event {
131 WireEvent::Ping { tick } => NetworkEvent::Ping {
132 client_id: ClientId(0), tick,
134 },
135 WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
136 WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
137 WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
138 client_id: ClientId(0),
139 fragment,
140 },
141 WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
142 client_id: ClientId(0), count,
144 rotate,
145 },
146 WireEvent::Spawn {
147 entity_type,
148 x,
149 y,
150 rot,
151 } => NetworkEvent::Spawn {
152 client_id: ClientId(0),
153 entity_type,
154 x,
155 y,
156 rot,
157 },
158 WireEvent::ClearWorld => NetworkEvent::ClearWorld {
159 client_id: ClientId(0),
160 },
161 WireEvent::StartSession => NetworkEvent::StartSession {
162 client_id: ClientId(0),
163 },
164 WireEvent::RequestSystemManifest => NetworkEvent::RequestSystemManifest {
165 client_id: ClientId(0),
166 },
167 WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
168 client_id: ClientId(0),
169 event,
170 },
171 WireEvent::ReplicationBatch(events) => NetworkEvent::ReplicationBatch {
172 client_id: ClientId(0),
173 events,
174 },
175 })
176 }
177}
178
179impl Encoder for SerdeEncoder {
180 fn codec_id(&self) -> u32 {
181 1
182 }
183
184 fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
185 self.encode_event(event)
186 }
187
188 fn encode_event_into(
189 &self,
190 event: &NetworkEvent,
191 buffer: &mut [u8],
192 ) -> Result<usize, EncodeError> {
193 self.encode_event_into(event, buffer)
194 }
195
196 fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
197 self.decode_event(data)
198 }
199
200 fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
201 #[cfg(not(target_arch = "wasm32"))]
202 let start = std::time::Instant::now();
203 let header = PacketHeader {
204 network_id: event.network_id,
205 component_kind: event.component_kind,
206 tick: event.tick,
207 };
208
209 let mut cursor = Cursor::new(buffer);
210 let mut serializer = rmp_serde::Serializer::new(&mut cursor);
211
212 header.serialize(&mut serializer).map_err(|_e| {
213 metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
214 .increment(1);
215 EncodeError::BufferOverflow {
217 needed: 32, available: cursor.get_ref().len(),
219 }
220 })?;
221
222 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
223 let payload_len = event.payload.len();
224 let total_needed = header_len + payload_len;
225
226 if total_needed > cursor.get_ref().len() {
227 metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
228 .increment(1);
229 return Err(EncodeError::BufferOverflow {
230 needed: total_needed,
231 available: cursor.get_ref().len(),
232 });
233 }
234
235 let slice = cursor.into_inner();
237 slice[header_len..total_needed].copy_from_slice(&event.payload);
238
239 #[allow(clippy::cast_precision_loss)]
240 metrics::histogram!(
241 "aetheris_encoder_payload_size_bytes",
242 "operation" => "encode"
243 )
244 .record(total_needed as f64);
245
246 #[cfg(not(target_arch = "wasm32"))]
247 metrics::histogram!(
248 "aetheris_encoder_encode_duration_seconds",
249 "kind" => event.component_kind.0.to_string()
250 )
251 .record(start.elapsed().as_secs_f64());
252
253 Ok(total_needed)
254 }
255
256 fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
257 #[cfg(not(target_arch = "wasm32"))]
258 let start = std::time::Instant::now();
259 let mut cursor = Cursor::new(buffer);
260 let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
261
262 let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
263 metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
264 .increment(1);
265 EncodeError::MalformedPayload {
266 offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
267 message: e.to_string(),
268 }
269 })?;
270
271 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
272 let payload = buffer
273 .get(header_len..)
274 .ok_or(EncodeError::MalformedPayload {
275 offset: header_len,
276 message: "Payload slice out of bounds".to_string(),
277 })?
278 .to_vec();
279
280 #[allow(clippy::cast_precision_loss)]
281 metrics::histogram!(
282 "aetheris_encoder_payload_size_bytes",
283 "operation" => "decode"
284 )
285 .record(buffer.len() as f64);
286
287 #[cfg(not(target_arch = "wasm32"))]
288 metrics::histogram!(
289 "aetheris_encoder_decode_duration_seconds",
290 "kind" => header.component_kind.0.to_string()
291 )
292 .record(start.elapsed().as_secs_f64());
293
294 Ok(ComponentUpdate {
295 network_id: header.network_id,
296 component_kind: header.component_kind,
297 payload,
298 tick: header.tick,
299 })
300 }
301
302 fn max_encoded_size(&self) -> usize {
303 aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use proptest::prelude::*;
311
312 #[test]
313 fn test_roundtrip() {
314 let encoder = SerdeEncoder::new();
315 let event = ReplicationEvent {
316 network_id: NetworkId(42),
317 component_kind: ComponentKind(1),
318 payload: vec![1, 2, 3, 4],
319 tick: 100,
320 };
321
322 let mut buffer = [0u8; 1200];
323 let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
324 assert!(bytes_written > 0);
325
326 let update = encoder.decode(&buffer[..bytes_written]).unwrap();
327 assert_eq!(update.network_id, event.network_id);
328 assert_eq!(update.component_kind, event.component_kind);
329 assert_eq!(update.tick, event.tick);
330 assert_eq!(update.payload, event.payload);
331 }
332 #[test]
333 fn test_fragment_roundtrip() {
334 let encoder = SerdeEncoder::new();
335 let fragment = aetheris_protocol::events::FragmentedEvent {
336 message_id: 123,
337 fragment_index: 1,
338 total_fragments: 5,
339 payload: vec![1, 2, 3],
340 };
341
342 let event = NetworkEvent::Fragment {
343 client_id: aetheris_protocol::types::ClientId(0),
344 fragment: fragment.clone(),
345 };
346
347 let output = encoder.encode_event(&event).unwrap();
348 let decoded = encoder.decode_event(&output).unwrap();
349
350 if let NetworkEvent::Fragment {
351 client_id: _,
352 fragment: decoded_fragment,
353 } = decoded
354 {
355 assert_eq!(decoded_fragment.message_id, fragment.message_id);
356 assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
357 assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
358 assert_eq!(decoded_fragment.payload, fragment.payload);
359 } else {
360 panic!("Decoded event is not a Fragment: {decoded:?}");
361 }
362 }
363
364 #[test]
365 fn test_buffer_overflow() {
366 let encoder = SerdeEncoder::new();
367 let event = ReplicationEvent {
368 network_id: NetworkId(42),
369 component_kind: ComponentKind(1),
370 payload: vec![1, 2, 3, 4],
371 tick: 100,
372 };
373
374 let mut small_buffer = [0u8; 1];
375 let result = encoder.encode(&event, &mut small_buffer);
376 assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
377 }
378
379 #[test]
380 fn test_malformed_payload() {
381 let encoder = SerdeEncoder::new();
382 let garbage = [0xff, 0xff, 0xff, 0xff];
383 let result = encoder.decode(&garbage);
384 if let Err(EncodeError::MalformedPayload { message, .. }) = result {
385 assert!(!message.is_empty());
386 } else {
387 panic!("Expected MalformedPayload error, got {result:?}");
388 }
389 }
390
391 proptest! {
392 #[test]
393 fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
394 let encoder = SerdeEncoder::new();
395 let _ = encoder.decode(bytes);
397 }
398
399 #[test]
400 fn test_fuzz_roundtrip(
401 nid in any::<u64>(),
402 kind in any::<u16>(),
403 tick in any::<u64>(),
404 ref payload in any::<Vec<u8>>()
405 ) {
406 let encoder = SerdeEncoder::new();
407 let event = ReplicationEvent {
408 network_id: NetworkId(nid),
409 component_kind: ComponentKind(kind),
410 payload: payload.clone(),
411 tick,
412 };
413
414 let mut buffer = vec![0u8; 2048 + payload.len()];
415 if let Ok(written) = encoder.encode(&event, &mut buffer) {
416 let update = encoder.decode(&buffer[..written])
417 .expect("Round-trip decode failed during fuzzed test");
418 assert_eq!(update.network_id, event.network_id);
419 assert_eq!(update.component_kind, event.component_kind);
420 assert_eq!(update.tick, event.tick);
421 assert_eq!(update.payload, event.payload);
422 }
423 }
424 }
425
426 #[test]
427 fn test_disconnected_not_serializable() {
428 let encoder = SerdeEncoder::new();
429 let event = NetworkEvent::Disconnected(ClientId(42));
430
431 let result = encoder.encode_event(&event);
433 assert!(result.is_err());
434 if let Err(EncodeError::Io(e)) = result {
435 assert!(e.to_string().contains("Cannot encode local-only variant"));
436 } else {
437 panic!("Expected EncodeError::Io with local-only message, got {result:?}");
438 }
439 }
440
441 #[test]
442 fn test_game_event_roundtrip() {
443 use aetheris_protocol::events::GameEvent;
444 use aetheris_protocol::types::NetworkId;
445
446 let encoder = SerdeEncoder::new();
447 let game_event = GameEvent::AsteroidDepleted {
448 network_id: NetworkId(123),
449 };
450 let event = NetworkEvent::GameEvent {
451 client_id: ClientId(1), event: game_event.clone(),
453 };
454
455 let output = encoder.encode_event(&event).unwrap();
456 let decoded = encoder.decode_event(&output).unwrap();
457
458 if let NetworkEvent::GameEvent {
459 client_id,
460 event: decoded_event,
461 } = decoded
462 {
463 assert_eq!(
464 client_id,
465 ClientId(0),
466 "Wire decoding should default client_id to 0"
467 );
468 match decoded_event {
469 GameEvent::AsteroidDepleted { network_id } => {
470 assert_eq!(network_id, NetworkId(123));
471 }
472 GameEvent::Possession { .. } | GameEvent::SystemManifest { .. } => {
473 panic!("Unexpected event type in roundtrip test");
474 }
475 }
476 } else {
477 panic!("Decoded event is not a GameEvent: {decoded:?}");
478 }
479 }
480}