witchcraft_server/tls/
tls_client_authentication.rs1use crate::tls::ClientCertificate;
15use conjure_error::{Error, PermissionDenied};
16use conjure_http::server::{
17 AsyncEndpoint, AsyncResponseBody, AsyncService, BoxAsyncEndpoint, ConjureRuntime, Endpoint,
18 EndpointMetadata, PathSegment, ResponseBody, Service,
19};
20use http::{Extensions, Method, Request, Response};
21use refreshable::Refreshable;
22use rustls_pki_types::ServerName;
23use std::collections::HashSet;
24use std::str;
25use std::sync::Arc;
26use webpki::EndEntityCert;
27
28pub struct TlsClientAuthenticationService<T> {
33 inner: T,
34 trusted_subject_names: Arc<Refreshable<HashSet<String>, Error>>,
35}
36
37impl<T> TlsClientAuthenticationService<T> {
38 pub fn new(inner: T, trusted_subject_names: Arc<Refreshable<HashSet<String>, Error>>) -> Self {
42 TlsClientAuthenticationService {
43 inner,
44 trusted_subject_names,
45 }
46 }
47}
48
49impl<T, I, O> Service<I, O> for TlsClientAuthenticationService<T>
50where
51 T: Service<I, O>,
52 I: 'static,
53 O: 'static,
54{
55 fn endpoints(
56 &self,
57 runtime: &Arc<ConjureRuntime>,
58 ) -> Vec<Box<dyn Endpoint<I, O> + Sync + Send>> {
59 self.inner
60 .endpoints(runtime)
61 .into_iter()
62 .map(|inner| {
63 Box::new(TlsClientAuthenticationEndpoint {
64 inner,
65 trusted_subject_names: self.trusted_subject_names.clone(),
66 }) as _
67 })
68 .collect()
69 }
70}
71
72impl<T, I, O> AsyncService<I, O> for TlsClientAuthenticationService<T>
73where
74 T: AsyncService<I, O>,
75 I: 'static + Send,
76 O: 'static,
77{
78 fn endpoints(&self, runtime: &Arc<ConjureRuntime>) -> Vec<BoxAsyncEndpoint<'static, I, O>> {
79 self.inner
80 .endpoints(runtime)
81 .into_iter()
82 .map(|inner| {
83 BoxAsyncEndpoint::new(TlsClientAuthenticationEndpoint {
84 inner,
85 trusted_subject_names: self.trusted_subject_names.clone(),
86 }) as _
87 })
88 .collect()
89 }
90}
91
92struct TlsClientAuthenticationEndpoint<T> {
93 inner: T,
94 trusted_subject_names: Arc<Refreshable<HashSet<String>, Error>>,
95}
96
97impl<T> TlsClientAuthenticationEndpoint<T> {
98 fn check_request<I>(&self, req: &Request<I>) -> Result<(), Error> {
99 let client_cert = match req.extensions().get::<ClientCertificate>() {
100 Some(client_cert) => client_cert,
101 None => {
102 return Err(Error::service_safe(
103 "client did not provide a certificate",
104 PermissionDenied::new(),
105 ))
106 }
107 };
108
109 let cert = EndEntityCert::try_from(client_cert.cert()).map_err(Error::internal_safe)?;
110 let valid = self
111 .trusted_subject_names
112 .get()
113 .iter()
114 .flat_map(|name| ServerName::try_from(&**name).ok())
115 .any(|name| cert.verify_is_valid_for_subject_name(&name).is_ok());
116
117 if valid {
118 Ok(())
119 } else {
120 Err(Error::internal_safe(
121 "Client certificate is not valid for any trusted subject name",
122 ))
123 }
124 }
125}
126
127impl<T> EndpointMetadata for TlsClientAuthenticationEndpoint<T>
128where
129 T: EndpointMetadata,
130{
131 fn method(&self) -> Method {
132 self.inner.method()
133 }
134
135 fn path(&self) -> &[PathSegment] {
136 self.inner.path()
137 }
138
139 fn template(&self) -> &str {
140 self.inner.template()
141 }
142
143 fn service_name(&self) -> &str {
144 self.inner.service_name()
145 }
146
147 fn name(&self) -> &str {
148 self.inner.name()
149 }
150
151 fn deprecated(&self) -> Option<&str> {
152 self.inner.deprecated()
153 }
154}
155
156impl<T, I, O> Endpoint<I, O> for TlsClientAuthenticationEndpoint<T>
157where
158 T: Endpoint<I, O>,
159{
160 fn handle(
161 &self,
162 req: Request<I>,
163 response_extensions: &mut Extensions,
164 ) -> Result<Response<ResponseBody<O>>, Error> {
165 self.check_request(&req)?;
166 self.inner.handle(req, response_extensions)
167 }
168}
169
170impl<T, I, O> AsyncEndpoint<I, O> for TlsClientAuthenticationEndpoint<T>
171where
172 T: AsyncEndpoint<I, O> + Sync + Send,
173 I: Send,
174{
175 async fn handle(
176 &self,
177 req: Request<I>,
178 response_extensions: &mut Extensions,
179 ) -> Result<Response<AsyncResponseBody<O>>, Error> {
180 self.check_request(&req)?;
181 self.inner.handle(req, response_extensions).await
182 }
183}