Skip to main content

warp_openssl/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4
5use crate::certificate::{Certificate, CertificateVerifier};
6use crate::config::{LookupFileFn, LookupHashDirFn, SslConfig, TlsConfigBuilder};
7use crate::stream::{CloneableStream, TlsStream};
8use crate::Result;
9
10use futures_util::{Future, TryFuture};
11
12use hyper_util::rt::{TokioExecutor, TokioIo};
13use hyper_util::server::conn::auto;
14use hyper_util::server::graceful::GracefulShutdown;
15use hyper_util::service::TowerToHyperService;
16
17use tokio::net::TcpListener;
18use warp::{Filter, Reply};
19
20/// Create an `OpensslServer` with the provided `Filter`.
21pub fn serve<F>(filter: F) -> OpensslServer<F> {
22    OpensslServer {
23        filter,
24        tls: TlsConfigBuilder::new(),
25    }
26}
27
28/// Settings corresponding to TLS level based on Mozilla's server side TLS recommendations.
29/// See its [documentation][docs] for more details on specifics.
30///
31/// [docs]: https://wiki.mozilla.org/Security/Server_Side_TLS
32#[derive(Debug, Clone)]
33pub enum TlsLevel {
34    /// Settings corresponding to modern configuration of version 4 of Mozilla's server side TLS
35    /// recommendations
36    MozillaModern,
37    /// Settings corresponding to modern configuration of version 5 of Mozilla's server side TLS
38    /// recommendations
39    MozillaModernV5,
40    /// Settings corresponding to the intermediate configuration of version 4 of Mozilla's server side TLS
41    /// recommendations
42    MozillaIntermediate,
43    /// Settings corresponding to the intermediate configuration of version 5 of Mozilla's server side TLS
44    /// recommendations
45    MozillaIntermediateV5,
46}
47
48/// Create an openssl based TLS warp server with the provided filter.
49///
50#[derive(Debug)]
51pub struct OpensslServer<F> {
52    filter: F,
53    tls: TlsConfigBuilder,
54}
55
56impl<F> OpensslServer<F>
57where
58    F: Filter + Clone + Send + Sync + 'static,
59    <F::Future as TryFuture>::Ok: Reply,
60{
61    /// Specify the in-memory contents of the private key.
62    ///
63    pub fn key(self, key: impl AsRef<[u8]>) -> Self {
64        self.with_tls(|tls| tls.key(key.as_ref()))
65    }
66
67    /// Specify the tls level based on Mozilla's server side TLS recommendations.
68    /// See its [documentation][docs] for more details on specifics.
69    ///
70    /// Defaults to `TlsLevel::MozillaIntermediateV5`.
71    ///
72    /// [docs]: https://wiki.mozilla.org/Security/Server_Side_TLS
73    pub fn tls_level(self, tls_level: TlsLevel) -> Self {
74        self.with_tls(|tls| tls.tls_level(tls_level))
75    }
76
77    /// Specify the in-memory contents of the certificate.
78    ///
79    pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
80        self.with_tls(|tls| tls.cert(cert.as_ref()))
81    }
82
83    /// Add file loop callback that loads all the certificates or CRLs present in a file into memory at the time the file is added as a lookup source.
84    /// See [`openssl::x509::X509Lookup::file`] for more details.
85    ///
86    pub fn add_file_lookup(self, lookup: LookupFileFn) -> Self {
87        self.with_tls(|tls| tls.add_file_lookup(lookup))
88    }
89
90    /// Add hash dir lookup callback that loads certificates and CRLs on demand and caches them in memory once they are loaded.
91    /// See [`openssl::x509::X509Lookup::hash_dir`] for more details.
92    ///
93    pub fn add_hash_dir_lookup(self, lookup: LookupHashDirFn) -> Self {
94        self.with_tls(|tls| tls.add_hash_dir_lookup(lookup))
95    }
96
97    /// Specify the in-memory contents of the trust anchor for optional client authentication.
98    ///
99    /// Anonymous clients will be accepted by default
100    /// Non anonymous clients passing CertificateVerifier and having a valid certificate chain will be accepted.
101    ///
102    pub fn client_auth_optional(
103        self,
104        trust_anchor: impl AsRef<[u8]>,
105        certificate_verifier: Arc<dyn CertificateVerifier>,
106    ) -> Self {
107        self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref(), certificate_verifier))
108    }
109
110    /// Specify the in-memory contents of the trust anchor for required client authentication.
111    /// Only clients passing CertificateVerifier and having a valid certificate chain will be accepted.
112    ///
113    pub fn client_auth_required(
114        self,
115        trust_anchor: impl AsRef<[u8]>,
116        certificate_verifier: Arc<dyn CertificateVerifier>,
117    ) -> Self {
118        self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref(), certificate_verifier))
119    }
120
121    /// **Not recommended** Disables partial certificate chain verification.
122    ///
123    /// For certificate pinning to work properly its enough to validate that
124    /// the certificate chains to an anchor in the trust store. This is the default behavior.
125    ///
126    pub fn disable_partial_chain_verification(self) -> Self {
127        self.with_tls(|tls| tls.disable_partial_chain_verification())
128    }
129
130    fn with_tls<Func>(self, func: Func) -> Self
131    where
132        Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
133    {
134        let OpensslServer { filter, tls } = self;
135        let tls = func(tls);
136        OpensslServer { filter, tls }
137    }
138
139    fn build_server(
140        self,
141        addr: impl Into<SocketAddr>,
142    ) -> Result<(SocketAddr, TcpListener, SslConfig, F)> {
143        let ssl_config = self.tls.build()?;
144        let addr = addr.into();
145        let std_listener = std::net::TcpListener::bind(addr)?;
146        std_listener.set_nonblocking(true)?;
147        let listener = TcpListener::from_std(std_listener)?;
148        let local_addr = listener.local_addr()?;
149        Ok((local_addr, listener, ssl_config, self.filter))
150    }
151
152    /// Create a tls server bound to a specific port.
153    ///
154    pub fn bind(
155        self,
156        addr: impl Into<SocketAddr>,
157    ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static)> {
158        let (addr, listener, ssl_config, filter) = self.build_server(addr)?;
159        let ssl_config = Arc::new(ssl_config);
160
161        let srv = async move {
162            let builder = auto::Builder::new(TokioExecutor::new());
163            loop {
164                let (tcp_stream, remote_addr) = match listener.accept().await {
165                    Ok(conn) => conn,
166                    Err(e) => {
167                        tracing::error!("accept error: {}", e);
168                        continue;
169                    }
170                };
171
172                if let Err(e) = tcp_stream.set_nodelay(true) {
173                    tracing::warn!("set_nodelay failed for {}: {}", remote_addr, e);
174                }
175
176                let ssl_config = ssl_config.clone();
177                let filter = filter.clone();
178                let builder = builder.clone();
179
180                tokio::spawn(async move {
181                    if let Err(e) =
182                        serve_connection(tcp_stream, &ssl_config, filter, &builder).await
183                    {
184                        tracing::error!("connection error: {}", e);
185                    }
186                });
187            }
188        };
189
190        Ok((addr, srv))
191    }
192
193    /// Create a tls server bound to a specific port with graceful shutdown signal.
194    ///
195    /// When the signal completes, the server will start the graceful shutdown
196    /// process.
197    ///
198    pub fn bind_with_graceful_shutdown(
199        self,
200        addr: impl Into<SocketAddr>,
201        signal: impl Future<Output = ()> + Send + 'static,
202    ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static)> {
203        let (addr, listener, ssl_config, filter) = self.build_server(addr)?;
204        let ssl_config = Arc::new(ssl_config);
205
206        let srv = async move {
207            let builder = auto::Builder::new(TokioExecutor::new());
208            let graceful = GracefulShutdown::new();
209            let mut signal = std::pin::pin!(signal);
210
211            loop {
212                tokio::select! {
213                    result = listener.accept() => {
214                        let (tcp_stream, remote_addr) = match result {
215                            Ok(conn) => conn,
216                            Err(e) => {
217                                tracing::error!("accept error: {}", e);
218                                continue;
219                            }
220                        };
221
222                        if let Err(e) = tcp_stream.set_nodelay(true) {
223                            tracing::warn!("set_nodelay failed for {}: {}", remote_addr, e);
224                        }
225
226                        let ssl_config = ssl_config.clone();
227                        let filter = filter.clone();
228                        let builder = builder.clone();
229                        let watcher = graceful.watcher();
230
231                        tokio::spawn(async move {
232                            let tls_stream = match TlsStream::new(tcp_stream, &ssl_config) {
233                                Ok(s) => s,
234                                Err(e) => {
235                                    tracing::error!("TLS stream creation error: {}", e);
236                                    return;
237                                }
238                            };
239
240                            let stream_ref = tls_stream.stream();
241                            let svc = CertInjectorService {
242                                inner: warp::service(filter),
243                                stream: stream_ref,
244                            };
245
246                            let conn = builder.serve_connection(
247                                TokioIo::new(tls_stream),
248                                TowerToHyperService::new(svc),
249                            );
250                            let conn = watcher.watch(conn.into_owned());
251
252                            if let Err(e) = conn.await {
253                                tracing::error!("connection error: {}", e);
254                            }
255                        });
256                    }
257                    _ = &mut signal => {
258                        break;
259                    }
260                }
261            }
262
263            graceful.shutdown().await;
264        };
265
266        Ok((addr, srv))
267    }
268}
269
270async fn serve_connection<F>(
271    tcp_stream: tokio::net::TcpStream,
272    ssl_config: &SslConfig,
273    filter: F,
274    builder: &auto::Builder<TokioExecutor>,
275) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>>
276where
277    F: Filter + Clone + Send + Sync + 'static,
278    <F::Future as TryFuture>::Ok: Reply,
279{
280    let tls_stream = TlsStream::new(tcp_stream, ssl_config)?;
281    let stream_ref = tls_stream.stream();
282
283    let svc = CertInjectorService {
284        inner: warp::service(filter),
285        stream: stream_ref,
286    };
287
288    builder
289        .serve_connection(TokioIo::new(tls_stream), TowerToHyperService::new(svc))
290        .await?;
291
292    Ok(())
293}
294
295/// A service wrapper that injects the peer certificate into request extensions.
296#[derive(Clone)]
297struct CertInjectorService<S> {
298    inner: S,
299    stream: CloneableStream,
300}
301
302impl<S, B> tower_service::Service<http::Request<B>> for CertInjectorService<S>
303where
304    S: tower_service::Service<http::Request<B>>,
305{
306    type Response = S::Response;
307    type Error = S::Error;
308    type Future = S::Future;
309
310    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
311        self.inner.poll_ready(cx)
312    }
313
314    fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
315        let certificate: Option<Certificate> = self
316            .stream
317            .lock()
318            .ok()
319            .and_then(|stream| stream.ssl().peer_certificate())
320            .and_then(|peer_certificate| peer_certificate.try_into().ok());
321
322        if let Some(certificate) = certificate {
323            req.extensions_mut().insert(certificate);
324        }
325
326        self.inner.call(req)
327    }
328}
329
330impl<S> std::fmt::Debug for CertInjectorService<S> {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        f.debug_struct("CertInjectorService").finish()
333    }
334}