dynomite/net/proxy.rs
1//! PROXY listener.
2//!
3//! Listens for client connections on the configured `listen:` port
4//! and spawns a CLIENT FSM per accepted socket. [`Proxy`] owns a
5//! [`tokio::net::TcpListener`] and a per-listener [`Dispatcher`]
6//! reference; calling [`Proxy::run`] enters an accept-loop that
7//! drives a fresh `tokio::spawn` for every incoming socket.
8//!
9//! # Examples
10//!
11//! ```no_run
12//! use dynomite::net::{NoopDispatcher, Proxy};
13//! use std::sync::Arc;
14//!
15//! let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
16//! let proxy = Proxy::bind(addr, Arc::new(NoopDispatcher)).unwrap();
17//! let _handle = proxy.local_addr();
18//! ```
19
20use std::io;
21use std::net::SocketAddr;
22use std::sync::Arc;
23
24use tokio::net::TcpListener;
25use tokio::sync::mpsc;
26use tokio::task::JoinHandle;
27use tracing::Instrument as _;
28
29use crate::conf::DataStore;
30use crate::io::reactor::{ConnRole, TcpTransport};
31use crate::net::client::{client_loop, ClientHandler};
32use crate::net::conn::Conn;
33use crate::net::dispatcher::Dispatcher;
34use crate::net::listener::{bind_dual_stack, BindOptions};
35use crate::net::NetError;
36
37/// PROXY listener.
38pub struct Proxy {
39 listener: TcpListener,
40 dispatcher: Arc<dyn Dispatcher>,
41 data_store: DataStore,
42 response_capacity: usize,
43}
44
45impl Proxy {
46 /// Bind a proxy listener to the given address.
47 ///
48 /// Uses [`crate::net::listener::bind_dual_stack`] to honor v4 +
49 /// v6 wildcard semantics. The dispatcher is invoked for every
50 /// fully-parsed request from any accepted client.
51 ///
52 /// # Errors
53 /// Forwarded from the underlying socket calls.
54 pub fn bind<A: Into<SocketAddr>>(
55 addr: A,
56 dispatcher: Arc<dyn Dispatcher>,
57 ) -> Result<Self, NetError> {
58 let listener = bind_dual_stack(addr.into(), BindOptions::default())?;
59 Ok(Self {
60 listener,
61 dispatcher,
62 data_store: DataStore::Redis,
63 response_capacity: 64,
64 })
65 }
66
67 /// Override the datastore the per-client FSMs will parse.
68 /// Defaults to [`DataStore::Redis`].
69 #[must_use]
70 pub fn with_data_store(mut self, ds: DataStore) -> Self {
71 self.data_store = ds;
72 self
73 }
74
75 /// Override the response-channel capacity per client.
76 #[must_use]
77 pub fn with_response_capacity(mut self, n: usize) -> Self {
78 self.response_capacity = n.max(1);
79 self
80 }
81
82 /// Local address of the listener.
83 pub fn local_addr(&self) -> io::Result<SocketAddr> {
84 self.listener.local_addr()
85 }
86
87 /// Borrow the bound listener so callers can extract the
88 /// fd-level socket handle when needed.
89 pub fn listener(&self) -> &TcpListener {
90 &self.listener
91 }
92
93 /// Drive the accept loop until the listener returns an error
94 /// or the supplied cancel future resolves.
95 ///
96 /// Each accepted socket is wrapped in a [`Conn`] tagged
97 /// [`ConnRole::Client`] and handed to a per-task client loop.
98 ///
99 /// # Errors
100 /// Forwarded from the listener accept call.
101 #[tracing::instrument(
102 name = "proxy.run",
103 skip_all,
104 fields(
105 local = self.listener.local_addr().map_or_else(|_| String::from("?"), |a| a.to_string()),
106 ),
107 )]
108 pub async fn run(
109 self,
110 cancel: std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
111 ) -> Result<(), NetError> {
112 let mut cancel = cancel;
113 let mut clients: Vec<JoinHandle<Result<(), NetError>>> = Vec::new();
114 loop {
115 let accept = self.listener.accept();
116 tokio::select! {
117 () = &mut cancel => break,
118 res = accept => {
119 let (sock, peer) = res?;
120 // Match the latency expectation of the
121 // datastore engines: Redis and memcache both
122 // assume the upstream proxy disables Nagle so
123 // small Redis requests fly without batching.
124 // Errors here are non-fatal: a peer that
125 // disconnected before the option could be
126 // applied is fine.
127 let _ = sock.set_nodelay(true);
128 let role = ConnRole::Client;
129 let transport = Box::new(TcpTransport::new(sock, role));
130 let conn = Conn::new(transport, role);
131 let dispatcher = Arc::clone(&self.dispatcher);
132 let cap = self.response_capacity;
133 let ds = self.data_store;
134 tracing::debug!(?peer, "proxy accepted client");
135 let accept_span = tracing::info_span!(
136 "client.accept",
137 peer = %peer,
138 );
139 let handle = tokio::spawn(
140 async move {
141 let (tx, rx) = mpsc::channel(cap);
142 let handler = ClientHandler::new(dispatcher, tx, ds);
143 client_loop(conn, handler, rx).await
144 }
145 .instrument(accept_span),
146 );
147 clients.push(handle);
148 }
149 }
150 // Drain finished tasks opportunistically.
151 clients.retain(|h| !h.is_finished());
152 }
153 for h in clients {
154 // Give each client a brief window to drain (e.g.
155 // finish writing the last response after the
156 // listener has stopped accepting). After that, abort
157 // so a wedged client_loop cannot keep the proxy from
158 // exiting on shutdown.
159 //
160 // Both arms of this match are intentionally no-ops:
161 // the timeout case can't abort the consumed handle
162 // (tokio will GC it on shutdown), and the
163 // task-completed case has nothing to do.
164 let _ = tokio::time::timeout(std::time::Duration::from_millis(250), h).await;
165 }
166 Ok(())
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[tokio::test]
175 async fn bind_returns_local_addr() {
176 let proxy = Proxy::bind(
177 "127.0.0.1:0".parse::<SocketAddr>().unwrap(),
178 Arc::new(crate::net::NoopDispatcher),
179 )
180 .unwrap();
181 assert!(proxy.local_addr().unwrap().ip().is_loopback());
182 }
183
184 #[tokio::test]
185 async fn run_exits_on_cancel() {
186 let proxy = Proxy::bind(
187 "127.0.0.1:0".parse::<SocketAddr>().unwrap(),
188 Arc::new(crate::net::NoopDispatcher),
189 )
190 .unwrap();
191 proxy.run(Box::pin(async {})).await.unwrap();
192 }
193}