1use protosocket::Connection;
2use socket2::TcpKeepalive;
3use std::{future::Future, net::SocketAddr, sync::Arc};
4use tokio::{net::TcpStream, sync::mpsc};
5use tokio_rustls::rustls::pki_types::ServerName;
6
7use crate::{
8 client::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor},
9 Message,
10};
11
12use super::{reactor::completion_reactor::RpcCompletionConnectionBindings, RpcClient};
13
14pub trait StreamConnector: std::fmt::Debug {
15 type Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static;
16
17 fn connect_stream(
18 &self,
19 stream: TcpStream,
20 ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send;
21}
22
23#[derive(Debug)]
25pub struct TcpStreamConnector;
26impl StreamConnector for TcpStreamConnector {
27 type Stream = TcpStream;
28
29 fn connect_stream(
30 &self,
31 stream: TcpStream,
32 ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
33 std::future::ready(Ok(stream))
34 }
35}
36
37pub struct WebpkiTlsStreamConnector {
39 connector: tokio_rustls::TlsConnector,
40 servername: ServerName<'static>,
41}
42impl WebpkiTlsStreamConnector {
43 pub fn new(servername: ServerName<'static>) -> Self {
45 let client_config = Arc::new(
46 tokio_rustls::rustls::ClientConfig::builder_with_protocol_versions(&[
47 &tokio_rustls::rustls::version::TLS13,
48 ])
49 .with_root_certificates(tokio_rustls::rustls::RootCertStore::from_iter(
50 webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
51 ))
52 .with_no_client_auth(),
53 );
54 let connector = tokio_rustls::TlsConnector::from(client_config);
55 Self {
56 connector,
57 servername,
58 }
59 }
60}
61impl std::fmt::Debug for WebpkiTlsStreamConnector {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("TlsStreamConnector").finish_non_exhaustive()
64 }
65}
66impl StreamConnector for WebpkiTlsStreamConnector {
67 type Stream = tokio_rustls::client::TlsStream<TcpStream>;
68
69 fn connect_stream(
70 &self,
71 stream: TcpStream,
72 ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
73 self.connector
74 .clone()
75 .connect(self.servername.clone(), stream)
76 }
77}
78
79pub struct UnverifiedTlsStreamConnector {
81 connector: tokio_rustls::TlsConnector,
82 servername: ServerName<'static>,
83}
84impl UnverifiedTlsStreamConnector {
85 pub fn new(servername: ServerName<'static>) -> Self {
88 let client_config = Arc::new(
89 tokio_rustls::rustls::ClientConfig::builder_with_protocol_versions(&[
90 &tokio_rustls::rustls::version::TLS13,
91 ])
92 .dangerous()
93 .with_custom_certificate_verifier(Arc::new(DoNothingVerifier))
94 .with_no_client_auth(),
95 );
96 let connector = tokio_rustls::TlsConnector::from(client_config);
97 Self {
98 connector,
99 servername,
100 }
101 }
102}
103impl std::fmt::Debug for UnverifiedTlsStreamConnector {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 f.debug_struct("UnverifiedTlsStreamConnector")
106 .finish_non_exhaustive()
107 }
108}
109impl StreamConnector for UnverifiedTlsStreamConnector {
110 type Stream = tokio_rustls::client::TlsStream<TcpStream>;
111
112 fn connect_stream(
113 &self,
114 stream: TcpStream,
115 ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
116 self.connector
117 .clone()
118 .connect(self.servername.clone(), stream)
119 }
120}
121
122#[derive(Debug)]
124struct DoNothingVerifier;
125impl tokio_rustls::rustls::client::danger::ServerCertVerifier for DoNothingVerifier {
126 fn verify_server_cert(
127 &self,
128 _end_entity: &rustls_pki_types::CertificateDer<'_>,
129 _intermediates: &[rustls_pki_types::CertificateDer<'_>],
130 _server_name: &rustls_pki_types::ServerName<'_>,
131 _ocsp_response: &[u8],
132 _now: rustls_pki_types::UnixTime,
133 ) -> Result<tokio_rustls::rustls::client::danger::ServerCertVerified, tokio_rustls::rustls::Error>
134 {
135 Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
136 }
137
138 fn verify_tls12_signature(
139 &self,
140 _message: &[u8],
141 _cert: &rustls_pki_types::CertificateDer<'_>,
142 _dss: &tokio_rustls::rustls::DigitallySignedStruct,
143 ) -> Result<
144 tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
145 tokio_rustls::rustls::Error,
146 > {
147 Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
148 }
149
150 fn verify_tls13_signature(
151 &self,
152 _message: &[u8],
153 _cert: &rustls_pki_types::CertificateDer<'_>,
154 _dss: &tokio_rustls::rustls::DigitallySignedStruct,
155 ) -> Result<
156 tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
157 tokio_rustls::rustls::Error,
158 > {
159 Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
160 }
161
162 fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
163 tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()
164 .signature_verification_algorithms
165 .supported_schemes()
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct Configuration<TStreamConnector> {
172 max_buffer_length: usize,
173 buffer_allocation_increment: usize,
174 max_queued_outbound_messages: usize,
175 tcp_keepalive_duration: Option<std::time::Duration>,
176 stream_connector: TStreamConnector,
177}
178
179impl<TStreamConnector> Configuration<TStreamConnector>
180where
181 TStreamConnector: StreamConnector,
182{
183 pub fn new(stream_connector: TStreamConnector) -> Self {
184 log::trace!("new client configuration");
185 Self {
186 max_buffer_length: 4 * (1 << 20), buffer_allocation_increment: 1 << 20,
188 max_queued_outbound_messages: 256,
189 tcp_keepalive_duration: None,
190 stream_connector,
191 }
192 }
193
194 pub fn max_buffer_length(&mut self, max_buffer_length: usize) {
198 self.max_buffer_length = max_buffer_length;
199 }
200
201 pub fn max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) {
205 self.max_queued_outbound_messages = max_queued_outbound_messages;
206 }
207
208 pub fn buffer_allocation_increment(&mut self, buffer_allocation_increment: usize) {
212 self.buffer_allocation_increment = buffer_allocation_increment;
213 }
214
215 pub fn tcp_keepalive_duration(&mut self, tcp_keepalive_duration: Option<std::time::Duration>) {
219 self.tcp_keepalive_duration = tcp_keepalive_duration;
220 }
221}
222
223pub async fn connect<Serializer, Deserializer, TStreamConnector>(
225 address: SocketAddr,
226 configuration: &Configuration<TStreamConnector>,
227) -> Result<
228 (
229 RpcClient<Serializer::Message, Deserializer::Message>,
230 protosocket::Connection<
231 RpcCompletionConnectionBindings<Serializer, Deserializer, TStreamConnector::Stream>,
232 >,
233 ),
234 crate::Error,
235>
236where
237 Deserializer: protosocket::Deserializer + Default + 'static,
238 Serializer: protosocket::Serializer + Default + 'static,
239 Deserializer::Message: Message,
240 Serializer::Message: Message,
241 TStreamConnector: StreamConnector,
242{
243 log::trace!("new client {address}, {configuration:?}");
244
245 let stream = TcpStream::connect(&address).await?;
246
247 let socket = socket2::SockRef::from(&stream);
249
250 let mut tcp_keepalive = TcpKeepalive::new();
251 if let Some(duration) = configuration.tcp_keepalive_duration {
252 tcp_keepalive = tcp_keepalive.with_time(duration);
253 }
254 socket.set_nonblocking(true)?;
255 socket.set_tcp_keepalive(&tcp_keepalive)?;
256 socket.set_tcp_nodelay(true)?;
257 socket.set_reuse_address(true)?;
258
259 let message_reactor: RpcCompletionReactor<
260 Deserializer::Message,
261 DoNothingMessageHandler<Deserializer::Message>,
262 > = RpcCompletionReactor::new(Default::default());
263 let (outbound, outbound_messages) = mpsc::channel(configuration.max_queued_outbound_messages);
264 let rpc_client = RpcClient::new(outbound, &message_reactor);
265 let stream = configuration
266 .stream_connector
267 .connect_stream(stream)
268 .await?;
269
270 let connection = Connection::<
272 RpcCompletionConnectionBindings<Serializer, Deserializer, TStreamConnector::Stream>,
273 >::new(
274 stream,
275 address,
276 Deserializer::default(),
277 Serializer::default(),
278 configuration.max_buffer_length,
279 configuration.buffer_allocation_increment,
280 configuration.max_queued_outbound_messages,
281 outbound_messages,
282 message_reactor,
283 );
284
285 Ok((rpc_client, connection))
286}