1use std::{
2 io,
3 net::SocketAddr,
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll},
7};
8
9use agnostic::{
10 Runtime,
11 net::{Net, TcpListener, TcpStream},
12};
13use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, lock::BiLock};
14pub use futures_rustls::{
15 TlsAcceptor, TlsConnector, client, pki_types::ServerName, rustls, server,
16};
17use peekable::future::{AsyncPeekExt, AsyncPeekable};
18use rustls::{SignatureScheme, client::danger::ServerCertVerifier};
19
20use super::{Listener, PromisedStream, StreamLayer};
21
22#[derive(Debug, Default)]
25pub struct NoopCertificateVerifier;
26
27impl NoopCertificateVerifier {
28 pub fn new() -> Arc<Self> {
30 Arc::new(Self)
31 }
32}
33
34impl ServerCertVerifier for NoopCertificateVerifier {
35 fn verify_server_cert(
36 &self,
37 _end_entity: &rustls::pki_types::CertificateDer<'_>,
38 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
39 _server_name: &rustls::pki_types::ServerName<'_>,
40 _ocsp_response: &[u8],
41 _now: rustls::pki_types::UnixTime,
42 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
43 Ok(rustls::client::danger::ServerCertVerified::assertion())
44 }
45
46 fn verify_tls12_signature(
47 &self,
48 _message: &[u8],
49 _cert: &rustls::pki_types::CertificateDer<'_>,
50 _dss: &rustls::DigitallySignedStruct,
51 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
52 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
53 }
54
55 fn verify_tls13_signature(
56 &self,
57 _message: &[u8],
58 _cert: &rustls::pki_types::CertificateDer<'_>,
59 _dss: &rustls::DigitallySignedStruct,
60 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
61 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
62 }
63
64 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
65 vec![
66 SignatureScheme::RSA_PKCS1_SHA1,
67 SignatureScheme::ECDSA_SHA1_Legacy,
68 SignatureScheme::RSA_PKCS1_SHA256,
69 SignatureScheme::ECDSA_NISTP256_SHA256,
70 SignatureScheme::RSA_PKCS1_SHA384,
71 SignatureScheme::ECDSA_NISTP384_SHA384,
72 SignatureScheme::RSA_PKCS1_SHA512,
73 SignatureScheme::ECDSA_NISTP521_SHA512,
74 SignatureScheme::RSA_PSS_SHA256,
75 SignatureScheme::RSA_PSS_SHA384,
76 SignatureScheme::RSA_PSS_SHA512,
77 SignatureScheme::ED25519,
78 SignatureScheme::ED448,
79 ]
80 }
81}
82
83#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))]
85pub struct TlsOptions {
86 #[viewit(
88 getter(const, style = "ref", attrs(doc = "Get the TLS acceptor."),),
89 setter(attrs(doc = "Set the TLS acceptor. (Builder pattern)"),)
90 )]
91 acceptor: TlsAcceptor,
92 #[viewit(
94 getter(const, style = "ref", attrs(doc = "Get the TLS connector."),),
95 setter(attrs(doc = "Set the TLS connector. (Builder pattern)"),)
96 )]
97 connector: TlsConnector,
98 #[viewit(
100 getter(const, style = "ref", attrs(doc = "Get the server name."),),
101 setter(attrs(doc = "Set the server name. (Builder pattern)"),)
102 )]
103 server_name: ServerName<'static>,
104}
105
106impl TlsOptions {
107 #[inline]
109 pub const fn new(
110 server_name: ServerName<'static>,
111 acceptor: TlsAcceptor,
112 connector: TlsConnector,
113 ) -> Self {
114 Self {
115 acceptor,
116 connector,
117 server_name,
118 }
119 }
120}
121
122pub struct Tls<R> {
124 domain: ServerName<'static>,
125 acceptor: Arc<TlsAcceptor>,
126 connector: TlsConnector,
127 _marker: std::marker::PhantomData<R>,
128}
129
130impl<R> Tls<R> {
131 #[inline]
133 fn new_in(domain: ServerName<'static>, acceptor: TlsAcceptor, connector: TlsConnector) -> Self {
134 Self {
135 domain,
136 acceptor: Arc::new(acceptor),
137 connector,
138 _marker: std::marker::PhantomData,
139 }
140 }
141}
142
143impl<R: Runtime> StreamLayer for Tls<R> {
144 type Listener = TlsListener<R>;
145 type Stream = TlsStream<R>;
146 type Options = TlsOptions;
147 type Runtime = R;
148
149 #[inline]
150 async fn new(options: Self::Options) -> io::Result<Self> {
151 Ok(Self::new_in(
152 options.server_name,
153 options.acceptor,
154 options.connector,
155 ))
156 }
157
158 async fn connect(&self, addr: SocketAddr) -> io::Result<Self::Stream> {
159 let conn = <<R::Net as Net>::TcpStream as TcpStream>::connect(addr).await?;
160 let local_addr = conn.local_addr()?;
161 let stream = self.connector.connect(self.domain.clone(), conn).await?;
162 Ok(TlsStream::client(stream, addr, local_addr))
163 }
164
165 async fn bind(&self, addr: SocketAddr) -> io::Result<Self::Listener> {
166 let acceptor = self.acceptor.clone();
167 <<R::Net as Net>::TcpListener as TcpListener>::bind(addr)
168 .await
169 .and_then(|ln| {
170 ln.local_addr().map(|local_addr| TlsListener {
171 ln,
172 acceptor,
173 local_addr,
174 })
175 })
176 }
177
178 fn is_secure() -> bool {
179 true
180 }
181}
182
183pub struct TlsListener<R: Runtime> {
185 ln: <R::Net as Net>::TcpListener,
186 acceptor: Arc<TlsAcceptor>,
187 local_addr: SocketAddr,
188}
189
190impl<R: Runtime> Listener for TlsListener<R> {
191 type Stream = TlsStream<R>;
192
193 async fn accept(&self) -> io::Result<(Self::Stream, std::net::SocketAddr)> {
194 let (conn, addr) = self.ln.accept().await?;
195 let stream = TlsAcceptor::accept(&self.acceptor, conn).await?;
196 Ok((TlsStream::server(stream, addr, self.local_addr), addr))
197 }
198
199 fn local_addr(&self) -> std::net::SocketAddr {
200 self.local_addr
201 }
202
203 async fn shutdown(&self) -> io::Result<()> {
204 Ok(())
205 }
206}
207
208#[pin_project::pin_project]
209enum TlsStreamKind<R: Runtime> {
210 Client {
211 #[pin]
212 stream: AsyncPeekable<client::TlsStream<<R::Net as Net>::TcpStream>>,
213 },
214 Server {
215 #[pin]
216 stream: AsyncPeekable<server::TlsStream<<R::Net as Net>::TcpStream>>,
217 },
218}
219
220impl<R: Runtime> AsyncRead for TlsStreamKind<R> {
221 fn poll_read(
222 self: Pin<&mut Self>,
223 cx: &mut Context<'_>,
224 buf: &mut [u8],
225 ) -> Poll<io::Result<usize>> {
226 match self.get_mut() {
227 Self::Client { stream } => Pin::new(stream).poll_read(cx, buf),
228 Self::Server { stream } => Pin::new(stream).poll_read(cx, buf),
229 }
230 }
231}
232
233impl<R: Runtime> AsyncWrite for TlsStreamKind<R> {
234 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
235 match self.get_mut() {
236 Self::Client { stream } => Pin::new(stream).poll_write(cx, buf),
237 Self::Server { stream } => Pin::new(stream).poll_write(cx, buf),
238 }
239 }
240
241 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
242 match self.get_mut() {
243 Self::Client { stream } => Pin::new(stream).poll_flush(cx),
244 Self::Server { stream } => Pin::new(stream).poll_flush(cx),
245 }
246 }
247
248 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
249 match self.get_mut() {
250 Self::Client { stream } => Pin::new(stream).poll_close(cx),
251 Self::Server { stream } => Pin::new(stream).poll_close(cx),
252 }
253 }
254}
255
256#[pin_project::pin_project]
258pub struct TlsStream<R: Runtime> {
259 #[pin]
260 stream: TlsStreamKind<R>,
261 local_addr: SocketAddr,
262 peer_addr: SocketAddr,
263}
264
265impl<R: Runtime> AsyncRead for TlsStream<R> {
266 fn poll_read(
267 self: Pin<&mut Self>,
268 cx: &mut Context<'_>,
269 buf: &mut [u8],
270 ) -> Poll<io::Result<usize>> {
271 self.project().stream.poll_read(cx, buf)
272 }
273}
274
275impl<R: Runtime> AsyncWrite for TlsStream<R> {
276 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
277 self.project().stream.poll_write(cx, buf)
278 }
279
280 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
281 self.project().stream.poll_flush(cx)
282 }
283
284 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
285 self.project().stream.poll_close(cx)
286 }
287}
288
289pub struct TlsProtoReader<R: Runtime> {
291 reader: BiLock<TlsStream<R>>,
292}
293
294impl<R: Runtime> memberlist_core::proto::ProtoReader for TlsProtoReader<R> {
295 async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
296 let mut reader = self.reader.lock().await;
297 match reader.stream {
298 TlsStreamKind::Client { ref mut stream } => stream.peek(buf).await,
299 TlsStreamKind::Server { ref mut stream } => stream.peek(buf).await,
300 }
301 }
302
303 async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
304 let mut reader = self.reader.lock().await;
305 AsyncReadExt::read(&mut *reader, buf).await
306 }
307
308 async fn peek_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
309 let mut reader = self.reader.lock().await;
310 match reader.stream {
311 TlsStreamKind::Client { ref mut stream } => stream.peek_exact(buf).await,
312 TlsStreamKind::Server { ref mut stream } => stream.peek_exact(buf).await,
313 }
314 }
315
316 async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
317 let mut reader = self.reader.lock().await;
318 AsyncReadExt::read_exact(&mut *reader, buf).await
319 }
320}
321
322pub struct TlsProtoWriter<R: Runtime> {
324 writer: BiLock<TlsStream<R>>,
325}
326
327impl<R: Runtime> memberlist_core::proto::ProtoWriter for TlsProtoWriter<R> {
328 async fn close(&mut self) -> std::io::Result<()> {
329 let mut writer = self.writer.lock().await;
330 AsyncWriteExt::close(&mut *writer).await
331 }
332
333 async fn write_all(&mut self, payload: &[u8]) -> std::io::Result<()> {
334 let mut writer = self.writer.lock().await;
335 AsyncWriteExt::write_all(&mut *writer, payload).await
336 }
337
338 async fn flush(&mut self) -> std::io::Result<()> {
339 let mut writer = self.writer.lock().await;
340 AsyncWriteExt::flush(&mut *writer).await
341 }
342}
343
344impl<R: Runtime> memberlist_core::transport::Connection for TlsStream<R> {
345 type Reader = TlsProtoReader<R>;
346
347 type Writer = TlsProtoWriter<R>;
348
349 fn split(self) -> (Self::Reader, Self::Writer) {
350 let (reader, writer) = BiLock::new(self);
351 (Self::Reader { reader }, Self::Writer { writer })
352 }
353
354 async fn close(&mut self) -> std::io::Result<()> {
355 AsyncWriteExt::close(self).await
356 }
357
358 async fn write_all(&mut self, payload: &[u8]) -> std::io::Result<()> {
359 AsyncWriteExt::write_all(self, payload).await
360 }
361
362 async fn flush(&mut self) -> std::io::Result<()> {
363 AsyncWriteExt::flush(self).await
364 }
365
366 async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
367 match &mut self.stream {
368 TlsStreamKind::Client { stream } => stream.peek(buf).await,
369 TlsStreamKind::Server { stream } => stream.peek(buf).await,
370 }
371 }
372
373 fn consume_peek(&mut self) {
374 let _ = match &mut self.stream {
375 TlsStreamKind::Client { stream } => stream.consume(),
376 TlsStreamKind::Server { stream } => stream.consume(),
377 };
378 }
379
380 async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
381 AsyncReadExt::read_exact(self, buf).await
382 }
383
384 async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
385 AsyncReadExt::read(self, buf).await
386 }
387
388 async fn peek_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
389 match &mut self.stream {
390 TlsStreamKind::Client { stream } => stream.peek_exact(buf).await,
391 TlsStreamKind::Server { stream } => stream.peek_exact(buf).await,
392 }
393 }
394}
395
396impl<R: Runtime> PromisedStream for TlsStream<R> {
397 type Instant = R::Instant;
398
399 #[inline]
400 fn local_addr(&self) -> SocketAddr {
401 self.local_addr
402 }
403
404 #[inline]
405 fn peer_addr(&self) -> SocketAddr {
406 self.peer_addr
407 }
408}
409
410impl<R: Runtime> TlsStream<R> {
411 #[inline]
412 fn client(
413 stream: client::TlsStream<<R::Net as Net>::TcpStream>,
414 peer_addr: SocketAddr,
415 local_addr: SocketAddr,
416 ) -> Self {
417 Self {
418 stream: TlsStreamKind::Client {
419 stream: stream.peekable(),
420 },
421 local_addr,
422 peer_addr,
423 }
424 }
425
426 #[inline]
427 fn server(
428 stream: server::TlsStream<<R::Net as Net>::TcpStream>,
429 peer_addr: SocketAddr,
430 local_addr: SocketAddr,
431 ) -> Self {
432 Self {
433 stream: TlsStreamKind::Server {
434 stream: stream.peekable(),
435 },
436 local_addr,
437 peer_addr,
438 }
439 }
440}