protosocket_server/
connection_server.rs

1use std::future::Future;
2use std::io::Error;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::Context;
7use std::task::Poll;
8
9use protosocket::Connection;
10use protosocket::ConnectionBindings;
11use protosocket::Serializer;
12use tokio::sync::mpsc;
13
14pub trait ServerConnector: Unpin {
15    type Bindings: ConnectionBindings;
16
17    fn serializer(&self) -> <Self::Bindings as ConnectionBindings>::Serializer;
18    fn deserializer(&self) -> <Self::Bindings as ConnectionBindings>::Deserializer;
19
20    fn new_reactor(
21        &self,
22        optional_outbound: mpsc::Sender<
23            <<Self::Bindings as ConnectionBindings>::Serializer as Serializer>::Message,
24        >,
25        address: SocketAddr,
26    ) -> <Self::Bindings as ConnectionBindings>::Reactor;
27
28    fn maximum_message_length(&self) -> usize {
29        4 * (2 << 20)
30    }
31
32    fn max_queued_outbound_messages(&self) -> usize {
33        256
34    }
35
36    fn connect(
37        &self,
38        stream: tokio::net::TcpStream,
39    ) -> <Self::Bindings as ConnectionBindings>::Stream;
40}
41
42/// A `protosocket::Connection` is an IO driver. It directly uses tokio's io wrapper of mio to poll
43/// the OS's io primitives, manages read and write buffers, and vends messages to & from connections.
44/// Connections send messages to the ConnectionServer through an mpsc channel, and they receive
45/// inbound messages via a reactor callback.
46///
47/// Protosockets are monomorphic messages: You can only have 1 kind of message per service.
48/// The expected way to work with this is to use prost and protocol buffers to encode messages.
49/// Of course you can do whatever you want, as the telnet example shows.
50///
51/// Protosocket messages are not opinionated about request & reply. If you are, you will need
52/// to implement such a thing. This allows you freely choose whether you want to send
53/// fire-&-forget messages sometimes; however it requires you to write your protocol's rules.
54/// You get an inbound iterable of <MessageIn> batches and an outbound stream of <MessageOut> per
55/// connection - you decide what those mean for you!
56///
57/// A ProtosocketServer is a future: You spawn it and it runs forever.
58pub struct ProtosocketServer<Connector: ServerConnector> {
59    connector: Connector,
60    listener: tokio::net::TcpListener,
61    max_buffer_length: usize,
62    buffer_allocation_increment: usize,
63    max_queued_outbound_messages: usize,
64    runtime: tokio::runtime::Handle,
65}
66
67impl<Connector: ServerConnector> ProtosocketServer<Connector> {
68    /// Construct a new `ProtosocketServer` listening on the provided address.
69    /// The address will be bound and listened upon with `SO_REUSEADDR` set.
70    /// The server will use the provided runtime to spawn new tcp connections as `protosocket::Connection`s.
71    pub async fn new(
72        address: std::net::SocketAddr,
73        runtime: tokio::runtime::Handle,
74        connector: Connector,
75    ) -> crate::Result<Self> {
76        let listener = tokio::net::TcpListener::bind(address)
77            .await
78            .map_err(Arc::new)?;
79        Ok(Self {
80            connector,
81            listener,
82            max_buffer_length: 16 * (2 << 20),
83            max_queued_outbound_messages: 128,
84            buffer_allocation_increment: 1 << 20,
85            runtime,
86        })
87    }
88
89    /// Set the maximum buffer length for connections created by this server after the setting is applied.
90    pub fn set_max_buffer_length(&mut self, max_buffer_length: usize) {
91        self.max_buffer_length = max_buffer_length;
92    }
93
94    /// Set the maximum queued outbound messages for connections created by this server after the setting is applied.
95    pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
96        self.max_queued_outbound_messages = max_queued_outbound_messages;
97    }
98}
99
100impl<Connector: ServerConnector> Future for ProtosocketServer<Connector> {
101    type Output = Result<(), Error>;
102
103    fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
104        loop {
105            break match self.listener.poll_accept(context) {
106                Poll::Ready(result) => match result {
107                    Ok((stream, address)) => {
108                        stream.set_nodelay(true)?;
109                        let (outbound_submission_queue, outbound_messages) =
110                            mpsc::channel(self.max_queued_outbound_messages);
111                        let reactor = self
112                            .connector
113                            .new_reactor(outbound_submission_queue.clone(), address);
114                        let stream = self.connector.connect(stream);
115                        let connection: Connection<Connector::Bindings> = Connection::new(
116                            stream,
117                            address,
118                            self.connector.deserializer(),
119                            self.connector.serializer(),
120                            self.max_buffer_length,
121                            self.buffer_allocation_increment,
122                            self.max_queued_outbound_messages,
123                            outbound_messages,
124                            reactor,
125                        );
126                        self.runtime.spawn(connection);
127                        continue;
128                    }
129                    Err(e) => {
130                        log::error!("failed to accept connection: {e:?}");
131                        continue;
132                    }
133                },
134                Poll::Pending => Poll::Pending,
135            };
136        }
137    }
138}