protosocket_server/connection_server.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
use std::future::Future;
use std::io::Error;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use protosocket::Connection;
use protosocket::ConnectionBindings;
use protosocket::Serializer;
use tokio::sync::mpsc;
pub trait ServerConnector: Unpin {
type Bindings: ConnectionBindings;
fn serializer(&self) -> <Self::Bindings as ConnectionBindings>::Serializer;
fn deserializer(&self) -> <Self::Bindings as ConnectionBindings>::Deserializer;
fn new_reactor(
&self,
optional_outbound: mpsc::Sender<
<<Self::Bindings as ConnectionBindings>::Serializer as Serializer>::Message,
>,
) -> <Self::Bindings as ConnectionBindings>::Reactor;
fn maximum_message_length(&self) -> usize {
4 * (2 << 20)
}
fn max_queued_outbound_messages(&self) -> usize {
256
}
}
/// A `protosocket::Connection` is an IO driver. It directly uses tokio's io wrapper of mio to poll
/// the OS's io primitives, manages read and write buffers, and vends messages to & from connections.
/// Connections send messages to the ConnectionServer through an mpsc channel, and they receive
/// inbound messages via a reactor callback.
///
/// Protosockets are monomorphic messages: You can only have 1 kind of message per service.
/// The expected way to work with this is to use prost and protocol buffers to encode messages.
/// Of course you can do whatever you want, as the telnet example shows.
///
/// Protosocket messages are not opinionated about request & reply. If you are, you will need
/// to implement such a thing. This allows you freely choose whether you want to send
/// fire-&-forget messages sometimes; however it requires you to write your protocol's rules.
/// You get an inbound iterable of <MessageIn> batches and an outbound stream of <MessageOut> per
/// connection - you decide what those mean for you!
///
/// A ProtosocketServer is a future: You spawn it and it runs forever.
pub struct ProtosocketServer<Connector: ServerConnector> {
connector: Connector,
listener: tokio::net::TcpListener,
max_buffer_length: usize,
max_queued_outbound_messages: usize,
runtime: tokio::runtime::Handle,
}
impl<Connector: ServerConnector> ProtosocketServer<Connector> {
/// Construct a new `ProtosocketServer` listening on the provided address.
/// The address will be bound and listened upon with `SO_REUSEADDR` set.
/// The server will use the provided runtime to spawn new tcp connections as `protosocket::Connection`s.
pub async fn new(
address: std::net::SocketAddr,
runtime: tokio::runtime::Handle,
connector: Connector,
) -> crate::Result<Self> {
let listener = tokio::net::TcpListener::bind(address)
.await
.map_err(Arc::new)?;
Ok(Self {
connector,
listener,
max_buffer_length: 16 * (2 << 20),
max_queued_outbound_messages: 128,
runtime,
})
}
/// Set the maximum buffer length for connections created by this server after the setting is applied.
pub fn set_max_buffer_length(&mut self, max_buffer_length: usize) {
self.max_buffer_length = max_buffer_length;
}
/// Set the maximum queued outbound messages for connections created by this server after the setting is applied.
pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
self.max_queued_outbound_messages = max_queued_outbound_messages;
}
}
impl<Connector: ServerConnector> Future for ProtosocketServer<Connector> {
type Output = Result<(), Error>;
fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
loop {
break match self.listener.poll_accept(context) {
Poll::Ready(result) => match result {
Ok((stream, address)) => {
stream.set_nodelay(true)?;
let (outbound_submission_queue, outbound_messages) =
mpsc::channel(self.max_queued_outbound_messages);
let reactor = self
.connector
.new_reactor(outbound_submission_queue.clone());
let connection: Connection<Connector::Bindings> = Connection::new(
stream,
address,
self.connector.deserializer(),
self.connector.serializer(),
self.max_buffer_length,
self.max_queued_outbound_messages,
outbound_messages,
reactor,
);
self.runtime.spawn(connection);
continue;
}
Err(e) => {
log::error!("failed to accept connection: {e:?}");
continue;
}
},
Poll::Pending => Poll::Pending,
};
}
}
}