protosocket_prost/
prost_client_registry.rs1use std::future::Future;
2
3use protosocket::{Connection, MessageReactor};
4use tokio::{net::TcpStream, sync::mpsc};
5
6use crate::{ProstClientConnectionBindings, ProstSerializer};
7
8#[derive(Debug, Clone)]
10pub struct ClientRegistry<TConnector = TcpConnector> {
11 max_buffer_length: usize,
12 buffer_allocation_increment: usize,
13 max_queued_outbound_messages: usize,
14 runtime: tokio::runtime::Handle,
15 stream_connector: TConnector,
16}
17
18pub trait StreamConnector {
19 type Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static;
20
21 fn connect_stream(
22 &self,
23 stream: TcpStream,
24 ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send;
25}
26
27pub struct TcpConnector;
28impl StreamConnector for TcpConnector {
29 type Stream = TcpStream;
30 fn connect_stream(
31 &self,
32 stream: TcpStream,
33 ) -> impl Future<Output = std::io::Result<TcpStream>> + Send {
34 std::future::ready(Ok(stream))
35 }
36}
37
38impl<TConnector> ClientRegistry<TConnector>
39where
40 TConnector: StreamConnector,
41{
42 pub fn new(runtime: tokio::runtime::Handle, connector: TConnector) -> Self {
44 log::trace!("new client registry");
45 Self {
46 max_buffer_length: 4 * (1 << 20),
47 max_queued_outbound_messages: 256,
48 buffer_allocation_increment: 1 << 20,
49 runtime,
50 stream_connector: connector,
51 }
52 }
53
54 pub fn set_max_read_buffer_length(&mut self, max_buffer_length: usize) {
57 self.max_buffer_length = max_buffer_length;
58 }
59
60 pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
63 self.max_queued_outbound_messages = max_queued_outbound_messages;
64 }
65
66 pub async fn register_client<Request, Response, Reactor>(
68 &self,
69 address: impl Into<String>,
70 message_reactor: Reactor,
71 ) -> crate::Result<mpsc::Sender<Request>>
72 where
73 Request: prost::Message + Default + Unpin + 'static,
74 Response: prost::Message + Default + Unpin + 'static,
75 Reactor: MessageReactor<Inbound = Response>,
76 {
77 let address = address.into().parse()?;
78 let stream = TcpStream::connect(address)
79 .await
80 .map_err(std::sync::Arc::new)?;
81 stream.set_nodelay(true).map_err(std::sync::Arc::new)?;
82 let stream = self
83 .stream_connector
84 .connect_stream(stream)
85 .await
86 .map_err(std::sync::Arc::new)?;
87 let (outbound, outbound_messages) = mpsc::channel(self.max_queued_outbound_messages);
88 let connection = Connection::<
89 ProstClientConnectionBindings<Request, Response, Reactor, TConnector::Stream>,
90 >::new(
91 stream,
92 address,
93 ProstSerializer::default(),
94 ProstSerializer::default(),
95 self.max_buffer_length,
96 self.buffer_allocation_increment,
97 self.max_queued_outbound_messages,
98 outbound_messages,
99 message_reactor,
100 );
101 self.runtime.spawn(connection);
102 Ok(outbound)
103 }
104}