tonic_native_tls/
lib.rs

1mod re_export;
2
3pub use re_export::*;
4
5use async_stream::stream;
6use futures_util::{Stream, StreamExt, TryStream, TryStreamExt};
7use std::{
8    error::Error as StdError,
9    fmt::Debug,
10    future::ready,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15use tokio_native_tls::{TlsAcceptor, TlsStream};
16
17pub type Error = Box<dyn StdError + Send + Sync + 'static>;
18
19pub fn incoming<S>(
20    mut incoming: S,
21    acceptor: TlsAcceptor,
22) -> impl Stream<Item = Result<TlsStreamWrapper<S::Ok>, Error>>
23where
24    S: TryStream + Unpin,
25    S::Ok: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
26    S::Error: StdError + Send + Sync + 'static,
27{
28    stream! {
29        while let Some(stream) = incoming.try_next().await.transpose() {
30            yield {
31                let acceptor = &acceptor;
32                move || async move {Ok(TlsStreamWrapper(acceptor.accept(stream?).await?))}
33            }().await;
34        }
35    }
36    .filter(|tls_stream| {
37        let ret = if let Err(_error) = tls_stream {
38            #[cfg(feature = "tracing")]
39            tracing::error!("Got error on incoming: `{_error}`.");
40            false
41        } else {
42            true
43        };
44
45        ready(ret)
46    })
47}
48
49#[derive(Debug)]
50pub struct TlsStreamWrapper<S>(TlsStream<S>);
51
52#[cfg(feature = "axum")]
53impl axum::extract::connect_info::Connected<&TlsStreamWrapper<tokio::net::TcpStream>>
54    for std::net::SocketAddr
55{
56    fn connect_info(target: &TlsStreamWrapper<tokio::net::TcpStream>) -> Self {
57        use std::net::{IpAddr, Ipv4Addr, SocketAddr};
58
59        target
60            .0
61            .get_ref()
62            .get_ref()
63            .get_ref()
64            .peer_addr()
65            .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0))
66    }
67}
68
69#[cfg(feature = "tonic")]
70impl<S> tonic::transport::server::Connected for TlsStreamWrapper<S>
71where
72    S: tonic::transport::server::Connected + AsyncRead + AsyncWrite + Unpin,
73{
74    type ConnectInfo = <S as tonic::transport::server::Connected>::ConnectInfo;
75
76    fn connect_info(&self) -> Self::ConnectInfo {
77        self.0.get_ref().get_ref().get_ref().connect_info()
78    }
79}
80
81impl<S> AsyncRead for TlsStreamWrapper<S>
82where
83    S: AsyncRead + AsyncWrite + Unpin,
84{
85    fn poll_read(
86        mut self: Pin<&mut Self>,
87        cx: &mut Context<'_>,
88        buf: &mut ReadBuf<'_>,
89    ) -> Poll<std::io::Result<()>> {
90        Pin::new(&mut self.0).poll_read(cx, buf)
91    }
92}
93
94impl<S> AsyncWrite for TlsStreamWrapper<S>
95where
96    S: AsyncRead + AsyncWrite + Unpin,
97{
98    fn poll_write(
99        mut self: Pin<&mut Self>,
100        cx: &mut Context<'_>,
101        buf: &[u8],
102    ) -> Poll<Result<usize, std::io::Error>> {
103        Pin::new(&mut self.0).poll_write(cx, buf)
104    }
105
106    fn poll_flush(
107        mut self: Pin<&mut Self>,
108        cx: &mut Context<'_>,
109    ) -> Poll<Result<(), std::io::Error>> {
110        Pin::new(&mut self.0).poll_flush(cx)
111    }
112
113    fn poll_shutdown(
114        mut self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116    ) -> Poll<Result<(), std::io::Error>> {
117        Pin::new(&mut self.0).poll_shutdown(cx)
118    }
119}