protosocket_rpc/client/
configuration.rs

1use protosocket::Connection;
2use socket2::TcpKeepalive;
3use std::{future::Future, net::SocketAddr, sync::Arc};
4use tokio::{net::TcpStream, sync::mpsc};
5use tokio_rustls::rustls::pki_types::ServerName;
6
7use crate::{
8    client::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor},
9    Message,
10};
11
12use super::{reactor::completion_reactor::RpcCompletionConnectionBindings, RpcClient};
13
14pub trait StreamConnector: std::fmt::Debug {
15    type Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static;
16
17    fn connect_stream(
18        &self,
19        stream: TcpStream,
20    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send;
21}
22
23/// A `StreamConnector` for bare TCP streams.
24#[derive(Debug)]
25pub struct TcpStreamConnector;
26impl StreamConnector for TcpStreamConnector {
27    type Stream = TcpStream;
28
29    fn connect_stream(
30        &self,
31        stream: TcpStream,
32    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
33        std::future::ready(Ok(stream))
34    }
35}
36
37/// A `StreamConnector` for PKI TLS streams.
38pub struct WebpkiTlsStreamConnector {
39    connector: tokio_rustls::TlsConnector,
40    servername: ServerName<'static>,
41}
42impl WebpkiTlsStreamConnector {
43    /// Create a new `TlsStreamConnector` for a server
44    pub fn new(servername: ServerName<'static>) -> Self {
45        let client_config = Arc::new(
46            tokio_rustls::rustls::ClientConfig::builder_with_protocol_versions(&[
47                &tokio_rustls::rustls::version::TLS13,
48            ])
49            .with_root_certificates(tokio_rustls::rustls::RootCertStore::from_iter(
50                webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
51            ))
52            .with_no_client_auth(),
53        );
54        let connector = tokio_rustls::TlsConnector::from(client_config);
55        Self {
56            connector,
57            servername,
58        }
59    }
60}
61impl std::fmt::Debug for WebpkiTlsStreamConnector {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("TlsStreamConnector").finish_non_exhaustive()
64    }
65}
66impl StreamConnector for WebpkiTlsStreamConnector {
67    type Stream = tokio_rustls::client::TlsStream<TcpStream>;
68
69    fn connect_stream(
70        &self,
71        stream: TcpStream,
72    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
73        self.connector
74            .clone()
75            .connect(self.servername.clone(), stream)
76    }
77}
78
79/// A `StreamConnector` for self-signed server TLS streams. No host certificate validation is performed.
80pub struct UnverifiedTlsStreamConnector {
81    connector: tokio_rustls::TlsConnector,
82    servername: ServerName<'static>,
83}
84impl UnverifiedTlsStreamConnector {
85    /// Create a new `UnverifiedTlsStreamConnector` for a server.
86    /// This connector does not perform any certificate validation.
87    pub fn new(servername: ServerName<'static>) -> Self {
88        let client_config = Arc::new(
89            tokio_rustls::rustls::ClientConfig::builder_with_protocol_versions(&[
90                &tokio_rustls::rustls::version::TLS13,
91            ])
92            .dangerous()
93            .with_custom_certificate_verifier(Arc::new(DoNothingVerifier))
94            .with_no_client_auth(),
95        );
96        let connector = tokio_rustls::TlsConnector::from(client_config);
97        Self {
98            connector,
99            servername,
100        }
101    }
102}
103impl std::fmt::Debug for UnverifiedTlsStreamConnector {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.debug_struct("UnverifiedTlsStreamConnector")
106            .finish_non_exhaustive()
107    }
108}
109impl StreamConnector for UnverifiedTlsStreamConnector {
110    type Stream = tokio_rustls::client::TlsStream<TcpStream>;
111
112    fn connect_stream(
113        &self,
114        stream: TcpStream,
115    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
116        self.connector
117            .clone()
118            .connect(self.servername.clone(), stream)
119    }
120}
121
122// You don't need this if you use a real certificate
123#[derive(Debug)]
124struct DoNothingVerifier;
125impl tokio_rustls::rustls::client::danger::ServerCertVerifier for DoNothingVerifier {
126    fn verify_server_cert(
127        &self,
128        _end_entity: &rustls_pki_types::CertificateDer<'_>,
129        _intermediates: &[rustls_pki_types::CertificateDer<'_>],
130        _server_name: &rustls_pki_types::ServerName<'_>,
131        _ocsp_response: &[u8],
132        _now: rustls_pki_types::UnixTime,
133    ) -> Result<tokio_rustls::rustls::client::danger::ServerCertVerified, tokio_rustls::rustls::Error>
134    {
135        Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
136    }
137
138    fn verify_tls12_signature(
139        &self,
140        _message: &[u8],
141        _cert: &rustls_pki_types::CertificateDer<'_>,
142        _dss: &tokio_rustls::rustls::DigitallySignedStruct,
143    ) -> Result<
144        tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
145        tokio_rustls::rustls::Error,
146    > {
147        Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
148    }
149
150    fn verify_tls13_signature(
151        &self,
152        _message: &[u8],
153        _cert: &rustls_pki_types::CertificateDer<'_>,
154        _dss: &tokio_rustls::rustls::DigitallySignedStruct,
155    ) -> Result<
156        tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
157        tokio_rustls::rustls::Error,
158    > {
159        Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
160    }
161
162    fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
163        tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()
164            .signature_verification_algorithms
165            .supported_schemes()
166    }
167}
168
169/// Configuration for a `protosocket` rpc client.
170#[derive(Debug, Clone)]
171pub struct Configuration<TStreamConnector> {
172    max_buffer_length: usize,
173    buffer_allocation_increment: usize,
174    max_queued_outbound_messages: usize,
175    tcp_keepalive_duration: Option<std::time::Duration>,
176    stream_connector: TStreamConnector,
177}
178
179impl<TStreamConnector> Configuration<TStreamConnector>
180where
181    TStreamConnector: StreamConnector,
182{
183    pub fn new(stream_connector: TStreamConnector) -> Self {
184        log::trace!("new client configuration");
185        Self {
186            max_buffer_length: 4 * (1 << 20), // 4 MiB
187            buffer_allocation_increment: 1 << 20,
188            max_queued_outbound_messages: 256,
189            tcp_keepalive_duration: None,
190            stream_connector,
191        }
192    }
193
194    /// Max buffer length limits the max message size. Try to use a buffer length that is at least 4 times the largest message you want to support.
195    ///
196    /// Default: 4MiB
197    pub fn max_buffer_length(&mut self, max_buffer_length: usize) {
198        self.max_buffer_length = max_buffer_length;
199    }
200
201    /// Max messages that will be queued up waiting for send on the client channel.
202    ///
203    /// Default: 256
204    pub fn max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
205        self.max_queued_outbound_messages = max_queued_outbound_messages;
206    }
207
208    /// Amount of buffer to allocate at one time when buffer needs extension.
209    ///
210    /// Default: 1MiB
211    pub fn buffer_allocation_increment(&mut self, buffer_allocation_increment: usize) {
212        self.buffer_allocation_increment = buffer_allocation_increment;
213    }
214
215    /// The duration to set for tcp_keepalive on the underlying socket.
216    ///
217    /// Default: None
218    pub fn tcp_keepalive_duration(&mut self, tcp_keepalive_duration: Option<std::time::Duration>) {
219        self.tcp_keepalive_duration = tcp_keepalive_duration;
220    }
221}
222
223/// Connect a new protosocket rpc client to a server
224pub async fn connect<Serializer, Deserializer, TStreamConnector>(
225    address: SocketAddr,
226    configuration: &Configuration<TStreamConnector>,
227) -> Result<
228    (
229        RpcClient<Serializer::Message, Deserializer::Message>,
230        protosocket::Connection<
231            RpcCompletionConnectionBindings<Serializer, Deserializer, TStreamConnector::Stream>,
232        >,
233    ),
234    crate::Error,
235>
236where
237    Deserializer: protosocket::Deserializer + Default + 'static,
238    Serializer: protosocket::Serializer + Default + 'static,
239    Deserializer::Message: Message,
240    Serializer::Message: Message,
241    TStreamConnector: StreamConnector,
242{
243    log::trace!("new client {address}, {configuration:?}");
244
245    let stream = TcpStream::connect(&address).await?;
246
247    // For setting socket configuration options available to socket2
248    let socket = socket2::SockRef::from(&stream);
249
250    let mut tcp_keepalive = TcpKeepalive::new();
251    if let Some(duration) = configuration.tcp_keepalive_duration {
252        tcp_keepalive = tcp_keepalive.with_time(duration);
253    }
254    socket.set_nonblocking(true)?;
255    socket.set_tcp_keepalive(&tcp_keepalive)?;
256    socket.set_tcp_nodelay(true)?;
257    socket.set_reuse_address(true)?;
258
259    let message_reactor: RpcCompletionReactor<
260        Deserializer::Message,
261        DoNothingMessageHandler<Deserializer::Message>,
262    > = RpcCompletionReactor::new(Default::default());
263    let (outbound, outbound_messages) = mpsc::channel(configuration.max_queued_outbound_messages);
264    let rpc_client = RpcClient::new(outbound, &message_reactor);
265    let stream = configuration
266        .stream_connector
267        .connect_stream(stream)
268        .await?;
269
270    // Tie outbound_messages to message_reactor via a protosocket::Connection
271    let connection = Connection::<
272        RpcCompletionConnectionBindings<Serializer, Deserializer, TStreamConnector::Stream>,
273    >::new(
274        stream,
275        address,
276        Deserializer::default(),
277        Serializer::default(),
278        configuration.max_buffer_length,
279        configuration.buffer_allocation_increment,
280        configuration.max_queued_outbound_messages,
281        outbound_messages,
282        message_reactor,
283    );
284
285    Ok((rpc_client, connection))
286}