use std::{
io::{ErrorKind, Read, Write},
net::SocketAddr,
};
use mio::net::{TcpListener, TcpStream, UdpSocket};
use rustls::{ProtocolVersion, ServerConnection};
use rusty_ulid::Ulid;
use socket2::{Domain, Protocol, Socket, Type};
use sozu_command::{config::MAX_LOOP_ITERATIONS, logging::ansi_palette};
use crate::metrics::names;
#[derive(thiserror::Error, Debug)]
pub enum ServerBindError {
#[error("could not set bind to socket: {0}")]
BindError(std::io::Error),
#[error("could not listen on socket: {0}")]
Listen(std::io::Error),
#[error("could not set socket to nonblocking: {0}")]
SetNonBlocking(std::io::Error),
#[error("could not set reuse address: {0}")]
SetReuseAddress(std::io::Error),
#[error("could not set reuse address: {0}")]
SetReusePort(std::io::Error),
#[error("Could not create socket: {0}")]
SocketCreationError(std::io::Error),
#[error("Invalid socket address '{address}': {error}")]
InvalidSocketAddress { address: String, error: String },
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum SocketResult {
Continue,
Closed,
WouldBlock,
Error,
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum TransportProtocol {
Tcp,
Ssl2,
Ssl3,
Tls1_0,
Tls1_1,
Tls1_2,
Tls1_3,
}
pub trait SocketHandler {
fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult);
fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult);
fn socket_write_vectored(&mut self, _buf: &[std::io::IoSlice]) -> (usize, SocketResult);
fn socket_wants_write(&self) -> bool {
false
}
fn socket_close(&mut self) {}
fn socket_ref(&self) -> &TcpStream;
fn socket_mut(&mut self) -> &mut TcpStream;
fn protocol(&self) -> TransportProtocol;
fn read_error(&self);
fn write_error(&self);
fn session_ulid(&self) -> Option<Ulid> {
None
}
}
macro_rules! log_socket_context {
($self:expr) => {{
let (open, reset, grey, gray, white) = ansi_palette();
let ulid = match $self.session_ulid() {
Some(ulid) => ulid.to_string(),
None => "-".to_string(),
};
let snapshot = crate::socket::stats::socket_snapshot($self.socket_ref());
let rtt = snapshot.as_ref().map(|s| s.rtt);
let state = snapshot.as_ref().map(|s| s.state);
format!(
"[{ulid} - - -]\t{open}SOCKET{reset}\t{grey}Session{reset}({gray}peer{reset}={white}{peer:?}{reset}, {gray}local{reset}={white}{local:?}{reset}, {gray}rtt{reset}={white}{rtt:?}{reset}, {gray}state{reset}={white}{state:?}{reset}, {gray}protocol{reset}={white}{protocol:?}{reset})\t >>>",
open = open,
reset = reset,
grey = grey,
gray = gray,
white = white,
ulid = ulid,
peer = $self.socket_ref().peer_addr().ok(),
local = $self.socket_ref().local_addr().ok(),
rtt = rtt,
state = state,
protocol = $self.protocol(),
)
}};
}
fn log_socket_module_prefix(
stream: &TcpStream,
session_ulid: Option<Ulid>,
configured_peer: Option<SocketAddr>,
) -> String {
let (open, reset, grey, gray, white) = ansi_palette();
let ulid = match session_ulid {
Some(ulid) => ulid.to_string(),
None => "-".to_string(),
};
let snapshot = crate::socket::stats::socket_snapshot(stream);
let rtt = snapshot.as_ref().map(|s| s.rtt);
let state = snapshot.as_ref().map(|s| s.state);
format!(
"[{ulid} - - -]\t{open}SOCKET{reset}\t{grey}Session{reset}({gray}peer{reset}={white}{peer:?}{reset}, {gray}local{reset}={white}{local:?}{reset}, {gray}rtt{reset}={white}{rtt:?}{reset}, {gray}state{reset}={white}{state:?}{reset}, {gray}protocol{reset}={white}Tcp{reset})\t >>>",
peer = configured_peer.or_else(|| stream.peer_addr().ok()),
local = stream.local_addr().ok(),
)
}
fn tcp_socket_read(
stream: &mut TcpStream,
buf: &mut [u8],
session_ulid: Option<Ulid>,
configured_peer: Option<SocketAddr>,
) -> (usize, SocketResult) {
let mut size = 0usize;
let mut counter = 0;
loop {
counter += 1;
if counter > MAX_LOOP_ITERATIONS {
error!(
"{} MAX_LOOP_ITERATION reached in TcpStream::socket_read",
log_socket_module_prefix(stream, session_ulid, configured_peer)
);
incr!(names::socket::READ_INFINITE_LOOP_ERROR);
return (size, SocketResult::Error);
}
debug_assert!(
size <= buf.len(),
"read cursor {size} overran buffer len {} (would slice out of bounds)",
buf.len()
);
if size == buf.len() {
return (size, SocketResult::Continue);
}
match stream.read(&mut buf[size..]) {
Ok(0) => return (size, SocketResult::Closed),
Ok(sz) => {
debug_assert!(
sz <= buf.len() - size,
"read reported {sz} bytes into a {}-byte remaining slice",
buf.len() - size
);
size += sz;
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe
| ErrorKind::ConnectionRefused => return (size, SocketResult::Closed),
ErrorKind::HostUnreachable
| ErrorKind::NetworkUnreachable
| ErrorKind::TimedOut
| ErrorKind::NotConnected => {
warn!(
"{} socket_read error={:?}",
log_socket_module_prefix(stream, session_ulid, configured_peer),
e
);
return (size, SocketResult::Error);
}
_ => {
error!(
"{} socket_read error={:?}",
log_socket_module_prefix(stream, session_ulid, configured_peer),
e
);
return (size, SocketResult::Error);
}
},
}
}
}
fn tcp_socket_write(
stream: &mut TcpStream,
buf: &[u8],
session_ulid: Option<Ulid>,
configured_peer: Option<SocketAddr>,
) -> (usize, SocketResult) {
let mut size = 0usize;
let mut counter = 0;
loop {
counter += 1;
if counter > MAX_LOOP_ITERATIONS {
error!(
"{} MAX_LOOP_ITERATION reached in TcpStream::socket_write",
log_socket_module_prefix(stream, session_ulid, configured_peer)
);
incr!(names::socket::WRITE_INFINITE_LOOP_ERROR);
return (size, SocketResult::Error);
}
debug_assert!(
size <= buf.len(),
"write cursor {size} overran buffer len {} (would slice out of bounds)",
buf.len()
);
if size == buf.len() {
return (size, SocketResult::Continue);
}
match stream.write(&buf[size..]) {
Ok(0) => return (size, SocketResult::Continue),
Ok(sz) => {
debug_assert!(
sz <= buf.len() - size,
"write reported {sz} bytes from a {}-byte remaining slice",
buf.len() - size
);
size += sz;
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe
| ErrorKind::ConnectionRefused => {
incr!(names::tcp::WRITE_ERROR);
return (size, SocketResult::Closed);
}
ErrorKind::HostUnreachable
| ErrorKind::NetworkUnreachable
| ErrorKind::TimedOut
| ErrorKind::NotConnected => {
warn!(
"{} socket_write error={:?}",
log_socket_module_prefix(stream, session_ulid, configured_peer),
e
);
incr!(names::tcp::WRITE_ERROR);
return (size, SocketResult::Error);
}
_ => {
error!(
"{} socket_write error={:?}",
log_socket_module_prefix(stream, session_ulid, configured_peer),
e
);
incr!(names::tcp::WRITE_ERROR);
return (size, SocketResult::Error);
}
},
}
}
}
fn tcp_socket_write_vectored(
stream: &mut TcpStream,
bufs: &[std::io::IoSlice],
session_ulid: Option<Ulid>,
configured_peer: Option<SocketAddr>,
) -> (usize, SocketResult) {
match stream.write_vectored(bufs) {
Ok(sz) => {
debug_assert!(
sz <= bufs.iter().map(|b| b.len()).sum::<usize>(),
"write_vectored reported {sz} bytes from {}-byte slices",
bufs.iter().map(|b| b.len()).sum::<usize>()
);
(sz, SocketResult::Continue)
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => (0, SocketResult::WouldBlock),
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe
| ErrorKind::ConnectionRefused => {
incr!(names::tcp::WRITE_ERROR);
(0, SocketResult::Closed)
}
ErrorKind::HostUnreachable
| ErrorKind::NetworkUnreachable
| ErrorKind::TimedOut
| ErrorKind::NotConnected => {
warn!(
"{} socket_write error={:?}",
log_socket_module_prefix(stream, session_ulid, configured_peer),
e
);
incr!(names::tcp::WRITE_ERROR);
(0, SocketResult::Error)
}
_ => {
error!(
"{} socket_write error={:?}",
log_socket_module_prefix(stream, session_ulid, configured_peer),
e
);
incr!(names::tcp::WRITE_ERROR);
(0, SocketResult::Error)
}
},
}
}
impl SocketHandler for TcpStream {
fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
tcp_socket_read(self, buf, None, None)
}
fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
tcp_socket_write(self, buf, None, None)
}
fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
tcp_socket_write_vectored(self, bufs, None, None)
}
fn socket_ref(&self) -> &TcpStream {
self
}
fn socket_mut(&mut self) -> &mut TcpStream {
self
}
fn protocol(&self) -> TransportProtocol {
TransportProtocol::Tcp
}
fn read_error(&self) {
incr!(names::tcp::READ_ERROR);
}
fn write_error(&self) {
incr!(names::tcp::WRITE_ERROR);
}
}
#[derive(Debug)]
pub struct SessionTcpStream {
pub stream: TcpStream,
pub session_ulid: Ulid,
pub configured_peer: Option<SocketAddr>,
}
impl SessionTcpStream {
pub fn new(stream: TcpStream, session_ulid: Ulid, configured_peer: Option<SocketAddr>) -> Self {
Self {
stream,
session_ulid,
configured_peer,
}
}
}
impl SocketHandler for SessionTcpStream {
fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
tcp_socket_read(
&mut self.stream,
buf,
Some(self.session_ulid),
self.configured_peer,
)
}
fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
tcp_socket_write(
&mut self.stream,
buf,
Some(self.session_ulid),
self.configured_peer,
)
}
fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
tcp_socket_write_vectored(
&mut self.stream,
bufs,
Some(self.session_ulid),
self.configured_peer,
)
}
fn socket_ref(&self) -> &TcpStream {
&self.stream
}
fn socket_mut(&mut self) -> &mut TcpStream {
&mut self.stream
}
fn protocol(&self) -> TransportProtocol {
TransportProtocol::Tcp
}
fn read_error(&self) {
incr!(names::tcp::READ_ERROR);
}
fn write_error(&self) {
incr!(names::tcp::WRITE_ERROR);
}
fn session_ulid(&self) -> Option<Ulid> {
Some(self.session_ulid)
}
}
pub struct FrontRustls {
pub stream: TcpStream,
pub session: ServerConnection,
pub peer_disconnected: bool,
pub peer_reset: bool,
pub session_ulid: Ulid,
}
impl std::fmt::Debug for FrontRustls {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FrontRustls")
.field("stream", &self.stream)
.finish_non_exhaustive()
}
}
impl SocketHandler for FrontRustls {
fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
let mut size = 0usize;
let mut can_read = true;
let mut is_error = false;
let mut is_closed = false;
let mut counter = 0;
loop {
counter += 1;
if counter > MAX_LOOP_ITERATIONS {
error!(
"{} MAX_LOOP_ITERATION reached in FrontRustls::socket_read",
log_socket_context!(self)
);
incr!(names::rustls::READ_INFINITE_LOOP_ERROR);
is_error = true;
break;
}
debug_assert!(
size <= buf.len(),
"rustls read cursor {size} overran buffer len {} (would slice out of bounds)",
buf.len()
);
if size == buf.len() {
break;
}
if !can_read | is_error | is_closed {
break;
}
match self.session.read_tls(&mut self.stream) {
Ok(0) => {
can_read = false;
is_closed = true;
self.peer_disconnected = true;
}
Ok(_sz) => {}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
can_read = false;
}
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => {
is_closed = true;
self.peer_disconnected = true;
self.peer_reset = true;
}
ErrorKind::Other => {}
_ => {
error!(
"{} could not read TLS stream from socket: {:?}",
log_socket_context!(self),
e
);
is_error = true;
break;
}
},
}
if let Err(e) = self.session.process_new_packets() {
error!(
"{} could not process read TLS packets: {:?}",
log_socket_context!(self),
e
);
is_error = true;
break;
}
while !self.session.wants_read() {
match self.session.reader().read(&mut buf[size..]) {
Ok(0) => break,
Ok(sz) => {
debug_assert!(
sz <= buf.len() - size,
"rustls reader returned {sz} bytes into a {}-byte remaining slice",
buf.len() - size
);
size += sz;
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
break;
}
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => {
is_closed = true;
break;
}
_ => {
error!(
"{} could not read data from TLS stream: {:?}",
log_socket_context!(self),
e
);
is_error = true;
break;
}
},
}
}
}
debug_assert!(
size <= buf.len(),
"rustls socket_read returned {size} bytes for a {}-byte buffer",
buf.len()
);
debug_assert!(
!(is_error && is_closed),
"rustls socket_read cannot be both Error and Closed"
);
if is_error {
(size, SocketResult::Error)
} else if is_closed {
(size, SocketResult::Closed)
} else if size == buf.len() {
(size, SocketResult::Continue)
} else if !can_read {
(size, SocketResult::WouldBlock)
} else {
(size, SocketResult::Continue)
}
}
fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
if self.peer_reset {
return (0, SocketResult::Closed);
}
let mut buffered_size = 0usize;
let mut can_write = true;
let mut is_error = false;
let mut is_closed = false;
let mut counter = 0;
loop {
counter += 1;
if counter > MAX_LOOP_ITERATIONS {
error!(
"{} MAX_LOOP_ITERATION reached in FrontRustls::socket_write",
log_socket_context!(self)
);
incr!(names::rustls::WRITE_INFINITE_LOOP_ERROR);
is_error = true;
break;
}
debug_assert!(
buffered_size <= buf.len(),
"rustls write cursor {buffered_size} overran buffer len {} (would slice out of bounds)",
buf.len()
);
if buffered_size == buf.len() {
break;
}
if !can_write | is_error | is_closed {
break;
}
match self.session.writer().write(&buf[buffered_size..]) {
Ok(0) => {} Ok(sz) => {
debug_assert!(
sz <= buf.len() - buffered_size,
"rustls writer absorbed {sz} bytes from a {}-byte remaining slice",
buf.len() - buffered_size
);
buffered_size += sz;
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
}
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => {
incr!(names::rustls::WRITE_ERROR);
is_closed = true;
self.peer_reset = true;
break;
}
_ => {
error!(
"{} could not write data to TLS stream: {:?}",
log_socket_context!(self),
e
);
incr!(names::rustls::WRITE_ERROR);
is_error = true;
break;
}
},
}
loop {
match self.session.write_tls(&mut self.stream) {
Ok(0) => {
break;
}
Ok(_sz) => {}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
can_write = false;
break;
}
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => {
incr!(names::rustls::WRITE_ERROR);
is_closed = true;
self.peer_reset = true;
break;
}
_ => {
error!(
"{} could not write TLS stream to socket: {:?}",
log_socket_context!(self),
e
);
incr!(names::rustls::WRITE_ERROR);
is_error = true;
break;
}
},
}
}
}
if !is_error && !is_closed && can_write && self.session.wants_write() {
loop {
match self.session.write_tls(&mut self.stream) {
Ok(0) => break,
Ok(_) => {}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
can_write = false;
break;
}
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => {
incr!(names::rustls::WRITE_ERROR);
is_closed = true;
self.peer_reset = true;
break;
}
_ => {
error!(
"{} could not flush TLS stream to socket: {:?}",
log_socket_context!(self),
e
);
incr!(names::rustls::WRITE_ERROR);
is_error = true;
break;
}
},
}
}
}
debug_assert!(
buffered_size <= buf.len(),
"rustls socket_write reported {buffered_size} bytes for a {}-byte buffer",
buf.len()
);
debug_assert!(
!(is_error && is_closed),
"rustls socket_write cannot be both Error and Closed"
);
if is_error {
(buffered_size, SocketResult::Error)
} else if is_closed {
(buffered_size, SocketResult::Closed)
} else if !can_write {
(buffered_size, SocketResult::WouldBlock)
} else {
(buffered_size, SocketResult::Continue)
}
}
fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
if self.peer_reset {
return (0, SocketResult::Closed);
}
let total_len: usize = bufs.iter().map(|b| b.len()).sum();
let mut buffered_size = 0usize;
let mut can_write = true;
let mut is_error = false;
let mut is_closed = false;
let mut counter = 0;
loop {
counter += 1;
if counter > MAX_LOOP_ITERATIONS {
error!(
"{} MAX_LOOP_ITERATION reached in FrontRustls::socket_write_vectored",
log_socket_context!(self)
);
incr!(names::rustls::WRITE_INFINITE_LOOP_ERROR);
is_error = true;
break;
}
debug_assert!(
buffered_size <= total_len,
"rustls vectored write cursor {buffered_size} overran total slice len {total_len}"
);
if buffered_size == total_len {
break;
}
if !can_write | is_error | is_closed {
break;
}
if buffered_size == 0 {
match self.session.writer().write_vectored(bufs) {
Ok(0) => {}
Ok(sz) => {
debug_assert!(
sz <= total_len,
"rustls writer absorbed {sz} bytes from {total_len}-byte slices"
);
buffered_size += sz;
}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {}
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => {
incr!(names::rustls::WRITE_ERROR);
is_closed = true;
self.peer_reset = true;
break;
}
_ => {
error!(
"{} could not write data to TLS stream: {:?}",
log_socket_context!(self),
e
);
incr!(names::rustls::WRITE_ERROR);
is_error = true;
break;
}
},
}
}
if buffered_size > 0 && buffered_size < total_len {
loop {
match self.session.write_tls(&mut self.stream) {
Ok(0) => break,
Ok(_) => {}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
can_write = false;
break;
}
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => {
incr!(names::rustls::WRITE_ERROR);
is_closed = true;
self.peer_reset = true;
break;
}
_ => {
error!(
"{} could not write TLS stream to socket: {:?}",
log_socket_context!(self),
e
);
incr!(names::rustls::WRITE_ERROR);
is_error = true;
break;
}
},
}
}
break;
}
loop {
match self.session.write_tls(&mut self.stream) {
Ok(0) => {
break;
}
Ok(_sz) => {}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
can_write = false;
break;
}
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => {
incr!(names::rustls::WRITE_ERROR);
is_closed = true;
self.peer_reset = true;
break;
}
_ => {
error!(
"{} could not write TLS stream to socket: {:?}",
log_socket_context!(self),
e
);
incr!(names::rustls::WRITE_ERROR);
is_error = true;
break;
}
},
}
}
}
if !is_error && !is_closed && can_write && self.session.wants_write() {
loop {
match self.session.write_tls(&mut self.stream) {
Ok(0) => break,
Ok(_) => {}
Err(e) => match e.kind() {
ErrorKind::WouldBlock => {
can_write = false;
break;
}
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => {
incr!(names::rustls::WRITE_ERROR);
is_closed = true;
self.peer_reset = true;
break;
}
_ => {
error!(
"{} could not flush TLS stream to socket: {:?}",
log_socket_context!(self),
e
);
incr!(names::rustls::WRITE_ERROR);
is_error = true;
break;
}
},
}
}
}
debug_assert!(
buffered_size <= total_len,
"rustls socket_write_vectored reported {buffered_size} bytes for {total_len}-byte slices"
);
debug_assert!(
!(is_error && is_closed),
"rustls socket_write_vectored cannot be both Error and Closed"
);
if is_error {
(buffered_size, SocketResult::Error)
} else if is_closed {
(buffered_size, SocketResult::Closed)
} else if !can_write {
(buffered_size, SocketResult::WouldBlock)
} else {
(buffered_size, SocketResult::Continue)
}
}
fn socket_close(&mut self) {
self.session.send_close_notify();
}
fn socket_wants_write(&self) -> bool {
!self.peer_reset && self.session.wants_write()
}
fn socket_ref(&self) -> &TcpStream {
&self.stream
}
fn socket_mut(&mut self) -> &mut TcpStream {
&mut self.stream
}
fn protocol(&self) -> TransportProtocol {
self.session
.protocol_version()
.map(|version| match version {
ProtocolVersion::SSLv2 => TransportProtocol::Ssl2,
ProtocolVersion::SSLv3 => TransportProtocol::Ssl3,
ProtocolVersion::TLSv1_0 => TransportProtocol::Tls1_0,
ProtocolVersion::TLSv1_1 => TransportProtocol::Tls1_1,
ProtocolVersion::TLSv1_2 => TransportProtocol::Tls1_2,
ProtocolVersion::TLSv1_3 => TransportProtocol::Tls1_3,
_ => TransportProtocol::Tls1_3,
})
.unwrap_or(TransportProtocol::Tcp)
}
fn read_error(&self) {
incr!(names::rustls::READ_ERROR);
}
fn write_error(&self) {
incr!(names::rustls::WRITE_ERROR);
}
fn session_ulid(&self) -> Option<Ulid> {
Some(self.session_ulid)
}
}
pub fn server_bind(addr: SocketAddr) -> Result<TcpListener, ServerBindError> {
let sock = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))
.map_err(ServerBindError::SocketCreationError)?;
if cfg!(unix) {
sock.set_reuse_address(true)
.map_err(ServerBindError::SetReuseAddress)?;
}
sock.set_reuse_port(true)
.map_err(ServerBindError::SetReusePort)?;
sock.bind(&addr.into())
.map_err(ServerBindError::BindError)?;
sock.set_nonblocking(true)
.map_err(ServerBindError::SetNonBlocking)?;
sock.listen(1024).map_err(ServerBindError::Listen)?;
if let Ok(nonblocking) = sock.nonblocking() {
debug_assert!(
nonblocking,
"server_bind must return a non-blocking socket (the worker event loop is edge-triggered)"
);
}
#[cfg(unix)]
if let Ok(reuse_port) = sock.reuse_port() {
debug_assert!(
reuse_port,
"server_bind must set SO_REUSEPORT so the listener survives a hot-upgrade re-bind"
);
}
#[cfg(unix)]
if let Ok(reuse_address) = sock.reuse_address() {
debug_assert!(
reuse_address,
"server_bind must set SO_REUSEADDR on unix (mirrors libstd)"
);
}
if let Ok(local) = sock.local_addr() {
debug_assert_eq!(
local.is_ipv4(),
addr.is_ipv4(),
"bound socket family must match the requested address family"
);
debug_assert_eq!(
local.is_ipv6(),
addr.is_ipv6(),
"bound socket family must match the requested address family"
);
}
Ok(TcpListener::from_std(sock.into()))
}
pub fn udp_bind(addr: SocketAddr) -> Result<UdpSocket, ServerBindError> {
let sock = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))
.map_err(ServerBindError::SocketCreationError)?;
if cfg!(unix) {
sock.set_reuse_address(true)
.map_err(ServerBindError::SetReuseAddress)?;
}
sock.set_reuse_port(true)
.map_err(ServerBindError::SetReusePort)?;
sock.bind(&addr.into())
.map_err(ServerBindError::BindError)?;
sock.set_nonblocking(true)
.map_err(ServerBindError::SetNonBlocking)?;
if let Ok(nonblocking) = sock.nonblocking() {
debug_assert!(
nonblocking,
"udp_bind must return a non-blocking socket (the worker event loop is edge-triggered)"
);
}
#[cfg(unix)]
if let Ok(reuse_port) = sock.reuse_port() {
debug_assert!(
reuse_port,
"udp_bind must set SO_REUSEPORT so the listener survives a hot-upgrade re-bind"
);
}
#[cfg(unix)]
if let Ok(reuse_address) = sock.reuse_address() {
debug_assert!(
reuse_address,
"udp_bind must set SO_REUSEADDR on unix (mirrors libstd / server_bind)"
);
}
if let Ok(local) = sock.local_addr() {
debug_assert_eq!(
local.is_ipv4(),
addr.is_ipv4(),
"bound UDP socket family must match the requested address family"
);
debug_assert_eq!(
local.is_ipv6(),
addr.is_ipv6(),
"bound UDP socket family must match the requested address family"
);
}
Ok(UdpSocket::from_std(sock.into()))
}
pub fn udp_connect(backend: SocketAddr) -> Result<UdpSocket, ServerBindError> {
let unspecified: SocketAddr = match backend {
SocketAddr::V4(_) => (std::net::Ipv4Addr::UNSPECIFIED, 0).into(),
SocketAddr::V6(_) => (std::net::Ipv6Addr::UNSPECIFIED, 0).into(),
};
debug_assert_eq!(
unspecified.is_ipv4(),
backend.is_ipv4(),
"ephemeral bind family must match the backend family"
);
debug_assert_eq!(
unspecified.port(),
0,
"ephemeral bind must use port 0 so the kernel picks the source port"
);
let sock = Socket::new(
Domain::for_address(backend),
Type::DGRAM,
Some(Protocol::UDP),
)
.map_err(ServerBindError::SocketCreationError)?;
sock.bind(&unspecified.into())
.map_err(ServerBindError::BindError)?;
sock.set_nonblocking(true)
.map_err(ServerBindError::SetNonBlocking)?;
sock.connect(&backend.into())
.map_err(ServerBindError::BindError)?;
if let Ok(nonblocking) = sock.nonblocking() {
debug_assert!(
nonblocking,
"udp_connect must return a non-blocking socket (the worker event loop is edge-triggered)"
);
}
if let Ok(local) = sock.local_addr() {
debug_assert_eq!(
local.is_ipv4(),
backend.is_ipv4(),
"connected UDP socket family must match the backend family"
);
if let Some(local) = local.as_socket() {
debug_assert_ne!(
local.port(),
0,
"connect must bind a concrete ephemeral source port (the return-demux key)"
);
}
}
if let Ok(peer) = sock.peer_addr() {
if let Some(peer) = peer.as_socket() {
debug_assert_eq!(
peer, backend,
"connect must pin the peer to the requested backend (symmetric-NAT return-demux key)"
);
}
}
Ok(UdpSocket::from_std(sock.into()))
}
pub mod stats {
use std::{os::fd::AsRawFd, time::Duration};
use internal::{OPT_LEVEL, OPT_NAME, TcpInfo};
#[derive(Clone, Debug)]
pub struct TcpSnapshot {
pub rtt: Duration,
pub state: &'static str,
}
pub fn socket_rtt<A: AsRawFd>(socket: &A) -> Option<Duration> {
socket_info(socket.as_raw_fd()).map(|info| Duration::from_micros(info.rtt() as u64))
}
pub fn socket_snapshot<A: AsRawFd>(socket: &A) -> Option<TcpSnapshot> {
socket_info(socket.as_raw_fd()).map(|info| TcpSnapshot {
rtt: Duration::from_micros(info.rtt() as u64),
state: info.state(),
})
}
#[cfg(unix)]
pub fn socket_info(fd: libc::c_int) -> Option<TcpInfo> {
let mut tcp_info: TcpInfo = unsafe { std::mem::zeroed() };
let struct_len = std::mem::size_of::<TcpInfo>() as libc::socklen_t;
let mut len = struct_len;
let status = unsafe {
libc::getsockopt(
fd,
OPT_LEVEL,
OPT_NAME,
&mut tcp_info as *mut _ as *mut _,
&mut len,
)
};
if status != 0 {
None
} else {
debug_assert!(
len <= struct_len,
"getsockopt(TCP_INFO) wrote back len {len} > struct size {struct_len} (buffer overrun)"
);
Some(tcp_info)
}
}
#[cfg(not(unix))]
pub fn socketinfo(fd: libc::c_int) -> Option<TcpInfo> {
None
}
#[cfg(unix)]
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
mod internal {
#[cfg(target_os = "linux")]
pub const OPT_LEVEL: libc::c_int = libc::SOL_TCP;
#[cfg(any(
target_os = "freebsd",
target_os = "dragonfly",
target_os = "openbsd",
target_os = "netbsd"
))]
pub const OPT_LEVEL: libc::c_int = libc::IPPROTO_TCP;
pub const OPT_NAME: libc::c_int = libc::TCP_INFO;
#[derive(Clone, Debug)]
#[repr(C)]
pub struct TcpInfo {
tcpi_state: u8,
tcpi_ca_state: u8,
tcpi_retransmits: u8,
tcpi_probes: u8,
tcpi_backoff: u8,
tcpi_options: u8,
tcpi_snd_rcv_wscale: u8,
tcpi_rto: u32,
tcpi_ato: u32,
tcpi_snd_mss: u32,
tcpi_rcv_mss: u32,
tcpi_unacked: u32,
tcpi_sacked: u32,
tcpi_lost: u32,
tcpi_retrans: u32,
tcpi_fackets: u32,
tcpi_last_data_sent: u32,
tcpi_last_ack_sent: u32, tcpi_last_data_recv: u32,
tcpi_last_ack_recv: u32,
tcpi_pmtu: u32,
tcpi_rcv_ssthresh: u32,
tcpi_rtt: u32,
tcpi_rttvar: u32,
tcpi_snd_ssthresh: u32,
tcpi_snd_cwnd: u32,
tcpi_advmss: u32,
tcpi_reordering: u32,
}
impl TcpInfo {
pub fn rtt(&self) -> u32 {
self.tcpi_rtt
}
pub fn state(&self) -> &'static str {
match self.tcpi_state {
1 => "ESTABLISHED",
2 => "SYN_SENT",
3 => "SYN_RECV",
4 => "FIN_WAIT1",
5 => "FIN_WAIT2",
6 => "TIME_WAIT",
7 => "CLOSE",
8 => "CLOSE_WAIT",
9 => "LAST_ACK",
10 => "LISTEN",
11 => "CLOSING",
12 => "NEW_SYN_RECV",
_ => "UNKNOWN",
}
}
}
}
#[cfg(unix)]
#[cfg(any(target_os = "macos", target_os = "ios"))]
mod internal {
pub const OPT_LEVEL: libc::c_int = libc::IPPROTO_TCP;
pub const OPT_NAME: libc::c_int = 0x106;
#[derive(Clone, Debug)]
#[repr(C)]
pub struct TcpInfo {
tcpi_state: u8,
tcpi_snd_wscale: u8,
tcpi_rcv_wscale: u8,
__pad1: u8,
tcpi_options: u32,
tcpi_flags: u32,
tcpi_rto: u32,
tcpi_maxseg: u32,
tcpi_snd_ssthresh: u32,
tcpi_snd_cwnd: u32,
tcpi_snd_wnd: u32,
tcpi_snd_sbbytes: u32,
tcpi_rcv_wnd: u32,
tcpi_rttcur: u32,
tcpi_srtt: u32,
tcpi_rttvar: u32,
tcpi_tfo: u32,
tcpi_txpackets: u64,
tcpi_txbytes: u64,
tcpi_txretransmitbytes: u64,
tcpi_rxpackets: u64,
tcpi_rxbytes: u64,
tcpi_rxoutoforderbytes: u64,
tcpi_txretransmitpackets: u64,
}
impl TcpInfo {
pub fn rtt(&self) -> u32 {
self.tcpi_srtt * 1000
}
pub fn state(&self) -> &'static str {
match self.tcpi_state {
0 => "CLOSED",
1 => "LISTEN",
2 => "SYN_SENT",
3 => "SYN_RECEIVED",
4 => "ESTABLISHED",
5 => "CLOSE_WAIT",
6 => "FIN_WAIT_1",
7 => "CLOSING",
8 => "LAST_ACK",
9 => "FIN_WAIT_2",
10 => "TIME_WAIT",
_ => "UNKNOWN",
}
}
}
}
#[cfg(not(unix))]
#[derive(Clone, Debug)]
struct TcpInfo {}
#[test]
#[serial_test::serial]
fn test_rtt() {
let sock = std::net::TcpStream::connect("google.com:80").unwrap();
let fd = sock.as_raw_fd();
let info = socket_info(fd);
assert!(info.is_some());
println!("{info:#?}");
println!(
"rtt: {}",
sozu_command::logging::LogDuration(socket_rtt(&sock))
);
}
}