pub use socket2::{TcpKeepalive};
use crate::network::adapter::{
Resource, Remote, Local, Adapter, SendStatus, AcceptedType, ReadStatus, ConnectionInfo,
ListeningInfo, PendingStatus,
};
use crate::network::{RemoteAddr, Readiness, TransportConnect, TransportListen};
use mio::net::{TcpListener, TcpStream};
use mio::event::{Source};
use socket2::{Socket, Domain, Type, Protocol};
use std::net::{SocketAddr};
#[cfg(unix)]
use std::ffi::{CString};
use std::io::{self, ErrorKind, Read, Write};
#[cfg(target_os = "macos")]
use std::num::NonZeroU32;
use std::ops::{Deref};
use std::mem::{forget, MaybeUninit};
use std::os::raw::c_int;
#[cfg(target_os = "windows")]
use std::os::windows::io::{FromRawSocket, AsRawSocket};
#[cfg(not(target_os = "windows"))]
use std::os::{fd::AsRawFd, unix::io::FromRawFd};
pub const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; pub const LISTENER_BACKLOG: c_int = 1024;
#[derive(Clone, Debug, Default)]
pub struct TcpConnectConfig {
bind_device: Option<String>,
source_address: Option<SocketAddr>,
keepalive: Option<TcpKeepalive>,
}
impl TcpConnectConfig {
pub fn with_bind_device(mut self, device: String) -> Self {
self.bind_device = Some(device);
self
}
pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
self.keepalive = Some(keepalive);
self
}
pub fn with_source_address(mut self, source_address: SocketAddr) -> Self {
self.source_address = Some(source_address);
self
}
}
#[derive(Clone, Debug, Default)]
pub struct TcpListenConfig {
bind_device: Option<String>,
keepalive: Option<TcpKeepalive>,
}
impl TcpListenConfig {
pub fn with_bind_device(mut self, device: String) -> Self {
self.bind_device = Some(device);
self
}
pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
self.keepalive = Some(keepalive);
self
}
}
pub(crate) struct TcpAdapter;
impl Adapter for TcpAdapter {
type Remote = RemoteResource;
type Local = LocalResource;
}
pub(crate) struct RemoteResource {
stream: TcpStream,
keepalive: Option<TcpKeepalive>,
}
impl Resource for RemoteResource {
fn source(&mut self) -> &mut dyn Source {
&mut self.stream
}
}
impl Remote for RemoteResource {
fn connect_with(
config: TransportConnect,
remote_addr: RemoteAddr,
) -> io::Result<ConnectionInfo<Self>> {
let config = match config {
TransportConnect::Tcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let peer_addr = *remote_addr.socket_addr();
let socket = Socket::new(
match peer_addr {
SocketAddr::V4 { .. } => Domain::IPV4,
SocketAddr::V6 { .. } => Domain::IPV6,
},
Type::STREAM,
Some(Protocol::TCP),
)?;
socket.set_nonblocking(true)?;
if let Some(source_address) = config.source_address {
socket.bind(&source_address.into())?;
}
#[cfg(unix)]
if let Some(bind_device) = config.bind_device {
let device = CString::new(bind_device)?;
#[cfg(not(target_os = "macos"))]
socket.bind_device(Some(device.as_bytes()))?;
#[cfg(target_os = "macos")]
match NonZeroU32::new(unsafe { libc::if_nametoindex(device.as_ptr()) }) {
Some(index) => socket.bind_device_by_index(Some(index))?,
None => {
return Err(io::Error::new(
ErrorKind::NotFound,
"Bind device interface not found",
))
}
}
}
match socket.connect(&peer_addr.into()) {
#[cfg(unix)]
Err(e) if e.raw_os_error() != Some(libc::EINPROGRESS) => return Err(e),
#[cfg(windows)]
Err(e) if e.kind() != io::ErrorKind::WouldBlock => return Err(e),
_ => {}
}
let stream = TcpStream::from_std(socket.into());
let local_addr = stream.local_addr()?;
Ok(ConnectionInfo {
remote: Self { stream, keepalive: config.keepalive },
local_addr,
peer_addr,
})
}
fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
let buffer: MaybeUninit<[u8; INPUT_BUFFER_SIZE]> = MaybeUninit::uninit();
let mut input_buffer = unsafe { buffer.assume_init() }; loop {
let stream = &self.stream;
match stream.deref().read(&mut input_buffer) {
Ok(0) => break ReadStatus::Disconnected,
Ok(size) => process_data(&input_buffer[..size]),
Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
break ReadStatus::WaitNextEvent
}
Err(ref err) if err.kind() == ErrorKind::ConnectionReset => {
break ReadStatus::Disconnected
}
Err(err) => {
log::error!("TCP receive error: {}", err);
break ReadStatus::Disconnected }
}
}
}
fn send(&self, data: &[u8]) -> SendStatus {
let mut total_bytes_sent = 0;
loop {
let stream = &self.stream;
match stream.deref().write(&data[total_bytes_sent..]) {
Ok(bytes_sent) => {
total_bytes_sent += bytes_sent;
if total_bytes_sent == data.len() {
break SendStatus::Sent
}
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue,
Err(err) => {
log::error!("TCP receive error: {}", err);
break SendStatus::ResourceNotFound }
}
}
}
fn pending(&self, _readiness: Readiness) -> PendingStatus {
let status = check_stream_ready(&self.stream);
if status == PendingStatus::Ready {
if let Some(keepalive) = &self.keepalive {
#[cfg(target_os = "windows")]
let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
#[cfg(not(target_os = "windows"))]
let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };
if let Err(e) = socket.set_tcp_keepalive(keepalive) {
log::warn!("TCP set keepalive error: {}", e);
}
forget(socket);
}
}
status
}
}
pub fn check_stream_ready(stream: &TcpStream) -> PendingStatus {
if let Ok(Some(_)) = stream.take_error() {
return PendingStatus::Disconnected
}
match stream.peer_addr() {
Ok(_) => PendingStatus::Ready,
Err(err) if err.kind() == io::ErrorKind::NotConnected => PendingStatus::Incomplete,
Err(err) if err.kind() == io::ErrorKind::InvalidInput => PendingStatus::Incomplete,
Err(_) => PendingStatus::Disconnected,
}
}
pub(crate) struct LocalResource {
listener: TcpListener,
keepalive: Option<TcpKeepalive>,
}
impl Resource for LocalResource {
fn source(&mut self) -> &mut dyn Source {
&mut self.listener
}
}
impl Local for LocalResource {
type Remote = RemoteResource;
fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
let config = match config {
TransportListen::Tcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let socket = Socket::new(
match addr {
SocketAddr::V4 { .. } => Domain::IPV4,
SocketAddr::V6 { .. } => Domain::IPV6,
},
Type::STREAM,
Some(Protocol::TCP),
)?;
socket.set_nonblocking(true)?;
socket.set_reuse_address(true)?;
#[cfg(unix)]
if let Some(bind_device) = config.bind_device {
let device = CString::new(bind_device)?;
#[cfg(not(target_os = "macos"))]
socket.bind_device(Some(device.as_bytes()))?;
#[cfg(target_os = "macos")]
match NonZeroU32::new(unsafe { libc::if_nametoindex(device.as_ptr()) }) {
Some(index) => socket.bind_device_by_index(Some(index))?,
None => {
return Err(io::Error::new(
ErrorKind::NotFound,
"Bind device interface not found",
))
}
}
}
socket.bind(&addr.into())?;
socket.listen(LISTENER_BACKLOG)?;
let listener = TcpListener::from_std(socket.into());
let local_addr = listener.local_addr().unwrap();
Ok(ListeningInfo {
local: { LocalResource { listener, keepalive: config.keepalive } },
local_addr,
})
}
fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
loop {
match self.listener.accept() {
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
addr,
RemoteResource { stream, keepalive: self.keepalive.clone() },
)),
Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
Err(err) => break log::error!("TCP accept error: {}", err), }
}
}
}