use std::{
io,
net::SocketAddr,
sync::{Arc, RwLock},
time::{Duration, Instant},
};
use rand::rngs::OsRng;
use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::{
internal::{
auth::InnerAuth,
messages::{DeserializedMessage, PUBLIC_KEY_SIZE, UDP_BUFFER_SIZE},
node::{NodeInternal, NodeState, NodeType, PartnerMessaging},
rt::{try_lock, AsyncRwLock, TaskHandle, TaskRunner, UdpSocket},
MessageChannel,
},
packets::{ClientTickEndPacket, Packet, PacketRegistry, SerializedPacket},
LimitedMessage, MessagingProperties, ReadHandlerProperties, MESSAGE_CHANNEL_SIZE,
};
#[cfg(feature = "store_unexpected")]
use crate::internal::node::StoreUnexpectedErrors;
#[cfg(any(feature = "auth_tcp", feature = "auth_tls"))]
use crate::internal::auth::InnerAuthTcpBased;
pub use crate::internal::node::Partner as ConnectedServer;
pub use auth::*;
use init::*;
mod auth;
mod init;
#[derive(Debug)]
pub enum ReadServerBytesResult {
ServerReceivedBytes,
ServerMaxTickByteLenOverflow,
}
impl ReadServerBytesResult {
pub fn is_unexpected(&self) -> bool {
match self {
ReadServerBytesResult::ServerReceivedBytes => false,
ReadServerBytesResult::ServerMaxTickByteLenOverflow => true,
}
}
}
#[derive(Debug)]
pub enum ServerDisconnectReason {
PendingMessageConfirmationTimeout,
MessageReceiveTimeout,
WriteUnlockTimeout,
InvalidProtocolCommunication,
ByteSendError(io::Error),
ManualDisconnect,
DisconnectRequest(DeserializedMessage),
}
pub struct ClientProperties {
pub auth_packet_loss_interpretation: Duration,
}
impl Default for ClientProperties {
fn default() -> Self {
Self {
auth_packet_loss_interpretation: Duration::from_secs(3),
}
}
}
#[derive(Debug, PartialEq, Eq)]
enum ClientTickState {
TickStartPending,
TickAfterMessagePending,
}
#[cfg(feature = "store_unexpected")]
#[derive(Debug)]
pub enum UnexpectedError {
OfReadServerBytes(ReadServerBytesResult),
InvalidProtocolCommunication,
}
#[derive(Debug)]
pub struct ReceivedMessageClientTickResult {
pub message: DeserializedMessage,
#[cfg(feature = "store_unexpected")]
pub unexpected_errors: Vec<UnexpectedError>,
}
#[derive(Debug)]
pub enum ClientTickResult {
ReceivedMessage(ReceivedMessageClientTickResult),
PendingMessage,
Disconnected,
WriteLocked,
}
pub struct GracefullyDisconnection {
pub message: LimitedMessage,
pub timeout: Duration,
}
#[derive(Debug)]
pub enum ClientDisconnectState {
Confirmed,
ConfirmationTimeout,
WithoutReason,
SendIoError(io::Error),
ReceiveIoError(io::Error),
AlreadyDisconnected(Option<ServerDisconnectReason>),
}
struct ClientNode {
reason_to_disconnect_sender: async_channel::Sender<ServerDisconnectReason>,
reason_to_disconnect_receiver: async_channel::Receiver<ServerDisconnectReason>,
authentication_mode: ConnectedAuthenticatorMode,
tick_state: RwLock<ClientTickState>,
client_properties: Arc<ClientProperties>,
connected_server: Arc<ConnectedServer>,
disconnect_reason: RwLock<Option<Option<ServerDisconnectReason>>>,
}
impl ClientNode {
async fn read_next_bytes(node: &NodeInternal<Self>, bytes: Vec<u8>) -> ReadServerBytesResult {
let mut messaging = node.node_type.connected_server.messaging.lock().await;
messaging.tick_bytes_len += bytes.len() + 8 + 40;
if messaging.tick_bytes_len > node.messaging_properties.max_tick_bytes_len {
ReadServerBytesResult::ServerMaxTickByteLenOverflow
} else {
let _ = node
.node_type
.connected_server
.receiving_bytes_sender
.try_send(bytes);
ReadServerBytesResult::ServerReceivedBytes
}
}
}
impl NodeType for ClientNode {
type Skt = Vec<u8>;
#[cfg(feature = "store_unexpected")]
type UnEr = UnexpectedError;
async fn pre_read_next_bytes(socket: &Arc<UdpSocket>) -> io::Result<Self::Skt> {
let mut buf = [0u8; UDP_BUFFER_SIZE];
let len = socket.recv(&mut buf).await?;
Ok(buf[..len].to_vec())
}
async fn consume_read_bytes_result(node: &Arc<NodeInternal<Self>>, result: Self::Skt) {
let _read_result = Self::read_next_bytes(&node, result).await;
#[cfg(feature = "store_unexpected")]
if _read_result.is_unexpected() {
let _ = node
.store_unexpected_errors
.error_sender
.try_send(UnexpectedError::OfReadServerBytes(_read_result));
}
}
fn on_inactivated(node: &Arc<NodeInternal<Self>>) -> TaskHandle<()> {
let node_clone = Arc::clone(&node);
node.task_runner.spawn(async move {
NodeInternal::on_partner_disposed(&node_clone, &node_clone.node_type.connected_server)
.await;
})
}
}
pub struct Client {
internal: Arc<NodeInternal<ClientNode>>,
}
impl Client {
pub fn connect(
remote_addr: SocketAddr,
packet_registry: Arc<PacketRegistry>,
messaging_properties: Arc<MessagingProperties>,
read_handler_properties: Arc<ReadHandlerProperties>,
client_properties: Arc<ClientProperties>,
authenticator_mode: AuthenticatorMode,
#[cfg(any(feature = "rt_tokio", feature = "rt_async_executor"))]
runtime: crate::internal::rt::Runtime,
) -> TaskHandle<Result<ConnectResult, ConnectError>> {
#[cfg(any(feature = "rt_tokio", feature = "rt_async_executor"))]
let task_runner = Arc::new(TaskRunner { runtime });
#[cfg(not(any(feature = "rt_tokio", feature = "rt_async_executor")))]
let task_runner = Arc::new(TaskRunner {});
let task_runner_exit = Arc::clone(&task_runner);
let bind_result_body = async move {
if !packet_registry.check_essential() {
return Err(ConnectError::MissingEssentialPackets);
}
let client_private_key = EphemeralSecret::random_from_rng(OsRng);
let client_public_key = PublicKey::from(&client_private_key);
let client_public_key_bytes = client_public_key.as_bytes();
let mut public_key_sent = Vec::with_capacity(1 + client_public_key_bytes.len());
public_key_sent.push(MessageChannel::PUBLIC_KEY_SEND);
public_key_sent.extend_from_slice(client_public_key_bytes);
let mut buf = [0u8; UDP_BUFFER_SIZE];
let socket = match UdpSocket::bind("0.0.0.0:0").await {
Ok(socket) => Arc::new(socket),
Err(e) => {
return Err(ConnectError::SocketConnectError(e));
}
};
if let Err(e) = socket.connect(remote_addr).await {
return Err(ConnectError::SocketConnectError(e));
}
let (len, auth_message, connected_authentication_mode) =
connecting::connect_auth_mode_match_arm(
authenticator_mode,
&client_properties,
&socket,
&mut buf,
&public_key_sent,
)
.await?;
let bytes = &buf[..len];
match bytes[0] {
MessageChannel::PUBLIC_KEY_SEND => {
if len != (MESSAGE_CHANNEL_SIZE + PUBLIC_KEY_SIZE) {
return Err(ConnectError::InvalidProtocolCommunication);
}
return connecting::connect_public_key_send_match_arm(
socket,
buf,
client_private_key,
auth_message,
packet_registry,
messaging_properties,
read_handler_properties,
client_properties,
task_runner,
remote_addr,
connected_authentication_mode,
)
.await;
}
_ => Err(ConnectError::InvalidProtocolCommunication),
}
};
task_runner_exit.spawn(bind_result_body)
}
pub fn packet_registry(&self) -> &PacketRegistry {
&self.internal.packet_registry
}
pub fn messaging_properties(&self) -> &MessagingProperties {
&self.internal.messaging_properties
}
pub fn read_handler_properties(&self) -> &ReadHandlerProperties {
&self.internal.read_handler_properties
}
pub fn client_properties(&self) -> &ClientProperties {
&self.internal.node_type.client_properties
}
pub fn connected_server(&self) -> &ConnectedServer {
&self.internal.node_type.connected_server
}
pub fn auth_mode(&self) -> &ConnectedAuthenticatorMode {
&self.internal.node_type.authentication_mode
}
pub fn local_addr(&self) -> SocketAddr {
self.internal.socket.local_addr().unwrap()
}
pub fn try_tick_start(&self) -> Result<ClientTickResult, ()> {
let internal = &self.internal;
let node_type = &internal.node_type;
{
let tick_state = node_type.tick_state.read().unwrap();
if *tick_state != ClientTickState::TickStartPending {
return Err(());
}
}
if self.is_disconnected() {
return Ok(ClientTickResult::Disconnected);
}
if let Ok(reason) = node_type.reason_to_disconnect_receiver.try_recv() {
*node_type.disconnect_reason.write().unwrap() = Some(Some(reason));
return Ok(ClientTickResult::Disconnected);
}
let now = Instant::now();
let server = &node_type.connected_server;
if let Some(mut messaging) = try_lock(&server.messaging) {
*server.last_messaging_write.write().unwrap() = now;
*server.average_latency.write().unwrap() = messaging.latency_monitor.average_value();
let average_packet_loss_rtt = messaging.average_packet_loss_rtt;
let mut messages_to_resend: Vec<Arc<Vec<u8>>> = Vec::new();
for (sent_instant, pending_part_id_map) in messaging.pending_confirmation.values_mut() {
if now - *sent_instant > internal.messaging_properties.timeout_interpretation {
*node_type.disconnect_reason.write().unwrap() = Some(Some(
ServerDisconnectReason::PendingMessageConfirmationTimeout,
));
return Ok(ClientTickResult::Disconnected);
}
for sent_part in pending_part_id_map.values_mut() {
if now - sent_part.last_sent_time > average_packet_loss_rtt {
sent_part.last_sent_time = now;
messages_to_resend.push(Arc::clone(&sent_part.finished_bytes));
}
}
}
for finished_bytes in messages_to_resend {
server
.shared_socket_bytes_send_sender
.try_send(finished_bytes)
.unwrap();
}
if !messaging.received_messages.is_empty() {
let message = messaging.received_messages.remove(0);
{
let mut tick_state = node_type.tick_state.write().unwrap();
*tick_state = ClientTickState::TickAfterMessagePending;
}
messaging.tick_bytes_len = 0;
#[cfg(feature = "store_unexpected")]
let unexpected_errors = match internal
.store_unexpected_errors
.error_list_receiver
.try_recv()
{
Ok(list) => list,
Err(_) => Vec::new(),
};
#[cfg(feature = "store_unexpected")]
internal
.store_unexpected_errors
.create_list_signal_sender
.try_send(())
.unwrap();
return Ok(ClientTickResult::ReceivedMessage(
ReceivedMessageClientTickResult {
message,
#[cfg(feature = "store_unexpected")]
unexpected_errors,
},
));
} else if now - messaging.last_received_message_instant
>= internal.messaging_properties.timeout_interpretation
{
*node_type.disconnect_reason.write().unwrap() =
Some(Some(ServerDisconnectReason::MessageReceiveTimeout));
return Ok(ClientTickResult::Disconnected);
} else {
return Ok(ClientTickResult::PendingMessage);
}
} else if now - *server.last_messaging_write.read().unwrap()
>= internal.messaging_properties.timeout_interpretation
{
*node_type.disconnect_reason.write().unwrap() =
Some(Some(ServerDisconnectReason::WriteUnlockTimeout));
return Ok(ClientTickResult::Disconnected);
} else {
return Ok(ClientTickResult::WriteLocked);
}
}
#[cfg(not(feature = "no_panics"))]
pub fn tick_start(&self) -> ClientTickResult {
self.try_tick_start().expect("Invalid client tick state.")
}
pub fn try_tick_after_message(&self) -> Result<(), ()> {
let internal = &self.internal;
let node_type = &internal.node_type;
{
let mut tick_state = node_type.tick_state.write().unwrap();
if *tick_state != ClientTickState::TickAfterMessagePending {
return Err(());
} else {
*tick_state = ClientTickState::TickStartPending;
}
}
let tick_packet_serialized = internal
.packet_registry
.try_serialize(&ClientTickEndPacket)
.unwrap();
let connected_server = &node_type.connected_server;
self.send_packet_serialized(tick_packet_serialized.clone());
connected_server
.packets_to_send_sender
.try_send(None)
.unwrap();
Ok(())
}
#[cfg(not(feature = "no_panics"))]
pub fn tick_after_message(&self) {
self.try_tick_after_message()
.expect("Invalid client tick state.")
}
pub fn disconnect(
self,
disconnection: Option<GracefullyDisconnection>,
) -> TaskHandle<ClientDisconnectState> {
let tasks_keeper_exit = Arc::clone(&self.internal.task_runner);
tasks_keeper_exit.spawn(async move {
NodeInternal::set_state_inactive(&self.internal).await;
if self.is_disconnected() {
return ClientDisconnectState::AlreadyDisconnected(self.take_disconnect_reason());
}
if let Some(disconnection) = disconnection {
let socket = Arc::clone(&self.internal.socket);
let timeout_interpretation = disconnection.timeout;
let packet_loss_timeout = self
.internal
.node_type
.connected_server
.messaging
.lock()
.await
.average_packet_loss_rtt
.min(timeout_interpretation);
let rejection_context = self
.internal
.node_type
.connected_server
.inner_auth
.rejection_of(Instant::now(), disconnection.message);
let rejection_confirm_bytes = &vec![MessageChannel::REJECTION_CONFIRM];
let disconnect_state = loop {
let now = Instant::now();
if now - rejection_context.rejection_instant > timeout_interpretation {
break ClientDisconnectState::ConfirmationTimeout;
}
if let Err(e) = socket.send(&rejection_context.finished_bytes).await {
break ClientDisconnectState::SendIoError(e);
}
let pre_read_next_bytes_result =
ClientNode::pre_read_next_bytes_timeout(&socket, packet_loss_timeout).await;
match pre_read_next_bytes_result {
Ok(result) => {
if &result == rejection_confirm_bytes {
break ClientDisconnectState::Confirmed;
}
}
Err(e) if e.kind() == io::ErrorKind::TimedOut => {}
Err(e) => break ClientDisconnectState::ReceiveIoError(e),
}
};
disconnect_state
} else {
ClientDisconnectState::WithoutReason
}
})
}
pub fn try_send_packet<P: Packet>(&self, packet: &P) -> Result<(), io::Error> {
let internal = &self.internal;
let serialized = internal.packet_registry.try_serialize(packet)?;
self.send_packet_serialized(serialized);
Ok(())
}
#[cfg(not(feature = "no_panics"))]
pub fn send_packet<P: Packet>(&self, packet: &P) {
self.try_send_packet(packet)
.expect("Failed to send packet.");
}
pub fn send_packet_serialized(&self, packet_serialized: SerializedPacket) {
let internal = &self.internal;
internal
.node_type
.connected_server
.packets_to_send_sender
.try_send(Some(packet_serialized))
.unwrap();
}
pub fn is_disconnected(&self) -> bool {
let internal = &self.internal;
let disconnect_reason = internal.node_type.disconnect_reason.read().unwrap();
disconnect_reason.is_some()
}
pub fn take_disconnect_reason(&self) -> Option<ServerDisconnectReason> {
let internal = &self.internal;
let mut disconnect_reason = internal.node_type.disconnect_reason.write().unwrap();
if let Some(ref mut is_disconnected) = *disconnect_reason {
if let Some(reason_was_not_taken) = is_disconnected.take() {
Some(reason_was_not_taken)
} else {
None
}
} else {
None
}
}
}
impl Drop for Client {
fn drop(&mut self) {
NodeInternal::on_holder_drop(&self.internal);
}
}