1use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17 network_id: NetworkId,
18 component_kind: ComponentKind,
19 tick: u64,
20}
21
22#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30 #[must_use]
32 pub fn new() -> Self {
33 Self
34 }
35
36 pub fn encode_event(
41 &self,
42 event: &aetheris_protocol::events::NetworkEvent,
43 ) -> Result<Vec<u8>, EncodeError> {
44 let wire_event = Self::to_wire_event(event)?;
45 rmp_serde::to_vec(&wire_event)
46 .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
47 }
48
49 pub fn encode_event_into(
54 &self,
55 event: &NetworkEvent,
56 buffer: &mut [u8],
57 ) -> Result<usize, EncodeError> {
58 let wire_event = Self::to_wire_event(event)?;
59 let mut cursor = Cursor::new(&mut *buffer);
60 rmp_serde::encode::write(&mut cursor, &wire_event).map_err(|e| {
61 let err_msg = e.to_string();
62 if err_msg.contains("unexpected end of file") || err_msg.contains("invalid value write")
63 {
64 EncodeError::BufferOverflow {
65 needed: 256, available: cursor.get_ref().len(),
67 }
68 } else {
69 EncodeError::Io(std::io::Error::other(err_msg))
70 }
71 })?;
72 usize::try_from(cursor.position()).map_err(|_| EncodeError::BufferOverflow {
73 needed: usize::MAX,
74 available: buffer.len(),
75 })
76 }
77
78 fn to_wire_event(event: &NetworkEvent) -> Result<WireEvent, EncodeError> {
79 Ok(match event {
80 NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
81 NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
82 NetworkEvent::Auth { session_token } => WireEvent::Auth {
83 session_token: session_token.clone(),
84 },
85 NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
86 NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
87 count: *count,
88 rotate: *rotate,
89 },
90 NetworkEvent::Spawn {
91 entity_type,
92 x,
93 y,
94 rot,
95 ..
96 } => WireEvent::Spawn {
97 entity_type: *entity_type,
98 x: *x,
99 y: *y,
100 rot: *rot,
101 },
102 NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
103 NetworkEvent::StartSession { .. } => WireEvent::StartSession,
104 NetworkEvent::RequestSystemManifest { .. } => WireEvent::RequestSystemManifest,
105 NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
106 NetworkEvent::ReplicationBatch { events, .. } => {
107 WireEvent::ReplicationBatch(events.clone())
108 }
109 _ => {
110 return Err(EncodeError::Io(std::io::Error::other(format!(
111 "Cannot encode local-only variant as wire event: {event:?}"
112 ))));
113 }
114 })
115 }
116
117 pub fn decode_event(
122 &self,
123 data: &[u8],
124 ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
125 let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
126 EncodeError::MalformedPayload {
127 offset: 0, message: e.to_string(),
129 }
130 })?;
131
132 Ok(match wire_event {
133 WireEvent::Ping { tick } => NetworkEvent::Ping {
134 client_id: ClientId(0), tick,
136 },
137 WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
138 WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
139 WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
140 client_id: ClientId(0),
141 fragment,
142 },
143 WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
144 client_id: ClientId(0), count,
146 rotate,
147 },
148 WireEvent::Spawn {
149 entity_type,
150 x,
151 y,
152 rot,
153 } => NetworkEvent::Spawn {
154 client_id: ClientId(0),
155 entity_type,
156 x,
157 y,
158 rot,
159 },
160 WireEvent::ClearWorld => NetworkEvent::ClearWorld {
161 client_id: ClientId(0),
162 },
163 WireEvent::StartSession => NetworkEvent::StartSession {
164 client_id: ClientId(0),
165 },
166 WireEvent::RequestSystemManifest => NetworkEvent::RequestSystemManifest {
167 client_id: ClientId(0),
168 },
169 WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
170 client_id: ClientId(0),
171 event,
172 },
173 WireEvent::ReplicationBatch(events) => NetworkEvent::ReplicationBatch {
174 client_id: ClientId(0),
175 events,
176 },
177 })
178 }
179}
180
181impl Encoder for SerdeEncoder {
182 fn codec_id(&self) -> u32 {
183 1
184 }
185
186 fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
187 self.encode_event(event)
188 }
189
190 fn encode_event_into(
191 &self,
192 event: &NetworkEvent,
193 buffer: &mut [u8],
194 ) -> Result<usize, EncodeError> {
195 self.encode_event_into(event, buffer)
196 }
197
198 fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
199 self.decode_event(data)
200 }
201
202 fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
203 #[cfg(not(target_arch = "wasm32"))]
204 let start = std::time::Instant::now();
205 let header = PacketHeader {
206 network_id: event.network_id,
207 component_kind: event.component_kind,
208 tick: event.tick,
209 };
210
211 let mut cursor = Cursor::new(buffer);
212 let mut serializer = rmp_serde::Serializer::new(&mut cursor);
213
214 header.serialize(&mut serializer).map_err(|_e| {
215 metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
216 .increment(1);
217 EncodeError::BufferOverflow {
219 needed: 32, available: cursor.get_ref().len(),
221 }
222 })?;
223
224 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
225 let payload_len = event.payload.len();
226 let total_needed = header_len + payload_len;
227
228 if total_needed > cursor.get_ref().len() {
229 metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
230 .increment(1);
231 return Err(EncodeError::BufferOverflow {
232 needed: total_needed,
233 available: cursor.get_ref().len(),
234 });
235 }
236
237 let slice = cursor.into_inner();
239 slice[header_len..total_needed].copy_from_slice(&event.payload);
240
241 #[allow(clippy::cast_precision_loss)]
242 metrics::histogram!(
243 "aetheris_encoder_payload_size_bytes",
244 "operation" => "encode"
245 )
246 .record(total_needed as f64);
247
248 #[cfg(not(target_arch = "wasm32"))]
249 metrics::histogram!(
250 "aetheris_encoder_encode_duration_seconds",
251 "kind" => event.component_kind.0.to_string()
252 )
253 .record(start.elapsed().as_secs_f64());
254
255 Ok(total_needed)
256 }
257
258 fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
259 #[cfg(not(target_arch = "wasm32"))]
260 let start = std::time::Instant::now();
261 let mut cursor = Cursor::new(buffer);
262 let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
263
264 let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
265 metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
266 .increment(1);
267 EncodeError::MalformedPayload {
268 offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
269 message: e.to_string(),
270 }
271 })?;
272
273 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
274 let payload = buffer
275 .get(header_len..)
276 .ok_or(EncodeError::MalformedPayload {
277 offset: header_len,
278 message: "Payload slice out of bounds".to_string(),
279 })?
280 .to_vec();
281
282 #[allow(clippy::cast_precision_loss)]
283 metrics::histogram!(
284 "aetheris_encoder_payload_size_bytes",
285 "operation" => "decode"
286 )
287 .record(buffer.len() as f64);
288
289 #[cfg(not(target_arch = "wasm32"))]
290 metrics::histogram!(
291 "aetheris_encoder_decode_duration_seconds",
292 "kind" => header.component_kind.0.to_string()
293 )
294 .record(start.elapsed().as_secs_f64());
295
296 Ok(ComponentUpdate {
297 network_id: header.network_id,
298 component_kind: header.component_kind,
299 payload,
300 tick: header.tick,
301 })
302 }
303
304 fn max_encoded_size(&self) -> usize {
305 aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use proptest::prelude::*;
313
314 #[test]
315 fn test_roundtrip() {
316 let encoder = SerdeEncoder::new();
317 let event = ReplicationEvent {
318 network_id: NetworkId(42),
319 component_kind: ComponentKind(1),
320 payload: vec![1, 2, 3, 4],
321 tick: 100,
322 };
323
324 let mut buffer = [0u8; 1200];
325 let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
326 assert!(bytes_written > 0);
327
328 let update = encoder.decode(&buffer[..bytes_written]).unwrap();
329 assert_eq!(update.network_id, event.network_id);
330 assert_eq!(update.component_kind, event.component_kind);
331 assert_eq!(update.tick, event.tick);
332 assert_eq!(update.payload, event.payload);
333 }
334 #[test]
335 fn test_fragment_roundtrip() {
336 let encoder = SerdeEncoder::new();
337 let fragment = aetheris_protocol::events::FragmentedEvent {
338 message_id: 123,
339 fragment_index: 1,
340 total_fragments: 5,
341 payload: vec![1, 2, 3],
342 };
343
344 let event = NetworkEvent::Fragment {
345 client_id: aetheris_protocol::types::ClientId(0),
346 fragment: fragment.clone(),
347 };
348
349 let output = encoder.encode_event(&event).unwrap();
350 let decoded = encoder.decode_event(&output).unwrap();
351
352 if let NetworkEvent::Fragment {
353 client_id: _,
354 fragment: decoded_fragment,
355 } = decoded
356 {
357 assert_eq!(decoded_fragment.message_id, fragment.message_id);
358 assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
359 assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
360 assert_eq!(decoded_fragment.payload, fragment.payload);
361 } else {
362 panic!("Decoded event is not a Fragment: {decoded:?}");
363 }
364 }
365
366 #[test]
367 fn test_buffer_overflow() {
368 let encoder = SerdeEncoder::new();
369 let event = ReplicationEvent {
370 network_id: NetworkId(42),
371 component_kind: ComponentKind(1),
372 payload: vec![1, 2, 3, 4],
373 tick: 100,
374 };
375
376 let mut small_buffer = [0u8; 1];
377 let result = encoder.encode(&event, &mut small_buffer);
378 assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
379 }
380
381 #[test]
382 fn test_malformed_payload() {
383 let encoder = SerdeEncoder::new();
384 let garbage = [0xff, 0xff, 0xff, 0xff];
385 let result = encoder.decode(&garbage);
386 if let Err(EncodeError::MalformedPayload { message, .. }) = result {
387 assert!(!message.is_empty());
388 } else {
389 panic!("Expected MalformedPayload error, got {result:?}");
390 }
391 }
392
393 proptest! {
394 #[test]
395 fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
396 let encoder = SerdeEncoder::new();
397 let _ = encoder.decode(bytes);
399 }
400
401 #[test]
402 fn test_fuzz_roundtrip(
403 nid in any::<u64>(),
404 kind in any::<u16>(),
405 tick in any::<u64>(),
406 ref payload in any::<Vec<u8>>()
407 ) {
408 let encoder = SerdeEncoder::new();
409 let event = ReplicationEvent {
410 network_id: NetworkId(nid),
411 component_kind: ComponentKind(kind),
412 payload: payload.clone(),
413 tick,
414 };
415
416 let mut buffer = vec![0u8; 2048 + payload.len()];
417 if let Ok(written) = encoder.encode(&event, &mut buffer) {
418 let update = encoder.decode(&buffer[..written])
419 .expect("Round-trip decode failed during fuzzed test");
420 assert_eq!(update.network_id, event.network_id);
421 assert_eq!(update.component_kind, event.component_kind);
422 assert_eq!(update.tick, event.tick);
423 assert_eq!(update.payload, event.payload);
424 }
425 }
426 }
427
428 #[test]
429 fn test_disconnected_not_serializable() {
430 let encoder = SerdeEncoder::new();
431 let event = NetworkEvent::Disconnected(ClientId(42));
432
433 let result = encoder.encode_event(&event);
435 assert!(result.is_err());
436 if let Err(EncodeError::Io(e)) = result {
437 assert!(e.to_string().contains("Cannot encode local-only variant"));
438 } else {
439 panic!("Expected EncodeError::Io with local-only message, got {result:?}");
440 }
441 }
442
443 #[test]
444 fn test_game_event_roundtrip() {
445 use aetheris_protocol::events::GameEvent;
446 use aetheris_protocol::types::NetworkId;
447 use std::collections::BTreeMap;
448
449 let encoder = SerdeEncoder::new();
450
451 let test_events = vec![
452 GameEvent::AsteroidDepleted {
453 network_id: NetworkId(123),
454 },
455 GameEvent::Possession {
456 network_id: NetworkId(456),
457 },
458 GameEvent::SystemManifest {
459 manifest: BTreeMap::new(),
460 },
461 GameEvent::DamageEvent {
462 source: NetworkId(1),
463 target: NetworkId(2),
464 amount: 10,
465 },
466 GameEvent::DeathEvent {
467 target: NetworkId(3),
468 },
469 GameEvent::RespawnEvent {
470 target: NetworkId(4),
471 x: 10.5,
472 y: 20.2,
473 },
474 GameEvent::CargoCollected {
475 network_id: NetworkId(5),
476 amount: 50,
477 },
478 ];
479
480 for game_event in test_events {
481 let event = NetworkEvent::GameEvent {
482 client_id: ClientId(1),
483 event: game_event.clone(),
484 };
485
486 let output = encoder.encode_event(&event).unwrap();
487 let decoded = encoder.decode_event(&output).unwrap();
488
489 if let NetworkEvent::GameEvent {
490 client_id,
491 event: decoded_event,
492 } = decoded
493 {
494 assert_eq!(
495 client_id,
496 ClientId(0),
497 "Wire decoding should default client_id to 0"
498 );
499 assert_eq!(
500 decoded_event, game_event,
501 "Roundtrip failed for {game_event:?}"
502 );
503 } else {
504 panic!("Decoded event is not a GameEvent: {decoded:?}");
505 }
506 }
507 }
508}