pub mod backend;
pub mod client;
pub mod forwarding;
use std::io;
use std::time::Duration;
use aes::Aes128;
use bytes::{Buf, BytesMut};
use cfb8::cipher::generic_array::GenericArray;
use cfb8::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit};
use deepslate_protocol::codec;
use deepslate_protocol::packet::Packet;
use deepslate_protocol::packet::login::SetCompressionPacket;
use deepslate_protocol::types::ProtocolError;
use deepslate_protocol::varint;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter};
use tokio::net::TcpStream;
use crate::auth::AuthError;
const MAX_READ_BUF_SIZE: usize = 2 * 1024 * 1024 + 512 * 1024;
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error(transparent)]
Io(#[from] io::Error),
#[error("protocol error: {0}")]
Protocol(#[from] ProtocolError),
#[error("authentication failed: {0}")]
Auth(#[from] AuthError),
#[error("backend connection failed: {reason}")]
BackendFailed {
reason: String,
},
#[error("connection timed out")]
Timeout,
}
impl ConnectionError {
#[must_use]
pub fn is_expected(&self) -> bool {
match self {
Self::Io(io_err) => matches!(
io_err.kind(),
io::ErrorKind::ConnectionReset
| io::ErrorKind::BrokenPipe
| io::ErrorKind::UnexpectedEof
| io::ErrorKind::ConnectionAborted
),
Self::Protocol(_) | Self::Auth(_) | Self::BackendFailed { .. } | Self::Timeout => false,
}
}
}
type Aes128Cfb8Enc = cfb8::Encryptor<Aes128>;
type Aes128Cfb8Dec = cfb8::Decryptor<Aes128>;
struct CipherPair {
encryptor: Aes128Cfb8Enc,
decryptor: Aes128Cfb8Dec,
}
impl CipherPair {
fn new(shared_secret: &[u8]) -> Self {
let key = shared_secret.into();
let iv = shared_secret.into();
Self {
encryptor: Aes128Cfb8Enc::new(key, iv),
decryptor: Aes128Cfb8Dec::new(key, iv),
}
}
fn encrypt(&mut self, data: &mut [u8]) {
let mut block = GenericArray::default();
for byte in data.iter_mut() {
block[0] = *byte;
self.encryptor.encrypt_block_mut(&mut block);
*byte = block[0];
}
}
fn decrypt(&mut self, data: &mut [u8]) {
let mut block = GenericArray::default();
for byte in data.iter_mut() {
block[0] = *byte;
self.decryptor.decrypt_block_mut(&mut block);
*byte = block[0];
}
}
}
pub struct MinecraftConnection<S = TcpStream>
where
S: AsyncRead + AsyncWrite + Unpin,
{
stream: BufWriter<S>,
read_buf: BytesMut,
cipher: Option<CipherPair>,
compression_threshold: i32,
decompressor: libdeflater::Decompressor,
compressor: libdeflater::Compressor,
decompress_buf: Vec<u8>,
compress_buf: Vec<u8>,
encode_buf: Vec<u8>,
write_buf: Vec<u8>,
read_timeout: Duration,
}
impl<S: AsyncRead + AsyncWrite + Unpin> MinecraftConnection<S> {
#[must_use]
pub fn new(
stream: S,
compression_level: libdeflater::CompressionLvl,
read_timeout: Duration,
) -> Self {
Self {
stream: BufWriter::new(stream),
read_buf: BytesMut::with_capacity(32768),
cipher: None,
compression_threshold: -1,
decompressor: libdeflater::Decompressor::new(),
compressor: libdeflater::Compressor::new(compression_level),
decompress_buf: Vec::new(),
compress_buf: Vec::new(),
encode_buf: Vec::new(),
write_buf: Vec::new(),
read_timeout,
}
}
pub fn enable_encryption(&mut self, shared_secret: &[u8]) {
self.cipher = Some(CipherPair::new(shared_secret));
}
pub const fn enable_compression(&mut self, threshold: i32) {
self.compression_threshold = threshold;
}
#[allow(clippy::large_stack_arrays)]
pub async fn read_frame(&mut self) -> Result<Option<BytesMut>, ConnectionError> {
loop {
if let Some((varint_size, frame_len)) = codec::try_read_frame(&self.read_buf)? {
self.read_buf.advance(varint_size);
let mut frame = self.read_buf.split_to(frame_len);
if self.compression_threshold >= 0 {
let (uncompressed_size, payload) = codec::read_compressed_frame(&frame)?;
if uncompressed_size == 0 {
frame.advance(1);
return Ok(Some(frame));
}
self.decompress_buf.resize(uncompressed_size, 0);
self.decompressor
.zlib_decompress(payload, &mut self.decompress_buf)
.map_err(ProtocolError::from)?;
return Ok(Some(BytesMut::from(&self.decompress_buf[..])));
}
return Ok(Some(frame));
}
if let Some(cipher) = &mut self.cipher {
let mut tmp = [0u8; 32768];
let n = self.stream.read(&mut tmp).await?;
if n == 0 {
return Ok(None);
}
cipher.decrypt(&mut tmp[..n]);
self.read_buf.extend_from_slice(&tmp[..n]);
} else {
let n = self.stream.read_buf(&mut self.read_buf).await?;
if n == 0 {
return Ok(None);
}
}
if self.read_buf.len() > MAX_READ_BUF_SIZE {
return Err(ConnectionError::Protocol(
ProtocolError::ReadBufferOverflow {
size: self.read_buf.len(),
max: MAX_READ_BUF_SIZE,
},
));
}
}
}
#[expect(
clippy::large_futures,
reason = "MinecraftConnection carries large framing buffers through read timeouts"
)]
pub async fn read_frame_timeout(&mut self) -> Result<Option<BytesMut>, ConnectionError> {
tokio::time::timeout(self.read_timeout, self.read_frame())
.await
.map_err(|_| ConnectionError::Timeout)?
}
pub fn try_read_frame(&mut self) -> Result<Option<BytesMut>, ConnectionError> {
let Some((varint_size, frame_len)) = codec::try_read_frame(&self.read_buf)? else {
return Ok(None);
};
self.read_buf.advance(varint_size);
let mut frame = self.read_buf.split_to(frame_len);
if self.compression_threshold >= 0 {
let (uncompressed_size, payload) = codec::read_compressed_frame(&frame)?;
if uncompressed_size == 0 {
frame.advance(1);
return Ok(Some(frame));
}
self.decompress_buf.resize(uncompressed_size, 0);
self.decompressor
.zlib_decompress(payload, &mut self.decompress_buf)
.map_err(ProtocolError::from)?;
return Ok(Some(BytesMut::from(&self.decompress_buf[..])));
}
Ok(Some(frame))
}
pub async fn flush(&mut self) -> Result<(), ConnectionError> {
self.stream.flush().await?;
Ok(())
}
#[allow(clippy::future_not_send)]
pub async fn write_packet<P: Packet>(&mut self, packet: &P) -> Result<(), ConnectionError> {
codec::encode_packet_data(&mut self.encode_buf, P::PACKET_ID, |buf| packet.encode(buf));
let encode_buf = std::mem::take(&mut self.encode_buf);
let result = self.write_raw_packet(&encode_buf).await;
self.encode_buf = encode_buf;
result?;
self.stream.flush().await?;
Ok(())
}
#[allow(clippy::future_not_send)]
pub async fn encode_and_write_packet(
&mut self,
packet_id: i32,
encode_fn: impl FnOnce(&mut Vec<u8>),
) -> Result<(), ConnectionError> {
codec::encode_packet_data(&mut self.encode_buf, packet_id, encode_fn);
let encode_buf = std::mem::take(&mut self.encode_buf);
let result = self.write_raw_packet(&encode_buf).await;
self.encode_buf = encode_buf;
result?;
self.stream.flush().await?;
Ok(())
}
#[allow(
clippy::cast_sign_loss,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap
)]
pub async fn write_raw_packet(&mut self, packet_data: &[u8]) -> Result<(), ConnectionError> {
self.write_buf.clear();
if self.compression_threshold >= 0 {
if packet_data.len() >= self.compression_threshold as usize {
let max_size = self.compressor.zlib_compress_bound(packet_data.len());
self.compress_buf.resize(max_size, 0);
let actual_size = self
.compressor
.zlib_compress(packet_data, &mut self.compress_buf)
.map_err(ProtocolError::from)?;
codec::write_compressed_frame(
&mut self.write_buf,
packet_data.len() as i32,
&self.compress_buf[..actual_size],
);
} else {
codec::write_compressed_frame(&mut self.write_buf, 0, packet_data);
}
} else {
codec::write_frame(&mut self.write_buf, packet_data);
}
if let Some(cipher) = &mut self.cipher {
cipher.encrypt(&mut self.write_buf);
}
self.stream.write_all(&self.write_buf).await?;
Ok(())
}
pub async fn set_compression(&mut self, threshold: i32) -> Result<(), ConnectionError> {
self.write_packet(&SetCompressionPacket { threshold })
.await?;
self.compression_threshold = threshold;
Ok(())
}
pub async fn shutdown(&mut self) -> Result<(), ConnectionError> {
self.stream.flush().await?;
self.stream.shutdown().await?;
let drain = async {
let mut buf = [0u8; 1024];
loop {
match self.stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(_) => {}
}
}
};
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), drain).await;
Ok(())
}
}
impl MinecraftConnection<TcpStream> {
pub fn into_split(self) -> (MinecraftReader, MinecraftWriter) {
let inner = self.stream.into_inner();
let (read_half, write_half) = inner.into_split();
let (enc_cipher, dec_cipher) = if let Some(cipher) = self.cipher {
(
Some(EncryptCipher(cipher.encryptor)),
Some(DecryptCipher(cipher.decryptor)),
)
} else {
(None, None)
};
(
MinecraftReader {
stream: read_half,
read_buf: self.read_buf,
cipher: dec_cipher,
read_timeout: self.read_timeout,
},
MinecraftWriter {
stream: write_half,
cipher: enc_cipher,
},
)
}
}
struct EncryptCipher(Aes128Cfb8Enc);
impl EncryptCipher {
fn encrypt(&mut self, data: &mut [u8]) {
let mut block = GenericArray::default();
for byte in data.iter_mut() {
block[0] = *byte;
self.0.encrypt_block_mut(&mut block);
*byte = block[0];
}
}
}
struct DecryptCipher(Aes128Cfb8Dec);
impl DecryptCipher {
fn decrypt(&mut self, data: &mut [u8]) {
let mut block = GenericArray::default();
for byte in data.iter_mut() {
block[0] = *byte;
self.0.decrypt_block_mut(&mut block);
*byte = block[0];
}
}
}
pub struct MinecraftReader {
stream: tokio::net::tcp::OwnedReadHalf,
read_buf: BytesMut,
cipher: Option<DecryptCipher>,
read_timeout: Duration,
}
impl MinecraftReader {
#[allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::large_stack_arrays
)]
pub async fn read_raw_frame(&mut self) -> Result<Option<Vec<u8>>, ConnectionError> {
loop {
if let Some((varint_size, frame_len)) = codec::try_read_frame(&self.read_buf)? {
let mut out = Vec::with_capacity(varint_size + frame_len);
varint::write_var_int(&mut out, frame_len as i32);
self.read_buf.advance(varint_size);
let frame = self.read_buf.split_to(frame_len);
out.extend_from_slice(&frame);
return Ok(Some(out));
}
if let Some(cipher) = &mut self.cipher {
let mut tmp = [0u8; 32768];
let n = self.stream.read(&mut tmp).await?;
if n == 0 {
return Ok(None);
}
cipher.decrypt(&mut tmp[..n]);
self.read_buf.extend_from_slice(&tmp[..n]);
} else {
let n = self.stream.read_buf(&mut self.read_buf).await?;
if n == 0 {
return Ok(None);
}
}
if self.read_buf.len() > MAX_READ_BUF_SIZE {
return Err(ConnectionError::Protocol(
ProtocolError::ReadBufferOverflow {
size: self.read_buf.len(),
max: MAX_READ_BUF_SIZE,
},
));
}
}
}
#[expect(
clippy::large_futures,
reason = "MinecraftReader carries large framing buffers through read timeouts"
)]
pub async fn read_raw_frame_timeout(&mut self) -> Result<Option<Vec<u8>>, ConnectionError> {
tokio::time::timeout(self.read_timeout, self.read_raw_frame())
.await
.map_err(|_| ConnectionError::Timeout)?
}
}
pub struct MinecraftWriter {
pub(crate) stream: tokio::net::tcp::OwnedWriteHalf,
cipher: Option<EncryptCipher>,
}
impl MinecraftWriter {
pub async fn write_raw_frame(&mut self, mut data: Vec<u8>) -> Result<(), ConnectionError> {
if let Some(cipher) = &mut self.cipher {
cipher.encrypt(&mut data);
}
self.stream.write_all(&data).await?;
Ok(())
}
}
#[cfg(test)]
#[expect(
clippy::large_futures,
reason = "MinecraftConnection carries large framing buffers; acceptable in tests"
)]
mod tests {
use std::time::Duration;
use deepslate_protocol::codec;
use deepslate_protocol::packet::Packet;
use deepslate_protocol::packet::login::SetCompressionPacket;
use super::MinecraftConnection;
fn duplex_pair(
buf_size: usize,
) -> (
MinecraftConnection<tokio::io::DuplexStream>,
MinecraftConnection<tokio::io::DuplexStream>,
) {
let (a, b) = tokio::io::duplex(buf_size);
let compression = libdeflater::CompressionLvl::new(1).expect("valid level");
let timeout = Duration::from_secs(5);
(
MinecraftConnection::new(a, compression, timeout),
MinecraftConnection::new(b, compression, timeout),
)
}
fn decode_frame<P: Packet>(frame: &[u8]) -> P {
let mut cursor = frame;
let packet_id = codec::read_packet_id(&mut cursor).expect("valid packet id");
assert_eq!(packet_id, P::PACKET_ID, "unexpected packet id");
P::decode(&mut cursor).expect("valid packet")
}
#[tokio::test]
async fn write_then_read_frame_round_trips() {
let (mut writer, mut reader) = duplex_pair(8192);
let packet = SetCompressionPacket { threshold: 256 };
writer.write_packet(&packet).await.unwrap();
drop(writer);
let frame = reader
.read_frame()
.await
.unwrap()
.expect("expected a frame");
let decoded: SetCompressionPacket = decode_frame(&frame);
assert_eq!(decoded, packet);
assert!(reader.read_frame().await.unwrap().is_none());
}
#[tokio::test]
async fn round_trip_with_compression() {
let (mut writer, mut reader) = duplex_pair(8192);
writer.enable_compression(0);
reader.enable_compression(0);
let packet = SetCompressionPacket { threshold: 512 };
writer.write_packet(&packet).await.unwrap();
drop(writer);
let frame = reader
.read_frame()
.await
.unwrap()
.expect("expected a frame");
let decoded: SetCompressionPacket = decode_frame(&frame);
assert_eq!(decoded, packet);
}
#[tokio::test]
async fn round_trip_with_encryption() {
let (mut writer, mut reader) = duplex_pair(8192);
let shared_secret = b"0123456789abcdef"; writer.enable_encryption(shared_secret);
reader.enable_encryption(shared_secret);
let packet = SetCompressionPacket { threshold: 128 };
writer.write_packet(&packet).await.unwrap();
drop(writer);
let frame = reader
.read_frame()
.await
.unwrap()
.expect("expected a frame");
let decoded: SetCompressionPacket = decode_frame(&frame);
assert_eq!(decoded, packet);
}
#[tokio::test]
async fn round_trip_with_compression_and_encryption() {
let (mut writer, mut reader) = duplex_pair(8192);
let shared_secret = b"fedcba9876543210";
writer.enable_encryption(shared_secret);
reader.enable_encryption(shared_secret);
writer.enable_compression(0);
reader.enable_compression(0);
let packet = SetCompressionPacket { threshold: 1024 };
writer.write_packet(&packet).await.unwrap();
drop(writer);
let frame = reader
.read_frame()
.await
.unwrap()
.expect("expected a frame");
let decoded: SetCompressionPacket = decode_frame(&frame);
assert_eq!(decoded, packet);
}
#[tokio::test]
async fn read_frame_returns_none_on_eof() {
let (writer, mut reader) = duplex_pair(8192);
drop(writer);
let result = reader.read_frame().await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn read_frame_timeout_returns_error_when_idle() {
let timeout = Duration::from_millis(50);
let (a, b) = tokio::io::duplex(8192);
let compression = libdeflater::CompressionLvl::new(1).expect("valid level");
let mut reader = MinecraftConnection::new(a, compression, timeout);
let _writer = b;
let result = reader.read_frame_timeout().await;
assert!(
matches!(result, Err(super::ConnectionError::Timeout)),
"expected Timeout, got {result:?}"
);
}
#[tokio::test]
async fn multiple_packets_round_trip() {
let (mut writer, mut reader) = duplex_pair(8192);
let packets = [
SetCompressionPacket { threshold: 0 },
SetCompressionPacket { threshold: 256 },
SetCompressionPacket { threshold: -1 },
];
for packet in &packets {
writer.write_packet(packet).await.unwrap();
}
drop(writer);
for expected in &packets {
let frame = reader
.read_frame()
.await
.unwrap()
.expect("expected a frame");
let decoded: SetCompressionPacket = decode_frame(&frame);
assert_eq!(&decoded, expected);
}
assert!(reader.read_frame().await.unwrap().is_none());
}
#[tokio::test]
async fn read_frame_rejects_oversized_read_buffer() {
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use deepslate_protocol::types::ProtocolError;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
struct FloodStream {
data: Vec<u8>,
pos: usize,
}
impl AsyncRead for FloodStream {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let remaining = &self.data[self.pos..];
if remaining.is_empty() {
return Poll::Ready(Ok(())); }
let n = remaining.len().min(buf.remaining());
buf.put_slice(&remaining[..n]);
self.pos += n;
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for FloodStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
let frame_len: i32 = 2 * 1024 * 1024; let mut data = Vec::new();
deepslate_protocol::varint::write_var_int(&mut data, frame_len);
data.resize(super::MAX_READ_BUF_SIZE + 64 * 1024, 0xAB);
let stream = FloodStream { data, pos: 0 };
let compression = libdeflater::CompressionLvl::new(1).expect("valid level");
let timeout = Duration::from_secs(5);
let mut conn = MinecraftConnection::new(stream, compression, timeout);
let result = conn.read_frame().await;
assert!(
matches!(
result,
Err(super::ConnectionError::Protocol(
ProtocolError::ReadBufferOverflow { .. }
))
),
"expected ReadBufferOverflow, got {result:?}"
);
}
}