1use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17 network_id: NetworkId,
18 component_kind: ComponentKind,
19 tick: u64,
20}
21
22#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30 #[must_use]
32 pub fn new() -> Self {
33 Self
34 }
35
36 pub fn encode_event(
41 &self,
42 event: &aetheris_protocol::events::NetworkEvent,
43 ) -> Result<Vec<u8>, EncodeError> {
44 let wire_event = Self::to_wire_event(event)?;
45 rmp_serde::to_vec(&wire_event)
46 .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
47 }
48
49 pub fn encode_event_into(
54 &self,
55 event: &NetworkEvent,
56 buffer: &mut [u8],
57 ) -> Result<usize, EncodeError> {
58 let wire_event = Self::to_wire_event(event)?;
59 let mut cursor = Cursor::new(&mut *buffer);
60 rmp_serde::encode::write(&mut cursor, &wire_event).map_err(|e| {
61 let err_msg = e.to_string();
62 if err_msg.contains("unexpected end of file") || err_msg.contains("invalid value write")
63 {
64 EncodeError::BufferOverflow {
65 needed: 256, available: cursor.get_ref().len(),
67 }
68 } else {
69 EncodeError::Io(std::io::Error::other(err_msg))
70 }
71 })?;
72 usize::try_from(cursor.position()).map_err(|_| EncodeError::BufferOverflow {
73 needed: usize::MAX,
74 available: buffer.len(),
75 })
76 }
77
78 fn to_wire_event(event: &NetworkEvent) -> Result<WireEvent, EncodeError> {
79 Ok(match event {
80 NetworkEvent::Ping { tick, .. } => WireEvent::Ping { tick: *tick },
81 NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
82 NetworkEvent::Auth { session_token } => WireEvent::Auth {
83 session_token: session_token.clone(),
84 },
85 NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
86 NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
87 count: *count,
88 rotate: *rotate,
89 },
90 NetworkEvent::Spawn {
91 entity_type,
92 x,
93 y,
94 rot,
95 ..
96 } => WireEvent::Spawn {
97 entity_type: *entity_type,
98 x: *x,
99 y: *y,
100 rot: *rot,
101 },
102 NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
103 NetworkEvent::StartSession { .. } => WireEvent::StartSession,
104 NetworkEvent::RequestWorkspaceManifest { .. } => WireEvent::RequestWorkspaceManifest,
105 NetworkEvent::PlatformEvent { event, .. } => WireEvent::PlatformEvent(event.clone()),
106 NetworkEvent::EntitySpawned {
107 network_id, kind, ..
108 } => WireEvent::EntitySpawned {
109 network_id: *network_id,
110 kind: *kind,
111 },
112 NetworkEvent::EntityDespawned { network_id, .. } => WireEvent::EntityDespawned {
113 network_id: *network_id,
114 },
115 NetworkEvent::ReplicationBatch { events, .. } => {
116 WireEvent::ReplicationBatch(events.clone())
117 }
118 NetworkEvent::ClientConnected(_)
121 | NetworkEvent::ClientDisconnected(_)
122 | NetworkEvent::UnreliableMessage { .. }
123 | NetworkEvent::ReliableMessage { .. }
124 | NetworkEvent::SessionClosed(_)
125 | NetworkEvent::StreamReset(_)
126 | NetworkEvent::Disconnected(_) => {
127 return Err(EncodeError::Io(std::io::Error::other(format!(
128 "Cannot encode local-only variant as wire event: {event:?}"
129 ))));
130 }
131 })
132 }
133
134 pub fn decode_event(
139 &self,
140 data: &[u8],
141 ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
142 let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
143 EncodeError::MalformedPayload {
144 offset: 0, message: e.to_string(),
146 }
147 })?;
148
149 Ok(match wire_event {
150 WireEvent::Ping { tick } => NetworkEvent::Ping {
151 client_id: ClientId(0), tick,
153 },
154 WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
155 WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
156 WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
157 client_id: ClientId(0),
158 fragment,
159 },
160 WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
161 client_id: ClientId(0), count,
163 rotate,
164 },
165 WireEvent::Spawn {
166 entity_type,
167 x,
168 y,
169 rot,
170 } => NetworkEvent::Spawn {
171 client_id: ClientId(0),
172 entity_type,
173 x,
174 y,
175 rot,
176 },
177 WireEvent::ClearWorld => NetworkEvent::ClearWorld {
178 client_id: ClientId(0),
179 },
180 WireEvent::StartSession => NetworkEvent::StartSession {
181 client_id: ClientId(0),
182 },
183 WireEvent::RequestWorkspaceManifest => NetworkEvent::RequestWorkspaceManifest {
184 client_id: ClientId(0),
185 },
186 WireEvent::PlatformEvent(event) => NetworkEvent::PlatformEvent {
187 client_id: ClientId(0),
188 event,
189 },
190 WireEvent::EntitySpawned { network_id, kind } => NetworkEvent::EntitySpawned {
191 client_id: ClientId(0),
192 network_id,
193 kind,
194 },
195 WireEvent::EntityDespawned { network_id } => NetworkEvent::EntityDespawned {
196 client_id: ClientId(0),
197 network_id,
198 },
199 WireEvent::ReplicationBatch(events) => NetworkEvent::ReplicationBatch {
200 client_id: ClientId(0),
201 events,
202 },
203 })
204 }
205}
206
207impl Encoder for SerdeEncoder {
208 fn codec_id(&self) -> u32 {
209 1
210 }
211
212 fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
213 self.encode_event(event)
214 }
215
216 fn encode_event_into(
217 &self,
218 event: &NetworkEvent,
219 buffer: &mut [u8],
220 ) -> Result<usize, EncodeError> {
221 self.encode_event_into(event, buffer)
222 }
223
224 fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
225 self.decode_event(data)
226 }
227
228 fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
229 #[cfg(not(target_arch = "wasm32"))]
230 let start = std::time::Instant::now();
231 let header = PacketHeader {
232 network_id: event.network_id,
233 component_kind: event.component_kind,
234 tick: event.tick,
235 };
236
237 let mut cursor = Cursor::new(buffer);
238 let mut serializer = rmp_serde::Serializer::new(&mut cursor);
239
240 header.serialize(&mut serializer).map_err(|_e| {
241 metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
242 .increment(1);
243 EncodeError::BufferOverflow {
245 needed: 32, available: cursor.get_ref().len(),
247 }
248 })?;
249
250 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
251 let payload_len = event.payload.len();
252 let total_needed = header_len + payload_len;
253
254 if total_needed > cursor.get_ref().len() {
255 metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
256 .increment(1);
257 return Err(EncodeError::BufferOverflow {
258 needed: total_needed,
259 available: cursor.get_ref().len(),
260 });
261 }
262
263 let slice = cursor.into_inner();
265 slice[header_len..total_needed].copy_from_slice(&event.payload);
266
267 #[allow(clippy::cast_precision_loss)]
268 metrics::histogram!(
269 "aetheris_encoder_payload_size_bytes",
270 "operation" => "encode"
271 )
272 .record(total_needed as f64);
273
274 #[cfg(not(target_arch = "wasm32"))]
275 metrics::histogram!(
276 "aetheris_encoder_encode_duration_seconds",
277 "kind" => event.component_kind.0.to_string()
278 )
279 .record(start.elapsed().as_secs_f64());
280
281 Ok(total_needed)
282 }
283
284 fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
285 #[cfg(not(target_arch = "wasm32"))]
286 let start = std::time::Instant::now();
287 let mut cursor = Cursor::new(buffer);
288 let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
289
290 let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
291 metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
292 .increment(1);
293 EncodeError::MalformedPayload {
294 offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
295 message: e.to_string(),
296 }
297 })?;
298
299 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
300 let payload = buffer
301 .get(header_len..)
302 .ok_or(EncodeError::MalformedPayload {
303 offset: header_len,
304 message: "Payload slice out of bounds".to_string(),
305 })?
306 .to_vec();
307
308 #[allow(clippy::cast_precision_loss)]
309 metrics::histogram!(
310 "aetheris_encoder_payload_size_bytes",
311 "operation" => "decode"
312 )
313 .record(buffer.len() as f64);
314
315 #[cfg(not(target_arch = "wasm32"))]
316 metrics::histogram!(
317 "aetheris_encoder_decode_duration_seconds",
318 "kind" => header.component_kind.0.to_string()
319 )
320 .record(start.elapsed().as_secs_f64());
321
322 Ok(ComponentUpdate {
323 network_id: header.network_id,
324 component_kind: header.component_kind,
325 payload,
326 tick: header.tick,
327 })
328 }
329
330 fn max_encoded_size(&self) -> usize {
331 aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use proptest::prelude::*;
339
340 #[test]
341 fn test_roundtrip() {
342 let encoder = SerdeEncoder::new();
343 let event = ReplicationEvent {
344 network_id: NetworkId(42),
345 component_kind: ComponentKind(1),
346 payload: vec![1, 2, 3, 4],
347 tick: 100,
348 };
349
350 let mut buffer = [0u8; 1200];
351 let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
352 assert!(bytes_written > 0);
353
354 let update = encoder.decode(&buffer[..bytes_written]).unwrap();
355 assert_eq!(update.network_id, event.network_id);
356 assert_eq!(update.component_kind, event.component_kind);
357 assert_eq!(update.tick, event.tick);
358 assert_eq!(update.payload, event.payload);
359 }
360 #[test]
361 fn test_fragment_roundtrip() {
362 let encoder = SerdeEncoder::new();
363 let fragment = aetheris_protocol::events::FragmentedEvent {
364 message_id: 123,
365 fragment_index: 1,
366 total_fragments: 5,
367 payload: vec![1, 2, 3],
368 };
369
370 let event = NetworkEvent::Fragment {
371 client_id: aetheris_protocol::types::ClientId(0),
372 fragment: fragment.clone(),
373 };
374
375 let output = encoder.encode_event(&event).unwrap();
376 let decoded = encoder.decode_event(&output).unwrap();
377
378 if let NetworkEvent::Fragment {
379 client_id: _,
380 fragment: decoded_fragment,
381 } = decoded
382 {
383 assert_eq!(decoded_fragment.message_id, fragment.message_id);
384 assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
385 assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
386 assert_eq!(decoded_fragment.payload, fragment.payload);
387 } else {
388 panic!("Decoded event is not a Fragment: {decoded:?}");
389 }
390 }
391
392 #[test]
393 fn test_buffer_overflow() {
394 let encoder = SerdeEncoder::new();
395 let event = ReplicationEvent {
396 network_id: NetworkId(42),
397 component_kind: ComponentKind(1),
398 payload: vec![1, 2, 3, 4],
399 tick: 100,
400 };
401
402 let mut small_buffer = [0u8; 1];
403 let result = encoder.encode(&event, &mut small_buffer);
404 assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
405 }
406
407 #[test]
408 fn test_malformed_payload() {
409 let encoder = SerdeEncoder::new();
410 let garbage = [0xff, 0xff, 0xff, 0xff];
411 let result = encoder.decode(&garbage);
412 if let Err(EncodeError::MalformedPayload { message, .. }) = result {
413 assert!(!message.is_empty());
414 } else {
415 panic!("Expected MalformedPayload error, got {result:?}");
416 }
417 }
418
419 proptest! {
420 #[test]
421 fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
422 let encoder = SerdeEncoder::new();
423 let _ = encoder.decode(bytes);
425 }
426
427 #[test]
428 fn test_fuzz_roundtrip(
429 nid in any::<u64>(),
430 kind in any::<u16>(),
431 tick in any::<u64>(),
432 ref payload in any::<Vec<u8>>()
433 ) {
434 let encoder = SerdeEncoder::new();
435 let event = ReplicationEvent {
436 network_id: NetworkId(nid),
437 component_kind: ComponentKind(kind),
438 payload: payload.clone(),
439 tick,
440 };
441
442 let mut buffer = vec![0u8; 2048 + payload.len()];
443 if let Ok(written) = encoder.encode(&event, &mut buffer) {
444 let update = encoder.decode(&buffer[..written])
445 .expect("Round-trip decode failed during fuzzed test");
446 assert_eq!(update.network_id, event.network_id);
447 assert_eq!(update.component_kind, event.component_kind);
448 assert_eq!(update.tick, event.tick);
449 assert_eq!(update.payload, event.payload);
450 }
451 }
452 }
453
454 #[test]
455 fn test_disconnected_not_serializable() {
456 let encoder = SerdeEncoder::new();
457 let event = NetworkEvent::Disconnected(ClientId(42));
458
459 let result = encoder.encode_event(&event);
461 assert!(result.is_err());
462 if let Err(EncodeError::Io(e)) = result {
463 assert!(e.to_string().contains("Cannot encode local-only variant"));
464 } else {
465 panic!("Expected EncodeError::Io with local-only message, got {result:?}");
466 }
467 }
468
469 #[test]
470 fn test_platform_event_roundtrip() {
471 use aetheris_protocol::events::PlatformEvent;
472 use aetheris_protocol::types::NetworkId;
473 use std::collections::BTreeMap;
474
475 let encoder = SerdeEncoder::new();
476
477 let test_events = vec![
478 PlatformEvent::ResourceExhausted {
479 network_id: NetworkId(123),
480 },
481 PlatformEvent::Possession {
482 network_id: NetworkId(456),
483 },
484 PlatformEvent::WorkspaceManifest {
485 manifest: BTreeMap::from([("test_key".to_string(), "test_value".to_string())]),
486 },
487 PlatformEvent::Interaction {
488 source: NetworkId(1),
489 target: NetworkId(2),
490 amount: 10,
491 },
492 PlatformEvent::Termination {
493 target: NetworkId(3),
494 },
495 PlatformEvent::Reinitialization {
496 target: NetworkId(4),
497 x: 10.5,
498 y: 20.2,
499 },
500 PlatformEvent::PayloadCollected {
501 network_id: NetworkId(5),
502 amount: 50,
503 },
504 ];
505
506 for platform_event in test_events {
507 let event = NetworkEvent::PlatformEvent {
508 client_id: ClientId(1),
509 event: platform_event.clone(),
510 };
511
512 let output = encoder.encode_event(&event).unwrap();
513 let decoded = encoder.decode_event(&output).unwrap();
514
515 if let NetworkEvent::PlatformEvent {
516 client_id,
517 event: decoded_event,
518 } = decoded
519 {
520 assert_eq!(
521 client_id,
522 ClientId(0),
523 "Wire decoding should default client_id to 0"
524 );
525 assert_eq!(
526 decoded_event, platform_event,
527 "Roundtrip failed for {platform_event:?}"
528 );
529 } else {
530 panic!("Decoded event is not a PlatformEvent: {decoded:?}");
531 }
532 }
533 }
534}