use protosocket::Codec;
use protosocket::Connection;
use protosocket::Decoder;
use protosocket::Encoder;
use protosocket::MessageReactor;
use protosocket::SocketListener;
use protosocket::SocketResult;
use protosocket::TcpSocketListener;
use std::future::Future;
use std::io::Error;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
pub trait ServerConnector: Unpin {
type Codec: Codec
+ Decoder<Message = <Self::Reactor as MessageReactor>::Inbound>
+ Encoder<Message = <Self::Reactor as MessageReactor>::Outbound>;
type Reactor: MessageReactor;
type SocketListener: SocketListener;
fn codec(&self) -> Self::Codec;
fn new_reactor(
&self,
optional_outbound: spillway::Sender<<Self::Reactor as MessageReactor>::LogicalOutbound>,
_connection: &<Self::SocketListener as SocketListener>::Stream,
) -> Self::Reactor;
fn spawn_connection(
&self,
connection: Connection<
<Self::SocketListener as SocketListener>::Stream,
Self::Codec,
Self::Reactor,
>,
);
}
pub struct ProtosocketServer<Connector: ServerConnector> {
connector: Connector,
listener: Connector::SocketListener,
max_buffer_length: usize,
buffer_allocation_increment: usize,
max_queued_outbound_messages: usize,
}
pub struct ProtosocketSocketConfig {
nodelay: bool,
reuse: bool,
keepalive_duration: Option<std::time::Duration>,
listen_backlog: u32,
}
impl ProtosocketSocketConfig {
pub fn nodelay(mut self, nodelay: bool) -> Self {
self.nodelay = nodelay;
self
}
pub fn reuse(mut self, reuse: bool) -> Self {
self.reuse = reuse;
self
}
pub fn keepalive_duration(mut self, keepalive_duration: std::time::Duration) -> Self {
self.keepalive_duration = Some(keepalive_duration);
self
}
pub fn listen_backlog(mut self, backlog: u32) -> Self {
self.listen_backlog = backlog;
self
}
}
impl Default for ProtosocketSocketConfig {
fn default() -> Self {
Self {
nodelay: true,
reuse: true,
keepalive_duration: None,
listen_backlog: 65536,
}
}
}
pub struct ProtosocketServerConfig {
max_buffer_length: usize,
max_queued_outbound_messages: usize,
buffer_allocation_increment: usize,
socket_config: ProtosocketSocketConfig,
}
impl ProtosocketServerConfig {
pub fn max_buffer_length(mut self, max_buffer_length: usize) -> Self {
self.max_buffer_length = max_buffer_length;
self
}
pub fn max_queued_outbound_messages(mut self, max_queued_outbound_messages: usize) -> Self {
self.max_queued_outbound_messages = max_queued_outbound_messages;
self
}
pub fn buffer_allocation_increment(mut self, buffer_allocation_increment: usize) -> Self {
self.buffer_allocation_increment = buffer_allocation_increment;
self
}
pub fn socket_config(mut self, config: ProtosocketSocketConfig) -> Self {
self.socket_config = config;
self
}
pub fn bind_tcp<Connector: ServerConnector<SocketListener = TcpSocketListener>>(
self,
address: SocketAddr,
connector: Connector,
) -> crate::Result<ProtosocketServer<Connector>> {
Ok(ProtosocketServer::new(
TcpSocketListener::listen(
address,
self.socket_config.listen_backlog,
self.socket_config.keepalive_duration,
)?,
connector,
self,
))
}
pub fn build_server<
Connector: ServerConnector<SocketListener = Listener>,
Listener: SocketListener,
>(
self,
listener: Listener,
connector: Connector,
) -> crate::Result<ProtosocketServer<Connector>> {
Ok(ProtosocketServer::new(listener, connector, self))
}
}
impl Default for ProtosocketServerConfig {
fn default() -> Self {
Self {
max_buffer_length: 16 * (2 << 20),
max_queued_outbound_messages: 128,
buffer_allocation_increment: 1 << 20,
socket_config: Default::default(),
}
}
}
impl<Connector: ServerConnector> ProtosocketServer<Connector> {
fn new(
listener: Connector::SocketListener,
connector: Connector,
config: ProtosocketServerConfig,
) -> Self {
Self {
connector,
listener,
max_buffer_length: config.max_buffer_length,
max_queued_outbound_messages: config.max_queued_outbound_messages,
buffer_allocation_increment: config.buffer_allocation_increment,
}
}
}
impl<Connector: ServerConnector> Unpin for ProtosocketServer<Connector> {}
impl<Connector: ServerConnector> Future for ProtosocketServer<Connector> {
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
loop {
break match self.listener.poll_accept(context) {
Poll::Ready(result) => match result {
SocketResult::Stream(stream) => {
let (outbound_submission_queue, outbound_messages) = spillway::channel();
let reactor = self
.connector
.new_reactor(outbound_submission_queue.clone(), &stream);
let connection = Connection::new(
stream,
self.connector.codec(),
self.max_buffer_length,
self.buffer_allocation_increment,
self.max_queued_outbound_messages,
outbound_messages,
reactor,
);
self.connector.spawn_connection(connection);
continue;
}
SocketResult::Disconnect => Poll::Ready(Ok(())),
},
Poll::Pending => Poll::Pending,
};
}
}
}