protosocket_rpc/server/
socket_server.rsuse std::future::Future;
use std::io::Error;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use protosocket::Connection;
use tokio::sync::mpsc;
use super::connection_server::RpcConnectionServer;
use super::rpc_submitter::RpcSubmitter;
use super::server_traits::SocketService;
pub struct SocketRpcServer<TSocketService>
where
TSocketService: SocketService,
{
socket_server: TSocketService,
listener: tokio::net::TcpListener,
max_buffer_length: usize,
max_queued_outbound_messages: usize,
}
impl<TSocketService> SocketRpcServer<TSocketService>
where
TSocketService: SocketService,
{
pub async fn new(
address: std::net::SocketAddr,
socket_server: TSocketService,
) -> crate::Result<Self> {
let listener = tokio::net::TcpListener::bind(address).await?;
Ok(Self {
socket_server,
listener,
max_buffer_length: 16 * (2 << 20),
max_queued_outbound_messages: 128,
})
}
pub fn set_max_buffer_length(&mut self, max_buffer_length: usize) {
self.max_buffer_length = max_buffer_length;
}
pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
self.max_queued_outbound_messages = max_queued_outbound_messages;
}
}
impl<TSocketService> Future for SocketRpcServer<TSocketService>
where
TSocketService: SocketService,
{
type Output = Result<(), Error>;
fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
loop {
break match self.listener.poll_accept(context) {
Poll::Ready(result) => match result {
Ok((stream, address)) => {
stream.set_nodelay(true)?;
let (submitter, inbound_messages) = RpcSubmitter::new();
let (outbound_messages, outbound_messages_receiver) =
mpsc::channel(self.max_queued_outbound_messages);
let connection_service = self.socket_server.new_connection_service(address);
let connection_rpc_server = RpcConnectionServer::new(
connection_service,
inbound_messages,
outbound_messages,
);
let connection: Connection<RpcSubmitter<TSocketService>> = Connection::new(
stream,
address,
self.socket_server.deserializer(),
self.socket_server.serializer(),
self.max_buffer_length,
self.max_queued_outbound_messages,
outbound_messages_receiver,
submitter,
);
tokio::spawn(connection);
tokio::spawn(connection_rpc_server);
continue;
}
Err(e) => {
log::error!("failed to accept connection: {e:?}");
continue;
}
},
Poll::Pending => Poll::Pending,
};
}
}
}