memberlist_net/stream_layer/
tls.rs

1use std::{
2  io,
3  net::SocketAddr,
4  pin::Pin,
5  sync::Arc,
6  task::{Context, Poll},
7};
8
9use agnostic::{
10  Runtime,
11  net::{Net, TcpListener, TcpStream},
12};
13use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, lock::BiLock};
14pub use futures_rustls::{
15  TlsAcceptor, TlsConnector, client, pki_types::ServerName, rustls, server,
16};
17use peekable::future::{AsyncPeekExt, AsyncPeekable};
18use rustls::{SignatureScheme, client::danger::ServerCertVerifier};
19
20use super::{Listener, PromisedStream, StreamLayer};
21
22/// A certificate verifier that does not verify the server certificate.
23/// This is useful for testing. Do not use in production.
24#[derive(Debug, Default)]
25pub struct NoopCertificateVerifier;
26
27impl NoopCertificateVerifier {
28  /// Constructs a new `NoopCertificateVerifier`.
29  pub fn new() -> Arc<Self> {
30    Arc::new(Self)
31  }
32}
33
34impl ServerCertVerifier for NoopCertificateVerifier {
35  fn verify_server_cert(
36    &self,
37    _end_entity: &rustls::pki_types::CertificateDer<'_>,
38    _intermediates: &[rustls::pki_types::CertificateDer<'_>],
39    _server_name: &rustls::pki_types::ServerName<'_>,
40    _ocsp_response: &[u8],
41    _now: rustls::pki_types::UnixTime,
42  ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
43    Ok(rustls::client::danger::ServerCertVerified::assertion())
44  }
45
46  fn verify_tls12_signature(
47    &self,
48    _message: &[u8],
49    _cert: &rustls::pki_types::CertificateDer<'_>,
50    _dss: &rustls::DigitallySignedStruct,
51  ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
52    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
53  }
54
55  fn verify_tls13_signature(
56    &self,
57    _message: &[u8],
58    _cert: &rustls::pki_types::CertificateDer<'_>,
59    _dss: &rustls::DigitallySignedStruct,
60  ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
61    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
62  }
63
64  fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
65    vec![
66      SignatureScheme::RSA_PKCS1_SHA1,
67      SignatureScheme::ECDSA_SHA1_Legacy,
68      SignatureScheme::RSA_PKCS1_SHA256,
69      SignatureScheme::ECDSA_NISTP256_SHA256,
70      SignatureScheme::RSA_PKCS1_SHA384,
71      SignatureScheme::ECDSA_NISTP384_SHA384,
72      SignatureScheme::RSA_PKCS1_SHA512,
73      SignatureScheme::ECDSA_NISTP521_SHA512,
74      SignatureScheme::RSA_PSS_SHA256,
75      SignatureScheme::RSA_PSS_SHA384,
76      SignatureScheme::RSA_PSS_SHA512,
77      SignatureScheme::ED25519,
78      SignatureScheme::ED448,
79    ]
80  }
81}
82
83/// The options for the tls stream layer.
84#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))]
85pub struct TlsOptions {
86  /// The acceptor for the server.
87  #[viewit(
88    getter(const, style = "ref", attrs(doc = "Get the TLS acceptor."),),
89    setter(attrs(doc = "Set the TLS acceptor. (Builder pattern)"),)
90  )]
91  acceptor: TlsAcceptor,
92  /// The connector for the client.
93  #[viewit(
94    getter(const, style = "ref", attrs(doc = "Get the TLS connector."),),
95    setter(attrs(doc = "Set the TLS connector. (Builder pattern)"),)
96  )]
97  connector: TlsConnector,
98  /// The server name
99  #[viewit(
100    getter(const, style = "ref", attrs(doc = "Get the server name."),),
101    setter(attrs(doc = "Set the server name. (Builder pattern)"),)
102  )]
103  server_name: ServerName<'static>,
104}
105
106impl TlsOptions {
107  /// Constructs a new `TlsOptions`.
108  #[inline]
109  pub const fn new(
110    server_name: ServerName<'static>,
111    acceptor: TlsAcceptor,
112    connector: TlsConnector,
113  ) -> Self {
114    Self {
115      acceptor,
116      connector,
117      server_name,
118    }
119  }
120}
121
122/// Tls stream layer
123pub struct Tls<R> {
124  domain: ServerName<'static>,
125  acceptor: Arc<TlsAcceptor>,
126  connector: TlsConnector,
127  _marker: std::marker::PhantomData<R>,
128}
129
130impl<R> Tls<R> {
131  /// Create a new tcp stream layer
132  #[inline]
133  fn new_in(domain: ServerName<'static>, acceptor: TlsAcceptor, connector: TlsConnector) -> Self {
134    Self {
135      domain,
136      acceptor: Arc::new(acceptor),
137      connector,
138      _marker: std::marker::PhantomData,
139    }
140  }
141}
142
143impl<R: Runtime> StreamLayer for Tls<R> {
144  type Listener = TlsListener<R>;
145  type Stream = TlsStream<R>;
146  type Options = TlsOptions;
147  type Runtime = R;
148
149  #[inline]
150  async fn new(options: Self::Options) -> io::Result<Self> {
151    Ok(Self::new_in(
152      options.server_name,
153      options.acceptor,
154      options.connector,
155    ))
156  }
157
158  async fn connect(&self, addr: SocketAddr) -> io::Result<Self::Stream> {
159    let conn = <<R::Net as Net>::TcpStream as TcpStream>::connect(addr).await?;
160    let local_addr = conn.local_addr()?;
161    let stream = self.connector.connect(self.domain.clone(), conn).await?;
162    Ok(TlsStream::client(stream, addr, local_addr))
163  }
164
165  async fn bind(&self, addr: SocketAddr) -> io::Result<Self::Listener> {
166    let acceptor = self.acceptor.clone();
167    <<R::Net as Net>::TcpListener as TcpListener>::bind(addr)
168      .await
169      .and_then(|ln| {
170        ln.local_addr().map(|local_addr| TlsListener {
171          ln,
172          acceptor,
173          local_addr,
174        })
175      })
176  }
177
178  fn is_secure() -> bool {
179    true
180  }
181}
182
183/// [`Listener`] of the TLS stream layer
184pub struct TlsListener<R: Runtime> {
185  ln: <R::Net as Net>::TcpListener,
186  acceptor: Arc<TlsAcceptor>,
187  local_addr: SocketAddr,
188}
189
190impl<R: Runtime> Listener for TlsListener<R> {
191  type Stream = TlsStream<R>;
192
193  async fn accept(&self) -> io::Result<(Self::Stream, std::net::SocketAddr)> {
194    let (conn, addr) = self.ln.accept().await?;
195    let stream = TlsAcceptor::accept(&self.acceptor, conn).await?;
196    Ok((TlsStream::server(stream, addr, self.local_addr), addr))
197  }
198
199  fn local_addr(&self) -> std::net::SocketAddr {
200    self.local_addr
201  }
202
203  async fn shutdown(&self) -> io::Result<()> {
204    Ok(())
205  }
206}
207
208#[pin_project::pin_project]
209enum TlsStreamKind<R: Runtime> {
210  Client {
211    #[pin]
212    stream: AsyncPeekable<client::TlsStream<<R::Net as Net>::TcpStream>>,
213  },
214  Server {
215    #[pin]
216    stream: AsyncPeekable<server::TlsStream<<R::Net as Net>::TcpStream>>,
217  },
218}
219
220impl<R: Runtime> AsyncRead for TlsStreamKind<R> {
221  fn poll_read(
222    self: Pin<&mut Self>,
223    cx: &mut Context<'_>,
224    buf: &mut [u8],
225  ) -> Poll<io::Result<usize>> {
226    match self.get_mut() {
227      Self::Client { stream } => Pin::new(stream).poll_read(cx, buf),
228      Self::Server { stream } => Pin::new(stream).poll_read(cx, buf),
229    }
230  }
231}
232
233impl<R: Runtime> AsyncWrite for TlsStreamKind<R> {
234  fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
235    match self.get_mut() {
236      Self::Client { stream } => Pin::new(stream).poll_write(cx, buf),
237      Self::Server { stream } => Pin::new(stream).poll_write(cx, buf),
238    }
239  }
240
241  fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
242    match self.get_mut() {
243      Self::Client { stream } => Pin::new(stream).poll_flush(cx),
244      Self::Server { stream } => Pin::new(stream).poll_flush(cx),
245    }
246  }
247
248  fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
249    match self.get_mut() {
250      Self::Client { stream } => Pin::new(stream).poll_close(cx),
251      Self::Server { stream } => Pin::new(stream).poll_close(cx),
252    }
253  }
254}
255
256/// [`PromisedStream`] of the TLS stream layer.
257#[pin_project::pin_project]
258pub struct TlsStream<R: Runtime> {
259  #[pin]
260  stream: TlsStreamKind<R>,
261  local_addr: SocketAddr,
262  peer_addr: SocketAddr,
263}
264
265impl<R: Runtime> AsyncRead for TlsStream<R> {
266  fn poll_read(
267    self: Pin<&mut Self>,
268    cx: &mut Context<'_>,
269    buf: &mut [u8],
270  ) -> Poll<io::Result<usize>> {
271    self.project().stream.poll_read(cx, buf)
272  }
273}
274
275impl<R: Runtime> AsyncWrite for TlsStream<R> {
276  fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
277    self.project().stream.poll_write(cx, buf)
278  }
279
280  fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
281    self.project().stream.poll_flush(cx)
282  }
283
284  fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
285    self.project().stream.poll_close(cx)
286  }
287}
288
289/// A [`ProtoReader`](memberlist_core::proto::ProtoReader) for the TLS stream layer.
290pub struct TlsProtoReader<R: Runtime> {
291  reader: BiLock<TlsStream<R>>,
292}
293
294impl<R: Runtime> memberlist_core::proto::ProtoReader for TlsProtoReader<R> {
295  async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
296    let mut reader = self.reader.lock().await;
297    match reader.stream {
298      TlsStreamKind::Client { ref mut stream } => stream.peek(buf).await,
299      TlsStreamKind::Server { ref mut stream } => stream.peek(buf).await,
300    }
301  }
302
303  async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
304    let mut reader = self.reader.lock().await;
305    AsyncReadExt::read(&mut *reader, buf).await
306  }
307
308  async fn peek_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
309    let mut reader = self.reader.lock().await;
310    match reader.stream {
311      TlsStreamKind::Client { ref mut stream } => stream.peek_exact(buf).await,
312      TlsStreamKind::Server { ref mut stream } => stream.peek_exact(buf).await,
313    }
314  }
315
316  async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
317    let mut reader = self.reader.lock().await;
318    AsyncReadExt::read_exact(&mut *reader, buf).await
319  }
320}
321
322/// A [`ProtoWriter`](memberlist_core::proto::ProtoWriter) for the TLS stream layer.
323pub struct TlsProtoWriter<R: Runtime> {
324  writer: BiLock<TlsStream<R>>,
325}
326
327impl<R: Runtime> memberlist_core::proto::ProtoWriter for TlsProtoWriter<R> {
328  async fn close(&mut self) -> std::io::Result<()> {
329    let mut writer = self.writer.lock().await;
330    AsyncWriteExt::close(&mut *writer).await
331  }
332
333  async fn write_all(&mut self, payload: &[u8]) -> std::io::Result<()> {
334    let mut writer = self.writer.lock().await;
335    AsyncWriteExt::write_all(&mut *writer, payload).await
336  }
337
338  async fn flush(&mut self) -> std::io::Result<()> {
339    let mut writer = self.writer.lock().await;
340    AsyncWriteExt::flush(&mut *writer).await
341  }
342}
343
344impl<R: Runtime> memberlist_core::transport::Connection for TlsStream<R> {
345  type Reader = TlsProtoReader<R>;
346
347  type Writer = TlsProtoWriter<R>;
348
349  fn split(self) -> (Self::Reader, Self::Writer) {
350    let (reader, writer) = BiLock::new(self);
351    (Self::Reader { reader }, Self::Writer { writer })
352  }
353
354  async fn close(&mut self) -> std::io::Result<()> {
355    AsyncWriteExt::close(self).await
356  }
357
358  async fn write_all(&mut self, payload: &[u8]) -> std::io::Result<()> {
359    AsyncWriteExt::write_all(self, payload).await
360  }
361
362  async fn flush(&mut self) -> std::io::Result<()> {
363    AsyncWriteExt::flush(self).await
364  }
365
366  async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
367    match &mut self.stream {
368      TlsStreamKind::Client { stream } => stream.peek(buf).await,
369      TlsStreamKind::Server { stream } => stream.peek(buf).await,
370    }
371  }
372
373  fn consume_peek(&mut self) {
374    let _ = match &mut self.stream {
375      TlsStreamKind::Client { stream } => stream.consume(),
376      TlsStreamKind::Server { stream } => stream.consume(),
377    };
378  }
379
380  async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
381    AsyncReadExt::read_exact(self, buf).await
382  }
383
384  async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
385    AsyncReadExt::read(self, buf).await
386  }
387
388  async fn peek_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
389    match &mut self.stream {
390      TlsStreamKind::Client { stream } => stream.peek_exact(buf).await,
391      TlsStreamKind::Server { stream } => stream.peek_exact(buf).await,
392    }
393  }
394}
395
396impl<R: Runtime> PromisedStream for TlsStream<R> {
397  type Instant = R::Instant;
398
399  #[inline]
400  fn local_addr(&self) -> SocketAddr {
401    self.local_addr
402  }
403
404  #[inline]
405  fn peer_addr(&self) -> SocketAddr {
406    self.peer_addr
407  }
408}
409
410impl<R: Runtime> TlsStream<R> {
411  #[inline]
412  fn client(
413    stream: client::TlsStream<<R::Net as Net>::TcpStream>,
414    peer_addr: SocketAddr,
415    local_addr: SocketAddr,
416  ) -> Self {
417    Self {
418      stream: TlsStreamKind::Client {
419        stream: stream.peekable(),
420      },
421      local_addr,
422      peer_addr,
423    }
424  }
425
426  #[inline]
427  fn server(
428    stream: server::TlsStream<<R::Net as Net>::TcpStream>,
429    peer_addr: SocketAddr,
430    local_addr: SocketAddr,
431  ) -> Self {
432    Self {
433      stream: TlsStreamKind::Server {
434        stream: stream.peekable(),
435      },
436      local_addr,
437      peer_addr,
438    }
439  }
440}