protosocket_rpc/server/
socket_server.rs

1use std::future::Future;
2use std::io::Error;
3use std::pin::Pin;
4use std::task::Context;
5use std::task::Poll;
6
7use protosocket::Connection;
8use tokio::sync::mpsc;
9
10use super::connection_server::RpcConnectionServer;
11use super::rpc_submitter::RpcSubmitter;
12use super::server_traits::SocketService;
13
14/// A `SocketRpcServer` is a server future. It listens on a socket and spawns new connections,
15/// with a ConnectionService to handle each connection.
16///
17/// Protosockets use monomorphic messages: You can only have 1 kind of message per service.
18/// The expected way to work with this is to use prost and protocol buffers to encode messages.
19///
20/// The socket server hosts your SocketService.
21/// Your SocketService creates a ConnectionService for each new connection.
22/// Your ConnectionService manages one connection. It is Dropped when the connection is closed.
23pub struct SocketRpcServer<TSocketService>
24where
25    TSocketService: SocketService,
26{
27    socket_server: TSocketService,
28    listener: tokio::net::TcpListener,
29    max_buffer_length: usize,
30    max_queued_outbound_messages: usize,
31}
32
33impl<TSocketService> SocketRpcServer<TSocketService>
34where
35    TSocketService: SocketService,
36{
37    /// Construct a new `SocketRpcServer` listening on the provided address.
38    pub async fn new(
39        address: std::net::SocketAddr,
40        socket_server: TSocketService,
41    ) -> crate::Result<Self> {
42        let listener = tokio::net::TcpListener::bind(address).await?;
43        Ok(Self {
44            socket_server,
45            listener,
46            max_buffer_length: 16 * (2 << 20),
47            max_queued_outbound_messages: 128,
48        })
49    }
50
51    /// Set the maximum buffer length for connections created by this server after the setting is applied.
52    pub fn set_max_buffer_length(&mut self, max_buffer_length: usize) {
53        self.max_buffer_length = max_buffer_length;
54    }
55
56    /// Set the maximum queued outbound messages for connections created by this server after the setting is applied.
57    pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
58        self.max_queued_outbound_messages = max_queued_outbound_messages;
59    }
60}
61
62impl<TSocketService> Future for SocketRpcServer<TSocketService>
63where
64    TSocketService: SocketService,
65{
66    type Output = Result<(), Error>;
67
68    fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
69        loop {
70            break match self.listener.poll_accept(context) {
71                Poll::Ready(result) => match result {
72                    Ok((stream, address)) => {
73                        stream.set_nodelay(true)?;
74                        let (submitter, inbound_messages) = RpcSubmitter::new();
75                        let (outbound_messages, outbound_messages_receiver) =
76                            mpsc::channel(self.max_queued_outbound_messages);
77                        let connection_service = self.socket_server.new_connection_service(address);
78                        let connection_rpc_server = RpcConnectionServer::new(
79                            connection_service,
80                            inbound_messages,
81                            outbound_messages,
82                        );
83
84                        let connection: Connection<RpcSubmitter<TSocketService>> = Connection::new(
85                            stream,
86                            address,
87                            self.socket_server.deserializer(),
88                            self.socket_server.serializer(),
89                            self.max_buffer_length,
90                            self.max_queued_outbound_messages,
91                            outbound_messages_receiver,
92                            submitter,
93                        );
94
95                        tokio::spawn(connection);
96                        tokio::spawn(connection_rpc_server);
97
98                        continue;
99                    }
100                    Err(e) => {
101                        log::error!("failed to accept connection: {e:?}");
102                        continue;
103                    }
104                },
105                Poll::Pending => Poll::Pending,
106            };
107        }
108    }
109}