Skip to main content

kafkit_client/network/
connection.rs

1use std::collections::HashMap;
2use std::fs::File;
3use std::future::Future;
4use std::io::BufReader;
5use std::net::SocketAddr;
6use std::path::Path;
7use std::pin::Pin;
8use std::sync::{Arc, Once};
9use std::time::{Duration, Instant};
10
11use anyhow::{Context, Result, bail};
12use bytes::{BufMut, Bytes, BytesMut};
13use kafka_protocol::error::ParseResponseErrorCode;
14use kafka_protocol::messages::{
15    ApiVersionsRequest, RequestHeader, ResponseHeader, SaslAuthenticateRequest,
16    SaslHandshakeRequest,
17};
18use kafka_protocol::protocol::{
19    Decodable, HeaderVersion, Message, Request, StrBytes, VersionRange,
20    encode_request_header_into_buffer,
21};
22use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
23use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
24use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
25use tokio::net::TcpStream;
26use tokio_rustls::TlsConnector;
27use tokio_rustls::client::TlsStream;
28use tracing::{Instrument, debug, trace, trace_span};
29
30use super::scram::ScramClient;
31use super::select_api_version;
32use crate::config::{SaslConfig, SaslMechanism, SecurityProtocol, TlsConfig};
33use crate::constants::{API_VERSIONS_FALLBACK_VERSION, API_VERSIONS_PROBE_VERSION};
34use crate::telemetry;
35
36pub async fn connect_to_any_bootstrap(
37    servers: &[String],
38    client_id: &str,
39    timeout: Duration,
40    security_protocol: SecurityProtocol,
41    tls: &TlsConfig,
42    sasl: &SaslConfig,
43    tcp_connector: &Arc<dyn TcpConnector>,
44) -> Result<BrokerConnection> {
45    if servers.is_empty() {
46        bail!("no bootstrap servers configured");
47    }
48
49    let mut last_error: Option<anyhow::Error> = None;
50    for server in servers {
51        match BrokerConnection::connect_with_transport(
52            server,
53            client_id,
54            timeout,
55            security_protocol,
56            tls,
57            sasl,
58            tcp_connector,
59        )
60        .await
61        {
62            Ok(conn) => return Ok(conn),
63            Err(e) => {
64                debug!(server = %server, error = %e, "bootstrap connection failed, trying next server");
65                last_error = Some(e);
66            }
67        }
68    }
69    Err(last_error.unwrap())
70}
71
72type ConnectFuture<'a> = Pin<Box<dyn Future<Output = Result<ConnectedTcpStream>> + Send + 'a>>;
73
74pub trait BrokerIo: AsyncRead + AsyncWrite + Unpin + Send {}
75
76impl<T> BrokerIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
77
78pub enum ConnectedTcpStream {
79    Tokio(TcpStream),
80    Custom(Box<dyn BrokerIo>),
81}
82
83impl ConnectedTcpStream {
84    fn set_nodelay(&self, nodelay: bool) -> Result<()> {
85        match self {
86            Self::Tokio(stream) => stream.set_nodelay(nodelay)?,
87            Self::Custom(_) => {}
88        }
89        Ok(())
90    }
91}
92
93pub trait TcpConnector: std::fmt::Debug + Send + Sync {
94    fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a>;
95}
96
97#[derive(Debug, Default)]
98pub struct TokioTcpConnector;
99
100impl TcpConnector for TokioTcpConnector {
101    fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a> {
102        Box::pin(async move {
103            let tcp_stream = tokio::time::timeout(timeout, TcpStream::connect(address))
104                .await
105                .with_context(|| format!("timed out connecting to {address}"))?
106                .with_context(|| format!("failed to connect to {address}"))?;
107            Ok(ConnectedTcpStream::Tokio(tcp_stream))
108        })
109    }
110}
111
112pub struct BrokerConnection {
113    stream: BrokerStream,
114    next_correlation_id: i32,
115    api_versions: HashMap<i16, VersionRange>,
116    finalized_features: HashMap<String, i16>,
117}
118
119enum BrokerStream {
120    Plain(Box<dyn BrokerIo>),
121    Tls(Box<TlsStream<TcpStream>>),
122}
123
124impl BrokerStream {
125    async fn write_all(&mut self, frame: &[u8]) -> Result<()> {
126        match self {
127            Self::Plain(stream) => stream.write_all(frame).await?,
128            Self::Tls(stream) => stream.write_all(frame).await?,
129        }
130        Ok(())
131    }
132
133    async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
134        match self {
135            Self::Plain(stream) => {
136                stream.read_exact(buf).await?;
137            }
138            Self::Tls(stream) => {
139                stream.read_exact(buf).await?;
140            }
141        };
142        Ok(())
143    }
144}
145
146impl BrokerConnection {
147    pub async fn connect_with_transport(
148        address: &str,
149        client_id: &str,
150        timeout: Duration,
151        security_protocol: SecurityProtocol,
152        tls: &TlsConfig,
153        sasl: &SaslConfig,
154        tcp_connector: &Arc<dyn TcpConnector>,
155    ) -> Result<Self> {
156        let started = Instant::now();
157        let result = async {
158            debug!(?security_protocol, "connecting to broker");
159            let stream =
160                connect_stream(address, timeout, security_protocol, tls, tcp_connector).await?;
161            let mut connection = Self {
162                stream,
163                next_correlation_id: 1,
164                api_versions: HashMap::new(),
165                finalized_features: HashMap::new(),
166            };
167            if security_protocol.uses_sasl() {
168                connection.authenticate_sasl(client_id, sasl).await?;
169            }
170            connection.negotiate_versions(client_id).await?;
171            debug!(
172                api_keys = connection.api_versions.len(),
173                finalized_features = connection.finalized_features.len(),
174                ?security_protocol,
175                "connected to broker"
176            );
177            Ok(connection)
178        }
179        .instrument(tracing::debug_span!(
180            "broker_connect",
181            %address,
182            %client_id,
183            timeout_ms = timeout.as_millis()
184        ))
185        .await;
186        telemetry::record_broker_connection(
187            client_id,
188            address,
189            &format!("{security_protocol:?}"),
190            started.elapsed(),
191            result.is_ok(),
192        );
193        result
194    }
195
196    async fn authenticate_sasl(&mut self, client_id: &str, sasl: &SaslConfig) -> Result<()> {
197        let response = self
198            .send_request::<ApiVersionsRequest>(
199                client_id,
200                API_VERSIONS_FALLBACK_VERSION,
201                &ApiVersionsRequest::default(),
202            )
203            .await
204            .context("SASL ApiVersions probe failed")?;
205        if let Some(error) = response.error_code.err() {
206            bail!("SASL ApiVersions probe failed: {error}");
207        }
208
209        let api_versions = response
210            .api_keys
211            .into_iter()
212            .map(|api| {
213                (
214                    api.api_key,
215                    VersionRange {
216                        min: api.min_version,
217                        max: api.max_version,
218                    },
219                )
220            })
221            .collect::<HashMap<_, _>>();
222
223        let handshake_version = api_versions
224            .get(&SaslHandshakeRequest::KEY)
225            .copied()
226            .map(|range| {
227                select_api_version(
228                    SaslHandshakeRequest::KEY,
229                    range,
230                    SaslHandshakeRequest::VERSIONS,
231                    SaslHandshakeRequest::VERSIONS.max,
232                )
233            })
234            .transpose()?
235            .unwrap_or(0);
236        let authenticate_version = api_versions
237            .get(&SaslAuthenticateRequest::KEY)
238            .copied()
239            .map(|range| {
240                select_api_version(
241                    SaslAuthenticateRequest::KEY,
242                    range,
243                    SaslAuthenticateRequest::VERSIONS,
244                    SaslAuthenticateRequest::VERSIONS.max,
245                )
246            })
247            .transpose()?;
248
249        let mechanism = sasl.mechanism.as_str();
250        let handshake =
251            SaslHandshakeRequest::default().with_mechanism(StrBytes::from_static_str(mechanism));
252        let response = self
253            .send_request::<SaslHandshakeRequest>(client_id, handshake_version, &handshake)
254            .await
255            .context("SASL handshake request failed")?;
256        if let Some(error) = response.error_code.err() {
257            let enabled = response
258                .mechanisms
259                .iter()
260                .map(ToString::to_string)
261                .collect::<Vec<_>>()
262                .join(", ");
263            bail!(
264                "SASL handshake failed for mechanism {mechanism}: {error}; enabled mechanisms: [{enabled}]"
265            );
266        }
267
268        match sasl.mechanism {
269            SaslMechanism::Plain => {
270                let token = build_plain_sasl_token(sasl)?;
271                if let Some(version) = authenticate_version {
272                    self.send_sasl_authenticate(client_id, version, mechanism, token)
273                        .await?;
274                } else {
275                    write_raw_sasl_token(&mut self.stream, &token).await?;
276                }
277            }
278            SaslMechanism::ScramSha256 | SaslMechanism::ScramSha512 => {
279                self.authenticate_scram(client_id, sasl, authenticate_version)
280                    .await?;
281            }
282        }
283
284        debug!(mechanism, "completed SASL authentication");
285        Ok(())
286    }
287
288    async fn authenticate_scram(
289        &mut self,
290        client_id: &str,
291        sasl: &SaslConfig,
292        authenticate_version: Option<i16>,
293    ) -> Result<()> {
294        let username = sasl
295            .username
296            .as_ref()
297            .context("SASL/SCRAM requires a username")?
298            .clone();
299        let password = sasl
300            .password
301            .as_ref()
302            .context("SASL/SCRAM requires a password")?
303            .clone();
304        let mechanism = sasl.mechanism.as_str();
305        let mut scram = ScramClient::new(sasl.mechanism, username, password)?;
306        let client_first = scram.client_first_message();
307
308        let server_first = if let Some(version) = authenticate_version {
309            self.send_sasl_authenticate(client_id, version, mechanism, client_first)
310                .await?
311        } else {
312            write_raw_sasl_token(&mut self.stream, &client_first).await?;
313            read_frame(&mut self.stream).await?
314        };
315        let client_final = scram.handle_server_first_message(&server_first)?;
316
317        let server_final = if let Some(version) = authenticate_version {
318            self.send_sasl_authenticate(client_id, version, mechanism, client_final)
319                .await?
320        } else {
321            write_raw_sasl_token(&mut self.stream, &client_final).await?;
322            read_frame(&mut self.stream).await?
323        };
324        scram.handle_server_final_message(&server_final)?;
325        Ok(())
326    }
327
328    async fn send_sasl_authenticate(
329        &mut self,
330        client_id: &str,
331        version: i16,
332        mechanism: &str,
333        token: Vec<u8>,
334    ) -> Result<Vec<u8>> {
335        let request = SaslAuthenticateRequest::default().with_auth_bytes(Bytes::from(token));
336        let response = self
337            .send_request::<SaslAuthenticateRequest>(client_id, version, &request)
338            .await
339            .context("SASL authenticate request failed")?;
340        if let Some(error) = response.error_code.err() {
341            let message = response
342                .error_message
343                .as_ref()
344                .map(ToString::to_string)
345                .filter(|message| !message.is_empty())
346                .unwrap_or_else(|| error.to_string());
347            bail!("SASL authentication failed for mechanism {mechanism}: {message}");
348        }
349        Ok(response.auth_bytes.to_vec())
350    }
351
352    pub fn version_with_cap<Req>(&self, cap: i16) -> Result<i16>
353    where
354        Req: Request,
355    {
356        let broker_range = self
357            .api_versions
358            .get(&Req::KEY)
359            .copied()
360            .with_context(|| format!("broker did not advertise API key {}", Req::KEY))?;
361        select_api_version(Req::KEY, broker_range, Req::VERSIONS, cap)
362    }
363
364    pub fn finalized_feature_level(&self, feature: &str) -> Option<i16> {
365        self.finalized_features.get(feature).copied()
366    }
367
368    pub fn finalized_feature_levels(&self) -> Vec<(String, i16)> {
369        let mut features = self
370            .finalized_features
371            .iter()
372            .map(|(name, level)| (name.clone(), *level))
373            .collect::<Vec<_>>();
374        features.sort_by(|left, right| left.0.cmp(&right.0));
375        features
376    }
377
378    async fn negotiate_versions(&mut self, client_id: &str) -> Result<()> {
379        let modern_request = ApiVersionsRequest::default()
380            .with_client_software_name(StrBytes::from_static_str("kafkit-client"))
381            .with_client_software_version(StrBytes::from_static_str("0.2.0"));
382
383        let response = match self
384            .send_request::<ApiVersionsRequest>(
385                client_id,
386                API_VERSIONS_PROBE_VERSION,
387                &modern_request,
388            )
389            .await
390        {
391            Ok(response) => response,
392            Err(error) => {
393                debug!(
394                    error = %error,
395                    "modern ApiVersions probe failed, retrying with fallback request"
396                );
397                self.send_request::<ApiVersionsRequest>(
398                    client_id,
399                    API_VERSIONS_FALLBACK_VERSION,
400                    &ApiVersionsRequest::default(),
401                )
402                .await?
403            }
404        };
405
406        if let Some(error) = response.error_code.err() {
407            bail!("ApiVersions failed: {error}");
408        }
409
410        self.api_versions = response
411            .api_keys
412            .into_iter()
413            .map(|api| {
414                (
415                    api.api_key,
416                    VersionRange {
417                        min: api.min_version,
418                        max: api.max_version,
419                    },
420                )
421            })
422            .collect();
423        self.finalized_features = response
424            .finalized_features
425            .into_iter()
426            .map(|feature| (feature.name.to_string(), feature.max_version_level))
427            .collect();
428
429        trace!(
430            api_keys = self.api_versions.len(),
431            finalized_features = self.finalized_features.len(),
432            "negotiated broker ApiVersions"
433        );
434        Ok(())
435    }
436
437    pub async fn send_request<Req>(
438        &mut self,
439        client_id: &str,
440        version: i16,
441        request: &Req,
442    ) -> Result<Req::Response>
443    where
444        Req: Request,
445    {
446        let correlation_id = self.next_correlation_id;
447        self.next_correlation_id += 1;
448        let started = Instant::now();
449        let mut request_bytes = 0usize;
450        let mut response_bytes = 0usize;
451        let span = trace_span!(
452            "kafka_request",
453            request = std::any::type_name::<Req>(),
454            api_key = Req::KEY,
455            api_version = version,
456            correlation_id,
457            %client_id
458        );
459
460        let result = async {
461            let mut body = BytesMut::new();
462            let header = RequestHeader::default()
463                .with_request_api_key(Req::KEY)
464                .with_request_api_version(version)
465                .with_correlation_id(correlation_id)
466                .with_client_id(Some(StrBytes::from_string(client_id.to_owned())));
467            encode_request_header_into_buffer(&mut body, &header)?;
468            request.encode(&mut body, version)?;
469            request_bytes = body.len();
470
471            trace!(request_bytes = body.len(), "encoded Kafka request");
472
473            let mut frame = BytesMut::with_capacity(body.len() + 4);
474            frame.put_i32(i32::try_from(body.len()).context("request frame is too large")?);
475            frame.extend_from_slice(&body);
476
477            self.stream.write_all(&frame).await?;
478            trace!(frame_bytes = frame.len(), "wrote Kafka request frame");
479
480            let response_frame = read_frame(&mut self.stream).await?;
481            response_bytes = response_frame.len();
482            trace!(
483                response_bytes = response_frame.len(),
484                "received Kafka response frame"
485            );
486            let mut response_body = Bytes::from(response_frame);
487            let header_version = Req::Response::header_version(version);
488            let response_header = ResponseHeader::decode(&mut response_body, header_version)?;
489            if response_header.correlation_id != correlation_id {
490                bail!(
491                    "response correlation mismatch: expected {}, got {}",
492                    correlation_id,
493                    response_header.correlation_id
494                );
495            }
496
497            let response = Req::Response::decode(&mut response_body, version)?;
498            trace!("completed Kafka request");
499            Ok(response)
500        }
501        .instrument(span)
502        .await;
503        telemetry::record_kafka_request::<Req>(
504            client_id,
505            version,
506            request_bytes,
507            response_bytes,
508            started.elapsed(),
509            result.is_ok(),
510            true,
511        );
512        result
513    }
514
515    pub async fn send_request_without_response<Req>(
516        &mut self,
517        client_id: &str,
518        version: i16,
519        request: &Req,
520    ) -> Result<()>
521    where
522        Req: Request,
523    {
524        let correlation_id = self.next_correlation_id;
525        self.next_correlation_id += 1;
526        let started = Instant::now();
527        let mut request_bytes = 0usize;
528        let span = trace_span!(
529            "kafka_request",
530            request = std::any::type_name::<Req>(),
531            api_key = Req::KEY,
532            api_version = version,
533            correlation_id,
534            expects_response = false,
535            %client_id
536        );
537
538        let result = async {
539            let mut body = BytesMut::new();
540            let header = RequestHeader::default()
541                .with_request_api_key(Req::KEY)
542                .with_request_api_version(version)
543                .with_correlation_id(correlation_id)
544                .with_client_id(Some(StrBytes::from_string(client_id.to_owned())));
545            encode_request_header_into_buffer(&mut body, &header)?;
546            request.encode(&mut body, version)?;
547            request_bytes = body.len();
548
549            trace!(request_bytes = body.len(), "encoded Kafka request");
550
551            let mut frame = BytesMut::with_capacity(body.len() + 4);
552            frame.put_i32(i32::try_from(body.len()).context("request frame is too large")?);
553            frame.extend_from_slice(&body);
554
555            self.stream.write_all(&frame).await?;
556            trace!(frame_bytes = frame.len(), "wrote Kafka request frame");
557            Ok(())
558        }
559        .instrument(span)
560        .await;
561        telemetry::record_kafka_request::<Req>(
562            client_id,
563            version,
564            request_bytes,
565            0,
566            started.elapsed(),
567            result.is_ok(),
568            false,
569        );
570        result
571    }
572}
573
574async fn connect_stream(
575    address: &str,
576    timeout: Duration,
577    security_protocol: SecurityProtocol,
578    tls: &TlsConfig,
579    tcp_connector: &Arc<dyn TcpConnector>,
580) -> Result<BrokerStream> {
581    let tcp_stream = tcp_connector.connect(address, timeout).await?;
582    tcp_stream
583        .set_nodelay(true)
584        .with_context(|| format!("failed to enable TCP_NODELAY for {address}"))?;
585
586    if security_protocol.uses_tls() {
587        let ConnectedTcpStream::Tokio(tcp_stream) = tcp_stream else {
588            bail!("custom TCP connectors do not support TLS broker connections");
589        };
590        let tls_config = build_tls_client_config(tls)?;
591        let connector = TlsConnector::from(tls_config);
592        let server_name = server_name_for_tls(address, tls)?;
593        let stream = tokio::time::timeout(timeout, connector.connect(server_name, tcp_stream))
594            .await
595            .with_context(|| format!("timed out negotiating TLS with {address}"))?
596            .with_context(|| format!("failed TLS handshake with {address}"))?;
597        Ok(BrokerStream::Tls(Box::new(stream)))
598    } else {
599        match tcp_stream {
600            ConnectedTcpStream::Tokio(stream) => Ok(BrokerStream::Plain(Box::new(stream))),
601            ConnectedTcpStream::Custom(stream) => Ok(BrokerStream::Plain(stream)),
602        }
603    }
604}
605
606fn build_plain_sasl_token(sasl: &SaslConfig) -> Result<Vec<u8>> {
607    let username = sasl
608        .username
609        .as_deref()
610        .context("SASL/PLAIN requires a username")?;
611    let password = sasl
612        .password
613        .as_deref()
614        .context("SASL/PLAIN requires a password")?;
615    let authorization_id = sasl.authorization_id.as_deref().unwrap_or_default();
616
617    let mut token =
618        Vec::with_capacity(authorization_id.len() + username.len() + password.len() + 2);
619    token.extend_from_slice(authorization_id.as_bytes());
620    token.push(0);
621    token.extend_from_slice(username.as_bytes());
622    token.push(0);
623    token.extend_from_slice(password.as_bytes());
624    Ok(token)
625}
626
627async fn write_raw_sasl_token(stream: &mut BrokerStream, token: &[u8]) -> Result<()> {
628    let mut frame = BytesMut::with_capacity(token.len() + 4);
629    frame.put_i32(i32::try_from(token.len()).context("SASL token frame is too large")?);
630    frame.extend_from_slice(token);
631    stream.write_all(&frame).await
632}
633
634fn build_tls_client_config(tls: &TlsConfig) -> Result<Arc<RustlsClientConfig>> {
635    ensure_rustls_crypto_provider();
636
637    let mut root_store = RootCertStore::empty();
638    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
639
640    if let Some(ca_cert_path) = tls.ca_cert_path.as_deref() {
641        for cert in load_certificates(ca_cert_path)? {
642            root_store.add(cert)?;
643        }
644    }
645
646    let builder = RustlsClientConfig::builder().with_root_certificates(root_store);
647    let config = match (
648        tls.client_cert_path.as_deref(),
649        tls.client_key_path.as_deref(),
650    ) {
651        (Some(client_cert_path), Some(client_key_path)) => builder.with_client_auth_cert(
652            load_certificates(client_cert_path)?,
653            load_private_key(client_key_path)?,
654        )?,
655        (None, None) => builder.with_no_client_auth(),
656        _ => bail!("TLS client auth requires both client_cert_path and client_key_path"),
657    };
658
659    Ok(Arc::new(config))
660}
661
662fn ensure_rustls_crypto_provider() {
663    static INSTALL_PROVIDER: Once = Once::new();
664
665    INSTALL_PROVIDER.call_once(|| {
666        if rustls::crypto::CryptoProvider::get_default().is_none() {
667            let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
668        }
669    });
670}
671
672fn load_certificates(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
673    let file = File::open(path)
674        .with_context(|| format!("failed to open TLS certificate file '{}'", path.display()))?;
675    let mut reader = BufReader::new(file);
676    let certs = rustls_pemfile::certs(&mut reader)
677        .collect::<std::result::Result<Vec<_>, _>>()
678        .with_context(|| format!("failed to parse TLS certificate PEM '{}'", path.display()))?;
679    if certs.is_empty() {
680        bail!(
681            "TLS certificate file '{}' did not contain any PEM certificates",
682            path.display()
683        );
684    }
685    Ok(certs)
686}
687
688fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>> {
689    let file = File::open(path)
690        .with_context(|| format!("failed to open TLS private key file '{}'", path.display()))?;
691    let mut reader = BufReader::new(file);
692    rustls_pemfile::private_key(&mut reader)
693        .with_context(|| format!("failed to parse TLS private key PEM '{}'", path.display()))?
694        .with_context(|| {
695            format!(
696                "TLS private key file '{}' did not contain a PEM key",
697                path.display()
698            )
699        })
700}
701
702fn server_name_for_tls(address: &str, tls: &TlsConfig) -> Result<ServerName<'static>> {
703    if let Some(server_name) = tls.server_name.as_ref() {
704        return ServerName::try_from(server_name.clone())
705            .with_context(|| format!("invalid TLS server name '{}'", server_name));
706    }
707
708    if let Ok(socket_addr) = address.parse::<SocketAddr>() {
709        return Ok(ServerName::IpAddress(socket_addr.ip().into()));
710    }
711
712    let host = if let Some(stripped) = address.strip_prefix('[') {
713        stripped
714            .split(']')
715            .next()
716            .context("invalid bracketed broker address")?
717            .to_owned()
718    } else {
719        address
720            .rsplit_once(':')
721            .map(|(host, _)| host.to_owned())
722            .unwrap_or_else(|| address.to_owned())
723    };
724
725    ServerName::try_from(host.clone()).with_context(|| {
726        format!("could not derive a valid TLS server name from broker address '{address}'")
727    })
728}
729
730async fn read_frame(stream: &mut BrokerStream) -> Result<Vec<u8>> {
731    let mut header = [0_u8; 4];
732    stream.read_exact(&mut header).await?;
733    let frame_len = i32::from_be_bytes(header);
734    if frame_len < 0 {
735        bail!("broker returned a negative frame length: {frame_len}");
736    }
737
738    let mut payload = vec![0_u8; usize::try_from(frame_len)?];
739    stream.read_exact(&mut payload).await?;
740    Ok(payload)
741}
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746    use std::fs;
747    use std::sync::Arc;
748
749    use tokio::io;
750
751    #[test]
752    fn tls_server_name_defaults_to_host() {
753        let server_name =
754            server_name_for_tls("broker.example.com:9093", &TlsConfig::default()).unwrap();
755        assert_eq!(server_name.to_str(), "broker.example.com");
756    }
757
758    #[test]
759    fn tls_server_name_respects_explicit_override() {
760        let tls = TlsConfig::new().with_server_name("cluster.internal");
761        let server_name = server_name_for_tls("127.0.0.1:9093", &tls).unwrap();
762        assert_eq!(server_name.to_str(), "cluster.internal");
763    }
764
765    #[test]
766    fn tls_server_name_handles_ip_and_bracketed_ipv6() {
767        let server_name = server_name_for_tls("127.0.0.1:9093", &TlsConfig::default()).unwrap();
768        assert_eq!(server_name.to_str(), "127.0.0.1");
769
770        let server_name = server_name_for_tls("[::1]:9093", &TlsConfig::default()).unwrap();
771        assert_eq!(server_name.to_str(), "::1");
772    }
773
774    #[test]
775    fn tls_server_name_rejects_invalid_override_and_empty_address() {
776        let tls = TlsConfig::new().with_server_name("not a dns name");
777        assert!(server_name_for_tls("127.0.0.1:9093", &tls).is_err());
778        assert!(server_name_for_tls("", &TlsConfig::default()).is_err());
779    }
780
781    #[test]
782    fn plain_sasl_token_requires_credentials_and_uses_authorization_id() {
783        assert!(build_plain_sasl_token(&SaslConfig::default()).is_err());
784        assert!(
785            build_plain_sasl_token(&SaslConfig::plain("user", "pw").with_authorization_id("authz"))
786                .unwrap()
787                == b"authz\0user\0pw"
788        );
789    }
790
791    #[test]
792    fn tls_file_loaders_reject_missing_empty_and_invalid_pem_files() {
793        let dir =
794            std::env::temp_dir().join(format!("kafkit-client-tls-test-{}", std::process::id()));
795        fs::create_dir_all(&dir).unwrap();
796        let cert_path = dir.join("cert.pem");
797        let key_path = dir.join("key.pem");
798        fs::write(&cert_path, b"not a certificate").unwrap();
799        fs::write(&key_path, b"not a key").unwrap();
800
801        assert!(load_certificates(&cert_path).is_err());
802        assert!(load_private_key(&key_path).is_err());
803        assert!(load_certificates(&dir.join("missing.pem")).is_err());
804
805        let _ = fs::remove_dir_all(dir);
806    }
807
808    #[test]
809    fn tls_client_config_loads_custom_ca_and_client_auth_pem_files() {
810        let dir = std::env::temp_dir().join(format!(
811            "kafkit-client-tls-valid-pem-test-{}",
812            std::process::id()
813        ));
814        fs::create_dir_all(&dir).unwrap();
815        let cert_path = dir.join("cert.pem");
816        let key_path = dir.join("key.pem");
817        fs::write(&cert_path, TEST_CERT_PEM).unwrap();
818        fs::write(&key_path, TEST_KEY_PEM).unwrap();
819
820        let tls = TlsConfig::new()
821            .with_ca_cert_path(&cert_path)
822            .with_client_cert_path(&cert_path)
823            .with_client_key_path(&key_path)
824            .with_server_name("cluster.internal");
825
826        build_tls_client_config(&tls).expect("valid custom CA and client auth config");
827        assert!(
828            build_tls_client_config(&TlsConfig::new().with_client_cert_path(&cert_path)).is_err()
829        );
830        assert!(
831            build_tls_client_config(&TlsConfig::new().with_client_key_path(&key_path)).is_err()
832        );
833
834        let _ = fs::remove_dir_all(dir);
835    }
836
837    #[tokio::test]
838    async fn tls_rejects_custom_tcp_connectors_before_handshake() {
839        let connector: Arc<dyn TcpConnector> = Arc::new(CustomOnlyConnector);
840        let error = match connect_stream(
841            "broker.example.com:9093",
842            Duration::from_secs(1),
843            SecurityProtocol::Ssl,
844            &TlsConfig::default(),
845            &connector,
846        )
847        .await
848        {
849            Ok(_) => panic!("TLS over custom stream should be rejected"),
850            Err(error) => error,
851        };
852
853        assert!(
854            error
855                .to_string()
856                .contains("custom TCP connectors do not support TLS broker connections")
857        );
858    }
859
860    #[derive(Debug)]
861    struct CustomOnlyConnector;
862
863    impl TcpConnector for CustomOnlyConnector {
864        fn connect<'a>(&'a self, _address: &'a str, _timeout: Duration) -> ConnectFuture<'a> {
865            Box::pin(async move {
866                let (stream, _peer) = io::duplex(64);
867                Ok(ConnectedTcpStream::Custom(Box::new(stream)))
868            })
869        }
870    }
871
872    const TEST_CERT_PEM: &[u8] = b"-----BEGIN CERTIFICATE-----
873MIIDFzCCAf+gAwIBAgIUU1sGIzptOpATf4S4bW3ljAEYj94wDQYJKoZIhvcNAQEL
874BQAwGzEZMBcGA1UEAwwQY2x1c3Rlci5pbnRlcm5hbDAeFw0yNjA1MDUxMzIyNTla
875Fw0yNjA1MDYxMzIyNTlaMBsxGTAXBgNVBAMMEGNsdXN0ZXIuaW50ZXJuYWwwggEi
876MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC59uFLczWX0ES7Y2ckovLTPC+r
877lhAYhS+KOpIeEjgo+mqQ9fmyqnAq6NTr/tWWVgcfgAoqNo1+gOQa9WIu55NOQzNa
878wBreheE8MaL7QD/QFZnvT0Z5Hh3hkXj2HTDQqBIMv1i3bVaDDOkK3xphfQO8QhV9
879YtZf2MvxvtCbl0kBqAUN+k+EECu4TENNLQyS+2rZhxqg0/Js3DUu24nMD3ilL4Kf
880KU2qE3pNfe6IrPl36LY+GkxprvmwPncocR4piJKGrc20XCsiM9KnAimIwZ6/nZ/C
881DJEESK2+NmjDs84GHQFmxh1rlpaSFYJsshxnFH/y0ccyHtLZpsi+R0S7iqPJAgMB
882AAGjUzBRMB0GA1UdDgQWBBR86FwGaRa1IxBdu4KK5TWR01asBzAfBgNVHSMEGDAW
883gBR86FwGaRa1IxBdu4KK5TWR01asBzAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3
884DQEBCwUAA4IBAQBgihO4KChG9VRoY7/Sq5UWjuZT8UWZoyjyejglK/J7enmx0bRX
885clEg8gRZfhbFpYIybppIK+UuKUixkFeqW2CAt/odzNDcYiMEhXZ8SWLx12LhKcLi
886EITLt0PZ877aNaszz5UWlP6Wj4ec8f1DiD1PSIQqz9gddwwdX8gespmyeW/riuCQ
887RMfp9HwJgpcVQMqqSeOwZaDlm1szhpEql+g1/mVGMjXHYO0B7fxzrMUY99vSkOw0
888iJQHtjVkkiHfkN1HDmpfwfONwfsyA0UYtzH4kwbVHm7v1FixQ8TS24jjQi19+v3h
889M/xsKOBvTns6oAKzm3oerDtSSt/heECbD3rb
890-----END CERTIFICATE-----
891";
892
893    const TEST_KEY_PEM: &[u8] = b"-----BEGIN PRIVATE KEY-----
894MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC59uFLczWX0ES7
895Y2ckovLTPC+rlhAYhS+KOpIeEjgo+mqQ9fmyqnAq6NTr/tWWVgcfgAoqNo1+gOQa
8969WIu55NOQzNawBreheE8MaL7QD/QFZnvT0Z5Hh3hkXj2HTDQqBIMv1i3bVaDDOkK
8973xphfQO8QhV9YtZf2MvxvtCbl0kBqAUN+k+EECu4TENNLQyS+2rZhxqg0/Js3DUu
89824nMD3ilL4KfKU2qE3pNfe6IrPl36LY+GkxprvmwPncocR4piJKGrc20XCsiM9Kn
899AimIwZ6/nZ/CDJEESK2+NmjDs84GHQFmxh1rlpaSFYJsshxnFH/y0ccyHtLZpsi+
900R0S7iqPJAgMBAAECggEAAJRn8TCSDX/NNXMix0b1kDoDGtS6oFDxLBjXPsSNknch
901YOobYqnl9Dd9ZNTxCbYJiwwYbzd0Hnci/ubrICLoElmvepkLT5lF1/mxoxKsTQ11
902yUl+enJhFnegU5tIsF9twWA3ukhBeXwcHkTbk+U4+NvER5VIyzJL6txOhMWmdemO
903Tvk7vm1gUzr84k+mYdEoIaS5Bb8zgSNWcLVvZTAvd5VQuV5/SNHVrbpCy6q1dC++
9047FdAhgSJ+CdRk/aAIXZ7zKrhe0pbCWDkmQLIdESLbv1onb9Sj/CLw8MEogMbT7T+
9050FvjagYsmKsIq6Jyhd/Ve+zoLXOgOszVDYVvW14GNwKBgQDwgdT/lQEHPCoRfd1U
906dz77OMIpawZtC7UAf+ab6HEmbSRaoIIa7kx5fjZeMTy6wQamN4xIcqCxt8KWozVH
907M8VnChAidj3yX15AWiKT9kIBuk4dJOLwVh0Hsho+ml034M7txhBNPNIWdfxpFIti
9080xncG9hkfCj5qxkUesHnuS29fwKBgQDF8ZdRnLU7iGW3YyE9OYKb/GzlF/NMkRex
9097mRyTueOR5p/OiWQkQYo1F4XArnmIQSCcllOb0VukwBJLItKqc8fBHjkiyJyvCft
910ZJSR3/BjFgx2w9Vo93bTpiHvevz2nTbebhV0kYXydgeiF7jCcpOQuAjreK5yhhCV
911HvJoKJrStwKBgCk0VSWkhZSTvjFY+v5pn6SyyLEH4QX1p4D6aKv1Ws1WjY/pR+EN
912SpTWBsKEdP8Z6uW3RpVy7g0EipX8SDh2qi9JDhKZZ2uK4z7rMllfK1fYb2GW3DqI
913xlh3Lv/ium3EWi9qa4iQDv5CIIhwOKEpwZhwPNaaXvrHUXisv2PP2gJJAoGBAL1x
914yjQWujFfCpKoclCJcSJfRc1Azd9S4g2uLj5knCNFDm2Dth4VXoLHNcHqHwdMRGeg
915jy6NOjNox5ZA5pMv0AZMnnOFYhPTVpdScwrl+8ipeoZUSTSr2vMXhlUQLXjN4Iyj
916aS9mc38pTYbqEy8uv2J7cDYFC1iaTNabhr7/VaYjAoGBAOApTlkgCYa7eUk9YYJs
917zdrPUZcgT8cGTL6f04cLleaAW9gICh+25yDBQbay4uLTSKXMwb5Kygu8RYDk2NDz
918GEdMJjFtDUbjt1eAlAarBIdsBs7A7jk1nGfu5g8Ervnm1X8Gs9FbUABmPQadNGJR
91920YddOzMXpjdAMlrtmhRp4z1
920-----END PRIVATE KEY-----
921";
922}