1use std::{error::Error, fs::File, io::BufReader, path::PathBuf, str::FromStr};
9
10use rustls::{
11 pki_types::{
12 pem::{PemObject, SectionKind},
13 CertificateDer, PrivateKeyDer,
14 },
15 server::ServerSessionMemoryCache,
16 ClientConfig, RootCertStore, ServerConfig,
17};
18use rustls_pemfile::certs;
19
20use crate::get_crate_root;
21
22pub fn load_tls_server_config(
28 cert_path: &str,
29 key_path: &str,
30) -> Result<ServerConfig, Box<dyn Error>> {
31 let (cert_chain, key) = load_chain_and_key(cert_path, key_path)?;
33 let mut config = ServerConfig::builder()
34 .with_no_client_auth()
35 .with_single_cert(cert_chain, key)?;
36
37 config.session_storage = ServerSessionMemoryCache::new(256);
40 Ok(config)
41}
42
43pub fn load_tls_client_config() -> Result<ClientConfig, Box<dyn Error>> {
44 let cert_chain = load_default_ca_cert()?;
45 let mut root_store: RootCertStore = RootCertStore::empty();
46 root_store.add_parsable_certificates(cert_chain);
47
48 let config = ClientConfig::builder()
49 .with_root_certificates(root_store)
50 .with_no_client_auth();
51
52 Ok(config)
53}
54
55pub fn load_tls_client_config_cert(
56 cert_path: &str,
57 key_path: &str,
58) -> Result<ClientConfig, Box<dyn Error>> {
59 let cert_chain = load_default_ca_cert()?;
60 let mut root_store: RootCertStore = RootCertStore::empty();
61 root_store.add_parsable_certificates(cert_chain);
62
63 let (cert_chain, key) = load_chain_and_key(cert_path, key_path)?;
65 let config = ClientConfig::builder()
66 .with_root_certificates(root_store)
67 .with_client_auth_cert(cert_chain, key)?;
68
69 Ok(config)
70}
71
72fn load_chain_and_key(
73 cert_path: &str,
74 key_path: &str,
75) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), Box<dyn Error>> {
76 let cert_file = &mut BufReader::new(File::open(cert_path)?);
78 let cert_chain = certs(cert_file)?
79 .into_iter()
80 .map(|der| CertificateDer::from_pem(SectionKind::Certificate, der).unwrap())
81 .collect::<Vec<_>>();
82
83 let key = PrivateKeyDer::from_pem_file(key_path)?;
85
86 Ok((cert_chain, key))
87}
88
89pub fn ca_cert_path() -> Result<PathBuf, Box<dyn Error>> {
92 let crate_root = get_crate_root().unwrap_or(PathBuf::from_str(".")?);
93 let ca_path = crate_root.join("certs").join("ca.cert.pem");
94 Ok(ca_path)
95}
96
97pub fn load_default_ca_cert() -> Result<Vec<CertificateDer<'static>>, Box<dyn Error>> {
99 let cert_chain = load_ca_cert(
100 ca_cert_path()?
101 .to_str()
102 .ok_or("Failed to get CA certificate path")?,
103 )?;
104 Ok(cert_chain)
105}
106
107fn load_ca_cert(cert_path: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn Error>> {
109 let cert_file = &mut BufReader::new(File::open(cert_path)?);
110 let cert_chain = certs(cert_file)?
111 .into_iter()
112 .map(|der| CertificateDer::from_pem(SectionKind::Certificate, der).unwrap())
113 .collect::<Vec<_>>();
114 Ok(cert_chain)
115}
116
117#[cfg(test)]
118mod tests {
119 use std::path::Path;
120
121 use tempfile::TempDir;
122
123 use super::*;
124
125 fn install_crypto_provider() {
126 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
127 }
128
129 fn generate_test_certs(dir: &Path) -> (PathBuf, PathBuf) {
130 let rcgen::CertifiedKey { cert, key_pair } =
131 rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
132 let cert_path = dir.join("test.cert.pem");
133 let key_path = dir.join("test.key.pem");
134 std::fs::write(&cert_path, cert.pem()).unwrap();
135 std::fs::write(&key_path, key_pair.serialize_pem()).unwrap();
136 (cert_path, key_path)
137 }
138
139 #[test]
140 fn test_load_tls_server_config() {
141 install_crypto_provider();
142 let temp_dir = TempDir::new().unwrap();
143 let (cert_path, key_path) = generate_test_certs(temp_dir.path());
144
145 let result =
146 load_tls_server_config(cert_path.to_str().unwrap(), key_path.to_str().unwrap());
147 assert!(result.is_ok());
148 }
149
150 #[test]
151 fn test_load_tls_server_config_invalid_cert() {
152 let result = load_tls_server_config("/nonexistent/cert.pem", "/nonexistent/key.pem");
153 assert!(result.is_err());
154 }
155
156 #[test]
157 fn test_load_tls_server_config_invalid_key() {
158 let temp_dir = TempDir::new().unwrap();
159 let (cert_path, _) = generate_test_certs(temp_dir.path());
160
161 let result = load_tls_server_config(cert_path.to_str().unwrap(), "/nonexistent/key.pem");
162 assert!(result.is_err());
163 }
164
165 #[test]
166 fn test_load_chain_and_key() {
167 let temp_dir = TempDir::new().unwrap();
168 let (cert_path, key_path) = generate_test_certs(temp_dir.path());
169
170 let (chain, _key) =
171 load_chain_and_key(cert_path.to_str().unwrap(), key_path.to_str().unwrap()).unwrap();
172 assert_eq!(chain.len(), 1);
173 }
174
175 #[test]
176 fn test_load_chain_and_key_invalid_path() {
177 let result = load_chain_and_key("/nonexistent/cert.pem", "/nonexistent/key.pem");
178 assert!(result.is_err());
179 }
180
181 #[test]
182 fn test_ca_cert_path() {
183 let path = ca_cert_path().unwrap();
184 assert!(path.ends_with("certs/ca.cert.pem") || path.ends_with("certs\\ca.cert.pem"));
185 }
186
187 #[test]
188 fn test_load_ca_cert() {
189 let temp_dir = TempDir::new().unwrap();
190 let (cert_path, _) = generate_test_certs(temp_dir.path());
191
192 let certs = load_ca_cert(cert_path.to_str().unwrap()).unwrap();
193 assert_eq!(certs.len(), 1);
194 }
195
196 #[test]
197 fn test_load_ca_cert_invalid_path() {
198 let result = load_ca_cert("/nonexistent/ca.cert.pem");
199 assert!(result.is_err());
200 }
201}