Skip to main content

protosocket_rpc/client/
stream_connector.rs

1use std::{future::Future, sync::Arc};
2use tokio::net::TcpStream;
3use tokio_rustls::rustls::pki_types::ServerName;
4
5/// An async handshake that provides an `AsyndRead`/`AsyncWrite` stream.
6///
7/// You could consider wrapping these if you need to hook stream connection, or of course
8/// you can implement your own connector for your own stream type.
9pub trait StreamConnector: std::fmt::Debug {
10    /// The type of stream this connector will produce
11    type Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static;
12
13    /// Take a TCP stream and turn it into a new Stream type.
14    fn connect_stream(
15        &self,
16        stream: TcpStream,
17    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send;
18}
19
20/// A `StreamConnector` for bare TCP streams.
21#[derive(Debug)]
22pub struct TcpStreamConnector;
23impl StreamConnector for TcpStreamConnector {
24    type Stream = TcpStream;
25
26    fn connect_stream(
27        &self,
28        stream: TcpStream,
29    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
30        std::future::ready(Ok(stream))
31    }
32}
33
34/// A `StreamConnector` for PKI TLS streams.
35pub struct WebpkiTlsStreamConnector {
36    connector: tokio_rustls::TlsConnector,
37    servername: ServerName<'static>,
38}
39impl WebpkiTlsStreamConnector {
40    /// Create a new `TlsStreamConnector` for a server
41    pub fn new(servername: ServerName<'static>) -> Self {
42        let client_config = Arc::new(
43            tokio_rustls::rustls::ClientConfig::builder_with_protocol_versions(&[
44                &tokio_rustls::rustls::version::TLS13,
45            ])
46            .with_root_certificates(tokio_rustls::rustls::RootCertStore::from_iter(
47                webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
48            ))
49            .with_no_client_auth(),
50        );
51        let connector = tokio_rustls::TlsConnector::from(client_config);
52        Self {
53            connector,
54            servername,
55        }
56    }
57}
58impl std::fmt::Debug for WebpkiTlsStreamConnector {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("TlsStreamConnector").finish_non_exhaustive()
61    }
62}
63impl StreamConnector for WebpkiTlsStreamConnector {
64    type Stream = tokio_rustls::client::TlsStream<TcpStream>;
65
66    fn connect_stream(
67        &self,
68        stream: TcpStream,
69    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
70        self.connector
71            .clone()
72            .connect(self.servername.clone(), stream)
73    }
74}
75
76/// A `StreamConnector` for self-signed server TLS streams. No host certificate validation is performed.
77pub struct UnverifiedTlsStreamConnector {
78    connector: tokio_rustls::TlsConnector,
79    servername: ServerName<'static>,
80}
81impl UnverifiedTlsStreamConnector {
82    /// Create a new `UnverifiedTlsStreamConnector` for a server.
83    /// This connector does not perform any certificate validation.
84    pub fn new(servername: ServerName<'static>) -> Self {
85        let client_config = Arc::new(
86            tokio_rustls::rustls::ClientConfig::builder_with_protocol_versions(&[
87                &tokio_rustls::rustls::version::TLS13,
88            ])
89            .dangerous()
90            .with_custom_certificate_verifier(Arc::new(DoNothingVerifier))
91            .with_no_client_auth(),
92        );
93        let connector = tokio_rustls::TlsConnector::from(client_config);
94        Self {
95            connector,
96            servername,
97        }
98    }
99}
100impl std::fmt::Debug for UnverifiedTlsStreamConnector {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        f.debug_struct("UnverifiedTlsStreamConnector")
103            .finish_non_exhaustive()
104    }
105}
106impl StreamConnector for UnverifiedTlsStreamConnector {
107    type Stream = tokio_rustls::client::TlsStream<TcpStream>;
108
109    fn connect_stream(
110        &self,
111        stream: TcpStream,
112    ) -> impl Future<Output = std::io::Result<Self::Stream>> + Send {
113        self.connector
114            .clone()
115            .connect(self.servername.clone(), stream)
116    }
117}
118
119// You don't need this if you use a real certificate
120#[derive(Debug)]
121struct DoNothingVerifier;
122impl tokio_rustls::rustls::client::danger::ServerCertVerifier for DoNothingVerifier {
123    fn verify_server_cert(
124        &self,
125        _end_entity: &rustls_pki_types::CertificateDer<'_>,
126        _intermediates: &[rustls_pki_types::CertificateDer<'_>],
127        _server_name: &rustls_pki_types::ServerName<'_>,
128        _ocsp_response: &[u8],
129        _now: rustls_pki_types::UnixTime,
130    ) -> Result<tokio_rustls::rustls::client::danger::ServerCertVerified, tokio_rustls::rustls::Error>
131    {
132        Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
133    }
134
135    fn verify_tls12_signature(
136        &self,
137        _message: &[u8],
138        _cert: &rustls_pki_types::CertificateDer<'_>,
139        _dss: &tokio_rustls::rustls::DigitallySignedStruct,
140    ) -> Result<
141        tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
142        tokio_rustls::rustls::Error,
143    > {
144        Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
145    }
146
147    fn verify_tls13_signature(
148        &self,
149        _message: &[u8],
150        _cert: &rustls_pki_types::CertificateDer<'_>,
151        _dss: &tokio_rustls::rustls::DigitallySignedStruct,
152    ) -> Result<
153        tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
154        tokio_rustls::rustls::Error,
155    > {
156        Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
157    }
158
159    fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
160        tokio_rustls::rustls::crypto::aws_lc_rs::default_provider()
161            .signature_verification_algorithms
162            .supported_schemes()
163    }
164}