1use std::pin::Pin;
2
3use async_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
4use async_trait::async_trait;
5use futures_lite::{AsyncRead, AsyncWrite};
6
7use sillad::{dialer::Dialer, listener::Listener, Pipe};
8
9pub struct TlsPipe<T: AsyncRead + AsyncWrite + Unpin + Send + 'static> {
11 inner: TlsStream<T>,
12 remote_addr: Option<String>,
13}
14
15impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPipe<T> {
16 fn poll_read(
17 self: Pin<&mut Self>,
18 cx: &mut std::task::Context<'_>,
19 buf: &mut [u8],
20 ) -> std::task::Poll<std::io::Result<usize>> {
21 Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
22 }
23}
24
25impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPipe<T> {
26 fn poll_write(
27 self: Pin<&mut Self>,
28 cx: &mut std::task::Context<'_>,
29 buf: &[u8],
30 ) -> std::task::Poll<std::io::Result<usize>> {
31 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
32 }
33
34 fn poll_flush(
35 self: Pin<&mut Self>,
36 cx: &mut std::task::Context<'_>,
37 ) -> std::task::Poll<std::io::Result<()>> {
38 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
39 }
40
41 fn poll_close(
42 self: Pin<&mut Self>,
43 cx: &mut std::task::Context<'_>,
44 ) -> std::task::Poll<std::io::Result<()>> {
45 Pin::new(&mut self.get_mut().inner).poll_close(cx)
46 }
47}
48
49impl<T: AsyncRead + AsyncWrite + Unpin + Send> Pipe for TlsPipe<T> {
50 fn protocol(&self) -> &str {
51 "tls"
52 }
53
54 fn remote_addr(&self) -> Option<&str> {
55 self.remote_addr.as_deref()
56 }
57}
58
59pub struct TlsDialer<D: Dialer> {
61 inner: D,
62 connector: TlsConnector,
63 domain: String,
64}
65
66impl<D: Dialer> TlsDialer<D> {
67 pub fn new(inner: D, connector: TlsConnector, domain: String) -> Self {
68 Self {
69 inner,
70 connector,
71 domain,
72 }
73 }
74}
75
76#[async_trait]
77impl<D: Dialer> Dialer for TlsDialer<D>
78where
79 D::P: AsyncRead + AsyncWrite + Unpin + Send,
80{
81 type P = TlsPipe<D::P>;
82
83 async fn dial(&self) -> std::io::Result<Self::P> {
84 let stream = self.inner.dial().await?;
85 let remote_addr = stream.remote_addr().map(|s| s.to_string());
86 let tls_stream = self
87 .connector
88 .connect(&self.domain, stream)
89 .await
90 .inspect_err(|e| {
91 tracing::warn!(
92 err = display(e),
93 addr = debug(&remote_addr),
94 "TLS connection failed"
95 )
96 })
97 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?;
98 tracing::warn!(addr = debug(&remote_addr), "TLS connection SUCCESS");
99 Ok(TlsPipe {
100 inner: tls_stream,
101 remote_addr,
102 })
103 }
104}
105
106pub struct TlsListener<L: Listener> {
108 incoming: tachyonix::Receiver<TlsPipe<L::P>>,
110 _accept_task: async_task::Task<()>,
112}
113
114impl<L: Listener> TlsListener<L>
115where
116 L::P: AsyncRead + AsyncWrite + Unpin + Send + 'static,
117{
118 pub fn new(mut inner: L, acceptor: TlsAcceptor) -> Self {
119 let (tx, rx) = tachyonix::channel(1);
121
122 let acceptor_clone = acceptor.clone();
123 let accept_task = smolscale::spawn(async move {
124 loop {
125 let raw_conn = match inner.accept().await {
127 Ok(conn) => conn,
128 Err(err) => {
129 eprintln!("Underlying listener error: {:?}", err);
131 break;
132 }
133 };
134
135 let tx2 = tx.clone();
137 let acceptor2 = acceptor_clone.clone();
138 let remote_addr = raw_conn.remote_addr().map(|s| s.to_string());
139 smolscale::spawn(async move {
140 match acceptor2.accept(raw_conn).await {
141 Ok(tls_stream) => {
142 let pipe = TlsPipe {
143 inner: tls_stream,
144 remote_addr,
145 };
146 let _ = tx2.send(pipe).await;
147 }
148 Err(e) => {
149 eprintln!("TLS handshake error (ignored): {:?}", e);
151 }
152 }
153 })
154 .detach();
155 }
156 });
157
158 TlsListener {
159 incoming: rx,
160 _accept_task: accept_task,
161 }
162 }
163}
164
165#[async_trait]
166impl<L: Listener> Listener for TlsListener<L>
167where
168 L::P: AsyncRead + AsyncWrite + Unpin + Send + 'static,
169{
170 type P = TlsPipe<L::P>;
171
172 async fn accept(&mut self) -> std::io::Result<Self::P> {
173 match self.incoming.recv().await {
176 Ok(pipe) => Ok(pipe),
177 Err(_) => Err(std::io::Error::new(
178 std::io::ErrorKind::Other,
179 "Underlying listener failure",
180 )),
181 }
182 }
183}