1use std::{
2 error::Error,
3 fmt,
4 io::{self, Cursor},
5 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
6 time::Duration,
7};
8
9use crate::{
10 crypto::{dencrypted_in_place_xnonce, encrypt_in_place_xnonce, generate_random_bytes},
11 serialize::*,
12 NetcodeError, NETCODE_ADDITIONAL_DATA_SIZE, NETCODE_ADDRESS_IPV4, NETCODE_ADDRESS_IPV6, NETCODE_ADDRESS_NONE,
13 NETCODE_CONNECT_TOKEN_PRIVATE_BYTES, NETCODE_CONNECT_TOKEN_XNONCE_BYTES, NETCODE_KEY_BYTES, NETCODE_USER_DATA_BYTES,
14 NETCODE_VERSION_INFO,
15};
16use chacha20poly1305::aead::Error as CryptoError;
17
18#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct ConnectToken {
23 pub client_id: u64,
27 pub version_info: [u8; 13],
28 pub protocol_id: u64,
29 pub create_timestamp: u64,
30 pub expire_timestamp: u64,
31 pub xnonce: [u8; NETCODE_CONNECT_TOKEN_XNONCE_BYTES],
32 pub server_addresses: [Option<SocketAddr>; 32],
33 pub client_to_server_key: [u8; NETCODE_KEY_BYTES],
34 pub server_to_client_key: [u8; NETCODE_KEY_BYTES],
35 pub private_data: [u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES],
36 pub timeout_seconds: i32,
37}
38
39#[derive(Debug, PartialEq, Eq)]
40pub(crate) struct PrivateConnectToken {
41 pub client_id: u64, pub timeout_seconds: i32, pub server_addresses: [Option<SocketAddr>; 32],
44 pub client_to_server_key: [u8; NETCODE_KEY_BYTES],
45 pub server_to_client_key: [u8; NETCODE_KEY_BYTES],
46 pub user_data: [u8; NETCODE_USER_DATA_BYTES], }
48
49#[derive(Debug)]
50pub enum TokenGenerationError {
51 MaxHostCount,
53 CryptoError,
54 IoError(io::Error),
55 NoServerAddressAvailable,
56}
57
58impl From<io::Error> for TokenGenerationError {
59 fn from(inner: io::Error) -> Self {
60 TokenGenerationError::IoError(inner)
61 }
62}
63
64impl From<CryptoError> for TokenGenerationError {
65 fn from(_: CryptoError) -> Self {
66 TokenGenerationError::CryptoError
67 }
68}
69
70impl Error for TokenGenerationError {}
71
72impl fmt::Display for TokenGenerationError {
73 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
74 use TokenGenerationError::*;
75
76 match *self {
77 MaxHostCount => write!(fmt, "connect token can only have 32 server adresses"),
78 CryptoError => write!(fmt, "error while encoding or decoding the connect token"),
79 IoError(ref io_err) => write!(fmt, "{}", io_err),
80 NoServerAddressAvailable => write!(fmt, "connect token must have at least one server address"),
81 }
82 }
83}
84
85impl ConnectToken {
86 #[allow(clippy::too_many_arguments)]
89 pub fn generate(
90 current_time: Duration,
91 protocol_id: u64,
92 expire_seconds: u64,
93 client_id: u64,
94 timeout_seconds: i32,
95 server_addresses: Vec<SocketAddr>,
96 user_data: Option<&[u8; NETCODE_USER_DATA_BYTES]>,
97 private_key: &[u8; NETCODE_KEY_BYTES],
98 ) -> Result<Self, TokenGenerationError> {
99 let expire_timestamp = current_time.as_secs() + expire_seconds;
100
101 let private_connect_token = PrivateConnectToken::generate(client_id, timeout_seconds, server_addresses, user_data)?;
102 let mut private_data = [0u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES];
103 let xnonce = generate_random_bytes();
104 private_connect_token.encode(&mut private_data, protocol_id, expire_timestamp, &xnonce, private_key)?;
105
106 Ok(Self {
107 client_id,
108 version_info: *NETCODE_VERSION_INFO,
109 protocol_id,
110 private_data,
111 create_timestamp: current_time.as_secs(),
112 expire_timestamp,
113 xnonce,
114 server_addresses: private_connect_token.server_addresses,
115 client_to_server_key: private_connect_token.client_to_server_key,
116 server_to_client_key: private_connect_token.server_to_client_key,
117 timeout_seconds,
118 })
119 }
120
121 pub fn write(&self, writer: &mut impl io::Write) -> Result<(), io::Error> {
122 writer.write_all(&self.client_id.to_le_bytes())?;
123 writer.write_all(&self.version_info)?;
124 writer.write_all(&self.protocol_id.to_le_bytes())?;
125 writer.write_all(&self.create_timestamp.to_le_bytes())?;
126 writer.write_all(&self.expire_timestamp.to_le_bytes())?;
127 writer.write_all(&self.xnonce)?;
128 writer.write_all(&self.private_data)?;
129 writer.write_all(&self.timeout_seconds.to_le_bytes())?;
130 write_server_adresses(writer, &self.server_addresses)?;
131 writer.write_all(&self.client_to_server_key)?;
132 writer.write_all(&self.server_to_client_key)?;
133
134 Ok(())
135 }
136
137 pub fn read(src: &mut impl io::Read) -> Result<Self, NetcodeError> {
138 let client_id = read_u64(src)?;
139 let version_info: [u8; 13] = read_bytes(src)?;
140 if &version_info != NETCODE_VERSION_INFO {
141 return Err(NetcodeError::InvalidVersion);
142 }
143
144 let protocol_id = read_u64(src)?;
145 let create_timestamp = read_u64(src)?;
146 let expire_timestamp = read_u64(src)?;
147 let xnonce = read_bytes(src)?;
148
149 let private_data: [u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES] = read_bytes(src)?;
150 let timeout_seconds = read_i32(src)?;
151 let server_addresses = read_server_addresses(src)?;
152 let client_to_server_key: [u8; NETCODE_KEY_BYTES] = read_bytes(src)?;
153 let server_to_client_key: [u8; NETCODE_KEY_BYTES] = read_bytes(src)?;
154
155 Ok(Self {
156 client_id,
157 version_info,
158 protocol_id,
159 create_timestamp,
160 expire_timestamp,
161 xnonce,
162 private_data,
163 server_addresses,
164 client_to_server_key,
165 server_to_client_key,
166 timeout_seconds,
167 })
168 }
169}
170
171impl PrivateConnectToken {
172 fn generate(
173 client_id: u64,
174 timeout_seconds: i32,
175 server_addresses: Vec<SocketAddr>,
176 user_data: Option<&[u8; NETCODE_USER_DATA_BYTES]>,
177 ) -> Result<Self, TokenGenerationError> {
178 if server_addresses.len() > 32 {
179 return Err(TokenGenerationError::MaxHostCount);
180 }
181 if server_addresses.is_empty() {
182 return Err(TokenGenerationError::NoServerAddressAvailable);
183 }
184
185 let mut server_addresses_arr = [None; 32];
186 for (i, addr) in server_addresses.into_iter().enumerate() {
187 server_addresses_arr[i] = Some(addr);
188 }
189
190 let client_to_server_key = generate_random_bytes();
191 let server_to_client_key = generate_random_bytes();
192
193 let user_data = match user_data {
194 Some(data) => *data,
195 None => generate_random_bytes(),
196 };
197
198 Ok(Self {
199 client_id,
200 timeout_seconds,
201 server_addresses: server_addresses_arr,
202 client_to_server_key,
203 server_to_client_key,
204 user_data,
205 })
206 }
207
208 fn write(&self, writer: &mut impl io::Write) -> Result<(), io::Error> {
209 writer.write_all(&self.client_id.to_le_bytes())?;
210 writer.write_all(&self.timeout_seconds.to_le_bytes())?;
211 write_server_adresses(writer, &self.server_addresses)?;
212 writer.write_all(&self.client_to_server_key)?;
213 writer.write_all(&self.server_to_client_key)?;
214 writer.write_all(&self.user_data)?;
215
216 Ok(())
217 }
218
219 fn read(src: &mut impl io::Read) -> Result<Self, io::Error> {
220 let client_id = read_u64(src)?;
221 let timeout_seconds = read_i32(src)?;
222 let server_addresses = read_server_addresses(src)?;
223 let mut client_to_server_key = [0u8; 32];
224 src.read_exact(&mut client_to_server_key)?;
225
226 let mut server_to_client_key = [0u8; 32];
227 src.read_exact(&mut server_to_client_key)?;
228
229 let mut user_data = [0u8; 256];
230 src.read_exact(&mut user_data)?;
231
232 Ok(Self {
233 client_id,
234 timeout_seconds,
235 server_addresses,
236 client_to_server_key,
237 server_to_client_key,
238 user_data,
239 })
240 }
241
242 pub(crate) fn encode(
243 &self,
244 buffer: &mut [u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES],
245 protocol_id: u64,
246 expire_timestamp: u64,
247 xnonce: &[u8; NETCODE_CONNECT_TOKEN_XNONCE_BYTES],
248 private_key: &[u8; NETCODE_KEY_BYTES],
249 ) -> Result<(), TokenGenerationError> {
250 let aad = get_additional_data(protocol_id, expire_timestamp);
251 self.write(&mut Cursor::new(&mut buffer[..]))?;
252
253 encrypt_in_place_xnonce(buffer, xnonce, private_key, &aad)?;
254
255 Ok(())
256 }
257
258 pub(crate) fn decode(
259 buffer: &[u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES],
260 protocol_id: u64,
261 expire_timestamp: u64,
262 xnonce: &[u8; NETCODE_CONNECT_TOKEN_XNONCE_BYTES],
263 private_key: &[u8; NETCODE_KEY_BYTES],
264 ) -> Result<Self, TokenGenerationError> {
265 let aad = get_additional_data(protocol_id, expire_timestamp);
266
267 let mut temp_buffer = [0u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES];
268 temp_buffer.copy_from_slice(buffer);
269
270 dencrypted_in_place_xnonce(&mut temp_buffer, xnonce, private_key, &aad)?;
271
272 let src = &mut io::Cursor::new(&temp_buffer[..]);
273 Ok(Self::read(src)?)
274 }
275}
276
277fn write_server_adresses(writer: &mut impl io::Write, server_addresses: &[Option<SocketAddr>; 32]) -> Result<(), io::Error> {
278 let num_server_addresses: u32 = server_addresses.iter().filter(|a| a.is_some()).count() as u32;
279 writer.write_all(&num_server_addresses.to_le_bytes())?;
280
281 for host in server_addresses.iter().flatten() {
282 match host {
283 SocketAddr::V4(addr) => {
284 writer.write_all(&NETCODE_ADDRESS_IPV4.to_le_bytes())?;
285 for i in addr.ip().octets() {
286 writer.write_all(&i.to_le_bytes())?;
287 }
288 }
289 SocketAddr::V6(addr) => {
290 writer.write_all(&NETCODE_ADDRESS_IPV6.to_le_bytes())?;
291 for i in addr.ip().octets() {
292 writer.write_all(&i.to_le_bytes())?;
293 }
294 }
295 }
296 writer.write_all(&host.port().to_le_bytes())?;
297 }
298
299 Ok(())
300}
301
302fn read_server_addresses(src: &mut impl io::Read) -> Result<[Option<SocketAddr>; 32], io::Error> {
303 let mut server_addresses = [None; 32];
304 let num_server_addresses = read_u32(src)? as usize;
305 for server_address in server_addresses.iter_mut().take(num_server_addresses) {
306 let host_type = read_u8(src)?;
307 match host_type {
308 NETCODE_ADDRESS_IPV4 => {
309 let mut ip = [0u8; 4];
310 src.read_exact(&mut ip)?;
311 let port = read_u16(src)?;
312 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip)), port);
313 *server_address = Some(addr);
314 }
315 NETCODE_ADDRESS_IPV6 => {
316 let mut ip = [0u8; 16];
317 src.read_exact(&mut ip)?;
318 let port = read_u16(src)?;
319 let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip)), port);
320 *server_address = Some(addr);
321 }
322 NETCODE_ADDRESS_NONE => {} _ => return Err(io::Error::new(io::ErrorKind::InvalidData, "Unknown ip address type")),
324 }
325 }
326
327 if server_addresses.is_empty() {
328 return Err(io::Error::new(
329 io::ErrorKind::InvalidData,
330 "ConnectToken does not have a server address",
331 ));
332 }
333
334 Ok(server_addresses)
335}
336
337fn get_additional_data(protocol_id: u64, expire_timestamp: u64) -> [u8; NETCODE_ADDITIONAL_DATA_SIZE] {
338 let mut buffer = [0; NETCODE_ADDITIONAL_DATA_SIZE];
339 buffer[..13].copy_from_slice(NETCODE_VERSION_INFO);
340 buffer[13..21].copy_from_slice(&protocol_id.to_le_bytes());
341 buffer[21..29].copy_from_slice(&expire_timestamp.to_le_bytes());
342
343 buffer
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn private_connect_token_serialization() {
352 let hosts: Vec<SocketAddr> = vec!["127.0.0.1:8080".parse().unwrap(), "127.0.0.2:3000".parse().unwrap()];
353 let token = PrivateConnectToken::generate(1, 5, hosts, Some(&generate_random_bytes())).unwrap();
354 let mut buffer: Vec<u8> = vec![];
355
356 token.write(&mut buffer).unwrap();
357 let result = PrivateConnectToken::read(&mut buffer.as_slice()).unwrap();
358
359 assert_eq!(token, result);
360 }
361
362 #[test]
363 fn private_connect_token_encode_decode() {
364 let hosts: Vec<SocketAddr> = vec!["127.0.0.1:8080".parse().unwrap(), "127.0.0.2:3000".parse().unwrap()];
365 let token = PrivateConnectToken::generate(1, 5, hosts, Some(&generate_random_bytes())).unwrap();
366 let key = b"an example very very secret key."; let protocol_id = 12;
368 let expire_timestamp = 0;
369 let mut buffer = [0u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES];
370 let xnonce = generate_random_bytes();
371 token.encode(&mut buffer, protocol_id, expire_timestamp, &xnonce, key).unwrap();
372
373 let result = PrivateConnectToken::decode(&buffer, protocol_id, expire_timestamp, &xnonce, key).unwrap();
374 assert_eq!(token, result);
375 }
376
377 #[test]
378 fn connect_token_serialization() {
379 let server_addresses: Vec<SocketAddr> = vec!["127.0.0.1:8080".parse().unwrap(), "127.0.0.2:3000".parse().unwrap()];
380 let user_data = generate_random_bytes();
381 let private_key = b"an example very very secret key."; let protocol_id = 2;
383 let expire_seconds = 3;
384 let client_id = 4;
385 let timeout_seconds = 5;
386 let token = ConnectToken::generate(
387 Duration::ZERO,
388 protocol_id,
389 expire_seconds,
390 client_id,
391 timeout_seconds,
392 server_addresses,
393 Some(&user_data),
394 private_key,
395 )
396 .unwrap();
397
398 let mut buffer: Vec<u8> = vec![];
399 token.write(&mut buffer).unwrap();
400
401 let result = ConnectToken::read(&mut buffer.as_slice()).unwrap();
402 assert_eq!(token, result);
403
404 let private = PrivateConnectToken::decode(
405 &result.private_data,
406 protocol_id,
407 result.expire_timestamp,
408 &result.xnonce,
409 private_key,
410 )
411 .unwrap();
412 assert_eq!(timeout_seconds, private.timeout_seconds);
413 assert_eq!(client_id, private.client_id);
414 assert_eq!(user_data, private.user_data);
415 assert_eq!(token.server_addresses, private.server_addresses);
416 assert_eq!(token.client_to_server_key, private.client_to_server_key);
417 assert_eq!(token.server_to_client_key, private.server_to_client_key);
418 }
419}