protosocket_rpc/server/
socket_server.rs

1use protosocket::Connection;
2use socket2::TcpKeepalive;
3use std::ffi::c_int;
4use std::future::Future;
5use std::io::Error;
6use std::pin::Pin;
7use std::task::Context;
8use std::task::Poll;
9use std::time::Duration;
10use tokio::sync::mpsc;
11
12use super::connection_server::RpcConnectionServer;
13use super::rpc_submitter::RpcSubmitter;
14use super::server_traits::SocketService;
15
16/// A `SocketRpcServer` is a server future. It listens on a socket and spawns new connections,
17/// with a ConnectionService to handle each connection.
18///
19/// Protosockets use monomorphic messages: You can only have 1 kind of message per service.
20/// The expected way to work with this is to use prost and protocol buffers to encode messages.
21///
22/// The socket server hosts your SocketService.
23/// Your SocketService creates a ConnectionService for each new connection.
24/// Your ConnectionService manages one connection. It is Dropped when the connection is closed.
25pub struct SocketRpcServer<TSocketService>
26where
27    TSocketService: SocketService,
28{
29    socket_server: TSocketService,
30    listener: tokio::net::TcpListener,
31    max_buffer_length: usize,
32    buffer_allocation_increment: usize,
33    max_queued_outbound_messages: usize,
34}
35
36impl<TSocketService> SocketRpcServer<TSocketService>
37where
38    TSocketService: SocketService,
39{
40    /// Construct a new `SocketRpcServer` listening on the provided address.
41    pub async fn new(
42        address: std::net::SocketAddr,
43        socket_server: TSocketService,
44        max_buffer_length: usize,
45        buffer_allocation_increment: usize,
46        max_queued_outbound_messages: usize,
47        listen_backlog: u32,
48        tcp_keepalive_duration: Option<Duration>,
49    ) -> crate::Result<Self> {
50        let socket = socket2::Socket::new(
51            match address {
52                std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
53                std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
54            },
55            socket2::Type::STREAM,
56            None,
57        )?;
58
59        let mut tcp_keepalive = TcpKeepalive::new();
60        if let Some(duration) = tcp_keepalive_duration {
61            tcp_keepalive = tcp_keepalive.with_time(duration);
62        }
63
64        socket.set_nonblocking(true)?;
65        socket.set_tcp_nodelay(true)?;
66        socket.set_tcp_keepalive(&tcp_keepalive)?;
67        socket.set_reuse_port(true)?;
68        socket.set_reuse_address(true)?;
69
70        socket.bind(&address.into())?;
71        socket.listen(listen_backlog as c_int)?;
72
73        let listener = tokio::net::TcpListener::from_std(socket.into())?;
74        Ok(Self {
75            socket_server,
76            listener,
77            max_buffer_length,
78            buffer_allocation_increment,
79            max_queued_outbound_messages,
80        })
81    }
82
83    /// Set the maximum buffer length for connections created by this server after the setting is applied.
84    pub fn set_max_buffer_length(&mut self, max_buffer_length: usize) {
85        self.max_buffer_length = max_buffer_length;
86    }
87
88    /// Set the maximum queued outbound messages for connections created by this server after the setting is applied.
89    pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
90        self.max_queued_outbound_messages = max_queued_outbound_messages;
91    }
92}
93
94impl<TSocketService> Future for SocketRpcServer<TSocketService>
95where
96    TSocketService: SocketService,
97{
98    type Output = Result<(), Error>;
99
100    fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
101        loop {
102            break match self.listener.poll_accept(context) {
103                Poll::Ready(result) => match result {
104                    Ok((stream, address)) => {
105                        stream.set_nodelay(true)?;
106                        let (submitter, inbound_messages) = RpcSubmitter::new();
107                        let (outbound_messages, outbound_messages_receiver) =
108                            mpsc::channel(self.max_queued_outbound_messages);
109                        let connection_service = self.socket_server.new_connection_service(address);
110                        let connection_rpc_server = RpcConnectionServer::new(
111                            connection_service,
112                            inbound_messages,
113                            outbound_messages,
114                        );
115                        let deserializer = self.socket_server.deserializer();
116                        let serializer = self.socket_server.serializer();
117                        let max_buffer_length = self.max_buffer_length;
118                        let max_queued_outbound_messages = self.max_queued_outbound_messages;
119                        let buffer_allocation_increment = self.buffer_allocation_increment;
120
121                        let stream_future = self.socket_server.accept_stream(stream);
122
123                        tokio::spawn(async move {
124                            match stream_future.await {
125                                Ok(stream) => {
126                                    let connection: Connection<RpcSubmitter<TSocketService>> =
127                                        Connection::new(
128                                            stream,
129                                            address,
130                                            deserializer,
131                                            serializer,
132                                            max_buffer_length,
133                                            buffer_allocation_increment,
134                                            max_queued_outbound_messages,
135                                            outbound_messages_receiver,
136                                            submitter,
137                                        );
138                                    tokio::spawn(connection);
139                                    tokio::spawn(connection_rpc_server);
140                                }
141                                Err(e) => {
142                                    log::error!("failed to connect stream: {e:?}");
143                                }
144                            }
145                        });
146
147                        continue;
148                    }
149                    Err(e) => {
150                        log::error!("failed to accept connection: {e:?}");
151                        continue;
152                    }
153                },
154                Poll::Pending => Poll::Pending,
155            };
156        }
157    }
158}