protosocket_server/
connection_server.rs1use protosocket::Connection;
2use protosocket::ConnectionBindings;
3use protosocket::Serializer;
4use socket2::TcpKeepalive;
5use std::ffi::c_int;
6use std::future::Future;
7use std::io::Error;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::task::Context;
11use std::task::Poll;
12use tokio::sync::mpsc;
13
14pub trait ServerConnector: Unpin {
15 type Bindings: ConnectionBindings;
16
17 fn serializer(&self) -> <Self::Bindings as ConnectionBindings>::Serializer;
18 fn deserializer(&self) -> <Self::Bindings as ConnectionBindings>::Deserializer;
19
20 fn new_reactor(
21 &self,
22 optional_outbound: mpsc::Sender<
23 <<Self::Bindings as ConnectionBindings>::Serializer as Serializer>::Message,
24 >,
25 address: SocketAddr,
26 ) -> <Self::Bindings as ConnectionBindings>::Reactor;
27
28 fn connect(
29 &self,
30 stream: tokio::net::TcpStream,
31 ) -> <Self::Bindings as ConnectionBindings>::Stream;
32}
33
34pub struct ProtosocketServer<Connector: ServerConnector> {
53 connector: Connector,
54 listener: tokio::net::TcpListener,
55 max_buffer_length: usize,
56 buffer_allocation_increment: usize,
57 max_queued_outbound_messages: usize,
58 runtime: tokio::runtime::Handle,
59}
60
61pub struct ProtosocketSocketConfig {
63 nodelay: bool,
64 reuse: bool,
65 keepalive_duration: Option<std::time::Duration>,
66 listen_backlog: u32,
67}
68
69impl ProtosocketSocketConfig {
70 pub fn nodelay(mut self, nodelay: bool) -> Self {
72 self.nodelay = nodelay;
73 self
74 }
75 pub fn reuse(mut self, reuse: bool) -> Self {
77 self.reuse = reuse;
78 self
79 }
80 pub fn keepalive_duration(mut self, keepalive_duration: std::time::Duration) -> Self {
82 self.keepalive_duration = Some(keepalive_duration);
83 self
84 }
85 pub fn listen_backlog(mut self, backlog: u32) -> Self {
87 self.listen_backlog = backlog;
88 self
89 }
90}
91
92impl Default for ProtosocketSocketConfig {
93 fn default() -> Self {
94 Self {
95 nodelay: true,
96 reuse: true,
97 keepalive_duration: None,
98 listen_backlog: 65536,
99 }
100 }
101}
102
103pub struct ProtosocketServerConfig {
104 max_buffer_length: usize,
105 max_queued_outbound_messages: usize,
106 buffer_allocation_increment: usize,
107 socket_config: ProtosocketSocketConfig,
108}
109
110impl ProtosocketServerConfig {
111 pub fn max_buffer_length(mut self, max_buffer_length: usize) -> Self {
113 self.max_buffer_length = max_buffer_length;
114 self
115 }
116 pub fn max_queued_outbound_messages(mut self, max_queued_outbound_messages: usize) -> Self {
118 self.max_queued_outbound_messages = max_queued_outbound_messages;
119 self
120 }
121 pub fn buffer_allocation_increment(mut self, buffer_allocation_increment: usize) -> Self {
123 self.buffer_allocation_increment = buffer_allocation_increment;
124 self
125 }
126 pub fn socket_config(mut self, config: ProtosocketSocketConfig) -> Self {
128 self.socket_config = config;
129 self
130 }
131
132 pub async fn bind_tcp<Connector: ServerConnector>(
135 self,
136 address: SocketAddr,
137 connector: Connector,
138 runtime: tokio::runtime::Handle,
139 ) -> crate::Result<ProtosocketServer<Connector>> {
140 ProtosocketServer::new(address, runtime, connector, self).await
141 }
142}
143
144impl Default for ProtosocketServerConfig {
145 fn default() -> Self {
146 Self {
147 max_buffer_length: 16 * (2 << 20),
148 max_queued_outbound_messages: 128,
149 buffer_allocation_increment: 1 << 20,
150 socket_config: Default::default(),
151 }
152 }
153}
154
155impl<Connector: ServerConnector> ProtosocketServer<Connector> {
156 async fn new(
160 address: SocketAddr,
161 runtime: tokio::runtime::Handle,
162 connector: Connector,
163 config: ProtosocketServerConfig,
164 ) -> crate::Result<Self> {
165 let socket = socket2::Socket::new(
166 match address {
167 SocketAddr::V4(_) => socket2::Domain::IPV4,
168 SocketAddr::V6(_) => socket2::Domain::IPV6,
169 },
170 socket2::Type::STREAM,
171 None,
172 )?;
173
174 let mut tcp_keepalive = TcpKeepalive::new();
175 if let Some(duration) = config.socket_config.keepalive_duration {
176 tcp_keepalive = tcp_keepalive.with_time(duration);
177 }
178
179 socket.set_nonblocking(true)?;
180 socket.set_tcp_nodelay(config.socket_config.nodelay)?;
181 socket.set_tcp_keepalive(&tcp_keepalive)?;
182 socket.set_reuse_port(config.socket_config.reuse)?;
183 socket.set_reuse_address(config.socket_config.reuse)?;
184
185 socket.bind(&address.into())?;
186 socket.listen(config.socket_config.listen_backlog as c_int)?;
187
188 let listener = tokio::net::TcpListener::from_std(socket.into())?;
189 Ok(Self {
190 connector,
191 listener,
192 max_buffer_length: config.max_buffer_length,
193 max_queued_outbound_messages: config.max_queued_outbound_messages,
194 buffer_allocation_increment: config.buffer_allocation_increment,
195 runtime,
196 })
197 }
198}
199
200impl<Connector: ServerConnector> Future for ProtosocketServer<Connector> {
201 type Output = Result<(), Error>;
202
203 fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
204 loop {
205 break match self.listener.poll_accept(context) {
206 Poll::Ready(result) => match result {
207 Ok((stream, address)) => {
208 stream.set_nodelay(true)?;
209 let (outbound_submission_queue, outbound_messages) =
210 mpsc::channel(self.max_queued_outbound_messages);
211 let reactor = self
212 .connector
213 .new_reactor(outbound_submission_queue.clone(), address);
214 let stream = self.connector.connect(stream);
215 let connection: Connection<Connector::Bindings> = Connection::new(
216 stream,
217 address,
218 self.connector.deserializer(),
219 self.connector.serializer(),
220 self.max_buffer_length,
221 self.buffer_allocation_increment,
222 self.max_queued_outbound_messages,
223 outbound_messages,
224 reactor,
225 );
226 self.runtime.spawn(connection);
227 continue;
228 }
229 Err(e) => {
230 log::error!("failed to accept connection: {e:?}");
231 continue;
232 }
233 },
234 Poll::Pending => Poll::Pending,
235 };
236 }
237 }
238}