1use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17 network_id: NetworkId,
18 component_kind: ComponentKind,
19 tick: u64,
20}
21
22#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30 #[must_use]
32 pub fn new() -> Self {
33 Self
34 }
35
36 pub fn encode_event(
41 &self,
42 event: &aetheris_protocol::events::NetworkEvent,
43 ) -> Result<Vec<u8>, EncodeError> {
44 let wire_event = match event {
45 NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
46 NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
47 NetworkEvent::Auth { session_token } => WireEvent::Auth {
48 session_token: session_token.clone(),
49 },
50 NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
51 NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
52 count: *count,
53 rotate: *rotate,
54 },
55 NetworkEvent::Spawn {
56 entity_type,
57 x,
58 y,
59 rot,
60 ..
61 } => WireEvent::Spawn {
62 entity_type: *entity_type,
63 x: *x,
64 y: *y,
65 rot: *rot,
66 },
67 NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
68 NetworkEvent::StartSession { .. } => WireEvent::StartSession,
69 NetworkEvent::RequestSystemManifest { .. } => WireEvent::RequestSystemManifest,
70 NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
71 NetworkEvent::ClientConnected(_)
72 | NetworkEvent::ClientDisconnected(_)
73 | NetworkEvent::UnreliableMessage { .. }
74 | NetworkEvent::ReliableMessage { .. }
75 | NetworkEvent::Ping { .. }
76 | NetworkEvent::SessionClosed(_)
77 | NetworkEvent::StreamReset(_)
78 | NetworkEvent::Disconnected(_) => {
79 return Err(EncodeError::Io(std::io::Error::other(format!(
80 "Cannot encode local-only variant as wire event: {event:?}"
81 ))));
82 }
83 };
84 rmp_serde::to_vec(&wire_event)
85 .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
86 }
87
88 pub fn decode_event(
93 &self,
94 data: &[u8],
95 ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
96 let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
97 EncodeError::MalformedPayload {
98 offset: 0, message: e.to_string(),
100 }
101 })?;
102
103 Ok(match wire_event {
104 WireEvent::Ping { tick } => NetworkEvent::Ping {
105 client_id: ClientId(0), tick,
107 },
108 WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
109 WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
110 WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
111 client_id: ClientId(0),
112 fragment,
113 },
114 WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
115 client_id: ClientId(0), count,
117 rotate,
118 },
119 WireEvent::Spawn {
120 entity_type,
121 x,
122 y,
123 rot,
124 } => NetworkEvent::Spawn {
125 client_id: ClientId(0),
126 entity_type,
127 x,
128 y,
129 rot,
130 },
131 WireEvent::ClearWorld => NetworkEvent::ClearWorld {
132 client_id: ClientId(0),
133 },
134 WireEvent::StartSession => NetworkEvent::StartSession {
135 client_id: ClientId(0),
136 },
137 WireEvent::RequestSystemManifest => NetworkEvent::RequestSystemManifest {
138 client_id: ClientId(0),
139 },
140 WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
141 client_id: ClientId(0),
142 event,
143 },
144 })
145 }
146}
147
148impl Encoder for SerdeEncoder {
149 fn codec_id(&self) -> u32 {
150 1
151 }
152
153 fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
154 self.encode_event(event)
155 }
156
157 fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
158 self.decode_event(data)
159 }
160
161 fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
162 #[cfg(not(target_arch = "wasm32"))]
163 let start = std::time::Instant::now();
164 let header = PacketHeader {
165 network_id: event.network_id,
166 component_kind: event.component_kind,
167 tick: event.tick,
168 };
169
170 let mut cursor = Cursor::new(buffer);
171 let mut serializer = rmp_serde::Serializer::new(&mut cursor);
172
173 header.serialize(&mut serializer).map_err(|_e| {
174 metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
175 .increment(1);
176 EncodeError::BufferOverflow {
178 needed: 32, available: cursor.get_ref().len(),
180 }
181 })?;
182
183 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
184 let payload_len = event.payload.len();
185 let total_needed = header_len + payload_len;
186
187 if total_needed > cursor.get_ref().len() {
188 metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
189 .increment(1);
190 return Err(EncodeError::BufferOverflow {
191 needed: total_needed,
192 available: cursor.get_ref().len(),
193 });
194 }
195
196 let slice = cursor.into_inner();
198 slice[header_len..total_needed].copy_from_slice(&event.payload);
199
200 #[allow(clippy::cast_precision_loss)]
201 metrics::histogram!(
202 "aetheris_encoder_payload_size_bytes",
203 "operation" => "encode"
204 )
205 .record(total_needed as f64);
206
207 #[cfg(not(target_arch = "wasm32"))]
208 metrics::histogram!(
209 "aetheris_encoder_encode_duration_seconds",
210 "kind" => event.component_kind.0.to_string()
211 )
212 .record(start.elapsed().as_secs_f64());
213
214 Ok(total_needed)
215 }
216
217 fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
218 #[cfg(not(target_arch = "wasm32"))]
219 let start = std::time::Instant::now();
220 let mut cursor = Cursor::new(buffer);
221 let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
222
223 let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
224 metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
225 .increment(1);
226 EncodeError::MalformedPayload {
227 offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
228 message: e.to_string(),
229 }
230 })?;
231
232 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
233 let payload = buffer
234 .get(header_len..)
235 .ok_or(EncodeError::MalformedPayload {
236 offset: header_len,
237 message: "Payload slice out of bounds".to_string(),
238 })?
239 .to_vec();
240
241 #[allow(clippy::cast_precision_loss)]
242 metrics::histogram!(
243 "aetheris_encoder_payload_size_bytes",
244 "operation" => "decode"
245 )
246 .record(buffer.len() as f64);
247
248 #[cfg(not(target_arch = "wasm32"))]
249 metrics::histogram!(
250 "aetheris_encoder_decode_duration_seconds",
251 "kind" => header.component_kind.0.to_string()
252 )
253 .record(start.elapsed().as_secs_f64());
254
255 Ok(ComponentUpdate {
256 network_id: header.network_id,
257 component_kind: header.component_kind,
258 payload,
259 tick: header.tick,
260 })
261 }
262
263 fn max_encoded_size(&self) -> usize {
264 aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use proptest::prelude::*;
272
273 #[test]
274 fn test_roundtrip() {
275 let encoder = SerdeEncoder::new();
276 let event = ReplicationEvent {
277 network_id: NetworkId(42),
278 component_kind: ComponentKind(1),
279 payload: vec![1, 2, 3, 4],
280 tick: 100,
281 };
282
283 let mut buffer = [0u8; 1200];
284 let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
285 assert!(bytes_written > 0);
286
287 let update = encoder.decode(&buffer[..bytes_written]).unwrap();
288 assert_eq!(update.network_id, event.network_id);
289 assert_eq!(update.component_kind, event.component_kind);
290 assert_eq!(update.tick, event.tick);
291 assert_eq!(update.payload, event.payload);
292 }
293 #[test]
294 fn test_fragment_roundtrip() {
295 let encoder = SerdeEncoder::new();
296 let fragment = aetheris_protocol::events::FragmentedEvent {
297 message_id: 123,
298 fragment_index: 1,
299 total_fragments: 5,
300 payload: vec![1, 2, 3],
301 };
302
303 let event = NetworkEvent::Fragment {
304 client_id: aetheris_protocol::types::ClientId(0),
305 fragment: fragment.clone(),
306 };
307
308 let output = encoder.encode_event(&event).unwrap();
309 let decoded = encoder.decode_event(&output).unwrap();
310
311 if let NetworkEvent::Fragment {
312 client_id: _,
313 fragment: decoded_fragment,
314 } = decoded
315 {
316 assert_eq!(decoded_fragment.message_id, fragment.message_id);
317 assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
318 assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
319 assert_eq!(decoded_fragment.payload, fragment.payload);
320 } else {
321 panic!("Decoded event is not a Fragment: {decoded:?}");
322 }
323 }
324
325 #[test]
326 fn test_buffer_overflow() {
327 let encoder = SerdeEncoder::new();
328 let event = ReplicationEvent {
329 network_id: NetworkId(42),
330 component_kind: ComponentKind(1),
331 payload: vec![1, 2, 3, 4],
332 tick: 100,
333 };
334
335 let mut small_buffer = [0u8; 1];
336 let result = encoder.encode(&event, &mut small_buffer);
337 assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
338 }
339
340 #[test]
341 fn test_malformed_payload() {
342 let encoder = SerdeEncoder::new();
343 let garbage = [0xff, 0xff, 0xff, 0xff];
344 let result = encoder.decode(&garbage);
345 if let Err(EncodeError::MalformedPayload { message, .. }) = result {
346 assert!(!message.is_empty());
347 } else {
348 panic!("Expected MalformedPayload error, got {result:?}");
349 }
350 }
351
352 proptest! {
353 #[test]
354 fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
355 let encoder = SerdeEncoder::new();
356 let _ = encoder.decode(bytes);
358 }
359
360 #[test]
361 fn test_fuzz_roundtrip(
362 nid in any::<u64>(),
363 kind in any::<u16>(),
364 tick in any::<u64>(),
365 ref payload in any::<Vec<u8>>()
366 ) {
367 let encoder = SerdeEncoder::new();
368 let event = ReplicationEvent {
369 network_id: NetworkId(nid),
370 component_kind: ComponentKind(kind),
371 payload: payload.clone(),
372 tick,
373 };
374
375 let mut buffer = vec![0u8; 2048 + payload.len()];
376 if let Ok(written) = encoder.encode(&event, &mut buffer) {
377 let update = encoder.decode(&buffer[..written])
378 .expect("Round-trip decode failed during fuzzed test");
379 assert_eq!(update.network_id, event.network_id);
380 assert_eq!(update.component_kind, event.component_kind);
381 assert_eq!(update.tick, event.tick);
382 assert_eq!(update.payload, event.payload);
383 }
384 }
385 }
386
387 #[test]
388 fn test_disconnected_not_serializable() {
389 let encoder = SerdeEncoder::new();
390 let event = NetworkEvent::Disconnected(ClientId(42));
391
392 let result = encoder.encode_event(&event);
394 assert!(result.is_err());
395 if let Err(EncodeError::Io(e)) = result {
396 assert!(e.to_string().contains("Cannot encode local-only variant"));
397 } else {
398 panic!("Expected EncodeError::Io with local-only message, got {result:?}");
399 }
400 }
401
402 #[test]
403 fn test_game_event_roundtrip() {
404 use aetheris_protocol::events::GameEvent;
405 use aetheris_protocol::types::NetworkId;
406
407 let encoder = SerdeEncoder::new();
408 let game_event = GameEvent::AsteroidDepleted {
409 network_id: NetworkId(123),
410 };
411 let event = NetworkEvent::GameEvent {
412 client_id: ClientId(1), event: game_event.clone(),
414 };
415
416 let output = encoder.encode_event(&event).unwrap();
417 let decoded = encoder.decode_event(&output).unwrap();
418
419 if let NetworkEvent::GameEvent {
420 client_id,
421 event: decoded_event,
422 } = decoded
423 {
424 assert_eq!(
425 client_id,
426 ClientId(0),
427 "Wire decoding should default client_id to 0"
428 );
429 match decoded_event {
430 GameEvent::AsteroidDepleted { network_id } => {
431 assert_eq!(network_id, NetworkId(123));
432 }
433 GameEvent::Possession { .. } | GameEvent::SystemManifest { .. } => {
434 panic!("Unexpected event type in roundtrip test");
435 }
436 }
437 } else {
438 panic!("Decoded event is not a GameEvent: {decoded:?}");
439 }
440 }
441}