use std::net::SocketAddr;
use {
bytes::{Bytes, BytesMut},
crossbeam_channel::Sender,
futures::StreamExt,
quinn::{
Connecting, Connection, ConnectionDriver, Datagrams, IncomingUniStreams, NewConnection,
RecvStream, SendStream,
},
tokio::sync::{
mpsc::{channel, Receiver as TokioReceiver, Sender as TokioSender},
oneshot::channel as oneshot_channel,
},
};
use crate::{
worker::{
handshake::{verify_handshake_client, verify_handshake_server},
WorkerEvent,
},
MessageOrder, PeerId,
};
#[derive(Debug)]
pub enum ConnectionCommand {
SendDatagram { bytes: Bytes },
SendMessage { bytes: Bytes, order: MessageOrder },
}
pub async fn handle_connecting(
connecting: Connecting,
sender: Sender<WorkerEvent>,
server_peer_id: Option<PeerId>,
protocol_checksum: u32,
) {
let new_connection = match connecting.await {
Ok(value) => value,
Err(_error) => {
if let Some(peer_id) = server_peer_id {
sender
.send(WorkerEvent::ConnectionFailed { peer_id })
.unwrap();
}
return;
}
};
let NewConnection {
driver,
connection,
datagrams,
uni_streams,
mut bi_streams,
..
} = { new_connection };
let (connection_sender, connection_receiver) = channel(256);
let peer_id = initialize_connection_data(
&sender,
connection.clone(),
connection_sender,
server_peer_id,
)
.await;
tokio::spawn(drive_connection(driver, sender.clone(), peer_id));
let result = if server_peer_id.is_none() {
verify_handshake_server(&connection, protocol_checksum).await
} else {
verify_handshake_client(&mut bi_streams, protocol_checksum).await
};
let (send, recv) = match result {
Some((send, recv)) => (send, recv),
None => {
connection.close(0u32.into(), b"Handshake Failed");
return;
}
};
let confirmed = verify_connection(&sender, connection.remote_address(), peer_id).await;
if !confirmed {
connection.close(0u32.into(), b"Denied by Application");
return;
};
tokio::spawn(listen_datagrams(datagrams, sender.clone(), peer_id));
tokio::spawn(listen_ordered_stream(recv, sender.clone(), peer_id));
tokio::spawn(listen_oneshot_streams(uni_streams, sender, peer_id));
let _ = connection_worker(connection, send, connection_receiver).await;
}
async fn initialize_connection_data(
sender: &Sender<WorkerEvent>,
connection: Connection,
connection_sender: TokioSender<ConnectionCommand>,
server_peer_id: Option<PeerId>,
) -> PeerId {
let (oneshot_sender, oneshot_receiver) = oneshot_channel();
sender
.send(WorkerEvent::ConnectionStarted {
connection,
connection_sender,
server_peer_id,
peer_id_sender: oneshot_sender,
})
.unwrap();
oneshot_receiver.await.unwrap()
}
async fn verify_connection(
sender: &Sender<WorkerEvent>,
socket_addr: SocketAddr,
peer_id: PeerId,
) -> bool {
let (oneshot_sender, oneshot_receiver) = oneshot_channel();
sender
.send(WorkerEvent::ConnectionRequested {
peer_id,
socket_addr,
confirm_sender: oneshot_sender,
})
.unwrap();
oneshot_receiver.await.unwrap_or(false)
}
async fn connection_worker(
connection: Connection,
mut ordered_send: SendStream,
mut receiver: TokioReceiver<ConnectionCommand>,
) -> Result<(), Box<dyn std::error::Error>> {
while let Some(command) = receiver.recv().await {
match command {
ConnectionCommand::SendDatagram { bytes } => {
connection.send_datagram(bytes).await?;
}
ConnectionCommand::SendMessage { bytes, order } => {
let length = bytes.len() as u16;
let length_bytes = length.to_be_bytes();
match order {
MessageOrder::Unordered => {
let mut oneshot_send = connection.open_uni().await?;
oneshot_send.write_all(&length_bytes).await?;
oneshot_send.write_all(&bytes).await?;
}
MessageOrder::Ordered => {
ordered_send.write_all(&length_bytes).await?;
ordered_send.write_all(&bytes).await?;
}
}
}
}
}
Ok(())
}
async fn drive_connection(driver: ConnectionDriver, sender: Sender<WorkerEvent>, peer_id: PeerId) {
let reason = match driver.await {
Ok(()) => None,
Err(error) => Some(error.to_string()),
};
sender
.send(WorkerEvent::Disconnected { peer_id, reason })
.unwrap();
}
async fn listen_datagrams(mut datagrams: Datagrams, sender: Sender<WorkerEvent>, peer_id: PeerId) {
while let Some(result) = datagrams.next().await {
match result {
Ok(bytes) => {
sender
.send(WorkerEvent::ReceivedDatagram { peer_id, bytes })
.unwrap();
}
Err(_) => {
return;
}
}
}
}
async fn listen_ordered_stream(mut recv: RecvStream, sender: Sender<WorkerEvent>, peer_id: PeerId) {
while let Some(bytes) = read_stream_message(&mut recv).await {
sender
.send(WorkerEvent::ReceivedMessage { peer_id, bytes })
.unwrap();
}
}
async fn listen_oneshot_streams(
mut uni_streams: IncomingUniStreams,
sender: Sender<WorkerEvent>,
peer_id: PeerId,
) {
while let Some(Ok(mut recv)) = uni_streams.next().await {
let bytes = match read_stream_message(&mut recv).await {
Some(bytes) => bytes,
None => return,
};
sender
.send(WorkerEvent::ReceivedMessage { peer_id, bytes })
.unwrap();
}
}
async fn read_stream_message(recv: &mut RecvStream) -> Option<Bytes> {
let mut length_bytes = [0u8; 2];
recv.read_exact(&mut length_bytes).await.ok()?;
let length = u16::from_be_bytes(length_bytes);
let mut bytes = BytesMut::new();
bytes.resize(length as usize, 0);
recv.read_exact(&mut bytes).await.ok()?;
Some(bytes.freeze())
}