1use std::collections::VecDeque;
2use std::net::ToSocketAddrs;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Duration;
6
7use bytes::BytesMut;
8use socket2::{Domain, Socket, Type};
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::{TcpListener, TcpStream};
11
12use crate::codec::{AsciiDecoder, AsciiLimits, BinaryDecoder, BinaryLimits, DecodeOutcome};
13use crate::context::{ConnectionInfo, Extensions, RequestContext};
14use crate::error::Error;
15use crate::response::Response;
16use crate::router::Router;
17use crate::types::{Op, Protocol, ReplyMode, Request, RequestMeta};
18
19static NEXT_CLIENT_ID: AtomicU64 = AtomicU64::new(1);
20
21#[derive(Debug, Clone)]
23pub struct ServerConfig {
24 pub max_line_len: usize,
25 pub max_blob_len: usize,
26 pub max_frame_len: usize,
27 pub max_inflight_requests: usize,
28 pub max_inflight_bytes: usize,
29 pub max_quiet_responses: usize,
30 pub max_quiet_bytes: usize,
31 pub write_batch_bytes: usize,
32 pub read_timeout: Option<Duration>,
33 pub write_timeout: Option<Duration>,
34 pub idle_timeout: Option<Duration>,
35 pub tcp_nodelay: bool,
36 pub backlog: Option<u32>,
37}
38
39impl Default for ServerConfig {
40 fn default() -> Self {
41 Self {
42 max_line_len: 4 * 1024,
43 max_blob_len: 1 << 20,
44 max_frame_len: 2 * 1024 * 1024,
45 max_inflight_requests: 128,
46 max_inflight_bytes: 8 * 1024 * 1024,
47 max_quiet_responses: 256,
48 max_quiet_bytes: 2 * 1024 * 1024,
49 write_batch_bytes: 8 * 1024,
50 read_timeout: None,
51 write_timeout: None,
52 idle_timeout: None,
53 tcp_nodelay: true,
54 backlog: None,
55 }
56 }
57}
58
59impl ServerConfig {
60 pub fn builder() -> ServerConfigBuilder {
61 ServerConfigBuilder {
62 cfg: ServerConfig::default(),
63 }
64 }
65}
66
67pub struct ServerConfigBuilder {
69 cfg: ServerConfig,
70}
71
72impl ServerConfigBuilder {
73 pub fn max_line_len(mut self, value: usize) -> Self {
74 self.cfg.max_line_len = value.max(1);
75 self
76 }
77
78 pub fn max_blob_len(mut self, value: usize) -> Self {
79 self.cfg.max_blob_len = value.max(1);
80 self
81 }
82
83 pub fn max_frame_len(mut self, value: usize) -> Self {
84 self.cfg.max_frame_len = value.max(1);
85 self
86 }
87
88 pub fn max_inflight_requests(mut self, value: usize) -> Self {
89 self.cfg.max_inflight_requests = value.max(1);
90 self
91 }
92
93 pub fn max_inflight_bytes(mut self, value: usize) -> Self {
94 self.cfg.max_inflight_bytes = value.max(1);
95 self
96 }
97
98 pub fn max_quiet_responses(mut self, value: usize) -> Self {
99 self.cfg.max_quiet_responses = value.max(1);
100 self
101 }
102
103 pub fn max_quiet_bytes(mut self, value: usize) -> Self {
104 self.cfg.max_quiet_bytes = value.max(1);
105 self
106 }
107
108 pub fn write_batch_bytes(mut self, value: usize) -> Self {
109 self.cfg.write_batch_bytes = value.max(1);
110 self
111 }
112
113 pub fn read_timeout(mut self, value: Option<Duration>) -> Self {
114 self.cfg.read_timeout = value;
115 self
116 }
117
118 pub fn write_timeout(mut self, value: Option<Duration>) -> Self {
119 self.cfg.write_timeout = value;
120 self
121 }
122
123 pub fn idle_timeout(mut self, value: Option<Duration>) -> Self {
124 self.cfg.idle_timeout = value;
125 self
126 }
127
128 pub fn tcp_nodelay(mut self, value: bool) -> Self {
129 self.cfg.tcp_nodelay = value;
130 self
131 }
132
133 pub fn backlog(mut self, value: Option<u32>) -> Self {
134 self.cfg.backlog = value;
135 self
136 }
137
138 pub fn build(self) -> ServerConfig {
139 self.cfg
140 }
141}
142
143pub struct Server;
145
146impl Server {
147 pub fn bind<A: ToString>(addr: A) -> ServerBuilder {
148 ServerBuilder {
149 addr: addr.to_string(),
150 cfg: ServerConfig::default(),
151 shutdown: None,
152 extensions_factory: Arc::new(|_| Extensions::default()),
153 }
154 }
155}
156
157pub struct ServerBuilder {
159 addr: String,
160 cfg: ServerConfig,
161 shutdown: Option<BoxFuture>,
162 extensions_factory: Arc<dyn Fn(ConnectionInfo) -> Extensions + Send + Sync>,
163}
164
165type BoxFuture = std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>;
166
167impl ServerBuilder {
168 pub fn with_config(mut self, cfg: ServerConfig) -> Self {
169 self.cfg = cfg;
170 self
171 }
172
173 pub fn with_graceful_shutdown<F>(mut self, fut: F) -> Self
174 where
175 F: std::future::Future<Output = ()> + Send + 'static,
176 {
177 self.shutdown = Some(Box::pin(fut));
178 self
179 }
180
181 pub fn with_connection_extensions<F>(mut self, factory: F) -> Self
182 where
183 F: Fn(ConnectionInfo) -> Extensions + Send + Sync + 'static,
184 {
185 self.extensions_factory = Arc::new(factory);
186 self
187 }
188
189 pub async fn serve<State>(self, app: Router<State>) -> std::io::Result<()>
190 where
191 State: Send + Sync + 'static,
192 {
193 let listener = if let Some(backlog) = self.cfg.backlog {
194 bind_with_backlog(&self.addr, backlog)?
195 } else {
196 TcpListener::bind(&self.addr).await?
197 };
198 self.serve_with_listener(listener, app).await
199 }
200
201 pub async fn serve_with_listener<State>(
202 self,
203 listener: TcpListener,
204 app: Router<State>,
205 ) -> std::io::Result<()>
206 where
207 State: Send + Sync + 'static,
208 {
209 let app = Arc::new(app);
210 if let Some(shutdown) = self.shutdown {
211 tokio::select! {
212 result = accept_loop(listener, app, self.cfg, self.extensions_factory) => result,
213 _ = shutdown => Ok(()),
214 }
215 } else {
216 accept_loop(listener, app, self.cfg, self.extensions_factory).await
217 }
218 }
219}
220
221async fn accept_loop<State>(
222 listener: TcpListener,
223 app: Arc<Router<State>>,
224 cfg: ServerConfig,
225 extensions_factory: Arc<dyn Fn(ConnectionInfo) -> Extensions + Send + Sync>,
226) -> std::io::Result<()>
227where
228 State: Send + Sync + 'static,
229{
230 loop {
231 let (stream, peer_addr) = listener.accept().await?;
232 let local_addr = stream.local_addr()?;
233 let client_id = NEXT_CLIENT_ID.fetch_add(1, Ordering::Relaxed);
234 let info = ConnectionInfo {
235 peer_addr,
236 local_addr,
237 client_id,
238 };
239 let extensions = (extensions_factory)(info);
240 let app = Arc::clone(&app);
241 let cfg = cfg.clone();
242 tokio::spawn(async move {
243 let _ = handle_connection(stream, app, cfg, info, extensions).await;
244 });
245 }
246}
247
248async fn handle_connection<State>(
249 mut stream: TcpStream,
250 app: Arc<Router<State>>,
251 cfg: ServerConfig,
252 info: ConnectionInfo,
253 base_extensions: Extensions,
254) -> std::io::Result<()>
255where
256 State: Send + Sync + 'static,
257{
258 if cfg.tcp_nodelay {
259 let _ = stream.set_nodelay(true);
260 }
261
262 let mut read_buf = BytesMut::with_capacity(4096);
263 let mut write_buf = BytesMut::with_capacity(cfg.write_batch_bytes);
264 let mut ascii = AsciiDecoder::new();
265 let mut binary = BinaryDecoder::new();
266 let mut protocol: Option<Protocol> = None;
267 let mut pending: VecDeque<(Request, RequestMeta, usize)> = VecDeque::new();
268 let mut pending_bytes: usize = 0;
269 let mut quiet = QuietBuffer::new();
270
271 loop {
272 loop {
274 if pending.len() >= cfg.max_inflight_requests {
275 break;
276 }
277 if protocol.is_none() {
278 if read_buf.is_empty() {
279 break;
280 }
281 protocol = Some(if read_buf[0] == 0x80 {
282 Protocol::Binary
283 } else {
284 Protocol::Ascii
285 });
286 }
287 let outcome = match protocol.unwrap() {
288 Protocol::Binary => binary.decode(
289 &mut read_buf,
290 BinaryLimits {
291 max_frame_len: cfg.max_frame_len,
292 },
293 ),
294 Protocol::Ascii | Protocol::Meta => ascii.decode(
295 &mut read_buf,
296 AsciiLimits {
297 max_line_len: cfg.max_line_len,
298 max_blob_len: cfg.max_blob_len,
299 },
300 ),
301 };
302 let Some(outcome) = outcome else {
303 break;
304 };
305 match outcome {
306 DecodeOutcome::Request(req, meta) => {
307 let est = estimate_request_bytes(&req);
308 pending_bytes = pending_bytes.saturating_add(est);
309 if pending_bytes + quiet.bytes > cfg.max_inflight_bytes {
310 let err = Response::Error(Error::server("inflight limit"));
311 let _ = send_response(
312 &mut stream,
313 &mut write_buf,
314 &mut quiet,
315 &cfg,
316 None,
317 meta,
318 err,
319 )
320 .await?;
321 return Ok(());
322 }
323 pending.push_back((req, meta, est));
324 }
325 DecodeOutcome::Response(meta, response) => {
326 let close = response_close(&response);
327 let extra_close = send_response(
328 &mut stream,
329 &mut write_buf,
330 &mut quiet,
331 &cfg,
332 None,
333 meta,
334 response,
335 )
336 .await?;
337 if close || extra_close {
338 return Ok(());
339 }
340 }
341 }
342 }
343
344 if let Some((req, meta, est)) = pending.pop_front() {
345 pending_bytes = pending_bytes.saturating_sub(est);
346 let ctx = RequestContext {
347 request: req.clone(),
348 meta,
349 peer_addr: info.peer_addr,
350 local_addr: info.local_addr,
351 client_id: info.client_id,
352 extensions: base_extensions.clone(),
353 };
354 let response = app.call(ctx).await;
355 let close = matches!(req.op, Op::Quit) || response_close(&response);
356 let extra_close = send_response(
357 &mut stream,
358 &mut write_buf,
359 &mut quiet,
360 &cfg,
361 Some(&req),
362 meta,
363 response,
364 )
365 .await?;
366 if close || extra_close {
367 flush_quiet(&mut stream, &mut write_buf, &mut quiet, &cfg).await?;
368 flush_write_buf(&mut stream, &mut write_buf, &cfg).await?;
369 return Ok(());
370 }
371 continue;
372 }
373
374 if !write_buf.is_empty() {
375 flush_write_buf(&mut stream, &mut write_buf, &cfg).await?;
376 }
377
378 let read = read_more(&mut stream, &mut read_buf, &cfg).await?;
379 if read == 0 {
380 return Ok(());
381 }
382 }
383}
384
385async fn send_response(
386 stream: &mut TcpStream,
387 write_buf: &mut BytesMut,
388 quiet: &mut QuietBuffer,
389 cfg: &ServerConfig,
390 req: Option<&Request>,
391 meta: RequestMeta,
392 response: Response,
393) -> std::io::Result<bool> {
394 let dummy_req = Request::new(Op::Unknown);
395 let req = req.unwrap_or(&dummy_req);
396 match meta.protocol {
397 Protocol::Ascii | Protocol::Meta => {
398 if meta.protocol == Protocol::Ascii {
399 if crate::codec::ascii::should_suppress_ascii(meta, &response) {
400 return Ok(false);
401 }
402 match response {
403 Response::ValuesStream(mut stream_vals) => {
404 let include_cas = matches!(req.op, Op::Gets | Op::Gats);
405 while let Some(entry) = stream_vals.next() {
406 crate::codec::ascii::encode_value_entry(&entry, include_cas, write_buf);
407 if write_buf.len() >= cfg.write_batch_bytes {
408 flush_write_buf(stream, write_buf, cfg).await?;
409 }
410 }
411 write_buf.extend_from_slice(b"END\r\n");
412 }
413 Response::StatsStream(mut stream_lines) => {
414 while let Some(line) = stream_lines.next() {
415 crate::codec::ascii::encode_stat_line(&line, write_buf);
416 if write_buf.len() >= cfg.write_batch_bytes {
417 flush_write_buf(stream, write_buf, cfg).await?;
418 }
419 }
420 write_buf.extend_from_slice(b"END\r\n");
421 }
422 other => {
423 crate::codec::ascii::encode_ascii_response(req, meta, &other, write_buf);
424 }
425 }
426 } else if let Response::Meta(meta_resp) = &response {
427 crate::codec::ascii::encode_meta_response(req, meta, meta_resp, write_buf);
428 } else if let Response::Stats(lines) = response {
429 if meta.reply != ReplyMode::SuppressSuccess {
430 crate::codec::ascii::encode_meta_debug(req, lines, write_buf);
431 }
432 } else if let Response::StatsStream(mut stream_lines) = response {
433 if meta.reply != ReplyMode::SuppressSuccess {
434 let mut lines = Vec::new();
435 while let Some(line) = stream_lines.next() {
436 lines.push(line);
437 }
438 crate::codec::ascii::encode_meta_debug(req, lines, write_buf);
439 }
440 } else {
441 crate::codec::ascii::encode_ascii_response(req, meta, &response, write_buf);
442 }
443
444 if write_buf.len() >= cfg.write_batch_bytes {
445 flush_write_buf(stream, write_buf, cfg).await?;
446 }
447 }
448 Protocol::Binary => {
449 let quiet_mode = meta.reply == ReplyMode::QuietBuffered;
450 if !quiet_mode {
451 flush_quiet(stream, write_buf, quiet, cfg).await?;
452 }
453
454 match response {
455 Response::ValuesStream(mut stream_vals) => {
456 if let Some(entry) = stream_vals.next() {
457 let mut tmp = BytesMut::new();
458 let (status, _) = crate::codec::binary::encode_binary_response(
459 meta,
460 &Response::Value(entry),
461 &mut tmp,
462 meta.return_key,
463 );
464 let extra_close = handle_quiet_response(
465 QuietContext {
466 stream,
467 write_buf,
468 quiet,
469 cfg,
470 req,
471 meta,
472 },
473 status,
474 tmp,
475 )
476 .await?;
477 if extra_close {
478 return Ok(true);
479 }
480 } else {
481 let mut tmp = BytesMut::new();
482 let (status, _) = crate::codec::binary::encode_binary_response(
483 meta,
484 &Response::NotFound,
485 &mut tmp,
486 meta.return_key,
487 );
488 let extra_close = handle_quiet_response(
489 QuietContext {
490 stream,
491 write_buf,
492 quiet,
493 cfg,
494 req,
495 meta,
496 },
497 status,
498 tmp,
499 )
500 .await?;
501 if extra_close {
502 return Ok(true);
503 }
504 }
505 }
506 Response::Stats(lines) => {
507 if quiet_mode {
508 let mut tmp = BytesMut::new();
509 crate::codec::binary::encode_binary_response(
510 meta,
511 &Response::Stats(lines),
512 &mut tmp,
513 false,
514 );
515 if quiet.would_overflow(cfg, tmp.len()) {
516 flush_quiet(stream, write_buf, quiet, cfg).await?;
517 }
518 if quiet.would_overflow(cfg, tmp.len()) {
519 let mut err = BytesMut::new();
520 let meta = RequestMeta {
521 protocol: Protocol::Binary,
522 reply: ReplyMode::Always,
523 opaque: meta.opaque,
524 return_key: false,
525 opcode: meta.opcode,
526 };
527 crate::codec::binary::encode_binary_response(
528 meta,
529 &Response::Error(Error::server("quiet overflow")),
530 &mut err,
531 false,
532 );
533 write_buf.extend_from_slice(&err);
534 flush_write_buf(stream, write_buf, cfg).await?;
535 return Ok(true);
536 }
537 quiet.push(tmp.freeze());
538 } else {
539 let mut tmp = BytesMut::new();
540 crate::codec::binary::encode_binary_response(
541 meta,
542 &Response::Stats(lines),
543 &mut tmp,
544 false,
545 );
546 write_buf.extend_from_slice(&tmp);
547 }
548 }
549 Response::StatsStream(mut stream_lines) => {
550 let mut lines = Vec::new();
551 while let Some(line) = stream_lines.next() {
552 lines.push(line);
553 }
554 let mut tmp = BytesMut::new();
555 crate::codec::binary::encode_binary_response(
556 meta,
557 &Response::Stats(lines),
558 &mut tmp,
559 false,
560 );
561 if quiet_mode {
562 if quiet.would_overflow(cfg, tmp.len()) {
563 flush_quiet(stream, write_buf, quiet, cfg).await?;
564 }
565 if quiet.would_overflow(cfg, tmp.len()) {
566 let mut err = BytesMut::new();
567 let meta = RequestMeta {
568 protocol: Protocol::Binary,
569 reply: ReplyMode::Always,
570 opaque: meta.opaque,
571 return_key: false,
572 opcode: meta.opcode,
573 };
574 crate::codec::binary::encode_binary_response(
575 meta,
576 &Response::Error(Error::server("quiet overflow")),
577 &mut err,
578 false,
579 );
580 write_buf.extend_from_slice(&err);
581 flush_write_buf(stream, write_buf, cfg).await?;
582 return Ok(true);
583 }
584 quiet.push(tmp.freeze());
585 } else {
586 write_buf.extend_from_slice(&tmp);
587 }
588 }
589 other => {
590 let mut tmp = BytesMut::new();
591 let (status, _) = crate::codec::binary::encode_binary_response(
592 meta,
593 &other,
594 &mut tmp,
595 meta.return_key,
596 );
597 let extra_close = handle_quiet_response(
598 QuietContext {
599 stream,
600 write_buf,
601 quiet,
602 cfg,
603 req,
604 meta,
605 },
606 status,
607 tmp,
608 )
609 .await?;
610 if extra_close {
611 return Ok(true);
612 }
613 }
614 }
615
616 if write_buf.len() >= cfg.write_batch_bytes {
617 flush_write_buf(stream, write_buf, cfg).await?;
618 }
619 }
620 }
621 Ok(false)
622}
623
624struct QuietContext<'a> {
625 stream: &'a mut TcpStream,
626 write_buf: &'a mut BytesMut,
627 quiet: &'a mut QuietBuffer,
628 cfg: &'a ServerConfig,
629 req: &'a Request,
630 meta: RequestMeta,
631}
632
633async fn handle_quiet_response(
634 ctx: QuietContext<'_>,
635 status: u16,
636 tmp: BytesMut,
637) -> std::io::Result<bool> {
638 let QuietContext {
639 stream,
640 write_buf,
641 quiet,
642 cfg,
643 req,
644 meta,
645 } = ctx;
646 if meta.reply != ReplyMode::QuietBuffered {
647 write_buf.extend_from_slice(&tmp);
648 return Ok(false);
649 }
650
651 let suppress = match req.op {
652 Op::Get => status == crate::codec::binary::STATUS_KEY_NOT_FOUND,
653 _ => status == crate::codec::binary::STATUS_SUCCESS,
654 };
655
656 if suppress {
657 return Ok(false);
658 }
659
660 if quiet.would_overflow(cfg, tmp.len()) {
661 flush_quiet(stream, write_buf, quiet, cfg).await?;
662 }
663 if quiet.would_overflow(cfg, tmp.len()) {
664 let mut err = BytesMut::new();
665 let meta = RequestMeta {
666 protocol: Protocol::Binary,
667 reply: ReplyMode::Always,
668 opaque: meta.opaque,
669 return_key: false,
670 opcode: meta.opcode,
671 };
672 crate::codec::binary::encode_binary_response(
673 meta,
674 &Response::Error(Error::server("quiet overflow")),
675 &mut err,
676 false,
677 );
678 write_buf.extend_from_slice(&err);
679 flush_write_buf(stream, write_buf, cfg).await?;
680 return Ok(true);
681 }
682 quiet.push(tmp.freeze());
683 Ok(false)
684}
685
686fn response_close(response: &Response) -> bool {
687 match response {
688 Response::Error(err) => err.close,
689 _ => false,
690 }
691}
692
693fn estimate_request_bytes(req: &Request) -> usize {
694 let mut total = 0usize;
695 if let Some(key) = &req.key {
696 total += key.len();
697 }
698 for key in &req.keys {
699 total += key.len();
700 }
701 if let Some(value) = &req.value {
702 total += value.len();
703 }
704 if let Some(meta) = &req.meta {
705 for flag in &meta.ordered {
706 if let Some(token) = &flag.token {
707 total += token.len();
708 }
709 }
710 }
711 total
712}
713
714struct QuietBuffer {
715 entries: Vec<bytes::Bytes>,
716 bytes: usize,
717}
718
719impl QuietBuffer {
720 fn new() -> Self {
721 Self {
722 entries: Vec::new(),
723 bytes: 0,
724 }
725 }
726
727 fn push(&mut self, value: bytes::Bytes) {
728 self.bytes = self.bytes.saturating_add(value.len());
729 self.entries.push(value);
730 }
731
732 fn clear(&mut self) {
733 self.entries.clear();
734 self.bytes = 0;
735 }
736
737 fn would_overflow(&self, cfg: &ServerConfig, add: usize) -> bool {
738 self.entries.len() + 1 > cfg.max_quiet_responses || self.bytes + add > cfg.max_quiet_bytes
739 }
740}
741
742async fn flush_quiet(
743 stream: &mut TcpStream,
744 write_buf: &mut BytesMut,
745 quiet: &mut QuietBuffer,
746 cfg: &ServerConfig,
747) -> std::io::Result<()> {
748 if quiet.entries.is_empty() {
749 return Ok(());
750 }
751 for entry in quiet.entries.drain(..) {
752 write_buf.extend_from_slice(&entry);
753 if write_buf.len() >= cfg.write_batch_bytes {
754 flush_write_buf(stream, write_buf, cfg).await?;
755 }
756 }
757 quiet.clear();
758 Ok(())
759}
760
761async fn read_more(
762 stream: &mut TcpStream,
763 buf: &mut BytesMut,
764 cfg: &ServerConfig,
765) -> std::io::Result<usize> {
766 let timeout = cfg.idle_timeout.or(cfg.read_timeout);
767 if let Some(timeout) = timeout {
768 Ok(tokio::time::timeout(timeout, stream.read_buf(buf)).await??)
769 } else {
770 stream.read_buf(buf).await
771 }
772}
773
774async fn flush_write_buf(
775 stream: &mut TcpStream,
776 buf: &mut BytesMut,
777 cfg: &ServerConfig,
778) -> std::io::Result<()> {
779 if buf.is_empty() {
780 return Ok(());
781 }
782 if let Some(timeout) = cfg.write_timeout {
783 tokio::time::timeout(timeout, stream.write_all(buf)).await??;
784 } else {
785 stream.write_all(buf).await?;
786 }
787 buf.clear();
788 Ok(())
789}
790
791fn bind_with_backlog(addr: &str, backlog: u32) -> std::io::Result<TcpListener> {
792 let addr = addr
793 .to_socket_addrs()?
794 .next()
795 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid addr"))?;
796 let socket = Socket::new(Domain::for_address(addr), Type::STREAM, None)?;
797 socket.set_reuse_address(true)?;
798 socket.bind(&addr.into())?;
799 socket.listen(backlog as i32)?;
800 let listener: std::net::TcpListener = socket.into();
801 listener.set_nonblocking(true)?;
802 TcpListener::from_std(listener)
803}