use core::future::Future;
use core::net::SocketAddr;
use std::collections::HashMap;
use std::io::{self, ErrorKind};
use common::*;
use futures::io::{ReadHalf, WriteHalf};
use futures::{
pin_mut,
select,
AsyncRead,
AsyncReadExt,
AsyncWrite,
FutureExt,
};
use tokio::select as race;
use tokio::sync::{mpsc, oneshot};
use crate::{
broadcast,
PeerMessage,
SessionMessage,
StreamReader,
StreamWriter,
};
pub(crate) struct Connection<Stream> {
is_host: bool,
peer_id: PeerId,
session_sender: mpsc::Sender<SessionMessage>,
session_id: SessionId,
state_senders: HashMap<PeerId, mpsc::Sender<SessionResponse>>,
stream_reader: StreamReader<ReadHalf<Stream>>,
stream_writer: StreamWriter<WriteHalf<Stream>>,
}
impl<Stream> Connection<Stream>
where
Stream: AsyncRead + AsyncWrite + Send + Unpin,
{
#[inline]
async fn handle(
&mut self,
peer_sender: &PeerSender,
peer_receiver: &mut PeerReceiver,
) -> Result<(), ConnectionError> {
let mut peer_relay_fut = None;
loop {
match &mut peer_relay_fut {
Some(future) => race! {
() = future => {
peer_relay_fut = None;
},
peer_msg = peer_receiver.recv() => {
self.handle_peer_msg(peer_msg?).await?;
},
},
None => race! {
client_msg = self.stream_reader.read() => {
let client_msg = client_msg?;
if let ClientMessage::PeerLeft(_) = &client_msg {
return Ok(());
}
if let Some(client_msg) =
self.handle_client_msg(client_msg).await
{
while let Some(msg) = peer_receiver.try_recv()? {
self.handle_peer_msg(msg).await?;
};
let future = peer_sender.relay(client_msg);
peer_relay_fut = Some(Box::pin(future));
}
},
peer_msg = peer_receiver.recv() => {
self.handle_peer_msg(peer_msg?).await?;
},
},
};
}
}
#[inline]
async fn handle_client_msg(
&mut self,
client_msg: ClientMessage,
) -> Option<ClientMessage> {
if let ClientMessage::SessionResponse(response) = client_msg {
if let Some(sender) =
self.state_senders.remove(&response.requested_by())
{
if let Err(err) = sender.send(response).await {
todo!("log {err:?}");
}
}
return None;
}
Some(client_msg)
}
#[inline]
async fn handle_peer_msg(
&mut self,
peer_msg: PeerMessage,
) -> Result<(), ConnectionError> {
let server_msg = match peer_msg {
PeerMessage::GetClientState { requested_by, sender } => {
if self.is_host {
self.state_senders.insert(requested_by, sender);
ServerMessage::SessionStateRequest { requested_by }
} else {
return Ok(());
}
},
PeerMessage::PeerJoined(peer_id) => {
ServerMessage::PeerJoined(peer_id)
},
PeerMessage::PeerLeft(peer_id) => ServerMessage::PeerLeft(peer_id),
PeerMessage::Relay(client_msg) => {
ServerMessage::Relayed(client_msg)
},
PeerMessage::RelayMany(client_msgs) => {
for msg in client_msgs.into_iter().map(ServerMessage::Relayed)
{
self.stream_writer.write(msg).await?;
}
return Ok(());
},
};
self.stream_writer.write(server_msg).await?;
Ok(())
}
#[inline]
async fn init(
mut stream: Stream,
session_sender: mpsc::Sender<SessionMessage>,
) -> Result<(Self, PeerSender, PeerReceiver), ConnectionError> {
let (read, _) = (&mut stream).split();
let init_request =
StreamReader::new(read).read_other::<InitRequest>().await?;
match init_request {
InitRequest::Start(peer_id) => {
Self::start_session(session_sender, stream, peer_id).await
},
InitRequest::Join(session_id, peer_id) => {
Self::join_session(session_sender, stream, session_id, peer_id)
.await
},
}
}
#[inline]
async fn join_session(
session_sender: mpsc::Sender<SessionMessage>,
stream: Stream,
session_id: SessionId,
peer_id: PeerId,
) -> Result<(Self, PeerSender, PeerReceiver), ConnectionError> {
let (read, write) = stream.split();
let mut stream_writer = StreamWriter::new(write);
let peer_sender =
match join_session(&session_sender, peer_id, session_id).await {
Ok(ok) => ok,
Err(err) => {
stream_writer
.write_other::<Either<_, ()>>(Either::Left(err))
.await?;
return Err(ConnectionError::ClientDisconnected);
},
};
let mut peer_receiver = peer_sender.subscribe();
let mut connection = Self {
is_host: false,
peer_id,
stream_reader: StreamReader::new(read),
session_sender,
session_id,
state_senders: HashMap::new(),
stream_writer,
};
let mut buffer = Vec::new();
let Some(session_state) = connection
.request_session_state(
&peer_sender,
&mut peer_receiver,
&mut buffer,
)
.await?
else {
let err = InitError::HostDisconnected;
connection
.stream_writer
.write_other::<Either<_, ()>>(Either::Left(err))
.await?;
return Err(ConnectionError::ClientDisconnected);
};
let response = InitResponse::Joined(session_state);
connection
.stream_writer
.write_other::<Either<InitError, _>>(Either::Right(response))
.await?;
let send_joined = peer_sender.send(PeerMessage::peer_joined(peer_id));
wait_while_receiving_peer_msgs(
send_joined,
&mut peer_receiver,
&mut buffer,
)
.await;
for peer_msg in buffer {
connection.handle_peer_msg(peer_msg).await?;
}
Ok((connection, peer_sender, peer_receiver))
}
#[inline]
async fn request_session_state(
&mut self,
peer_sender: &PeerSender,
peer_receiver: &mut PeerReceiver,
buffer: &mut Vec<PeerMessage>,
) -> Result<Option<Session>, ConnectionError> {
let (send, mut recv) = mpsc::channel(1);
let peer_msg = PeerMessage::get_client_state(self.peer_id, send);
wait_while_receiving_peer_msgs(
peer_sender.send(peer_msg),
peer_receiver,
buffer,
)
.await;
let Some(session_response) =
wait_while_receiving_peer_msgs(recv.recv(), peer_receiver, buffer)
.await
else {
return Ok(None);
};
Ok(Some(session_response.into_session()))
}
#[inline]
pub(crate) async fn run(
stream: Stream,
client_addr: SocketAddr,
session_sender: mpsc::Sender<SessionMessage>,
) {
match Self::run_inner(stream, session_sender).await {
Ok(()) | Err(ConnectionError::ClientDisconnected) => {},
Err(err) => eprintln!(
"error while running connection with {client_addr:?}: {err}",
),
}
}
#[inline]
async fn run_inner(
stream: Stream,
session_sender: mpsc::Sender<SessionMessage>,
) -> Result<(), ConnectionError> {
let (mut this, peer_sender, mut peer_receiver) =
Self::init(stream, session_sender).await?;
this.handle(&peer_sender, &mut peer_receiver).await
}
#[inline]
async fn start_session(
session_sender: mpsc::Sender<SessionMessage>,
stream: Stream,
peer_id: PeerId,
) -> Result<(Self, PeerSender, PeerReceiver), ConnectionError> {
let (session_id, peer_sender, peer_receiver) = {
let (send, recv) = oneshot::channel();
let message = SessionMessage::StartSession(peer_id, send);
session_sender.send(message).await.unwrap();
recv.await.unwrap()
};
let response = InitResponse::Started(session_id);
let (read, write) = stream.split();
let mut stream_writer = StreamWriter::new(write);
stream_writer
.write_other::<Either<InitError, _>>(Either::Right(response))
.await?;
let connection = Self {
is_host: true,
peer_id,
stream_reader: StreamReader::new(read),
session_sender,
session_id,
state_senders: HashMap::new(),
stream_writer,
};
Ok((connection, peer_sender, peer_receiver))
}
}
impl<Stream> Drop for Connection<Stream> {
#[inline]
fn drop(&mut self) {
let msg = SessionMessage::PeerLeft(self.peer_id, self.session_id);
let sender = self.session_sender.clone();
tokio::spawn(async move { sender.send(msg).await });
}
}
pub(crate) struct PeerReceiver {
inner: broadcast::Receiver<PeerMessage>,
peer_id: PeerId,
}
impl PeerReceiver {
#[inline]
pub(crate) fn new(
inner: broadcast::Receiver<PeerMessage>,
peer_id: PeerId,
) -> Self {
Self { inner, peer_id }
}
#[inline]
async fn recv(&mut self) -> Result<PeerMessage, ConnectionError> {
loop {
match self.inner.recv().await {
Ok(peer_msg) => {
if peer_msg.sent_by() != self.peer_id {
return Ok(peer_msg);
};
},
Err(broadcast::RecvError::Lagged(num_lost)) => {
return Err(ConnectionError::ReceiverLagged {
num_lost,
peer_id: self.peer_id,
});
},
Err(broadcast::RecvError::Closed) => {
unreachable!("the `Session` has a `Sender`")
},
}
}
}
#[inline]
fn try_recv(&mut self) -> Result<Option<PeerMessage>, ConnectionError> {
loop {
match self.inner.try_recv() {
Ok(peer_msg) => {
if peer_msg.sent_by() != self.peer_id {
return Ok(Some(peer_msg));
};
},
Err(broadcast::TryRecvError::Lagged(num_lost)) => {
return Err(ConnectionError::ReceiverLagged {
num_lost,
peer_id: self.peer_id,
})
},
Err(broadcast::TryRecvError::Empty) => return Ok(None),
Err(broadcast::TryRecvError::Closed) => {
unreachable!("the `Session` has a `Sender`")
},
}
}
}
}
#[derive(Clone)]
pub(crate) struct PeerSender {
inner: broadcast::Sender<PeerMessage>,
peer_id: PeerId,
}
impl PeerSender {
#[inline]
pub(crate) fn fork(&self, peer_id: PeerId) -> Self {
Self { inner: self.inner.clone(), peer_id }
}
#[inline]
pub(crate) fn new(
inner: broadcast::Sender<PeerMessage>,
peer_id: PeerId,
) -> Self {
Self { inner, peer_id }
}
#[inline]
async fn relay(&self, client_msg: ClientMessage) {
self.send(PeerMessage::relay(client_msg)).await
}
#[inline]
async fn send(&self, peer_msg: PeerMessage) {
let _ = self.inner.send(peer_msg).await;
}
#[inline]
fn subscribe(&self) -> PeerReceiver {
PeerReceiver::new(self.inner.subscribe(), self.peer_id)
}
}
impl Drop for PeerSender {
#[inline]
fn drop(&mut self) {
let msg = PeerMessage::peer_left(self.peer_id);
let sender = self.inner.clone();
tokio::spawn(async move { sender.send(msg).await });
}
}
#[inline]
async fn wait_while_receiving_peer_msgs<T, F>(
future: F,
peer_receiver: &mut PeerReceiver,
buffer: &mut Vec<PeerMessage>,
) -> T
where
F: Future<Output = T>,
{
let future = future.fuse();
pin_mut!(future);
loop {
let peer_recv = peer_receiver.recv().fuse();
pin_mut!(peer_recv);
select! {
output = future => {
return output;
},
maybe_peer_msg = peer_recv => {
let Ok(peer_msg) = maybe_peer_msg else {
todo!();
};
buffer.push(peer_msg);
},
}
}
}
#[inline]
async fn join_session(
session_sender: &mpsc::Sender<SessionMessage>,
peer_id: PeerId,
session_id: SessionId,
) -> Result<PeerSender, InitError> {
let (sender, receiver) = oneshot::channel();
let message = SessionMessage::JoinSession(peer_id, session_id, sender);
session_sender.send(message).await.unwrap();
receiver.await.unwrap()
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnectionError {
#[error("the client disconnected")]
ClientDisconnected,
#[error("couldn't decode message from client: {0:?}")]
Decode(#[from] encode::DecodeError),
#[error("couldn't encode message to client: {0:?}")]
Encode(#[from] encode::EncodeError),
#[error("receiver on {peer_id:?} lagged behind by {num_lost}")]
ReceiverLagged {
num_lost: u64,
peer_id: PeerId,
},
#[error("couldn't read from stream: {0:?}")]
StreamRead(io::Error),
#[error("couldn't write to stream: {0:?}")]
StreamWrite(io::Error),
}
impl ConnectionError {
#[inline]
pub(crate) fn is_client_disconnected(io_err: &io::Error) -> bool {
matches!(
io_err.kind(),
ErrorKind::BrokenPipe
| ErrorKind::ConnectionReset
| ErrorKind::UnexpectedEof
)
}
}