1use std::sync::Arc;
2
3use async_trait::async_trait;
4use rustls::server::ResolvesServerCert;
5use tokio::io::{AsyncRead, AsyncWrite};
6use tokio::net::TcpListener;
7
8use crate::{Backend, CustomServer, Error};
9
10impl<B: Backend> CustomServer<B> {
11 pub async fn listen_for_tcp_on<S: TcpService, T: tokio::net::ToSocketAddrs + Send + Sync>(
14 &self,
15 addr: T,
16 service: S,
17 ) -> Result<(), Error> {
18 let listener = TcpListener::bind(&addr).await?;
19 let mut shutdown_watcher = self
20 .data
21 .shutdown
22 .watcher()
23 .await
24 .expect("server already shutdown");
25
26 loop {
27 tokio::select! {
28 _ = shutdown_watcher.wait_for_shutdown() => {
29 break;
30 }
31 incoming = listener.accept() => {
32 if incoming.is_err() {
33 continue;
34 }
35 let (connection, remote_addr) = incoming.unwrap();
36
37 let peer = Peer {
38 address: remote_addr,
39 protocol: service.available_protocols()[0].clone(),
40 secure: false,
41 };
42
43 let task_self = self.clone();
44 let task_service = service.clone();
45 tokio::spawn(async move {
46 if let Err(err) = task_self.handle_tcp_connection(connection, peer, &task_service).await {
47 log::error!("[server] closing connection {}: {:?}", remote_addr, err);
48 }
49 });
50 }
51 }
52 }
53
54 Ok(())
55 }
56
57 #[cfg_attr(not(feature = "websockets"), allow(unused_variables))]
63 #[cfg_attr(not(feature = "acme"), allow(unused_mut))]
64 pub async fn listen_for_secure_tcp_on<
65 S: TcpService,
66 T: tokio::net::ToSocketAddrs + Send + Sync,
67 >(
68 &self,
69 addr: T,
70 service: S,
71 ) -> Result<(), Error> {
72 drop(self.refresh_certified_key().await);
74
75 #[cfg(feature = "acme")]
76 {
77 let task_self = self.clone();
78 tokio::task::spawn(async move {
79 if let Err(err) = task_self.update_acme_certificates().await {
80 log::error!("[server] acme task error: {0}", err);
81 }
82 });
83 }
84
85 let mut config = rustls::ServerConfig::builder()
86 .with_safe_defaults()
87 .with_no_client_auth()
88 .with_cert_resolver(Arc::new(self.clone()));
89 config.alpn_protocols = service
90 .available_protocols()
91 .iter()
92 .map(|proto| proto.alpn_name().to_vec())
93 .collect();
94
95 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
96 let listener = TcpListener::bind(&addr).await?;
97 loop {
98 let (stream, peer_addr) = listener.accept().await?;
99 let acceptor = acceptor.clone();
100
101 let task_self = self.clone();
102 let task_service = service.clone();
103 tokio::task::spawn(async move {
104 let stream = match acceptor.accept(stream).await {
105 Ok(stream) => stream,
106 Err(err) => {
107 log::error!("[server] error during tls handshake: {:?}", err);
108 return;
109 }
110 };
111
112 let available_protocols = task_service.available_protocols();
113 let protocol = stream
114 .get_ref()
115 .1
116 .alpn_protocol()
117 .and_then(|protocol| {
118 available_protocols
119 .iter()
120 .find(|p| p.alpn_name() == protocol)
121 .cloned()
122 })
123 .unwrap_or_else(|| available_protocols[0].clone());
124 let peer = Peer {
125 address: peer_addr,
126 secure: true,
127 protocol,
128 };
129 if let Err(err) = task_self
130 .handle_tcp_connection(stream, peer, &task_service)
131 .await
132 {
133 log::error!("[server] error for client {}: {:?}", peer_addr, err);
134 }
135 });
136 }
137 }
138
139 #[cfg_attr(not(feature = "websockets"), allow(unused_variables))]
140 async fn handle_tcp_connection<
141 S: TcpService,
142 C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
143 >(
144 &self,
145 connection: C,
146 peer: Peer<S::ApplicationProtocols>,
147 service: &S,
148 ) -> Result<(), Error> {
149 #[cfg(feature = "acme")]
151 if peer.protocol.alpn_name() == async_acme::acme::ACME_TLS_ALPN_NAME {
152 log::info!("received acme challenge connection");
153 return Ok(());
154 }
155
156 if let Err(connection) = service.handle_connection(connection, &peer).await {
157 #[cfg(feature = "websockets")]
158 if let Err(err) = self
159 .handle_raw_websocket_connection(connection, peer.address)
160 .await
161 {
162 log::error!(
163 "[server] error on websocket for {}: {:?}",
164 peer.address,
165 err
166 );
167 }
168 }
169
170 Ok(())
171 }
172}
173
174impl<B: Backend> ResolvesServerCert for CustomServer<B> {
175 #[cfg_attr(not(feature = "acme"), allow(unused_variables))]
176 fn resolve(
177 &self,
178 client_hello: rustls::server::ClientHello<'_>,
179 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
180 #[cfg(feature = "acme")]
181 if client_hello
182 .alpn()
183 .map(|mut iter| iter.any(|n| n == async_acme::acme::ACME_TLS_ALPN_NAME))
184 .unwrap_or_default()
185 {
186 let server_name = client_hello.server_name()?.to_owned();
187 let keys = self.data.alpn_keys.lock();
188 if let Some(key) = keys.get(AsRef::<str>::as_ref(&server_name)) {
189 log::info!("returning acme challenge");
190 return Some(key.clone());
191 }
192
193 log::error!(
194 "acme alpn challenge received with no key for {}",
195 server_name
196 );
197 return None;
198 }
199
200 let cached_key = self.data.primary_tls_key.lock();
201 if let Some(key) = cached_key.as_ref() {
202 Some(key.clone())
203 } else {
204 log::error!("[server] inbound tls connection with no certificate installed");
205 None
206 }
207 }
208}
209
210#[async_trait]
212pub trait TcpService: Clone + Send + Sync + 'static {
213 type ApplicationProtocols: ApplicationProtocols;
215
216 fn available_protocols(&self) -> &[Self::ApplicationProtocols];
220
221 async fn handle_connection<
224 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
225 >(
226 &self,
227 connection: S,
228 peer: &Peer<Self::ApplicationProtocols>,
229 ) -> Result<(), S>;
230}
231
232#[async_trait]
236pub trait HttpService: Clone + Send + Sync + 'static {
237 async fn handle_connection<
240 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
241 >(
242 &self,
243 connection: S,
244 peer: &Peer,
245 ) -> Result<(), S>;
246}
247
248#[async_trait]
249impl<T> TcpService for T
250where
251 T: HttpService,
252{
253 type ApplicationProtocols = StandardTcpProtocols;
254
255 fn available_protocols(&self) -> &[Self::ApplicationProtocols] {
256 StandardTcpProtocols::all()
257 }
258
259 async fn handle_connection<
260 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
261 >(
262 &self,
263 connection: S,
264 peer: &Peer<Self::ApplicationProtocols>,
265 ) -> Result<(), S> {
266 HttpService::handle_connection(self, connection, peer).await
267 }
268}
269
270#[async_trait]
271impl HttpService for () {
272 async fn handle_connection<
273 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
274 >(
275 &self,
276 connection: S,
277 _peer: &Peer<StandardTcpProtocols>,
278 ) -> Result<(), S> {
279 Err(connection)
280 }
281}
282
283pub trait ApplicationProtocols: Clone + std::fmt::Debug + Send + Sync {
285 fn alpn_name(&self) -> &'static [u8];
287}
288
289#[derive(Debug, Clone)]
291pub struct Peer<P: ApplicationProtocols = StandardTcpProtocols> {
292 pub address: std::net::SocketAddr,
294 pub secure: bool,
296 pub protocol: P,
298}
299
300#[derive(Debug, Clone)]
302#[allow(missing_docs)]
303pub enum StandardTcpProtocols {
304 Http1,
305 #[cfg(feature = "acme")]
306 Acme,
307 Other,
308}
309
310impl StandardTcpProtocols {
311 #[cfg(feature = "acme")]
312 const fn all() -> &'static [Self] {
313 &[Self::Http1, Self::Acme]
314 }
315
316 #[cfg(not(feature = "acme"))]
317 const fn all() -> &'static [Self] {
318 &[Self::Http1]
319 }
320}
321
322impl Default for StandardTcpProtocols {
323 fn default() -> Self {
324 Self::Http1
325 }
326}
327
328impl ApplicationProtocols for StandardTcpProtocols {
329 fn alpn_name(&self) -> &'static [u8] {
330 match self {
331 Self::Http1 => b"http/1.1",
332 #[cfg(feature = "acme")]
333 Self::Acme => async_acme::acme::ACME_TLS_ALPN_NAME,
334 Self::Other => unreachable!(),
335 }
336 }
337}