use std::fmt::Display;
use std::net::SocketAddr;
use log::{debug, trace};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::network::network_error::NetworkError;
use crate::protocol::packets::Packet;
use crate::protocol::packets::packet_definer::{PacketDirection, PacketState};
use crate::protocol::serialization::{McDeserializer, McSerialize, McSerializer, StateBasedDeserializer};
use crate::protocol::serialization::serializer_error::SerializingErr;
use crate::protocol_types::datatypes::var_types::VarInt;
use crate::protocol_types::protocol_verison::ProtocolVerison;
pub mod client_handlers;
const PACKET_MAX_SIZE: usize = 2097151; const CONTINUE_BIT: u8 = 0b10000000;
#[derive(Debug)]
#[allow(dead_code)]
pub struct CraftClient {
pub(crate) tcp_stream: TcpStream,
pub(crate) socket_addr: SocketAddr,
pub packet_state: PacketState,
pub compression_threshold: Option<i32>,
pub client_version: Option<VarInt>
}
impl CraftClient {
pub fn from_connection(tcp_stream: TcpStream) -> Result<Self, NetworkError> {
tcp_stream.set_nodelay(true)?;
Ok(Self {
socket_addr: tcp_stream.peer_addr()?,
tcp_stream,
packet_state: PacketState::HANDSHAKING,
compression_threshold: None,
client_version: None
})
}
pub async fn send_packet(&mut self, packet: Packet) -> Result<(), NetworkError> {
let mut serializer = McSerializer::new();
packet.mc_serialize(&mut serializer)?;
let output = &serializer.output;
trace!("Sending to {} : {:?}", self, output);
self.tcp_stream.write_all(output).await?;
Ok(())
}
pub async fn receive_packet(&mut self) -> Result<Packet, NetworkError> {
let mut vec = Vec::with_capacity(3);
loop {
let b = self.tcp_stream.read_u8().await?;
vec.push(b);
if b & CONTINUE_BIT == 0 {
break;
} else if vec.len() > 3 {
return Err(SerializingErr::VarTypeTooLong("Packet length VarInt max bytes is 3".to_string()).into());
}
}
let vari = VarInt::from_slice(&vec)?;
if vari.0 > PACKET_MAX_SIZE as i32 { return Err(NetworkError::PacketTooLarge);
}
let length = vari.0 as usize + vec.len();
let mut buffer = vec![0; length];
let mut i = 0;
for b in &vec {
buffer[i] = *b;
i += 1;
}
let length = self.tcp_stream.read(&mut buffer[vec.len()..]).await;
let length = match length {
Ok(length) => {length}
Err(e) => {
if e.to_string().contains("An established connection was aborted by the software in your host machine") {
debug!("OS Error detected in packet receive, closing the connection: {}", e);
self.close().await;
return Err(NetworkError::ConnectionAbortedLocally);
}
return Err(NetworkError::IOError(e));
}
};
trace!("Received from {} : {:?}", self, &buffer);
if length == 0 { self.close().await;
return Err(NetworkError::NoDataReceived);
} else if length == PACKET_MAX_SIZE {
return Err(NetworkError::PacketTooLarge);
}
let mut deserializer = McDeserializer::new(&buffer);
let packet = Packet::deserialize_state(&mut deserializer, self.packet_state, PacketDirection::SERVER)?;
Ok(packet)
}
pub fn try_receive_packet(&mut self) -> Result<Packet, NetworkError> {
let mut vec = vec![];
loop {
let var_buffer = &mut [0u8; 1];
let len = self.tcp_stream.try_read(var_buffer)?;
if len == 0 {
return Err(NetworkError::NoDataReceived);
}
let b = var_buffer[0];
if b & CONTINUE_BIT == 0 {
vec.push(b);
break;
} else {
vec.push(b);
if vec.len() > 3 {
return Err(SerializingErr::VarTypeTooLong("Packet length VarInt max bytes is 3".to_string()).into());
}
}
}
let vari = VarInt::from_slice(&vec)?;
let varbytes = vari.to_bytes();
if vari.0 > PACKET_MAX_SIZE as i32 { return Err(NetworkError::PacketTooLarge);
}
let length = vari.0 as usize + varbytes.len();
let mut buffer = vec![0; length];
let mut i = 0;
for b in &varbytes {
buffer[i] = *b;
i += 1;
}
let length = self.tcp_stream.try_read(&mut buffer[varbytes.len()..]);
if let Err(e) = length {
return Err(NetworkError::IOError(e));
}
let length = length.unwrap();
trace!("Received from {} : {:?}", self, &buffer);
if length == 0 { return Err(NetworkError::NoDataReceived);
} else if length == PACKET_MAX_SIZE {
return Err(NetworkError::PacketTooLarge);
}
let mut deserializer = McDeserializer::new(&buffer);
let packet = Packet::deserialize_state(&mut deserializer, self.packet_state, PacketDirection::SERVER)?;
Ok(packet)
}
pub async fn peek_packet(&mut self) -> Result<Packet, NetworkError> {
let mut i = 1usize;
let vari: VarInt;
loop {
let mut b = vec![0; i];
if self.tcp_stream.peek(&mut b).await? == 0 {
return Err(NetworkError::NoDataReceived);
}
if b[i - 1] & CONTINUE_BIT == 0 {
vari = VarInt::from_slice(&b)?;
break;
} else {
if i > 3 { return Err(SerializingErr::VarTypeTooLong("Packet length VarInt max bytes is 3".to_string()).into());
}
}
i += 1;
}
let varbytes = vari.to_bytes();
if vari.0 > PACKET_MAX_SIZE as i32 { return Err(NetworkError::PacketTooLarge);
}
let length = vari.0 as usize + varbytes.len();
let mut buffer = vec![0; length];
let mut i = 0;
for b in &varbytes {
buffer[i] = *b;
i += 1;
}
let length = self.tcp_stream.peek(&mut buffer[varbytes.len()..]).await;
if let Err(e) = length {
if e.to_string().contains("An established connection was aborted by the software in your host machine") {
debug!("OS Error detected in packet receive, closing the connection: {}", e);
self.close().await;
return Err(NetworkError::ConnectionAbortedLocally);
}
return Err(NetworkError::IOError(e));
}
let length = length.unwrap();
trace!("Peeked from {} : {:?}", self, &buffer);
if length == 0 { self.close().await;
return Err(NetworkError::NoDataReceived);
} else if length == PACKET_MAX_SIZE {
return Err(NetworkError::PacketTooLarge);
}
let mut deserializer = McDeserializer::new(&buffer);
let packet = Packet::deserialize_state(&mut deserializer, self.packet_state, PacketDirection::SERVER)?;
Ok(packet)
}
pub fn change_state(&mut self, state: PacketState) {
self.packet_state = state;
}
pub fn enable_compression(&mut self, threshold: Option<i32>) {
self.compression_threshold = threshold;
}
pub async fn close(&mut self) -> bool {
debug!("Closing connection to {}", self);
self.tcp_stream.shutdown().await.is_ok()
}
pub fn get_client_version(&self) -> Option<ProtocolVerison> {
Some(ProtocolVerison::from(self.client_version?.0 as i16)?)
}
}
impl Display for CraftClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = if let Ok(addr) = self.tcp_stream.peer_addr() {
format!("{}", addr)
} else {
"Unknown".to_string()
};
write!(f, "{}", format!("CraftConnection: {}", s))
}
}