bolt_client/
stream.rs

1use std::{
2    fmt::Debug,
3    io,
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7};
8
9use pin_project::pin_project;
10use tokio::{
11    io::{AsyncRead, AsyncWrite, ReadBuf},
12    net::{TcpStream, ToSocketAddrs},
13};
14use tokio_rustls::{
15    client::TlsStream,
16    rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
17    TlsConnector,
18};
19
20/// A convenient wrapper around a [`TcpStream`](tokio::net::TcpStream) or a
21/// [`TlsStream`](tokio_rustls::client::TlsStream).
22#[cfg_attr(docsrs, doc(cfg(feature = "tokio-stream")))]
23#[pin_project(project = StreamProj)]
24#[derive(Debug)]
25pub enum Stream {
26    Tcp(#[pin] TcpStream),
27    SecureTcp(#[pin] Box<TlsStream<TcpStream>>),
28}
29
30impl Stream {
31    /// Establish a connection with a remote socket. If a domain is provided, TLS negotiation will
32    /// be attempted.
33    #[cfg_attr(docsrs, doc(cfg(feature = "tokio-stream")))]
34    pub async fn connect(
35        addr: impl ToSocketAddrs,
36        domain: Option<impl AsRef<str>>,
37    ) -> io::Result<Self> {
38        match domain {
39            Some(domain) => {
40                let mut root_cert_store = RootCertStore::empty();
41                root_cert_store.add_server_trust_anchors(
42                    webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|anchor| {
43                        OwnedTrustAnchor::from_subject_spki_name_constraints(
44                            anchor.subject,
45                            anchor.spki,
46                            anchor.name_constraints,
47                        )
48                    }),
49                );
50
51                let config = ClientConfig::builder()
52                    .with_safe_defaults()
53                    .with_root_certificates(root_cert_store)
54                    .with_no_client_auth();
55
56                let server_name = ServerName::try_from(domain.as_ref())
57                    .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, domain.as_ref()))?;
58
59                let stream = TcpStream::connect(addr).await?;
60
61                Ok(Stream::SecureTcp(Box::new(
62                    TlsConnector::from(Arc::new(config))
63                        .connect(server_name, stream)
64                        .await?,
65                )))
66            }
67            None => Ok(Stream::Tcp(TcpStream::connect(addr).await?)),
68        }
69    }
70}
71
72impl AsyncRead for Stream {
73    fn poll_read(
74        self: Pin<&mut Self>,
75        cx: &mut Context<'_>,
76        buf: &mut ReadBuf<'_>,
77    ) -> Poll<io::Result<()>> {
78        match self.project() {
79            StreamProj::Tcp(tcp_stream) => tcp_stream.poll_read(cx, buf),
80            StreamProj::SecureTcp(tls_stream) => tls_stream.poll_read(cx, buf),
81        }
82    }
83}
84
85impl AsyncWrite for Stream {
86    fn poll_write(
87        self: Pin<&mut Self>,
88        cx: &mut Context<'_>,
89        buf: &[u8],
90    ) -> Poll<io::Result<usize>> {
91        match self.project() {
92            StreamProj::Tcp(tcp_stream) => tcp_stream.poll_write(cx, buf),
93            StreamProj::SecureTcp(tls_stream) => tls_stream.poll_write(cx, buf),
94        }
95    }
96
97    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
98        match self.project() {
99            StreamProj::Tcp(tcp_stream) => tcp_stream.poll_flush(cx),
100            StreamProj::SecureTcp(tls_stream) => tls_stream.poll_flush(cx),
101        }
102    }
103
104    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
105        match self.project() {
106            StreamProj::Tcp(tcp_stream) => tcp_stream.poll_shutdown(cx),
107            StreamProj::SecureTcp(tls_stream) => tls_stream.poll_shutdown(cx),
108        }
109    }
110}