use socket2::Type as SockType;
use std::io;
use std::net::SocketAddr;
use std::time::Duration;
use crate::SocketFamily;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TcpSocketType {
Stream,
Raw,
}
impl TcpSocketType {
pub fn is_stream(&self) -> bool {
matches!(self, TcpSocketType::Stream)
}
pub fn is_raw(&self) -> bool {
matches!(self, TcpSocketType::Raw)
}
pub(crate) fn to_sock_type(&self) -> SockType {
match self {
TcpSocketType::Stream => SockType::STREAM,
TcpSocketType::Raw => SockType::RAW,
}
}
}
#[derive(Debug, Clone)]
pub struct TcpConfig {
pub socket_family: SocketFamily,
pub socket_type: TcpSocketType,
pub bind_addr: Option<SocketAddr>,
pub nonblocking: bool,
pub reuseaddr: Option<bool>,
pub reuseport: Option<bool>,
pub nodelay: Option<bool>,
pub linger: Option<Duration>,
pub ttl: Option<u32>,
pub hoplimit: Option<u32>,
pub read_timeout: Option<Duration>,
pub write_timeout: Option<Duration>,
pub recv_buffer_size: Option<usize>,
pub send_buffer_size: Option<usize>,
pub tos: Option<u32>,
pub tclass_v6: Option<u32>,
pub only_v6: Option<bool>,
pub bind_device: Option<String>,
pub keepalive: Option<bool>,
}
impl TcpConfig {
pub fn new(socket_family: SocketFamily) -> Self {
match socket_family {
SocketFamily::IPV4 => Self::v4_stream(),
SocketFamily::IPV6 => Self::v6_stream(),
}
}
pub fn v4_stream() -> Self {
Self {
socket_family: SocketFamily::IPV4,
socket_type: TcpSocketType::Stream,
bind_addr: None,
nonblocking: false,
reuseaddr: None,
reuseport: None,
nodelay: None,
linger: None,
ttl: None,
hoplimit: None,
read_timeout: None,
write_timeout: None,
recv_buffer_size: None,
send_buffer_size: None,
tos: None,
tclass_v6: None,
only_v6: None,
bind_device: None,
keepalive: None,
}
}
pub fn raw_v4() -> Self {
Self {
socket_family: SocketFamily::IPV4,
socket_type: TcpSocketType::Raw,
..Self::v4_stream()
}
}
pub fn v6_stream() -> Self {
Self {
socket_family: SocketFamily::IPV6,
socket_type: TcpSocketType::Stream,
..Self::v4_stream()
}
}
pub fn raw_v6() -> Self {
Self {
socket_family: SocketFamily::IPV6,
socket_type: TcpSocketType::Raw,
..Self::v4_stream()
}
}
pub fn with_bind(mut self, addr: SocketAddr) -> Self {
self.bind_addr = Some(addr);
self
}
pub fn with_bind_addr(self, addr: SocketAddr) -> Self {
self.with_bind(addr)
}
pub fn with_nonblocking(mut self, flag: bool) -> Self {
self.nonblocking = flag;
self
}
pub fn with_reuseaddr(mut self, flag: bool) -> Self {
self.reuseaddr = Some(flag);
self
}
pub fn with_reuseport(mut self, flag: bool) -> Self {
self.reuseport = Some(flag);
self
}
pub fn with_nodelay(mut self, flag: bool) -> Self {
self.nodelay = Some(flag);
self
}
pub fn with_linger(mut self, dur: Duration) -> Self {
self.linger = Some(dur);
self
}
pub fn with_ttl(mut self, ttl: u32) -> Self {
self.ttl = Some(ttl);
self
}
pub fn with_hoplimit(mut self, hops: u32) -> Self {
self.hoplimit = Some(hops);
self
}
pub fn with_hop_limit(self, hops: u32) -> Self {
self.with_hoplimit(hops)
}
pub fn with_keepalive(mut self, on: bool) -> Self {
self.keepalive = Some(on);
self
}
pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
self.read_timeout = Some(timeout);
self
}
pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
self.write_timeout = Some(timeout);
self
}
pub fn with_recv_buffer_size(mut self, size: usize) -> Self {
self.recv_buffer_size = Some(size);
self
}
pub fn with_send_buffer_size(mut self, size: usize) -> Self {
self.send_buffer_size = Some(size);
self
}
pub fn with_tos(mut self, tos: u32) -> Self {
self.tos = Some(tos);
self
}
pub fn with_tclass_v6(mut self, tclass: u32) -> Self {
self.tclass_v6 = Some(tclass);
self
}
pub fn with_only_v6(mut self, only_v6: bool) -> Self {
self.only_v6 = Some(only_v6);
self
}
pub fn with_bind_device(mut self, iface: impl Into<String>) -> Self {
self.bind_device = Some(iface.into());
self
}
pub fn validate(&self) -> io::Result<()> {
if let Some(addr) = self.bind_addr {
let addr_family = crate::SocketFamily::from_socket_addr(&addr);
if addr_family != self.socket_family {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"bind_addr family does not match socket_family",
));
}
}
if self.socket_family.is_v4() {
if self.hoplimit.is_some() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"hoplimit is only supported for IPv6 TCP sockets",
));
}
if self.tclass_v6.is_some() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"tclass_v6 is only supported for IPv6 TCP sockets",
));
}
if self.only_v6.is_some() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"only_v6 is only supported for IPv6 TCP sockets",
));
}
}
if self.socket_family.is_v6() && self.ttl.is_some() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"ttl is only supported for IPv4 TCP sockets",
));
}
if matches!(self.read_timeout, Some(timeout) if timeout.is_zero()) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"read_timeout must be greater than zero",
));
}
if matches!(self.write_timeout, Some(timeout) if timeout.is_zero()) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"write_timeout must be greater than zero",
));
}
if matches!(self.recv_buffer_size, Some(0)) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"recv_buffer_size must be greater than zero",
));
}
if matches!(self.send_buffer_size, Some(0)) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"send_buffer_size must be greater than zero",
));
}
if matches!(self.bind_device.as_deref(), Some("")) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"bind_device must not be empty",
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tcp_config_builders() {
let addr: SocketAddr = "127.0.0.1:80".parse().unwrap();
let cfg = TcpConfig::new(SocketFamily::IPV4)
.with_bind_addr(addr)
.with_nonblocking(true)
.with_reuseaddr(true)
.with_reuseport(true)
.with_nodelay(true)
.with_ttl(10)
.with_recv_buffer_size(8192)
.with_send_buffer_size(8192)
.with_tos(0x10)
.with_tclass_v6(0x20);
assert_eq!(cfg.socket_family, SocketFamily::IPV4);
assert_eq!(cfg.socket_type, TcpSocketType::Stream);
assert_eq!(cfg.bind_addr, Some(addr));
assert!(cfg.nonblocking);
assert_eq!(cfg.reuseaddr, Some(true));
assert_eq!(cfg.reuseport, Some(true));
assert_eq!(cfg.nodelay, Some(true));
assert_eq!(cfg.ttl, Some(10));
assert_eq!(cfg.recv_buffer_size, Some(8192));
assert_eq!(cfg.send_buffer_size, Some(8192));
assert_eq!(cfg.tos, Some(0x10));
assert_eq!(cfg.tclass_v6, Some(0x20));
}
#[test]
fn new_with_ipv6_family_creates_v6_stream() {
let cfg = TcpConfig::new(SocketFamily::IPV6);
assert_eq!(cfg.socket_family, SocketFamily::IPV6);
assert_eq!(cfg.socket_type, TcpSocketType::Stream);
}
#[test]
fn tcp_config_validate_rejects_family_mismatch() {
let cfg = TcpConfig::v4_stream().with_bind("[::1]:0".parse().unwrap());
assert!(cfg.validate().is_err());
}
}