use std::{
collections::BTreeMap,
fmt, io,
net::SocketAddr,
sync::{Arc, RwLock},
time::{Duration, Instant},
};
#[cfg(any(feature = "auth_tcp", feature = "auth_tls"))]
use chacha20poly1305::{
aead::AeadCore, aead::KeyInit, ChaCha20Poly1305, ChaChaPoly1305, Key, Nonce,
};
use x25519_dalek::EphemeralSecret;
use crate::{
internal::{
messages::{DeserializedMessage, MessagePartMap},
rt::{Mutex, UdpSocket},
utils::{DurationMonitor, RttCalculator},
MessageChannel,
},
packets::{ClientTickEndPacket, PacketRegistry},
};
#[cfg(any(feature = "auth_tcp", feature = "auth_tls"))]
use crate::{MessagingProperties, ReadHandlerProperties, SentMessagePart, MESSAGE_CHANNEL_SIZE};
use super::*;
#[cfg(feature = "auth_tcp")]
use crate::auth_tcp::AuthTcpClientProperties;
#[cfg(feature = "auth_tls")]
use crate::auth_tls::{AuthTlsClientProperties, TlsConnector};
#[cfg(feature = "auth_tls")]
use crate::auth_tls::rustls;
#[cfg(any(feature = "auth_tcp", feature = "auth_tls"))]
use crate::internal::rt::{AsyncReadExt, AsyncWriteExt, TcpStream};
pub struct AuthenticationProperties {
pub message: LimitedMessage,
pub timeout: Duration,
}
pub enum AuthenticatorMode {
NoCryptography(AuthenticationProperties),
#[cfg(feature = "auth_tcp")]
RequireTcp(AuthenticationProperties, AuthTcpClientProperties),
#[cfg(feature = "auth_tls")]
RequireTls(AuthenticationProperties, AuthTlsClientProperties),
AttemptList(Vec<AuthenticatorMode>),
}
pub enum ConnectedAuthenticatorMode {
NoCryptography,
#[cfg(feature = "auth_tcp")]
RequireTcp,
#[cfg(feature = "auth_tls")]
RequireTls,
}
pub struct ConnectResult {
pub client: Client,
pub initial_message: DeserializedMessage,
}
#[derive(Debug)]
pub enum ConnectError {
MissingEssentialPackets,
Timeout,
InvalidProtocolCommunication,
InvalidDnsName,
SocketConnectError(io::Error),
AuthenticatorConnectIoError(io::Error),
AuthenticatorWriteIoError(io::Error),
AuthenticatorReadIoError(io::Error),
AuthenticationBytesSendIoError(io::Error),
Disconnected(DisconnectedConnectError),
AllAttemptsFailed(Vec<ConnectError>),
}
impl fmt::Display for ConnectError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ConnectError::MissingEssentialPackets => write!(
f,
"Packet registry has not registered the essential packets."
),
ConnectError::Timeout => write!(f, "Server took a long time to respond."),
ConnectError::InvalidProtocolCommunication => {
write!(f, "Server did not communicate correctly.")
}
ConnectError::InvalidDnsName => write!(f, "Invalid dns name."),
ConnectError::SocketConnectError(e) => write!(f, "Failed to bind UDP socket: {}", e),
ConnectError::AuthenticatorConnectIoError(ref err) => {
write!(f, "Authenticator connect IO error: {}", err)
}
ConnectError::AuthenticatorWriteIoError(ref err) => {
write!(f, "Authenticator write IO error: {}", err)
}
ConnectError::AuthenticatorReadIoError(ref err) => {
write!(f, "Authenticator read IO error: {}", err)
}
ConnectError::AuthenticationBytesSendIoError(ref err) => {
write!(
f,
"IO error sending authentication bytes by socket: {}",
err
)
}
ConnectError::Disconnected(reason) => write!(f, "Client disconnected: {:?}", reason),
Self::AllAttemptsFailed(errors) => write!(f, "All attempts failed: {:?}", errors),
}
}
}
impl std::error::Error for ConnectError {}
#[derive(Debug)]
pub struct DisconnectedConnectError {
pub reason: ServerDisconnectReason,
#[cfg(feature = "store_unexpected")]
pub unexpected_errors: Vec<UnexpectedError>,
}
pub(super) mod connecting {
use crate::internal::{
messages::{PUBLIC_KEY_SIZE, UDP_BUFFER_SIZE},
node::{ActiveAwaitableHandler, ActiveCancelableHandler, ActiveDisposableHandler},
};
use super::*;
pub async fn connect_auth_mode_match_arm(
authenticator_mode: AuthenticatorMode,
client_properties: &ClientProperties,
socket: &UdpSocket,
buf: &mut [u8; UDP_BUFFER_SIZE],
public_key_sent: &Vec<u8>,
) -> Result<(usize, LimitedMessage, ConnectedAuthenticatorMode), ConnectError> {
Ok(match authenticator_mode {
AuthenticatorMode::NoCryptography(props) => {
connect_no_cryptography_match_arm(&socket, buf, &public_key_sent, props).await?
}
#[cfg(feature = "auth_tcp")]
AuthenticatorMode::RequireTcp(props, auth_mode) => {
connect_require_tcp_match_arm(buf, &public_key_sent, auth_mode, props).await?
}
#[cfg(feature = "auth_tls")]
AuthenticatorMode::RequireTls(props, auth_mode) => {
connect_require_tls_match_arm(buf, &public_key_sent, auth_mode, props).await?
}
AuthenticatorMode::AttemptList(modes) => {
Box::pin(connect_attempt_list_match_arm(
&client_properties,
&socket,
buf,
&public_key_sent,
modes,
))
.await?
}
})
}
pub async fn connect_no_cryptography_match_arm(
socket: &UdpSocket,
buf: &mut [u8; UDP_BUFFER_SIZE],
public_key_sent: &Vec<u8>,
props: AuthenticationProperties,
) -> Result<(usize, LimitedMessage, ConnectedAuthenticatorMode), ConnectError> {
let sent_time = Instant::now();
loop {
let now = Instant::now();
if now - sent_time > props.timeout {
return Err(ConnectError::Timeout);
}
if let Err(e) = socket.send(&public_key_sent).await {
return Err(ConnectError::AuthenticatorWriteIoError(e));
}
match crate::internal::rt::timeout(props.timeout, socket.recv(buf)).await {
Ok(len) => match len {
Ok(len) => {
if len < MESSAGE_CHANNEL_SIZE {
return Err(ConnectError::InvalidProtocolCommunication);
}
match buf[0] {
MessageChannel::PUBLIC_KEY_SEND => {
break Ok((
len,
props.message,
ConnectedAuthenticatorMode::NoCryptography,
));
}
_ => (),
}
}
Err(e) => {
return Err(ConnectError::AuthenticatorReadIoError(e));
}
},
_ => (),
}
}
}
#[cfg(any(feature = "auth_tcp", feature = "auth_tls"))]
pub async fn connect_require_tcp_based_match_arm<T>(
buf: &mut [u8; UDP_BUFFER_SIZE],
public_key_sent: &Vec<u8>,
mut stream: T,
) -> Result<usize, ConnectError>
where
T: AsyncReadExt,
T: AsyncWriteExt,
T: Unpin,
{
if let Err(e) = stream.write_all(&public_key_sent).await {
return Err(ConnectError::AuthenticatorWriteIoError(e));
}
let len = match stream.read(buf).await {
Ok(0) => return Err(ConnectError::InvalidProtocolCommunication),
Ok(len) => len,
Err(e) => {
return Err(ConnectError::AuthenticatorReadIoError(e));
}
};
if len < MESSAGE_CHANNEL_SIZE {
return Err(ConnectError::InvalidProtocolCommunication);
}
Ok(len)
}
#[cfg(feature = "auth_tcp")]
pub async fn connect_require_tcp_match_arm(
buf: &mut [u8; UDP_BUFFER_SIZE],
public_key_sent: &Vec<u8>,
auth_mode: AuthTcpClientProperties,
props: AuthenticationProperties,
) -> Result<(usize, LimitedMessage, ConnectedAuthenticatorMode), ConnectError> {
match crate::internal::rt::timeout(props.timeout, async {
let tcp_stream = match TcpStream::connect(auth_mode.server_addr).await {
Ok(tcp_stream) => tcp_stream,
Err(e) => {
return Err(ConnectError::AuthenticatorConnectIoError(e));
}
};
connect_require_tcp_based_match_arm(buf, public_key_sent, tcp_stream).await
})
.await
{
Ok(len) => Ok((len?, props.message, ConnectedAuthenticatorMode::RequireTcp)),
Err(_) => return Err(ConnectError::Timeout),
}
}
#[cfg(feature = "auth_tls")]
pub async fn connect_require_tls_match_arm(
buf: &mut [u8; UDP_BUFFER_SIZE],
public_key_sent: &Vec<u8>,
auth_mode: AuthTlsClientProperties,
props: AuthenticationProperties,
) -> Result<(usize, LimitedMessage, ConnectedAuthenticatorMode), ConnectError> {
match crate::internal::rt::timeout(props.timeout, async {
let server_name = match rustls::pki_types::ServerName::try_from(auth_mode.server_name) {
Ok(server_name) => server_name,
Err(_) => return Err(ConnectError::InvalidDnsName),
};
let config = Arc::new(auth_mode.new_client_config());
let connector = TlsConnector::from(config);
let stream = match TcpStream::connect(auth_mode.server_addr).await {
Ok(tcp_stream) => tcp_stream,
Err(e) => {
return Err(ConnectError::AuthenticatorConnectIoError(e));
}
};
let tls_stream = match connector.connect(server_name, stream).await {
Ok(tls_stream) => tls_stream,
Err(e) => {
return Err(ConnectError::AuthenticatorConnectIoError(e));
}
};
connect_require_tcp_based_match_arm(buf, public_key_sent, tls_stream).await
})
.await
{
Ok(len) => Ok((len?, props.message, ConnectedAuthenticatorMode::RequireTls)),
Err(_) => return Err(ConnectError::Timeout),
}
}
pub async fn connect_attempt_list_match_arm(
client_properties: &ClientProperties,
socket: &UdpSocket,
buf: &mut [u8; UDP_BUFFER_SIZE],
public_key_sent: &Vec<u8>,
modes: Vec<AuthenticatorMode>,
) -> Result<(usize, LimitedMessage, ConnectedAuthenticatorMode), ConnectError> {
let mut errors = Vec::<ConnectError>::new();
for mode in modes {
match connect_auth_mode_match_arm(
mode,
&client_properties,
&socket,
buf,
&public_key_sent,
)
.await
{
Ok(result) => return Ok(result),
Err(e) => errors.push(e),
}
}
Err(ConnectError::AllAttemptsFailed(errors))
}
pub async fn connect_public_key_send_match_arm(
socket: Arc<UdpSocket>,
buf: [u8; UDP_BUFFER_SIZE],
_client_private_key: EphemeralSecret,
message: LimitedMessage,
packet_registry: Arc<PacketRegistry>,
messaging_properties: Arc<MessagingProperties>,
read_handler_properties: Arc<ReadHandlerProperties>,
client_properties: Arc<ClientProperties>,
task_runner: Arc<TaskRunner>,
remote_addr: SocketAddr,
connected_auth_mode: ConnectedAuthenticatorMode,
) -> Result<ConnectResult, ConnectError> {
let mut server_public_key_bytes: [u8; PUBLIC_KEY_SIZE] = [0; PUBLIC_KEY_SIZE];
server_public_key_bytes
.copy_from_slice(&buf[MESSAGE_CHANNEL_SIZE..(MESSAGE_CHANNEL_SIZE + PUBLIC_KEY_SIZE)]);
let (authentication_bytes, inner_auth) = match &connected_auth_mode {
&ConnectedAuthenticatorMode::NoCryptography => {
let mut list_bytes = message.to_list().bytes;
let mut authentication_bytes =
Vec::with_capacity(MESSAGE_CHANNEL_SIZE + PUBLIC_KEY_SIZE + list_bytes.len());
authentication_bytes.push(MessageChannel::AUTH_MESSAGE);
authentication_bytes.extend_from_slice(&server_public_key_bytes);
authentication_bytes.append(&mut list_bytes);
(authentication_bytes, InnerAuth::NoCryptography)
}
#[cfg(feature = "auth_tcp")]
ConnectedAuthenticatorMode::RequireTcp => {
let (cipher, authentication_bytes) =
connect_auth_cipher_arm(server_public_key_bytes, _client_private_key, message);
(
authentication_bytes,
InnerAuth::RequireTcp(InnerAuthTcpBased { cipher }),
)
}
#[cfg(feature = "auth_tls")]
ConnectedAuthenticatorMode::RequireTls => {
let (cipher, authentication_bytes) =
connect_auth_cipher_arm(server_public_key_bytes, _client_private_key, message);
(
authentication_bytes,
InnerAuth::RequireTls(InnerAuthTcpBased { cipher }),
)
}
};
let (reason_to_disconnect_sender, reason_to_disconnect_receiver) =
async_channel::bounded(1);
let (receiving_bytes_sender, receiving_bytes_receiver) = async_channel::unbounded();
let (packets_to_send_sender, packets_to_send_receiver) = async_channel::unbounded();
let (message_part_confirmation_sender, message_part_confirmation_receiver) =
async_channel::unbounded();
let (shared_socket_bytes_send_sender, shared_socket_bytes_send_receiver) =
async_channel::unbounded();
let (awaitable_tasks_sender, awaitable_tasks_receiver) = async_channel::unbounded();
let messaging = Mutex::new(PartnerMessaging {
pending_confirmation: BTreeMap::new(),
incoming_messages: MessagePartMap::new(
messaging_properties.initial_next_message_part_id,
),
tick_bytes_len: 0,
last_received_message_instant: Instant::now(),
received_messages: Vec::new(),
packet_loss_rtt_calculator: RttCalculator::new(messaging_properties.initial_latency),
average_packet_loss_rtt: messaging_properties.initial_latency,
latency_monitor: DurationMonitor::try_filled_with(
messaging_properties.initial_latency,
16,
)
.unwrap(),
});
let server = Arc::new(ConnectedServer {
disposable_handlers_keeper: Mutex::new(Vec::new()),
receiving_bytes_sender,
packets_to_send_sender,
message_part_confirmation_sender,
shared_socket_bytes_send_sender,
addr: remote_addr,
inner_auth,
messaging,
last_messaging_write: RwLock::new(Instant::now()),
average_latency: RwLock::new(messaging_properties.initial_latency),
incoming_messages_total_size: RwLock::new(0),
});
#[cfg(feature = "store_unexpected")]
let (store_unexpected_errors, store_unexpected_errors_create_list_signal_receiver) =
StoreUnexpectedErrors::new();
let client = Client {
internal: Arc::new(NodeInternal {
disposable_handlers_keeper: Mutex::new(Vec::new()),
cancelable_handlers_keeper: Mutex::new(Vec::new()),
awaitable_tasks_sender,
socket: Arc::clone(&socket),
#[cfg(feature = "store_unexpected")]
store_unexpected_errors,
packet_registry: Arc::clone(&packet_registry),
messaging_properties: Arc::clone(&messaging_properties),
read_handler_properties: Arc::clone(&read_handler_properties),
task_runner,
state: AsyncRwLock::new(NodeState::Active),
node_type: ClientNode {
reason_to_disconnect_sender,
reason_to_disconnect_receiver,
authentication_mode: connected_auth_mode,
tick_state: RwLock::new(ClientTickState::TickStartPending),
client_properties: Arc::clone(&client_properties),
connected_server: Arc::clone(&server),
disconnect_reason: RwLock::new(None),
},
}),
};
let internal = &client.internal;
let tick_packet_serialized = packet_registry.try_serialize(&ClientTickEndPacket).unwrap();
let connected_server = &internal.node_type.connected_server;
client.send_packet_serialized(tick_packet_serialized.clone());
connected_server
.packets_to_send_sender
.try_send(None)
.unwrap();
let mut disposable_handlers_keeper = server.disposable_handlers_keeper.lock().await;
let client_downgraded = Arc::downgrade(&internal);
let server_downgraded = Arc::downgrade(&server);
disposable_handlers_keeper.push(ActiveDisposableHandler {
task: internal
.task_runner
.spawn(server::create_receiving_bytes_handler(
client_downgraded,
server_downgraded,
receiving_bytes_receiver,
)),
});
let client_downgraded = Arc::downgrade(&internal);
let server_downgraded = Arc::downgrade(&server);
let initial_next_message_part_id =
internal.messaging_properties.initial_next_message_part_id;
disposable_handlers_keeper.push(ActiveDisposableHandler {
task: internal
.task_runner
.spawn(server::create_packets_to_send_handler(
client_downgraded,
server_downgraded,
packets_to_send_receiver,
initial_next_message_part_id,
)),
});
let client_downgraded = Arc::downgrade(&internal);
disposable_handlers_keeper.push(ActiveDisposableHandler {
task: internal
.task_runner
.spawn(server::create_message_part_confirmation_handler(
client_downgraded,
message_part_confirmation_receiver,
)),
});
let client_downgraded = Arc::downgrade(&internal);
disposable_handlers_keeper.push(ActiveDisposableHandler {
task: internal
.task_runner
.spawn(server::create_shared_socket_bytes_send_handler(
client_downgraded,
shared_socket_bytes_send_receiver,
)),
});
#[cfg(feature = "store_unexpected")]
{
let client_downgraded = Arc::downgrade(&internal);
disposable_handlers_keeper.push(ActiveDisposableHandler {
task: internal.task_runner.spawn(
init::client::create_store_unexpected_error_list_handler(
client_downgraded,
store_unexpected_errors_create_list_signal_receiver,
),
),
});
}
let sent_time = Instant::now();
let packet_loss_timeout = client_properties
.auth_packet_loss_interpretation
.min(messaging_properties.timeout_interpretation);
loop {
let now = Instant::now();
if now - sent_time > messaging_properties.timeout_interpretation {
return Err(ConnectError::Timeout);
}
if let Err(e) = socket.send(&authentication_bytes).await {
return Err(ConnectError::AuthenticationBytesSendIoError(e));
}
let pre_read_next_bytes_result =
ClientNode::pre_read_next_bytes_timeout(&socket, packet_loss_timeout).await;
match pre_read_next_bytes_result {
Ok(result) => {
let _read_result = ClientNode::read_next_bytes(internal, result).await;
#[cfg(feature = "store_unexpected")]
if _read_result.is_unexpected() {
let _ = internal
.store_unexpected_errors
.error_sender
.try_send(UnexpectedError::OfReadServerBytes(_read_result));
}
}
Err(_) => {}
}
match client.try_tick_start().unwrap() {
ClientTickResult::ReceivedMessage(tick_result) => {
client.try_tick_after_message().unwrap();
let mut cancelable_handlers_keeper =
internal.cancelable_handlers_keeper.lock().await;
{
for _ in 0..internal.read_handler_properties.target_tasks_size {
let (cancel_sender, cancel_receiver) = async_channel::bounded(1);
cancelable_handlers_keeper.push(ActiveCancelableHandler {
cancel_sender,
task: internal.task_runner.spawn(NodeType::create_read_handler(
Arc::downgrade(&internal),
Arc::clone(&internal.socket),
cancel_receiver,
)),
});
}
}
{
let (cancel_sender, cancel_receiver) = async_channel::bounded(1);
cancelable_handlers_keeper.push(ActiveCancelableHandler {
cancel_sender,
task: internal.task_runner.spawn(
ActiveAwaitableHandler::create_holder(
awaitable_tasks_receiver,
cancel_receiver,
),
),
});
}
drop(cancelable_handlers_keeper);
return Ok(ConnectResult {
client,
initial_message: tick_result.message,
});
}
ClientTickResult::PendingMessage => (),
ClientTickResult::Disconnected => {
#[cfg(feature = "store_unexpected")]
let unexpected_errors =
NodeType::store_unexpected_error_list_pick(&client.internal).await;
return Err(ConnectError::Disconnected(DisconnectedConnectError {
reason: client.take_disconnect_reason().unwrap(),
#[cfg(feature = "store_unexpected")]
unexpected_errors,
}));
}
ClientTickResult::WriteLocked => (),
}
}
}
#[cfg(any(feature = "auth_tcp", feature = "auth_tls"))]
fn connect_auth_cipher_arm(
server_public_key_bytes: [u8; 32],
_client_private_key: EphemeralSecret,
message: LimitedMessage,
) -> (ChaCha20Poly1305, Vec<u8>) {
let list = message.to_list();
let server_public_key = x25519_dalek::PublicKey::from(server_public_key_bytes);
let shared_key = _client_private_key.diffie_hellman(&server_public_key);
let cipher = ChaChaPoly1305::new(Key::from_slice(shared_key.as_bytes()));
let nonce: Nonce = ChaCha20Poly1305::generate_nonce(&mut rand::rngs::OsRng);
let mut cipher_bytes =
SentMessagePart::cryptograph_message_part(&list.bytes, &cipher, &nonce);
let mut authentication_bytes = Vec::with_capacity(
MESSAGE_CHANNEL_SIZE + server_public_key_bytes.len() + nonce.len() + list.bytes.len(),
);
authentication_bytes.push(MessageChannel::AUTH_MESSAGE);
authentication_bytes.extend_from_slice(&server_public_key_bytes);
authentication_bytes.extend_from_slice(&nonce);
authentication_bytes.append(&mut cipher_bytes);
(cipher, authentication_bytes)
}
}