Skip to main content

dynomite/net/
proxy.rs

1//! PROXY listener.
2//!
3//! Listens for client connections on the configured `listen:` port
4//! and spawns a CLIENT FSM per accepted socket. [`Proxy`] owns a
5//! [`tokio::net::TcpListener`] and a per-listener [`Dispatcher`]
6//! reference; calling [`Proxy::run`] enters an accept-loop that
7//! drives a fresh `tokio::spawn` for every incoming socket.
8//!
9//! # Examples
10//!
11//! ```no_run
12//! use dynomite::net::{NoopDispatcher, Proxy};
13//! use std::sync::Arc;
14//!
15//! let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
16//! let proxy = Proxy::bind(addr, Arc::new(NoopDispatcher)).unwrap();
17//! let _handle = proxy.local_addr();
18//! ```
19
20use std::io;
21use std::net::SocketAddr;
22use std::sync::Arc;
23
24use tokio::net::TcpListener;
25use tokio::sync::mpsc;
26use tokio::task::JoinHandle;
27use tracing::Instrument as _;
28
29use crate::conf::DataStore;
30use crate::io::reactor::{ConnRole, TcpTransport};
31use crate::net::client::{client_loop, ClientHandler};
32use crate::net::conn::Conn;
33use crate::net::dispatcher::Dispatcher;
34use crate::net::listener::{bind_dual_stack, BindOptions};
35use crate::net::NetError;
36
37/// PROXY listener.
38pub struct Proxy {
39    listener: TcpListener,
40    dispatcher: Arc<dyn Dispatcher>,
41    data_store: DataStore,
42    response_capacity: usize,
43}
44
45impl Proxy {
46    /// Bind a proxy listener to the given address.
47    ///
48    /// Uses [`crate::net::listener::bind_dual_stack`] to honor v4 +
49    /// v6 wildcard semantics. The dispatcher is invoked for every
50    /// fully-parsed request from any accepted client.
51    ///
52    /// # Errors
53    /// Forwarded from the underlying socket calls.
54    pub fn bind<A: Into<SocketAddr>>(
55        addr: A,
56        dispatcher: Arc<dyn Dispatcher>,
57    ) -> Result<Self, NetError> {
58        let listener = bind_dual_stack(addr.into(), BindOptions::default())?;
59        Ok(Self {
60            listener,
61            dispatcher,
62            data_store: DataStore::Redis,
63            response_capacity: 64,
64        })
65    }
66
67    /// Override the datastore the per-client FSMs will parse.
68    /// Defaults to [`DataStore::Redis`].
69    #[must_use]
70    pub fn with_data_store(mut self, ds: DataStore) -> Self {
71        self.data_store = ds;
72        self
73    }
74
75    /// Override the response-channel capacity per client.
76    #[must_use]
77    pub fn with_response_capacity(mut self, n: usize) -> Self {
78        self.response_capacity = n.max(1);
79        self
80    }
81
82    /// Local address of the listener.
83    pub fn local_addr(&self) -> io::Result<SocketAddr> {
84        self.listener.local_addr()
85    }
86
87    /// Borrow the bound listener so callers can extract the
88    /// fd-level socket handle when needed.
89    pub fn listener(&self) -> &TcpListener {
90        &self.listener
91    }
92
93    /// Drive the accept loop until the listener returns an error
94    /// or the supplied cancel future resolves.
95    ///
96    /// Each accepted socket is wrapped in a [`Conn`] tagged
97    /// [`ConnRole::Client`] and handed to a per-task client loop.
98    ///
99    /// # Errors
100    /// Forwarded from the listener accept call.
101    #[tracing::instrument(
102        name = "proxy.run",
103        skip_all,
104        fields(
105            local = self.listener.local_addr().map_or_else(|_| String::from("?"), |a| a.to_string()),
106        ),
107    )]
108    pub async fn run(
109        self,
110        cancel: std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
111    ) -> Result<(), NetError> {
112        let mut cancel = cancel;
113        let mut clients: Vec<JoinHandle<Result<(), NetError>>> = Vec::new();
114        loop {
115            let accept = self.listener.accept();
116            tokio::select! {
117                () = &mut cancel => break,
118                res = accept => {
119                    let (sock, peer) = res?;
120                    // Match the latency expectation of the
121                    // datastore engines: Redis and memcache both
122                    // assume the upstream proxy disables Nagle so
123                    // small Redis requests fly without batching.
124                    // Errors here are non-fatal: a peer that
125                    // disconnected before the option could be
126                    // applied is fine.
127                    let _ = sock.set_nodelay(true);
128                    let role = ConnRole::Client;
129                    let transport = Box::new(TcpTransport::new(sock, role));
130                    let conn = Conn::new(transport, role);
131                    let dispatcher = Arc::clone(&self.dispatcher);
132                    let cap = self.response_capacity;
133                    let ds = self.data_store;
134                    tracing::debug!(?peer, "proxy accepted client");
135                    let accept_span = tracing::info_span!(
136                        "client.accept",
137                        peer = %peer,
138                    );
139                    let handle = tokio::spawn(
140                        async move {
141                            let (tx, rx) = mpsc::channel(cap);
142                            let handler = ClientHandler::new(dispatcher, tx, ds);
143                            client_loop(conn, handler, rx).await
144                        }
145                        .instrument(accept_span),
146                    );
147                    clients.push(handle);
148                }
149            }
150            // Drain finished tasks opportunistically.
151            clients.retain(|h| !h.is_finished());
152        }
153        for h in clients {
154            // Give each client a brief window to drain (e.g.
155            // finish writing the last response after the
156            // listener has stopped accepting). After that, abort
157            // so a wedged client_loop cannot keep the proxy from
158            // exiting on shutdown.
159            //
160            // Both arms of this match are intentionally no-ops:
161            // the timeout case can't abort the consumed handle
162            // (tokio will GC it on shutdown), and the
163            // task-completed case has nothing to do.
164            let _ = tokio::time::timeout(std::time::Duration::from_millis(250), h).await;
165        }
166        Ok(())
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[tokio::test]
175    async fn bind_returns_local_addr() {
176        let proxy = Proxy::bind(
177            "127.0.0.1:0".parse::<SocketAddr>().unwrap(),
178            Arc::new(crate::net::NoopDispatcher),
179        )
180        .unwrap();
181        assert!(proxy.local_addr().unwrap().ip().is_loopback());
182    }
183
184    #[tokio::test]
185    async fn run_exits_on_cancel() {
186        let proxy = Proxy::bind(
187            "127.0.0.1:0".parse::<SocketAddr>().unwrap(),
188            Arc::new(crate::net::NoopDispatcher),
189        )
190        .unwrap();
191        proxy.run(Box::pin(async {})).await.unwrap();
192    }
193}