protosocket_prost/
prost_client_registry.rs

1use std::future::Future;
2
3use protosocket::{Connection, MessageReactor};
4use tokio::{net::TcpStream, sync::mpsc};
5
6use crate::{ProstClientConnectionBindings, ProstSerializer};
7
8/// A factory for creating client connections to a `protosocket` server.
9#[derive(Debug, Clone)]
10pub struct ClientRegistry<TConnector = TcpConnector> {
11    max_buffer_length: usize,
12    buffer_allocation_increment: usize,
13    max_queued_outbound_messages: usize,
14    runtime: tokio::runtime::Handle,
15    stream_connector: TConnector,
16}
17
18pub trait StreamConnector {
19    type Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static;
20
21    fn connect_stream(
22        &self,
23        stream: TcpStream,
24    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send;
25}
26
27pub struct TcpConnector;
28impl StreamConnector for TcpConnector {
29    type Stream = TcpStream;
30    fn connect_stream(
31        &self,
32        stream: TcpStream,
33    ) -> impl Future<Output = std::io::Result<TcpStream>> + Send {
34        std::future::ready(Ok(stream))
35    }
36}
37
38impl<TConnector> ClientRegistry<TConnector>
39where
40    TConnector: StreamConnector,
41{
42    /// Construct a new client registry. Connections will be spawned on the provided runtime.
43    pub fn new(runtime: tokio::runtime::Handle, connector: TConnector) -> Self {
44        log::trace!("new client registry");
45        Self {
46            max_buffer_length: 4 * (1 << 20),
47            max_queued_outbound_messages: 256,
48            buffer_allocation_increment: 1 << 20,
49            runtime,
50            stream_connector: connector,
51        }
52    }
53
54    /// Sets the maximum read buffer length for connections created by this registry after
55    /// the setting is applied.
56    pub fn set_max_read_buffer_length(&mut self, max_buffer_length: usize) {
57        self.max_buffer_length = max_buffer_length;
58    }
59
60    /// Sets the maximum queued outbound messages for connections created by this registry after
61    /// the setting is applied.
62    pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
63        self.max_queued_outbound_messages = max_queued_outbound_messages;
64    }
65
66    /// Get a new connection to a `protosocket` server.
67    pub async fn register_client<Request, Response, Reactor>(
68        &self,
69        address: impl Into<String>,
70        message_reactor: Reactor,
71    ) -> crate::Result<mpsc::Sender<Request>>
72    where
73        Request: prost::Message + Default + Unpin + 'static,
74        Response: prost::Message + Default + Unpin + 'static,
75        Reactor: MessageReactor<Inbound = Response>,
76    {
77        let address = address.into().parse()?;
78        let stream = TcpStream::connect(address)
79            .await
80            .map_err(std::sync::Arc::new)?;
81        stream.set_nodelay(true).map_err(std::sync::Arc::new)?;
82        let stream = self
83            .stream_connector
84            .connect_stream(stream)
85            .await
86            .map_err(std::sync::Arc::new)?;
87        let (outbound, outbound_messages) = mpsc::channel(self.max_queued_outbound_messages);
88        let connection = Connection::<
89            ProstClientConnectionBindings<Request, Response, Reactor, TConnector::Stream>,
90        >::new(
91            stream,
92            address,
93            ProstSerializer::default(),
94            ProstSerializer::default(),
95            self.max_buffer_length,
96            self.buffer_allocation_increment,
97            self.max_queued_outbound_messages,
98            outbound_messages,
99            message_reactor,
100        );
101        self.runtime.spawn(connection);
102        Ok(outbound)
103    }
104}