1use std::collections::VecDeque;
2use std::net::{SocketAddr, ToSocketAddrs};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
5use std::time::Duration;
6
7use bytes::BytesMut;
8use log::error;
9use socket2::{Domain, Protocol, SockAddr, Socket, Type};
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{TcpListener, TcpStream};
12use tokio::sync::{Semaphore, broadcast, mpsc, oneshot};
13use tokio::task::JoinSet;
14
15use crate::context::{Command, Extensions, PubSubHandle, PushHandle, RequestContext};
16use crate::error::Result;
17use crate::resp::{DecodeLimits, Value, ValueDecoder};
18use crate::response::{IntoResponse, RespError};
19use crate::router::Router;
20
21#[derive(Debug, Clone, Copy)]
23pub struct ConnectionInfo {
24 pub id: u64,
25 pub peer_addr: SocketAddr,
26 pub local_addr: SocketAddr,
27}
28
29pub trait ServerHooks: Send + Sync + 'static {
31 fn on_accept_error(&self, _err: &std::io::Error) {}
32 fn on_connection_open(&self, _info: ConnectionInfo) {}
33 fn on_connection_close(&self, _info: ConnectionInfo) {}
34 fn on_command(&self, _id: u64, _command: &Command) {}
35 fn on_protocol_error(&self, _err: &RespError) {}
36 fn on_io_error(&self, _err: &std::io::Error) {}
37}
38
39#[derive(Debug, Default, Clone)]
41pub struct NoopServerHooks;
42
43impl ServerHooks for NoopServerHooks {}
44
45struct ConnectionGuard {
46 hooks: Arc<dyn ServerHooks>,
47 info: ConnectionInfo,
48}
49
50impl Drop for ConnectionGuard {
51 fn drop(&mut self) {
52 self.hooks.on_connection_close(self.info);
53 }
54}
55
56type ExtensionsFactory = Arc<dyn Fn(ConnectionInfo) -> Extensions + Send + Sync + 'static>;
57
58fn default_extensions(_info: ConnectionInfo) -> Extensions {
59 Extensions::default()
60}
61
62#[derive(Clone, Debug)]
64pub struct ServerConfig {
65 pub max_frame_size: usize,
66 pub max_bulk_len: usize,
67 pub max_array_len: usize,
68 pub max_depth: usize,
69 pub max_inflight_requests: usize,
70 pub max_connections: usize,
71 pub read_timeout: Option<Duration>,
72 pub write_timeout: Option<Duration>,
73 pub idle_timeout: Option<Duration>,
74 pub push_queue_len: usize,
75 pub response_queue_len: usize,
76 pub write_batch_bytes: usize,
77 pub tcp_nodelay: bool,
78 pub backlog: Option<u32>,
79}
80
81impl Default for ServerConfig {
82 fn default() -> Self {
83 Self {
84 max_frame_size: 1 << 20,
85 max_bulk_len: 1 << 20,
86 max_array_len: 1024,
87 max_depth: 16,
88 max_inflight_requests: 128,
89 max_connections: 1024,
90 read_timeout: None,
91 write_timeout: None,
92 idle_timeout: None,
93 push_queue_len: 1024,
94 response_queue_len: 1024,
95 write_batch_bytes: 8 * 1024,
96 tcp_nodelay: true,
97 backlog: None,
98 }
99 }
100}
101
102#[derive(Clone, Debug)]
104pub struct ServerConfigBuilder {
105 cfg: ServerConfig,
106}
107
108impl ServerConfig {
109 pub fn builder() -> ServerConfigBuilder {
110 ServerConfigBuilder {
111 cfg: ServerConfig::default(),
112 }
113 }
114}
115
116impl ServerConfigBuilder {
117 pub fn max_frame_size(mut self, value: usize) -> Self {
119 self.cfg.max_frame_size = value.max(1);
120 self
121 }
122
123 pub fn max_bulk_len(mut self, value: usize) -> Self {
125 self.cfg.max_bulk_len = value.max(1);
126 self
127 }
128
129 pub fn max_array_len(mut self, value: usize) -> Self {
131 self.cfg.max_array_len = value.max(1);
132 self
133 }
134
135 pub fn max_depth(mut self, value: usize) -> Self {
137 self.cfg.max_depth = value.max(1);
138 self
139 }
140
141 pub fn max_inflight_requests(mut self, value: usize) -> Self {
143 self.cfg.max_inflight_requests = value.max(1);
144 self
145 }
146
147 pub fn max_connections(mut self, value: usize) -> Self {
149 self.cfg.max_connections = value.max(1);
150 self
151 }
152
153 pub fn read_timeout(mut self, value: Option<Duration>) -> Self {
155 self.cfg.read_timeout = value;
156 self
157 }
158
159 pub fn write_timeout(mut self, value: Option<Duration>) -> Self {
161 self.cfg.write_timeout = value;
162 self
163 }
164
165 pub fn idle_timeout(mut self, value: Option<Duration>) -> Self {
167 self.cfg.idle_timeout = value;
168 self
169 }
170
171 pub fn push_queue_len(mut self, value: usize) -> Self {
173 self.cfg.push_queue_len = value.max(1);
174 self
175 }
176
177 pub fn response_queue_len(mut self, value: usize) -> Self {
179 self.cfg.response_queue_len = value.max(1);
180 self
181 }
182
183 pub fn write_batch_bytes(mut self, value: usize) -> Self {
185 self.cfg.write_batch_bytes = value.max(1);
186 self
187 }
188
189 pub fn tcp_nodelay(mut self, value: bool) -> Self {
191 self.cfg.tcp_nodelay = value;
192 self
193 }
194
195 pub fn backlog(mut self, value: Option<u32>) -> Self {
197 self.cfg.backlog = value;
198 self
199 }
200
201 pub fn build(self) -> ServerConfig {
203 self.cfg
204 }
205}
206
207pub struct ServerBuilder {
209 addr: String,
210 cfg: ServerConfig,
211 shutdown: Option<BoxFuture>,
212 hooks: Arc<dyn ServerHooks>,
213 extensions_factory: ExtensionsFactory,
214}
215
216type BoxFuture = std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>;
217
218impl ServerBuilder {
219 pub fn with_config(mut self, cfg: ServerConfig) -> Self {
221 self.cfg = cfg;
222 self
223 }
224
225 pub fn with_graceful_shutdown<F>(mut self, fut: F) -> Self
227 where
228 F: std::future::Future<Output = ()> + Send + 'static,
229 {
230 self.shutdown = Some(Box::pin(fut));
231 self
232 }
233
234 pub fn with_hooks<H>(mut self, hooks: H) -> Self
236 where
237 H: ServerHooks,
238 {
239 self.hooks = Arc::new(hooks);
240 self
241 }
242
243 pub fn with_connection_extensions<F>(mut self, factory: F) -> Self
248 where
249 F: Fn(ConnectionInfo) -> Extensions + Send + Sync + 'static,
250 {
251 self.extensions_factory = Arc::new(factory);
252 self
253 }
254
255 pub async fn serve<State>(self, app: Router<State>) -> Result<()>
256 where
257 State: Send + Sync + 'static,
258 {
259 let listener = if let Some(backlog) = self.cfg.backlog {
260 bind_with_backlog(&self.addr, backlog)?
261 } else {
262 TcpListener::bind(&self.addr).await?
263 };
264 self.serve_with_listener(listener, app).await
265 }
266
267 pub async fn serve_with_listener<State>(
269 self,
270 listener: TcpListener,
271 app: Router<State>,
272 ) -> Result<()>
273 where
274 State: Send + Sync + 'static,
275 {
276 run_server(
277 listener,
278 app,
279 self.cfg,
280 self.shutdown,
281 self.hooks,
282 self.extensions_factory,
283 )
284 .await
285 }
286}
287
288pub struct Server;
290
291impl Server {
292 pub fn bind(addr: impl Into<String>) -> ServerBuilder {
294 ServerBuilder {
295 addr: addr.into(),
296 cfg: ServerConfig::default(),
297 shutdown: None,
298 hooks: Arc::new(NoopServerHooks),
299 extensions_factory: Arc::new(default_extensions),
300 }
301 }
302}
303
304fn bind_with_backlog(addr: &str, backlog: u32) -> Result<TcpListener> {
305 let mut addrs = addr.to_socket_addrs()?;
306 let addr = addrs.next().ok_or_else(|| {
307 std::io::Error::new(std::io::ErrorKind::InvalidInput, "empty bind address")
308 })?;
309 let domain = Domain::for_address(addr);
310 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
311 socket.set_nonblocking(true)?;
312 socket.bind(&SockAddr::from(addr))?;
313 let backlog = backlog.max(1).min(i32::MAX as u32) as i32;
314 socket.listen(backlog)?;
315 let listener: std::net::TcpListener = socket.into();
316 TcpListener::from_std(listener)
317}
318
319async fn run_server<State>(
320 listener: TcpListener,
321 app: Router<State>,
322 cfg: ServerConfig,
323 shutdown: Option<BoxFuture>,
324 hooks: Arc<dyn ServerHooks>,
325 extensions_factory: ExtensionsFactory,
326) -> Result<()>
327where
328 State: Send + Sync + 'static,
329{
330 let (shutdown_tx, _) = broadcast::channel(1);
331 let mut shutdown_rx = shutdown_tx.subscribe();
332 let mut join_set = JoinSet::new();
333 let semaphore = Arc::new(Semaphore::new(cfg.max_connections));
334 let client_id = Arc::new(AtomicU64::new(1));
335
336 let shutdown_fut = async move {
337 if let Some(fut) = shutdown {
338 fut.await;
339 } else {
340 std::future::pending::<()>().await;
341 }
342 };
343 let mut shutdown_fut = Box::pin(shutdown_fut);
344
345 loop {
346 tokio::select! {
347 _ = &mut shutdown_fut => {
348 break;
349 }
350 _ = shutdown_rx.recv() => {
351 break;
352 }
353 accept = listener.accept() => {
354 let (socket, _) = match accept {
355 Ok(value) => value,
356 Err(err) => {
357 error!("accept error: {:?}", err);
358 hooks.on_accept_error(&err);
359 continue;
360 }
361 };
362
363 let permit = match semaphore.clone().try_acquire_owned() {
364 Ok(permit) => permit,
365 Err(_) => {
366 continue;
367 }
368 };
369
370 if cfg.tcp_nodelay {
371 let _ = socket.set_nodelay(true);
372 }
373
374 let handler = app.clone();
375 let cfg = cfg.clone();
376 let shutdown_rx = shutdown_tx.subscribe();
377 let id = client_id.fetch_add(1, Ordering::AcqRel);
378 let hooks = Arc::clone(&hooks);
379 let extensions_factory = Arc::clone(&extensions_factory);
380 join_set.spawn(async move {
381 let _permit = permit;
382 if let Err(err) =
383 run_connection(
384 id,
385 socket,
386 handler,
387 cfg,
388 shutdown_rx,
389 hooks.clone(),
390 extensions_factory,
391 )
392 .await
393 {
394 error!("connection error: {:?}", err);
395 hooks.on_io_error(&err);
396 }
397 });
398 }
399 }
400 }
401
402 let _ = shutdown_tx.send(());
403
404 while let Some(res) = join_set.join_next().await {
405 if let Err(err) = res {
406 error!("connection task error: {:?}", err);
407 }
408 }
409
410 Ok(())
411}
412
413async fn run_connection<State>(
414 id: u64,
415 socket: TcpStream,
416 app: Router<State>,
417 cfg: ServerConfig,
418 mut shutdown: broadcast::Receiver<()>,
419 hooks: Arc<dyn ServerHooks>,
420 extensions_factory: ExtensionsFactory,
421) -> Result<()>
422where
423 State: Send + Sync + 'static,
424{
425 let peer_addr = socket.peer_addr()?;
426 let local_addr = socket.local_addr()?;
427 let info = ConnectionInfo {
428 id,
429 peer_addr,
430 local_addr,
431 };
432 hooks.on_connection_open(info);
433 let _guard = ConnectionGuard {
434 hooks: Arc::clone(&hooks),
435 info,
436 };
437 let (mut reader, writer) = socket.into_split();
438
439 let (resp_tx, resp_rx) = mpsc::channel(cfg.response_queue_len);
440 let (push_tx, push_rx) = mpsc::channel(cfg.push_queue_len);
441 let (close_tx, mut close_rx) = mpsc::channel(1);
442 let (writer_close_tx, writer_close_rx) = oneshot::channel();
443 let mut writer_close_tx = Some(writer_close_tx);
444
445 let push_handle = PushHandle::new(push_tx, close_tx);
446 let pubsub_count = Arc::new(AtomicUsize::new(0));
447 let pubsub_handle = PubSubHandle::new(pubsub_count.clone());
448 let extensions = (extensions_factory)(info);
449
450 let writer_cfg = cfg.clone();
451 let writer_task = tokio::spawn(async move {
452 writer_loop(writer, resp_rx, push_rx, writer_close_rx, writer_cfg).await
453 });
454
455 let mut rd = BytesMut::with_capacity(4096);
456 let mut decoder = ValueDecoder::new(DecodeLimits {
457 max_bulk_len: cfg.max_bulk_len,
458 max_array_len: cfg.max_array_len,
459 max_depth: cfg.max_depth,
460 });
461 let mut inflight = VecDeque::new();
462
463 loop {
464 while inflight.len() < cfg.max_inflight_requests {
465 match decoder.try_decode(&mut rd) {
466 Ok(Some(value)) => {
467 let command = match Command::from_value(value) {
468 Ok(cmd) => cmd,
469 Err(err) => {
470 hooks.on_protocol_error(&err);
471 let _ = resp_tx.send(err.into_response()).await;
472 signal_writer_close(&mut writer_close_tx);
473 return Ok(());
474 }
475 };
476
477 if pubsub_count.load(Ordering::Acquire) > 0
478 && !is_pubsub_allowed(&command.name_upper)
479 {
480 let err = RespError::invalid_data(
481 "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context",
482 );
483 hooks.on_protocol_error(&err);
484 let _ = resp_tx.send(err.into_response()).await;
485 continue;
486 }
487
488 hooks.on_command(id, &command);
489 inflight.push_back(command);
490 }
491 Ok(None) => break,
492 Err(err) => {
493 hooks.on_protocol_error(&err);
494 let _ = resp_tx.send(err.into_response()).await;
495 signal_writer_close(&mut writer_close_tx);
496 return Ok(());
497 }
498 }
499 }
500
501 if let Some(command) = inflight.pop_front() {
502 let close_after = command.name_upper.as_ref() == b"QUIT";
503 let ctx = RequestContext {
504 command,
505 peer_addr,
506 local_addr,
507 client_id: id,
508 extensions: extensions.clone(),
509 push: push_handle.clone(),
510 pubsub: pubsub_handle.clone(),
511 };
512 let response = app.call(ctx).await;
513 if resp_tx.send(response).await.is_err() {
514 break;
515 }
516 if close_after {
517 signal_writer_close(&mut writer_close_tx);
518 break;
519 }
520 continue;
521 }
522
523 tokio::select! {
524 _ = shutdown.recv() => {
525 signal_writer_close(&mut writer_close_tx);
526 break;
527 }
528 _ = close_rx.recv() => {
529 let err = RespError::invalid_data("ERR client output buffer limit reached");
530 hooks.on_protocol_error(&err);
531 let _ = resp_tx.try_send(err.into_response());
532 signal_writer_close(&mut writer_close_tx);
533 break;
534 }
535 read = read_more(&mut reader, &mut rd, cfg.read_timeout, cfg.idle_timeout) => {
536 let read = read?;
537 if read == 0 {
538 if rd.is_empty() {
539 signal_writer_close(&mut writer_close_tx);
540 break;
541 }
542 let err = RespError::invalid_data("ERR unexpected EOF");
543 hooks.on_protocol_error(&err);
544 let _ = resp_tx.send(err.into_response()).await;
545 signal_writer_close(&mut writer_close_tx);
546 break;
547 }
548 if rd.len() > cfg.max_frame_size {
549 let err = RespError::invalid_data("ERR max frame size exceeded");
550 hooks.on_protocol_error(&err);
551 let _ = resp_tx.send(err.into_response()).await;
552 signal_writer_close(&mut writer_close_tx);
553 break;
554 }
555 }
556 }
557 }
558
559 signal_writer_close(&mut writer_close_tx);
560 drop(resp_tx);
561 drop(push_handle);
562
563 let _ = writer_task.await;
564 Ok(())
565}
566
567async fn read_more(
568 reader: &mut tokio::net::tcp::OwnedReadHalf,
569 buf: &mut BytesMut,
570 read_timeout: Option<Duration>,
571 idle_timeout: Option<Duration>,
572) -> Result<usize> {
573 let fut = reader.read_buf(buf);
574 let timeout = idle_timeout.or(read_timeout);
575 if let Some(timeout) = timeout {
576 match tokio::time::timeout(timeout, fut).await {
577 Ok(res) => res,
578 Err(_) => Err(std::io::ErrorKind::TimedOut.into()),
579 }
580 } else {
581 fut.await
582 }
583}
584
585async fn writer_loop(
586 mut writer: tokio::net::tcp::OwnedWriteHalf,
587 mut resp_rx: mpsc::Receiver<Value>,
588 mut push_rx: mpsc::Receiver<Value>,
589 mut close_rx: oneshot::Receiver<()>,
590 cfg: ServerConfig,
591) -> Result<()> {
592 let mut buf = BytesMut::with_capacity(cfg.write_batch_bytes);
593 let mut resp_closed = false;
594 let mut push_closed = false;
595 let mut closing = false;
596
597 loop {
598 let mut drained_response = false;
599 while let Ok(value) = resp_rx.try_recv() {
600 drained_response = true;
601 value.encode(&mut buf);
602 if buf.len() >= cfg.write_batch_bytes {
603 flush_buffer(&mut writer, &mut buf, cfg.write_timeout).await?;
604 }
605 }
606 if !drained_response && !closing {
607 while let Ok(value) = push_rx.try_recv() {
608 value.encode(&mut buf);
609 if buf.len() >= cfg.write_batch_bytes {
610 break;
611 }
612 }
613 }
614 if !buf.is_empty() {
615 flush_buffer(&mut writer, &mut buf, cfg.write_timeout).await?;
616 }
617
618 if closing {
619 break;
620 }
621
622 if resp_closed && push_closed {
623 break;
624 }
625
626 tokio::select! {
627 biased;
628 _ = &mut close_rx => {
629 closing = true;
630 }
631 res = resp_rx.recv() => {
632 match res {
633 Some(value) => value.encode(&mut buf),
634 None => resp_closed = true,
635 }
636 }
637 res = push_rx.recv(), if !closing => {
638 match res {
639 Some(value) => value.encode(&mut buf),
640 None => push_closed = true,
641 }
642 }
643 }
644
645 if !buf.is_empty() {
646 flush_buffer(&mut writer, &mut buf, cfg.write_timeout).await?;
647 }
648 }
649
650 Ok(())
651}
652
653fn signal_writer_close(tx: &mut Option<oneshot::Sender<()>>) {
654 if let Some(tx) = tx.take() {
655 let _ = tx.send(());
656 }
657}
658
659async fn flush_buffer(
660 writer: &mut tokio::net::tcp::OwnedWriteHalf,
661 buf: &mut BytesMut,
662 timeout: Option<Duration>,
663) -> Result<()> {
664 if buf.is_empty() {
665 return Ok(());
666 }
667 let write = writer.write_all(buf);
668 if let Some(timeout) = timeout {
669 match tokio::time::timeout(timeout, write).await {
670 Ok(res) => res?,
671 Err(_) => return Err(std::io::ErrorKind::TimedOut.into()),
672 }
673 } else {
674 write.await?;
675 }
676 buf.clear();
677 Ok(())
678}
679
680fn is_pubsub_allowed(cmd: &bytes::Bytes) -> bool {
681 matches!(
682 cmd.as_ref(),
683 b"SUBSCRIBE" | b"PSUBSCRIBE" | b"UNSUBSCRIBE" | b"PUNSUBSCRIBE" | b"PING" | b"QUIT"
684 )
685}