1use std::{
2 fmt::Debug,
3 io,
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll},
7};
8
9use pin_project::pin_project;
10use tokio::{
11 io::{AsyncRead, AsyncWrite, ReadBuf},
12 net::{TcpStream, ToSocketAddrs},
13};
14use tokio_rustls::{
15 client::TlsStream,
16 rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
17 TlsConnector,
18};
19
20#[cfg_attr(docsrs, doc(cfg(feature = "tokio-stream")))]
23#[pin_project(project = StreamProj)]
24#[derive(Debug)]
25pub enum Stream {
26 Tcp(#[pin] TcpStream),
27 SecureTcp(#[pin] Box<TlsStream<TcpStream>>),
28}
29
30impl Stream {
31 #[cfg_attr(docsrs, doc(cfg(feature = "tokio-stream")))]
34 pub async fn connect(
35 addr: impl ToSocketAddrs,
36 domain: Option<impl AsRef<str>>,
37 ) -> io::Result<Self> {
38 match domain {
39 Some(domain) => {
40 let mut root_cert_store = RootCertStore::empty();
41 root_cert_store.add_server_trust_anchors(
42 webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|anchor| {
43 OwnedTrustAnchor::from_subject_spki_name_constraints(
44 anchor.subject,
45 anchor.spki,
46 anchor.name_constraints,
47 )
48 }),
49 );
50
51 let config = ClientConfig::builder()
52 .with_safe_defaults()
53 .with_root_certificates(root_cert_store)
54 .with_no_client_auth();
55
56 let server_name = ServerName::try_from(domain.as_ref())
57 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, domain.as_ref()))?;
58
59 let stream = TcpStream::connect(addr).await?;
60
61 Ok(Stream::SecureTcp(Box::new(
62 TlsConnector::from(Arc::new(config))
63 .connect(server_name, stream)
64 .await?,
65 )))
66 }
67 None => Ok(Stream::Tcp(TcpStream::connect(addr).await?)),
68 }
69 }
70}
71
72impl AsyncRead for Stream {
73 fn poll_read(
74 self: Pin<&mut Self>,
75 cx: &mut Context<'_>,
76 buf: &mut ReadBuf<'_>,
77 ) -> Poll<io::Result<()>> {
78 match self.project() {
79 StreamProj::Tcp(tcp_stream) => tcp_stream.poll_read(cx, buf),
80 StreamProj::SecureTcp(tls_stream) => tls_stream.poll_read(cx, buf),
81 }
82 }
83}
84
85impl AsyncWrite for Stream {
86 fn poll_write(
87 self: Pin<&mut Self>,
88 cx: &mut Context<'_>,
89 buf: &[u8],
90 ) -> Poll<io::Result<usize>> {
91 match self.project() {
92 StreamProj::Tcp(tcp_stream) => tcp_stream.poll_write(cx, buf),
93 StreamProj::SecureTcp(tls_stream) => tls_stream.poll_write(cx, buf),
94 }
95 }
96
97 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
98 match self.project() {
99 StreamProj::Tcp(tcp_stream) => tcp_stream.poll_flush(cx),
100 StreamProj::SecureTcp(tls_stream) => tls_stream.poll_flush(cx),
101 }
102 }
103
104 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
105 match self.project() {
106 StreamProj::Tcp(tcp_stream) => tcp_stream.poll_shutdown(cx),
107 StreamProj::SecureTcp(tls_stream) => tls_stream.poll_shutdown(cx),
108 }
109 }
110}