Skip to main content

kafkit_client/network/
connection.rs

1use std::collections::HashMap;
2use std::fs::File;
3use std::future::Future;
4use std::io::BufReader;
5use std::net::SocketAddr;
6use std::path::Path;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Duration;
10
11use anyhow::{Context, Result, bail};
12use bytes::{BufMut, Bytes, BytesMut};
13use kafka_protocol::error::ParseResponseErrorCode;
14use kafka_protocol::messages::{
15    ApiVersionsRequest, RequestHeader, ResponseHeader, SaslAuthenticateRequest,
16    SaslHandshakeRequest,
17};
18use kafka_protocol::protocol::{
19    Decodable, HeaderVersion, Message, Request, StrBytes, VersionRange,
20    encode_request_header_into_buffer,
21};
22use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
23use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
24use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
25use tokio::net::TcpStream;
26use tokio_rustls::TlsConnector;
27use tokio_rustls::client::TlsStream;
28use tracing::{Instrument, debug, trace, trace_span};
29
30use super::scram::ScramClient;
31use super::select_api_version;
32use crate::config::{SaslConfig, SaslMechanism, SecurityProtocol, TlsConfig};
33use crate::constants::{API_VERSIONS_FALLBACK_VERSION, API_VERSIONS_PROBE_VERSION};
34
35pub async fn connect_to_any_bootstrap(
36    servers: &[String],
37    client_id: &str,
38    timeout: Duration,
39    security_protocol: SecurityProtocol,
40    tls: &TlsConfig,
41    sasl: &SaslConfig,
42    tcp_connector: &Arc<dyn TcpConnector>,
43) -> Result<BrokerConnection> {
44    if servers.is_empty() {
45        bail!("no bootstrap servers configured");
46    }
47
48    let mut last_error: Option<anyhow::Error> = None;
49    for server in servers {
50        match BrokerConnection::connect_with_transport(
51            server,
52            client_id,
53            timeout,
54            security_protocol,
55            tls,
56            sasl,
57            tcp_connector,
58        )
59        .await
60        {
61            Ok(conn) => return Ok(conn),
62            Err(e) => {
63                debug!(server = %server, error = %e, "bootstrap connection failed, trying next server");
64                last_error = Some(e);
65            }
66        }
67    }
68    Err(last_error.unwrap())
69}
70
71type ConnectFuture<'a> = Pin<Box<dyn Future<Output = Result<ConnectedTcpStream>> + Send + 'a>>;
72
73pub trait BrokerIo: AsyncRead + AsyncWrite + Unpin + Send {}
74
75impl<T> BrokerIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
76
77pub enum ConnectedTcpStream {
78    Tokio(TcpStream),
79    Custom(Box<dyn BrokerIo>),
80}
81
82impl ConnectedTcpStream {
83    fn set_nodelay(&self, nodelay: bool) -> Result<()> {
84        match self {
85            Self::Tokio(stream) => stream.set_nodelay(nodelay)?,
86            Self::Custom(_) => {}
87        }
88        Ok(())
89    }
90}
91
92pub trait TcpConnector: std::fmt::Debug + Send + Sync {
93    fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a>;
94}
95
96#[derive(Debug, Default)]
97pub struct TokioTcpConnector;
98
99impl TcpConnector for TokioTcpConnector {
100    fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a> {
101        Box::pin(async move {
102            let tcp_stream = tokio::time::timeout(timeout, TcpStream::connect(address))
103                .await
104                .with_context(|| format!("timed out connecting to {address}"))?
105                .with_context(|| format!("failed to connect to {address}"))?;
106            Ok(ConnectedTcpStream::Tokio(tcp_stream))
107        })
108    }
109}
110
111pub struct BrokerConnection {
112    stream: BrokerStream,
113    next_correlation_id: i32,
114    api_versions: HashMap<i16, VersionRange>,
115    finalized_features: HashMap<String, i16>,
116}
117
118enum BrokerStream {
119    Plain(Box<dyn BrokerIo>),
120    Tls(Box<TlsStream<TcpStream>>),
121}
122
123impl BrokerStream {
124    async fn write_all(&mut self, frame: &[u8]) -> Result<()> {
125        match self {
126            Self::Plain(stream) => stream.write_all(frame).await?,
127            Self::Tls(stream) => stream.write_all(frame).await?,
128        }
129        Ok(())
130    }
131
132    async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
133        match self {
134            Self::Plain(stream) => {
135                stream.read_exact(buf).await?;
136            }
137            Self::Tls(stream) => {
138                stream.read_exact(buf).await?;
139            }
140        };
141        Ok(())
142    }
143}
144
145impl BrokerConnection {
146    pub async fn connect_with_transport(
147        address: &str,
148        client_id: &str,
149        timeout: Duration,
150        security_protocol: SecurityProtocol,
151        tls: &TlsConfig,
152        sasl: &SaslConfig,
153        tcp_connector: &Arc<dyn TcpConnector>,
154    ) -> Result<Self> {
155        async {
156            debug!(?security_protocol, "connecting to broker");
157            let stream =
158                connect_stream(address, timeout, security_protocol, tls, tcp_connector).await?;
159            let mut connection = Self {
160                stream,
161                next_correlation_id: 1,
162                api_versions: HashMap::new(),
163                finalized_features: HashMap::new(),
164            };
165            if security_protocol.uses_sasl() {
166                connection.authenticate_sasl(client_id, sasl).await?;
167            }
168            connection.negotiate_versions(client_id).await?;
169            debug!(
170                api_keys = connection.api_versions.len(),
171                finalized_features = connection.finalized_features.len(),
172                ?security_protocol,
173                "connected to broker"
174            );
175            Ok(connection)
176        }
177        .instrument(tracing::debug_span!(
178            "broker_connect",
179            %address,
180            %client_id,
181            timeout_ms = timeout.as_millis()
182        ))
183        .await
184    }
185
186    async fn authenticate_sasl(&mut self, client_id: &str, sasl: &SaslConfig) -> Result<()> {
187        let response = self
188            .send_request::<ApiVersionsRequest>(
189                client_id,
190                API_VERSIONS_FALLBACK_VERSION,
191                &ApiVersionsRequest::default(),
192            )
193            .await
194            .context("SASL ApiVersions probe failed")?;
195        if let Some(error) = response.error_code.err() {
196            bail!("SASL ApiVersions probe failed: {error}");
197        }
198
199        let api_versions = response
200            .api_keys
201            .into_iter()
202            .map(|api| {
203                (
204                    api.api_key,
205                    VersionRange {
206                        min: api.min_version,
207                        max: api.max_version,
208                    },
209                )
210            })
211            .collect::<HashMap<_, _>>();
212
213        let handshake_version = api_versions
214            .get(&SaslHandshakeRequest::KEY)
215            .copied()
216            .map(|range| {
217                select_api_version(
218                    SaslHandshakeRequest::KEY,
219                    range,
220                    SaslHandshakeRequest::VERSIONS,
221                    SaslHandshakeRequest::VERSIONS.max,
222                )
223            })
224            .transpose()?
225            .unwrap_or(0);
226        let authenticate_version = api_versions
227            .get(&SaslAuthenticateRequest::KEY)
228            .copied()
229            .map(|range| {
230                select_api_version(
231                    SaslAuthenticateRequest::KEY,
232                    range,
233                    SaslAuthenticateRequest::VERSIONS,
234                    SaslAuthenticateRequest::VERSIONS.max,
235                )
236            })
237            .transpose()?;
238
239        let mechanism = sasl.mechanism.as_str();
240        let handshake =
241            SaslHandshakeRequest::default().with_mechanism(StrBytes::from_static_str(mechanism));
242        let response = self
243            .send_request::<SaslHandshakeRequest>(client_id, handshake_version, &handshake)
244            .await
245            .context("SASL handshake request failed")?;
246        if let Some(error) = response.error_code.err() {
247            let enabled = response
248                .mechanisms
249                .iter()
250                .map(ToString::to_string)
251                .collect::<Vec<_>>()
252                .join(", ");
253            bail!(
254                "SASL handshake failed for mechanism {mechanism}: {error}; enabled mechanisms: [{enabled}]"
255            );
256        }
257
258        match sasl.mechanism {
259            SaslMechanism::Plain => {
260                let token = build_plain_sasl_token(sasl)?;
261                if let Some(version) = authenticate_version {
262                    self.send_sasl_authenticate(client_id, version, mechanism, token)
263                        .await?;
264                } else {
265                    write_raw_sasl_token(&mut self.stream, &token).await?;
266                }
267            }
268            SaslMechanism::ScramSha256 | SaslMechanism::ScramSha512 => {
269                self.authenticate_scram(client_id, sasl, authenticate_version)
270                    .await?;
271            }
272        }
273
274        debug!(mechanism, "completed SASL authentication");
275        Ok(())
276    }
277
278    async fn authenticate_scram(
279        &mut self,
280        client_id: &str,
281        sasl: &SaslConfig,
282        authenticate_version: Option<i16>,
283    ) -> Result<()> {
284        let username = sasl
285            .username
286            .as_ref()
287            .context("SASL/SCRAM requires a username")?
288            .clone();
289        let password = sasl
290            .password
291            .as_ref()
292            .context("SASL/SCRAM requires a password")?
293            .clone();
294        let mechanism = sasl.mechanism.as_str();
295        let mut scram = ScramClient::new(sasl.mechanism, username, password)?;
296        let client_first = scram.client_first_message();
297
298        let server_first = if let Some(version) = authenticate_version {
299            self.send_sasl_authenticate(client_id, version, mechanism, client_first)
300                .await?
301        } else {
302            write_raw_sasl_token(&mut self.stream, &client_first).await?;
303            read_frame(&mut self.stream).await?
304        };
305        let client_final = scram.handle_server_first_message(&server_first)?;
306
307        let server_final = if let Some(version) = authenticate_version {
308            self.send_sasl_authenticate(client_id, version, mechanism, client_final)
309                .await?
310        } else {
311            write_raw_sasl_token(&mut self.stream, &client_final).await?;
312            read_frame(&mut self.stream).await?
313        };
314        scram.handle_server_final_message(&server_final)?;
315        Ok(())
316    }
317
318    async fn send_sasl_authenticate(
319        &mut self,
320        client_id: &str,
321        version: i16,
322        mechanism: &str,
323        token: Vec<u8>,
324    ) -> Result<Vec<u8>> {
325        let request = SaslAuthenticateRequest::default().with_auth_bytes(Bytes::from(token));
326        let response = self
327            .send_request::<SaslAuthenticateRequest>(client_id, version, &request)
328            .await
329            .context("SASL authenticate request failed")?;
330        if let Some(error) = response.error_code.err() {
331            let message = response
332                .error_message
333                .as_ref()
334                .map(ToString::to_string)
335                .filter(|message| !message.is_empty())
336                .unwrap_or_else(|| error.to_string());
337            bail!("SASL authentication failed for mechanism {mechanism}: {message}");
338        }
339        Ok(response.auth_bytes.to_vec())
340    }
341
342    pub fn version_with_cap<Req>(&self, cap: i16) -> Result<i16>
343    where
344        Req: Request,
345    {
346        let broker_range = self
347            .api_versions
348            .get(&Req::KEY)
349            .copied()
350            .with_context(|| format!("broker did not advertise API key {}", Req::KEY))?;
351        select_api_version(Req::KEY, broker_range, Req::VERSIONS, cap)
352    }
353
354    pub fn finalized_feature_level(&self, feature: &str) -> Option<i16> {
355        self.finalized_features.get(feature).copied()
356    }
357
358    pub fn finalized_feature_levels(&self) -> Vec<(String, i16)> {
359        let mut features = self
360            .finalized_features
361            .iter()
362            .map(|(name, level)| (name.clone(), *level))
363            .collect::<Vec<_>>();
364        features.sort_by(|left, right| left.0.cmp(&right.0));
365        features
366    }
367
368    async fn negotiate_versions(&mut self, client_id: &str) -> Result<()> {
369        let modern_request = ApiVersionsRequest::default()
370            .with_client_software_name(StrBytes::from_static_str("kafkit-client"))
371            .with_client_software_version(StrBytes::from_static_str("0.2.0"));
372
373        let response = match self
374            .send_request::<ApiVersionsRequest>(
375                client_id,
376                API_VERSIONS_PROBE_VERSION,
377                &modern_request,
378            )
379            .await
380        {
381            Ok(response) => response,
382            Err(error) => {
383                debug!(
384                    error = %error,
385                    "modern ApiVersions probe failed, retrying with fallback request"
386                );
387                self.send_request::<ApiVersionsRequest>(
388                    client_id,
389                    API_VERSIONS_FALLBACK_VERSION,
390                    &ApiVersionsRequest::default(),
391                )
392                .await?
393            }
394        };
395
396        if let Some(error) = response.error_code.err() {
397            bail!("ApiVersions failed: {error}");
398        }
399
400        self.api_versions = response
401            .api_keys
402            .into_iter()
403            .map(|api| {
404                (
405                    api.api_key,
406                    VersionRange {
407                        min: api.min_version,
408                        max: api.max_version,
409                    },
410                )
411            })
412            .collect();
413        self.finalized_features = response
414            .finalized_features
415            .into_iter()
416            .map(|feature| (feature.name.to_string(), feature.max_version_level))
417            .collect();
418
419        trace!(
420            api_keys = self.api_versions.len(),
421            finalized_features = self.finalized_features.len(),
422            "negotiated broker ApiVersions"
423        );
424        Ok(())
425    }
426
427    pub async fn send_request<Req>(
428        &mut self,
429        client_id: &str,
430        version: i16,
431        request: &Req,
432    ) -> Result<Req::Response>
433    where
434        Req: Request,
435    {
436        let correlation_id = self.next_correlation_id;
437        self.next_correlation_id += 1;
438        let span = trace_span!(
439            "kafka_request",
440            request = std::any::type_name::<Req>(),
441            api_key = Req::KEY,
442            api_version = version,
443            correlation_id,
444            %client_id
445        );
446
447        async {
448            let mut body = BytesMut::new();
449            let header = RequestHeader::default()
450                .with_request_api_key(Req::KEY)
451                .with_request_api_version(version)
452                .with_correlation_id(correlation_id)
453                .with_client_id(Some(StrBytes::from_string(client_id.to_owned())));
454            encode_request_header_into_buffer(&mut body, &header)?;
455            request.encode(&mut body, version)?;
456
457            trace!(request_bytes = body.len(), "encoded Kafka request");
458
459            let mut frame = BytesMut::with_capacity(body.len() + 4);
460            frame.put_i32(i32::try_from(body.len()).context("request frame is too large")?);
461            frame.extend_from_slice(&body);
462
463            self.stream.write_all(&frame).await?;
464            trace!(frame_bytes = frame.len(), "wrote Kafka request frame");
465
466            let response_frame = read_frame(&mut self.stream).await?;
467            trace!(
468                response_bytes = response_frame.len(),
469                "received Kafka response frame"
470            );
471            let mut response_body = Bytes::from(response_frame);
472            let header_version = Req::Response::header_version(version);
473            let response_header = ResponseHeader::decode(&mut response_body, header_version)?;
474            if response_header.correlation_id != correlation_id {
475                bail!(
476                    "response correlation mismatch: expected {}, got {}",
477                    correlation_id,
478                    response_header.correlation_id
479                );
480            }
481
482            let response = Req::Response::decode(&mut response_body, version)?;
483            debug!("completed Kafka request");
484            Ok(response)
485        }
486        .instrument(span)
487        .await
488    }
489}
490
491async fn connect_stream(
492    address: &str,
493    timeout: Duration,
494    security_protocol: SecurityProtocol,
495    tls: &TlsConfig,
496    tcp_connector: &Arc<dyn TcpConnector>,
497) -> Result<BrokerStream> {
498    let tcp_stream = tcp_connector.connect(address, timeout).await?;
499    tcp_stream
500        .set_nodelay(true)
501        .with_context(|| format!("failed to enable TCP_NODELAY for {address}"))?;
502
503    if security_protocol.uses_tls() {
504        let ConnectedTcpStream::Tokio(tcp_stream) = tcp_stream else {
505            bail!("custom TCP connectors do not support TLS broker connections");
506        };
507        let tls_config = build_tls_client_config(tls)?;
508        let connector = TlsConnector::from(tls_config);
509        let server_name = server_name_for_tls(address, tls)?;
510        let stream = tokio::time::timeout(timeout, connector.connect(server_name, tcp_stream))
511            .await
512            .with_context(|| format!("timed out negotiating TLS with {address}"))?
513            .with_context(|| format!("failed TLS handshake with {address}"))?;
514        Ok(BrokerStream::Tls(Box::new(stream)))
515    } else {
516        match tcp_stream {
517            ConnectedTcpStream::Tokio(stream) => Ok(BrokerStream::Plain(Box::new(stream))),
518            ConnectedTcpStream::Custom(stream) => Ok(BrokerStream::Plain(stream)),
519        }
520    }
521}
522
523fn build_plain_sasl_token(sasl: &SaslConfig) -> Result<Vec<u8>> {
524    let username = sasl
525        .username
526        .as_deref()
527        .context("SASL/PLAIN requires a username")?;
528    let password = sasl
529        .password
530        .as_deref()
531        .context("SASL/PLAIN requires a password")?;
532    let authorization_id = sasl.authorization_id.as_deref().unwrap_or_default();
533
534    let mut token =
535        Vec::with_capacity(authorization_id.len() + username.len() + password.len() + 2);
536    token.extend_from_slice(authorization_id.as_bytes());
537    token.push(0);
538    token.extend_from_slice(username.as_bytes());
539    token.push(0);
540    token.extend_from_slice(password.as_bytes());
541    Ok(token)
542}
543
544async fn write_raw_sasl_token(stream: &mut BrokerStream, token: &[u8]) -> Result<()> {
545    let mut frame = BytesMut::with_capacity(token.len() + 4);
546    frame.put_i32(i32::try_from(token.len()).context("SASL token frame is too large")?);
547    frame.extend_from_slice(token);
548    stream.write_all(&frame).await
549}
550
551fn build_tls_client_config(tls: &TlsConfig) -> Result<Arc<RustlsClientConfig>> {
552    let mut root_store = RootCertStore::empty();
553    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
554
555    if let Some(ca_cert_path) = tls.ca_cert_path.as_deref() {
556        for cert in load_certificates(ca_cert_path)? {
557            root_store.add(cert)?;
558        }
559    }
560
561    let builder = RustlsClientConfig::builder().with_root_certificates(root_store);
562    let config = match (
563        tls.client_cert_path.as_deref(),
564        tls.client_key_path.as_deref(),
565    ) {
566        (Some(client_cert_path), Some(client_key_path)) => builder.with_client_auth_cert(
567            load_certificates(client_cert_path)?,
568            load_private_key(client_key_path)?,
569        )?,
570        (None, None) => builder.with_no_client_auth(),
571        _ => bail!("TLS client auth requires both client_cert_path and client_key_path"),
572    };
573
574    Ok(Arc::new(config))
575}
576
577fn load_certificates(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
578    let file = File::open(path)
579        .with_context(|| format!("failed to open TLS certificate file '{}'", path.display()))?;
580    let mut reader = BufReader::new(file);
581    let certs = rustls_pemfile::certs(&mut reader)
582        .collect::<std::result::Result<Vec<_>, _>>()
583        .with_context(|| format!("failed to parse TLS certificate PEM '{}'", path.display()))?;
584    if certs.is_empty() {
585        bail!(
586            "TLS certificate file '{}' did not contain any PEM certificates",
587            path.display()
588        );
589    }
590    Ok(certs)
591}
592
593fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>> {
594    let file = File::open(path)
595        .with_context(|| format!("failed to open TLS private key file '{}'", path.display()))?;
596    let mut reader = BufReader::new(file);
597    rustls_pemfile::private_key(&mut reader)
598        .with_context(|| format!("failed to parse TLS private key PEM '{}'", path.display()))?
599        .with_context(|| {
600            format!(
601                "TLS private key file '{}' did not contain a PEM key",
602                path.display()
603            )
604        })
605}
606
607fn server_name_for_tls(address: &str, tls: &TlsConfig) -> Result<ServerName<'static>> {
608    if let Some(server_name) = tls.server_name.as_ref() {
609        return ServerName::try_from(server_name.clone())
610            .with_context(|| format!("invalid TLS server name '{}'", server_name));
611    }
612
613    if let Ok(socket_addr) = address.parse::<SocketAddr>() {
614        return Ok(ServerName::IpAddress(socket_addr.ip().into()));
615    }
616
617    let host = if let Some(stripped) = address.strip_prefix('[') {
618        stripped
619            .split(']')
620            .next()
621            .context("invalid bracketed broker address")?
622            .to_owned()
623    } else {
624        address
625            .rsplit_once(':')
626            .map(|(host, _)| host.to_owned())
627            .unwrap_or_else(|| address.to_owned())
628    };
629
630    ServerName::try_from(host.clone()).with_context(|| {
631        format!("could not derive a valid TLS server name from broker address '{address}'")
632    })
633}
634
635async fn read_frame(stream: &mut BrokerStream) -> Result<Vec<u8>> {
636    let mut header = [0_u8; 4];
637    stream.read_exact(&mut header).await?;
638    let frame_len = i32::from_be_bytes(header);
639    if frame_len < 0 {
640        bail!("broker returned a negative frame length: {frame_len}");
641    }
642
643    let mut payload = vec![0_u8; usize::try_from(frame_len)?];
644    stream.read_exact(&mut payload).await?;
645    Ok(payload)
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651    use std::fs;
652
653    #[test]
654    fn tls_server_name_defaults_to_host() {
655        let server_name =
656            server_name_for_tls("broker.example.com:9093", &TlsConfig::default()).unwrap();
657        assert_eq!(server_name.to_str(), "broker.example.com");
658    }
659
660    #[test]
661    fn tls_server_name_respects_explicit_override() {
662        let tls = TlsConfig::new().with_server_name("cluster.internal");
663        let server_name = server_name_for_tls("127.0.0.1:9093", &tls).unwrap();
664        assert_eq!(server_name.to_str(), "cluster.internal");
665    }
666
667    #[test]
668    fn tls_server_name_handles_ip_and_bracketed_ipv6() {
669        let server_name = server_name_for_tls("127.0.0.1:9093", &TlsConfig::default()).unwrap();
670        assert_eq!(server_name.to_str(), "127.0.0.1");
671
672        let server_name = server_name_for_tls("[::1]:9093", &TlsConfig::default()).unwrap();
673        assert_eq!(server_name.to_str(), "::1");
674    }
675
676    #[test]
677    fn tls_server_name_rejects_invalid_override_and_empty_address() {
678        let tls = TlsConfig::new().with_server_name("not a dns name");
679        assert!(server_name_for_tls("127.0.0.1:9093", &tls).is_err());
680        assert!(server_name_for_tls("", &TlsConfig::default()).is_err());
681    }
682
683    #[test]
684    fn plain_sasl_token_requires_credentials_and_uses_authorization_id() {
685        assert!(build_plain_sasl_token(&SaslConfig::default()).is_err());
686        assert!(
687            build_plain_sasl_token(&SaslConfig::plain("user", "pw").with_authorization_id("authz"))
688                .unwrap()
689                == b"authz\0user\0pw"
690        );
691    }
692
693    #[test]
694    fn tls_file_loaders_reject_missing_empty_and_invalid_pem_files() {
695        let dir =
696            std::env::temp_dir().join(format!("kafkit-client-tls-test-{}", std::process::id()));
697        fs::create_dir_all(&dir).unwrap();
698        let cert_path = dir.join("cert.pem");
699        let key_path = dir.join("key.pem");
700        fs::write(&cert_path, b"not a certificate").unwrap();
701        fs::write(&key_path, b"not a key").unwrap();
702
703        assert!(load_certificates(&cert_path).is_err());
704        assert!(load_private_key(&key_path).is_err());
705        assert!(load_certificates(&dir.join("missing.pem")).is_err());
706
707        let _ = fs::remove_dir_all(dir);
708    }
709}