crabka_client_core/
security.rs1use std::path::PathBuf;
7use std::sync::Arc;
8
9use crabka_security::ListenerProtocol;
10use rustls_pki_types::pem::PemObject;
11use tokio_rustls::TlsConnector;
12
13pub use crate::sasl::SaslCredentials;
14
15#[derive(Debug, Clone)]
18pub struct TlsConnectorConfig {
19 pub trust_roots_pem: Option<PathBuf>,
24 pub server_name: String,
27 pub client_identity: Option<(PathBuf, PathBuf)>,
34}
35
36impl TlsConnectorConfig {
37 pub fn build(&self) -> Result<Arc<rustls::ClientConfig>, String> {
46 let mut roots = rustls::RootCertStore::empty();
47 if let Some(path) = &self.trust_roots_pem {
48 for cert in rustls::pki_types::CertificateDer::pem_file_iter(path)
49 .map_err(|e| format!("trust roots load {}: {e}", path.display()))?
50 {
51 let cert = cert.map_err(|e| format!("trust roots parse: {e}"))?;
52 roots
53 .add(cert)
54 .map_err(|e| format!("trust roots add: {e}"))?;
55 }
56 }
57 let cfg = if let Some((cert_pem, key_pem)) = &self.client_identity {
58 let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
59 rustls::pki_types::CertificateDer::pem_file_iter(cert_pem)
60 .map_err(|e| format!("client cert load {}: {e}", cert_pem.display()))?
61 .collect::<Result<Vec<_>, _>>()
62 .map_err(|e| format!("client cert parse: {e}"))?;
63 let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_pem)
64 .map_err(|e| format!("client key load {}: {e}", key_pem.display()))?;
65 rustls::ClientConfig::builder()
66 .with_root_certificates(roots)
67 .with_client_auth_cert(certs, key)
68 .map_err(|e| format!("client auth cert: {e}"))?
69 } else {
70 rustls::ClientConfig::builder()
71 .with_root_certificates(roots)
72 .with_no_client_auth()
73 };
74 Ok(Arc::new(cfg))
75 }
76
77 pub fn connector(&self) -> Result<TlsConnector, String> {
82 Ok(TlsConnector::from(self.build()?))
83 }
84}
85
86#[derive(Debug, Clone)]
90pub struct ClientSecurity {
91 pub protocol: ListenerProtocol,
92 pub tls: Option<TlsConnectorConfig>,
93 pub sasl: Option<SaslCredentials>,
94 pub sasl_host: Option<String>,
104}
105
106impl ClientSecurity {
107 #[must_use]
114 pub fn sasl_handshake_host<'a>(&'a self, target_host: Option<&'a str>) -> &'a str {
115 if let Some(h) = self.sasl_host.as_deref() {
116 h
117 } else if let Some(tls) = self.tls.as_ref() {
118 tls.server_name.as_str()
119 } else if let Some(h) = target_host {
120 h
121 } else {
122 "localhost"
123 }
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use assert2::assert;
131 use crabka_security::ListenerProtocol;
132
133 #[test]
134 fn plaintext_security_has_no_tls_or_sasl() {
135 let s = ClientSecurity {
136 protocol: ListenerProtocol::Plaintext,
137 tls: None,
138 sasl: None,
139 sasl_host: None,
140 };
141 assert!(!s.protocol.requires_tls());
142 assert!(!s.protocol.requires_sasl());
143 }
144
145 #[test]
146 fn sasl_plaintext_carries_creds() {
147 let s = ClientSecurity {
148 protocol: ListenerProtocol::SaslPlaintext,
149 tls: None,
150 sasl: Some(SaslCredentials::Plain {
151 username: "u".into(),
152 password: "p".into(),
153 }),
154 sasl_host: None,
155 };
156 assert!(s.protocol.requires_sasl());
157 assert!(matches!(s.sasl, Some(SaslCredentials::Plain { .. })));
158 }
159
160 #[test]
161 fn sasl_handshake_host_prefers_explicit_field() {
162 let s = ClientSecurity {
165 protocol: ListenerProtocol::SaslPlaintext,
166 tls: None,
167 sasl: None,
168 sasl_host: Some("kdc-broker.example.com".into()),
169 };
170 assert!(s.sasl_handshake_host(Some("10.0.0.5")) == "kdc-broker.example.com");
171 }
172
173 #[test]
174 fn sasl_handshake_host_falls_back_to_tls_then_target_then_localhost() {
175 let with_tls = ClientSecurity {
177 protocol: ListenerProtocol::SaslSsl,
178 tls: Some(TlsConnectorConfig {
179 trust_roots_pem: None,
180 server_name: "tls-host".into(),
181 client_identity: None,
182 }),
183 sasl: None,
184 sasl_host: None,
185 };
186 assert!(with_tls.sasl_handshake_host(Some("10.0.0.5")) == "tls-host");
187
188 let no_tls = ClientSecurity {
190 protocol: ListenerProtocol::SaslPlaintext,
191 tls: None,
192 sasl: None,
193 sasl_host: None,
194 };
195 assert!(no_tls.sasl_handshake_host(Some("10.0.0.5")) == "10.0.0.5");
196
197 assert!(no_tls.sasl_handshake_host(None) == "localhost");
199 }
200
201 #[test]
202 fn tls_connector_config_builds_client_config() {
203 let _ = rustls::crypto::ring::default_provider().install_default();
205 let cfg = TlsConnectorConfig {
206 trust_roots_pem: None,
207 server_name: "broker".into(),
208 client_identity: None,
209 };
210 cfg.build().expect("client config builds with empty roots");
211 }
212
213 #[test]
214 fn tls_connector_config_client_identity_none_builds_and_bogus_path_errors() {
215 let _ = rustls::crypto::ring::default_provider().install_default();
216
217 let no_id = TlsConnectorConfig {
219 trust_roots_pem: None,
220 server_name: "broker".into(),
221 client_identity: None,
222 };
223 no_id
224 .build()
225 .expect("one-way TLS builds with client_identity=None");
226
227 let bogus = TlsConnectorConfig {
229 trust_roots_pem: None,
230 server_name: "broker".into(),
231 client_identity: Some((
232 "/nonexistent/cert.pem".into(),
233 "/nonexistent/key.pem".into(),
234 )),
235 };
236 assert!(
237 bogus.build().is_err(),
238 "bogus client-identity path returns Err"
239 );
240 }
241}