Skip to main content

protosocket_prost/
prost_client_registry.rs

1use std::future::Future;
2
3use protosocket::{Connection, MessageReactor, PooledEncoder};
4use tokio::net::TcpStream;
5
6use crate::{prost_serializer::ProstDecoder, ProstSerializer};
7
8/// A factory for creating client connections to a `protosocket` server.
9#[derive(Debug, Clone)]
10pub struct ClientRegistry<TConnector = TcpConnector> {
11    max_buffer_length: usize,
12    buffer_allocation_increment: usize,
13    max_queued_outbound_messages: usize,
14    runtime: tokio::runtime::Handle,
15    stream_connector: TConnector,
16}
17
18pub trait StreamConnector {
19    type Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static;
20
21    fn connect_stream(
22        &self,
23        stream: TcpStream,
24    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send;
25}
26
27pub struct TcpConnector;
28impl StreamConnector for TcpConnector {
29    type Stream = TcpStream;
30    fn connect_stream(
31        &self,
32        stream: TcpStream,
33    ) -> impl Future<Output = std::io::Result<TcpStream>> + Send {
34        std::future::ready(Ok(stream))
35    }
36}
37
38impl<TConnector> ClientRegistry<TConnector>
39where
40    TConnector: StreamConnector,
41{
42    /// Construct a new client registry. Connections will be spawned on the provided runtime.
43    pub fn new(runtime: tokio::runtime::Handle, connector: TConnector) -> Self {
44        log::trace!("new client registry");
45        Self {
46            max_buffer_length: 4 * (1 << 20),
47            max_queued_outbound_messages: 256,
48            buffer_allocation_increment: 1 << 20,
49            runtime,
50            stream_connector: connector,
51        }
52    }
53
54    /// Sets the maximum read buffer length for connections created by this registry after
55    /// the setting is applied.
56    pub fn set_max_read_buffer_length(&mut self, max_buffer_length: usize) {
57        self.max_buffer_length = max_buffer_length;
58    }
59
60    /// Sets the maximum queued outbound messages for connections created by this registry after
61    /// the setting is applied.
62    pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
63        self.max_queued_outbound_messages = max_queued_outbound_messages;
64    }
65
66    /// Get a new connection to a `protosocket` server.
67    pub async fn register_client<Request, Response, Reactor>(
68        &self,
69        address: impl Into<String>,
70        message_reactor: Reactor,
71    ) -> crate::Result<spillway::Sender<Reactor::LogicalOutbound>>
72    where
73        Request: prost::Message + Default + Unpin + std::fmt::Debug + 'static,
74        Response: prost::Message + Default + Unpin + std::fmt::Debug + 'static,
75        Reactor: MessageReactor<Inbound = Response, Outbound = Request> + Send,
76        Reactor::LogicalOutbound: Send,
77    {
78        let address: std::net::SocketAddr = address.into().parse()?;
79        let stream = TcpStream::connect(address)
80            .await
81            .map_err(std::sync::Arc::new)?;
82        stream.set_nodelay(true).map_err(std::sync::Arc::new)?;
83        let stream = self
84            .stream_connector
85            .connect_stream(stream)
86            .await
87            .map_err(std::sync::Arc::new)?;
88        let (outbound, outbound_messages) = spillway::channel();
89        let codec = (
90            PooledEncoder::<ProstSerializer<Request>>::default(),
91            ProstDecoder::<Response>::default(),
92        );
93        let connection = Connection::new(
94            stream,
95            codec,
96            self.max_buffer_length,
97            self.buffer_allocation_increment,
98            self.max_queued_outbound_messages,
99            outbound_messages,
100            message_reactor,
101        );
102        self.runtime.spawn(connection);
103        Ok(outbound)
104    }
105}