Skip to main content

taxy/proxy/
tls.rs

1use crate::certs::Cert;
2use crate::server::cert_list::CertList;
3use dashmap::DashMap;
4use std::fmt;
5use std::str::FromStr;
6use std::sync::Arc;
7use taxy_api::cert::CertKind;
8use taxy_api::error::Error;
9use taxy_api::id::ShortId;
10use taxy_api::subject_name::SubjectName;
11use taxy_api::tls::TlsState;
12use tokio_rustls::rustls::server::{ClientHello, ResolvesServerCert};
13use tokio_rustls::rustls::sign::CertifiedKey;
14use tokio_rustls::rustls::ServerConfig;
15use tokio_rustls::TlsAcceptor;
16use tracing::error;
17
18pub struct TlsTermination {
19    pub server_names: Vec<SubjectName>,
20    pub acceptor: Option<TlsAcceptor>,
21    pub alpn_protocols: Vec<Vec<u8>>,
22}
23
24impl fmt::Debug for TlsTermination {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        f.debug_struct("TlsTermination")
27            .field("server_names", &self.server_names)
28            .finish()
29    }
30}
31
32impl TlsTermination {
33    pub fn new(
34        config: &taxy_api::tls::TlsTermination,
35        alpn_protocols: Vec<Vec<u8>>,
36    ) -> Result<Self, Error> {
37        let mut server_names = Vec::new();
38        for name in &config.server_names {
39            let name = SubjectName::from_str(name)?;
40            server_names.push(name);
41        }
42        Ok(Self {
43            server_names,
44            acceptor: None,
45            alpn_protocols,
46        })
47    }
48
49    pub async fn setup(&mut self, certs: &CertList) -> TlsState {
50        let resolver: Arc<dyn ResolvesServerCert> = Arc::new(CertResolver::new(
51            certs
52                .iter()
53                .filter(|cert| cert.kind == CertKind::Server)
54                .cloned()
55                .collect(),
56            self.server_names.clone(),
57            true,
58        ));
59
60        let mut server_config = ServerConfig::builder()
61            .with_no_client_auth()
62            .with_cert_resolver(resolver);
63        server_config
64            .alpn_protocols
65            .clone_from(&self.alpn_protocols);
66
67        let server_config = Arc::new(server_config);
68        self.acceptor = Some(TlsAcceptor::from(server_config));
69
70        TlsState::Active
71    }
72}
73
74#[derive(Debug, Default)]
75pub struct CertResolver {
76    certs: Vec<Arc<Cert>>,
77    default_names: Vec<SubjectName>,
78    sni: bool,
79    cache: DashMap<ShortId, Arc<CertifiedKey>>,
80}
81
82impl CertResolver {
83    pub fn new(certs: Vec<Arc<Cert>>, default_names: Vec<SubjectName>, sni: bool) -> Self {
84        Self {
85            certs,
86            default_names,
87            sni,
88            cache: DashMap::new(),
89        }
90    }
91}
92
93impl ResolvesServerCert for CertResolver {
94    fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
95        let sni = client_hello
96            .server_name()
97            .filter(|_| self.sni)
98            .map(|sni| SubjectName::DnsName(sni.into()))
99            .into_iter()
100            .collect::<Vec<_>>();
101
102        let names = if sni.is_empty() {
103            &self.default_names
104        } else {
105            &sni
106        };
107
108        let cert = self
109            .certs
110            .iter()
111            .find(|cert| cert.is_valid() && names.iter().all(|name| cert.has_subject_name(name)))?;
112
113        if let Some(cert) = self.cache.get(&cert.id()) {
114            Some(cert.clone())
115        } else {
116            let certified = match cert.certified_key() {
117                Ok(certified) => Arc::new(certified),
118                Err(err) => {
119                    error!("failed to load certified key: {}", err);
120                    return None;
121                }
122            };
123            self.cache.insert(cert.id(), certified.clone());
124            Some(certified)
125        }
126    }
127}