use crate::state::{self, State};
use gday_contact_exchange_protocol::{ClientMsg, ServerMsg, read_from_async, write_to_async};
use log::{error, info, warn};
use std::net::SocketAddr;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
net::TcpStream,
};
use tokio_rustls::TlsAcceptor;
pub async fn handle_connection(
mut tcp_stream: TcpStream,
origin: SocketAddr,
tls_acceptor: Option<TlsAcceptor>,
state: State,
) {
if let Some(tls_acceptor) = tls_acceptor {
let mut tls_stream = match tls_acceptor.accept(tcp_stream).await {
Ok(tls_stream) => tls_stream,
Err(err) => {
warn!("Error establishing TLS connection with '{origin}': {err}");
return;
}
};
let _ = handle_requests(&mut tls_stream, state, origin).await;
let _ = tls_stream.shutdown().await;
} else {
let _ = handle_requests(&mut tcp_stream, state, origin).await;
let _ = tcp_stream.shutdown().await;
}
}
async fn handle_requests(
stream: &mut (impl AsyncRead + AsyncWrite + Unpin),
mut state: State,
origin: SocketAddr,
) -> Result<(), HandleMessageError> {
loop {
let result = handle_message(stream, &mut state, origin).await;
match result {
Ok(()) => (),
Err(HandleMessageError::State(state::Error::NoSuchRoomCode)) => {
warn!("Replying with ServerMsg::ErrorNoSuchRoomCode.");
write_to_async(ServerMsg::ErrorNoSuchRoomCode, stream).await?;
}
Err(HandleMessageError::Receiver(_)) => {
warn!("Replying with ServerMsg::ErrorPeerTimedOut.");
write_to_async(ServerMsg::ErrorPeerTimedOut, stream).await?;
}
Err(HandleMessageError::State(state::Error::RoomCodeTaken)) => {
warn!("Replying with ServerMsg::ErrorRoomTaken.");
write_to_async(ServerMsg::ErrorRoomTaken, stream).await?;
}
Err(HandleMessageError::State(state::Error::TooManyRequests)) => {
warn!("Replying with ServerMsg::ErrorTooManyRequests and disconnecting.");
write_to_async(ServerMsg::ErrorTooManyRequests, stream).await?;
return result;
}
Err(HandleMessageError::State(state::Error::CantUpdateDoneClient)) => {
warn!("Replying with ServerMsg::ErrorUnexpectedMsg.");
write_to_async(ServerMsg::ErrorUnexpectedMsg, stream).await?;
}
Err(HandleMessageError::Protocol(ref err)) => {
warn!("Replying with ServerMsg::ErrorSyntax and disconnecting, because: {err}");
write_to_async(ServerMsg::ErrorSyntax, stream).await?;
return result;
}
Err(HandleMessageError::IO(_)) => {
info!("'{origin}' disconnected.");
return result;
}
}
}
}
async fn handle_message(
stream: &mut (impl AsyncRead + AsyncWrite + Unpin),
state: &mut State,
origin: SocketAddr,
) -> Result<(), HandleMessageError> {
let msg: ClientMsg = read_from_async(stream).await?;
match msg {
ClientMsg::CreateRoom { room_code } => {
state.create_room(room_code, origin.ip())?;
write_to_async(ServerMsg::RoomCreated, stream).await?;
}
ClientMsg::RecordPublicAddr {
room_code,
is_creator,
} => {
state.update_client(&room_code, is_creator, origin, true, origin.ip())?;
write_to_async(ServerMsg::ReceivedAddr, stream).await?;
}
ClientMsg::ReadyToShare {
room_code,
is_creator,
local_contact,
} => {
if let Some(sockaddr_v4) = local_contact.v4 {
state.update_client(
&room_code,
is_creator,
sockaddr_v4.into(),
false,
origin.ip(),
)?;
}
if let Some(sockaddr_v6) = local_contact.v6 {
state.update_client(
&room_code,
is_creator,
sockaddr_v6.into(),
false,
origin.ip(),
)?;
}
let (client_contact, rx) =
state.set_client_done(&room_code, is_creator, origin.ip())?;
write_to_async(ServerMsg::ClientContact(client_contact), stream).await?;
info!("Sent client '{origin}' their contact of '{client_contact}'.");
let peer_contact = rx.await?;
write_to_async(ServerMsg::PeerContact(peer_contact), stream).await?;
info!("Sent client '{origin}' their peer's contact of '{client_contact}'.");
}
}
Ok(())
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
enum HandleMessageError {
#[error("Serialization/deserialization error: {0}")]
Protocol(#[from] gday_contact_exchange_protocol::Error),
#[error("Error updating server state: {0}")]
State(#[from] state::Error),
#[error("Timed out while waiting for other peer to share contact: {0}")]
Receiver(#[from] tokio::sync::oneshot::error::RecvError),
#[error("IO Error: {0}")]
IO(#[from] std::io::Error),
}