1use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17 network_id: NetworkId,
18 component_kind: ComponentKind,
19 tick: u64,
20}
21
22#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30 #[must_use]
32 pub fn new() -> Self {
33 Self
34 }
35
36 pub fn encode_event(
41 &self,
42 event: &aetheris_protocol::events::NetworkEvent,
43 ) -> Result<Vec<u8>, EncodeError> {
44 let wire_event = match event {
45 NetworkEvent::Ping { tick, .. } if event.is_wire() => WireEvent::Ping { tick: *tick },
46 NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
47 NetworkEvent::Auth { session_token } => WireEvent::Auth {
48 session_token: session_token.clone(),
49 },
50 NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
51 NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
52 count: *count,
53 rotate: *rotate,
54 },
55 NetworkEvent::Spawn {
56 entity_type,
57 x,
58 y,
59 rot,
60 ..
61 } => WireEvent::Spawn {
62 entity_type: *entity_type,
63 x: *x,
64 y: *y,
65 rot: *rot,
66 },
67 NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
68 NetworkEvent::ClientConnected(_)
69 | NetworkEvent::ClientDisconnected(_)
70 | NetworkEvent::UnreliableMessage { .. }
71 | NetworkEvent::ReliableMessage { .. }
72 | NetworkEvent::Ping { .. }
73 | NetworkEvent::SessionClosed(_)
74 | NetworkEvent::StreamReset(_)
75 | NetworkEvent::Disconnected(_) => {
76 return Err(EncodeError::Io(std::io::Error::other(format!(
77 "Cannot encode local-only variant as wire event: {event:?}"
78 ))));
79 }
80 };
81 rmp_serde::to_vec(&wire_event)
82 .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
83 }
84
85 pub fn decode_event(
90 &self,
91 data: &[u8],
92 ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
93 let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
94 EncodeError::MalformedPayload {
95 offset: 0, message: e.to_string(),
97 }
98 })?;
99
100 Ok(match wire_event {
101 WireEvent::Ping { tick } => NetworkEvent::Ping {
102 client_id: ClientId(0), tick,
104 },
105 WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
106 WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
107 WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
108 client_id: ClientId(0),
109 fragment,
110 },
111 WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
112 client_id: ClientId(0), count,
114 rotate,
115 },
116 WireEvent::Spawn {
117 entity_type,
118 x,
119 y,
120 rot,
121 } => NetworkEvent::Spawn {
122 client_id: ClientId(0),
123 entity_type,
124 x,
125 y,
126 rot,
127 },
128 WireEvent::ClearWorld => NetworkEvent::ClearWorld {
129 client_id: ClientId(0),
130 },
131 })
132 }
133}
134
135impl Encoder for SerdeEncoder {
136 fn codec_id(&self) -> u32 {
137 1
138 }
139
140 fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
141 self.encode_event(event)
142 }
143
144 fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
145 self.decode_event(data)
146 }
147
148 fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
149 #[cfg(not(target_arch = "wasm32"))]
150 let start = std::time::Instant::now();
151 let header = PacketHeader {
152 network_id: event.network_id,
153 component_kind: event.component_kind,
154 tick: event.tick,
155 };
156
157 let mut cursor = Cursor::new(buffer);
158 let mut serializer = rmp_serde::Serializer::new(&mut cursor);
159
160 header.serialize(&mut serializer).map_err(|_e| {
161 metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
162 .increment(1);
163 EncodeError::BufferOverflow {
165 needed: 32, available: cursor.get_ref().len(),
167 }
168 })?;
169
170 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
171 let payload_len = event.payload.len();
172 let total_needed = header_len + payload_len;
173
174 if total_needed > cursor.get_ref().len() {
175 metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
176 .increment(1);
177 return Err(EncodeError::BufferOverflow {
178 needed: total_needed,
179 available: cursor.get_ref().len(),
180 });
181 }
182
183 let slice = cursor.into_inner();
185 slice[header_len..total_needed].copy_from_slice(&event.payload);
186
187 #[allow(clippy::cast_precision_loss)]
188 metrics::histogram!(
189 "aetheris_encoder_payload_size_bytes",
190 "operation" => "encode"
191 )
192 .record(total_needed as f64);
193
194 #[cfg(not(target_arch = "wasm32"))]
195 metrics::histogram!(
196 "aetheris_encoder_encode_duration_seconds",
197 "kind" => event.component_kind.0.to_string()
198 )
199 .record(start.elapsed().as_secs_f64());
200
201 Ok(total_needed)
202 }
203
204 fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
205 #[cfg(not(target_arch = "wasm32"))]
206 let start = std::time::Instant::now();
207 let mut cursor = Cursor::new(buffer);
208 let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
209
210 let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
211 metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
212 .increment(1);
213 EncodeError::MalformedPayload {
214 offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
215 message: e.to_string(),
216 }
217 })?;
218
219 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
220 let payload = buffer
221 .get(header_len..)
222 .ok_or(EncodeError::MalformedPayload {
223 offset: header_len,
224 message: "Payload slice out of bounds".to_string(),
225 })?
226 .to_vec();
227
228 #[allow(clippy::cast_precision_loss)]
229 metrics::histogram!(
230 "aetheris_encoder_payload_size_bytes",
231 "operation" => "decode"
232 )
233 .record(buffer.len() as f64);
234
235 #[cfg(not(target_arch = "wasm32"))]
236 metrics::histogram!(
237 "aetheris_encoder_decode_duration_seconds",
238 "kind" => header.component_kind.0.to_string()
239 )
240 .record(start.elapsed().as_secs_f64());
241
242 Ok(ComponentUpdate {
243 network_id: header.network_id,
244 component_kind: header.component_kind,
245 payload,
246 tick: header.tick,
247 })
248 }
249
250 fn max_encoded_size(&self) -> usize {
251 aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use proptest::prelude::*;
259
260 #[test]
261 fn test_roundtrip() {
262 let encoder = SerdeEncoder::new();
263 let event = ReplicationEvent {
264 network_id: NetworkId(42),
265 component_kind: ComponentKind(1),
266 payload: vec![1, 2, 3, 4],
267 tick: 100,
268 };
269
270 let mut buffer = [0u8; 1200];
271 let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
272 assert!(bytes_written > 0);
273
274 let update = encoder.decode(&buffer[..bytes_written]).unwrap();
275 assert_eq!(update.network_id, event.network_id);
276 assert_eq!(update.component_kind, event.component_kind);
277 assert_eq!(update.tick, event.tick);
278 assert_eq!(update.payload, event.payload);
279 }
280 #[test]
281 fn test_fragment_roundtrip() {
282 let encoder = SerdeEncoder::new();
283 let fragment = aetheris_protocol::events::FragmentedEvent {
284 message_id: 123,
285 fragment_index: 1,
286 total_fragments: 5,
287 payload: vec![1, 2, 3],
288 };
289
290 let event = NetworkEvent::Fragment {
291 client_id: aetheris_protocol::types::ClientId(0),
292 fragment: fragment.clone(),
293 };
294
295 let output = encoder.encode_event(&event).unwrap();
296 let decoded = encoder.decode_event(&output).unwrap();
297
298 if let NetworkEvent::Fragment {
299 client_id: _,
300 fragment: decoded_fragment,
301 } = decoded
302 {
303 assert_eq!(decoded_fragment.message_id, fragment.message_id);
304 assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
305 assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
306 assert_eq!(decoded_fragment.payload, fragment.payload);
307 } else {
308 panic!("Decoded event is not a Fragment: {decoded:?}");
309 }
310 }
311
312 #[test]
313 fn test_buffer_overflow() {
314 let encoder = SerdeEncoder::new();
315 let event = ReplicationEvent {
316 network_id: NetworkId(42),
317 component_kind: ComponentKind(1),
318 payload: vec![1, 2, 3, 4],
319 tick: 100,
320 };
321
322 let mut small_buffer = [0u8; 1];
323 let result = encoder.encode(&event, &mut small_buffer);
324 assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
325 }
326
327 #[test]
328 fn test_malformed_payload() {
329 let encoder = SerdeEncoder::new();
330 let garbage = [0xff, 0xff, 0xff, 0xff];
331 let result = encoder.decode(&garbage);
332 if let Err(EncodeError::MalformedPayload { message, .. }) = result {
333 assert!(!message.is_empty());
334 } else {
335 panic!("Expected MalformedPayload error, got {result:?}");
336 }
337 }
338
339 proptest! {
340 #[test]
341 fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
342 let encoder = SerdeEncoder::new();
343 let _ = encoder.decode(bytes);
345 }
346
347 #[test]
348 fn test_fuzz_roundtrip(
349 nid in any::<u64>(),
350 kind in any::<u16>(),
351 tick in any::<u64>(),
352 ref payload in any::<Vec<u8>>()
353 ) {
354 let encoder = SerdeEncoder::new();
355 let event = ReplicationEvent {
356 network_id: NetworkId(nid),
357 component_kind: ComponentKind(kind),
358 payload: payload.clone(),
359 tick,
360 };
361
362 let mut buffer = vec![0u8; 2048 + payload.len()];
363 if let Ok(written) = encoder.encode(&event, &mut buffer) {
364 let update = encoder.decode(&buffer[..written])
365 .expect("Round-trip decode failed during fuzzed test");
366 assert_eq!(update.network_id, event.network_id);
367 assert_eq!(update.component_kind, event.component_kind);
368 assert_eq!(update.tick, event.tick);
369 assert_eq!(update.payload, event.payload);
370 }
371 }
372 }
373
374 #[test]
375 fn test_disconnected_not_serializable() {
376 let encoder = SerdeEncoder::new();
377 let event = NetworkEvent::Disconnected(ClientId(42));
378
379 let result = encoder.encode_event(&event);
381 assert!(result.is_err());
382 if let Err(EncodeError::Io(e)) = result {
383 assert!(e.to_string().contains("Cannot encode local-only variant"));
384 } else {
385 panic!("Expected EncodeError::Io with local-only message, got {result:?}");
386 }
387 }
388
389 #[test]
390 fn test_wire_event_exclusivity() {
391 let encoder = SerdeEncoder::new();
392 let event = NetworkEvent::ClientConnected(ClientId(1));
394 assert!(encoder.encode_event(&event).is_err());
395
396 let auth = NetworkEvent::Auth {
401 session_token: "token".to_string(),
402 };
403 assert!(encoder.encode_event(&auth).is_ok());
404 }
405}