1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4
5use crate::certificate::{Certificate, CertificateVerifier};
6use crate::config::{LookupFileFn, LookupHashDirFn, SslConfig, TlsConfigBuilder};
7use crate::stream::{CloneableStream, TlsStream};
8use crate::Result;
9
10use futures_util::{Future, TryFuture};
11
12use hyper_util::rt::{TokioExecutor, TokioIo};
13use hyper_util::server::conn::auto;
14use hyper_util::server::graceful::GracefulShutdown;
15use hyper_util::service::TowerToHyperService;
16
17use tokio::net::TcpListener;
18use warp::{Filter, Reply};
19
20pub fn serve<F>(filter: F) -> OpensslServer<F> {
22 OpensslServer {
23 filter,
24 tls: TlsConfigBuilder::new(),
25 }
26}
27
28#[derive(Debug, Clone)]
33pub enum TlsLevel {
34 MozillaModern,
37 MozillaModernV5,
40 MozillaIntermediate,
43 MozillaIntermediateV5,
46}
47
48#[derive(Debug)]
51pub struct OpensslServer<F> {
52 filter: F,
53 tls: TlsConfigBuilder,
54}
55
56impl<F> OpensslServer<F>
57where
58 F: Filter + Clone + Send + Sync + 'static,
59 <F::Future as TryFuture>::Ok: Reply,
60{
61 pub fn key(self, key: impl AsRef<[u8]>) -> Self {
64 self.with_tls(|tls| tls.key(key.as_ref()))
65 }
66
67 pub fn tls_level(self, tls_level: TlsLevel) -> Self {
74 self.with_tls(|tls| tls.tls_level(tls_level))
75 }
76
77 pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
80 self.with_tls(|tls| tls.cert(cert.as_ref()))
81 }
82
83 pub fn add_file_lookup(self, lookup: LookupFileFn) -> Self {
87 self.with_tls(|tls| tls.add_file_lookup(lookup))
88 }
89
90 pub fn add_hash_dir_lookup(self, lookup: LookupHashDirFn) -> Self {
94 self.with_tls(|tls| tls.add_hash_dir_lookup(lookup))
95 }
96
97 pub fn client_auth_optional(
103 self,
104 trust_anchor: impl AsRef<[u8]>,
105 certificate_verifier: Arc<dyn CertificateVerifier>,
106 ) -> Self {
107 self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref(), certificate_verifier))
108 }
109
110 pub fn client_auth_required(
114 self,
115 trust_anchor: impl AsRef<[u8]>,
116 certificate_verifier: Arc<dyn CertificateVerifier>,
117 ) -> Self {
118 self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref(), certificate_verifier))
119 }
120
121 pub fn disable_partial_chain_verification(self) -> Self {
127 self.with_tls(|tls| tls.disable_partial_chain_verification())
128 }
129
130 fn with_tls<Func>(self, func: Func) -> Self
131 where
132 Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
133 {
134 let OpensslServer { filter, tls } = self;
135 let tls = func(tls);
136 OpensslServer { filter, tls }
137 }
138
139 fn build_server(
140 self,
141 addr: impl Into<SocketAddr>,
142 ) -> Result<(SocketAddr, TcpListener, SslConfig, F)> {
143 let ssl_config = self.tls.build()?;
144 let addr = addr.into();
145 let std_listener = std::net::TcpListener::bind(addr)?;
146 std_listener.set_nonblocking(true)?;
147 let listener = TcpListener::from_std(std_listener)?;
148 let local_addr = listener.local_addr()?;
149 Ok((local_addr, listener, ssl_config, self.filter))
150 }
151
152 pub fn bind(
155 self,
156 addr: impl Into<SocketAddr>,
157 ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static)> {
158 let (addr, listener, ssl_config, filter) = self.build_server(addr)?;
159 let ssl_config = Arc::new(ssl_config);
160
161 let srv = async move {
162 let builder = auto::Builder::new(TokioExecutor::new());
163 loop {
164 let (tcp_stream, remote_addr) = match listener.accept().await {
165 Ok(conn) => conn,
166 Err(e) => {
167 tracing::error!("accept error: {}", e);
168 continue;
169 }
170 };
171
172 if let Err(e) = tcp_stream.set_nodelay(true) {
173 tracing::warn!("set_nodelay failed for {}: {}", remote_addr, e);
174 }
175
176 let ssl_config = ssl_config.clone();
177 let filter = filter.clone();
178 let builder = builder.clone();
179
180 tokio::spawn(async move {
181 if let Err(e) =
182 serve_connection(tcp_stream, &ssl_config, filter, &builder).await
183 {
184 tracing::error!("connection error: {}", e);
185 }
186 });
187 }
188 };
189
190 Ok((addr, srv))
191 }
192
193 pub fn bind_with_graceful_shutdown(
199 self,
200 addr: impl Into<SocketAddr>,
201 signal: impl Future<Output = ()> + Send + 'static,
202 ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static)> {
203 let (addr, listener, ssl_config, filter) = self.build_server(addr)?;
204 let ssl_config = Arc::new(ssl_config);
205
206 let srv = async move {
207 let builder = auto::Builder::new(TokioExecutor::new());
208 let graceful = GracefulShutdown::new();
209 let mut signal = std::pin::pin!(signal);
210
211 loop {
212 tokio::select! {
213 result = listener.accept() => {
214 let (tcp_stream, remote_addr) = match result {
215 Ok(conn) => conn,
216 Err(e) => {
217 tracing::error!("accept error: {}", e);
218 continue;
219 }
220 };
221
222 if let Err(e) = tcp_stream.set_nodelay(true) {
223 tracing::warn!("set_nodelay failed for {}: {}", remote_addr, e);
224 }
225
226 let ssl_config = ssl_config.clone();
227 let filter = filter.clone();
228 let builder = builder.clone();
229 let watcher = graceful.watcher();
230
231 tokio::spawn(async move {
232 let tls_stream = match TlsStream::new(tcp_stream, &ssl_config) {
233 Ok(s) => s,
234 Err(e) => {
235 tracing::error!("TLS stream creation error: {}", e);
236 return;
237 }
238 };
239
240 let stream_ref = tls_stream.stream();
241 let svc = CertInjectorService {
242 inner: warp::service(filter),
243 stream: stream_ref,
244 };
245
246 let conn = builder.serve_connection(
247 TokioIo::new(tls_stream),
248 TowerToHyperService::new(svc),
249 );
250 let conn = watcher.watch(conn.into_owned());
251
252 if let Err(e) = conn.await {
253 tracing::error!("connection error: {}", e);
254 }
255 });
256 }
257 _ = &mut signal => {
258 break;
259 }
260 }
261 }
262
263 graceful.shutdown().await;
264 };
265
266 Ok((addr, srv))
267 }
268}
269
270async fn serve_connection<F>(
271 tcp_stream: tokio::net::TcpStream,
272 ssl_config: &SslConfig,
273 filter: F,
274 builder: &auto::Builder<TokioExecutor>,
275) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>>
276where
277 F: Filter + Clone + Send + Sync + 'static,
278 <F::Future as TryFuture>::Ok: Reply,
279{
280 let tls_stream = TlsStream::new(tcp_stream, ssl_config)?;
281 let stream_ref = tls_stream.stream();
282
283 let svc = CertInjectorService {
284 inner: warp::service(filter),
285 stream: stream_ref,
286 };
287
288 builder
289 .serve_connection(TokioIo::new(tls_stream), TowerToHyperService::new(svc))
290 .await?;
291
292 Ok(())
293}
294
295#[derive(Clone)]
297struct CertInjectorService<S> {
298 inner: S,
299 stream: CloneableStream,
300}
301
302impl<S, B> tower_service::Service<http::Request<B>> for CertInjectorService<S>
303where
304 S: tower_service::Service<http::Request<B>>,
305{
306 type Response = S::Response;
307 type Error = S::Error;
308 type Future = S::Future;
309
310 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
311 self.inner.poll_ready(cx)
312 }
313
314 fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
315 let certificate: Option<Certificate> = self
316 .stream
317 .lock()
318 .ok()
319 .and_then(|stream| stream.ssl().peer_certificate())
320 .and_then(|peer_certificate| peer_certificate.try_into().ok());
321
322 if let Some(certificate) = certificate {
323 req.extensions_mut().insert(certificate);
324 }
325
326 self.inner.call(req)
327 }
328}
329
330impl<S> std::fmt::Debug for CertInjectorService<S> {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 f.debug_struct("CertInjectorService").finish()
333 }
334}