1use std::collections::HashMap;
2use std::fs::File;
3use std::future::Future;
4use std::io::BufReader;
5use std::net::SocketAddr;
6use std::path::Path;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Duration;
10
11use anyhow::{Context, Result, bail};
12use bytes::{BufMut, Bytes, BytesMut};
13use kafka_protocol::error::ParseResponseErrorCode;
14use kafka_protocol::messages::{
15 ApiVersionsRequest, RequestHeader, ResponseHeader, SaslAuthenticateRequest,
16 SaslHandshakeRequest,
17};
18use kafka_protocol::protocol::{
19 Decodable, HeaderVersion, Message, Request, StrBytes, VersionRange,
20 encode_request_header_into_buffer,
21};
22use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
23use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
24use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
25use tokio::net::TcpStream;
26use tokio_rustls::TlsConnector;
27use tokio_rustls::client::TlsStream;
28use tracing::{Instrument, debug, trace, trace_span};
29
30use super::scram::ScramClient;
31use super::select_api_version;
32use crate::config::{SaslConfig, SaslMechanism, SecurityProtocol, TlsConfig};
33use crate::constants::{API_VERSIONS_FALLBACK_VERSION, API_VERSIONS_PROBE_VERSION};
34
35pub async fn connect_to_any_bootstrap(
36 servers: &[String],
37 client_id: &str,
38 timeout: Duration,
39 security_protocol: SecurityProtocol,
40 tls: &TlsConfig,
41 sasl: &SaslConfig,
42 tcp_connector: &Arc<dyn TcpConnector>,
43) -> Result<BrokerConnection> {
44 if servers.is_empty() {
45 bail!("no bootstrap servers configured");
46 }
47
48 let mut last_error: Option<anyhow::Error> = None;
49 for server in servers {
50 match BrokerConnection::connect_with_transport(
51 server,
52 client_id,
53 timeout,
54 security_protocol,
55 tls,
56 sasl,
57 tcp_connector,
58 )
59 .await
60 {
61 Ok(conn) => return Ok(conn),
62 Err(e) => {
63 debug!(server = %server, error = %e, "bootstrap connection failed, trying next server");
64 last_error = Some(e);
65 }
66 }
67 }
68 Err(last_error.unwrap())
69}
70
71type ConnectFuture<'a> = Pin<Box<dyn Future<Output = Result<ConnectedTcpStream>> + Send + 'a>>;
72
73pub trait BrokerIo: AsyncRead + AsyncWrite + Unpin + Send {}
74
75impl<T> BrokerIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
76
77pub enum ConnectedTcpStream {
78 Tokio(TcpStream),
79 Custom(Box<dyn BrokerIo>),
80}
81
82impl ConnectedTcpStream {
83 fn set_nodelay(&self, nodelay: bool) -> Result<()> {
84 match self {
85 Self::Tokio(stream) => stream.set_nodelay(nodelay)?,
86 Self::Custom(_) => {}
87 }
88 Ok(())
89 }
90}
91
92pub trait TcpConnector: std::fmt::Debug + Send + Sync {
93 fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a>;
94}
95
96#[derive(Debug, Default)]
97pub struct TokioTcpConnector;
98
99impl TcpConnector for TokioTcpConnector {
100 fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a> {
101 Box::pin(async move {
102 let tcp_stream = tokio::time::timeout(timeout, TcpStream::connect(address))
103 .await
104 .with_context(|| format!("timed out connecting to {address}"))?
105 .with_context(|| format!("failed to connect to {address}"))?;
106 Ok(ConnectedTcpStream::Tokio(tcp_stream))
107 })
108 }
109}
110
111pub struct BrokerConnection {
112 stream: BrokerStream,
113 next_correlation_id: i32,
114 api_versions: HashMap<i16, VersionRange>,
115 finalized_features: HashMap<String, i16>,
116}
117
118enum BrokerStream {
119 Plain(Box<dyn BrokerIo>),
120 Tls(Box<TlsStream<TcpStream>>),
121}
122
123impl BrokerStream {
124 async fn write_all(&mut self, frame: &[u8]) -> Result<()> {
125 match self {
126 Self::Plain(stream) => stream.write_all(frame).await?,
127 Self::Tls(stream) => stream.write_all(frame).await?,
128 }
129 Ok(())
130 }
131
132 async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
133 match self {
134 Self::Plain(stream) => {
135 stream.read_exact(buf).await?;
136 }
137 Self::Tls(stream) => {
138 stream.read_exact(buf).await?;
139 }
140 };
141 Ok(())
142 }
143}
144
145impl BrokerConnection {
146 pub async fn connect_with_transport(
147 address: &str,
148 client_id: &str,
149 timeout: Duration,
150 security_protocol: SecurityProtocol,
151 tls: &TlsConfig,
152 sasl: &SaslConfig,
153 tcp_connector: &Arc<dyn TcpConnector>,
154 ) -> Result<Self> {
155 async {
156 debug!(?security_protocol, "connecting to broker");
157 let stream =
158 connect_stream(address, timeout, security_protocol, tls, tcp_connector).await?;
159 let mut connection = Self {
160 stream,
161 next_correlation_id: 1,
162 api_versions: HashMap::new(),
163 finalized_features: HashMap::new(),
164 };
165 if security_protocol.uses_sasl() {
166 connection.authenticate_sasl(client_id, sasl).await?;
167 }
168 connection.negotiate_versions(client_id).await?;
169 debug!(
170 api_keys = connection.api_versions.len(),
171 finalized_features = connection.finalized_features.len(),
172 ?security_protocol,
173 "connected to broker"
174 );
175 Ok(connection)
176 }
177 .instrument(tracing::debug_span!(
178 "broker_connect",
179 %address,
180 %client_id,
181 timeout_ms = timeout.as_millis()
182 ))
183 .await
184 }
185
186 async fn authenticate_sasl(&mut self, client_id: &str, sasl: &SaslConfig) -> Result<()> {
187 let response = self
188 .send_request::<ApiVersionsRequest>(
189 client_id,
190 API_VERSIONS_FALLBACK_VERSION,
191 &ApiVersionsRequest::default(),
192 )
193 .await
194 .context("SASL ApiVersions probe failed")?;
195 if let Some(error) = response.error_code.err() {
196 bail!("SASL ApiVersions probe failed: {error}");
197 }
198
199 let api_versions = response
200 .api_keys
201 .into_iter()
202 .map(|api| {
203 (
204 api.api_key,
205 VersionRange {
206 min: api.min_version,
207 max: api.max_version,
208 },
209 )
210 })
211 .collect::<HashMap<_, _>>();
212
213 let handshake_version = api_versions
214 .get(&SaslHandshakeRequest::KEY)
215 .copied()
216 .map(|range| {
217 select_api_version(
218 SaslHandshakeRequest::KEY,
219 range,
220 SaslHandshakeRequest::VERSIONS,
221 SaslHandshakeRequest::VERSIONS.max,
222 )
223 })
224 .transpose()?
225 .unwrap_or(0);
226 let authenticate_version = api_versions
227 .get(&SaslAuthenticateRequest::KEY)
228 .copied()
229 .map(|range| {
230 select_api_version(
231 SaslAuthenticateRequest::KEY,
232 range,
233 SaslAuthenticateRequest::VERSIONS,
234 SaslAuthenticateRequest::VERSIONS.max,
235 )
236 })
237 .transpose()?;
238
239 let mechanism = sasl.mechanism.as_str();
240 let handshake =
241 SaslHandshakeRequest::default().with_mechanism(StrBytes::from_static_str(mechanism));
242 let response = self
243 .send_request::<SaslHandshakeRequest>(client_id, handshake_version, &handshake)
244 .await
245 .context("SASL handshake request failed")?;
246 if let Some(error) = response.error_code.err() {
247 let enabled = response
248 .mechanisms
249 .iter()
250 .map(ToString::to_string)
251 .collect::<Vec<_>>()
252 .join(", ");
253 bail!(
254 "SASL handshake failed for mechanism {mechanism}: {error}; enabled mechanisms: [{enabled}]"
255 );
256 }
257
258 match sasl.mechanism {
259 SaslMechanism::Plain => {
260 let token = build_plain_sasl_token(sasl)?;
261 if let Some(version) = authenticate_version {
262 self.send_sasl_authenticate(client_id, version, mechanism, token)
263 .await?;
264 } else {
265 write_raw_sasl_token(&mut self.stream, &token).await?;
266 }
267 }
268 SaslMechanism::ScramSha256 | SaslMechanism::ScramSha512 => {
269 self.authenticate_scram(client_id, sasl, authenticate_version)
270 .await?;
271 }
272 }
273
274 debug!(mechanism, "completed SASL authentication");
275 Ok(())
276 }
277
278 async fn authenticate_scram(
279 &mut self,
280 client_id: &str,
281 sasl: &SaslConfig,
282 authenticate_version: Option<i16>,
283 ) -> Result<()> {
284 let username = sasl
285 .username
286 .as_ref()
287 .context("SASL/SCRAM requires a username")?
288 .clone();
289 let password = sasl
290 .password
291 .as_ref()
292 .context("SASL/SCRAM requires a password")?
293 .clone();
294 let mechanism = sasl.mechanism.as_str();
295 let mut scram = ScramClient::new(sasl.mechanism, username, password)?;
296 let client_first = scram.client_first_message();
297
298 let server_first = if let Some(version) = authenticate_version {
299 self.send_sasl_authenticate(client_id, version, mechanism, client_first)
300 .await?
301 } else {
302 write_raw_sasl_token(&mut self.stream, &client_first).await?;
303 read_frame(&mut self.stream).await?
304 };
305 let client_final = scram.handle_server_first_message(&server_first)?;
306
307 let server_final = if let Some(version) = authenticate_version {
308 self.send_sasl_authenticate(client_id, version, mechanism, client_final)
309 .await?
310 } else {
311 write_raw_sasl_token(&mut self.stream, &client_final).await?;
312 read_frame(&mut self.stream).await?
313 };
314 scram.handle_server_final_message(&server_final)?;
315 Ok(())
316 }
317
318 async fn send_sasl_authenticate(
319 &mut self,
320 client_id: &str,
321 version: i16,
322 mechanism: &str,
323 token: Vec<u8>,
324 ) -> Result<Vec<u8>> {
325 let request = SaslAuthenticateRequest::default().with_auth_bytes(Bytes::from(token));
326 let response = self
327 .send_request::<SaslAuthenticateRequest>(client_id, version, &request)
328 .await
329 .context("SASL authenticate request failed")?;
330 if let Some(error) = response.error_code.err() {
331 let message = response
332 .error_message
333 .as_ref()
334 .map(ToString::to_string)
335 .filter(|message| !message.is_empty())
336 .unwrap_or_else(|| error.to_string());
337 bail!("SASL authentication failed for mechanism {mechanism}: {message}");
338 }
339 Ok(response.auth_bytes.to_vec())
340 }
341
342 pub fn version_with_cap<Req>(&self, cap: i16) -> Result<i16>
343 where
344 Req: Request,
345 {
346 let broker_range = self
347 .api_versions
348 .get(&Req::KEY)
349 .copied()
350 .with_context(|| format!("broker did not advertise API key {}", Req::KEY))?;
351 select_api_version(Req::KEY, broker_range, Req::VERSIONS, cap)
352 }
353
354 pub fn finalized_feature_level(&self, feature: &str) -> Option<i16> {
355 self.finalized_features.get(feature).copied()
356 }
357
358 pub fn finalized_feature_levels(&self) -> Vec<(String, i16)> {
359 let mut features = self
360 .finalized_features
361 .iter()
362 .map(|(name, level)| (name.clone(), *level))
363 .collect::<Vec<_>>();
364 features.sort_by(|left, right| left.0.cmp(&right.0));
365 features
366 }
367
368 async fn negotiate_versions(&mut self, client_id: &str) -> Result<()> {
369 let modern_request = ApiVersionsRequest::default()
370 .with_client_software_name(StrBytes::from_static_str("kafkit-client"))
371 .with_client_software_version(StrBytes::from_static_str("0.2.0"));
372
373 let response = match self
374 .send_request::<ApiVersionsRequest>(
375 client_id,
376 API_VERSIONS_PROBE_VERSION,
377 &modern_request,
378 )
379 .await
380 {
381 Ok(response) => response,
382 Err(error) => {
383 debug!(
384 error = %error,
385 "modern ApiVersions probe failed, retrying with fallback request"
386 );
387 self.send_request::<ApiVersionsRequest>(
388 client_id,
389 API_VERSIONS_FALLBACK_VERSION,
390 &ApiVersionsRequest::default(),
391 )
392 .await?
393 }
394 };
395
396 if let Some(error) = response.error_code.err() {
397 bail!("ApiVersions failed: {error}");
398 }
399
400 self.api_versions = response
401 .api_keys
402 .into_iter()
403 .map(|api| {
404 (
405 api.api_key,
406 VersionRange {
407 min: api.min_version,
408 max: api.max_version,
409 },
410 )
411 })
412 .collect();
413 self.finalized_features = response
414 .finalized_features
415 .into_iter()
416 .map(|feature| (feature.name.to_string(), feature.max_version_level))
417 .collect();
418
419 trace!(
420 api_keys = self.api_versions.len(),
421 finalized_features = self.finalized_features.len(),
422 "negotiated broker ApiVersions"
423 );
424 Ok(())
425 }
426
427 pub async fn send_request<Req>(
428 &mut self,
429 client_id: &str,
430 version: i16,
431 request: &Req,
432 ) -> Result<Req::Response>
433 where
434 Req: Request,
435 {
436 let correlation_id = self.next_correlation_id;
437 self.next_correlation_id += 1;
438 let span = trace_span!(
439 "kafka_request",
440 request = std::any::type_name::<Req>(),
441 api_key = Req::KEY,
442 api_version = version,
443 correlation_id,
444 %client_id
445 );
446
447 async {
448 let mut body = BytesMut::new();
449 let header = RequestHeader::default()
450 .with_request_api_key(Req::KEY)
451 .with_request_api_version(version)
452 .with_correlation_id(correlation_id)
453 .with_client_id(Some(StrBytes::from_string(client_id.to_owned())));
454 encode_request_header_into_buffer(&mut body, &header)?;
455 request.encode(&mut body, version)?;
456
457 trace!(request_bytes = body.len(), "encoded Kafka request");
458
459 let mut frame = BytesMut::with_capacity(body.len() + 4);
460 frame.put_i32(i32::try_from(body.len()).context("request frame is too large")?);
461 frame.extend_from_slice(&body);
462
463 self.stream.write_all(&frame).await?;
464 trace!(frame_bytes = frame.len(), "wrote Kafka request frame");
465
466 let response_frame = read_frame(&mut self.stream).await?;
467 trace!(
468 response_bytes = response_frame.len(),
469 "received Kafka response frame"
470 );
471 let mut response_body = Bytes::from(response_frame);
472 let header_version = Req::Response::header_version(version);
473 let response_header = ResponseHeader::decode(&mut response_body, header_version)?;
474 if response_header.correlation_id != correlation_id {
475 bail!(
476 "response correlation mismatch: expected {}, got {}",
477 correlation_id,
478 response_header.correlation_id
479 );
480 }
481
482 let response = Req::Response::decode(&mut response_body, version)?;
483 debug!("completed Kafka request");
484 Ok(response)
485 }
486 .instrument(span)
487 .await
488 }
489}
490
491async fn connect_stream(
492 address: &str,
493 timeout: Duration,
494 security_protocol: SecurityProtocol,
495 tls: &TlsConfig,
496 tcp_connector: &Arc<dyn TcpConnector>,
497) -> Result<BrokerStream> {
498 let tcp_stream = tcp_connector.connect(address, timeout).await?;
499 tcp_stream
500 .set_nodelay(true)
501 .with_context(|| format!("failed to enable TCP_NODELAY for {address}"))?;
502
503 if security_protocol.uses_tls() {
504 let ConnectedTcpStream::Tokio(tcp_stream) = tcp_stream else {
505 bail!("custom TCP connectors do not support TLS broker connections");
506 };
507 let tls_config = build_tls_client_config(tls)?;
508 let connector = TlsConnector::from(tls_config);
509 let server_name = server_name_for_tls(address, tls)?;
510 let stream = tokio::time::timeout(timeout, connector.connect(server_name, tcp_stream))
511 .await
512 .with_context(|| format!("timed out negotiating TLS with {address}"))?
513 .with_context(|| format!("failed TLS handshake with {address}"))?;
514 Ok(BrokerStream::Tls(Box::new(stream)))
515 } else {
516 match tcp_stream {
517 ConnectedTcpStream::Tokio(stream) => Ok(BrokerStream::Plain(Box::new(stream))),
518 ConnectedTcpStream::Custom(stream) => Ok(BrokerStream::Plain(stream)),
519 }
520 }
521}
522
523fn build_plain_sasl_token(sasl: &SaslConfig) -> Result<Vec<u8>> {
524 let username = sasl
525 .username
526 .as_deref()
527 .context("SASL/PLAIN requires a username")?;
528 let password = sasl
529 .password
530 .as_deref()
531 .context("SASL/PLAIN requires a password")?;
532 let authorization_id = sasl.authorization_id.as_deref().unwrap_or_default();
533
534 let mut token =
535 Vec::with_capacity(authorization_id.len() + username.len() + password.len() + 2);
536 token.extend_from_slice(authorization_id.as_bytes());
537 token.push(0);
538 token.extend_from_slice(username.as_bytes());
539 token.push(0);
540 token.extend_from_slice(password.as_bytes());
541 Ok(token)
542}
543
544async fn write_raw_sasl_token(stream: &mut BrokerStream, token: &[u8]) -> Result<()> {
545 let mut frame = BytesMut::with_capacity(token.len() + 4);
546 frame.put_i32(i32::try_from(token.len()).context("SASL token frame is too large")?);
547 frame.extend_from_slice(token);
548 stream.write_all(&frame).await
549}
550
551fn build_tls_client_config(tls: &TlsConfig) -> Result<Arc<RustlsClientConfig>> {
552 let mut root_store = RootCertStore::empty();
553 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
554
555 if let Some(ca_cert_path) = tls.ca_cert_path.as_deref() {
556 for cert in load_certificates(ca_cert_path)? {
557 root_store.add(cert)?;
558 }
559 }
560
561 let builder = RustlsClientConfig::builder().with_root_certificates(root_store);
562 let config = match (
563 tls.client_cert_path.as_deref(),
564 tls.client_key_path.as_deref(),
565 ) {
566 (Some(client_cert_path), Some(client_key_path)) => builder.with_client_auth_cert(
567 load_certificates(client_cert_path)?,
568 load_private_key(client_key_path)?,
569 )?,
570 (None, None) => builder.with_no_client_auth(),
571 _ => bail!("TLS client auth requires both client_cert_path and client_key_path"),
572 };
573
574 Ok(Arc::new(config))
575}
576
577fn load_certificates(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
578 let file = File::open(path)
579 .with_context(|| format!("failed to open TLS certificate file '{}'", path.display()))?;
580 let mut reader = BufReader::new(file);
581 let certs = rustls_pemfile::certs(&mut reader)
582 .collect::<std::result::Result<Vec<_>, _>>()
583 .with_context(|| format!("failed to parse TLS certificate PEM '{}'", path.display()))?;
584 if certs.is_empty() {
585 bail!(
586 "TLS certificate file '{}' did not contain any PEM certificates",
587 path.display()
588 );
589 }
590 Ok(certs)
591}
592
593fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>> {
594 let file = File::open(path)
595 .with_context(|| format!("failed to open TLS private key file '{}'", path.display()))?;
596 let mut reader = BufReader::new(file);
597 rustls_pemfile::private_key(&mut reader)
598 .with_context(|| format!("failed to parse TLS private key PEM '{}'", path.display()))?
599 .with_context(|| {
600 format!(
601 "TLS private key file '{}' did not contain a PEM key",
602 path.display()
603 )
604 })
605}
606
607fn server_name_for_tls(address: &str, tls: &TlsConfig) -> Result<ServerName<'static>> {
608 if let Some(server_name) = tls.server_name.as_ref() {
609 return ServerName::try_from(server_name.clone())
610 .with_context(|| format!("invalid TLS server name '{}'", server_name));
611 }
612
613 if let Ok(socket_addr) = address.parse::<SocketAddr>() {
614 return Ok(ServerName::IpAddress(socket_addr.ip().into()));
615 }
616
617 let host = if let Some(stripped) = address.strip_prefix('[') {
618 stripped
619 .split(']')
620 .next()
621 .context("invalid bracketed broker address")?
622 .to_owned()
623 } else {
624 address
625 .rsplit_once(':')
626 .map(|(host, _)| host.to_owned())
627 .unwrap_or_else(|| address.to_owned())
628 };
629
630 ServerName::try_from(host.clone()).with_context(|| {
631 format!("could not derive a valid TLS server name from broker address '{address}'")
632 })
633}
634
635async fn read_frame(stream: &mut BrokerStream) -> Result<Vec<u8>> {
636 let mut header = [0_u8; 4];
637 stream.read_exact(&mut header).await?;
638 let frame_len = i32::from_be_bytes(header);
639 if frame_len < 0 {
640 bail!("broker returned a negative frame length: {frame_len}");
641 }
642
643 let mut payload = vec![0_u8; usize::try_from(frame_len)?];
644 stream.read_exact(&mut payload).await?;
645 Ok(payload)
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651 use std::fs;
652
653 #[test]
654 fn tls_server_name_defaults_to_host() {
655 let server_name =
656 server_name_for_tls("broker.example.com:9093", &TlsConfig::default()).unwrap();
657 assert_eq!(server_name.to_str(), "broker.example.com");
658 }
659
660 #[test]
661 fn tls_server_name_respects_explicit_override() {
662 let tls = TlsConfig::new().with_server_name("cluster.internal");
663 let server_name = server_name_for_tls("127.0.0.1:9093", &tls).unwrap();
664 assert_eq!(server_name.to_str(), "cluster.internal");
665 }
666
667 #[test]
668 fn tls_server_name_handles_ip_and_bracketed_ipv6() {
669 let server_name = server_name_for_tls("127.0.0.1:9093", &TlsConfig::default()).unwrap();
670 assert_eq!(server_name.to_str(), "127.0.0.1");
671
672 let server_name = server_name_for_tls("[::1]:9093", &TlsConfig::default()).unwrap();
673 assert_eq!(server_name.to_str(), "::1");
674 }
675
676 #[test]
677 fn tls_server_name_rejects_invalid_override_and_empty_address() {
678 let tls = TlsConfig::new().with_server_name("not a dns name");
679 assert!(server_name_for_tls("127.0.0.1:9093", &tls).is_err());
680 assert!(server_name_for_tls("", &TlsConfig::default()).is_err());
681 }
682
683 #[test]
684 fn plain_sasl_token_requires_credentials_and_uses_authorization_id() {
685 assert!(build_plain_sasl_token(&SaslConfig::default()).is_err());
686 assert!(
687 build_plain_sasl_token(&SaslConfig::plain("user", "pw").with_authorization_id("authz"))
688 .unwrap()
689 == b"authz\0user\0pw"
690 );
691 }
692
693 #[test]
694 fn tls_file_loaders_reject_missing_empty_and_invalid_pem_files() {
695 let dir =
696 std::env::temp_dir().join(format!("kafkit-client-tls-test-{}", std::process::id()));
697 fs::create_dir_all(&dir).unwrap();
698 let cert_path = dir.join("cert.pem");
699 let key_path = dir.join("key.pem");
700 fs::write(&cert_path, b"not a certificate").unwrap();
701 fs::write(&key_path, b"not a key").unwrap();
702
703 assert!(load_certificates(&cert_path).is_err());
704 assert!(load_private_key(&key_path).is_err());
705 assert!(load_certificates(&dir.join("missing.pem")).is_err());
706
707 let _ = fs::remove_dir_all(dir);
708 }
709}