1use std::io::Cursor;
4
5use serde::{Deserialize, Serialize};
6
7use aetheris_protocol::error::EncodeError;
8use aetheris_protocol::events::{ComponentUpdate, NetworkEvent, ReplicationEvent, WireEvent};
9use aetheris_protocol::traits::Encoder;
10use aetheris_protocol::types::{ClientId, ComponentKind, NetworkId};
11
12#[derive(Debug, Serialize, Deserialize)]
16struct PacketHeader {
17 network_id: NetworkId,
18 component_kind: ComponentKind,
19 tick: u64,
20}
21
22#[derive(Debug, Default)]
27pub struct SerdeEncoder;
28
29impl SerdeEncoder {
30 #[must_use]
32 pub fn new() -> Self {
33 Self
34 }
35
36 pub fn encode_event(
41 &self,
42 event: &aetheris_protocol::events::NetworkEvent,
43 ) -> Result<Vec<u8>, EncodeError> {
44 let wire_event = match event {
45 NetworkEvent::Ping { tick, .. } => WireEvent::Ping { tick: *tick },
46 NetworkEvent::Pong { tick } => WireEvent::Pong { tick: *tick },
47 NetworkEvent::Auth { session_token } => WireEvent::Auth {
48 session_token: session_token.clone(),
49 },
50 NetworkEvent::Fragment { fragment, .. } => WireEvent::Fragment(fragment.clone()),
51 NetworkEvent::StressTest { count, rotate, .. } => WireEvent::StressTest {
52 count: *count,
53 rotate: *rotate,
54 },
55 NetworkEvent::Spawn {
56 entity_type,
57 x,
58 y,
59 rot,
60 ..
61 } => WireEvent::Spawn {
62 entity_type: *entity_type,
63 x: *x,
64 y: *y,
65 rot: *rot,
66 },
67 NetworkEvent::ClearWorld { .. } => WireEvent::ClearWorld,
68 NetworkEvent::ClientConnected(_)
69 | NetworkEvent::ClientDisconnected(_)
70 | NetworkEvent::UnreliableMessage { .. }
71 | NetworkEvent::ReliableMessage { .. }
72 | NetworkEvent::SessionClosed(_)
73 | NetworkEvent::StreamReset(_) => {
74 return Err(EncodeError::Io(std::io::Error::other(format!(
75 "Cannot encode local-only variant as wire event: {event:?}"
76 ))));
77 }
78 };
79 rmp_serde::to_vec(&wire_event)
80 .map_err(|e| EncodeError::Io(std::io::Error::other(e.to_string())))
81 }
82
83 pub fn decode_event(
88 &self,
89 data: &[u8],
90 ) -> Result<aetheris_protocol::events::NetworkEvent, EncodeError> {
91 let wire_event: WireEvent = rmp_serde::from_slice(data).map_err(|e| {
92 EncodeError::MalformedPayload {
93 offset: 0, message: e.to_string(),
95 }
96 })?;
97
98 Ok(match wire_event {
99 WireEvent::Ping { tick } => NetworkEvent::Ping {
100 client_id: ClientId(0), tick,
102 },
103 WireEvent::Pong { tick } => NetworkEvent::Pong { tick },
104 WireEvent::Auth { session_token } => NetworkEvent::Auth { session_token },
105 WireEvent::Fragment(fragment) => NetworkEvent::Fragment {
106 client_id: ClientId(0),
107 fragment,
108 },
109 WireEvent::StressTest { count, rotate } => NetworkEvent::StressTest {
110 client_id: ClientId(0), count,
112 rotate,
113 },
114 WireEvent::Spawn {
115 entity_type,
116 x,
117 y,
118 rot,
119 } => NetworkEvent::Spawn {
120 client_id: ClientId(0),
121 entity_type,
122 x,
123 y,
124 rot,
125 },
126 WireEvent::ClearWorld => NetworkEvent::ClearWorld {
127 client_id: ClientId(0),
128 },
129 })
130 }
131}
132
133impl Encoder for SerdeEncoder {
134 fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
135 self.encode_event(event)
136 }
137
138 fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
139 self.decode_event(data)
140 }
141
142 fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
143 #[cfg(not(target_arch = "wasm32"))]
144 let start = std::time::Instant::now();
145 let header = PacketHeader {
146 network_id: event.network_id,
147 component_kind: event.component_kind,
148 tick: event.tick,
149 };
150
151 let mut cursor = Cursor::new(buffer);
152 let mut serializer = rmp_serde::Serializer::new(&mut cursor);
153
154 header.serialize(&mut serializer).map_err(|_e| {
155 metrics::counter!("aetheris_encoder_errors_total", "type" => "header_serialize_fail")
156 .increment(1);
157 EncodeError::BufferOverflow {
159 needed: 32, available: cursor.get_ref().len(),
161 }
162 })?;
163
164 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
165 let payload_len = event.payload.len();
166 let total_needed = header_len + payload_len;
167
168 if total_needed > self.max_encoded_size() {
169 return Err(EncodeError::BufferOverflow {
170 needed: total_needed,
171 available: self.max_encoded_size(),
172 });
173 }
174
175 if total_needed > cursor.get_ref().len() {
176 metrics::counter!("aetheris_encoder_errors_total", "type" => "buffer_overflow")
177 .increment(1);
178 return Err(EncodeError::BufferOverflow {
179 needed: total_needed,
180 available: cursor.get_ref().len(),
181 });
182 }
183
184 let slice = cursor.into_inner();
186 slice[header_len..total_needed].copy_from_slice(&event.payload);
187
188 #[allow(clippy::cast_precision_loss)]
189 metrics::histogram!(
190 "aetheris_encoder_payload_size_bytes",
191 "operation" => "encode"
192 )
193 .record(total_needed as f64);
194
195 #[cfg(not(target_arch = "wasm32"))]
196 metrics::histogram!(
197 "aetheris_encoder_encode_duration_seconds",
198 "kind" => event.component_kind.0.to_string()
199 )
200 .record(start.elapsed().as_secs_f64());
201
202 Ok(total_needed)
203 }
204
205 fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
206 #[cfg(not(target_arch = "wasm32"))]
207 let start = std::time::Instant::now();
208 let mut cursor = Cursor::new(buffer);
209 let mut deserializer = rmp_serde::Deserializer::new(&mut cursor);
210
211 let header = PacketHeader::deserialize(&mut deserializer).map_err(|e| {
212 metrics::counter!("aetheris_encoder_errors_total", "type" => "malformed_payload")
213 .increment(1);
214 EncodeError::MalformedPayload {
215 offset: usize::try_from(cursor.position()).unwrap_or(usize::MAX),
216 message: e.to_string(),
217 }
218 })?;
219
220 let header_len = usize::try_from(cursor.position()).unwrap_or(usize::MAX);
221 let payload = buffer
222 .get(header_len..)
223 .ok_or(EncodeError::MalformedPayload {
224 offset: header_len,
225 message: "Payload slice out of bounds".to_string(),
226 })?
227 .to_vec();
228
229 #[allow(clippy::cast_precision_loss)]
230 metrics::histogram!(
231 "aetheris_encoder_payload_size_bytes",
232 "operation" => "decode"
233 )
234 .record(buffer.len() as f64);
235
236 #[cfg(not(target_arch = "wasm32"))]
237 metrics::histogram!(
238 "aetheris_encoder_decode_duration_seconds",
239 "kind" => header.component_kind.0.to_string()
240 )
241 .record(start.elapsed().as_secs_f64());
242
243 Ok(ComponentUpdate {
244 network_id: header.network_id,
245 component_kind: header.component_kind,
246 payload,
247 tick: header.tick,
248 })
249 }
250
251 fn max_encoded_size(&self) -> usize {
252 aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use proptest::prelude::*;
260
261 #[test]
262 fn test_roundtrip() {
263 let encoder = SerdeEncoder::new();
264 let event = ReplicationEvent {
265 network_id: NetworkId(42),
266 component_kind: ComponentKind(1),
267 payload: vec![1, 2, 3, 4],
268 tick: 100,
269 };
270
271 let mut buffer = [0u8; 1200];
272 let bytes_written = encoder.encode(&event, &mut buffer).unwrap();
273 assert!(bytes_written > 0);
274
275 let update = encoder.decode(&buffer[..bytes_written]).unwrap();
276 assert_eq!(update.network_id, event.network_id);
277 assert_eq!(update.component_kind, event.component_kind);
278 assert_eq!(update.tick, event.tick);
279 assert_eq!(update.payload, event.payload);
280 }
281 #[test]
282 fn test_fragment_roundtrip() {
283 let encoder = SerdeEncoder::new();
284 let fragment = aetheris_protocol::events::FragmentedEvent {
285 message_id: 123,
286 fragment_index: 1,
287 total_fragments: 5,
288 payload: vec![1, 2, 3],
289 };
290
291 let event = NetworkEvent::Fragment {
292 client_id: aetheris_protocol::types::ClientId(0),
293 fragment: fragment.clone(),
294 };
295
296 let output = encoder.encode_event(&event).unwrap();
297 let decoded = encoder.decode_event(&output).unwrap();
298
299 if let NetworkEvent::Fragment {
300 client_id: _,
301 fragment: decoded_fragment,
302 } = decoded
303 {
304 assert_eq!(decoded_fragment.message_id, fragment.message_id);
305 assert_eq!(decoded_fragment.fragment_index, fragment.fragment_index);
306 assert_eq!(decoded_fragment.total_fragments, fragment.total_fragments);
307 assert_eq!(decoded_fragment.payload, fragment.payload);
308 } else {
309 panic!("Decoded event is not a Fragment: {decoded:?}");
310 }
311 }
312
313 #[test]
314 fn test_buffer_overflow() {
315 let encoder = SerdeEncoder::new();
316 let event = ReplicationEvent {
317 network_id: NetworkId(42),
318 component_kind: ComponentKind(1),
319 payload: vec![1, 2, 3, 4],
320 tick: 100,
321 };
322
323 let mut small_buffer = [0u8; 1];
324 let result = encoder.encode(&event, &mut small_buffer);
325 assert!(matches!(result, Err(EncodeError::BufferOverflow { .. })));
326 }
327
328 #[test]
329 fn test_malformed_payload() {
330 let encoder = SerdeEncoder::new();
331 let garbage = [0xff, 0xff, 0xff, 0xff];
332 let result = encoder.decode(&garbage);
333 if let Err(EncodeError::MalformedPayload { message, .. }) = result {
334 assert!(!message.is_empty());
335 } else {
336 panic!("Expected MalformedPayload error, got {result:?}");
337 }
338 }
339
340 proptest! {
341 #[test]
342 fn test_fuzz_decode(ref bytes in any::<Vec<u8>>()) {
343 let encoder = SerdeEncoder::new();
344 let _ = encoder.decode(bytes);
346 }
347
348 #[test]
349 fn test_fuzz_roundtrip(
350 nid in any::<u64>(),
351 kind in any::<u16>(),
352 tick in any::<u64>(),
353 ref payload in any::<Vec<u8>>()
354 ) {
355 let encoder = SerdeEncoder::new();
356 let event = ReplicationEvent {
357 network_id: NetworkId(nid),
358 component_kind: ComponentKind(kind),
359 payload: payload.clone(),
360 tick,
361 };
362
363 let mut buffer = vec![0u8; 2048 + payload.len()];
364 if let Ok(written) = encoder.encode(&event, &mut buffer) {
365 let update = encoder.decode(&buffer[..written])
366 .expect("Round-trip decode failed during fuzzed test");
367 assert_eq!(update.network_id, event.network_id);
368 assert_eq!(update.component_kind, event.component_kind);
369 assert_eq!(update.tick, event.tick);
370 assert_eq!(update.payload, event.payload);
371 }
372 }
373 }
374}