protosocket_server/
connection_server.rs

1use std::future::Future;
2use std::io::Error;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::Context;
6use std::task::Poll;
7
8use protosocket::Connection;
9use protosocket::ConnectionBindings;
10use protosocket::Serializer;
11use tokio::sync::mpsc;
12
13pub trait ServerConnector: Unpin {
14    type Bindings: ConnectionBindings;
15
16    fn serializer(&self) -> <Self::Bindings as ConnectionBindings>::Serializer;
17    fn deserializer(&self) -> <Self::Bindings as ConnectionBindings>::Deserializer;
18
19    fn new_reactor(
20        &self,
21        optional_outbound: mpsc::Sender<
22            <<Self::Bindings as ConnectionBindings>::Serializer as Serializer>::Message,
23        >,
24    ) -> <Self::Bindings as ConnectionBindings>::Reactor;
25
26    fn maximum_message_length(&self) -> usize {
27        4 * (2 << 20)
28    }
29
30    fn max_queued_outbound_messages(&self) -> usize {
31        256
32    }
33
34    fn connect(
35        &self,
36        stream: tokio::net::TcpStream,
37    ) -> <Self::Bindings as ConnectionBindings>::Stream;
38}
39
40/// A `protosocket::Connection` is an IO driver. It directly uses tokio's io wrapper of mio to poll
41/// the OS's io primitives, manages read and write buffers, and vends messages to & from connections.
42/// Connections send messages to the ConnectionServer through an mpsc channel, and they receive
43/// inbound messages via a reactor callback.
44///
45/// Protosockets are monomorphic messages: You can only have 1 kind of message per service.
46/// The expected way to work with this is to use prost and protocol buffers to encode messages.
47/// Of course you can do whatever you want, as the telnet example shows.
48///
49/// Protosocket messages are not opinionated about request & reply. If you are, you will need
50/// to implement such a thing. This allows you freely choose whether you want to send
51/// fire-&-forget messages sometimes; however it requires you to write your protocol's rules.
52/// You get an inbound iterable of <MessageIn> batches and an outbound stream of <MessageOut> per
53/// connection - you decide what those mean for you!
54///
55/// A ProtosocketServer is a future: You spawn it and it runs forever.
56pub struct ProtosocketServer<Connector: ServerConnector> {
57    connector: Connector,
58    listener: tokio::net::TcpListener,
59    max_buffer_length: usize,
60    buffer_allocation_increment: usize,
61    max_queued_outbound_messages: usize,
62    runtime: tokio::runtime::Handle,
63}
64
65impl<Connector: ServerConnector> ProtosocketServer<Connector> {
66    /// Construct a new `ProtosocketServer` listening on the provided address.
67    /// The address will be bound and listened upon with `SO_REUSEADDR` set.
68    /// The server will use the provided runtime to spawn new tcp connections as `protosocket::Connection`s.
69    pub async fn new(
70        address: std::net::SocketAddr,
71        runtime: tokio::runtime::Handle,
72        connector: Connector,
73    ) -> crate::Result<Self> {
74        let listener = tokio::net::TcpListener::bind(address)
75            .await
76            .map_err(Arc::new)?;
77        Ok(Self {
78            connector,
79            listener,
80            max_buffer_length: 16 * (2 << 20),
81            max_queued_outbound_messages: 128,
82            buffer_allocation_increment: 1 << 20,
83            runtime,
84        })
85    }
86
87    /// Set the maximum buffer length for connections created by this server after the setting is applied.
88    pub fn set_max_buffer_length(&mut self, max_buffer_length: usize) {
89        self.max_buffer_length = max_buffer_length;
90    }
91
92    /// Set the maximum queued outbound messages for connections created by this server after the setting is applied.
93    pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
94        self.max_queued_outbound_messages = max_queued_outbound_messages;
95    }
96}
97
98impl<Connector: ServerConnector> Future for ProtosocketServer<Connector> {
99    type Output = Result<(), Error>;
100
101    fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
102        loop {
103            break match self.listener.poll_accept(context) {
104                Poll::Ready(result) => match result {
105                    Ok((stream, address)) => {
106                        stream.set_nodelay(true)?;
107                        let (outbound_submission_queue, outbound_messages) =
108                            mpsc::channel(self.max_queued_outbound_messages);
109                        let reactor = self
110                            .connector
111                            .new_reactor(outbound_submission_queue.clone());
112                        let stream = self.connector.connect(stream);
113                        let connection: Connection<Connector::Bindings> = Connection::new(
114                            stream,
115                            address,
116                            self.connector.deserializer(),
117                            self.connector.serializer(),
118                            self.max_buffer_length,
119                            self.buffer_allocation_increment,
120                            self.max_queued_outbound_messages,
121                            outbound_messages,
122                            reactor,
123                        );
124                        self.runtime.spawn(connection);
125                        continue;
126                    }
127                    Err(e) => {
128                        log::error!("failed to accept connection: {e:?}");
129                        continue;
130                    }
131                },
132                Poll::Pending => Poll::Pending,
133            };
134        }
135    }
136}