use std::collections::VecDeque;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use bytes::BytesMut;
use socket2::{Domain, Socket, Type};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use crate::codec::{AsciiDecoder, AsciiLimits, BinaryDecoder, BinaryLimits, DecodeOutcome};
use crate::context::{ConnectionInfo, Extensions, RequestContext};
use crate::error::Error;
use crate::response::Response;
use crate::router::Router;
use crate::types::{Op, Protocol, ReplyMode, Request, RequestMeta};
static NEXT_CLIENT_ID: AtomicU64 = AtomicU64::new(1);
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub max_line_len: usize,
pub max_blob_len: usize,
pub max_frame_len: usize,
pub max_inflight_requests: usize,
pub max_inflight_bytes: usize,
pub max_quiet_responses: usize,
pub max_quiet_bytes: usize,
pub write_batch_bytes: usize,
pub read_timeout: Option<Duration>,
pub write_timeout: Option<Duration>,
pub idle_timeout: Option<Duration>,
pub tcp_nodelay: bool,
pub backlog: Option<u32>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
max_line_len: 4 * 1024,
max_blob_len: 1 << 20,
max_frame_len: 2 * 1024 * 1024,
max_inflight_requests: 128,
max_inflight_bytes: 8 * 1024 * 1024,
max_quiet_responses: 256,
max_quiet_bytes: 2 * 1024 * 1024,
write_batch_bytes: 8 * 1024,
read_timeout: None,
write_timeout: None,
idle_timeout: None,
tcp_nodelay: true,
backlog: None,
}
}
}
impl ServerConfig {
pub fn builder() -> ServerConfigBuilder {
ServerConfigBuilder {
cfg: ServerConfig::default(),
}
}
}
pub struct ServerConfigBuilder {
cfg: ServerConfig,
}
impl ServerConfigBuilder {
pub fn max_line_len(mut self, value: usize) -> Self {
self.cfg.max_line_len = value.max(1);
self
}
pub fn max_blob_len(mut self, value: usize) -> Self {
self.cfg.max_blob_len = value.max(1);
self
}
pub fn max_frame_len(mut self, value: usize) -> Self {
self.cfg.max_frame_len = value.max(1);
self
}
pub fn max_inflight_requests(mut self, value: usize) -> Self {
self.cfg.max_inflight_requests = value.max(1);
self
}
pub fn max_inflight_bytes(mut self, value: usize) -> Self {
self.cfg.max_inflight_bytes = value.max(1);
self
}
pub fn max_quiet_responses(mut self, value: usize) -> Self {
self.cfg.max_quiet_responses = value.max(1);
self
}
pub fn max_quiet_bytes(mut self, value: usize) -> Self {
self.cfg.max_quiet_bytes = value.max(1);
self
}
pub fn write_batch_bytes(mut self, value: usize) -> Self {
self.cfg.write_batch_bytes = value.max(1);
self
}
pub fn read_timeout(mut self, value: Option<Duration>) -> Self {
self.cfg.read_timeout = value;
self
}
pub fn write_timeout(mut self, value: Option<Duration>) -> Self {
self.cfg.write_timeout = value;
self
}
pub fn idle_timeout(mut self, value: Option<Duration>) -> Self {
self.cfg.idle_timeout = value;
self
}
pub fn tcp_nodelay(mut self, value: bool) -> Self {
self.cfg.tcp_nodelay = value;
self
}
pub fn backlog(mut self, value: Option<u32>) -> Self {
self.cfg.backlog = value;
self
}
pub fn build(self) -> ServerConfig {
self.cfg
}
}
pub struct Server;
impl Server {
pub fn bind<A: ToString>(addr: A) -> ServerBuilder {
ServerBuilder {
addr: addr.to_string(),
cfg: ServerConfig::default(),
shutdown: None,
extensions_factory: Arc::new(|_| Extensions::default()),
}
}
}
pub struct ServerBuilder {
addr: String,
cfg: ServerConfig,
shutdown: Option<BoxFuture>,
extensions_factory: Arc<dyn Fn(ConnectionInfo) -> Extensions + Send + Sync>,
}
type BoxFuture = std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>;
impl ServerBuilder {
pub fn with_config(mut self, cfg: ServerConfig) -> Self {
self.cfg = cfg;
self
}
pub fn with_graceful_shutdown<F>(mut self, fut: F) -> Self
where
F: std::future::Future<Output = ()> + Send + 'static,
{
self.shutdown = Some(Box::pin(fut));
self
}
pub fn with_connection_extensions<F>(mut self, factory: F) -> Self
where
F: Fn(ConnectionInfo) -> Extensions + Send + Sync + 'static,
{
self.extensions_factory = Arc::new(factory);
self
}
pub async fn serve<State>(self, app: Router<State>) -> std::io::Result<()>
where
State: Send + Sync + 'static,
{
let listener = if let Some(backlog) = self.cfg.backlog {
bind_with_backlog(&self.addr, backlog)?
} else {
TcpListener::bind(&self.addr).await?
};
self.serve_with_listener(listener, app).await
}
pub async fn serve_with_listener<State>(
self,
listener: TcpListener,
app: Router<State>,
) -> std::io::Result<()>
where
State: Send + Sync + 'static,
{
let app = Arc::new(app);
if let Some(shutdown) = self.shutdown {
tokio::select! {
result = accept_loop(listener, app, self.cfg, self.extensions_factory) => result,
_ = shutdown => Ok(()),
}
} else {
accept_loop(listener, app, self.cfg, self.extensions_factory).await
}
}
}
async fn accept_loop<State>(
listener: TcpListener,
app: Arc<Router<State>>,
cfg: ServerConfig,
extensions_factory: Arc<dyn Fn(ConnectionInfo) -> Extensions + Send + Sync>,
) -> std::io::Result<()>
where
State: Send + Sync + 'static,
{
loop {
let (stream, peer_addr) = listener.accept().await?;
let local_addr = stream.local_addr()?;
let client_id = NEXT_CLIENT_ID.fetch_add(1, Ordering::Relaxed);
let info = ConnectionInfo {
peer_addr,
local_addr,
client_id,
};
let extensions = (extensions_factory)(info);
let app = Arc::clone(&app);
let cfg = cfg.clone();
tokio::spawn(async move {
let _ = handle_connection(stream, app, cfg, info, extensions).await;
});
}
}
async fn handle_connection<State>(
mut stream: TcpStream,
app: Arc<Router<State>>,
cfg: ServerConfig,
info: ConnectionInfo,
base_extensions: Extensions,
) -> std::io::Result<()>
where
State: Send + Sync + 'static,
{
if cfg.tcp_nodelay {
let _ = stream.set_nodelay(true);
}
let mut read_buf = BytesMut::with_capacity(4096);
let mut write_buf = BytesMut::with_capacity(cfg.write_batch_bytes);
let mut ascii = AsciiDecoder::new();
let mut binary = BinaryDecoder::new();
let mut protocol: Option<Protocol> = None;
let mut pending: VecDeque<(Request, RequestMeta, usize)> = VecDeque::new();
let mut pending_bytes: usize = 0;
let mut quiet = QuietBuffer::new();
loop {
loop {
if pending.len() >= cfg.max_inflight_requests {
break;
}
if protocol.is_none() {
if read_buf.is_empty() {
break;
}
protocol = Some(if read_buf[0] == 0x80 {
Protocol::Binary
} else {
Protocol::Ascii
});
}
let outcome = match protocol.unwrap() {
Protocol::Binary => binary.decode(
&mut read_buf,
BinaryLimits {
max_frame_len: cfg.max_frame_len,
},
),
Protocol::Ascii | Protocol::Meta => ascii.decode(
&mut read_buf,
AsciiLimits {
max_line_len: cfg.max_line_len,
max_blob_len: cfg.max_blob_len,
},
),
};
let Some(outcome) = outcome else {
break;
};
match outcome {
DecodeOutcome::Request(req, meta) => {
let est = estimate_request_bytes(&req);
pending_bytes = pending_bytes.saturating_add(est);
if pending_bytes + quiet.bytes > cfg.max_inflight_bytes {
let err = Response::Error(Error::server("inflight limit"));
let _ = send_response(
&mut stream,
&mut write_buf,
&mut quiet,
&cfg,
None,
meta,
err,
)
.await?;
return Ok(());
}
pending.push_back((req, meta, est));
}
DecodeOutcome::Response(meta, response) => {
let close = response_close(&response);
let extra_close = send_response(
&mut stream,
&mut write_buf,
&mut quiet,
&cfg,
None,
meta,
response,
)
.await?;
if close || extra_close {
return Ok(());
}
}
}
}
if let Some((req, meta, est)) = pending.pop_front() {
pending_bytes = pending_bytes.saturating_sub(est);
let ctx = RequestContext {
request: req.clone(),
meta,
peer_addr: info.peer_addr,
local_addr: info.local_addr,
client_id: info.client_id,
extensions: base_extensions.clone(),
};
let response = app.call(ctx).await;
let close = matches!(req.op, Op::Quit) || response_close(&response);
let extra_close = send_response(
&mut stream,
&mut write_buf,
&mut quiet,
&cfg,
Some(&req),
meta,
response,
)
.await?;
if close || extra_close {
flush_quiet(&mut stream, &mut write_buf, &mut quiet, &cfg).await?;
flush_write_buf(&mut stream, &mut write_buf, &cfg).await?;
return Ok(());
}
continue;
}
if !write_buf.is_empty() {
flush_write_buf(&mut stream, &mut write_buf, &cfg).await?;
}
let read = read_more(&mut stream, &mut read_buf, &cfg).await?;
if read == 0 {
return Ok(());
}
}
}
async fn send_response(
stream: &mut TcpStream,
write_buf: &mut BytesMut,
quiet: &mut QuietBuffer,
cfg: &ServerConfig,
req: Option<&Request>,
meta: RequestMeta,
response: Response,
) -> std::io::Result<bool> {
let dummy_req = Request::new(Op::Unknown);
let req = req.unwrap_or(&dummy_req);
match meta.protocol {
Protocol::Ascii | Protocol::Meta => {
if meta.protocol == Protocol::Ascii {
if crate::codec::ascii::should_suppress_ascii(meta, &response) {
return Ok(false);
}
match response {
Response::ValuesStream(mut stream_vals) => {
let include_cas = matches!(req.op, Op::Gets | Op::Gats);
while let Some(entry) = stream_vals.next() {
crate::codec::ascii::encode_value_entry(&entry, include_cas, write_buf);
if write_buf.len() >= cfg.write_batch_bytes {
flush_write_buf(stream, write_buf, cfg).await?;
}
}
write_buf.extend_from_slice(b"END\r\n");
}
Response::StatsStream(mut stream_lines) => {
while let Some(line) = stream_lines.next() {
crate::codec::ascii::encode_stat_line(&line, write_buf);
if write_buf.len() >= cfg.write_batch_bytes {
flush_write_buf(stream, write_buf, cfg).await?;
}
}
write_buf.extend_from_slice(b"END\r\n");
}
other => {
crate::codec::ascii::encode_ascii_response(req, meta, &other, write_buf);
}
}
} else if let Response::Meta(meta_resp) = &response {
crate::codec::ascii::encode_meta_response(req, meta, meta_resp, write_buf);
} else if let Response::Stats(lines) = response {
if meta.reply != ReplyMode::SuppressSuccess {
crate::codec::ascii::encode_meta_debug(req, lines, write_buf);
}
} else if let Response::StatsStream(mut stream_lines) = response {
if meta.reply != ReplyMode::SuppressSuccess {
let mut lines = Vec::new();
while let Some(line) = stream_lines.next() {
lines.push(line);
}
crate::codec::ascii::encode_meta_debug(req, lines, write_buf);
}
} else {
crate::codec::ascii::encode_ascii_response(req, meta, &response, write_buf);
}
if write_buf.len() >= cfg.write_batch_bytes {
flush_write_buf(stream, write_buf, cfg).await?;
}
}
Protocol::Binary => {
let quiet_mode = meta.reply == ReplyMode::QuietBuffered;
if !quiet_mode {
flush_quiet(stream, write_buf, quiet, cfg).await?;
}
match response {
Response::ValuesStream(mut stream_vals) => {
if let Some(entry) = stream_vals.next() {
let mut tmp = BytesMut::new();
let (status, _) = crate::codec::binary::encode_binary_response(
meta,
&Response::Value(entry),
&mut tmp,
meta.return_key,
);
let extra_close = handle_quiet_response(
QuietContext {
stream,
write_buf,
quiet,
cfg,
req,
meta,
},
status,
tmp,
)
.await?;
if extra_close {
return Ok(true);
}
} else {
let mut tmp = BytesMut::new();
let (status, _) = crate::codec::binary::encode_binary_response(
meta,
&Response::NotFound,
&mut tmp,
meta.return_key,
);
let extra_close = handle_quiet_response(
QuietContext {
stream,
write_buf,
quiet,
cfg,
req,
meta,
},
status,
tmp,
)
.await?;
if extra_close {
return Ok(true);
}
}
}
Response::Stats(lines) => {
if quiet_mode {
let mut tmp = BytesMut::new();
crate::codec::binary::encode_binary_response(
meta,
&Response::Stats(lines),
&mut tmp,
false,
);
if quiet.would_overflow(cfg, tmp.len()) {
flush_quiet(stream, write_buf, quiet, cfg).await?;
}
if quiet.would_overflow(cfg, tmp.len()) {
let mut err = BytesMut::new();
let meta = RequestMeta {
protocol: Protocol::Binary,
reply: ReplyMode::Always,
opaque: meta.opaque,
return_key: false,
opcode: meta.opcode,
};
crate::codec::binary::encode_binary_response(
meta,
&Response::Error(Error::server("quiet overflow")),
&mut err,
false,
);
write_buf.extend_from_slice(&err);
flush_write_buf(stream, write_buf, cfg).await?;
return Ok(true);
}
quiet.push(tmp.freeze());
} else {
let mut tmp = BytesMut::new();
crate::codec::binary::encode_binary_response(
meta,
&Response::Stats(lines),
&mut tmp,
false,
);
write_buf.extend_from_slice(&tmp);
}
}
Response::StatsStream(mut stream_lines) => {
let mut lines = Vec::new();
while let Some(line) = stream_lines.next() {
lines.push(line);
}
let mut tmp = BytesMut::new();
crate::codec::binary::encode_binary_response(
meta,
&Response::Stats(lines),
&mut tmp,
false,
);
if quiet_mode {
if quiet.would_overflow(cfg, tmp.len()) {
flush_quiet(stream, write_buf, quiet, cfg).await?;
}
if quiet.would_overflow(cfg, tmp.len()) {
let mut err = BytesMut::new();
let meta = RequestMeta {
protocol: Protocol::Binary,
reply: ReplyMode::Always,
opaque: meta.opaque,
return_key: false,
opcode: meta.opcode,
};
crate::codec::binary::encode_binary_response(
meta,
&Response::Error(Error::server("quiet overflow")),
&mut err,
false,
);
write_buf.extend_from_slice(&err);
flush_write_buf(stream, write_buf, cfg).await?;
return Ok(true);
}
quiet.push(tmp.freeze());
} else {
write_buf.extend_from_slice(&tmp);
}
}
other => {
let mut tmp = BytesMut::new();
let (status, _) = crate::codec::binary::encode_binary_response(
meta,
&other,
&mut tmp,
meta.return_key,
);
let extra_close = handle_quiet_response(
QuietContext {
stream,
write_buf,
quiet,
cfg,
req,
meta,
},
status,
tmp,
)
.await?;
if extra_close {
return Ok(true);
}
}
}
if write_buf.len() >= cfg.write_batch_bytes {
flush_write_buf(stream, write_buf, cfg).await?;
}
}
}
Ok(false)
}
struct QuietContext<'a> {
stream: &'a mut TcpStream,
write_buf: &'a mut BytesMut,
quiet: &'a mut QuietBuffer,
cfg: &'a ServerConfig,
req: &'a Request,
meta: RequestMeta,
}
async fn handle_quiet_response(
ctx: QuietContext<'_>,
status: u16,
tmp: BytesMut,
) -> std::io::Result<bool> {
let QuietContext {
stream,
write_buf,
quiet,
cfg,
req,
meta,
} = ctx;
if meta.reply != ReplyMode::QuietBuffered {
write_buf.extend_from_slice(&tmp);
return Ok(false);
}
let suppress = match req.op {
Op::Get => status == crate::codec::binary::STATUS_KEY_NOT_FOUND,
_ => status == crate::codec::binary::STATUS_SUCCESS,
};
if suppress {
return Ok(false);
}
if quiet.would_overflow(cfg, tmp.len()) {
flush_quiet(stream, write_buf, quiet, cfg).await?;
}
if quiet.would_overflow(cfg, tmp.len()) {
let mut err = BytesMut::new();
let meta = RequestMeta {
protocol: Protocol::Binary,
reply: ReplyMode::Always,
opaque: meta.opaque,
return_key: false,
opcode: meta.opcode,
};
crate::codec::binary::encode_binary_response(
meta,
&Response::Error(Error::server("quiet overflow")),
&mut err,
false,
);
write_buf.extend_from_slice(&err);
flush_write_buf(stream, write_buf, cfg).await?;
return Ok(true);
}
quiet.push(tmp.freeze());
Ok(false)
}
fn response_close(response: &Response) -> bool {
match response {
Response::Error(err) => err.close,
_ => false,
}
}
fn estimate_request_bytes(req: &Request) -> usize {
let mut total = 0usize;
if let Some(key) = &req.key {
total += key.len();
}
for key in &req.keys {
total += key.len();
}
if let Some(value) = &req.value {
total += value.len();
}
if let Some(meta) = &req.meta {
for flag in &meta.ordered {
if let Some(token) = &flag.token {
total += token.len();
}
}
}
total
}
struct QuietBuffer {
entries: Vec<bytes::Bytes>,
bytes: usize,
}
impl QuietBuffer {
fn new() -> Self {
Self {
entries: Vec::new(),
bytes: 0,
}
}
fn push(&mut self, value: bytes::Bytes) {
self.bytes = self.bytes.saturating_add(value.len());
self.entries.push(value);
}
fn clear(&mut self) {
self.entries.clear();
self.bytes = 0;
}
fn would_overflow(&self, cfg: &ServerConfig, add: usize) -> bool {
self.entries.len() + 1 > cfg.max_quiet_responses || self.bytes + add > cfg.max_quiet_bytes
}
}
async fn flush_quiet(
stream: &mut TcpStream,
write_buf: &mut BytesMut,
quiet: &mut QuietBuffer,
cfg: &ServerConfig,
) -> std::io::Result<()> {
if quiet.entries.is_empty() {
return Ok(());
}
for entry in quiet.entries.drain(..) {
write_buf.extend_from_slice(&entry);
if write_buf.len() >= cfg.write_batch_bytes {
flush_write_buf(stream, write_buf, cfg).await?;
}
}
quiet.clear();
Ok(())
}
async fn read_more(
stream: &mut TcpStream,
buf: &mut BytesMut,
cfg: &ServerConfig,
) -> std::io::Result<usize> {
let timeout = cfg.idle_timeout.or(cfg.read_timeout);
if let Some(timeout) = timeout {
Ok(tokio::time::timeout(timeout, stream.read_buf(buf)).await??)
} else {
stream.read_buf(buf).await
}
}
async fn flush_write_buf(
stream: &mut TcpStream,
buf: &mut BytesMut,
cfg: &ServerConfig,
) -> std::io::Result<()> {
if buf.is_empty() {
return Ok(());
}
if let Some(timeout) = cfg.write_timeout {
tokio::time::timeout(timeout, stream.write_all(buf)).await??;
} else {
stream.write_all(buf).await?;
}
buf.clear();
Ok(())
}
fn bind_with_backlog(addr: &str, backlog: u32) -> std::io::Result<TcpListener> {
let addr = addr
.to_socket_addrs()?
.next()
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid addr"))?;
let socket = Socket::new(Domain::for_address(addr), Type::STREAM, None)?;
socket.set_reuse_address(true)?;
socket.bind(&addr.into())?;
socket.listen(backlog as i32)?;
let listener: std::net::TcpListener = socket.into();
listener.set_nonblocking(true)?;
TcpListener::from_std(listener)
}