Skip to main content

solti_tls/
server.rs

1//! Server-side TLS configuration.
2
3use std::path::PathBuf;
4use std::sync::Arc;
5
6use crate::{PemSource, TlsError};
7
8/// Server-side TLS configuration.
9///
10/// _Construct via [`ServerTlsConfig::builder`]_.
11#[derive(Debug, Clone)]
12pub struct ServerTlsConfig {
13    /// Server certificate chain (leaf first).
14    pub cert: PemSource,
15    /// Server private key (PKCS#8, PKCS#1, or SEC1).
16    pub key: PemSource,
17    /// Trusted CA bundle for verifying client certificates (mTLS).
18    /// `None` = standard TLS (no client cert required).
19    pub client_ca: Option<PemSource>,
20    /// ALPN protocol list, in preference order (e.g. `[b"h2"]` for gRPC).
21    /// Empty = no ALPN negotiation requested.
22    pub alpn: Vec<Vec<u8>>,
23}
24
25impl ServerTlsConfig {
26    /// Start a new builder.
27    pub fn builder() -> ServerTlsConfigBuilder {
28        ServerTlsConfigBuilder::default()
29    }
30
31    /// Build a [`rustls::ServerConfig`] from this configuration.
32    ///
33    /// Reads PEM sources from disk (or memory), parses certs and key, optionally constructs a `WebPkiClientVerifier` for mTLS, and applies ALPN settings.
34    /// All I/O and parse errors surface here.
35    ///
36    /// Auto-installs the `ring` `CryptoProvider` if no provider is set process-wide.
37    pub fn into_rustls_config(self) -> Result<rustls::ServerConfig, TlsError> {
38        crate::ensure_default_provider();
39
40        let cert_bytes = self.cert.read()?;
41        let key_bytes = self.key.read()?;
42
43        let certs = crate::load_certs_from_pem(cert_bytes.as_slice())?;
44        let key = crate::load_key_from_pem(key_bytes.as_slice())?;
45
46        let builder = rustls::ServerConfig::builder();
47        let server_builder = match self.client_ca {
48            Some(ca_src) => {
49                let ca_bytes = ca_src.read()?;
50                let ca_certs = crate::load_certs_from_pem(ca_bytes.as_slice())?;
51                let mut roots = rustls::RootCertStore::empty();
52                for ca in ca_certs {
53                    roots.add(ca)?;
54                }
55                let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(roots))
56                    .build()
57                    .map_err(|e| TlsError::ClientVerifier(e.to_string()))?;
58                builder.with_client_cert_verifier(verifier)
59            }
60            None => builder.with_no_client_auth(),
61        };
62
63        let mut config = server_builder.with_single_cert(certs, key)?;
64        config.alpn_protocols = self.alpn;
65        Ok(config)
66    }
67}
68
69/// Incremental builder for [`ServerTlsConfig`].
70#[derive(Debug, Default, Clone)]
71pub struct ServerTlsConfigBuilder {
72    cert: Option<PemSource>,
73    key: Option<PemSource>,
74    client_ca: Option<PemSource>,
75    alpn: Vec<Vec<u8>>,
76}
77
78impl ServerTlsConfigBuilder {
79    /// Set the server cert chain from any [`PemSource`].
80    pub fn cert(mut self, src: PemSource) -> Self {
81        self.cert = Some(src);
82        self
83    }
84
85    /// Set the server private key from any [`PemSource`].
86    pub fn key(mut self, src: PemSource) -> Self {
87        self.key = Some(src);
88        self
89    }
90
91    /// Set the ALPN protocol list, in preference order.
92    ///
93    /// Pass `["h2"]` for gRPC-only, `["h2", "http/1.1"]` for axum HTTP.
94    /// Default is empty (no ALPN negotiation).
95    pub fn with_alpn<I, S>(mut self, protocols: I) -> Self
96    where
97        I: IntoIterator<Item = S>,
98        S: Into<Vec<u8>>,
99    {
100        self.alpn = protocols.into_iter().map(Into::into).collect();
101        self
102    }
103
104    /// Convenience: set the server cert chain from a file path.
105    pub fn cert_pem_file(self, path: impl Into<PathBuf>) -> Self {
106        self.cert(PemSource::Path(path.into()))
107    }
108
109    /// Convenience: set the server cert chain from in-memory bytes.
110    pub fn cert_pem_bytes(self, bytes: impl Into<Vec<u8>>) -> Self {
111        self.cert(PemSource::Bytes(bytes.into()))
112    }
113
114    /// Convenience: set the server private key from a file path.
115    pub fn key_pem_file(self, path: impl Into<PathBuf>) -> Self {
116        self.key(PemSource::Path(path.into()))
117    }
118
119    /// Convenience: set the server private key from in-memory bytes.
120    pub fn key_pem_bytes(self, bytes: impl Into<Vec<u8>>) -> Self {
121        self.key(PemSource::Bytes(bytes.into()))
122    }
123
124    /// Convenience: enable mTLS with a CA bundle from a file path.
125    pub fn require_client_ca_pem_file(self, path: impl Into<PathBuf>) -> Self {
126        self.require_client_ca(PemSource::Path(path.into()))
127    }
128
129    /// Convenience: enable mTLS with a CA bundle from in-memory bytes.
130    pub fn require_client_ca_pem_bytes(self, bytes: impl Into<Vec<u8>>) -> Self {
131        self.require_client_ca(PemSource::Bytes(bytes.into()))
132    }
133
134    /// Require client certificates signed by this CA bundle (turns on mTLS).
135    pub fn require_client_ca(mut self, src: PemSource) -> Self {
136        self.client_ca = Some(src);
137        self
138    }
139
140    /// Build.
141    pub fn build(self) -> Result<ServerTlsConfig, TlsError> {
142        let cert = self.cert.ok_or(TlsError::MissingField("cert"))?;
143        let key = self.key.ok_or(TlsError::MissingField("key"))?;
144        Ok(ServerTlsConfig {
145            cert,
146            key,
147            client_ca: self.client_ca,
148            alpn: self.alpn,
149        })
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::PemSource;
157
158    #[test]
159    fn builder_returns_config_when_cert_and_key_provided() {
160        let cfg = ServerTlsConfig::builder()
161            .cert_pem_bytes(b"--FAKE CERT--".to_vec())
162            .key_pem_bytes(b"--FAKE KEY--".to_vec())
163            .build()
164            .unwrap();
165        assert!(matches!(cfg.cert, PemSource::Bytes(_)));
166        assert!(matches!(cfg.key, PemSource::Bytes(_)));
167    }
168
169    #[test]
170    fn builder_errors_when_cert_is_missing() {
171        let err = ServerTlsConfig::builder()
172            .key_pem_bytes(vec![1])
173            .build()
174            .unwrap_err();
175        assert!(matches!(err, TlsError::MissingField("cert")));
176    }
177
178    #[test]
179    fn builder_errors_when_key_is_missing() {
180        let err = ServerTlsConfig::builder()
181            .cert_pem_bytes(vec![1])
182            .build()
183            .unwrap_err();
184        assert!(matches!(err, TlsError::MissingField("key")));
185    }
186
187    #[test]
188    fn cert_pem_file_creates_path_source() {
189        let cfg = ServerTlsConfig::builder()
190            .cert_pem_file("/etc/server.crt")
191            .key_pem_bytes(vec![1])
192            .build()
193            .unwrap();
194        assert!(matches!(cfg.cert, PemSource::Path(_)));
195    }
196
197    #[test]
198    fn client_ca_defaults_to_none() {
199        let cfg = ServerTlsConfig::builder()
200            .cert_pem_bytes(vec![1])
201            .key_pem_bytes(vec![2])
202            .build()
203            .unwrap();
204        assert!(cfg.client_ca.is_none());
205    }
206
207    #[test]
208    fn require_client_ca_pem_bytes_enables_mtls() {
209        let cfg = ServerTlsConfig::builder()
210            .cert_pem_bytes(vec![1])
211            .key_pem_bytes(vec![2])
212            .require_client_ca_pem_bytes(b"--FAKE CA--".to_vec())
213            .build()
214            .unwrap();
215        assert!(matches!(cfg.client_ca, Some(PemSource::Bytes(_))));
216    }
217
218    #[test]
219    fn require_client_ca_pem_file_enables_mtls() {
220        let cfg = ServerTlsConfig::builder()
221            .cert_pem_bytes(vec![1])
222            .key_pem_bytes(vec![2])
223            .require_client_ca_pem_file("/etc/ca.crt")
224            .build()
225            .unwrap();
226        assert!(matches!(cfg.client_ca, Some(PemSource::Path(_))));
227    }
228
229    #[test]
230    fn alpn_defaults_to_empty() {
231        let cfg = ServerTlsConfig::builder()
232            .cert_pem_bytes(vec![1])
233            .key_pem_bytes(vec![2])
234            .build()
235            .unwrap();
236        assert!(cfg.alpn.is_empty());
237    }
238
239    #[test]
240    fn with_alpn_sets_protocols() {
241        let cfg = ServerTlsConfig::builder()
242            .cert_pem_bytes(vec![1])
243            .key_pem_bytes(vec![2])
244            .with_alpn(["h2", "http/1.1"])
245            .build()
246            .unwrap();
247        assert_eq!(cfg.alpn, vec![b"h2".to_vec(), b"http/1.1".to_vec()]);
248    }
249
250    fn rcgen_self_signed() -> (Vec<u8>, Vec<u8>) {
251        let b = rcgen::generate_simple_self_signed(vec!["example.com".into()]).unwrap();
252        (
253            b.cert.pem().into_bytes(),
254            b.signing_key.serialize_pem().into_bytes(),
255        )
256    }
257
258    #[test]
259    fn into_rustls_config_succeeds_with_real_cert_and_key() {
260        let (cert, key) = rcgen_self_signed();
261        let cfg = ServerTlsConfig::builder()
262            .cert_pem_bytes(cert)
263            .key_pem_bytes(key)
264            .build()
265            .unwrap();
266
267        let _rustls = cfg.into_rustls_config().unwrap();
268    }
269
270    #[test]
271    fn into_rustls_config_succeeds_with_mtls_client_ca() {
272        let (cert, key) = rcgen_self_signed();
273        let (ca, _) = rcgen_self_signed();
274        let cfg = ServerTlsConfig::builder()
275            .cert_pem_bytes(cert)
276            .key_pem_bytes(key)
277            .require_client_ca_pem_bytes(ca)
278            .build()
279            .unwrap();
280
281        let _rustls = cfg.into_rustls_config().unwrap();
282    }
283
284    #[test]
285    fn into_rustls_config_propagates_alpn_to_rustls() {
286        let (cert, key) = rcgen_self_signed();
287        let cfg = ServerTlsConfig::builder()
288            .cert_pem_bytes(cert)
289            .key_pem_bytes(key)
290            .with_alpn(["h2"])
291            .build()
292            .unwrap();
293
294        let rustls = cfg.into_rustls_config().unwrap();
295        assert_eq!(rustls.alpn_protocols, vec![b"h2".to_vec()]);
296    }
297}