1use std::io::{self, Cursor, Write};
2
3use crate::crypto::{dencrypted_in_place, encrypt_in_place};
4use crate::replay_protection::ReplayProtection;
5use crate::token::ConnectToken;
6use crate::{
7 serialize::*, NetcodeError, NETCODE_CHALLENGE_TOKEN_BYTES, NETCODE_CONNECT_TOKEN_PRIVATE_BYTES, NETCODE_CONNECT_TOKEN_XNONCE_BYTES,
8 NETCODE_KEY_BYTES, NETCODE_MAC_BYTES,
9};
10use crate::{NETCODE_USER_DATA_BYTES, NETCODE_VERSION_INFO};
11
12#[derive(Debug)]
13#[repr(u8)]
14pub enum PacketType {
15 ConnectionRequest = 0,
16 ConnectionDenied = 1,
17 Challenge = 2,
18 Response = 3,
19 KeepAlive = 4,
20 Payload = 5,
21 Disconnect = 6,
22}
23
24#[derive(Debug, PartialEq, Eq)]
25#[allow(clippy::large_enum_variant)] pub enum Packet<'a> {
27 ConnectionRequest {
28 version_info: [u8; 13], protocol_id: u64,
30 expire_timestamp: u64,
31 xnonce: [u8; NETCODE_CONNECT_TOKEN_XNONCE_BYTES],
32 data: [u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES],
33 },
34 ConnectionDenied,
35 Challenge {
36 token_sequence: u64,
37 token_data: [u8; NETCODE_CHALLENGE_TOKEN_BYTES], },
39 Response {
40 token_sequence: u64,
41 token_data: [u8; NETCODE_CHALLENGE_TOKEN_BYTES], },
43 KeepAlive {
44 client_index: u32,
45 max_clients: u32,
46 },
47 Payload(&'a [u8]),
48 Disconnect,
49}
50
51#[derive(Debug, PartialEq, Eq)]
52pub struct ChallengeToken {
53 pub client_id: u64,
54 pub user_data: [u8; 256],
55}
56
57impl PacketType {
58 fn from_u8(value: u8) -> Result<Self, NetcodeError> {
59 use PacketType::*;
60
61 let packet_type = match value {
62 0 => ConnectionRequest,
63 1 => ConnectionDenied,
64 2 => Challenge,
65 3 => Response,
66 4 => KeepAlive,
67 5 => Payload,
68 6 => Disconnect,
69 _ => return Err(NetcodeError::InvalidPacketType),
70 };
71 Ok(packet_type)
72 }
73
74 fn apply_replay_protection(&self) -> bool {
75 use PacketType::*;
76
77 matches!(self, KeepAlive | Payload | Disconnect)
78 }
79}
80
81impl<'a> Packet<'a> {
82 pub fn packet_type(&self) -> PacketType {
83 match self {
84 Packet::ConnectionRequest { .. } => PacketType::ConnectionRequest,
85 Packet::ConnectionDenied => PacketType::ConnectionDenied,
86 Packet::Challenge { .. } => PacketType::Challenge,
87 Packet::Response { .. } => PacketType::Response,
88 Packet::KeepAlive { .. } => PacketType::KeepAlive,
89 Packet::Payload { .. } => PacketType::Payload,
90 Packet::Disconnect => PacketType::Disconnect,
91 }
92 }
93
94 pub fn id(&self) -> u8 {
95 self.packet_type() as u8
96 }
97
98 pub fn connection_request_from_token(connect_token: &ConnectToken) -> Self {
99 Packet::ConnectionRequest {
100 xnonce: connect_token.xnonce,
101 version_info: *NETCODE_VERSION_INFO,
102 protocol_id: connect_token.protocol_id,
103 expire_timestamp: connect_token.expire_timestamp,
104 data: connect_token.private_data,
105 }
106 }
107
108 pub fn generate_challenge(
109 client_id: u64,
110 user_data: &[u8; NETCODE_USER_DATA_BYTES],
111 challenge_sequence: u64,
112 challenge_key: &[u8; NETCODE_KEY_BYTES],
113 ) -> Result<Self, NetcodeError> {
114 let token = ChallengeToken::new(client_id, user_data);
115 let mut buffer = [0u8; NETCODE_CHALLENGE_TOKEN_BYTES];
116 token.write(&mut Cursor::new(&mut buffer[..]))?;
117 encrypt_in_place(&mut buffer, challenge_sequence, challenge_key, b"")?;
118
119 Ok(Packet::Challenge {
120 token_sequence: challenge_sequence,
121 token_data: buffer,
122 })
123 }
124
125 fn write(&self, writer: &mut impl io::Write) -> Result<(), io::Error> {
126 match self {
127 Packet::ConnectionRequest {
128 version_info,
129 protocol_id,
130 expire_timestamp,
131 xnonce,
132 data,
133 } => {
134 writer.write_all(version_info)?;
135 writer.write_all(&protocol_id.to_le_bytes())?;
136 writer.write_all(&expire_timestamp.to_le_bytes())?;
137 writer.write_all(xnonce)?;
138 writer.write_all(data)?;
139 }
140 Packet::Challenge {
141 token_data,
142 token_sequence,
143 }
144 | Packet::Response {
145 token_data,
146 token_sequence,
147 } => {
148 writer.write_all(&token_sequence.to_le_bytes())?;
149 writer.write_all(token_data)?;
150 }
151 Packet::KeepAlive { max_clients, client_index } => {
152 writer.write_all(&client_index.to_le_bytes())?;
153 writer.write_all(&max_clients.to_le_bytes())?;
154 }
155 Packet::Payload(p) => {
156 writer.write_all(p)?;
157 }
158 Packet::ConnectionDenied | Packet::Disconnect => {}
159 }
160
161 Ok(())
162 }
163
164 fn read(packet_type: PacketType, src: &'a [u8]) -> Result<Self, io::Error> {
165 if matches!(packet_type, PacketType::Payload) {
166 return Ok(Packet::Payload(src));
167 }
168
169 let src = &mut Cursor::new(src);
170
171 match packet_type {
172 PacketType::ConnectionRequest => {
173 let version_info = read_bytes(src)?;
174 let protocol_id = read_u64(src)?;
175 let expire_timestamp = read_u64(src)?;
176 let xnonce = read_bytes(src)?;
177 let token_data = read_bytes(src)?;
178
179 Ok(Packet::ConnectionRequest {
180 version_info,
181 protocol_id,
182 expire_timestamp,
183 xnonce,
184 data: token_data,
185 })
186 }
187 PacketType::Challenge => {
188 let token_sequence = read_u64(src)?;
189 let token_data = read_bytes(src)?;
190
191 Ok(Packet::Challenge {
192 token_data,
193 token_sequence,
194 })
195 }
196 PacketType::Response => {
197 let token_sequence = read_u64(src)?;
198 let token_data = read_bytes(src)?;
199
200 Ok(Packet::Response {
201 token_data,
202 token_sequence,
203 })
204 }
205 PacketType::KeepAlive => {
206 let client_index = read_u32(src)?;
207 let max_clients = read_u32(src)?;
208
209 Ok(Packet::KeepAlive { client_index, max_clients })
210 }
211 PacketType::ConnectionDenied => Ok(Packet::ConnectionDenied),
212 PacketType::Disconnect => Ok(Packet::Disconnect),
213 PacketType::Payload => unreachable!(),
214 }
215 }
216
217 pub fn encode(&self, buffer: &mut [u8], protocol_id: u64, crypto_info: Option<(u64, &[u8; 32])>) -> Result<usize, NetcodeError> {
218 if matches!(self, Packet::ConnectionRequest { .. }) {
219 let mut writer = io::Cursor::new(buffer);
220 let prefix_byte = encode_prefix(self.id(), 0);
221 writer.write_all(&prefix_byte.to_le_bytes())?;
222
223 self.write(&mut writer)?;
224 Ok(writer.position() as usize)
225 } else if let Some((sequence, private_key)) = crypto_info {
226 let (start, end, aad) = {
227 let mut writer = io::Cursor::new(&mut *buffer);
228 let prefix_byte = {
229 let prefix_byte = encode_prefix(self.id(), sequence);
230 writer.write_all(&prefix_byte.to_le_bytes())?;
231 write_sequence(&mut writer, sequence)?;
232 prefix_byte
233 };
234
235 let start = writer.position() as usize;
236 self.write(&mut writer)?;
237
238 let additional_data = get_additional_data(prefix_byte, protocol_id);
239 (start, writer.position() as usize, additional_data)
240 };
241 if buffer.len() < end + NETCODE_MAC_BYTES {
242 return Err(NetcodeError::IoError(io::Error::new(
243 io::ErrorKind::WriteZero,
244 "buffer too small to encode with encryption tag",
245 )));
246 }
247
248 encrypt_in_place(&mut buffer[start..end + NETCODE_MAC_BYTES], sequence, private_key, &aad)?;
249 Ok(end + NETCODE_MAC_BYTES)
250 } else {
251 Err(NetcodeError::UnavailablePrivateKey)
252 }
253 }
254
255 pub fn decode(
256 mut buffer: &'a mut [u8],
257 protocol_id: u64,
258 private_key: Option<&[u8; 32]>,
259 replay_protection: Option<&mut ReplayProtection>,
260 ) -> Result<(u64, Self), NetcodeError> {
261 if buffer.len() < 2 + NETCODE_MAC_BYTES {
262 return Err(NetcodeError::PacketTooSmall);
263 }
264
265 let prefix_byte = buffer[0];
266 let (packet_type, sequence_len) = decode_prefix(prefix_byte);
267 let packet_type = PacketType::from_u8(packet_type)?;
268
269 if matches!(packet_type, PacketType::ConnectionRequest) {
270 Ok((0, Packet::read(PacketType::ConnectionRequest, &buffer[1..])?))
271 } else if let Some(private_key) = private_key {
272 let (sequence, aad, read_pos) = {
273 let src = &mut io::Cursor::new(&mut buffer);
274 src.set_position(1);
275 let sequence = read_sequence(src, sequence_len)?;
276 let additional_data = get_additional_data(prefix_byte, protocol_id);
277 (sequence, additional_data, src.position() as usize)
278 };
279
280 if let Some(ref replay_protection) = replay_protection {
281 if packet_type.apply_replay_protection() && replay_protection.already_received(sequence) {
282 return Err(NetcodeError::DuplicatedSequence);
283 }
284 }
285
286 dencrypted_in_place(&mut buffer[read_pos..], sequence, private_key, &aad)?;
287
288 if let Some(replay_protection) = replay_protection {
289 if packet_type.apply_replay_protection() {
290 replay_protection.advance_sequence(sequence);
291 }
292 }
293
294 let packet = Packet::read(packet_type, &buffer[read_pos..buffer.len() - NETCODE_MAC_BYTES])?;
295 Ok((sequence, packet))
296 } else {
297 Err(NetcodeError::UnavailablePrivateKey)
298 }
299 }
300}
301
302impl ChallengeToken {
303 pub fn new(client_id: u64, user_data: &[u8; NETCODE_USER_DATA_BYTES]) -> Self {
304 Self {
305 client_id,
306 user_data: *user_data,
307 }
308 }
309
310 fn read(src: &mut impl io::Read) -> Result<Self, io::Error> {
311 let client_id = read_u64(src)?;
312 let user_data: [u8; NETCODE_USER_DATA_BYTES] = read_bytes(src)?;
313
314 Ok(Self { client_id, user_data })
315 }
316
317 fn write(&self, out: &mut impl io::Write) -> Result<(), io::Error> {
318 out.write_all(&self.client_id.to_le_bytes())?;
319 out.write_all(&self.user_data)?;
320
321 Ok(())
322 }
323
324 pub fn decode(
325 token_data: [u8; NETCODE_CHALLENGE_TOKEN_BYTES],
326 token_sequence: u64,
327 challenge_key: &[u8; NETCODE_KEY_BYTES],
328 ) -> Result<ChallengeToken, NetcodeError> {
329 let mut decoded = [0u8; NETCODE_CHALLENGE_TOKEN_BYTES];
330 decoded.copy_from_slice(&token_data);
331 dencrypted_in_place(&mut decoded, token_sequence, challenge_key, b"")?;
332
333 Ok(ChallengeToken::read(&mut Cursor::new(&mut decoded))?)
334 }
335}
336
337fn get_additional_data(prefix: u8, protocol_id: u64) -> [u8; 13 + 8 + 1] {
338 let mut buffer = [0; 13 + 8 + 1];
339 buffer[..13].copy_from_slice(NETCODE_VERSION_INFO);
340 buffer[13..21].copy_from_slice(&protocol_id.to_le_bytes());
341 buffer[21] = prefix;
342
343 buffer
344}
345
346fn decode_prefix(value: u8) -> (u8, usize) {
347 ((value & 0xF), (value >> 4) as usize)
348}
349
350fn encode_prefix(value: u8, sequence: u64) -> u8 {
351 value | ((sequence_bytes_required(sequence) as u8) << 4)
352}
353
354fn sequence_bytes_required(sequence: u64) -> usize {
355 let mut mask: u64 = 0xFF00_0000_0000_0000;
356 for i in 0..8 {
357 if (sequence & mask) != 0x00 {
358 return 8 - i;
359 }
360
361 mask >>= 8;
362 }
363
364 0
365}
366
367fn write_sequence(out: &mut impl io::Write, seq: u64) -> Result<usize, io::Error> {
368 let len = sequence_bytes_required(seq);
369 let sequence_scratch = seq.to_le_bytes();
370 out.write(&sequence_scratch[..len])
371}
372
373fn read_sequence(source: &mut impl io::Read, len: usize) -> Result<u64, io::Error> {
374 let mut seq_scratch = [0; 8];
375 source.read_exact(&mut seq_scratch[0..len])?;
376 Ok(u64::from_le_bytes(seq_scratch))
377}
378
379#[cfg(test)]
380mod tests {
381 use crate::{crypto::generate_random_bytes, NETCODE_MAX_PACKET_BYTES, NETCODE_MAX_PAYLOAD_BYTES};
382
383 use super::*;
384
385 #[test]
386 fn connection_request_serialization() {
387 let connection_request = Packet::ConnectionRequest {
388 xnonce: generate_random_bytes(),
389 version_info: [0; 13], protocol_id: 1,
391 expire_timestamp: 3,
392 data: [5; 1024],
393 };
394 let mut buffer = Vec::new();
395 connection_request.write(&mut buffer).unwrap();
396 let deserialized = Packet::read(PacketType::ConnectionRequest, &buffer).unwrap();
397
398 assert_eq!(deserialized, connection_request);
399 }
400
401 #[test]
402 fn connection_challenge_serialization() {
403 let connection_challenge = Packet::Challenge {
404 token_sequence: 0,
405 token_data: [1u8; 300],
406 };
407
408 let mut buffer = Vec::new();
409 connection_challenge.write(&mut buffer).unwrap();
410 let deserialized = Packet::read(PacketType::Challenge, buffer.as_slice()).unwrap();
411
412 assert_eq!(deserialized, connection_challenge);
413 }
414
415 #[test]
416 fn connection_keep_alive_serialization() {
417 let connection_keep_alive = Packet::KeepAlive {
418 max_clients: 2,
419 client_index: 1,
420 };
421
422 let mut buffer = Vec::new();
423 connection_keep_alive.write(&mut buffer).unwrap();
424 let deserialized = Packet::read(PacketType::KeepAlive, buffer.as_slice()).unwrap();
425
426 assert_eq!(deserialized, connection_keep_alive);
427 }
428
429 #[test]
430 fn prefix_sequence() {
431 let packet_type = Packet::Disconnect.id();
432 let sequence = 99999;
433
434 let mut buffer = vec![];
435 write_sequence(&mut buffer, sequence).unwrap();
436
437 let prefix = encode_prefix(packet_type, sequence);
438 let (d_packet_type, sequence_len) = decode_prefix(prefix);
439 assert_eq!(packet_type, d_packet_type);
440 assert_eq!(buffer.len(), sequence_len);
441
442 let d_sequence = read_sequence(&mut buffer.as_slice(), sequence_len).unwrap();
443
444 assert_eq!(sequence, d_sequence);
445 }
446
447 #[test]
448 fn encrypt_decrypt_disconnect_packet() {
449 let mut buffer = [0u8; NETCODE_MAX_PACKET_BYTES];
450 let key = b"an example very very secret key."; let packet = Packet::Disconnect;
452 let protocol_id = 12;
453 let sequence = 1;
454 let len = packet.encode(&mut buffer, protocol_id, Some((sequence, key))).unwrap();
455 let (d_sequence, d_packet) = Packet::decode(&mut buffer[..len], protocol_id, Some(key), None).unwrap();
456 assert_eq!(sequence, d_sequence);
457 assert_eq!(packet, d_packet);
458 }
459
460 #[test]
461 fn encrypt_decrypt_denied_packet() {
462 let mut buffer = [0u8; NETCODE_MAX_PACKET_BYTES];
463 let key = b"an example very very secret key."; let packet = Packet::ConnectionDenied;
465 let protocol_id = 12;
466 let sequence = 2;
467 let len = packet.encode(&mut buffer, protocol_id, Some((sequence, key))).unwrap();
468 let (d_sequence, d_packet) = Packet::decode(&mut buffer[..len], protocol_id, Some(key), None).unwrap();
469 assert_eq!(sequence, d_sequence);
470 assert_eq!(packet, d_packet);
471 }
472
473 #[test]
474 fn encrypt_decrypt_payload_packet() {
475 let mut buffer = [0u8; NETCODE_MAX_PACKET_BYTES];
476 let payload = vec![7u8; NETCODE_MAX_PAYLOAD_BYTES];
477 let key = b"an example very very secret key."; let packet = Packet::Payload(&payload);
479 let protocol_id = 12;
480 let sequence = 2;
481 let len = packet.encode(&mut buffer, protocol_id, Some((sequence, key))).unwrap();
482 let (d_sequence, d_packet) = Packet::decode(&mut buffer[..len], protocol_id, Some(key), None).unwrap();
483 assert_eq!(sequence, d_sequence);
484 match d_packet {
485 Packet::Payload(ref p) => assert_eq!(&payload, p),
486 _ => unreachable!(),
487 }
488 assert_eq!(packet, d_packet);
489 }
490
491 #[test]
492 fn encrypt_decrypt_challenge_token() {
493 let client_id = 0;
494 let user_data = generate_random_bytes();
495 let challenge_key = generate_random_bytes();
496 let challenge_sequence = 1;
497 let token = ChallengeToken::new(client_id, &user_data);
498 let packet = Packet::generate_challenge(client_id, &user_data, challenge_sequence, &challenge_key).unwrap();
499
500 match packet {
501 Packet::Challenge {
502 token_data,
503 token_sequence,
504 } => {
505 let decoded = ChallengeToken::decode(token_data, token_sequence, &challenge_key).unwrap();
506 assert_eq!(decoded, token);
507 }
508 _ => unreachable!(),
509 }
510 }
511}