use std::time::Duration;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::{
io::{self, BufReader, BufWriter},
net::TcpStream,
};
use crate::EncryptionState;
use super::packet::{Packet, PacketReader, PacketResult, PacketWriter};
use super::proxy_protocol::ProtocolResult;
pub struct Connection {
reader: PacketReader<BufReader<OwnedReadHalf>>,
writer: PacketWriter<BufWriter<OwnedWriteHalf>>,
timeout: Duration,
}
impl Connection {
pub async fn new(stream: TcpStream) -> io::Result<Self> {
stream.set_nodelay(true)?;
let (read_half, write_half) = stream.into_split();
Ok(Self {
reader: PacketReader::new(BufReader::new(read_half)),
writer: PacketWriter::new(BufWriter::new(write_half)),
timeout: Duration::from_secs(30),
})
}
pub async fn connect<A: tokio::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
let stream = TcpStream::connect(addr).await?;
Self::new(stream).await
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = timeout;
}
pub fn enable_encryption(&mut self, encryption_state: &EncryptionState) {
if let Some((encrypt, decrypt)) = encryption_state.create_cipher() {
self.reader.enable_encryption(decrypt);
self.writer.enable_encryption(encrypt);
}
}
pub fn enable_compression(&mut self, threshold: i32) {
self.reader.enable_compression(threshold);
self.writer.enable_compression(threshold);
}
pub fn disable_compression(&mut self) {
self.reader.disable_compression();
self.writer.disable_compression();
}
pub async fn read_packet(&mut self) -> ProtocolResult<Packet> {
Ok(self.reader.read_packet().await?)
}
pub async fn write_packet(&mut self, packet: &Packet) -> ProtocolResult<()> {
self.writer.write_packet(packet).await
}
pub async fn close(&mut self) -> PacketResult<()> {
self.writer.close().await
}
pub async fn peer_addr(&self) -> PacketResult<std::net::SocketAddr> {
self.reader
.reader
.get_ref()
.peer_addr()
.map_err(|e| e.into())
}
pub fn into_split(
self,
) -> (
PacketReader<BufReader<tokio::net::tcp::OwnedReadHalf>>,
PacketWriter<BufWriter<tokio::net::tcp::OwnedWriteHalf>>,
) {
let reader = self.reader;
let writer = self.writer;
(reader, writer)
}
pub fn into_split_raw(
self,
) -> (
PacketReader<BufReader<OwnedReadHalf>>,
PacketWriter<BufWriter<OwnedWriteHalf>>,
) {
(self.reader, self.writer)
}
}
pub struct ServerConnection {
connection: Connection,
}
impl ServerConnection {
pub async fn new(stream: TcpStream) -> io::Result<Self> {
Ok(Self {
connection: Connection::new(stream).await?,
})
}
pub async fn connect<A: tokio::net::ToSocketAddrs>(addr: A) -> io::Result<Self> {
Ok(Self {
connection: Connection::connect(addr).await?,
})
}
pub fn into_split(
self,
) -> (
PacketReader<BufReader<tokio::net::tcp::OwnedReadHalf>>,
PacketWriter<BufWriter<tokio::net::tcp::OwnedWriteHalf>>,
) {
self.connection.into_split()
}
pub fn into_split_raw(
self,
) -> (
PacketReader<BufReader<tokio::net::tcp::OwnedReadHalf>>,
PacketWriter<BufWriter<tokio::net::tcp::OwnedWriteHalf>>,
) {
self.connection.into_split_raw()
}
pub async fn close(&mut self) -> PacketResult<()> {
self.connection.close().await
}
pub async fn peer_addr(&self) -> PacketResult<std::net::SocketAddr> {
self.connection.peer_addr().await
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.connection.set_timeout(timeout);
}
pub fn enable_encryption(&mut self, encryption_state: &EncryptionState) {
self.connection.enable_encryption(encryption_state);
}
pub fn enable_compression(&mut self, threshold: i32) {
self.connection.enable_compression(threshold);
}
pub fn disable_compression(&mut self) {
self.connection.disable_compression();
}
pub async fn read_packet(&mut self) -> ProtocolResult<Packet> {
self.connection.read_packet().await
}
pub async fn write_packet(&mut self, packet: &Packet) -> ProtocolResult<()> {
self.connection.write_packet(packet).await
}
}