1use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17 network_id: NetworkId,
18 component_kind: ComponentKind,
19 tick: u64,
20}
21
22#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30 #[must_use]
32 pub fn new() -> Self {
33 Self
34 }
35
36 pub fn encode_event(
41 &self,
42 event: &aetheris_protocol::events::NetworkEvent,
43 ) -> Result<Vec<u8>, EncodeError> {
44 let wire_event = match event {
45 NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
46 NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
47 NetworkEvent::Auth { session_token } => WireEvent::Auth {
48 session_token: session_token.clone(),
49 },
50 NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
51 NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
52 count: *count,
53 rotate: *rotate,
54 },
55 NetworkEvent::Spawn {
56 entity_type,
57 x,
58 y,
59 rot,
60 ..
61 } => WireEvent::Spawn {
62 entity_type: *entity_type,
63 x: *x,
64 y: *y,
65 rot: *rot,
66 },
67 NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
68 NetworkEvent::GameEvent { event, .. } => WireEvent::GameEvent(event.clone()),
69 NetworkEvent::ClientConnected(_)
70 | NetworkEvent::ClientDisconnected(_)
71 | NetworkEvent::UnreliableMessage { .. }
72 | NetworkEvent::ReliableMessage { .. }
73 | NetworkEvent::Ping { .. }
74 | NetworkEvent::SessionClosed(_)
75 | NetworkEvent::StreamReset(_)
76 | NetworkEvent::Disconnected(_) => {
77 return Err(EncodeError::Io(std::io::Error::other(format!(
78 "Cannot encode local-only variant as wire event: {event:?}"
79 ))));
80 }
81 };
82 rmp_serde::to_vec(&wire_event)
83 .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
84 }
85
86 pub fn decode_event(
91 &self,
92 data: &[u8],
93 ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
94 let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
95 EncodeError::MalformedPayload {
96 offset: 0, message: e.to_string(),
98 }
99 })?;
100
101 Ok(match wire_event {
102 WireEvent::Ping { tick } => NetworkEvent::Ping {
103 client_id: ClientId(0), tick,
105 },
106 WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
107 WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
108 WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
109 client_id: ClientId(0),
110 fragment,
111 },
112 WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
113 client_id: ClientId(0), count,
115 rotate,
116 },
117 WireEvent::Spawn {
118 entity_type,
119 x,
120 y,
121 rot,
122 } => NetworkEvent::Spawn {
123 client_id: ClientId(0),
124 entity_type,
125 x,
126 y,
127 rot,
128 },
129 WireEvent::ClearWorld => NetworkEvent::ClearWorld {
130 client_id: ClientId(0),
131 },
132 WireEvent::GameEvent(event) => NetworkEvent::GameEvent {
133 client_id: ClientId(0),
134 event,
135 },
136 })
137 }
138}
139
140impl Encoder for SerdeEncoder {
141 fn codec_id(&self) -> u32 {
142 1
143 }
144
145 fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
146 self.encode_event(event)
147 }
148
149 fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
150 self.decode_event(data)
151 }
152
153 fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
154 #[cfg(not(target_arch = "wasm32"))]
155 let start = std::time::Instant::now();
156 let header = PacketHeader {
157 network_id: event.network_id,
158 component_kind: event.component_kind,
159 tick: event.tick,
160 };
161
162 let mut cursor = Cursor::new(buffer);
163 let mut serializer = rmp_serde::Serializer::new(&mut cursor);
164
165 header.serialize(&mut serializer).map_err(|_e| {
166 metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
167 .increment(1);
168 EncodeError::BufferOverflow {
170 needed: 32, available: cursor.get_ref().len(),
172 }
173 })?;
174
175 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
176 let payload_len = event.payload.len();
177 let total_needed = header_len + payload_len;
178
179 if total_needed > cursor.get_ref().len() {
180 metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
181 .increment(1);
182 return Err(EncodeError::BufferOverflow {
183 needed: total_needed,
184 available: cursor.get_ref().len(),
185 });
186 }
187
188 let slice = cursor.into_inner();
190 slice[header_len..total_needed].copy_from_slice(&event.payload);
191
192 #[allow(clippy::cast_precision_loss)]
193 metrics::histogram!(
194 "aetheris_encoder_payload_size_bytes",
195 "operation" => "encode"
196 )
197 .record(total_needed as f64);
198
199 #[cfg(not(target_arch = "wasm32"))]
200 metrics::histogram!(
201 "aetheris_encoder_encode_duration_seconds",
202 "kind" => event.component_kind.0.to_string()
203 )
204 .record(start.elapsed().as_secs_f64());
205
206 Ok(total_needed)
207 }
208
209 fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
210 #[cfg(not(target_arch = "wasm32"))]
211 let start = std::time::Instant::now();
212 let mut cursor = Cursor::new(buffer);
213 let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
214
215 let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
216 metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
217 .increment(1);
218 EncodeError::MalformedPayload {
219 offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
220 message: e.to_string(),
221 }
222 })?;
223
224 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
225 let payload = buffer
226 .get(header_len..)
227 .ok_or(EncodeError::MalformedPayload {
228 offset: header_len,
229 message: "Payload slice out of bounds".to_string(),
230 })?
231 .to_vec();
232
233 #[allow(clippy::cast_precision_loss)]
234 metrics::histogram!(
235 "aetheris_encoder_payload_size_bytes",
236 "operation" => "decode"
237 )
238 .record(buffer.len() as f64);
239
240 #[cfg(not(target_arch = "wasm32"))]
241 metrics::histogram!(
242 "aetheris_encoder_decode_duration_seconds",
243 "kind" => header.component_kind.0.to_string()
244 )
245 .record(start.elapsed().as_secs_f64());
246
247 Ok(ComponentUpdate {
248 network_id: header.network_id,
249 component_kind: header.component_kind,
250 payload,
251 tick: header.tick,
252 })
253 }
254
255 fn max_encoded_size(&self) -> usize {
256 aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use proptest::prelude::*;
264
265 #[test]
266 fn test_roundtrip() {
267 let encoder = SerdeEncoder::new();
268 let event = ReplicationEvent {
269 network_id: NetworkId(42),
270 component_kind: ComponentKind(1),
271 payload: vec![1, 2, 3, 4],
272 tick: 100,
273 };
274
275 let mut buffer = [0u8; 1200];
276 let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
277 assert!(bytes_written > 0);
278
279 let update = encoder.decode(&buffer[..bytes_written]).unwrap();
280 assert_eq!(update.network_id, event.network_id);
281 assert_eq!(update.component_kind, event.component_kind);
282 assert_eq!(update.tick, event.tick);
283 assert_eq!(update.payload, event.payload);
284 }
285 #[test]
286 fn test_fragment_roundtrip() {
287 let encoder = SerdeEncoder::new();
288 let fragment = aetheris_protocol::events::FragmentedEvent {
289 message_id: 123,
290 fragment_index: 1,
291 total_fragments: 5,
292 payload: vec![1, 2, 3],
293 };
294
295 let event = NetworkEvent::Fragment {
296 client_id: aetheris_protocol::types::ClientId(0),
297 fragment: fragment.clone(),
298 };
299
300 let output = encoder.encode_event(&event).unwrap();
301 let decoded = encoder.decode_event(&output).unwrap();
302
303 if let NetworkEvent::Fragment {
304 client_id: _,
305 fragment: decoded_fragment,
306 } = decoded
307 {
308 assert_eq!(decoded_fragment.message_id, fragment.message_id);
309 assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
310 assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
311 assert_eq!(decoded_fragment.payload, fragment.payload);
312 } else {
313 panic!("Decoded event is not a Fragment: {decoded:?}");
314 }
315 }
316
317 #[test]
318 fn test_buffer_overflow() {
319 let encoder = SerdeEncoder::new();
320 let event = ReplicationEvent {
321 network_id: NetworkId(42),
322 component_kind: ComponentKind(1),
323 payload: vec![1, 2, 3, 4],
324 tick: 100,
325 };
326
327 let mut small_buffer = [0u8; 1];
328 let result = encoder.encode(&event, &mut small_buffer);
329 assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
330 }
331
332 #[test]
333 fn test_malformed_payload() {
334 let encoder = SerdeEncoder::new();
335 let garbage = [0xff, 0xff, 0xff, 0xff];
336 let result = encoder.decode(&garbage);
337 if let Err(EncodeError::MalformedPayload { message, .. }) = result {
338 assert!(!message.is_empty());
339 } else {
340 panic!("Expected MalformedPayload error, got {result:?}");
341 }
342 }
343
344 proptest! {
345 #[test]
346 fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
347 let encoder = SerdeEncoder::new();
348 let _ = encoder.decode(bytes);
350 }
351
352 #[test]
353 fn test_fuzz_roundtrip(
354 nid in any::<u64>(),
355 kind in any::<u16>(),
356 tick in any::<u64>(),
357 ref payload in any::<Vec<u8>>()
358 ) {
359 let encoder = SerdeEncoder::new();
360 let event = ReplicationEvent {
361 network_id: NetworkId(nid),
362 component_kind: ComponentKind(kind),
363 payload: payload.clone(),
364 tick,
365 };
366
367 let mut buffer = vec![0u8; 2048 + payload.len()];
368 if let Ok(written) = encoder.encode(&event, &mut buffer) {
369 let update = encoder.decode(&buffer[..written])
370 .expect("Round-trip decode failed during fuzzed test");
371 assert_eq!(update.network_id, event.network_id);
372 assert_eq!(update.component_kind, event.component_kind);
373 assert_eq!(update.tick, event.tick);
374 assert_eq!(update.payload, event.payload);
375 }
376 }
377 }
378
379 #[test]
380 fn test_disconnected_not_serializable() {
381 let encoder = SerdeEncoder::new();
382 let event = NetworkEvent::Disconnected(ClientId(42));
383
384 let result = encoder.encode_event(&event);
386 assert!(result.is_err());
387 if let Err(EncodeError::Io(e)) = result {
388 assert!(e.to_string().contains("Cannot encode local-only variant"));
389 } else {
390 panic!("Expected EncodeError::Io with local-only message, got {result:?}");
391 }
392 }
393
394 #[test]
395 fn test_game_event_roundtrip() {
396 use aetheris_protocol::events::GameEvent;
397 use aetheris_protocol::types::NetworkId;
398
399 let encoder = SerdeEncoder::new();
400 let game_event = GameEvent::AsteroidDepleted {
401 network_id: NetworkId(123),
402 };
403 let event = NetworkEvent::GameEvent {
404 client_id: ClientId(1), event: game_event.clone(),
406 };
407
408 let output = encoder.encode_event(&event).unwrap();
409 let decoded = encoder.decode_event(&output).unwrap();
410
411 if let NetworkEvent::GameEvent {
412 client_id,
413 event: decoded_event,
414 } = decoded
415 {
416 assert_eq!(
417 client_id,
418 ClientId(0),
419 "Wire decoding should default client_id to 0"
420 );
421 match decoded_event {
422 GameEvent::AsteroidDepleted { network_id } => {
423 assert_eq!(network_id, NetworkId(123));
424 }
425 }
426 } else {
427 panic!("Decoded event is not a GameEvent: {decoded:?}");
428 }
429 }
430}