use std::collections::HashMap;
use std::fs::File;
use std::future::Future;
use std::io::BufReader;
use std::net::SocketAddr;
use std::path::Path;
use std::pin::Pin;
use std::sync::{Arc, Once};
use std::time::{Duration, Instant};
use anyhow::{Context, Result, bail};
use bytes::{BufMut, Bytes, BytesMut};
use kafka_protocol::error::ParseResponseErrorCode;
use kafka_protocol::messages::{
ApiVersionsRequest, RequestHeader, ResponseHeader, SaslAuthenticateRequest,
SaslHandshakeRequest,
};
use kafka_protocol::protocol::{
Decodable, HeaderVersion, Message, Request, StrBytes, VersionRange,
encode_request_header_into_buffer,
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use tokio_rustls::client::TlsStream;
use tracing::{Instrument, debug, trace, trace_span};
use super::scram::ScramClient;
use super::select_api_version;
use crate::config::{SaslConfig, SaslMechanism, SecurityProtocol, TlsConfig};
use crate::constants::{API_VERSIONS_FALLBACK_VERSION, API_VERSIONS_PROBE_VERSION};
use crate::telemetry;
pub async fn connect_to_any_bootstrap(
servers: &[String],
client_id: &str,
timeout: Duration,
security_protocol: SecurityProtocol,
tls: &TlsConfig,
sasl: &SaslConfig,
tcp_connector: &Arc<dyn TcpConnector>,
) -> Result<BrokerConnection> {
if servers.is_empty() {
bail!("no bootstrap servers configured");
}
let mut last_error: Option<anyhow::Error> = None;
for server in servers {
match BrokerConnection::connect_with_transport(
server,
client_id,
timeout,
security_protocol,
tls,
sasl,
tcp_connector,
)
.await
{
Ok(conn) => return Ok(conn),
Err(e) => {
debug!(server = %server, error = %e, "bootstrap connection failed, trying next server");
last_error = Some(e);
}
}
}
Err(last_error.unwrap())
}
type ConnectFuture<'a> = Pin<Box<dyn Future<Output = Result<ConnectedTcpStream>> + Send + 'a>>;
pub trait BrokerIo: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T> BrokerIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
pub enum ConnectedTcpStream {
Tokio(TcpStream),
Custom(Box<dyn BrokerIo>),
}
impl ConnectedTcpStream {
fn set_nodelay(&self, nodelay: bool) -> Result<()> {
match self {
Self::Tokio(stream) => stream.set_nodelay(nodelay)?,
Self::Custom(_) => {}
}
Ok(())
}
}
pub trait TcpConnector: std::fmt::Debug + Send + Sync {
fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a>;
}
#[derive(Debug, Default)]
pub struct TokioTcpConnector;
impl TcpConnector for TokioTcpConnector {
fn connect<'a>(&'a self, address: &'a str, timeout: Duration) -> ConnectFuture<'a> {
Box::pin(async move {
let tcp_stream = tokio::time::timeout(timeout, TcpStream::connect(address))
.await
.with_context(|| format!("timed out connecting to {address}"))?
.with_context(|| format!("failed to connect to {address}"))?;
Ok(ConnectedTcpStream::Tokio(tcp_stream))
})
}
}
pub struct BrokerConnection {
stream: BrokerStream,
next_correlation_id: i32,
api_versions: HashMap<i16, VersionRange>,
finalized_features: HashMap<String, i16>,
}
enum BrokerStream {
Plain(Box<dyn BrokerIo>),
Tls(Box<TlsStream<TcpStream>>),
}
impl BrokerStream {
async fn write_all(&mut self, frame: &[u8]) -> Result<()> {
match self {
Self::Plain(stream) => stream.write_all(frame).await?,
Self::Tls(stream) => stream.write_all(frame).await?,
}
Ok(())
}
async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
match self {
Self::Plain(stream) => {
stream.read_exact(buf).await?;
}
Self::Tls(stream) => {
stream.read_exact(buf).await?;
}
};
Ok(())
}
}
impl BrokerConnection {
pub async fn connect_with_transport(
address: &str,
client_id: &str,
timeout: Duration,
security_protocol: SecurityProtocol,
tls: &TlsConfig,
sasl: &SaslConfig,
tcp_connector: &Arc<dyn TcpConnector>,
) -> Result<Self> {
let started = Instant::now();
let result = async {
debug!(?security_protocol, "connecting to broker");
let stream =
connect_stream(address, timeout, security_protocol, tls, tcp_connector).await?;
let mut connection = Self {
stream,
next_correlation_id: 1,
api_versions: HashMap::new(),
finalized_features: HashMap::new(),
};
if security_protocol.uses_sasl() {
connection.authenticate_sasl(client_id, sasl).await?;
}
connection.negotiate_versions(client_id).await?;
debug!(
api_keys = connection.api_versions.len(),
finalized_features = connection.finalized_features.len(),
?security_protocol,
"connected to broker"
);
Ok(connection)
}
.instrument(tracing::debug_span!(
"broker_connect",
%address,
%client_id,
timeout_ms = timeout.as_millis()
))
.await;
telemetry::record_broker_connection(
client_id,
address,
&format!("{security_protocol:?}"),
started.elapsed(),
result.is_ok(),
);
result
}
async fn authenticate_sasl(&mut self, client_id: &str, sasl: &SaslConfig) -> Result<()> {
let response = self
.send_request::<ApiVersionsRequest>(
client_id,
API_VERSIONS_FALLBACK_VERSION,
&ApiVersionsRequest::default(),
)
.await
.context("SASL ApiVersions probe failed")?;
if let Some(error) = response.error_code.err() {
bail!("SASL ApiVersions probe failed: {error}");
}
let api_versions = response
.api_keys
.into_iter()
.map(|api| {
(
api.api_key,
VersionRange {
min: api.min_version,
max: api.max_version,
},
)
})
.collect::<HashMap<_, _>>();
let handshake_version = api_versions
.get(&SaslHandshakeRequest::KEY)
.copied()
.map(|range| {
select_api_version(
SaslHandshakeRequest::KEY,
range,
SaslHandshakeRequest::VERSIONS,
SaslHandshakeRequest::VERSIONS.max,
)
})
.transpose()?
.unwrap_or(0);
let authenticate_version = api_versions
.get(&SaslAuthenticateRequest::KEY)
.copied()
.map(|range| {
select_api_version(
SaslAuthenticateRequest::KEY,
range,
SaslAuthenticateRequest::VERSIONS,
SaslAuthenticateRequest::VERSIONS.max,
)
})
.transpose()?;
let mechanism = sasl.mechanism.as_str();
let handshake =
SaslHandshakeRequest::default().with_mechanism(StrBytes::from_static_str(mechanism));
let response = self
.send_request::<SaslHandshakeRequest>(client_id, handshake_version, &handshake)
.await
.context("SASL handshake request failed")?;
if let Some(error) = response.error_code.err() {
let enabled = response
.mechanisms
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ");
bail!(
"SASL handshake failed for mechanism {mechanism}: {error}; enabled mechanisms: [{enabled}]"
);
}
match sasl.mechanism {
SaslMechanism::Plain => {
let token = build_plain_sasl_token(sasl)?;
if let Some(version) = authenticate_version {
self.send_sasl_authenticate(client_id, version, mechanism, token)
.await?;
} else {
write_raw_sasl_token(&mut self.stream, &token).await?;
}
}
SaslMechanism::ScramSha256 | SaslMechanism::ScramSha512 => {
self.authenticate_scram(client_id, sasl, authenticate_version)
.await?;
}
}
debug!(mechanism, "completed SASL authentication");
Ok(())
}
async fn authenticate_scram(
&mut self,
client_id: &str,
sasl: &SaslConfig,
authenticate_version: Option<i16>,
) -> Result<()> {
let username = sasl
.username
.as_ref()
.context("SASL/SCRAM requires a username")?
.clone();
let password = sasl
.password
.as_ref()
.context("SASL/SCRAM requires a password")?
.clone();
let mechanism = sasl.mechanism.as_str();
let mut scram = ScramClient::new(sasl.mechanism, username, password)?;
let client_first = scram.client_first_message();
let server_first = if let Some(version) = authenticate_version {
self.send_sasl_authenticate(client_id, version, mechanism, client_first)
.await?
} else {
write_raw_sasl_token(&mut self.stream, &client_first).await?;
read_frame(&mut self.stream).await?
};
let client_final = scram.handle_server_first_message(&server_first)?;
let server_final = if let Some(version) = authenticate_version {
self.send_sasl_authenticate(client_id, version, mechanism, client_final)
.await?
} else {
write_raw_sasl_token(&mut self.stream, &client_final).await?;
read_frame(&mut self.stream).await?
};
scram.handle_server_final_message(&server_final)?;
Ok(())
}
async fn send_sasl_authenticate(
&mut self,
client_id: &str,
version: i16,
mechanism: &str,
token: Vec<u8>,
) -> Result<Vec<u8>> {
let request = SaslAuthenticateRequest::default().with_auth_bytes(Bytes::from(token));
let response = self
.send_request::<SaslAuthenticateRequest>(client_id, version, &request)
.await
.context("SASL authenticate request failed")?;
if let Some(error) = response.error_code.err() {
let message = response
.error_message
.as_ref()
.map(ToString::to_string)
.filter(|message| !message.is_empty())
.unwrap_or_else(|| error.to_string());
bail!("SASL authentication failed for mechanism {mechanism}: {message}");
}
Ok(response.auth_bytes.to_vec())
}
pub fn version_with_cap<Req>(&self, cap: i16) -> Result<i16>
where
Req: Request,
{
let broker_range = self
.api_versions
.get(&Req::KEY)
.copied()
.with_context(|| format!("broker did not advertise API key {}", Req::KEY))?;
select_api_version(Req::KEY, broker_range, Req::VERSIONS, cap)
}
pub fn finalized_feature_level(&self, feature: &str) -> Option<i16> {
self.finalized_features.get(feature).copied()
}
pub fn finalized_feature_levels(&self) -> Vec<(String, i16)> {
let mut features = self
.finalized_features
.iter()
.map(|(name, level)| (name.clone(), *level))
.collect::<Vec<_>>();
features.sort_by(|left, right| left.0.cmp(&right.0));
features
}
async fn negotiate_versions(&mut self, client_id: &str) -> Result<()> {
let modern_request = ApiVersionsRequest::default()
.with_client_software_name(StrBytes::from_static_str("kafkit-client"))
.with_client_software_version(StrBytes::from_static_str("0.2.0"));
let response = match self
.send_request::<ApiVersionsRequest>(
client_id,
API_VERSIONS_PROBE_VERSION,
&modern_request,
)
.await
{
Ok(response) => response,
Err(error) => {
debug!(
error = %error,
"modern ApiVersions probe failed, retrying with fallback request"
);
self.send_request::<ApiVersionsRequest>(
client_id,
API_VERSIONS_FALLBACK_VERSION,
&ApiVersionsRequest::default(),
)
.await?
}
};
if let Some(error) = response.error_code.err() {
bail!("ApiVersions failed: {error}");
}
self.api_versions = response
.api_keys
.into_iter()
.map(|api| {
(
api.api_key,
VersionRange {
min: api.min_version,
max: api.max_version,
},
)
})
.collect();
self.finalized_features = response
.finalized_features
.into_iter()
.map(|feature| (feature.name.to_string(), feature.max_version_level))
.collect();
trace!(
api_keys = self.api_versions.len(),
finalized_features = self.finalized_features.len(),
"negotiated broker ApiVersions"
);
Ok(())
}
pub async fn send_request<Req>(
&mut self,
client_id: &str,
version: i16,
request: &Req,
) -> Result<Req::Response>
where
Req: Request,
{
let correlation_id = self.next_correlation_id;
self.next_correlation_id += 1;
let started = Instant::now();
let mut request_bytes = 0usize;
let mut response_bytes = 0usize;
let span = trace_span!(
"kafka_request",
request = std::any::type_name::<Req>(),
api_key = Req::KEY,
api_version = version,
correlation_id,
%client_id
);
let result = async {
let mut body = BytesMut::new();
let header = RequestHeader::default()
.with_request_api_key(Req::KEY)
.with_request_api_version(version)
.with_correlation_id(correlation_id)
.with_client_id(Some(StrBytes::from_string(client_id.to_owned())));
encode_request_header_into_buffer(&mut body, &header)?;
request.encode(&mut body, version)?;
request_bytes = body.len();
trace!(request_bytes = body.len(), "encoded Kafka request");
let mut frame = BytesMut::with_capacity(body.len() + 4);
frame.put_i32(i32::try_from(body.len()).context("request frame is too large")?);
frame.extend_from_slice(&body);
self.stream.write_all(&frame).await?;
trace!(frame_bytes = frame.len(), "wrote Kafka request frame");
let response_frame = read_frame(&mut self.stream).await?;
response_bytes = response_frame.len();
trace!(
response_bytes = response_frame.len(),
"received Kafka response frame"
);
let mut response_body = Bytes::from(response_frame);
let header_version = Req::Response::header_version(version);
let response_header = ResponseHeader::decode(&mut response_body, header_version)?;
if response_header.correlation_id != correlation_id {
bail!(
"response correlation mismatch: expected {}, got {}",
correlation_id,
response_header.correlation_id
);
}
let response = Req::Response::decode(&mut response_body, version)?;
trace!("completed Kafka request");
Ok(response)
}
.instrument(span)
.await;
telemetry::record_kafka_request::<Req>(
client_id,
version,
request_bytes,
response_bytes,
started.elapsed(),
result.is_ok(),
true,
);
result
}
pub async fn send_request_without_response<Req>(
&mut self,
client_id: &str,
version: i16,
request: &Req,
) -> Result<()>
where
Req: Request,
{
let correlation_id = self.next_correlation_id;
self.next_correlation_id += 1;
let started = Instant::now();
let mut request_bytes = 0usize;
let span = trace_span!(
"kafka_request",
request = std::any::type_name::<Req>(),
api_key = Req::KEY,
api_version = version,
correlation_id,
expects_response = false,
%client_id
);
let result = async {
let mut body = BytesMut::new();
let header = RequestHeader::default()
.with_request_api_key(Req::KEY)
.with_request_api_version(version)
.with_correlation_id(correlation_id)
.with_client_id(Some(StrBytes::from_string(client_id.to_owned())));
encode_request_header_into_buffer(&mut body, &header)?;
request.encode(&mut body, version)?;
request_bytes = body.len();
trace!(request_bytes = body.len(), "encoded Kafka request");
let mut frame = BytesMut::with_capacity(body.len() + 4);
frame.put_i32(i32::try_from(body.len()).context("request frame is too large")?);
frame.extend_from_slice(&body);
self.stream.write_all(&frame).await?;
trace!(frame_bytes = frame.len(), "wrote Kafka request frame");
Ok(())
}
.instrument(span)
.await;
telemetry::record_kafka_request::<Req>(
client_id,
version,
request_bytes,
0,
started.elapsed(),
result.is_ok(),
false,
);
result
}
}
async fn connect_stream(
address: &str,
timeout: Duration,
security_protocol: SecurityProtocol,
tls: &TlsConfig,
tcp_connector: &Arc<dyn TcpConnector>,
) -> Result<BrokerStream> {
let tcp_stream = tcp_connector.connect(address, timeout).await?;
tcp_stream
.set_nodelay(true)
.with_context(|| format!("failed to enable TCP_NODELAY for {address}"))?;
if security_protocol.uses_tls() {
let ConnectedTcpStream::Tokio(tcp_stream) = tcp_stream else {
bail!("custom TCP connectors do not support TLS broker connections");
};
let tls_config = build_tls_client_config(tls)?;
let connector = TlsConnector::from(tls_config);
let server_name = server_name_for_tls(address, tls)?;
let stream = tokio::time::timeout(timeout, connector.connect(server_name, tcp_stream))
.await
.with_context(|| format!("timed out negotiating TLS with {address}"))?
.with_context(|| format!("failed TLS handshake with {address}"))?;
Ok(BrokerStream::Tls(Box::new(stream)))
} else {
match tcp_stream {
ConnectedTcpStream::Tokio(stream) => Ok(BrokerStream::Plain(Box::new(stream))),
ConnectedTcpStream::Custom(stream) => Ok(BrokerStream::Plain(stream)),
}
}
}
fn build_plain_sasl_token(sasl: &SaslConfig) -> Result<Vec<u8>> {
let username = sasl
.username
.as_deref()
.context("SASL/PLAIN requires a username")?;
let password = sasl
.password
.as_deref()
.context("SASL/PLAIN requires a password")?;
let authorization_id = sasl.authorization_id.as_deref().unwrap_or_default();
let mut token =
Vec::with_capacity(authorization_id.len() + username.len() + password.len() + 2);
token.extend_from_slice(authorization_id.as_bytes());
token.push(0);
token.extend_from_slice(username.as_bytes());
token.push(0);
token.extend_from_slice(password.as_bytes());
Ok(token)
}
async fn write_raw_sasl_token(stream: &mut BrokerStream, token: &[u8]) -> Result<()> {
let mut frame = BytesMut::with_capacity(token.len() + 4);
frame.put_i32(i32::try_from(token.len()).context("SASL token frame is too large")?);
frame.extend_from_slice(token);
stream.write_all(&frame).await
}
fn build_tls_client_config(tls: &TlsConfig) -> Result<Arc<RustlsClientConfig>> {
ensure_rustls_crypto_provider();
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if let Some(ca_cert_path) = tls.ca_cert_path.as_deref() {
for cert in load_certificates(ca_cert_path)? {
root_store.add(cert)?;
}
}
let builder = RustlsClientConfig::builder().with_root_certificates(root_store);
let config = match (
tls.client_cert_path.as_deref(),
tls.client_key_path.as_deref(),
) {
(Some(client_cert_path), Some(client_key_path)) => builder.with_client_auth_cert(
load_certificates(client_cert_path)?,
load_private_key(client_key_path)?,
)?,
(None, None) => builder.with_no_client_auth(),
_ => bail!("TLS client auth requires both client_cert_path and client_key_path"),
};
Ok(Arc::new(config))
}
fn ensure_rustls_crypto_provider() {
static INSTALL_PROVIDER: Once = Once::new();
INSTALL_PROVIDER.call_once(|| {
if rustls::crypto::CryptoProvider::get_default().is_none() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
}
});
}
fn load_certificates(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
let file = File::open(path)
.with_context(|| format!("failed to open TLS certificate file '{}'", path.display()))?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader)
.collect::<std::result::Result<Vec<_>, _>>()
.with_context(|| format!("failed to parse TLS certificate PEM '{}'", path.display()))?;
if certs.is_empty() {
bail!(
"TLS certificate file '{}' did not contain any PEM certificates",
path.display()
);
}
Ok(certs)
}
fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>> {
let file = File::open(path)
.with_context(|| format!("failed to open TLS private key file '{}'", path.display()))?;
let mut reader = BufReader::new(file);
rustls_pemfile::private_key(&mut reader)
.with_context(|| format!("failed to parse TLS private key PEM '{}'", path.display()))?
.with_context(|| {
format!(
"TLS private key file '{}' did not contain a PEM key",
path.display()
)
})
}
fn server_name_for_tls(address: &str, tls: &TlsConfig) -> Result<ServerName<'static>> {
if let Some(server_name) = tls.server_name.as_ref() {
return ServerName::try_from(server_name.clone())
.with_context(|| format!("invalid TLS server name '{}'", server_name));
}
if let Ok(socket_addr) = address.parse::<SocketAddr>() {
return Ok(ServerName::IpAddress(socket_addr.ip().into()));
}
let host = if let Some(stripped) = address.strip_prefix('[') {
stripped
.split(']')
.next()
.context("invalid bracketed broker address")?
.to_owned()
} else {
address
.rsplit_once(':')
.map(|(host, _)| host.to_owned())
.unwrap_or_else(|| address.to_owned())
};
ServerName::try_from(host.clone()).with_context(|| {
format!("could not derive a valid TLS server name from broker address '{address}'")
})
}
async fn read_frame(stream: &mut BrokerStream) -> Result<Vec<u8>> {
let mut header = [0_u8; 4];
stream.read_exact(&mut header).await?;
let frame_len = i32::from_be_bytes(header);
if frame_len < 0 {
bail!("broker returned a negative frame length: {frame_len}");
}
let mut payload = vec![0_u8; usize::try_from(frame_len)?];
stream.read_exact(&mut payload).await?;
Ok(payload)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::sync::Arc;
use tokio::io;
#[test]
fn tls_server_name_defaults_to_host() {
let server_name =
server_name_for_tls("broker.example.com:9093", &TlsConfig::default()).unwrap();
assert_eq!(server_name.to_str(), "broker.example.com");
}
#[test]
fn tls_server_name_respects_explicit_override() {
let tls = TlsConfig::new().with_server_name("cluster.internal");
let server_name = server_name_for_tls("127.0.0.1:9093", &tls).unwrap();
assert_eq!(server_name.to_str(), "cluster.internal");
}
#[test]
fn tls_server_name_handles_ip_and_bracketed_ipv6() {
let server_name = server_name_for_tls("127.0.0.1:9093", &TlsConfig::default()).unwrap();
assert_eq!(server_name.to_str(), "127.0.0.1");
let server_name = server_name_for_tls("[::1]:9093", &TlsConfig::default()).unwrap();
assert_eq!(server_name.to_str(), "::1");
}
#[test]
fn tls_server_name_rejects_invalid_override_and_empty_address() {
let tls = TlsConfig::new().with_server_name("not a dns name");
assert!(server_name_for_tls("127.0.0.1:9093", &tls).is_err());
assert!(server_name_for_tls("", &TlsConfig::default()).is_err());
}
#[test]
fn plain_sasl_token_requires_credentials_and_uses_authorization_id() {
assert!(build_plain_sasl_token(&SaslConfig::default()).is_err());
assert!(
build_plain_sasl_token(&SaslConfig::plain("user", "pw").with_authorization_id("authz"))
.unwrap()
== b"authz\0user\0pw"
);
}
#[test]
fn tls_file_loaders_reject_missing_empty_and_invalid_pem_files() {
let dir =
std::env::temp_dir().join(format!("kafkit-client-tls-test-{}", std::process::id()));
fs::create_dir_all(&dir).unwrap();
let cert_path = dir.join("cert.pem");
let key_path = dir.join("key.pem");
fs::write(&cert_path, b"not a certificate").unwrap();
fs::write(&key_path, b"not a key").unwrap();
assert!(load_certificates(&cert_path).is_err());
assert!(load_private_key(&key_path).is_err());
assert!(load_certificates(&dir.join("missing.pem")).is_err());
let _ = fs::remove_dir_all(dir);
}
#[test]
fn tls_client_config_loads_custom_ca_and_client_auth_pem_files() {
let dir = std::env::temp_dir().join(format!(
"kafkit-client-tls-valid-pem-test-{}",
std::process::id()
));
fs::create_dir_all(&dir).unwrap();
let cert_path = dir.join("cert.pem");
let key_path = dir.join("key.pem");
fs::write(&cert_path, TEST_CERT_PEM).unwrap();
fs::write(&key_path, TEST_KEY_PEM).unwrap();
let tls = TlsConfig::new()
.with_ca_cert_path(&cert_path)
.with_client_cert_path(&cert_path)
.with_client_key_path(&key_path)
.with_server_name("cluster.internal");
build_tls_client_config(&tls).expect("valid custom CA and client auth config");
assert!(
build_tls_client_config(&TlsConfig::new().with_client_cert_path(&cert_path)).is_err()
);
assert!(
build_tls_client_config(&TlsConfig::new().with_client_key_path(&key_path)).is_err()
);
let _ = fs::remove_dir_all(dir);
}
#[tokio::test]
async fn tls_rejects_custom_tcp_connectors_before_handshake() {
let connector: Arc<dyn TcpConnector> = Arc::new(CustomOnlyConnector);
let error = match connect_stream(
"broker.example.com:9093",
Duration::from_secs(1),
SecurityProtocol::Ssl,
&TlsConfig::default(),
&connector,
)
.await
{
Ok(_) => panic!("TLS over custom stream should be rejected"),
Err(error) => error,
};
assert!(
error
.to_string()
.contains("custom TCP connectors do not support TLS broker connections")
);
}
#[derive(Debug)]
struct CustomOnlyConnector;
impl TcpConnector for CustomOnlyConnector {
fn connect<'a>(&'a self, _address: &'a str, _timeout: Duration) -> ConnectFuture<'a> {
Box::pin(async move {
let (stream, _peer) = io::duplex(64);
Ok(ConnectedTcpStream::Custom(Box::new(stream)))
})
}
}
const TEST_CERT_PEM: &[u8] = b"-----BEGIN CERTIFICATE-----
MIIDFzCCAf+gAwIBAgIUU1sGIzptOpATf4S4bW3ljAEYj94wDQYJKoZIhvcNAQEL
BQAwGzEZMBcGA1UEAwwQY2x1c3Rlci5pbnRlcm5hbDAeFw0yNjA1MDUxMzIyNTla
Fw0yNjA1MDYxMzIyNTlaMBsxGTAXBgNVBAMMEGNsdXN0ZXIuaW50ZXJuYWwwggEi
MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC59uFLczWX0ES7Y2ckovLTPC+r
lhAYhS+KOpIeEjgo+mqQ9fmyqnAq6NTr/tWWVgcfgAoqNo1+gOQa9WIu55NOQzNa
wBreheE8MaL7QD/QFZnvT0Z5Hh3hkXj2HTDQqBIMv1i3bVaDDOkK3xphfQO8QhV9
YtZf2MvxvtCbl0kBqAUN+k+EECu4TENNLQyS+2rZhxqg0/Js3DUu24nMD3ilL4Kf
KU2qE3pNfe6IrPl36LY+GkxprvmwPncocR4piJKGrc20XCsiM9KnAimIwZ6/nZ/C
DJEESK2+NmjDs84GHQFmxh1rlpaSFYJsshxnFH/y0ccyHtLZpsi+R0S7iqPJAgMB
AAGjUzBRMB0GA1UdDgQWBBR86FwGaRa1IxBdu4KK5TWR01asBzAfBgNVHSMEGDAW
gBR86FwGaRa1IxBdu4KK5TWR01asBzAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3
DQEBCwUAA4IBAQBgihO4KChG9VRoY7/Sq5UWjuZT8UWZoyjyejglK/J7enmx0bRX
clEg8gRZfhbFpYIybppIK+UuKUixkFeqW2CAt/odzNDcYiMEhXZ8SWLx12LhKcLi
EITLt0PZ877aNaszz5UWlP6Wj4ec8f1DiD1PSIQqz9gddwwdX8gespmyeW/riuCQ
RMfp9HwJgpcVQMqqSeOwZaDlm1szhpEql+g1/mVGMjXHYO0B7fxzrMUY99vSkOw0
iJQHtjVkkiHfkN1HDmpfwfONwfsyA0UYtzH4kwbVHm7v1FixQ8TS24jjQi19+v3h
M/xsKOBvTns6oAKzm3oerDtSSt/heECbD3rb
-----END CERTIFICATE-----
";
const TEST_KEY_PEM: &[u8] = b"-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC59uFLczWX0ES7
Y2ckovLTPC+rlhAYhS+KOpIeEjgo+mqQ9fmyqnAq6NTr/tWWVgcfgAoqNo1+gOQa
9WIu55NOQzNawBreheE8MaL7QD/QFZnvT0Z5Hh3hkXj2HTDQqBIMv1i3bVaDDOkK
3xphfQO8QhV9YtZf2MvxvtCbl0kBqAUN+k+EECu4TENNLQyS+2rZhxqg0/Js3DUu
24nMD3ilL4KfKU2qE3pNfe6IrPl36LY+GkxprvmwPncocR4piJKGrc20XCsiM9Kn
AimIwZ6/nZ/CDJEESK2+NmjDs84GHQFmxh1rlpaSFYJsshxnFH/y0ccyHtLZpsi+
R0S7iqPJAgMBAAECggEAAJRn8TCSDX/NNXMix0b1kDoDGtS6oFDxLBjXPsSNknch
YOobYqnl9Dd9ZNTxCbYJiwwYbzd0Hnci/ubrICLoElmvepkLT5lF1/mxoxKsTQ11
yUl+enJhFnegU5tIsF9twWA3ukhBeXwcHkTbk+U4+NvER5VIyzJL6txOhMWmdemO
Tvk7vm1gUzr84k+mYdEoIaS5Bb8zgSNWcLVvZTAvd5VQuV5/SNHVrbpCy6q1dC++
7FdAhgSJ+CdRk/aAIXZ7zKrhe0pbCWDkmQLIdESLbv1onb9Sj/CLw8MEogMbT7T+
0FvjagYsmKsIq6Jyhd/Ve+zoLXOgOszVDYVvW14GNwKBgQDwgdT/lQEHPCoRfd1U
dz77OMIpawZtC7UAf+ab6HEmbSRaoIIa7kx5fjZeMTy6wQamN4xIcqCxt8KWozVH
M8VnChAidj3yX15AWiKT9kIBuk4dJOLwVh0Hsho+ml034M7txhBNPNIWdfxpFIti
0xncG9hkfCj5qxkUesHnuS29fwKBgQDF8ZdRnLU7iGW3YyE9OYKb/GzlF/NMkRex
7mRyTueOR5p/OiWQkQYo1F4XArnmIQSCcllOb0VukwBJLItKqc8fBHjkiyJyvCft
ZJSR3/BjFgx2w9Vo93bTpiHvevz2nTbebhV0kYXydgeiF7jCcpOQuAjreK5yhhCV
HvJoKJrStwKBgCk0VSWkhZSTvjFY+v5pn6SyyLEH4QX1p4D6aKv1Ws1WjY/pR+EN
SpTWBsKEdP8Z6uW3RpVy7g0EipX8SDh2qi9JDhKZZ2uK4z7rMllfK1fYb2GW3DqI
xlh3Lv/ium3EWi9qa4iQDv5CIIhwOKEpwZhwPNaaXvrHUXisv2PP2gJJAoGBAL1x
yjQWujFfCpKoclCJcSJfRc1Azd9S4g2uLj5knCNFDm2Dth4VXoLHNcHqHwdMRGeg
jy6NOjNox5ZA5pMv0AZMnnOFYhPTVpdScwrl+8ipeoZUSTSr2vMXhlUQLXjN4Iyj
aS9mc38pTYbqEy8uv2J7cDYFC1iaTNabhr7/VaYjAoGBAOApTlkgCYa7eUk9YYJs
zdrPUZcgT8cGTL6f04cLleaAW9gICh+25yDBQbay4uLTSKXMwb5Kygu8RYDk2NDz
GEdMJjFtDUbjt1eAlAarBIdsBs7A7jk1nGfu5g8Ervnm1X8Gs9FbUABmPQadNGJR
20YddOzMXpjdAMlrtmhRp4z1
-----END PRIVATE KEY-----
";
}