1use crate::certs::Cert;
2use crate::server::cert_list::CertList;
3use dashmap::DashMap;
4use std::fmt;
5use std::str::FromStr;
6use std::sync::Arc;
7use taxy_api::cert::CertKind;
8use taxy_api::error::Error;
9use taxy_api::id::ShortId;
10use taxy_api::subject_name::SubjectName;
11use taxy_api::tls::TlsState;
12use tokio_rustls::rustls::server::{ClientHello, ResolvesServerCert};
13use tokio_rustls::rustls::sign::CertifiedKey;
14use tokio_rustls::rustls::ServerConfig;
15use tokio_rustls::TlsAcceptor;
16use tracing::error;
17
18pub struct TlsTermination {
19 pub server_names: Vec<SubjectName>,
20 pub acceptor: Option<TlsAcceptor>,
21 pub alpn_protocols: Vec<Vec<u8>>,
22}
23
24impl fmt::Debug for TlsTermination {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 f.debug_struct("TlsTermination")
27 .field("server_names", &self.server_names)
28 .finish()
29 }
30}
31
32impl TlsTermination {
33 pub fn new(
34 config: &taxy_api::tls::TlsTermination,
35 alpn_protocols: Vec<Vec<u8>>,
36 ) -> Result<Self, Error> {
37 let mut server_names = Vec::new();
38 for name in &config.server_names {
39 let name = SubjectName::from_str(name)?;
40 server_names.push(name);
41 }
42 Ok(Self {
43 server_names,
44 acceptor: None,
45 alpn_protocols,
46 })
47 }
48
49 pub async fn setup(&mut self, certs: &CertList) -> TlsState {
50 let resolver: Arc<dyn ResolvesServerCert> = Arc::new(CertResolver::new(
51 certs
52 .iter()
53 .filter(|cert| cert.kind == CertKind::Server)
54 .cloned()
55 .collect(),
56 self.server_names.clone(),
57 true,
58 ));
59
60 let mut server_config = ServerConfig::builder()
61 .with_no_client_auth()
62 .with_cert_resolver(resolver);
63 server_config
64 .alpn_protocols
65 .clone_from(&self.alpn_protocols);
66
67 let server_config = Arc::new(server_config);
68 self.acceptor = Some(TlsAcceptor::from(server_config));
69
70 TlsState::Active
71 }
72}
73
74#[derive(Debug, Default)]
75pub struct CertResolver {
76 certs: Vec<Arc<Cert>>,
77 default_names: Vec<SubjectName>,
78 sni: bool,
79 cache: DashMap<ShortId, Arc<CertifiedKey>>,
80}
81
82impl CertResolver {
83 pub fn new(certs: Vec<Arc<Cert>>, default_names: Vec<SubjectName>, sni: bool) -> Self {
84 Self {
85 certs,
86 default_names,
87 sni,
88 cache: DashMap::new(),
89 }
90 }
91}
92
93impl ResolvesServerCert for CertResolver {
94 fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
95 let sni = client_hello
96 .server_name()
97 .filter(|_| self.sni)
98 .map(|sni| SubjectName::DnsName(sni.into()))
99 .into_iter()
100 .collect::<Vec<_>>();
101
102 let names = if sni.is_empty() {
103 &self.default_names
104 } else {
105 &sni
106 };
107
108 let cert = self
109 .certs
110 .iter()
111 .find(|cert| cert.is_valid() && names.iter().all(|name| cert.has_subject_name(name)))?;
112
113 if let Some(cert) = self.cache.get(&cert.id()) {
114 Some(cert.clone())
115 } else {
116 let certified = match cert.certified_key() {
117 Ok(certified) => Arc::new(certified),
118 Err(err) => {
119 error!("failed to load certified key: {}", err);
120 return None;
121 }
122 };
123 self.cache.insert(cert.id(), certified.clone());
124 Some(certified)
125 }
126 }
127}