Skip to main content

resp_async/
io.rs

1use std::collections::VecDeque;
2use std::net::{SocketAddr, ToSocketAddrs};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
5use std::time::Duration;
6
7use bytes::BytesMut;
8use log::error;
9use socket2::{Domain, Protocol, SockAddr, Socket, Type};
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{TcpListener, TcpStream};
12use tokio::sync::{Semaphore, broadcast, mpsc, oneshot};
13use tokio::task::JoinSet;
14
15use crate::context::{Command, Extensions, PubSubHandle, PushHandle, RequestContext};
16use crate::error::Result;
17use crate::resp::{DecodeLimits, Value, ValueDecoder};
18use crate::response::{IntoResponse, RespError};
19use crate::router::Router;
20
21/// Per-connection metadata passed to hooks and extension factories.
22#[derive(Debug, Clone, Copy)]
23pub struct ConnectionInfo {
24    pub id: u64,
25    pub peer_addr: SocketAddr,
26    pub local_addr: SocketAddr,
27}
28
29/// Observability hooks for the server runtime.
30pub trait ServerHooks: Send + Sync + 'static {
31    fn on_accept_error(&self, _err: &std::io::Error) {}
32    fn on_connection_open(&self, _info: ConnectionInfo) {}
33    fn on_connection_close(&self, _info: ConnectionInfo) {}
34    fn on_command(&self, _id: u64, _command: &Command) {}
35    fn on_protocol_error(&self, _err: &RespError) {}
36    fn on_io_error(&self, _err: &std::io::Error) {}
37}
38
39/// Default no-op hooks implementation.
40#[derive(Debug, Default, Clone)]
41pub struct NoopServerHooks;
42
43impl ServerHooks for NoopServerHooks {}
44
45struct ConnectionGuard {
46    hooks: Arc<dyn ServerHooks>,
47    info: ConnectionInfo,
48}
49
50impl Drop for ConnectionGuard {
51    fn drop(&mut self) {
52        self.hooks.on_connection_close(self.info);
53    }
54}
55
56type ExtensionsFactory = Arc<dyn Fn(ConnectionInfo) -> Extensions + Send + Sync + 'static>;
57
58fn default_extensions(_info: ConnectionInfo) -> Extensions {
59    Extensions::default()
60}
61
62/// Server configuration for limits and runtime behavior.
63#[derive(Clone, Debug)]
64pub struct ServerConfig {
65    pub max_frame_size: usize,
66    pub max_bulk_len: usize,
67    pub max_array_len: usize,
68    pub max_depth: usize,
69    pub max_inflight_requests: usize,
70    pub max_connections: usize,
71    pub read_timeout: Option<Duration>,
72    pub write_timeout: Option<Duration>,
73    pub idle_timeout: Option<Duration>,
74    pub push_queue_len: usize,
75    pub response_queue_len: usize,
76    pub write_batch_bytes: usize,
77    pub tcp_nodelay: bool,
78    pub backlog: Option<u32>,
79}
80
81impl Default for ServerConfig {
82    fn default() -> Self {
83        Self {
84            max_frame_size: 1 << 20,
85            max_bulk_len: 1 << 20,
86            max_array_len: 1024,
87            max_depth: 16,
88            max_inflight_requests: 128,
89            max_connections: 1024,
90            read_timeout: None,
91            write_timeout: None,
92            idle_timeout: None,
93            push_queue_len: 1024,
94            response_queue_len: 1024,
95            write_batch_bytes: 8 * 1024,
96            tcp_nodelay: true,
97            backlog: None,
98        }
99    }
100}
101
102/// Builder for [`ServerConfig`].
103#[derive(Clone, Debug)]
104pub struct ServerConfigBuilder {
105    cfg: ServerConfig,
106}
107
108impl ServerConfig {
109    pub fn builder() -> ServerConfigBuilder {
110        ServerConfigBuilder {
111            cfg: ServerConfig::default(),
112        }
113    }
114}
115
116impl ServerConfigBuilder {
117    /// Set the maximum total frame size (bytes).
118    pub fn max_frame_size(mut self, value: usize) -> Self {
119        self.cfg.max_frame_size = value.max(1);
120        self
121    }
122
123    /// Set the maximum bulk string length (bytes).
124    pub fn max_bulk_len(mut self, value: usize) -> Self {
125        self.cfg.max_bulk_len = value.max(1);
126        self
127    }
128
129    /// Set the maximum array length (elements).
130    pub fn max_array_len(mut self, value: usize) -> Self {
131        self.cfg.max_array_len = value.max(1);
132        self
133    }
134
135    /// Set the maximum nesting depth.
136    pub fn max_depth(mut self, value: usize) -> Self {
137        self.cfg.max_depth = value.max(1);
138        self
139    }
140
141    /// Set the maximum number of buffered in-flight requests.
142    pub fn max_inflight_requests(mut self, value: usize) -> Self {
143        self.cfg.max_inflight_requests = value.max(1);
144        self
145    }
146
147    /// Set the maximum number of concurrent connections.
148    pub fn max_connections(mut self, value: usize) -> Self {
149        self.cfg.max_connections = value.max(1);
150        self
151    }
152
153    /// Set the read timeout for socket reads.
154    pub fn read_timeout(mut self, value: Option<Duration>) -> Self {
155        self.cfg.read_timeout = value;
156        self
157    }
158
159    /// Set the write timeout for socket writes.
160    pub fn write_timeout(mut self, value: Option<Duration>) -> Self {
161        self.cfg.write_timeout = value;
162        self
163    }
164
165    /// Set the idle timeout for connections.
166    pub fn idle_timeout(mut self, value: Option<Duration>) -> Self {
167        self.cfg.idle_timeout = value;
168        self
169    }
170
171    /// Set the push queue length.
172    pub fn push_queue_len(mut self, value: usize) -> Self {
173        self.cfg.push_queue_len = value.max(1);
174        self
175    }
176
177    /// Set the response queue length.
178    pub fn response_queue_len(mut self, value: usize) -> Self {
179        self.cfg.response_queue_len = value.max(1);
180        self
181    }
182
183    /// Set the write batch size (bytes).
184    pub fn write_batch_bytes(mut self, value: usize) -> Self {
185        self.cfg.write_batch_bytes = value.max(1);
186        self
187    }
188
189    /// Enable or disable TCP_NODELAY.
190    pub fn tcp_nodelay(mut self, value: bool) -> Self {
191        self.cfg.tcp_nodelay = value;
192        self
193    }
194
195    /// Set the listener backlog (may be ignored depending on platform).
196    pub fn backlog(mut self, value: Option<u32>) -> Self {
197        self.cfg.backlog = value;
198        self
199    }
200
201    /// Finalize the configuration.
202    pub fn build(self) -> ServerConfig {
203        self.cfg
204    }
205}
206
207/// Server builder for configuring the runtime.
208pub struct ServerBuilder {
209    addr: String,
210    cfg: ServerConfig,
211    shutdown: Option<BoxFuture>,
212    hooks: Arc<dyn ServerHooks>,
213    extensions_factory: ExtensionsFactory,
214}
215
216type BoxFuture = std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>;
217
218impl ServerBuilder {
219    /// Apply a custom server configuration.
220    pub fn with_config(mut self, cfg: ServerConfig) -> Self {
221        self.cfg = cfg;
222        self
223    }
224
225    /// Provide a future to trigger graceful shutdown.
226    pub fn with_graceful_shutdown<F>(mut self, fut: F) -> Self
227    where
228        F: std::future::Future<Output = ()> + Send + 'static,
229    {
230        self.shutdown = Some(Box::pin(fut));
231        self
232    }
233
234    /// Provide observability hooks for the server runtime.
235    pub fn with_hooks<H>(mut self, hooks: H) -> Self
236    where
237        H: ServerHooks,
238    {
239        self.hooks = Arc::new(hooks);
240        self
241    }
242
243    /// Provide per-connection extensions for request contexts.
244    ///
245    /// The factory is invoked once per connection and its `Extensions` are
246    /// cloned into each `RequestContext` for that connection.
247    pub fn with_connection_extensions<F>(mut self, factory: F) -> Self
248    where
249        F: Fn(ConnectionInfo) -> Extensions + Send + Sync + 'static,
250    {
251        self.extensions_factory = Arc::new(factory);
252        self
253    }
254
255    pub async fn serve<State>(self, app: Router<State>) -> Result<()>
256    where
257        State: Send + Sync + 'static,
258    {
259        let listener = if let Some(backlog) = self.cfg.backlog {
260            bind_with_backlog(&self.addr, backlog)?
261        } else {
262            TcpListener::bind(&self.addr).await?
263        };
264        self.serve_with_listener(listener, app).await
265    }
266
267    /// Serve using a pre-bound listener (useful for tests).
268    pub async fn serve_with_listener<State>(
269        self,
270        listener: TcpListener,
271        app: Router<State>,
272    ) -> Result<()>
273    where
274        State: Send + Sync + 'static,
275    {
276        run_server(
277            listener,
278            app,
279            self.cfg,
280            self.shutdown,
281            self.hooks,
282            self.extensions_factory,
283        )
284        .await
285    }
286}
287
288/// A RESP server.
289pub struct Server;
290
291impl Server {
292    /// Bind a server to the provided address.
293    pub fn bind(addr: impl Into<String>) -> ServerBuilder {
294        ServerBuilder {
295            addr: addr.into(),
296            cfg: ServerConfig::default(),
297            shutdown: None,
298            hooks: Arc::new(NoopServerHooks),
299            extensions_factory: Arc::new(default_extensions),
300        }
301    }
302}
303
304fn bind_with_backlog(addr: &str, backlog: u32) -> Result<TcpListener> {
305    let mut addrs = addr.to_socket_addrs()?;
306    let addr = addrs.next().ok_or_else(|| {
307        std::io::Error::new(std::io::ErrorKind::InvalidInput, "empty bind address")
308    })?;
309    let domain = Domain::for_address(addr);
310    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
311    socket.set_nonblocking(true)?;
312    socket.bind(&SockAddr::from(addr))?;
313    let backlog = backlog.max(1).min(i32::MAX as u32) as i32;
314    socket.listen(backlog)?;
315    let listener: std::net::TcpListener = socket.into();
316    TcpListener::from_std(listener)
317}
318
319async fn run_server<State>(
320    listener: TcpListener,
321    app: Router<State>,
322    cfg: ServerConfig,
323    shutdown: Option<BoxFuture>,
324    hooks: Arc<dyn ServerHooks>,
325    extensions_factory: ExtensionsFactory,
326) -> Result<()>
327where
328    State: Send + Sync + 'static,
329{
330    let (shutdown_tx, _) = broadcast::channel(1);
331    let mut shutdown_rx = shutdown_tx.subscribe();
332    let mut join_set = JoinSet::new();
333    let semaphore = Arc::new(Semaphore::new(cfg.max_connections));
334    let client_id = Arc::new(AtomicU64::new(1));
335
336    let shutdown_fut = async move {
337        if let Some(fut) = shutdown {
338            fut.await;
339        } else {
340            std::future::pending::<()>().await;
341        }
342    };
343    let mut shutdown_fut = Box::pin(shutdown_fut);
344
345    loop {
346        tokio::select! {
347            _ = &mut shutdown_fut => {
348                break;
349            }
350            _ = shutdown_rx.recv() => {
351                break;
352            }
353            accept = listener.accept() => {
354                let (socket, _) = match accept {
355                    Ok(value) => value,
356                    Err(err) => {
357                        error!("accept error: {:?}", err);
358                        hooks.on_accept_error(&err);
359                        continue;
360                    }
361                };
362
363                let permit = match semaphore.clone().try_acquire_owned() {
364                    Ok(permit) => permit,
365                    Err(_) => {
366                        continue;
367                    }
368                };
369
370                if cfg.tcp_nodelay {
371                    let _ = socket.set_nodelay(true);
372                }
373
374                let handler = app.clone();
375                let cfg = cfg.clone();
376                let shutdown_rx = shutdown_tx.subscribe();
377                let id = client_id.fetch_add(1, Ordering::AcqRel);
378                let hooks = Arc::clone(&hooks);
379                let extensions_factory = Arc::clone(&extensions_factory);
380                join_set.spawn(async move {
381                    let _permit = permit;
382                    if let Err(err) =
383                        run_connection(
384                            id,
385                            socket,
386                            handler,
387                            cfg,
388                            shutdown_rx,
389                            hooks.clone(),
390                            extensions_factory,
391                        )
392                        .await
393                    {
394                        error!("connection error: {:?}", err);
395                        hooks.on_io_error(&err);
396                    }
397                });
398            }
399        }
400    }
401
402    let _ = shutdown_tx.send(());
403
404    while let Some(res) = join_set.join_next().await {
405        if let Err(err) = res {
406            error!("connection task error: {:?}", err);
407        }
408    }
409
410    Ok(())
411}
412
413async fn run_connection<State>(
414    id: u64,
415    socket: TcpStream,
416    app: Router<State>,
417    cfg: ServerConfig,
418    mut shutdown: broadcast::Receiver<()>,
419    hooks: Arc<dyn ServerHooks>,
420    extensions_factory: ExtensionsFactory,
421) -> Result<()>
422where
423    State: Send + Sync + 'static,
424{
425    let peer_addr = socket.peer_addr()?;
426    let local_addr = socket.local_addr()?;
427    let info = ConnectionInfo {
428        id,
429        peer_addr,
430        local_addr,
431    };
432    hooks.on_connection_open(info);
433    let _guard = ConnectionGuard {
434        hooks: Arc::clone(&hooks),
435        info,
436    };
437    let (mut reader, writer) = socket.into_split();
438
439    let (resp_tx, resp_rx) = mpsc::channel(cfg.response_queue_len);
440    let (push_tx, push_rx) = mpsc::channel(cfg.push_queue_len);
441    let (close_tx, mut close_rx) = mpsc::channel(1);
442    let (writer_close_tx, writer_close_rx) = oneshot::channel();
443    let mut writer_close_tx = Some(writer_close_tx);
444
445    let push_handle = PushHandle::new(push_tx, close_tx);
446    let pubsub_count = Arc::new(AtomicUsize::new(0));
447    let pubsub_handle = PubSubHandle::new(pubsub_count.clone());
448    let extensions = (extensions_factory)(info);
449
450    let writer_cfg = cfg.clone();
451    let writer_task = tokio::spawn(async move {
452        writer_loop(writer, resp_rx, push_rx, writer_close_rx, writer_cfg).await
453    });
454
455    let mut rd = BytesMut::with_capacity(4096);
456    let mut decoder = ValueDecoder::new(DecodeLimits {
457        max_bulk_len: cfg.max_bulk_len,
458        max_array_len: cfg.max_array_len,
459        max_depth: cfg.max_depth,
460    });
461    let mut inflight = VecDeque::new();
462
463    loop {
464        while inflight.len() < cfg.max_inflight_requests {
465            match decoder.try_decode(&mut rd) {
466                Ok(Some(value)) => {
467                    let command = match Command::from_value(value) {
468                        Ok(cmd) => cmd,
469                        Err(err) => {
470                            hooks.on_protocol_error(&err);
471                            let _ = resp_tx.send(err.into_response()).await;
472                            signal_writer_close(&mut writer_close_tx);
473                            return Ok(());
474                        }
475                    };
476
477                    if pubsub_count.load(Ordering::Acquire) > 0
478                        && !is_pubsub_allowed(&command.name_upper)
479                    {
480                        let err = RespError::invalid_data(
481                            "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context",
482                        );
483                        hooks.on_protocol_error(&err);
484                        let _ = resp_tx.send(err.into_response()).await;
485                        continue;
486                    }
487
488                    hooks.on_command(id, &command);
489                    inflight.push_back(command);
490                }
491                Ok(None) => break,
492                Err(err) => {
493                    hooks.on_protocol_error(&err);
494                    let _ = resp_tx.send(err.into_response()).await;
495                    signal_writer_close(&mut writer_close_tx);
496                    return Ok(());
497                }
498            }
499        }
500
501        if let Some(command) = inflight.pop_front() {
502            let close_after = command.name_upper.as_ref() == b"QUIT";
503            let ctx = RequestContext {
504                command,
505                peer_addr,
506                local_addr,
507                client_id: id,
508                extensions: extensions.clone(),
509                push: push_handle.clone(),
510                pubsub: pubsub_handle.clone(),
511            };
512            let response = app.call(ctx).await;
513            if resp_tx.send(response).await.is_err() {
514                break;
515            }
516            if close_after {
517                signal_writer_close(&mut writer_close_tx);
518                break;
519            }
520            continue;
521        }
522
523        tokio::select! {
524            _ = shutdown.recv() => {
525                signal_writer_close(&mut writer_close_tx);
526                break;
527            }
528            _ = close_rx.recv() => {
529                let err = RespError::invalid_data("ERR client output buffer limit reached");
530                hooks.on_protocol_error(&err);
531                let _ = resp_tx.try_send(err.into_response());
532                signal_writer_close(&mut writer_close_tx);
533                break;
534            }
535            read = read_more(&mut reader, &mut rd, cfg.read_timeout, cfg.idle_timeout) => {
536                let read = read?;
537                if read == 0 {
538                    if rd.is_empty() {
539                        signal_writer_close(&mut writer_close_tx);
540                        break;
541                    }
542                    let err = RespError::invalid_data("ERR unexpected EOF");
543                    hooks.on_protocol_error(&err);
544                    let _ = resp_tx.send(err.into_response()).await;
545                    signal_writer_close(&mut writer_close_tx);
546                    break;
547                }
548                if rd.len() > cfg.max_frame_size {
549                    let err = RespError::invalid_data("ERR max frame size exceeded");
550                    hooks.on_protocol_error(&err);
551                    let _ = resp_tx.send(err.into_response()).await;
552                    signal_writer_close(&mut writer_close_tx);
553                    break;
554                }
555            }
556        }
557    }
558
559    signal_writer_close(&mut writer_close_tx);
560    drop(resp_tx);
561    drop(push_handle);
562
563    let _ = writer_task.await;
564    Ok(())
565}
566
567async fn read_more(
568    reader: &mut tokio::net::tcp::OwnedReadHalf,
569    buf: &mut BytesMut,
570    read_timeout: Option<Duration>,
571    idle_timeout: Option<Duration>,
572) -> Result<usize> {
573    let fut = reader.read_buf(buf);
574    let timeout = idle_timeout.or(read_timeout);
575    if let Some(timeout) = timeout {
576        match tokio::time::timeout(timeout, fut).await {
577            Ok(res) => res,
578            Err(_) => Err(std::io::ErrorKind::TimedOut.into()),
579        }
580    } else {
581        fut.await
582    }
583}
584
585async fn writer_loop(
586    mut writer: tokio::net::tcp::OwnedWriteHalf,
587    mut resp_rx: mpsc::Receiver<Value>,
588    mut push_rx: mpsc::Receiver<Value>,
589    mut close_rx: oneshot::Receiver<()>,
590    cfg: ServerConfig,
591) -> Result<()> {
592    let mut buf = BytesMut::with_capacity(cfg.write_batch_bytes);
593    let mut resp_closed = false;
594    let mut push_closed = false;
595    let mut closing = false;
596
597    loop {
598        let mut drained_response = false;
599        while let Ok(value) = resp_rx.try_recv() {
600            drained_response = true;
601            value.encode(&mut buf);
602            if buf.len() >= cfg.write_batch_bytes {
603                flush_buffer(&mut writer, &mut buf, cfg.write_timeout).await?;
604            }
605        }
606        if !drained_response && !closing {
607            while let Ok(value) = push_rx.try_recv() {
608                value.encode(&mut buf);
609                if buf.len() >= cfg.write_batch_bytes {
610                    break;
611                }
612            }
613        }
614        if !buf.is_empty() {
615            flush_buffer(&mut writer, &mut buf, cfg.write_timeout).await?;
616        }
617
618        if closing {
619            break;
620        }
621
622        if resp_closed && push_closed {
623            break;
624        }
625
626        tokio::select! {
627            biased;
628            _ = &mut close_rx => {
629                closing = true;
630            }
631            res = resp_rx.recv() => {
632                match res {
633                    Some(value) => value.encode(&mut buf),
634                    None => resp_closed = true,
635                }
636            }
637            res = push_rx.recv(), if !closing => {
638                match res {
639                    Some(value) => value.encode(&mut buf),
640                    None => push_closed = true,
641                }
642            }
643        }
644
645        if !buf.is_empty() {
646            flush_buffer(&mut writer, &mut buf, cfg.write_timeout).await?;
647        }
648    }
649
650    Ok(())
651}
652
653fn signal_writer_close(tx: &mut Option<oneshot::Sender<()>>) {
654    if let Some(tx) = tx.take() {
655        let _ = tx.send(());
656    }
657}
658
659async fn flush_buffer(
660    writer: &mut tokio::net::tcp::OwnedWriteHalf,
661    buf: &mut BytesMut,
662    timeout: Option<Duration>,
663) -> Result<()> {
664    if buf.is_empty() {
665        return Ok(());
666    }
667    let write = writer.write_all(buf);
668    if let Some(timeout) = timeout {
669        match tokio::time::timeout(timeout, write).await {
670            Ok(res) => res?,
671            Err(_) => return Err(std::io::ErrorKind::TimedOut.into()),
672        }
673    } else {
674        write.await?;
675    }
676    buf.clear();
677    Ok(())
678}
679
680fn is_pubsub_allowed(cmd: &bytes::Bytes) -> bool {
681    matches!(
682        cmd.as_ref(),
683        b"SUBSCRIBE" | b"PSUBSCRIBE" | b"UNSUBSCRIBE" | b"PUNSUBSCRIBE" | b"PING" | b"QUIT"
684    )
685}