1use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE;
8use aetheris_protocol::error::EncodeError;
9use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
10use aetheris_protocol::traits::Encoder;
11use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
12
13#[derive(Debug, Serialize, Deserialize)]
17struct PacketHeader {
18 network_id: NetworkId,
19 component_kind: ComponentKind,
20 tick: u64,
21}
22
23#[derive(Debug, Default)]
28pub struct SerdeEncoder;
29
30impl SerdeEncoder {
31 #[must_use]
33 pub fn new() -> Self {
34 Self
35 }
36
37 pub fn encode_event(
42 &self,
43 event: &aetheris_protocol::events::NetworkEvent,
44 ) -> Result<Vec<u8>, EncodeError> {
45 let wire_event = match event {
46 NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
47 NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
48 NetworkEvent::Auth { session_token } => WireEvent::Auth {
49 session_token: session_token.clone(),
50 },
51 NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
52 NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
53 count: *count,
54 rotate: *rotate,
55 },
56 NetworkEvent::Spawn {
57 entity_type,
58 x,
59 y,
60 rot,
61 ..
62 } => WireEvent::Spawn {
63 entity_type: *entity_type,
64 x: *x,
65 y: *y,
66 rot: *rot,
67 },
68 NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
69 NetworkEvent::StartSession { .. } => WireEvent::StartSession,
70 NetworkEvent::RequestSystemManifest { .. } => WireEvent::RequestSystemManifest,
71 NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
72 NetworkEvent::ReplicationBatch { events, .. } => {
73 let wire_event = WireEvent::ReplicationBatch(events.clone());
74 let encoded = rmp_serde::to_vec(&wire_event)
75 .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))?;
76
77 if encoded.len() > MAX_SAFE_PAYLOAD_SIZE {
78 return Err(EncodeError::BufferOverflow {
79 needed: encoded.len(),
80 available: MAX_SAFE_PAYLOAD_SIZE,
81 });
82 }
83 return Ok(encoded);
84 }
85 NetworkEvent::ClientConnected(_)
86 | NetworkEvent::ClientDisconnected(_)
87 | NetworkEvent::UnreliableMessage { .. }
88 | NetworkEvent::ReliableMessage { .. }
89 | NetworkEvent::Ping { .. }
90 | NetworkEvent::SessionClosed(_)
91 | NetworkEvent::StreamReset(_)
92 | NetworkEvent::Disconnected(_) => {
93 return Err(EncodeError::Io(std::io::Error::other(format!(
94 "Cannot encode local-only variant as wire event: {event:?}"
95 ))));
96 }
97 };
98 rmp_serde::to_vec(&wire_event)
99 .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
100 }
101
102 pub fn decode_event(
107 &self,
108 data: &[u8],
109 ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
110 let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
111 EncodeError::MalformedPayload {
112 offset: 0, message: e.to_string(),
114 }
115 })?;
116
117 Ok(match wire_event {
118 WireEvent::Ping { tick } => NetworkEvent::Ping {
119 client_id: ClientId(0), tick,
121 },
122 WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
123 WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
124 WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
125 client_id: ClientId(0),
126 fragment,
127 },
128 WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
129 client_id: ClientId(0), count,
131 rotate,
132 },
133 WireEvent::Spawn {
134 entity_type,
135 x,
136 y,
137 rot,
138 } => NetworkEvent::Spawn {
139 client_id: ClientId(0),
140 entity_type,
141 x,
142 y,
143 rot,
144 },
145 WireEvent::ClearWorld => NetworkEvent::ClearWorld {
146 client_id: ClientId(0),
147 },
148 WireEvent::StartSession => NetworkEvent::StartSession {
149 client_id: ClientId(0),
150 },
151 WireEvent::RequestSystemManifest => NetworkEvent::RequestSystemManifest {
152 client_id: ClientId(0),
153 },
154 WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
155 client_id: ClientId(0),
156 event,
157 },
158 WireEvent::ReplicationBatch(events) => NetworkEvent::ReplicationBatch {
159 client_id: ClientId(0),
160 events,
161 },
162 })
163 }
164}
165
166impl Encoder for SerdeEncoder {
167 fn codec_id(&self) -> u32 {
168 1
169 }
170
171 fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
172 self.encode_event(event)
173 }
174
175 fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
176 self.decode_event(data)
177 }
178
179 fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
180 #[cfg(not(target_arch = "wasm32"))]
181 let start = std::time::Instant::now();
182 let header = PacketHeader {
183 network_id: event.network_id,
184 component_kind: event.component_kind,
185 tick: event.tick,
186 };
187
188 let mut cursor = Cursor::new(buffer);
189 let mut serializer = rmp_serde::Serializer::new(&mut cursor);
190
191 header.serialize(&mut serializer).map_err(|_e| {
192 metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
193 .increment(1);
194 EncodeError::BufferOverflow {
196 needed: 32, available: cursor.get_ref().len(),
198 }
199 })?;
200
201 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
202 let payload_len = event.payload.len();
203 let total_needed = header_len + payload_len;
204
205 if total_needed > cursor.get_ref().len() {
206 metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
207 .increment(1);
208 return Err(EncodeError::BufferOverflow {
209 needed: total_needed,
210 available: cursor.get_ref().len(),
211 });
212 }
213
214 let slice = cursor.into_inner();
216 slice[header_len..total_needed].copy_from_slice(&event.payload);
217
218 #[allow(clippy::cast_precision_loss)]
219 metrics::histogram!(
220 "aetheris_encoder_payload_size_bytes",
221 "operation" => "encode"
222 )
223 .record(total_needed as f64);
224
225 #[cfg(not(target_arch = "wasm32"))]
226 metrics::histogram!(
227 "aetheris_encoder_encode_duration_seconds",
228 "kind" => event.component_kind.0.to_string()
229 )
230 .record(start.elapsed().as_secs_f64());
231
232 Ok(total_needed)
233 }
234
235 fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
236 #[cfg(not(target_arch = "wasm32"))]
237 let start = std::time::Instant::now();
238 let mut cursor = Cursor::new(buffer);
239 let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
240
241 let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
242 metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
243 .increment(1);
244 EncodeError::MalformedPayload {
245 offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
246 message: e.to_string(),
247 }
248 })?;
249
250 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
251 let payload = buffer
252 .get(header_len..)
253 .ok_or(EncodeError::MalformedPayload {
254 offset: header_len,
255 message: "Payload slice out of bounds".to_string(),
256 })?
257 .to_vec();
258
259 #[allow(clippy::cast_precision_loss)]
260 metrics::histogram!(
261 "aetheris_encoder_payload_size_bytes",
262 "operation" => "decode"
263 )
264 .record(buffer.len() as f64);
265
266 #[cfg(not(target_arch = "wasm32"))]
267 metrics::histogram!(
268 "aetheris_encoder_decode_duration_seconds",
269 "kind" => header.component_kind.0.to_string()
270 )
271 .record(start.elapsed().as_secs_f64());
272
273 Ok(ComponentUpdate {
274 network_id: header.network_id,
275 component_kind: header.component_kind,
276 payload,
277 tick: header.tick,
278 })
279 }
280
281 fn max_encoded_size(&self) -> usize {
282 aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use proptest::prelude::*;
290
291 #[test]
292 fn test_roundtrip() {
293 let encoder = SerdeEncoder::new();
294 let event = ReplicationEvent {
295 network_id: NetworkId(42),
296 component_kind: ComponentKind(1),
297 payload: vec![1, 2, 3, 4],
298 tick: 100,
299 };
300
301 let mut buffer = [0u8; 1200];
302 let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
303 assert!(bytes_written > 0);
304
305 let update = encoder.decode(&buffer[..bytes_written]).unwrap();
306 assert_eq!(update.network_id, event.network_id);
307 assert_eq!(update.component_kind, event.component_kind);
308 assert_eq!(update.tick, event.tick);
309 assert_eq!(update.payload, event.payload);
310 }
311 #[test]
312 fn test_fragment_roundtrip() {
313 let encoder = SerdeEncoder::new();
314 let fragment = aetheris_protocol::events::FragmentedEvent {
315 message_id: 123,
316 fragment_index: 1,
317 total_fragments: 5,
318 payload: vec![1, 2, 3],
319 };
320
321 let event = NetworkEvent::Fragment {
322 client_id: aetheris_protocol::types::ClientId(0),
323 fragment: fragment.clone(),
324 };
325
326 let output = encoder.encode_event(&event).unwrap();
327 let decoded = encoder.decode_event(&output).unwrap();
328
329 if let NetworkEvent::Fragment {
330 client_id: _,
331 fragment: decoded_fragment,
332 } = decoded
333 {
334 assert_eq!(decoded_fragment.message_id, fragment.message_id);
335 assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
336 assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
337 assert_eq!(decoded_fragment.payload, fragment.payload);
338 } else {
339 panic!("Decoded event is not a Fragment: {decoded:?}");
340 }
341 }
342
343 #[test]
344 fn test_buffer_overflow() {
345 let encoder = SerdeEncoder::new();
346 let event = ReplicationEvent {
347 network_id: NetworkId(42),
348 component_kind: ComponentKind(1),
349 payload: vec![1, 2, 3, 4],
350 tick: 100,
351 };
352
353 let mut small_buffer = [0u8; 1];
354 let result = encoder.encode(&event, &mut small_buffer);
355 assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
356 }
357
358 #[test]
359 fn test_malformed_payload() {
360 let encoder = SerdeEncoder::new();
361 let garbage = [0xff, 0xff, 0xff, 0xff];
362 let result = encoder.decode(&garbage);
363 if let Err(EncodeError::MalformedPayload { message, .. }) = result {
364 assert!(!message.is_empty());
365 } else {
366 panic!("Expected MalformedPayload error, got {result:?}");
367 }
368 }
369
370 proptest! {
371 #[test]
372 fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
373 let encoder = SerdeEncoder::new();
374 let _ = encoder.decode(bytes);
376 }
377
378 #[test]
379 fn test_fuzz_roundtrip(
380 nid in any::<u64>(),
381 kind in any::<u16>(),
382 tick in any::<u64>(),
383 ref payload in any::<Vec<u8>>()
384 ) {
385 let encoder = SerdeEncoder::new();
386 let event = ReplicationEvent {
387 network_id: NetworkId(nid),
388 component_kind: ComponentKind(kind),
389 payload: payload.clone(),
390 tick,
391 };
392
393 let mut buffer = vec![0u8; 2048 + payload.len()];
394 if let Ok(written) = encoder.encode(&event, &mut buffer) {
395 let update = encoder.decode(&buffer[..written])
396 .expect("Round-trip decode failed during fuzzed test");
397 assert_eq!(update.network_id, event.network_id);
398 assert_eq!(update.component_kind, event.component_kind);
399 assert_eq!(update.tick, event.tick);
400 assert_eq!(update.payload, event.payload);
401 }
402 }
403 }
404
405 #[test]
406 fn test_disconnected_not_serializable() {
407 let encoder = SerdeEncoder::new();
408 let event = NetworkEvent::Disconnected(ClientId(42));
409
410 let result = encoder.encode_event(&event);
412 assert!(result.is_err());
413 if let Err(EncodeError::Io(e)) = result {
414 assert!(e.to_string().contains("Cannot encode local-only variant"));
415 } else {
416 panic!("Expected EncodeError::Io with local-only message, got {result:?}");
417 }
418 }
419
420 #[test]
421 fn test_game_event_roundtrip() {
422 use aetheris_protocol::events::GameEvent;
423 use aetheris_protocol::types::NetworkId;
424
425 let encoder = SerdeEncoder::new();
426 let game_event = GameEvent::AsteroidDepleted {
427 network_id: NetworkId(123),
428 };
429 let event = NetworkEvent::GameEvent {
430 client_id: ClientId(1), event: game_event.clone(),
432 };
433
434 let output = encoder.encode_event(&event).unwrap();
435 let decoded = encoder.decode_event(&output).unwrap();
436
437 if let NetworkEvent::GameEvent {
438 client_id,
439 event: decoded_event,
440 } = decoded
441 {
442 assert_eq!(
443 client_id,
444 ClientId(0),
445 "Wire decoding should default client_id to 0"
446 );
447 match decoded_event {
448 GameEvent::AsteroidDepleted { network_id } => {
449 assert_eq!(network_id, NetworkId(123));
450 }
451 GameEvent::Possession { .. } | GameEvent::SystemManifest { .. } => {
452 panic!("Unexpected event type in roundtrip test");
453 }
454 }
455 } else {
456 panic!("Decoded event is not a GameEvent: {decoded:?}");
457 }
458 }
459}