1#![doc(html_root_url = "https://docs.rs/tonic-openssl/0.3.0")]
7#![warn(
8 missing_debug_implementations,
9 missing_docs,
10 rust_2018_idioms,
11 unreachable_pub
12)]
13
14mod client;
15pub use client::{connector, new_endpoint};
16
17use futures::{Stream, TryStreamExt};
18use openssl::{
19 ssl::{Ssl, SslAcceptor},
20 x509::X509,
21};
22use std::{
23 fmt::Debug,
24 io,
25 ops::ControlFlow,
26 pin::{pin, Pin},
27 sync::Arc,
28 task::{Context, Poll},
29};
30use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
31use tonic::transport::server::Connected;
32
33pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
35
36pub const ALPN_H2_WIRE: &[u8] = b"\x02h2";
39
40pub fn incoming<IO, IE>(
44 incoming: impl Stream<Item = Result<IO, IE>>,
45 acceptor: SslAcceptor,
46) -> impl Stream<Item = Result<SslStream<IO>, Error>>
47where
48 IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
49 IE: Into<crate::Error>,
50{
51 async_stream::try_stream! {
52 let mut incoming = pin!(incoming);
53
54 let mut tasks = tokio::task::JoinSet::new();
55
56 loop {
57 match select(&mut incoming, &mut tasks).await {
58 SelectOutput::Incoming(stream) => {
59 let ssl = Ssl::new(acceptor.context())?;
60 let mut tls = tokio_openssl::SslStream::new(ssl, stream)?;
61 tasks.spawn(async move {
62 Pin::new(&mut tls).accept().await?;
63 Ok(SslStream {
64 inner: tls
65 })
66 });
67 }
68
69 SelectOutput::Io(io) => {
70 yield io;
71 }
72
73 SelectOutput::TcpErr(e) => match handle_tcp_accept_error(e) {
74 ControlFlow::Continue(()) => continue,
75 ControlFlow::Break(e) => Err(e)?,
76 }
77
78 SelectOutput::TlsErr(e) => {
79 tracing::debug!(error = %e, "tls accept error");
80 continue;
81 }
82
83 SelectOutput::Done => {
84 break;
85 }
86 }
87 }
88 }
89}
90
91async fn select<IO: 'static, IE>(
93 incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
94 tasks: &mut tokio::task::JoinSet<Result<SslStream<IO>, crate::Error>>,
95) -> SelectOutput<IO>
96where
97 IE: Into<crate::Error>,
98{
99 let incoming_stream_future = async {
100 match incoming.try_next().await {
101 Ok(Some(stream)) => SelectOutput::Incoming(stream),
102 Ok(None) => SelectOutput::Done,
103 Err(e) => SelectOutput::TcpErr(e.into()),
104 }
105 };
106
107 if tasks.is_empty() {
108 return incoming_stream_future.await;
109 }
110
111 tokio::select! {
112 stream = incoming_stream_future => stream,
113 accept = tasks.join_next() => {
114 match accept.expect("JoinSet should never end") {
115 Ok(Ok(io)) => SelectOutput::Io(io),
116 Ok(Err(e)) => SelectOutput::TlsErr(e),
117 Err(e) => SelectOutput::TlsErr(e.into()),
118 }
119 }
120 }
121}
122
123enum SelectOutput<A> {
124 Incoming(A),
125 Io(SslStream<A>),
126 TcpErr(crate::Error),
127 TlsErr(crate::Error),
128 Done,
129}
130
131fn handle_tcp_accept_error(e: impl Into<crate::Error>) -> ControlFlow<crate::Error> {
136 let e = e.into();
137 tracing::debug!(error = %e, "accept loop error");
138 if let Some(e) = e.downcast_ref::<io::Error>() {
139 if matches!(
140 e.kind(),
141 io::ErrorKind::ConnectionAborted
142 | io::ErrorKind::ConnectionReset
143 | io::ErrorKind::BrokenPipe
144 | io::ErrorKind::Interrupted
145 | io::ErrorKind::WouldBlock
146 | io::ErrorKind::TimedOut
147 ) {
148 return ControlFlow::Continue(());
149 }
150 }
151
152 ControlFlow::Break(e)
153}
154
155#[derive(Debug)]
158pub struct SslStream<S> {
159 inner: tokio_openssl::SslStream<S>,
160}
161
162impl<S: Connected> Connected for SslStream<S> {
163 type ConnectInfo = SslConnectInfo<S::ConnectInfo>;
164
165 fn connect_info(&self) -> Self::ConnectInfo {
166 let inner = self.inner.get_ref().connect_info();
167
168 let ssl = self.inner.ssl();
172 let certs = ssl
173 .verified_chain()
174 .map(|certs| {
175 certs
176 .iter()
177 .filter_map(|c| c.to_pem().ok())
178 .filter_map(|p| X509::from_pem(&p).ok())
179 .collect()
180 })
181 .map(Arc::new);
182
183 SslConnectInfo { inner, certs }
184 }
185}
186
187impl<S> AsyncRead for SslStream<S>
188where
189 S: AsyncRead + AsyncWrite + Unpin,
190{
191 fn poll_read(
192 mut self: Pin<&mut Self>,
193 cx: &mut Context<'_>,
194 buf: &mut ReadBuf<'_>,
195 ) -> Poll<std::io::Result<()>> {
196 Pin::new(&mut self.inner).poll_read(cx, buf)
197 }
198}
199
200impl<S> AsyncWrite for SslStream<S>
201where
202 S: AsyncRead + AsyncWrite + Unpin,
203{
204 fn poll_write(
205 mut self: Pin<&mut Self>,
206 cx: &mut Context<'_>,
207 buf: &[u8],
208 ) -> Poll<std::io::Result<usize>> {
209 Pin::new(&mut self.inner).poll_write(cx, buf)
210 }
211
212 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
213 Pin::new(&mut self.inner).poll_flush(cx)
214 }
215
216 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
217 Pin::new(&mut self.inner).poll_shutdown(cx)
218 }
219}
220
221#[derive(Debug, Clone)]
227pub struct SslConnectInfo<T> {
228 inner: T,
229 certs: Option<Arc<Vec<X509>>>,
230}
231
232impl<T> SslConnectInfo<T> {
233 pub fn get_ref(&self) -> &T {
235 &self.inner
236 }
237
238 pub fn get_mut(&mut self) -> &mut T {
240 &mut self.inner
241 }
242
243 pub fn peer_certs(&self) -> Option<Arc<Vec<X509>>> {
245 self.certs.clone()
246 }
247}