use rand::prelude::*;
use std::fmt::Display;
use std::future::Future;
use std::ops::ControlFlow;
use std::sync::Arc;
use thiserror::Error;
use tokio::{io::BufStream, sync::Mutex};
use tracing::{debug, error, info, warn};
use crate::backoff::ErrorOrThrottle;
use crate::client::metadata_cache::MetadataCacheGeneration;
use crate::connection::topology::{Broker, BrokerTopology};
use crate::connection::transport::Transport;
use crate::messenger::{Messenger, RequestError};
use crate::protocol::messages::{MetadataRequest, MetadataRequestTopic, MetadataResponse};
use crate::protocol::primitives::String_;
use crate::throttle::maybe_throttle;
use crate::{
backoff::{Backoff, BackoffConfig, BackoffError},
client::metadata_cache::MetadataCache,
};
pub use self::transport::TlsConfig;
pub use self::transport::{Credentials, OauthBearerCredentials, OauthCallback, SaslConfig};
mod topology;
mod transport;
pub type BrokerConnection = Arc<MessengerTransport>;
pub type MessengerTransport = Messenger<BufStream<transport::Transport>>;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
#[error("error getting cluster metadata: {0}")]
Metadata(#[from] RequestError),
#[error("error connecting to broker \"{broker}\": {error}")]
Transport {
broker: String,
error: transport::Error,
},
#[error("cannot sync versions: {0}")]
SyncVersions(#[from] crate::messenger::SyncVersionsError),
#[error("all retries failed: {0}")]
RetryFailed(BackoffError),
#[error("Sasl handshake failed: {0}")]
SaslFailed(#[from] crate::messenger::SaslError),
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Error)]
pub struct MultiError(Vec<Box<dyn std::error::Error + Send + Sync>>);
impl Display for MultiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut needs_comma = false;
if self.0.len() > 1 {
write!(f, "Multiple errors occured: ")?;
}
for err in &self.0 {
if needs_comma {
write!(f, ", ")?;
}
needs_comma = true;
write!(f, "{err}")?;
}
Ok(())
}
}
trait ConnectionHandler {
type R: RequestHandler + Send + Sync;
fn connect(
&self,
client_id: Arc<str>,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
sasl_config: Option<SaslConfig>,
max_message_size: usize,
) -> impl Future<Output = Result<Arc<Self::R>>> + Send;
}
#[derive(Debug)]
pub enum MetadataLookupMode<B = BrokerConnection> {
ArbitraryBroker,
SpecificBroker(B),
CachedArbitrary,
}
impl<B> std::fmt::Display for MetadataLookupMode<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ArbitraryBroker => write!(f, "ArbitraryBroker"),
Self::SpecificBroker(_) => f.debug_tuple("SpecificBroker").field(&"...").finish(),
Self::CachedArbitrary => write!(f, "CachedArbitrary"),
}
}
}
enum BrokerRepresentation {
Bootstrap(String),
Topology(Broker),
}
impl BrokerRepresentation {
fn id(&self) -> Option<i32> {
match self {
Self::Bootstrap(_) => None,
Self::Topology(broker) => Some(broker.id),
}
}
fn url(&self) -> String {
match self {
Self::Bootstrap(inner) => inner.clone(),
Self::Topology(broker) => broker.to_string(),
}
}
}
impl ConnectionHandler for BrokerRepresentation {
type R = MessengerTransport;
async fn connect(
&self,
client_id: Arc<str>,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
sasl_config: Option<SaslConfig>,
max_message_size: usize,
) -> Result<Arc<Self::R>> {
let url = self.url();
info!(
broker = self.id(),
url = url.as_str(),
"Establishing new connection",
);
let transport = Transport::connect(&url, tls_config, socks5_proxy)
.await
.map_err(|error| Error::Transport {
broker: url.to_string(),
error,
})?;
let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id);
messenger.sync_versions().await?;
if let Some(sasl_config) = sasl_config {
messenger.do_sasl(sasl_config).await?;
}
Ok(Arc::new(messenger))
}
}
pub struct BrokerConnector {
bootstrap_brokers: Vec<String>,
client_id: Arc<str>,
topology: BrokerTopology,
cached_arbitrary_broker: Mutex<(Option<BrokerConnection>, BrokerCacheGeneration)>,
cached_metadata: MetadataCache,
backoff_config: Arc<BackoffConfig>,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
sasl_config: Option<SaslConfig>,
max_message_size: usize,
}
impl BrokerConnector {
pub fn new(
bootstrap_brokers: Vec<String>,
client_id: Arc<str>,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
sasl_config: Option<SaslConfig>,
max_message_size: usize,
backoff_config: Arc<BackoffConfig>,
) -> Self {
Self {
bootstrap_brokers,
client_id,
topology: Default::default(),
cached_arbitrary_broker: Mutex::new((None, BrokerCacheGeneration::START)),
cached_metadata: Default::default(),
backoff_config,
tls_config,
socks5_proxy,
sasl_config,
max_message_size,
}
}
pub async fn refresh_metadata(&self) -> Result<()> {
self.request_metadata(&MetadataLookupMode::ArbitraryBroker, None)
.await?;
Ok(())
}
pub async fn request_metadata(
&self,
metadata_mode: &MetadataLookupMode,
topics: Option<Vec<String>>,
) -> Result<(MetadataResponse, Option<MetadataCacheGeneration>)> {
if matches!(metadata_mode, MetadataLookupMode::CachedArbitrary) {
if let Some((m, r#gen)) = self.cached_metadata.get(&topics) {
return Ok((m, Some(r#gen)));
}
}
let backoff = Backoff::new(&self.backoff_config);
let request = MetadataRequest {
topics: topics.map(|t| {
t.into_iter()
.map(|x| MetadataRequestTopic { name: String_(x) })
.collect()
}),
allow_auto_topic_creation: None,
};
let response = metadata_request_with_retry(metadata_mode, &request, backoff, self).await?;
if request.topics.is_none() {
self.cached_metadata.update(response.clone());
}
self.topology.update(&response.brokers);
Ok((response, None))
}
pub(crate) fn invalidate_metadata_cache(
&self,
reason: &'static str,
r#gen: MetadataCacheGeneration,
) {
self.cached_metadata.invalidate(reason, r#gen)
}
pub async fn connect(&self, broker_id: i32) -> Result<Option<BrokerConnection>> {
match self.topology.get_broker(broker_id).await {
Some(broker) => {
let connection = BrokerRepresentation::Topology(broker)
.connect(
Arc::clone(&self.client_id),
self.tls_config.clone(),
self.socks5_proxy.clone(),
self.sasl_config.clone(),
self.max_message_size,
)
.await?;
Ok(Some(connection))
}
None => Ok(None),
}
}
fn brokers(&self) -> Vec<BrokerRepresentation> {
if self.topology.is_empty() {
self.bootstrap_brokers
.iter()
.cloned()
.map(BrokerRepresentation::Bootstrap)
.collect()
} else {
self.topology
.get_brokers()
.iter()
.cloned()
.map(BrokerRepresentation::Topology)
.collect()
}
}
}
impl std::fmt::Debug for BrokerConnector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BrokerConnector")
.field("bootstrap_brokers", &self.bootstrap_brokers)
.field("topology", &self.topology)
.field("cached_arbitrary_broker", &self.cached_arbitrary_broker)
.field("backoff_config", &self.backoff_config)
.field("tls_config", &"...")
.field("max_message_size", &self.max_message_size)
.finish()
}
}
trait RequestHandler {
fn metadata_request(
&self,
request_params: &MetadataRequest,
) -> impl Future<Output = Result<MetadataResponse, RequestError>> + Send;
}
impl RequestHandler for MessengerTransport {
async fn metadata_request(
&self,
request_params: &MetadataRequest,
) -> Result<MetadataResponse, RequestError> {
self.request(request_params).await
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BrokerCacheGeneration(usize);
impl BrokerCacheGeneration {
pub const START: Self = Self(0);
pub fn bump(&mut self) -> Self {
self.0 += 1;
*self
}
pub fn get(&self) -> usize {
self.0
}
}
pub trait BrokerCache: Send + Sync {
type R: Send + Sync;
type E: std::error::Error + Send + Sync;
fn get(
&self,
) -> impl Future<Output = Result<(Arc<Self::R>, BrokerCacheGeneration), Self::E>> + Send;
fn invalidate(
&self,
reason: &'static str,
r#gen: BrokerCacheGeneration,
) -> impl Future<Output = ()> + Send;
}
impl BrokerCache for &BrokerConnector {
type R = MessengerTransport;
type E = Error;
async fn get(&self) -> Result<(Arc<Self::R>, BrokerCacheGeneration), Self::E> {
let mut current_broker = self.cached_arbitrary_broker.lock().await;
if let Some(broker) = ¤t_broker.0 {
return Ok((Arc::clone(broker), current_broker.1));
}
let connection = connect_to_a_broker_with_retry(
self.brokers(),
Arc::clone(&self.client_id),
&self.backoff_config,
self.tls_config.clone(),
self.socks5_proxy.clone(),
self.sasl_config.clone(),
self.max_message_size,
)
.await?;
current_broker.0 = Some(Arc::clone(&connection));
current_broker.1.bump();
Ok((connection, current_broker.1))
}
async fn invalidate(&self, reason: &'static str, r#gen: BrokerCacheGeneration) {
let mut guard = self.cached_arbitrary_broker.lock().await;
if guard.1 != r#gen {
debug!(
reason,
current_gen = guard.1.0,
request_gen = r#gen.0,
"stale invalidation request for arbitrary broker cache",
);
return;
}
info!(reason, "Invalidating cached arbitrary broker",);
guard.0.take();
}
}
async fn connect_to_a_broker_with_retry<B>(
mut brokers: Vec<B>,
client_id: Arc<str>,
backoff_config: &BackoffConfig,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
sasl_config: Option<SaslConfig>,
max_message_size: usize,
) -> Result<Arc<B::R>>
where
B: ConnectionHandler + Send + Sync,
{
brokers.shuffle(&mut thread_rng());
let mut backoff = Backoff::new(backoff_config);
backoff
.retry_with_backoff("broker_connect", || async {
let mut errors = Vec::<Box<dyn std::error::Error + Send + Sync>>::new();
for broker in &brokers {
let conn = broker
.connect(
Arc::clone(&client_id),
tls_config.clone(),
socks5_proxy.clone(),
sasl_config.clone(),
max_message_size,
)
.await;
let connection = match conn {
Ok(transport) => transport,
Err(e) => {
warn!(%e, "Failed to connect to broker");
errors.push(Box::new(e));
continue;
}
};
return ControlFlow::Break(connection);
}
let err = Box::<dyn std::error::Error + Send + Sync>::from(MultiError(errors));
let err: Arc<dyn std::error::Error + Send + Sync> = err.into();
ControlFlow::Continue(ErrorOrThrottle::Error(err))
})
.await
.map_err(Error::RetryFailed)
}
async fn metadata_request_with_retry<A>(
metadata_mode: &MetadataLookupMode<Arc<A::R>>,
request_params: &MetadataRequest,
mut backoff: Backoff,
arbitrary_broker_cache: A,
) -> Result<MetadataResponse>
where
A: BrokerCache,
A::R: RequestHandler,
Error: From<A::E>,
{
backoff
.retry_with_backoff("metadata", || async {
let (broker, cache_gen) = match metadata_mode {
MetadataLookupMode::SpecificBroker(b) => (Arc::clone(b), None),
MetadataLookupMode::ArbitraryBroker | MetadataLookupMode::CachedArbitrary => {
match arbitrary_broker_cache.get().await {
Ok((broker, cache_gen)) => (broker, Some(cache_gen)),
Err(e) => return ControlFlow::Break(Err(e.into())),
}
}
};
match broker.metadata_request(request_params).await {
Ok(response) => {
if let Err(e) = maybe_throttle(response.throttle_time_ms) {
return ControlFlow::Continue(e);
}
ControlFlow::Break(Ok(response))
}
Err(e @ RequestError::Poisoned(_) | e @ RequestError::IO(_))
if !matches!(metadata_mode, MetadataLookupMode::SpecificBroker(_)) =>
{
if let Some(r#gen) = cache_gen {
arbitrary_broker_cache
.invalidate(
"metadata request: arbitrary/cached broker is connection is broken",
r#gen,
)
.await;
}
ControlFlow::Continue(ErrorOrThrottle::Error(e))
}
Err(error) => {
error!(
e=%error,
"metadata request encountered fatal error",
);
ControlFlow::Break(Err(error.into()))
}
}
})
.await
.map_err(Error::RetryFailed)?
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{build_info::DEFAULT_CLIENT_ID, protocol::api_key::ApiKey};
use std::sync::atomic::{AtomicBool, Ordering};
struct FakeBroker(Box<dyn Fn() -> Result<MetadataResponse, RequestError> + Send + Sync>);
impl FakeBroker {
fn success() -> Self {
Self(Box::new(|| Ok(arbitrary_metadata_response())))
}
fn fatal_error() -> Self {
Self(Box::new(|| Err(arbitrary_fatal_error())))
}
fn recoverable() -> Self {
Self(Box::new(|| Err(arbitrary_recoverable_error())))
}
}
impl RequestHandler for FakeBroker {
async fn metadata_request(
&self,
_request_params: &MetadataRequest,
) -> Result<MetadataResponse, RequestError> {
self.0()
}
}
struct FakeBrokerCache {
get: Box<dyn Fn() -> Result<Arc<FakeBroker>> + Send + Sync>,
invalidate: Box<dyn Fn() + Send + Sync>,
}
impl BrokerCache for FakeBrokerCache {
type R = FakeBroker;
type E = Error;
async fn get(&self) -> Result<(Arc<Self::R>, BrokerCacheGeneration)> {
(self.get)().map(|b| (b, BrokerCacheGeneration::START))
}
async fn invalidate(&self, _reason: &'static str, _gen: BrokerCacheGeneration) {
(self.invalidate)()
}
}
#[tokio::test]
async fn happy_cached_broker() {
let metadata_request = arbitrary_metadata_request();
let success_response = arbitrary_metadata_response();
let broker_cache = FakeBrokerCache {
get: Box::new(|| Ok(Arc::new(FakeBroker::success()))),
invalidate: Box::new(|| {}),
};
let result = metadata_request_with_retry(
&MetadataLookupMode::ArbitraryBroker,
&metadata_request,
Backoff::new(&Default::default()),
broker_cache,
)
.await
.unwrap();
assert_eq!(success_response, result)
}
#[tokio::test]
async fn fatal_error_cached_broker() {
let metadata_request = arbitrary_metadata_request();
let broker_cache = FakeBrokerCache {
get: Box::new(|| Ok(Arc::new(FakeBroker::fatal_error()))),
invalidate: Box::new(|| {}),
};
let result = metadata_request_with_retry(
&MetadataLookupMode::ArbitraryBroker,
&metadata_request,
Backoff::new(&Default::default()),
broker_cache,
)
.await
.unwrap_err();
assert!(matches!(
result,
Error::Metadata(RequestError::NoVersionMatch { .. })
));
}
#[tokio::test]
async fn sad_cached_broker() {
let succeed = Arc::new(AtomicBool::new(false));
let metadata_request = arbitrary_metadata_request();
let success_response = arbitrary_metadata_response();
let broker_cache = FakeBrokerCache {
get: Box::new({
let succeed = Arc::clone(&succeed);
move || {
Ok(Arc::new(if succeed.load(Ordering::SeqCst) {
FakeBroker::success()
} else {
FakeBroker::recoverable()
}))
}
}),
invalidate: Box::new({
let succeed = Arc::clone(&succeed);
move || succeed.store(true, Ordering::SeqCst)
}),
};
let result = metadata_request_with_retry(
&MetadataLookupMode::ArbitraryBroker,
&metadata_request,
Backoff::new(&Default::default()),
broker_cache,
)
.await
.unwrap();
assert_eq!(success_response, result)
}
#[tokio::test]
async fn happy_broker_override() {
let broker_override = Arc::new(FakeBroker::success());
let metadata_request = arbitrary_metadata_request();
let success_response = arbitrary_metadata_response();
let broker_cache = FakeBrokerCache {
get: Box::new(|| unreachable!()),
invalidate: Box::new(|| unreachable!()),
};
let result = metadata_request_with_retry(
&MetadataLookupMode::SpecificBroker(broker_override),
&metadata_request,
Backoff::new(&Default::default()),
broker_cache,
)
.await
.unwrap();
assert_eq!(success_response, result)
}
#[tokio::test]
async fn sad_broker_override() {
let broker_override = Arc::new(FakeBroker::recoverable());
let metadata_request = arbitrary_metadata_request();
let broker_cache = FakeBrokerCache {
get: Box::new(|| unreachable!()),
invalidate: Box::new(|| unreachable!()),
};
let result = metadata_request_with_retry(
&MetadataLookupMode::SpecificBroker(broker_override),
&metadata_request,
Backoff::new(&Default::default()),
broker_cache,
)
.await
.unwrap_err();
assert!(matches!(result, Error::Metadata(RequestError::IO { .. })));
}
fn arbitrary_metadata_request() -> MetadataRequest {
MetadataRequest {
topics: Default::default(),
allow_auto_topic_creation: Default::default(),
}
}
fn arbitrary_metadata_response() -> MetadataResponse {
MetadataResponse {
throttle_time_ms: Default::default(),
brokers: Default::default(),
cluster_id: Default::default(),
controller_id: Default::default(),
topics: Default::default(),
}
}
fn arbitrary_fatal_error() -> RequestError {
RequestError::NoVersionMatch {
api_key: ApiKey::Metadata,
}
}
fn arbitrary_recoverable_error() -> RequestError {
RequestError::IO(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))
}
struct FakeBrokerRepresentation {
conn: Box<dyn Fn() -> Result<Arc<FakeConn>> + Send + Sync>,
}
#[derive(Debug, PartialEq)]
struct FakeConn;
impl RequestHandler for FakeConn {
async fn metadata_request(
&self,
_request_params: &MetadataRequest,
) -> Result<MetadataResponse, RequestError> {
unreachable!();
}
}
impl ConnectionHandler for FakeBrokerRepresentation {
type R = FakeConn;
async fn connect(
&self,
_client_id: Arc<str>,
_tls_config: TlsConfig,
_socks5_proxy: Option<String>,
_sasl_config: Option<SaslConfig>,
_max_message_size: usize,
) -> Result<Arc<Self::R>> {
(self.conn)()
}
}
#[tokio::test]
async fn connect_picks_successful_broker() {
let brokers = vec![
FakeBrokerRepresentation {
conn: Box::new(|| Ok(Arc::new(FakeConn))),
},
FakeBrokerRepresentation {
conn: Box::new(|| Err(Error::Metadata(arbitrary_recoverable_error()))),
},
];
let conn = connect_to_a_broker_with_retry(
brokers,
Arc::from(DEFAULT_CLIENT_ID),
&Default::default(),
Default::default(),
Default::default(),
Default::default(),
Default::default(),
)
.await
.unwrap();
assert_eq!(*conn, FakeConn);
}
}