1use std::collections::HashMap;
2use std::fs::File;
3use std::future::Future;
4use std::io::BufReader;
5use std::net::SocketAddr;
6use std::path::Path;
7use std::pin::Pin;
8use std::sync::{Arc, Once};
9use std::time::{Duration, Instant};
10
11use anyhow::{Context, Result, bail};
12use bytes::{BufMut, Bytes, BytesMut};
13use kafka_protocol::error::ParseResponseErrorCode;
14use kafka_protocol::messages::{
15 ApiVersionsRequest, RequestHeader, ResponseHeader, SaslAuthenticateRequest,
16 SaslHandshakeRequest,
17};
18use kafka_protocol::protocol::{
19 Decodable, HeaderVersion, Message, Request, StrBytes, VersionRange,
20 encode_request_header_into_buffer,
21};
22use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
23use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
24use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
25use tokio::net::TcpStream;
26use tokio_rustls::TlsConnector;
27use tokio_rustls::client::TlsStream;
28use tracing::{Instrument, debug, trace, trace_span};
29
30use super::scram::ScramClient;
31use super::select_api_version;
32use crate::config::{SaslConfig, SaslMechanism, SecurityProtocol, TlsConfig};
33use crate::constants::{API_VERSIONS_FALLBACK_VERSION, API_VERSIONS_PROBE_VERSION};
34use crate::telemetry;
35
36pub async fn connect_to_any_bootstrap(
37 servers: &[String],
38 client_id: &str,
39 timeout: Duration,
40 security_protocol: SecurityProtocol,
41 tls: &TlsConfig,
42 sasl: &SaslConfig,
43 tcp_connector: &Arc<dyn TcpConnector>,
44) -> Result<BrokerConnection> {
45 if servers.is_empty() {
46 bail!("no bootstrap servers configured");
47 }
48
49 let mut last_error: Option<anyhow::Error> = None;
50 for server in servers {
51 match BrokerConnection::connect_with_transport(
52 server,
53 client_id,
54 timeout,
55 security_protocol,
56 tls,
57 sasl,
58 tcp_connector,
59 )
60 .await
61 {
62 Ok(conn) => return Ok(conn),
63 Err(e) => {
64 debug!(server = %server, error = %e, "bootstrap connection failed, trying next server");
65 last_error = Some(e);
66 }
67 }
68 }
69 Err(last_error.unwrap())
70}
71
72type ConnectFuture<'a> = Pin<Box<dyn Future<Output = Result<ConnectedTcpStream>> + Send + 'a>>;
73
74pub trait BrokerIo: AsyncRead + AsyncWrite + Unpin + Send {}
75
76impl<T> BrokerIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
77
78pub enum ConnectedTcpStream {
79 Tokio(TcpStream),
80 Custom(Box<dyn BrokerIo>),
81}
82
83impl ConnectedTcpStream {
84 fn set_nodelay(&self, nodelay: bool) -> Result<()> {
85 match self {
86 Self::Tokio(stream) => stream.set_nodelay(nodelay)?,
87 Self::Custom(_) => {}
88 }
89 Ok(())
90 }
91}
92
93pub trait TcpConnector: std::fmt::Debug + Send + Sync {
94 fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a>;
95}
96
97#[derive(Debug, Default)]
98pub struct TokioTcpConnector;
99
100impl TcpConnector for TokioTcpConnector {
101 fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a> {
102 Box::pin(async move {
103 let tcp_stream = tokio::time::timeout(timeout, TcpStream::connect(address))
104 .await
105 .with_context(|| format!("timed out connecting to {address}"))?
106 .with_context(|| format!("failed to connect to {address}"))?;
107 Ok(ConnectedTcpStream::Tokio(tcp_stream))
108 })
109 }
110}
111
112pub struct BrokerConnection {
113 stream: BrokerStream,
114 next_correlation_id: i32,
115 api_versions: HashMap<i16, VersionRange>,
116 finalized_features: HashMap<String, i16>,
117}
118
119enum BrokerStream {
120 Plain(Box<dyn BrokerIo>),
121 Tls(Box<TlsStream<TcpStream>>),
122}
123
124impl BrokerStream {
125 async fn write_all(&mut self, frame: &[u8]) -> Result<()> {
126 match self {
127 Self::Plain(stream) => stream.write_all(frame).await?,
128 Self::Tls(stream) => stream.write_all(frame).await?,
129 }
130 Ok(())
131 }
132
133 async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
134 match self {
135 Self::Plain(stream) => {
136 stream.read_exact(buf).await?;
137 }
138 Self::Tls(stream) => {
139 stream.read_exact(buf).await?;
140 }
141 };
142 Ok(())
143 }
144}
145
146impl BrokerConnection {
147 pub async fn connect_with_transport(
148 address: &str,
149 client_id: &str,
150 timeout: Duration,
151 security_protocol: SecurityProtocol,
152 tls: &TlsConfig,
153 sasl: &SaslConfig,
154 tcp_connector: &Arc<dyn TcpConnector>,
155 ) -> Result<Self> {
156 let started = Instant::now();
157 let result = async {
158 debug!(?security_protocol, "connecting to broker");
159 let stream =
160 connect_stream(address, timeout, security_protocol, tls, tcp_connector).await?;
161 let mut connection = Self {
162 stream,
163 next_correlation_id: 1,
164 api_versions: HashMap::new(),
165 finalized_features: HashMap::new(),
166 };
167 if security_protocol.uses_sasl() {
168 connection.authenticate_sasl(client_id, sasl).await?;
169 }
170 connection.negotiate_versions(client_id).await?;
171 debug!(
172 api_keys = connection.api_versions.len(),
173 finalized_features = connection.finalized_features.len(),
174 ?security_protocol,
175 "connected to broker"
176 );
177 Ok(connection)
178 }
179 .instrument(tracing::debug_span!(
180 "broker_connect",
181 %address,
182 %client_id,
183 timeout_ms = timeout.as_millis()
184 ))
185 .await;
186 telemetry::record_broker_connection(
187 client_id,
188 address,
189 &format!("{security_protocol:?}"),
190 started.elapsed(),
191 result.is_ok(),
192 );
193 result
194 }
195
196 async fn authenticate_sasl(&mut self, client_id: &str, sasl: &SaslConfig) -> Result<()> {
197 let response = self
198 .send_request::<ApiVersionsRequest>(
199 client_id,
200 API_VERSIONS_FALLBACK_VERSION,
201 &ApiVersionsRequest::default(),
202 )
203 .await
204 .context("SASL ApiVersions probe failed")?;
205 if let Some(error) = response.error_code.err() {
206 bail!("SASL ApiVersions probe failed: {error}");
207 }
208
209 let api_versions = response
210 .api_keys
211 .into_iter()
212 .map(|api| {
213 (
214 api.api_key,
215 VersionRange {
216 min: api.min_version,
217 max: api.max_version,
218 },
219 )
220 })
221 .collect::<HashMap<_, _>>();
222
223 let handshake_version = api_versions
224 .get(&SaslHandshakeRequest::KEY)
225 .copied()
226 .map(|range| {
227 select_api_version(
228 SaslHandshakeRequest::KEY,
229 range,
230 SaslHandshakeRequest::VERSIONS,
231 SaslHandshakeRequest::VERSIONS.max,
232 )
233 })
234 .transpose()?
235 .unwrap_or(0);
236 let authenticate_version = api_versions
237 .get(&SaslAuthenticateRequest::KEY)
238 .copied()
239 .map(|range| {
240 select_api_version(
241 SaslAuthenticateRequest::KEY,
242 range,
243 SaslAuthenticateRequest::VERSIONS,
244 SaslAuthenticateRequest::VERSIONS.max,
245 )
246 })
247 .transpose()?;
248
249 let mechanism = sasl.mechanism.as_str();
250 let handshake =
251 SaslHandshakeRequest::default().with_mechanism(StrBytes::from_static_str(mechanism));
252 let response = self
253 .send_request::<SaslHandshakeRequest>(client_id, handshake_version, &handshake)
254 .await
255 .context("SASL handshake request failed")?;
256 if let Some(error) = response.error_code.err() {
257 let enabled = response
258 .mechanisms
259 .iter()
260 .map(ToString::to_string)
261 .collect::<Vec<_>>()
262 .join(", ");
263 bail!(
264 "SASL handshake failed for mechanism {mechanism}: {error}; enabled mechanisms: [{enabled}]"
265 );
266 }
267
268 match sasl.mechanism {
269 SaslMechanism::Plain => {
270 let token = build_plain_sasl_token(sasl)?;
271 if let Some(version) = authenticate_version {
272 self.send_sasl_authenticate(client_id, version, mechanism, token)
273 .await?;
274 } else {
275 write_raw_sasl_token(&mut self.stream, &token).await?;
276 }
277 }
278 SaslMechanism::ScramSha256 | SaslMechanism::ScramSha512 => {
279 self.authenticate_scram(client_id, sasl, authenticate_version)
280 .await?;
281 }
282 }
283
284 debug!(mechanism, "completed SASL authentication");
285 Ok(())
286 }
287
288 async fn authenticate_scram(
289 &mut self,
290 client_id: &str,
291 sasl: &SaslConfig,
292 authenticate_version: Option<i16>,
293 ) -> Result<()> {
294 let username = sasl
295 .username
296 .as_ref()
297 .context("SASL/SCRAM requires a username")?
298 .clone();
299 let password = sasl
300 .password
301 .as_ref()
302 .context("SASL/SCRAM requires a password")?
303 .clone();
304 let mechanism = sasl.mechanism.as_str();
305 let mut scram = ScramClient::new(sasl.mechanism, username, password)?;
306 let client_first = scram.client_first_message();
307
308 let server_first = if let Some(version) = authenticate_version {
309 self.send_sasl_authenticate(client_id, version, mechanism, client_first)
310 .await?
311 } else {
312 write_raw_sasl_token(&mut self.stream, &client_first).await?;
313 read_frame(&mut self.stream).await?
314 };
315 let client_final = scram.handle_server_first_message(&server_first)?;
316
317 let server_final = if let Some(version) = authenticate_version {
318 self.send_sasl_authenticate(client_id, version, mechanism, client_final)
319 .await?
320 } else {
321 write_raw_sasl_token(&mut self.stream, &client_final).await?;
322 read_frame(&mut self.stream).await?
323 };
324 scram.handle_server_final_message(&server_final)?;
325 Ok(())
326 }
327
328 async fn send_sasl_authenticate(
329 &mut self,
330 client_id: &str,
331 version: i16,
332 mechanism: &str,
333 token: Vec<u8>,
334 ) -> Result<Vec<u8>> {
335 let request = SaslAuthenticateRequest::default().with_auth_bytes(Bytes::from(token));
336 let response = self
337 .send_request::<SaslAuthenticateRequest>(client_id, version, &request)
338 .await
339 .context("SASL authenticate request failed")?;
340 if let Some(error) = response.error_code.err() {
341 let message = response
342 .error_message
343 .as_ref()
344 .map(ToString::to_string)
345 .filter(|message| !message.is_empty())
346 .unwrap_or_else(|| error.to_string());
347 bail!("SASL authentication failed for mechanism {mechanism}: {message}");
348 }
349 Ok(response.auth_bytes.to_vec())
350 }
351
352 pub fn version_with_cap<Req>(&self, cap: i16) -> Result<i16>
353 where
354 Req: Request,
355 {
356 let broker_range = self
357 .api_versions
358 .get(&Req::KEY)
359 .copied()
360 .with_context(|| format!("broker did not advertise API key {}", Req::KEY))?;
361 select_api_version(Req::KEY, broker_range, Req::VERSIONS, cap)
362 }
363
364 pub fn finalized_feature_level(&self, feature: &str) -> Option<i16> {
365 self.finalized_features.get(feature).copied()
366 }
367
368 pub fn finalized_feature_levels(&self) -> Vec<(String, i16)> {
369 let mut features = self
370 .finalized_features
371 .iter()
372 .map(|(name, level)| (name.clone(), *level))
373 .collect::<Vec<_>>();
374 features.sort_by(|left, right| left.0.cmp(&right.0));
375 features
376 }
377
378 async fn negotiate_versions(&mut self, client_id: &str) -> Result<()> {
379 let modern_request = ApiVersionsRequest::default()
380 .with_client_software_name(StrBytes::from_static_str("kafkit-client"))
381 .with_client_software_version(StrBytes::from_static_str("0.2.0"));
382
383 let response = match self
384 .send_request::<ApiVersionsRequest>(
385 client_id,
386 API_VERSIONS_PROBE_VERSION,
387 &modern_request,
388 )
389 .await
390 {
391 Ok(response) => response,
392 Err(error) => {
393 debug!(
394 error = %error,
395 "modern ApiVersions probe failed, retrying with fallback request"
396 );
397 self.send_request::<ApiVersionsRequest>(
398 client_id,
399 API_VERSIONS_FALLBACK_VERSION,
400 &ApiVersionsRequest::default(),
401 )
402 .await?
403 }
404 };
405
406 if let Some(error) = response.error_code.err() {
407 bail!("ApiVersions failed: {error}");
408 }
409
410 self.api_versions = response
411 .api_keys
412 .into_iter()
413 .map(|api| {
414 (
415 api.api_key,
416 VersionRange {
417 min: api.min_version,
418 max: api.max_version,
419 },
420 )
421 })
422 .collect();
423 self.finalized_features = response
424 .finalized_features
425 .into_iter()
426 .map(|feature| (feature.name.to_string(), feature.max_version_level))
427 .collect();
428
429 trace!(
430 api_keys = self.api_versions.len(),
431 finalized_features = self.finalized_features.len(),
432 "negotiated broker ApiVersions"
433 );
434 Ok(())
435 }
436
437 pub async fn send_request<Req>(
438 &mut self,
439 client_id: &str,
440 version: i16,
441 request: &Req,
442 ) -> Result<Req::Response>
443 where
444 Req: Request,
445 {
446 let correlation_id = self.next_correlation_id;
447 self.next_correlation_id += 1;
448 let started = Instant::now();
449 let mut request_bytes = 0usize;
450 let mut response_bytes = 0usize;
451 let span = trace_span!(
452 "kafka_request",
453 request = std::any::type_name::<Req>(),
454 api_key = Req::KEY,
455 api_version = version,
456 correlation_id,
457 %client_id
458 );
459
460 let result = async {
461 let mut body = BytesMut::new();
462 let header = RequestHeader::default()
463 .with_request_api_key(Req::KEY)
464 .with_request_api_version(version)
465 .with_correlation_id(correlation_id)
466 .with_client_id(Some(StrBytes::from_string(client_id.to_owned())));
467 encode_request_header_into_buffer(&mut body, &header)?;
468 request.encode(&mut body, version)?;
469 request_bytes = body.len();
470
471 trace!(request_bytes = body.len(), "encoded Kafka request");
472
473 let mut frame = BytesMut::with_capacity(body.len() + 4);
474 frame.put_i32(i32::try_from(body.len()).context("request frame is too large")?);
475 frame.extend_from_slice(&body);
476
477 self.stream.write_all(&frame).await?;
478 trace!(frame_bytes = frame.len(), "wrote Kafka request frame");
479
480 let response_frame = read_frame(&mut self.stream).await?;
481 response_bytes = response_frame.len();
482 trace!(
483 response_bytes = response_frame.len(),
484 "received Kafka response frame"
485 );
486 let mut response_body = Bytes::from(response_frame);
487 let header_version = Req::Response::header_version(version);
488 let response_header = ResponseHeader::decode(&mut response_body, header_version)?;
489 if response_header.correlation_id != correlation_id {
490 bail!(
491 "response correlation mismatch: expected {}, got {}",
492 correlation_id,
493 response_header.correlation_id
494 );
495 }
496
497 let response = Req::Response::decode(&mut response_body, version)?;
498 trace!("completed Kafka request");
499 Ok(response)
500 }
501 .instrument(span)
502 .await;
503 telemetry::record_kafka_request::<Req>(
504 client_id,
505 version,
506 request_bytes,
507 response_bytes,
508 started.elapsed(),
509 result.is_ok(),
510 true,
511 );
512 result
513 }
514
515 pub async fn send_request_without_response<Req>(
516 &mut self,
517 client_id: &str,
518 version: i16,
519 request: &Req,
520 ) -> Result<()>
521 where
522 Req: Request,
523 {
524 let correlation_id = self.next_correlation_id;
525 self.next_correlation_id += 1;
526 let started = Instant::now();
527 let mut request_bytes = 0usize;
528 let span = trace_span!(
529 "kafka_request",
530 request = std::any::type_name::<Req>(),
531 api_key = Req::KEY,
532 api_version = version,
533 correlation_id,
534 expects_response = false,
535 %client_id
536 );
537
538 let result = async {
539 let mut body = BytesMut::new();
540 let header = RequestHeader::default()
541 .with_request_api_key(Req::KEY)
542 .with_request_api_version(version)
543 .with_correlation_id(correlation_id)
544 .with_client_id(Some(StrBytes::from_string(client_id.to_owned())));
545 encode_request_header_into_buffer(&mut body, &header)?;
546 request.encode(&mut body, version)?;
547 request_bytes = body.len();
548
549 trace!(request_bytes = body.len(), "encoded Kafka request");
550
551 let mut frame = BytesMut::with_capacity(body.len() + 4);
552 frame.put_i32(i32::try_from(body.len()).context("request frame is too large")?);
553 frame.extend_from_slice(&body);
554
555 self.stream.write_all(&frame).await?;
556 trace!(frame_bytes = frame.len(), "wrote Kafka request frame");
557 Ok(())
558 }
559 .instrument(span)
560 .await;
561 telemetry::record_kafka_request::<Req>(
562 client_id,
563 version,
564 request_bytes,
565 0,
566 started.elapsed(),
567 result.is_ok(),
568 false,
569 );
570 result
571 }
572}
573
574async fn connect_stream(
575 address: &str,
576 timeout: Duration,
577 security_protocol: SecurityProtocol,
578 tls: &TlsConfig,
579 tcp_connector: &Arc<dyn TcpConnector>,
580) -> Result<BrokerStream> {
581 let tcp_stream = tcp_connector.connect(address, timeout).await?;
582 tcp_stream
583 .set_nodelay(true)
584 .with_context(|| format!("failed to enable TCP_NODELAY for {address}"))?;
585
586 if security_protocol.uses_tls() {
587 let ConnectedTcpStream::Tokio(tcp_stream) = tcp_stream else {
588 bail!("custom TCP connectors do not support TLS broker connections");
589 };
590 let tls_config = build_tls_client_config(tls)?;
591 let connector = TlsConnector::from(tls_config);
592 let server_name = server_name_for_tls(address, tls)?;
593 let stream = tokio::time::timeout(timeout, connector.connect(server_name, tcp_stream))
594 .await
595 .with_context(|| format!("timed out negotiating TLS with {address}"))?
596 .with_context(|| format!("failed TLS handshake with {address}"))?;
597 Ok(BrokerStream::Tls(Box::new(stream)))
598 } else {
599 match tcp_stream {
600 ConnectedTcpStream::Tokio(stream) => Ok(BrokerStream::Plain(Box::new(stream))),
601 ConnectedTcpStream::Custom(stream) => Ok(BrokerStream::Plain(stream)),
602 }
603 }
604}
605
606fn build_plain_sasl_token(sasl: &SaslConfig) -> Result<Vec<u8>> {
607 let username = sasl
608 .username
609 .as_deref()
610 .context("SASL/PLAIN requires a username")?;
611 let password = sasl
612 .password
613 .as_deref()
614 .context("SASL/PLAIN requires a password")?;
615 let authorization_id = sasl.authorization_id.as_deref().unwrap_or_default();
616
617 let mut token =
618 Vec::with_capacity(authorization_id.len() + username.len() + password.len() + 2);
619 token.extend_from_slice(authorization_id.as_bytes());
620 token.push(0);
621 token.extend_from_slice(username.as_bytes());
622 token.push(0);
623 token.extend_from_slice(password.as_bytes());
624 Ok(token)
625}
626
627async fn write_raw_sasl_token(stream: &mut BrokerStream, token: &[u8]) -> Result<()> {
628 let mut frame = BytesMut::with_capacity(token.len() + 4);
629 frame.put_i32(i32::try_from(token.len()).context("SASL token frame is too large")?);
630 frame.extend_from_slice(token);
631 stream.write_all(&frame).await
632}
633
634fn build_tls_client_config(tls: &TlsConfig) -> Result<Arc<RustlsClientConfig>> {
635 ensure_rustls_crypto_provider();
636
637 let mut root_store = RootCertStore::empty();
638 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
639
640 if let Some(ca_cert_path) = tls.ca_cert_path.as_deref() {
641 for cert in load_certificates(ca_cert_path)? {
642 root_store.add(cert)?;
643 }
644 }
645
646 let builder = RustlsClientConfig::builder().with_root_certificates(root_store);
647 let config = match (
648 tls.client_cert_path.as_deref(),
649 tls.client_key_path.as_deref(),
650 ) {
651 (Some(client_cert_path), Some(client_key_path)) => builder.with_client_auth_cert(
652 load_certificates(client_cert_path)?,
653 load_private_key(client_key_path)?,
654 )?,
655 (None, None) => builder.with_no_client_auth(),
656 _ => bail!("TLS client auth requires both client_cert_path and client_key_path"),
657 };
658
659 Ok(Arc::new(config))
660}
661
662fn ensure_rustls_crypto_provider() {
663 static INSTALL_PROVIDER: Once = Once::new();
664
665 INSTALL_PROVIDER.call_once(|| {
666 if rustls::crypto::CryptoProvider::get_default().is_none() {
667 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
668 }
669 });
670}
671
672fn load_certificates(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
673 let file = File::open(path)
674 .with_context(|| format!("failed to open TLS certificate file '{}'", path.display()))?;
675 let mut reader = BufReader::new(file);
676 let certs = rustls_pemfile::certs(&mut reader)
677 .collect::<std::result::Result<Vec<_>, _>>()
678 .with_context(|| format!("failed to parse TLS certificate PEM '{}'", path.display()))?;
679 if certs.is_empty() {
680 bail!(
681 "TLS certificate file '{}' did not contain any PEM certificates",
682 path.display()
683 );
684 }
685 Ok(certs)
686}
687
688fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>> {
689 let file = File::open(path)
690 .with_context(|| format!("failed to open TLS private key file '{}'", path.display()))?;
691 let mut reader = BufReader::new(file);
692 rustls_pemfile::private_key(&mut reader)
693 .with_context(|| format!("failed to parse TLS private key PEM '{}'", path.display()))?
694 .with_context(|| {
695 format!(
696 "TLS private key file '{}' did not contain a PEM key",
697 path.display()
698 )
699 })
700}
701
702fn server_name_for_tls(address: &str, tls: &TlsConfig) -> Result<ServerName<'static>> {
703 if let Some(server_name) = tls.server_name.as_ref() {
704 return ServerName::try_from(server_name.clone())
705 .with_context(|| format!("invalid TLS server name '{}'", server_name));
706 }
707
708 if let Ok(socket_addr) = address.parse::<SocketAddr>() {
709 return Ok(ServerName::IpAddress(socket_addr.ip().into()));
710 }
711
712 let host = if let Some(stripped) = address.strip_prefix('[') {
713 stripped
714 .split(']')
715 .next()
716 .context("invalid bracketed broker address")?
717 .to_owned()
718 } else {
719 address
720 .rsplit_once(':')
721 .map(|(host, _)| host.to_owned())
722 .unwrap_or_else(|| address.to_owned())
723 };
724
725 ServerName::try_from(host.clone()).with_context(|| {
726 format!("could not derive a valid TLS server name from broker address '{address}'")
727 })
728}
729
730async fn read_frame(stream: &mut BrokerStream) -> Result<Vec<u8>> {
731 let mut header = [0_u8; 4];
732 stream.read_exact(&mut header).await?;
733 let frame_len = i32::from_be_bytes(header);
734 if frame_len < 0 {
735 bail!("broker returned a negative frame length: {frame_len}");
736 }
737
738 let mut payload = vec![0_u8; usize::try_from(frame_len)?];
739 stream.read_exact(&mut payload).await?;
740 Ok(payload)
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746 use std::fs;
747 use std::sync::Arc;
748
749 use tokio::io;
750
751 #[test]
752 fn tls_server_name_defaults_to_host() {
753 let server_name =
754 server_name_for_tls("broker.example.com:9093", &TlsConfig::default()).unwrap();
755 assert_eq!(server_name.to_str(), "broker.example.com");
756 }
757
758 #[test]
759 fn tls_server_name_respects_explicit_override() {
760 let tls = TlsConfig::new().with_server_name("cluster.internal");
761 let server_name = server_name_for_tls("127.0.0.1:9093", &tls).unwrap();
762 assert_eq!(server_name.to_str(), "cluster.internal");
763 }
764
765 #[test]
766 fn tls_server_name_handles_ip_and_bracketed_ipv6() {
767 let server_name = server_name_for_tls("127.0.0.1:9093", &TlsConfig::default()).unwrap();
768 assert_eq!(server_name.to_str(), "127.0.0.1");
769
770 let server_name = server_name_for_tls("[::1]:9093", &TlsConfig::default()).unwrap();
771 assert_eq!(server_name.to_str(), "::1");
772 }
773
774 #[test]
775 fn tls_server_name_rejects_invalid_override_and_empty_address() {
776 let tls = TlsConfig::new().with_server_name("not a dns name");
777 assert!(server_name_for_tls("127.0.0.1:9093", &tls).is_err());
778 assert!(server_name_for_tls("", &TlsConfig::default()).is_err());
779 }
780
781 #[test]
782 fn plain_sasl_token_requires_credentials_and_uses_authorization_id() {
783 assert!(build_plain_sasl_token(&SaslConfig::default()).is_err());
784 assert!(
785 build_plain_sasl_token(&SaslConfig::plain("user", "pw").with_authorization_id("authz"))
786 .unwrap()
787 == b"authz\0user\0pw"
788 );
789 }
790
791 #[test]
792 fn tls_file_loaders_reject_missing_empty_and_invalid_pem_files() {
793 let dir =
794 std::env::temp_dir().join(format!("kafkit-client-tls-test-{}", std::process::id()));
795 fs::create_dir_all(&dir).unwrap();
796 let cert_path = dir.join("cert.pem");
797 let key_path = dir.join("key.pem");
798 fs::write(&cert_path, b"not a certificate").unwrap();
799 fs::write(&key_path, b"not a key").unwrap();
800
801 assert!(load_certificates(&cert_path).is_err());
802 assert!(load_private_key(&key_path).is_err());
803 assert!(load_certificates(&dir.join("missing.pem")).is_err());
804
805 let _ = fs::remove_dir_all(dir);
806 }
807
808 #[test]
809 fn tls_client_config_loads_custom_ca_and_client_auth_pem_files() {
810 let dir = std::env::temp_dir().join(format!(
811 "kafkit-client-tls-valid-pem-test-{}",
812 std::process::id()
813 ));
814 fs::create_dir_all(&dir).unwrap();
815 let cert_path = dir.join("cert.pem");
816 let key_path = dir.join("key.pem");
817 fs::write(&cert_path, TEST_CERT_PEM).unwrap();
818 fs::write(&key_path, TEST_KEY_PEM).unwrap();
819
820 let tls = TlsConfig::new()
821 .with_ca_cert_path(&cert_path)
822 .with_client_cert_path(&cert_path)
823 .with_client_key_path(&key_path)
824 .with_server_name("cluster.internal");
825
826 build_tls_client_config(&tls).expect("valid custom CA and client auth config");
827 assert!(
828 build_tls_client_config(&TlsConfig::new().with_client_cert_path(&cert_path)).is_err()
829 );
830 assert!(
831 build_tls_client_config(&TlsConfig::new().with_client_key_path(&key_path)).is_err()
832 );
833
834 let _ = fs::remove_dir_all(dir);
835 }
836
837 #[tokio::test]
838 async fn tls_rejects_custom_tcp_connectors_before_handshake() {
839 let connector: Arc<dyn TcpConnector> = Arc::new(CustomOnlyConnector);
840 let error = match connect_stream(
841 "broker.example.com:9093",
842 Duration::from_secs(1),
843 SecurityProtocol::Ssl,
844 &TlsConfig::default(),
845 &connector,
846 )
847 .await
848 {
849 Ok(_) => panic!("TLS over custom stream should be rejected"),
850 Err(error) => error,
851 };
852
853 assert!(
854 error
855 .to_string()
856 .contains("custom TCP connectors do not support TLS broker connections")
857 );
858 }
859
860 #[derive(Debug)]
861 struct CustomOnlyConnector;
862
863 impl TcpConnector for CustomOnlyConnector {
864 fn connect<'a>(&'a self, _address: &'a str, _timeout: Duration) -> ConnectFuture<'a> {
865 Box::pin(async move {
866 let (stream, _peer) = io::duplex(64);
867 Ok(ConnectedTcpStream::Custom(Box::new(stream)))
868 })
869 }
870 }
871
872 const TEST_CERT_PEM: &[u8] = b"-----BEGIN CERTIFICATE-----
873MIIDFzCCAf+gAwIBAgIUU1sGIzptOpATf4S4bW3ljAEYj94wDQYJKoZIhvcNAQEL
874BQAwGzEZMBcGA1UEAwwQY2x1c3Rlci5pbnRlcm5hbDAeFw0yNjA1MDUxMzIyNTla
875Fw0yNjA1MDYxMzIyNTlaMBsxGTAXBgNVBAMMEGNsdXN0ZXIuaW50ZXJuYWwwggEi
876MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC59uFLczWX0ES7Y2ckovLTPC+r
877lhAYhS+KOpIeEjgo+mqQ9fmyqnAq6NTr/tWWVgcfgAoqNo1+gOQa9WIu55NOQzNa
878wBreheE8MaL7QD/QFZnvT0Z5Hh3hkXj2HTDQqBIMv1i3bVaDDOkK3xphfQO8QhV9
879YtZf2MvxvtCbl0kBqAUN+k+EECu4TENNLQyS+2rZhxqg0/Js3DUu24nMD3ilL4Kf
880KU2qE3pNfe6IrPl36LY+GkxprvmwPncocR4piJKGrc20XCsiM9KnAimIwZ6/nZ/C
881DJEESK2+NmjDs84GHQFmxh1rlpaSFYJsshxnFH/y0ccyHtLZpsi+R0S7iqPJAgMB
882AAGjUzBRMB0GA1UdDgQWBBR86FwGaRa1IxBdu4KK5TWR01asBzAfBgNVHSMEGDAW
883gBR86FwGaRa1IxBdu4KK5TWR01asBzAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3
884DQEBCwUAA4IBAQBgihO4KChG9VRoY7/Sq5UWjuZT8UWZoyjyejglK/J7enmx0bRX
885clEg8gRZfhbFpYIybppIK+UuKUixkFeqW2CAt/odzNDcYiMEhXZ8SWLx12LhKcLi
886EITLt0PZ877aNaszz5UWlP6Wj4ec8f1DiD1PSIQqz9gddwwdX8gespmyeW/riuCQ
887RMfp9HwJgpcVQMqqSeOwZaDlm1szhpEql+g1/mVGMjXHYO0B7fxzrMUY99vSkOw0
888iJQHtjVkkiHfkN1HDmpfwfONwfsyA0UYtzH4kwbVHm7v1FixQ8TS24jjQi19+v3h
889M/xsKOBvTns6oAKzm3oerDtSSt/heECbD3rb
890-----END CERTIFICATE-----
891";
892
893 const TEST_KEY_PEM: &[u8] = b"-----BEGIN PRIVATE KEY-----
894MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC59uFLczWX0ES7
895Y2ckovLTPC+rlhAYhS+KOpIeEjgo+mqQ9fmyqnAq6NTr/tWWVgcfgAoqNo1+gOQa
8969WIu55NOQzNawBreheE8MaL7QD/QFZnvT0Z5Hh3hkXj2HTDQqBIMv1i3bVaDDOkK
8973xphfQO8QhV9YtZf2MvxvtCbl0kBqAUN+k+EECu4TENNLQyS+2rZhxqg0/Js3DUu
89824nMD3ilL4KfKU2qE3pNfe6IrPl36LY+GkxprvmwPncocR4piJKGrc20XCsiM9Kn
899AimIwZ6/nZ/CDJEESK2+NmjDs84GHQFmxh1rlpaSFYJsshxnFH/y0ccyHtLZpsi+
900R0S7iqPJAgMBAAECggEAAJRn8TCSDX/NNXMix0b1kDoDGtS6oFDxLBjXPsSNknch
901YOobYqnl9Dd9ZNTxCbYJiwwYbzd0Hnci/ubrICLoElmvepkLT5lF1/mxoxKsTQ11
902yUl+enJhFnegU5tIsF9twWA3ukhBeXwcHkTbk+U4+NvER5VIyzJL6txOhMWmdemO
903Tvk7vm1gUzr84k+mYdEoIaS5Bb8zgSNWcLVvZTAvd5VQuV5/SNHVrbpCy6q1dC++
9047FdAhgSJ+CdRk/aAIXZ7zKrhe0pbCWDkmQLIdESLbv1onb9Sj/CLw8MEogMbT7T+
9050FvjagYsmKsIq6Jyhd/Ve+zoLXOgOszVDYVvW14GNwKBgQDwgdT/lQEHPCoRfd1U
906dz77OMIpawZtC7UAf+ab6HEmbSRaoIIa7kx5fjZeMTy6wQamN4xIcqCxt8KWozVH
907M8VnChAidj3yX15AWiKT9kIBuk4dJOLwVh0Hsho+ml034M7txhBNPNIWdfxpFIti
9080xncG9hkfCj5qxkUesHnuS29fwKBgQDF8ZdRnLU7iGW3YyE9OYKb/GzlF/NMkRex
9097mRyTueOR5p/OiWQkQYo1F4XArnmIQSCcllOb0VukwBJLItKqc8fBHjkiyJyvCft
910ZJSR3/BjFgx2w9Vo93bTpiHvevz2nTbebhV0kYXydgeiF7jCcpOQuAjreK5yhhCV
911HvJoKJrStwKBgCk0VSWkhZSTvjFY+v5pn6SyyLEH4QX1p4D6aKv1Ws1WjY/pR+EN
912SpTWBsKEdP8Z6uW3RpVy7g0EipX8SDh2qi9JDhKZZ2uK4z7rMllfK1fYb2GW3DqI
913xlh3Lv/ium3EWi9qa4iQDv5CIIhwOKEpwZhwPNaaXvrHUXisv2PP2gJJAoGBAL1x
914yjQWujFfCpKoclCJcSJfRc1Azd9S4g2uLj5knCNFDm2Dth4VXoLHNcHqHwdMRGeg
915jy6NOjNox5ZA5pMv0AZMnnOFYhPTVpdScwrl+8ipeoZUSTSr2vMXhlUQLXjN4Iyj
916aS9mc38pTYbqEy8uv2J7cDYFC1iaTNabhr7/VaYjAoGBAOApTlkgCYa7eUk9YYJs
917zdrPUZcgT8cGTL6f04cLleaAW9gICh+25yDBQbay4uLTSKXMwb5Kygu8RYDk2NDz
918GEdMJjFtDUbjt1eAlAarBIdsBs7A7jk1nGfu5g8Ervnm1X8Gs9FbUABmPQadNGJR
91920YddOzMXpjdAMlrtmhRp4z1
920-----END PRIVATE KEY-----
921";
922}