1use byteorder::{LittleEndian, WriteBytesExt};
2use chacha20poly1305::{aead::OsRng, AeadCore, XChaCha20Poly1305, XNonce};
3use thiserror::Error;
4
5use crate::{
6 bytes::Bytes,
7 crypto::{self, Key},
8 error::Error,
9 free_list::{FreeList, FreeListIter},
10 CONNECTION_TIMEOUT_SEC, CONNECT_TOKEN_BYTES, NETCODE_VERSION, PRIVATE_KEY_BYTES,
11 USER_DATA_BYTES,
12};
13
14use std::{
15 io::{self, Write},
16 mem::size_of,
17 net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
18};
19
20const MAX_SERVERS_PER_CONNECT: usize = 32;
21const TOKEN_EXPIRE_SEC: i32 = 30;
22
23#[derive(Error, Debug)]
25pub enum InvalidTokenError {
26 #[error("address list length is out of range 1-32: {0}")]
27 AddressListLength(u32),
28 #[error("invalid ip address type (must be 1 for ipv4 or 2 for ipv6): {0}")]
29 InvalidIpAddressType(u8),
30 #[error("create timestamp is greater than expire timestamp")]
31 InvalidTimestamp,
32 #[error("invalid version")]
33 InvalidVersion,
34 #[error("io error: {0}")]
35 Io(#[from] io::Error),
36}
37
38#[derive(Debug, Clone, Copy)]
39pub struct AddressList {
40 addrs: FreeList<SocketAddr, MAX_SERVERS_PER_CONNECT>,
41}
42
43impl AddressList {
44 const IPV4: u8 = 1;
45 const IPV6: u8 = 2;
46 pub fn new(addrs: impl ToSocketAddrs) -> Result<Self, Error> {
47 let mut server_addresses = FreeList::new();
48
49 for (i, addr) in addrs.to_socket_addrs()?.enumerate() {
50 if i >= MAX_SERVERS_PER_CONNECT {
51 break;
52 }
53
54 server_addresses.insert(addr);
55 }
56
57 Ok(AddressList {
58 addrs: server_addresses,
59 })
60 }
61 pub fn len(&self) -> usize {
62 self.addrs.len()
63 }
64 pub fn iter(&self) -> FreeListIter<SocketAddr, MAX_SERVERS_PER_CONNECT> {
65 FreeListIter {
66 free_list: &self.addrs,
67 index: 0,
68 }
69 }
70}
71
72impl std::ops::Index<usize> for AddressList {
73 type Output = SocketAddr;
74
75 fn index(&self, index: usize) -> &Self::Output {
76 self.addrs.get(index).expect("index out of bounds")
77 }
78}
79
80impl Bytes for AddressList {
81 const SIZE: usize = size_of::<u32>() + MAX_SERVERS_PER_CONNECT * (1 + size_of::<u16>() + 16);
82 type Error = InvalidTokenError;
83 fn write_to(&self, buf: &mut impl io::Write) -> Result<(), InvalidTokenError> {
84 buf.write_u32::<LittleEndian>(self.len() as u32)?;
85 for (_, addr) in self.iter() {
86 match addr {
87 SocketAddr::V4(addr_v4) => {
88 buf.write_u8(Self::IPV4)?;
89 buf.write_all(&addr_v4.ip().octets())?;
90 buf.write_u16::<LittleEndian>(addr_v4.port())?;
91 }
92 SocketAddr::V6(addr_v6) => {
93 buf.write_u8(Self::IPV6)?;
94 buf.write_all(&addr_v6.ip().octets())?;
95 buf.write_u16::<LittleEndian>(addr_v6.port())?;
96 }
97 }
98 }
99 Ok(())
100 }
101
102 fn read_from(reader: &mut impl byteorder::ReadBytesExt) -> Result<Self, InvalidTokenError> {
103 let len = reader.read_u32::<LittleEndian>()?;
104
105 if !(1..=MAX_SERVERS_PER_CONNECT as u32).contains(&len) {
106 return Err(InvalidTokenError::AddressListLength(len));
107 }
108
109 let mut addrs = FreeList::new();
110
111 for _ in 0..len {
112 let addr_type = reader.read_u8()?;
113 let addr = match addr_type {
114 Self::IPV4 => {
115 let mut octets = [0; 4];
116 reader.read_exact(&mut octets)?;
117 let port = reader.read_u16::<LittleEndian>()?;
118 SocketAddr::from((Ipv4Addr::from(octets), port))
119 }
120 Self::IPV6 => {
121 let mut octets = [0; 16];
122 reader.read_exact(&mut octets)?;
123 let port = reader.read_u16::<LittleEndian>()?;
124 SocketAddr::from((Ipv6Addr::from(octets), port))
125 }
126 t => return Err(InvalidTokenError::InvalidIpAddressType(t)),
127 };
128 addrs.insert(addr);
129 }
130
131 Ok(Self { addrs })
132 }
133}
134
135pub struct ConnectTokenPrivate {
136 pub client_id: u64,
137 pub timeout_seconds: i32,
138 pub server_addresses: AddressList,
139 pub client_to_server_key: Key,
140 pub server_to_client_key: Key,
141 pub user_data: [u8; USER_DATA_BYTES],
142}
143
144impl ConnectTokenPrivate {
145 fn aead(
146 protocol_id: u64,
147 expire_timestamp: u64,
148 ) -> Result<[u8; NETCODE_VERSION.len() + std::mem::size_of::<u64>() * 2], Error> {
149 let mut aead = [0; NETCODE_VERSION.len() + std::mem::size_of::<u64>() * 2];
150 let mut cursor = io::Cursor::new(&mut aead[..]);
151 cursor.write_all(NETCODE_VERSION)?;
152 cursor.write_u64::<LittleEndian>(protocol_id)?;
153 cursor.write_u64::<LittleEndian>(expire_timestamp)?;
154 Ok(aead)
155 }
156
157 pub fn encrypt(
158 &self,
159 protocol_id: u64,
160 expire_timestamp: u64,
161 nonce: XNonce,
162 private_key: &Key,
163 ) -> Result<[u8; Self::SIZE], Error> {
164 let aead = Self::aead(protocol_id, expire_timestamp)?;
165 let mut buf = [0u8; Self::SIZE]; let mut cursor = io::Cursor::new(&mut buf[..]);
167 self.write_to(&mut cursor)?;
168 crypto::xchacha_encrypt(&mut buf, Some(&aead), nonce, private_key)?;
169 Ok(buf)
170 }
171
172 pub fn decrypt(
173 encrypted: &mut [u8],
174 protocol_id: u64,
175 expire_timestamp: u64,
176 nonce: XNonce,
177 private_key: &Key,
178 ) -> Result<Self, Error> {
179 let aead = Self::aead(protocol_id, expire_timestamp)?;
180 crypto::xchacha_decrypt(encrypted, Some(&aead), nonce, private_key)?;
181 let mut cursor = io::Cursor::new(encrypted);
182 Ok(Self::read_from(&mut cursor)?)
183 }
184}
185
186impl Bytes for ConnectTokenPrivate {
187 const SIZE: usize = 1024; type Error = io::Error;
189 fn write_to(&self, buf: &mut impl io::Write) -> Result<(), io::Error> {
190 buf.write_u64::<LittleEndian>(self.client_id)?;
191 buf.write_i32::<LittleEndian>(self.timeout_seconds)?;
192 self.server_addresses
193 .write_to(buf)
194 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
195 buf.write_all(&self.client_to_server_key)?;
196 buf.write_all(&self.server_to_client_key)?;
197 buf.write_all(&self.user_data)?;
198 Ok(())
199 }
200
201 fn read_from(reader: &mut impl byteorder::ReadBytesExt) -> Result<Self, io::Error> {
202 let client_id = reader.read_u64::<LittleEndian>()?;
203 let timeout_seconds = reader.read_i32::<LittleEndian>()?;
204 let server_addresses =
205 AddressList::read_from(reader).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
206
207 let mut client_to_server_key = [0; PRIVATE_KEY_BYTES];
208 reader.read_exact(&mut client_to_server_key)?;
209
210 let mut server_to_client_key = [0; PRIVATE_KEY_BYTES];
211 reader.read_exact(&mut server_to_client_key)?;
212
213 let mut user_data = [0; USER_DATA_BYTES];
214 reader.read_exact(&mut user_data)?;
215
216 Ok(Self {
217 client_id,
218 timeout_seconds,
219 server_addresses,
220 client_to_server_key,
221 server_to_client_key,
222 user_data,
223 })
224 }
225}
226
227pub struct ChallengeToken {
228 pub client_id: u64,
229 pub user_data: [u8; USER_DATA_BYTES],
230}
231
232impl ChallengeToken {
233 pub const SIZE: usize = 300;
234 pub fn encrypt(&self, sequence: u64, private_key: &Key) -> Result<[u8; Self::SIZE], Error> {
235 let mut buf = [0u8; Self::SIZE]; let mut cursor = io::Cursor::new(&mut buf[..]);
237 self.write_to(&mut cursor)?;
238 crypto::chacha_encrypt(&mut buf, None, sequence, private_key)?;
239 Ok(buf)
240 }
241
242 pub fn decrypt(
243 encrypted: &mut [u8; Self::SIZE],
244 sequence: u64,
245 private_key: &Key,
246 ) -> Result<Self, Error> {
247 crypto::chacha_decrypt(encrypted, None, sequence, private_key)?;
248 let mut cursor = io::Cursor::new(&encrypted[..]);
249 Ok(Self::read_from(&mut cursor)?)
250 }
251}
252
253impl Bytes for ChallengeToken {
254 const SIZE: usize = size_of::<u64>() + USER_DATA_BYTES;
255 type Error = io::Error;
256 fn write_to(&self, buf: &mut impl io::Write) -> Result<(), io::Error> {
257 buf.write_u64::<LittleEndian>(self.client_id)?;
258 buf.write_all(&self.user_data)?;
259 Ok(())
260 }
261
262 fn read_from(reader: &mut impl byteorder::ReadBytesExt) -> Result<Self, io::Error> {
263 let client_id = reader.read_u64::<LittleEndian>()?;
264 let mut user_data = [0; USER_DATA_BYTES];
265 reader.read_exact(&mut user_data)?;
266 Ok(Self {
267 client_id,
268 user_data,
269 })
270 }
271}
272
273pub struct ConnectToken {
307 pub(crate) version_info: [u8; NETCODE_VERSION.len()],
308 pub(crate) protocol_id: u64,
309 pub(crate) create_timestamp: u64,
310 pub(crate) expire_timestamp: u64,
311 pub(crate) nonce: XNonce,
312 pub(crate) private_data: [u8; ConnectTokenPrivate::SIZE],
313 pub(crate) timeout_seconds: i32,
314 pub(crate) server_addresses: AddressList,
315 pub(crate) client_to_server_key: Key,
316 pub(crate) server_to_client_key: Key,
317}
318
319pub struct ConnectTokenBuilder<A: ToSocketAddrs> {
321 protocol_id: u64,
322 client_id: u64,
323 expire_seconds: i32,
324 private_key: Key,
325 timeout_seconds: i32,
326 public_server_addresses: A,
327 internal_server_addresses: Option<AddressList>,
328 user_data: [u8; USER_DATA_BYTES],
329}
330
331impl<A: ToSocketAddrs> ConnectTokenBuilder<A> {
332 fn new(server_addresses: A, protocol_id: u64, client_id: u64, private_key: Key) -> Self {
333 Self {
334 protocol_id,
335 client_id,
336 expire_seconds: TOKEN_EXPIRE_SEC,
337 private_key,
338 timeout_seconds: CONNECTION_TIMEOUT_SEC,
339 public_server_addresses: server_addresses,
340 internal_server_addresses: None,
341 user_data: [0; USER_DATA_BYTES],
342 }
343 }
344 pub fn expire_seconds(mut self, expire_seconds: i32) -> Self {
348 self.expire_seconds = expire_seconds;
349 self
350 }
351 pub fn timeout_seconds(mut self, timeout_seconds: i32) -> Self {
355 self.timeout_seconds = timeout_seconds;
356 self
357 }
358 pub fn user_data(mut self, user_data: [u8; USER_DATA_BYTES]) -> Self {
360 self.user_data = user_data;
361 self
362 }
363 pub fn internal_addresses(mut self, internal_addresses: A) -> Result<Self, Error> {
372 self.internal_server_addresses = Some(AddressList::new(internal_addresses)?);
373 Ok(self)
374 }
375 pub fn generate(self) -> Result<ConnectToken, Error> {
377 let now = std::time::SystemTime::now()
378 .duration_since(std::time::UNIX_EPOCH)?
379 .as_secs();
380 let expire_timestamp = if self.expire_seconds < 0 {
381 u64::MAX
382 } else {
383 now + self.expire_seconds as u64
384 };
385 let public_server_addresses = AddressList::new(self.public_server_addresses)?;
386 let internal_server_addresses = match self.internal_server_addresses {
387 Some(addresses) => addresses,
388 None => public_server_addresses,
389 };
390 let client_to_server_key = crypto::generate_key();
391 let server_to_client_key = crypto::generate_key();
392 let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
393
394 let private_data = ConnectTokenPrivate {
395 client_id: self.client_id,
396 timeout_seconds: self.timeout_seconds,
397 server_addresses: internal_server_addresses,
398 client_to_server_key,
399 server_to_client_key,
400 user_data: self.user_data,
401 }
402 .encrypt(self.protocol_id, expire_timestamp, nonce, &self.private_key)?;
403
404 Ok(ConnectToken {
405 version_info: *NETCODE_VERSION,
406 protocol_id: self.protocol_id,
407 create_timestamp: now,
408 expire_timestamp,
409 nonce,
410 private_data,
411 timeout_seconds: self.timeout_seconds,
412 server_addresses: public_server_addresses,
413 client_to_server_key,
414 server_to_client_key,
415 })
416 }
417}
418
419impl ConnectToken {
420 pub fn build<A: ToSocketAddrs>(
422 server_addresses: A,
423 protocol_id: u64,
424 client_id: u64,
425 private_key: Key,
426 ) -> ConnectTokenBuilder<A> {
427 ConnectTokenBuilder::new(server_addresses, protocol_id, client_id, private_key)
428 }
429
430 pub fn try_into_bytes(self) -> Result<[u8; CONNECT_TOKEN_BYTES], io::Error> {
432 let mut buf = [0u8; CONNECT_TOKEN_BYTES];
433 let mut cursor = io::Cursor::new(&mut buf[..]);
434 self.write_to(&mut cursor).map_err(|e| {
435 io::Error::new(
436 io::ErrorKind::Other,
437 format!("failed to write token to buffer: {}", e),
438 )
439 })?;
440 Ok(buf)
441 }
442}
443
444impl Bytes for ConnectToken {
445 const SIZE: usize = 2048; type Error = InvalidTokenError;
447 fn write_to(&self, buf: &mut impl io::Write) -> Result<(), Self::Error> {
448 buf.write_all(&self.version_info)?;
449 buf.write_u64::<LittleEndian>(self.protocol_id)?;
450 buf.write_u64::<LittleEndian>(self.create_timestamp)?;
451 buf.write_u64::<LittleEndian>(self.expire_timestamp)?;
452 buf.write_all(&self.nonce)?;
453 buf.write_all(&self.private_data)?;
454 buf.write_i32::<LittleEndian>(self.timeout_seconds)?;
455 self.server_addresses.write_to(buf)?;
456 buf.write_all(&self.client_to_server_key)?;
457 buf.write_all(&self.server_to_client_key)?;
458 Ok(())
459 }
460
461 fn read_from(reader: &mut impl byteorder::ReadBytesExt) -> Result<Self, Self::Error> {
462 let mut version_info = [0; NETCODE_VERSION.len()];
463 reader.read_exact(&mut version_info)?;
464
465 if version_info != *NETCODE_VERSION {
466 return Err(InvalidTokenError::InvalidVersion);
467 }
468
469 let protocol_id = reader.read_u64::<LittleEndian>()?;
470
471 let create_timestamp = reader.read_u64::<LittleEndian>()?;
472 let expire_timestamp = reader.read_u64::<LittleEndian>()?;
473
474 if create_timestamp > expire_timestamp {
475 return Err(InvalidTokenError::InvalidTimestamp);
476 }
477
478 let mut nonce = [0; size_of::<XNonce>()];
479 reader.read_exact(&mut nonce)?;
480 let nonce = XNonce::from_slice(&nonce).to_owned();
481
482 let mut private_data = [0; ConnectTokenPrivate::SIZE];
483 reader.read_exact(&mut private_data)?;
484
485 let timeout_seconds = reader.read_i32::<LittleEndian>()?;
486
487 let server_addresses = AddressList::read_from(reader)?;
488
489 let mut client_to_server_key = [0; PRIVATE_KEY_BYTES];
490 reader.read_exact(&mut client_to_server_key)?;
491
492 let mut server_to_client_key = [0; PRIVATE_KEY_BYTES];
493 reader.read_exact(&mut server_to_client_key)?;
494
495 Ok(Self {
496 version_info,
497 protocol_id,
498 create_timestamp,
499 expire_timestamp,
500 nonce,
501 private_data,
502 timeout_seconds,
503 server_addresses,
504 client_to_server_key,
505 server_to_client_key,
506 })
507 }
508}
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn encrypt_decrypt_private_token() {
515 let private_key = crypto::generate_key();
516 let protocol_id = 1;
517 let expire_timestamp = 2;
518 let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
519 let client_id = 4;
520 let timeout_seconds = 5;
521 let server_addresses = AddressList::new(
522 &[
523 SocketAddr::from(([127, 0, 0, 1], 1)),
524 SocketAddr::from(([127, 0, 0, 1], 2)),
525 SocketAddr::from(([127, 0, 0, 1], 3)),
526 SocketAddr::from(([127, 0, 0, 1], 4)),
527 ][..],
528 )
529 .unwrap();
530 let user_data = [0x11; USER_DATA_BYTES];
531
532 let private_token = ConnectTokenPrivate {
533 client_id,
534 timeout_seconds,
535 server_addresses,
536 user_data,
537 client_to_server_key: crypto::generate_key(),
538 server_to_client_key: crypto::generate_key(),
539 };
540
541 let mut encrypted = private_token
542 .encrypt(protocol_id, expire_timestamp, nonce, &private_key)
543 .unwrap();
544
545 let private_token = ConnectTokenPrivate::decrypt(
546 &mut encrypted,
547 protocol_id,
548 expire_timestamp,
549 nonce,
550 &private_key,
551 )
552 .unwrap();
553
554 assert_eq!(private_token.client_id, client_id);
555 assert_eq!(private_token.timeout_seconds, timeout_seconds);
556 private_token
557 .server_addresses
558 .iter()
559 .zip(server_addresses.iter())
560 .for_each(|(have, expected)| {
561 assert_eq!(have, expected);
562 });
563 assert_eq!(private_token.user_data, user_data);
564 assert_eq!(
565 private_token.server_to_client_key,
566 private_token.server_to_client_key
567 );
568 assert_eq!(
569 private_token.client_to_server_key,
570 private_token.client_to_server_key
571 );
572 }
573
574 #[test]
575 fn encrypt_decrypt_challenge_token() {
576 let private_key = crypto::generate_key();
577 let sequence = 1;
578 let client_id = 2;
579 let user_data = [0x11; USER_DATA_BYTES];
580
581 let challenge_token = ChallengeToken {
582 client_id,
583 user_data,
584 };
585
586 let mut encrypted = challenge_token.encrypt(sequence, &private_key).unwrap();
587
588 let challenge_token =
589 ChallengeToken::decrypt(&mut encrypted, sequence, &private_key).unwrap();
590
591 assert_eq!(challenge_token.client_id, client_id);
592 assert_eq!(challenge_token.user_data, user_data);
593 }
594
595 #[test]
596 fn connect_token_read_write() {
597 let private_key = crypto::generate_key();
598 let protocol_id = 1;
599 let expire_timestamp = 2;
600 let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
601 let client_id = 4;
602 let timeout_seconds = 5;
603 let server_addresses = AddressList::new(
604 &[
605 SocketAddr::from(([127, 0, 0, 1], 1)),
606 SocketAddr::from(([127, 0, 0, 1], 2)),
607 SocketAddr::from(([127, 0, 0, 1], 3)),
608 SocketAddr::from(([127, 0, 0, 1], 4)),
609 ][..],
610 )
611 .unwrap();
612 let user_data = [0x11; USER_DATA_BYTES];
613
614 let private_token = ConnectTokenPrivate {
615 client_id,
616 timeout_seconds,
617 server_addresses,
618 user_data,
619 client_to_server_key: crypto::generate_key(),
620 server_to_client_key: crypto::generate_key(),
621 };
622
623 let mut encrypted = private_token
624 .encrypt(protocol_id, expire_timestamp, nonce, &private_key)
625 .unwrap();
626
627 let private_token = ConnectTokenPrivate::decrypt(
628 &mut encrypted,
629 protocol_id,
630 expire_timestamp,
631 nonce,
632 &private_key,
633 )
634 .unwrap();
635
636 let mut private_data = [0; ConnectTokenPrivate::SIZE];
637 let mut cursor = io::Cursor::new(&mut private_data[..]);
638 private_token.write_to(&mut cursor).unwrap();
639
640 let connect_token = ConnectToken {
641 version_info: *NETCODE_VERSION,
642 protocol_id,
643 create_timestamp: 0,
644 expire_timestamp,
645 nonce,
646 private_data,
647 timeout_seconds,
648 server_addresses,
649 client_to_server_key: private_token.client_to_server_key,
650 server_to_client_key: private_token.server_to_client_key,
651 };
652
653 let mut buf = Vec::new();
654 connect_token.write_to(&mut buf).unwrap();
655
656 let connect_token = ConnectToken::read_from(&mut buf.as_slice()).unwrap();
657
658 assert_eq!(connect_token.version_info, *NETCODE_VERSION);
659 assert_eq!(connect_token.protocol_id, protocol_id);
660 assert_eq!(connect_token.create_timestamp, 0);
661 assert_eq!(connect_token.expire_timestamp, expire_timestamp);
662 assert_eq!(connect_token.nonce, nonce);
663 assert_eq!(connect_token.private_data, private_data);
664 assert_eq!(connect_token.timeout_seconds, timeout_seconds);
665 connect_token
666 .server_addresses
667 .iter()
668 .zip(server_addresses.iter())
669 .for_each(|(have, expected)| {
670 assert_eq!(have, expected);
671 });
672 }
673
674 #[test]
675 fn connect_token_builder() {
676 let protocol_id = 1;
677 let client_id = 4;
678 let server_addresses = "127.0.0.1:12345";
679
680 let connect_token = ConnectToken::build(
681 server_addresses,
682 protocol_id,
683 client_id,
684 [0x42; PRIVATE_KEY_BYTES],
685 )
686 .user_data([0x11; USER_DATA_BYTES])
687 .timeout_seconds(5)
688 .expire_seconds(6)
689 .internal_addresses("0.0.0.0:0")
690 .expect("failed to parse address")
691 .generate()
692 .unwrap();
693
694 assert_eq!(connect_token.version_info, *NETCODE_VERSION);
695 assert_eq!(connect_token.protocol_id, protocol_id);
696 assert_eq!(connect_token.timeout_seconds, 5);
697 assert_eq!(
698 connect_token.expire_timestamp,
699 connect_token.create_timestamp + 6
700 );
701 connect_token
702 .server_addresses
703 .iter()
704 .zip(server_addresses.to_socket_addrs().into_iter().flatten())
705 .for_each(|((_, have), expected)| {
706 assert_eq!(have, expected);
707 });
708 }
709}