witchcraft_server/tls/
tls_client_authentication.rs

1// Copyright 2022 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::tls::ClientCertificate;
15use conjure_error::{Error, PermissionDenied};
16use conjure_http::server::{
17    AsyncEndpoint, AsyncResponseBody, AsyncService, BoxAsyncEndpoint, ConjureRuntime, Endpoint,
18    EndpointMetadata, PathSegment, ResponseBody, Service,
19};
20use http::{Extensions, Method, Request, Response};
21use refreshable::Refreshable;
22use rustls_pki_types::ServerName;
23use std::collections::HashSet;
24use std::str;
25use std::sync::Arc;
26use webpki::EndEntityCert;
27
28/// A service adapter which validates a client's certificate against a collection of allowed subject names.
29///
30/// Requests will be rejected if the client did not provide a certificate or if the certificate does not have an allowed
31/// subject name.
32pub struct TlsClientAuthenticationService<T> {
33    inner: T,
34    trusted_subject_names: Arc<Refreshable<HashSet<String>, Error>>,
35}
36
37impl<T> TlsClientAuthenticationService<T> {
38    /// Creates a new service which will validate the subject name of a client's certificate for each request.
39    ///
40    /// The inner service can implement either the [`Service`] or [`AsyncService`] trait.
41    pub fn new(inner: T, trusted_subject_names: Arc<Refreshable<HashSet<String>, Error>>) -> Self {
42        TlsClientAuthenticationService {
43            inner,
44            trusted_subject_names,
45        }
46    }
47}
48
49impl<T, I, O> Service<I, O> for TlsClientAuthenticationService<T>
50where
51    T: Service<I, O>,
52    I: 'static,
53    O: 'static,
54{
55    fn endpoints(
56        &self,
57        runtime: &Arc<ConjureRuntime>,
58    ) -> Vec<Box<dyn Endpoint<I, O> + Sync + Send>> {
59        self.inner
60            .endpoints(runtime)
61            .into_iter()
62            .map(|inner| {
63                Box::new(TlsClientAuthenticationEndpoint {
64                    inner,
65                    trusted_subject_names: self.trusted_subject_names.clone(),
66                }) as _
67            })
68            .collect()
69    }
70}
71
72impl<T, I, O> AsyncService<I, O> for TlsClientAuthenticationService<T>
73where
74    T: AsyncService<I, O>,
75    I: 'static + Send,
76    O: 'static,
77{
78    fn endpoints(&self, runtime: &Arc<ConjureRuntime>) -> Vec<BoxAsyncEndpoint<'static, I, O>> {
79        self.inner
80            .endpoints(runtime)
81            .into_iter()
82            .map(|inner| {
83                BoxAsyncEndpoint::new(TlsClientAuthenticationEndpoint {
84                    inner,
85                    trusted_subject_names: self.trusted_subject_names.clone(),
86                }) as _
87            })
88            .collect()
89    }
90}
91
92struct TlsClientAuthenticationEndpoint<T> {
93    inner: T,
94    trusted_subject_names: Arc<Refreshable<HashSet<String>, Error>>,
95}
96
97impl<T> TlsClientAuthenticationEndpoint<T> {
98    fn check_request<I>(&self, req: &Request<I>) -> Result<(), Error> {
99        let client_cert = match req.extensions().get::<ClientCertificate>() {
100            Some(client_cert) => client_cert,
101            None => {
102                return Err(Error::service_safe(
103                    "client did not provide a certificate",
104                    PermissionDenied::new(),
105                ))
106            }
107        };
108
109        let cert = EndEntityCert::try_from(client_cert.cert()).map_err(Error::internal_safe)?;
110        let valid = self
111            .trusted_subject_names
112            .get()
113            .iter()
114            .flat_map(|name| ServerName::try_from(&**name).ok())
115            .any(|name| cert.verify_is_valid_for_subject_name(&name).is_ok());
116
117        if valid {
118            Ok(())
119        } else {
120            Err(Error::internal_safe(
121                "Client certificate is not valid for any trusted subject name",
122            ))
123        }
124    }
125}
126
127impl<T> EndpointMetadata for TlsClientAuthenticationEndpoint<T>
128where
129    T: EndpointMetadata,
130{
131    fn method(&self) -> Method {
132        self.inner.method()
133    }
134
135    fn path(&self) -> &[PathSegment] {
136        self.inner.path()
137    }
138
139    fn template(&self) -> &str {
140        self.inner.template()
141    }
142
143    fn service_name(&self) -> &str {
144        self.inner.service_name()
145    }
146
147    fn name(&self) -> &str {
148        self.inner.name()
149    }
150
151    fn deprecated(&self) -> Option<&str> {
152        self.inner.deprecated()
153    }
154}
155
156impl<T, I, O> Endpoint<I, O> for TlsClientAuthenticationEndpoint<T>
157where
158    T: Endpoint<I, O>,
159{
160    fn handle(
161        &self,
162        req: Request<I>,
163        response_extensions: &mut Extensions,
164    ) -> Result<Response<ResponseBody<O>>, Error> {
165        self.check_request(&req)?;
166        self.inner.handle(req, response_extensions)
167    }
168}
169
170impl<T, I, O> AsyncEndpoint<I, O> for TlsClientAuthenticationEndpoint<T>
171where
172    T: AsyncEndpoint<I, O> + Sync + Send,
173    I: Send,
174{
175    async fn handle(
176        &self,
177        req: Request<I>,
178        response_extensions: &mut Extensions,
179    ) -> Result<Response<AsyncResponseBody<O>>, Error> {
180        self.check_request(&req)?;
181        self.inner.handle(req, response_extensions).await
182    }
183}