renetcode/
token.rs

1use std::{
2    error::Error,
3    fmt,
4    io::{self, Cursor},
5    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
6    time::Duration,
7};
8
9use crate::{
10    crypto::{dencrypted_in_place_xnonce, encrypt_in_place_xnonce, generate_random_bytes},
11    serialize::*,
12    NetcodeError, NETCODE_ADDITIONAL_DATA_SIZE, NETCODE_ADDRESS_IPV4, NETCODE_ADDRESS_IPV6, NETCODE_ADDRESS_NONE,
13    NETCODE_CONNECT_TOKEN_PRIVATE_BYTES, NETCODE_CONNECT_TOKEN_XNONCE_BYTES, NETCODE_KEY_BYTES, NETCODE_USER_DATA_BYTES,
14    NETCODE_VERSION_INFO,
15};
16use chacha20poly1305::aead::Error as CryptoError;
17
18/// A public connect token that the client receives to start connecting to the server.
19/// How the client receives ConnectToken is up to you, could be from a matchmaking
20/// system or from a call to a REST API as an example.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct ConnectToken {
23    // NOTE: On the netcode standard the client id is not available in the public part of the
24    // ConnectToken. But having it acessible here makes it easier to consume the token, and the
25    // server still uses the client_id from the private part.
26    pub client_id: u64,
27    pub version_info: [u8; 13],
28    pub protocol_id: u64,
29    pub create_timestamp: u64,
30    pub expire_timestamp: u64,
31    pub xnonce: [u8; NETCODE_CONNECT_TOKEN_XNONCE_BYTES],
32    pub server_addresses: [Option<SocketAddr>; 32],
33    pub client_to_server_key: [u8; NETCODE_KEY_BYTES],
34    pub server_to_client_key: [u8; NETCODE_KEY_BYTES],
35    pub private_data: [u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES],
36    pub timeout_seconds: i32,
37}
38
39#[derive(Debug, PartialEq, Eq)]
40pub(crate) struct PrivateConnectToken {
41    pub client_id: u64,       // globally unique identifier for an authenticated client
42    pub timeout_seconds: i32, // timeout in seconds. negative values disable timeout (dev only)
43    pub server_addresses: [Option<SocketAddr>; 32],
44    pub client_to_server_key: [u8; NETCODE_KEY_BYTES],
45    pub server_to_client_key: [u8; NETCODE_KEY_BYTES],
46    pub user_data: [u8; NETCODE_USER_DATA_BYTES], // user defined data specific to this protocol id
47}
48
49#[derive(Debug)]
50pub enum TokenGenerationError {
51    /// The maximum number of address in the token is 32
52    MaxHostCount,
53    CryptoError,
54    IoError(io::Error),
55    NoServerAddressAvailable,
56}
57
58impl From<io::Error> for TokenGenerationError {
59    fn from(inner: io::Error) -> Self {
60        TokenGenerationError::IoError(inner)
61    }
62}
63
64impl From<CryptoError> for TokenGenerationError {
65    fn from(_: CryptoError) -> Self {
66        TokenGenerationError::CryptoError
67    }
68}
69
70impl Error for TokenGenerationError {}
71
72impl fmt::Display for TokenGenerationError {
73    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
74        use TokenGenerationError::*;
75
76        match *self {
77            MaxHostCount => write!(fmt, "connect token can only have 32 server adresses"),
78            CryptoError => write!(fmt, "error while encoding or decoding the connect token"),
79            IoError(ref io_err) => write!(fmt, "{}", io_err),
80            NoServerAddressAvailable => write!(fmt, "connect token must have at least one server address"),
81        }
82    }
83}
84
85impl ConnectToken {
86    /// Generate a token to be sent to an client. The user data is available to the server after an
87    /// successfull conection. The private key and the protocol id must be the same used in server.
88    #[allow(clippy::too_many_arguments)]
89    pub fn generate(
90        current_time: Duration,
91        protocol_id: u64,
92        expire_seconds: u64,
93        client_id: u64,
94        timeout_seconds: i32,
95        server_addresses: Vec<SocketAddr>,
96        user_data: Option<&[u8; NETCODE_USER_DATA_BYTES]>,
97        private_key: &[u8; NETCODE_KEY_BYTES],
98    ) -> Result<Self, TokenGenerationError> {
99        let expire_timestamp = current_time.as_secs() + expire_seconds;
100
101        let private_connect_token = PrivateConnectToken::generate(client_id, timeout_seconds, server_addresses, user_data)?;
102        let mut private_data = [0u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES];
103        let xnonce = generate_random_bytes();
104        private_connect_token.encode(&mut private_data, protocol_id, expire_timestamp, &xnonce, private_key)?;
105
106        Ok(Self {
107            client_id,
108            version_info: *NETCODE_VERSION_INFO,
109            protocol_id,
110            private_data,
111            create_timestamp: current_time.as_secs(),
112            expire_timestamp,
113            xnonce,
114            server_addresses: private_connect_token.server_addresses,
115            client_to_server_key: private_connect_token.client_to_server_key,
116            server_to_client_key: private_connect_token.server_to_client_key,
117            timeout_seconds,
118        })
119    }
120
121    pub fn write(&self, writer: &mut impl io::Write) -> Result<(), io::Error> {
122        writer.write_all(&self.client_id.to_le_bytes())?;
123        writer.write_all(&self.version_info)?;
124        writer.write_all(&self.protocol_id.to_le_bytes())?;
125        writer.write_all(&self.create_timestamp.to_le_bytes())?;
126        writer.write_all(&self.expire_timestamp.to_le_bytes())?;
127        writer.write_all(&self.xnonce)?;
128        writer.write_all(&self.private_data)?;
129        writer.write_all(&self.timeout_seconds.to_le_bytes())?;
130        write_server_adresses(writer, &self.server_addresses)?;
131        writer.write_all(&self.client_to_server_key)?;
132        writer.write_all(&self.server_to_client_key)?;
133
134        Ok(())
135    }
136
137    pub fn read(src: &mut impl io::Read) -> Result<Self, NetcodeError> {
138        let client_id = read_u64(src)?;
139        let version_info: [u8; 13] = read_bytes(src)?;
140        if &version_info != NETCODE_VERSION_INFO {
141            return Err(NetcodeError::InvalidVersion);
142        }
143
144        let protocol_id = read_u64(src)?;
145        let create_timestamp = read_u64(src)?;
146        let expire_timestamp = read_u64(src)?;
147        let xnonce = read_bytes(src)?;
148
149        let private_data: [u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES] = read_bytes(src)?;
150        let timeout_seconds = read_i32(src)?;
151        let server_addresses = read_server_addresses(src)?;
152        let client_to_server_key: [u8; NETCODE_KEY_BYTES] = read_bytes(src)?;
153        let server_to_client_key: [u8; NETCODE_KEY_BYTES] = read_bytes(src)?;
154
155        Ok(Self {
156            client_id,
157            version_info,
158            protocol_id,
159            create_timestamp,
160            expire_timestamp,
161            xnonce,
162            private_data,
163            server_addresses,
164            client_to_server_key,
165            server_to_client_key,
166            timeout_seconds,
167        })
168    }
169}
170
171impl PrivateConnectToken {
172    fn generate(
173        client_id: u64,
174        timeout_seconds: i32,
175        server_addresses: Vec<SocketAddr>,
176        user_data: Option<&[u8; NETCODE_USER_DATA_BYTES]>,
177    ) -> Result<Self, TokenGenerationError> {
178        if server_addresses.len() > 32 {
179            return Err(TokenGenerationError::MaxHostCount);
180        }
181        if server_addresses.is_empty() {
182            return Err(TokenGenerationError::NoServerAddressAvailable);
183        }
184
185        let mut server_addresses_arr = [None; 32];
186        for (i, addr) in server_addresses.into_iter().enumerate() {
187            server_addresses_arr[i] = Some(addr);
188        }
189
190        let client_to_server_key = generate_random_bytes();
191        let server_to_client_key = generate_random_bytes();
192
193        let user_data = match user_data {
194            Some(data) => *data,
195            None => generate_random_bytes(),
196        };
197
198        Ok(Self {
199            client_id,
200            timeout_seconds,
201            server_addresses: server_addresses_arr,
202            client_to_server_key,
203            server_to_client_key,
204            user_data,
205        })
206    }
207
208    fn write(&self, writer: &mut impl io::Write) -> Result<(), io::Error> {
209        writer.write_all(&self.client_id.to_le_bytes())?;
210        writer.write_all(&self.timeout_seconds.to_le_bytes())?;
211        write_server_adresses(writer, &self.server_addresses)?;
212        writer.write_all(&self.client_to_server_key)?;
213        writer.write_all(&self.server_to_client_key)?;
214        writer.write_all(&self.user_data)?;
215
216        Ok(())
217    }
218
219    fn read(src: &mut impl io::Read) -> Result<Self, io::Error> {
220        let client_id = read_u64(src)?;
221        let timeout_seconds = read_i32(src)?;
222        let server_addresses = read_server_addresses(src)?;
223        let mut client_to_server_key = [0u8; 32];
224        src.read_exact(&mut client_to_server_key)?;
225
226        let mut server_to_client_key = [0u8; 32];
227        src.read_exact(&mut server_to_client_key)?;
228
229        let mut user_data = [0u8; 256];
230        src.read_exact(&mut user_data)?;
231
232        Ok(Self {
233            client_id,
234            timeout_seconds,
235            server_addresses,
236            client_to_server_key,
237            server_to_client_key,
238            user_data,
239        })
240    }
241
242    pub(crate) fn encode(
243        &self,
244        buffer: &mut [u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES],
245        protocol_id: u64,
246        expire_timestamp: u64,
247        xnonce: &[u8; NETCODE_CONNECT_TOKEN_XNONCE_BYTES],
248        private_key: &[u8; NETCODE_KEY_BYTES],
249    ) -> Result<(), TokenGenerationError> {
250        let aad = get_additional_data(protocol_id, expire_timestamp);
251        self.write(&mut Cursor::new(&mut buffer[..]))?;
252
253        encrypt_in_place_xnonce(buffer, xnonce, private_key, &aad)?;
254
255        Ok(())
256    }
257
258    pub(crate) fn decode(
259        buffer: &[u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES],
260        protocol_id: u64,
261        expire_timestamp: u64,
262        xnonce: &[u8; NETCODE_CONNECT_TOKEN_XNONCE_BYTES],
263        private_key: &[u8; NETCODE_KEY_BYTES],
264    ) -> Result<Self, TokenGenerationError> {
265        let aad = get_additional_data(protocol_id, expire_timestamp);
266
267        let mut temp_buffer = [0u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES];
268        temp_buffer.copy_from_slice(buffer);
269
270        dencrypted_in_place_xnonce(&mut temp_buffer, xnonce, private_key, &aad)?;
271
272        let src = &mut io::Cursor::new(&temp_buffer[..]);
273        Ok(Self::read(src)?)
274    }
275}
276
277fn write_server_adresses(writer: &mut impl io::Write, server_addresses: &[Option<SocketAddr>; 32]) -> Result<(), io::Error> {
278    let num_server_addresses: u32 = server_addresses.iter().filter(|a| a.is_some()).count() as u32;
279    writer.write_all(&num_server_addresses.to_le_bytes())?;
280
281    for host in server_addresses.iter().flatten() {
282        match host {
283            SocketAddr::V4(addr) => {
284                writer.write_all(&NETCODE_ADDRESS_IPV4.to_le_bytes())?;
285                for i in addr.ip().octets() {
286                    writer.write_all(&i.to_le_bytes())?;
287                }
288            }
289            SocketAddr::V6(addr) => {
290                writer.write_all(&NETCODE_ADDRESS_IPV6.to_le_bytes())?;
291                for i in addr.ip().octets() {
292                    writer.write_all(&i.to_le_bytes())?;
293                }
294            }
295        }
296        writer.write_all(&host.port().to_le_bytes())?;
297    }
298
299    Ok(())
300}
301
302fn read_server_addresses(src: &mut impl io::Read) -> Result<[Option<SocketAddr>; 32], io::Error> {
303    let mut server_addresses = [None; 32];
304    let num_server_addresses = read_u32(src)? as usize;
305    for server_address in server_addresses.iter_mut().take(num_server_addresses) {
306        let host_type = read_u8(src)?;
307        match host_type {
308            NETCODE_ADDRESS_IPV4 => {
309                let mut ip = [0u8; 4];
310                src.read_exact(&mut ip)?;
311                let port = read_u16(src)?;
312                let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::from(ip)), port);
313                *server_address = Some(addr);
314            }
315            NETCODE_ADDRESS_IPV6 => {
316                let mut ip = [0u8; 16];
317                src.read_exact(&mut ip)?;
318                let port = read_u16(src)?;
319                let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::from(ip)), port);
320                *server_address = Some(addr);
321            }
322            NETCODE_ADDRESS_NONE => {} // skip
323            _ => return Err(io::Error::new(io::ErrorKind::InvalidData, "Unknown ip address type")),
324        }
325    }
326
327    if server_addresses.is_empty() {
328        return Err(io::Error::new(
329            io::ErrorKind::InvalidData,
330            "ConnectToken does not have a server address",
331        ));
332    }
333
334    Ok(server_addresses)
335}
336
337fn get_additional_data(protocol_id: u64, expire_timestamp: u64) -> [u8; NETCODE_ADDITIONAL_DATA_SIZE] {
338    let mut buffer = [0; NETCODE_ADDITIONAL_DATA_SIZE];
339    buffer[..13].copy_from_slice(NETCODE_VERSION_INFO);
340    buffer[13..21].copy_from_slice(&protocol_id.to_le_bytes());
341    buffer[21..29].copy_from_slice(&expire_timestamp.to_le_bytes());
342
343    buffer
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn private_connect_token_serialization() {
352        let hosts: Vec<SocketAddr> = vec!["127.0.0.1:8080".parse().unwrap(), "127.0.0.2:3000".parse().unwrap()];
353        let token = PrivateConnectToken::generate(1, 5, hosts, Some(&generate_random_bytes())).unwrap();
354        let mut buffer: Vec<u8> = vec![];
355
356        token.write(&mut buffer).unwrap();
357        let result = PrivateConnectToken::read(&mut buffer.as_slice()).unwrap();
358
359        assert_eq!(token, result);
360    }
361
362    #[test]
363    fn private_connect_token_encode_decode() {
364        let hosts: Vec<SocketAddr> = vec!["127.0.0.1:8080".parse().unwrap(), "127.0.0.2:3000".parse().unwrap()];
365        let token = PrivateConnectToken::generate(1, 5, hosts, Some(&generate_random_bytes())).unwrap();
366        let key = b"an example very very secret key."; // 32-bytes
367        let protocol_id = 12;
368        let expire_timestamp = 0;
369        let mut buffer = [0u8; NETCODE_CONNECT_TOKEN_PRIVATE_BYTES];
370        let xnonce = generate_random_bytes();
371        token.encode(&mut buffer, protocol_id, expire_timestamp, &xnonce, key).unwrap();
372
373        let result = PrivateConnectToken::decode(&buffer, protocol_id, expire_timestamp, &xnonce, key).unwrap();
374        assert_eq!(token, result);
375    }
376
377    #[test]
378    fn connect_token_serialization() {
379        let server_addresses: Vec<SocketAddr> = vec!["127.0.0.1:8080".parse().unwrap(), "127.0.0.2:3000".parse().unwrap()];
380        let user_data = generate_random_bytes();
381        let private_key = b"an example very very secret key."; // 32-bytes
382        let protocol_id = 2;
383        let expire_seconds = 3;
384        let client_id = 4;
385        let timeout_seconds = 5;
386        let token = ConnectToken::generate(
387            Duration::ZERO,
388            protocol_id,
389            expire_seconds,
390            client_id,
391            timeout_seconds,
392            server_addresses,
393            Some(&user_data),
394            private_key,
395        )
396        .unwrap();
397
398        let mut buffer: Vec<u8> = vec![];
399        token.write(&mut buffer).unwrap();
400
401        let result = ConnectToken::read(&mut buffer.as_slice()).unwrap();
402        assert_eq!(token, result);
403
404        let private = PrivateConnectToken::decode(
405            &result.private_data,
406            protocol_id,
407            result.expire_timestamp,
408            &result.xnonce,
409            private_key,
410        )
411        .unwrap();
412        assert_eq!(timeout_seconds, private.timeout_seconds);
413        assert_eq!(client_id, private.client_id);
414        assert_eq!(user_data, private.user_data);
415        assert_eq!(token.server_addresses, private.server_addresses);
416        assert_eq!(token.client_to_server_key, private.client_to_server_key);
417        assert_eq!(token.server_to_client_key, private.server_to_client_key);
418    }
419}