use crate::{
wire_format::WireFormat,
Error,
Result,
};
use bytes::Bytes;
use std::time::Duration;
use tokio::{
io::{
AsyncRead,
AsyncReadExt,
AsyncWrite,
AsyncWriteExt,
BufReader,
BufWriter,
},
net::TcpStream,
};
#[cfg(feature = "tls")]
use rustls::ServerName;
#[cfg(feature = "tls")]
use std::sync::Arc;
#[cfg(feature = "tls")]
use tokio_rustls::TlsConnector;
const DEFAULT_READ_BUFFER_SIZE: usize = 8192;
const DEFAULT_WRITE_BUFFER_SIZE: usize = 8192;
#[derive(Clone, Debug)]
pub struct ConnectionOptions {
pub connect_timeout: Duration,
pub recv_timeout: Duration,
pub send_timeout: Duration,
pub tcp_keepalive: bool,
pub tcp_keepalive_idle: Duration,
pub tcp_keepalive_interval: Duration,
pub tcp_keepalive_count: u32,
pub tcp_nodelay: bool,
}
impl Default for ConnectionOptions {
fn default() -> Self {
Self {
connect_timeout: Duration::from_secs(5),
recv_timeout: Duration::ZERO,
send_timeout: Duration::ZERO,
tcp_keepalive: false,
tcp_keepalive_idle: Duration::from_secs(60),
tcp_keepalive_interval: Duration::from_secs(5),
tcp_keepalive_count: 3,
tcp_nodelay: true,
}
}
}
impl ConnectionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn recv_timeout(mut self, timeout: Duration) -> Self {
self.recv_timeout = timeout;
self
}
pub fn send_timeout(mut self, timeout: Duration) -> Self {
self.send_timeout = timeout;
self
}
pub fn tcp_keepalive(mut self, enabled: bool) -> Self {
self.tcp_keepalive = enabled;
self
}
pub fn tcp_keepalive_idle(mut self, duration: Duration) -> Self {
self.tcp_keepalive_idle = duration;
self
}
pub fn tcp_keepalive_interval(mut self, duration: Duration) -> Self {
self.tcp_keepalive_interval = duration;
self
}
pub fn tcp_keepalive_count(mut self, count: u32) -> Self {
self.tcp_keepalive_count = count;
self
}
pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
self.tcp_nodelay = enabled;
self
}
}
pub struct Connection {
reader: BufReader<Box<dyn AsyncRead + Unpin + Send>>,
writer: BufWriter<Box<dyn AsyncWrite + Unpin + Send>>,
}
impl Connection {
pub fn new(stream: TcpStream) -> Self {
let (read_half, write_half) = tokio::io::split(stream);
Self {
reader: BufReader::with_capacity(
DEFAULT_READ_BUFFER_SIZE,
Box::new(read_half) as Box<dyn AsyncRead + Unpin + Send>,
),
writer: BufWriter::with_capacity(
DEFAULT_WRITE_BUFFER_SIZE,
Box::new(write_half) as Box<dyn AsyncWrite + Unpin + Send>,
),
}
}
#[cfg(feature = "tls")]
pub fn new_tls(
stream: tokio_rustls::client::TlsStream<TcpStream>,
) -> Self {
let (read_half, write_half) = tokio::io::split(stream);
Self {
reader: BufReader::with_capacity(
DEFAULT_READ_BUFFER_SIZE,
Box::new(read_half) as Box<dyn AsyncRead + Unpin + Send>,
),
writer: BufWriter::with_capacity(
DEFAULT_WRITE_BUFFER_SIZE,
Box::new(write_half) as Box<dyn AsyncWrite + Unpin + Send>,
),
}
}
pub async fn connect(host: &str, port: u16) -> Result<Self> {
Self::connect_with_options(host, port, &ConnectionOptions::default())
.await
}
pub async fn connect_with_options(
host: &str,
port: u16,
options: &ConnectionOptions,
) -> Result<Self> {
let addr = format!("{}:{}", host, port);
let stream = if options.connect_timeout > Duration::ZERO {
tokio::time::timeout(
options.connect_timeout,
TcpStream::connect(&addr),
)
.await
.map_err(|_| {
Error::Connection(format!(
"Connection timeout after {:?} to {}",
options.connect_timeout, addr
))
})?
.map_err(|e| {
Error::Connection(format!(
"Failed to connect to {}: {}",
addr, e
))
})?
} else {
TcpStream::connect(&addr).await.map_err(|e| {
Error::Connection(format!(
"Failed to connect to {}: {}",
addr, e
))
})?
};
if options.tcp_nodelay {
stream.set_nodelay(true).map_err(|e| {
Error::Connection(format!("Failed to set TCP_NODELAY: {}", e))
})?;
}
#[cfg(unix)]
if options.tcp_keepalive {
use socket2::{
Socket,
TcpKeepalive,
};
use std::os::unix::io::{
AsRawFd,
FromRawFd,
};
let socket = unsafe { Socket::from_raw_fd(stream.as_raw_fd()) };
let mut keepalive =
TcpKeepalive::new().with_time(options.tcp_keepalive_idle);
#[cfg(any(target_os = "linux", target_os = "macos"))]
{
keepalive =
keepalive.with_interval(options.tcp_keepalive_interval);
}
socket.set_tcp_keepalive(&keepalive).map_err(|e| {
Error::Connection(format!(
"Failed to set TCP keepalive: {}",
e
))
})?;
std::mem::forget(socket);
}
#[cfg(windows)]
if options.tcp_keepalive {
use socket2::{
Socket,
TcpKeepalive,
};
use std::os::windows::io::{
AsRawSocket,
FromRawSocket,
};
let socket =
unsafe { Socket::from_raw_socket(stream.as_raw_socket()) };
let keepalive = TcpKeepalive::new()
.with_time(options.tcp_keepalive_idle)
.with_interval(options.tcp_keepalive_interval);
socket.set_tcp_keepalive(&keepalive).map_err(|e| {
Error::Connection(format!(
"Failed to set TCP keepalive: {}",
e
))
})?;
std::mem::forget(socket);
}
Ok(Self::new(stream))
}
#[cfg(feature = "tls")]
pub async fn connect_with_tls(
host: &str,
port: u16,
options: &ConnectionOptions,
ssl_config: Arc<rustls::ClientConfig>,
server_name: Option<&str>,
) -> Result<Self> {
let addr = format!("{}:{}", host, port);
let stream = if options.connect_timeout > Duration::ZERO {
tokio::time::timeout(
options.connect_timeout,
TcpStream::connect(&addr),
)
.await
.map_err(|_| {
Error::Connection(format!(
"Connection timeout after {:?} to {}",
options.connect_timeout, addr
))
})?
.map_err(|e| {
Error::Connection(format!(
"Failed to connect to {}: {}",
addr, e
))
})?
} else {
TcpStream::connect(&addr).await.map_err(|e| {
Error::Connection(format!(
"Failed to connect to {}: {}",
addr, e
))
})?
};
if options.tcp_nodelay {
stream.set_nodelay(true).map_err(|e| {
Error::Connection(format!("Failed to set TCP_NODELAY: {}", e))
})?;
}
#[cfg(unix)]
if options.tcp_keepalive {
use socket2::{
Socket,
TcpKeepalive,
};
use std::os::unix::io::{
AsRawFd,
FromRawFd,
};
let socket = unsafe { Socket::from_raw_fd(stream.as_raw_fd()) };
let mut keepalive =
TcpKeepalive::new().with_time(options.tcp_keepalive_idle);
#[cfg(any(target_os = "linux", target_os = "macos"))]
{
keepalive =
keepalive.with_interval(options.tcp_keepalive_interval);
}
socket.set_tcp_keepalive(&keepalive).map_err(|e| {
Error::Connection(format!(
"Failed to set TCP keepalive: {}",
e
))
})?;
std::mem::forget(socket);
}
#[cfg(windows)]
if options.tcp_keepalive {
use socket2::{
Socket,
TcpKeepalive,
};
use std::os::windows::io::{
AsRawSocket,
FromRawSocket,
};
let socket =
unsafe { Socket::from_raw_socket(stream.as_raw_socket()) };
let keepalive = TcpKeepalive::new()
.with_time(options.tcp_keepalive_idle)
.with_interval(options.tcp_keepalive_interval);
socket.set_tcp_keepalive(&keepalive).map_err(|e| {
Error::Connection(format!(
"Failed to set TCP keepalive: {}",
e
))
})?;
std::mem::forget(socket);
}
let connector = TlsConnector::from(ssl_config);
let server_name_to_use = server_name.unwrap_or(host);
let domain =
ServerName::try_from(server_name_to_use).map_err(|e| {
Error::Connection(format!(
"Invalid server name '{}': {}",
server_name_to_use, e
))
})?;
let tls_stream =
connector.connect(domain, stream).await.map_err(|e| {
Error::Connection(format!("TLS handshake failed: {}", e))
})?;
Ok(Self::new_tls(tls_stream))
}
pub async fn read_varint(&mut self) -> Result<u64> {
WireFormat::read_varint64(&mut self.reader).await
}
pub async fn write_varint(&mut self, value: u64) -> Result<()> {
WireFormat::write_varint64(&mut self.writer, value).await
}
pub async fn read_u8(&mut self) -> Result<u8> {
Ok(self.reader.read_u8().await?)
}
pub async fn read_u16(&mut self) -> Result<u16> {
Ok(self.reader.read_u16_le().await?)
}
pub async fn read_u32(&mut self) -> Result<u32> {
Ok(self.reader.read_u32_le().await?)
}
pub async fn read_u64(&mut self) -> Result<u64> {
Ok(self.reader.read_u64_le().await?)
}
pub async fn read_i8(&mut self) -> Result<i8> {
Ok(self.reader.read_i8().await?)
}
pub async fn read_i16(&mut self) -> Result<i16> {
Ok(self.reader.read_i16_le().await?)
}
pub async fn read_i32(&mut self) -> Result<i32> {
Ok(self.reader.read_i32_le().await?)
}
pub async fn read_i64(&mut self) -> Result<i64> {
Ok(self.reader.read_i64_le().await?)
}
pub async fn write_u8(&mut self, value: u8) -> Result<()> {
Ok(self.writer.write_u8(value).await?)
}
pub async fn write_u16(&mut self, value: u16) -> Result<()> {
Ok(self.writer.write_u16_le(value).await?)
}
pub async fn write_u32(&mut self, value: u32) -> Result<()> {
Ok(self.writer.write_u32_le(value).await?)
}
pub async fn write_u64(&mut self, value: u64) -> Result<()> {
Ok(self.writer.write_u64_le(value).await?)
}
pub async fn write_u128(&mut self, value: u128) -> Result<()> {
Ok(self.writer.write_u128_le(value).await?)
}
pub async fn write_i8(&mut self, value: i8) -> Result<()> {
Ok(self.writer.write_i8(value).await?)
}
pub async fn write_i16(&mut self, value: i16) -> Result<()> {
Ok(self.writer.write_i16_le(value).await?)
}
pub async fn write_i32(&mut self, value: i32) -> Result<()> {
Ok(self.writer.write_i32_le(value).await?)
}
pub async fn write_i64(&mut self, value: i64) -> Result<()> {
Ok(self.writer.write_i64_le(value).await?)
}
pub async fn read_string(&mut self) -> Result<String> {
WireFormat::read_string(&mut self.reader).await
}
pub async fn write_string(&mut self, s: &str) -> Result<()> {
WireFormat::write_string(&mut self.writer, s).await
}
pub async fn write_quoted_string(&mut self, s: &str) -> Result<()> {
WireFormat::write_quoted_string(&mut self.writer, s).await
}
pub async fn read_bytes(&mut self, len: usize) -> Result<Bytes> {
let mut buf = vec![0u8; len];
self.reader.read_exact(&mut buf).await?;
Ok(Bytes::from(buf))
}
pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
self.reader.read_exact(buf).await?;
Ok(())
}
pub async fn write_bytes(&mut self, data: &[u8]) -> Result<()> {
Ok(self.writer.write_all(data).await?)
}
pub async fn flush(&mut self) -> Result<()> {
Ok(self.writer.flush().await?)
}
pub async fn read_packet(&mut self) -> Result<Bytes> {
let len = self.read_varint().await? as usize;
if len == 0 {
return Ok(Bytes::new());
}
if len > 0x40000000 {
return Err(Error::Protocol(format!("Packet too large: {}", len)));
}
self.read_bytes(len).await
}
pub async fn write_packet(&mut self, data: &[u8]) -> Result<()> {
self.write_varint(data.len() as u64).await?;
self.write_bytes(data).await?;
Ok(())
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn test_buffer_sizes() {
assert_eq!(DEFAULT_READ_BUFFER_SIZE, 8192);
assert_eq!(DEFAULT_WRITE_BUFFER_SIZE, 8192);
}
}