use std::io::{self, Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use super::error::NetbatError;
use super::frame::{decode_line, dispatch_frame, encode_response, ResponseFrame};
use super::limits::{IoTimeouts, Limits};
pub const DEFAULT_MAX_CONNECTIONS: usize = 1024;
pub const DEFAULT_MAX_REQUESTS_PER_CONNECTION: usize = 1;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct TcpServerConfig {
pub limits: Limits,
pub timeouts: IoTimeouts,
pub max_connections: usize,
pub max_requests_per_connection: usize,
pub idle_sleep: Duration,
}
impl Default for TcpServerConfig {
fn default() -> Self {
Self {
limits: Limits::default(),
timeouts: IoTimeouts::default(),
max_connections: DEFAULT_MAX_CONNECTIONS,
max_requests_per_connection: DEFAULT_MAX_REQUESTS_PER_CONNECTION,
idle_sleep: Duration::from_millis(10),
}
}
}
impl TcpServerConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_limits(mut self, limits: Limits) -> Self {
self.limits = limits;
self
}
#[must_use]
pub const fn with_timeouts(mut self, timeouts: IoTimeouts) -> Self {
self.timeouts = timeouts;
self
}
#[must_use]
pub const fn with_max_connections(mut self, value: usize) -> Self {
self.max_connections = value;
self
}
#[must_use]
pub const fn with_max_requests_per_connection(mut self, value: usize) -> Self {
self.max_requests_per_connection = value;
self
}
#[must_use]
pub const fn with_idle_sleep(mut self, value: Duration) -> Self {
self.idle_sleep = value;
self
}
}
#[derive(Clone, Debug, Default)]
pub struct ShutdownHandle {
inner: Arc<AtomicBool>,
}
impl ShutdownHandle {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn shutdown(&self) {
self.inner.store(true, Ordering::Release);
}
#[must_use]
pub fn is_shutdown(&self) -> bool {
self.inner.load(Ordering::Acquire)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[non_exhaustive]
pub struct TcpServeStats {
pub accepted_connections: usize,
pub served_requests: usize,
pub failed_requests: usize,
pub malformed_requests: usize,
pub limit_failures: usize,
pub runtime_failures: usize,
pub connection_io_failures: usize,
pub shutdown_requested: bool,
}
#[tracing::instrument(name = "netbat.serve_stream", skip_all)]
pub fn serve_stream<S>(
stream: &mut S,
core: &mut syncbat::Core,
limits: &Limits,
) -> Result<ResponseFrame, NetbatError>
where
S: Read + Write,
{
let line = match read_line(stream, limits.max_line_bytes) {
Ok(line) => line,
Err(NetbatError::EmptyStream) => {
tracing::debug!("client closed before sending request");
return Err(NetbatError::EmptyStream);
}
Err(error) => {
let encoded = encode_response(Err(&error));
let _ = stream.write_all(&encoded);
return Err(error);
}
};
let frame = decode_line(&line, limits);
let response = match frame {
Ok(frame) => match dispatch_frame(core, frame, limits) {
Ok(response) => {
let encoded = encode_response(Ok(response.output()));
stream.write_all(&encoded)?;
return Ok(response);
}
Err(error) => {
let encoded = encode_response(Err(&error));
stream.write_all(&encoded)?;
Err(error)
}
},
Err(error) => {
let encoded = encode_response(Err(&error));
stream.write_all(&encoded)?;
Err(error)
}
};
response
}
#[tracing::instrument(name = "netbat.serve_tcp_listener", skip_all, fields(
addr = %listener.local_addr().map(|a| a.to_string()).unwrap_or_default(),
max_connections = config.max_connections,
))]
pub fn serve_tcp_listener(
listener: TcpListener,
core: &mut syncbat::Core,
config: &TcpServerConfig,
shutdown: &ShutdownHandle,
) -> Result<TcpServeStats, NetbatError> {
listener.set_nonblocking(true)?;
let mut stats = TcpServeStats::default();
tracing::info!("accept loop started");
while !shutdown.is_shutdown() && stats.accepted_connections < config.max_connections {
match listener.accept() {
Ok((stream, addr)) => {
stats.accepted_connections += 1;
tracing::debug!(peer = %addr, "connection accepted");
stream.set_nonblocking(false)?;
apply_timeouts(&stream, config.timeouts)?;
serve_tcp_connection(stream, core, config, &mut stats)?;
}
Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
thread::sleep(config.idle_sleep);
}
Err(error) if error.kind() == io::ErrorKind::Interrupted => {}
Err(error) => return Err(error.into()),
}
}
stats.shutdown_requested = shutdown.is_shutdown();
tracing::info!(
accepted = stats.accepted_connections,
served = stats.served_requests,
failed = stats.failed_requests,
shutdown = stats.shutdown_requested,
"accept loop exiting",
);
drop(listener);
Ok(stats)
}
fn serve_tcp_connection(
mut stream: TcpStream,
core: &mut syncbat::Core,
config: &TcpServerConfig,
stats: &mut TcpServeStats,
) -> Result<(), NetbatError> {
serve_connection_loop(&mut stream, core, config, stats)
}
fn serve_connection_loop<S: Read + Write>(
stream: &mut S,
core: &mut syncbat::Core,
config: &TcpServerConfig,
stats: &mut TcpServeStats,
) -> Result<(), NetbatError> {
for _ in 0..config.max_requests_per_connection {
match serve_stream(stream, core, &config.limits) {
Ok(_) => stats.served_requests += 1,
Err(NetbatError::EmptyStream) => return Ok(()),
Err(NetbatError::Io { .. }) => {
stats.connection_io_failures += 1;
tracing::debug!("connection torn down by peer IO error");
return Ok(());
}
Err(error @ NetbatError::LineTooLong { .. }) => {
stats.failed_requests += 1;
record_request_failure(stats, &error);
tracing::debug!("closing connection after LineTooLong to resync framing");
return Ok(());
}
Err(error) => {
stats.failed_requests += 1;
record_request_failure(stats, &error);
}
}
}
Ok(())
}
fn apply_timeouts(stream: &TcpStream, timeouts: IoTimeouts) -> Result<(), NetbatError> {
stream.set_read_timeout(timeouts.read)?;
stream.set_write_timeout(timeouts.write)?;
Ok(())
}
fn read_line<R: Read>(reader: &mut R, max_line_bytes: usize) -> Result<Vec<u8>, NetbatError> {
let mut line = Vec::new();
let mut byte = [0_u8; 1];
loop {
match reader.read(&mut byte) {
Ok(0) if line.is_empty() => return Err(NetbatError::EmptyStream),
Ok(0) => return Ok(line),
Ok(_) => {
line.push(byte[0]);
if line.len() > max_line_bytes {
return Err(NetbatError::LineTooLong {
max: max_line_bytes,
});
}
if byte[0] == b'\n' {
return Ok(line);
}
}
Err(error) if error.kind() == io::ErrorKind::Interrupted => {}
Err(error) => return Err(error.into()),
}
}
}
fn record_request_failure(stats: &mut TcpServeStats, error: &NetbatError) {
match error {
NetbatError::LineTooLong { .. }
| NetbatError::OperationNameTooLong { .. }
| NetbatError::InputTooLarge { .. }
| NetbatError::OutputTooLarge { .. } => stats.limit_failures += 1,
NetbatError::MalformedRequest { .. } | NetbatError::UnsupportedProtocolVersion { .. } => {
stats.malformed_requests += 1;
}
NetbatError::Runtime(_) => stats.runtime_failures += 1,
NetbatError::Io { .. } | NetbatError::EmptyStream => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use syncbat::{Core, EffectClass, Handler, HandlerResult, OperationDescriptor};
const PING: OperationDescriptor = OperationDescriptor::new(
"ping",
EffectClass::Inspect,
"schema.ping.input.v1",
"schema.ping.output.v1",
"receipt.ping.v1",
);
struct PingHandler;
impl Handler for PingHandler {
fn handle(&mut self, input: &[u8], _cx: &mut syncbat::Ctx<'_>) -> HandlerResult {
Ok(input.to_vec())
}
}
fn core_with_ping() -> Core {
let mut builder = Core::builder();
builder.register(PING, PingHandler).expect("register");
builder.build().expect("build")
}
struct WriteFailsAfterRead {
request: Cursor<Vec<u8>>,
}
impl WriteFailsAfterRead {
fn new(request: &[u8]) -> Self {
Self {
request: Cursor::new(request.to_vec()),
}
}
}
impl Read for WriteFailsAfterRead {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.request.read(buf)
}
}
impl Write for WriteFailsAfterRead {
fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
Err(io::Error::from(io::ErrorKind::BrokenPipe))
}
fn flush(&mut self) -> io::Result<()> {
Err(io::Error::from(io::ErrorKind::BrokenPipe))
}
}
#[test]
fn peer_io_failure_does_not_propagate_from_connection() {
let mut stream = WriteFailsAfterRead::new(b"NETBAT/1 CALL ping 6869\n");
let mut core = core_with_ping();
let config = TcpServerConfig::default();
let mut stats = TcpServeStats::default();
let outcome = serve_connection_loop(&mut stream, &mut core, &config, &mut stats);
assert!(
outcome.is_ok(),
"per-connection IO failure must not escalate; got {outcome:?}"
);
assert_eq!(stats.connection_io_failures, 1);
assert_eq!(stats.served_requests, 0);
assert_eq!(stats.failed_requests, 0);
}
}