use tokio::{
io::{AsyncRead, AsyncWriteExt},
net::TcpStream,
};
use tokio_stream::StreamExt;
use tokio_util::{
bytes::BytesMut,
codec::{FramedRead, LengthDelimitedCodec},
};
pub const PROTOCOL_VERSION: u32 = 2;
pub type FramedReader<T> = FramedRead<T, LengthDelimitedCodec>;
#[derive(Clone, bincode::Encode, bincode::Decode)]
pub enum ClientServerPacket {
Ping,
ProtocolVersion(u32),
PubKey(Vec<u8>),
ClientId(u64),
Challenge(Vec<u8>),
ChallengeResponse(Vec<u8>),
}
impl ClientServerPacket {
pub fn into_vec(self) -> Result<Vec<u8>, bincode::error::EncodeError> {
bincode::encode_to_vec(self, bincode::config::standard())
}
pub fn from_slice(data: &[u8]) -> Result<Self, bincode::error::DecodeError> {
bincode::decode_from_slice(&data, bincode::config::standard()).map(|(packet, _)| packet)
}
}
#[derive(Clone, Debug)]
pub enum TaggedPacket {
Data { client_id: u64, data: Vec<u8> },
Failure { client_id: u64, error: String },
Kick { client_id: u64 },
Reconnection { client_id: u64 },
}
impl TaggedPacket {
pub fn client_id(&self) -> u64 {
match self {
TaggedPacket::Data { client_id, .. } => *client_id,
TaggedPacket::Failure { client_id, .. } => *client_id,
TaggedPacket::Kick { client_id } => *client_id,
TaggedPacket::Reconnection { client_id } => *client_id,
}
}
pub fn into_vec(self) -> Vec<u8> {
let mut buf = Vec::new();
match self {
TaggedPacket::Data { data, client_id } => {
buf.extend_from_slice(&client_id.to_le_bytes());
buf.push(0x00);
buf.extend_from_slice(&data);
}
TaggedPacket::Failure { error, client_id } => {
buf.extend_from_slice(&client_id.to_le_bytes());
buf.push(0x01);
buf.extend_from_slice(error.as_bytes());
}
TaggedPacket::Kick { client_id } => {
buf.extend_from_slice(&client_id.to_le_bytes());
buf.push(0x02);
}
TaggedPacket::Reconnection { client_id } => {
buf.extend_from_slice(&client_id.to_le_bytes());
buf.push(0x03);
}
}
buf
}
}
pub fn configure_performance_tcp_socket(stream: &mut TcpStream) -> std::io::Result<()> {
stream.set_nodelay(true)?;
stream.set_linger(Some(std::time::Duration::from_secs(5)))?;
Ok(())
}
pub fn new_framed_reader<T: AsyncRead + Unpin>(stream: T) -> FramedReader<T> {
LengthDelimitedCodec::builder()
.length_field_type::<u32>()
.little_endian()
.new_read(stream)
}
pub async fn recv_size_prefixed<T: AsyncRead + Unpin>(
read: &mut FramedReader<T>,
) -> anyhow::Result<BytesMut> {
Ok(read
.next()
.await
.ok_or_else(|| anyhow::format_err!("Connection closed or Eof"))??)
}
pub async fn send_size_prefixed<T: AsyncWriteExt + Unpin>(
stream: &mut T,
message: &[u8],
) -> anyhow::Result<()> {
let size = message.len() as u32;
let size_bytes = size.to_le_bytes();
let mut combined_message = Vec::with_capacity(4 + message.len());
combined_message.extend_from_slice(&size_bytes);
combined_message.extend_from_slice(message);
stream.write_all(&combined_message).await?;
Ok(())
}
pub async fn recv_tagged_packet<T: AsyncRead + Unpin>(
read: &mut FramedReader<T>,
) -> anyhow::Result<TaggedPacket> {
let buffer = recv_size_prefixed(read).await?;
if buffer.len() < 8 {
return Err(anyhow::format_err!("Packet too small"));
}
let client_id = u64::from_le_bytes(buffer[0..8].try_into().unwrap());
let buf: &[u8] = buffer[8..].into();
match buf[0] {
0x00 => {
Ok(TaggedPacket::Data {
client_id,
data: buf[1..].into(),
})
}
0x01 => {
let error = String::from_utf8_lossy(&buf[1..]).to_string();
Ok(TaggedPacket::Failure { client_id, error })
}
0x02 => {
Ok(TaggedPacket::Kick { client_id })
}
0x03 => {
Ok(TaggedPacket::Reconnection { client_id })
}
_ => {
return Err(anyhow::format_err!("Unknown packet type"));
}
}
}
pub async fn send_tagged_packet<T: AsyncWriteExt + Unpin>(
stream: &mut T,
packet: TaggedPacket,
) -> anyhow::Result<()> {
let data = packet.into_vec();
let size = data.len() as u32;
let size_bytes = size.to_le_bytes();
let mut combined_message = Vec::with_capacity(size as usize + 4);
combined_message.extend_from_slice(&size_bytes);
combined_message.extend_from_slice(&data);
stream.write_all(&combined_message).await?;
Ok(())
}