use std::time::Duration;
use bytes::{Bytes, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::timeout;
use crate::config::{Config, DEFAULT_SDU};
use crate::constants::PACKET_HEADER_SIZE;
use crate::error::{Error, Result};
use crate::packet::{Packet, PacketHeader};
use super::Transport;
pub struct TcpTransport {
stream: Option<TcpStream>,
read_buf: BytesMut,
sdu: u32,
large_sdu: bool,
connect_timeout: Duration,
}
impl TcpTransport {
pub fn new() -> Self {
Self {
stream: None,
read_buf: BytesMut::with_capacity(DEFAULT_SDU as usize),
sdu: DEFAULT_SDU,
large_sdu: false,
connect_timeout: Duration::from_secs(10),
}
}
pub fn with_sdu(sdu: u32) -> Self {
Self {
stream: None,
read_buf: BytesMut::with_capacity(sdu as usize),
sdu,
large_sdu: false,
connect_timeout: Duration::from_secs(10),
}
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub async fn connect(&mut self, addr: &str) -> Result<()> {
let stream = timeout(self.connect_timeout, TcpStream::connect(addr))
.await
.map_err(|_| Error::ConnectionTimeout(self.connect_timeout))?
.map_err(Error::Io)?;
stream.set_nodelay(true).map_err(Error::Io)?;
self.stream = Some(stream);
Ok(())
}
pub async fn connect_with_config(&mut self, config: &Config) -> Result<()> {
self.sdu = config.sdu;
self.connect_timeout = config.connect_timeout;
self.read_buf = BytesMut::with_capacity(self.sdu as usize);
self.connect(&config.socket_addr()).await
}
fn stream_mut(&mut self) -> Result<&mut TcpStream> {
self.stream
.as_mut()
.ok_or(Error::ConnectionClosed)
}
async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
let stream = self.stream_mut()?;
let mut buf = vec![0u8; n];
stream.read_exact(&mut buf).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
Error::ConnectionClosed
} else {
Error::Io(e)
}
})?;
Ok(Bytes::from(buf))
}
async fn read_packet_header(&mut self) -> Result<(PacketHeader, usize)> {
let header_bytes = self.read_exact(PACKET_HEADER_SIZE).await?;
let header = if self.large_sdu {
PacketHeader::parse_large_sdu(&header_bytes)?
} else {
PacketHeader::parse(&header_bytes)?
};
let total_length = header.length as usize;
Ok((header, total_length))
}
}
impl Default for TcpTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl Transport for TcpTransport {
async fn send(&mut self, data: &[u8]) -> Result<()> {
let stream = self.stream_mut()?;
stream.write_all(data).await.map_err(Error::Io)?;
stream.flush().await.map_err(Error::Io)?;
Ok(())
}
async fn send_packet(&mut self, packet: Bytes) -> Result<()> {
self.send(&packet).await
}
async fn receive_packet(&mut self) -> Result<Packet> {
let (header, total_length) = self.read_packet_header().await?;
let payload_length = total_length.saturating_sub(PACKET_HEADER_SIZE);
let payload = if payload_length > 0 {
self.read_exact(payload_length).await?
} else {
Bytes::new()
};
Ok(Packet::new(header, payload))
}
fn is_connected(&self) -> bool {
self.stream.is_some()
}
async fn close(&mut self) -> Result<()> {
if let Some(mut stream) = self.stream.take() {
stream.shutdown().await.map_err(Error::Io)?;
}
Ok(())
}
fn sdu(&self) -> u32 {
self.sdu
}
fn set_sdu(&mut self, sdu: u32) {
self.sdu = sdu;
if self.read_buf.capacity() < sdu as usize {
self.read_buf = BytesMut::with_capacity(sdu as usize);
}
}
fn uses_large_sdu(&self) -> bool {
self.large_sdu
}
fn set_large_sdu(&mut self, large_sdu: bool) {
self.large_sdu = large_sdu;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tcp_transport_new() {
let transport = TcpTransport::new();
assert!(!transport.is_connected());
assert_eq!(transport.sdu(), DEFAULT_SDU);
assert!(!transport.uses_large_sdu());
}
#[test]
fn test_tcp_transport_with_sdu() {
let transport = TcpTransport::with_sdu(16384);
assert_eq!(transport.sdu(), 16384);
}
#[test]
fn test_tcp_transport_set_sdu() {
let mut transport = TcpTransport::new();
transport.set_sdu(32768);
assert_eq!(transport.sdu(), 32768);
}
#[test]
fn test_tcp_transport_set_large_sdu() {
let mut transport = TcpTransport::new();
assert!(!transport.uses_large_sdu());
transport.set_large_sdu(true);
assert!(transport.uses_large_sdu());
}
}