protosocket_server/
connection_server.rs

1use protosocket::Connection;
2use protosocket::ConnectionBindings;
3use protosocket::Serializer;
4use socket2::TcpKeepalive;
5use std::ffi::c_int;
6use std::future::Future;
7use std::io::Error;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::Context;
11use std::task::Poll;
12use tokio::sync::mpsc;
13
14pub trait ServerConnector: Unpin {
15    type Bindings: ConnectionBindings;
16
17    fn serializer(&self) -> <Self::Bindings as ConnectionBindings>::Serializer;
18    fn deserializer(&self) -> <Self::Bindings as ConnectionBindings>::Deserializer;
19
20    fn new_reactor(
21        &self,
22        optional_outbound: mpsc::Sender<
23            <<Self::Bindings as ConnectionBindings>::Serializer as Serializer>::Message,
24        >,
25        address: SocketAddr,
26    ) -> <Self::Bindings as ConnectionBindings>::Reactor;
27
28    fn connect(
29        &self,
30        stream: tokio::net::TcpStream,
31    ) -> <Self::Bindings as ConnectionBindings>::Stream;
32}
33
34/// A `protosocket::Connection` is an IO driver. It directly uses tokio's io wrapper of mio to poll
35/// the OS's io primitives, manages read and write buffers, and vends messages to & from connections.
36/// Connections send messages to the ConnectionServer through an mpsc channel, and they receive
37/// inbound messages via a reactor callback.
38///
39/// Protosockets are monomorphic messages: You can only have 1 kind of message per service.
40/// The expected way to work with this is to use prost and protocol buffers to encode messages.
41/// Of course you can do whatever you want, as the telnet example shows.
42///
43/// Protosocket messages are not opinionated about request & reply. If you are, you will need
44/// to implement such a thing. This allows you freely choose whether you want to send
45/// fire-&-forget messages sometimes; however it requires you to write your protocol's rules.
46/// You get an inbound iterable of <MessageIn> batches and an outbound stream of <MessageOut> per
47/// connection - you decide what those mean for you!
48///
49/// A ProtosocketServer is a future: You spawn it and it runs forever.
50///
51/// Construct a new ProtosocketServer by creating a ProtosocketServerConfig and calling the {{bind_tcp}} method.
52pub struct ProtosocketServer<Connector: ServerConnector> {
53    connector: Connector,
54    listener: tokio::net::TcpListener,
55    max_buffer_length: usize,
56    buffer_allocation_increment: usize,
57    max_queued_outbound_messages: usize,
58    runtime: tokio::runtime::Handle,
59}
60
61/// Socket configuration options for a ProtosocketServer.
62pub struct ProtosocketSocketConfig {
63    nodelay: bool,
64    reuse: bool,
65    keepalive_duration: Option<std::time::Duration>,
66    listen_backlog: u32,
67}
68
69impl ProtosocketSocketConfig {
70    /// Whether nodelay should be set on the socket.
71    pub fn nodelay(mut self, nodelay: bool) -> Self {
72        self.nodelay = nodelay;
73        self
74    }
75    /// Whether reuseaddr and reuseport should be set on the socket.
76    pub fn reuse(mut self, reuse: bool) -> Self {
77        self.reuse = reuse;
78        self
79    }
80    /// The keepalive window to be set on the socket.
81    pub fn keepalive_duration(mut self, keepalive_duration: std::time::Duration) -> Self {
82        self.keepalive_duration = Some(keepalive_duration);
83        self
84    }
85    /// The backlog to be set on the socket when invoking `listen`.
86    pub fn listen_backlog(mut self, backlog: u32) -> Self {
87        self.listen_backlog = backlog;
88        self
89    }
90}
91
92impl Default for ProtosocketSocketConfig {
93    fn default() -> Self {
94        Self {
95            nodelay: true,
96            reuse: true,
97            keepalive_duration: None,
98            listen_backlog: 65536,
99        }
100    }
101}
102
103pub struct ProtosocketServerConfig {
104    max_buffer_length: usize,
105    max_queued_outbound_messages: usize,
106    buffer_allocation_increment: usize,
107    socket_config: ProtosocketSocketConfig,
108}
109
110impl ProtosocketServerConfig {
111    /// The maximum buffer length per connection on this server.
112    pub fn max_buffer_length(mut self, max_buffer_length: usize) -> Self {
113        self.max_buffer_length = max_buffer_length;
114        self
115    }
116    /// The maximum number of queued outbound messages per connection on this server.
117    pub fn max_queued_outbound_messages(mut self, max_queued_outbound_messages: usize) -> Self {
118        self.max_queued_outbound_messages = max_queued_outbound_messages;
119        self
120    }
121    /// The step size for allocating additional memory for connection buffers on this server.
122    pub fn buffer_allocation_increment(mut self, buffer_allocation_increment: usize) -> Self {
123        self.buffer_allocation_increment = buffer_allocation_increment;
124        self
125    }
126    /// The tcp socket configuration options for this server.
127    pub fn socket_config(mut self, config: ProtosocketSocketConfig) -> Self {
128        self.socket_config = config;
129        self
130    }
131
132    /// Binds a tcp listener to the given address and returns a ProtosocketServer with this configuration.
133    /// After binding, you must await the returned server future to process requests.
134    pub async fn bind_tcp<Connector: ServerConnector>(
135        self,
136        address: SocketAddr,
137        connector: Connector,
138        runtime: tokio::runtime::Handle,
139    ) -> crate::Result<ProtosocketServer<Connector>> {
140        ProtosocketServer::new(address, runtime, connector, self).await
141    }
142}
143
144impl Default for ProtosocketServerConfig {
145    fn default() -> Self {
146        Self {
147            max_buffer_length: 16 * (2 << 20),
148            max_queued_outbound_messages: 128,
149            buffer_allocation_increment: 1 << 20,
150            socket_config: Default::default(),
151        }
152    }
153}
154
155impl<Connector: ServerConnector> ProtosocketServer<Connector> {
156    /// Construct a new `ProtosocketServer` listening on the provided address.
157    /// The address will be bound and listened upon with `SO_REUSEADDR` set.
158    /// The server will use the provided runtime to spawn new tcp connections as `protosocket::Connection`s.
159    async fn new(
160        address: SocketAddr,
161        runtime: tokio::runtime::Handle,
162        connector: Connector,
163        config: ProtosocketServerConfig,
164    ) -> crate::Result<Self> {
165        let socket = socket2::Socket::new(
166            match address {
167                SocketAddr::V4(_) => socket2::Domain::IPV4,
168                SocketAddr::V6(_) => socket2::Domain::IPV6,
169            },
170            socket2::Type::STREAM,
171            None,
172        )?;
173
174        let mut tcp_keepalive = TcpKeepalive::new();
175        if let Some(duration) = config.socket_config.keepalive_duration {
176            tcp_keepalive = tcp_keepalive.with_time(duration);
177        }
178
179        socket.set_nonblocking(true)?;
180        socket.set_tcp_nodelay(config.socket_config.nodelay)?;
181        socket.set_tcp_keepalive(&tcp_keepalive)?;
182        socket.set_reuse_port(config.socket_config.reuse)?;
183        socket.set_reuse_address(config.socket_config.reuse)?;
184
185        socket.bind(&address.into())?;
186        socket.listen(config.socket_config.listen_backlog as c_int)?;
187
188        let listener = tokio::net::TcpListener::from_std(socket.into())?;
189        Ok(Self {
190            connector,
191            listener,
192            max_buffer_length: config.max_buffer_length,
193            max_queued_outbound_messages: config.max_queued_outbound_messages,
194            buffer_allocation_increment: config.buffer_allocation_increment,
195            runtime,
196        })
197    }
198}
199
200impl<Connector: ServerConnector> Future for ProtosocketServer<Connector> {
201    type Output = Result<(), Error>;
202
203    fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
204        loop {
205            break match self.listener.poll_accept(context) {
206                Poll::Ready(result) => match result {
207                    Ok((stream, address)) => {
208                        stream.set_nodelay(true)?;
209                        let (outbound_submission_queue, outbound_messages) =
210                            mpsc::channel(self.max_queued_outbound_messages);
211                        let reactor = self
212                            .connector
213                            .new_reactor(outbound_submission_queue.clone(), address);
214                        let stream = self.connector.connect(stream);
215                        let connection: Connection<Connector::Bindings> = Connection::new(
216                            stream,
217                            address,
218                            self.connector.deserializer(),
219                            self.connector.serializer(),
220                            self.max_buffer_length,
221                            self.buffer_allocation_increment,
222                            self.max_queued_outbound_messages,
223                            outbound_messages,
224                            reactor,
225                        );
226                        self.runtime.spawn(connection);
227                        continue;
228                    }
229                    Err(e) => {
230                        log::error!("failed to accept connection: {e:?}");
231                        continue;
232                    }
233                },
234                Poll::Pending => Poll::Pending,
235            };
236        }
237    }
238}