use crabka_security::ListenerProtocol;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use std::sync::Arc;
use crate::config::InterBrokerCredentials;
pub(crate) fn to_client_creds(c: &InterBrokerCredentials) -> crabka_client_core::SaslCredentials {
match c {
InterBrokerCredentials::Plain { username, password } => {
crabka_client_core::SaslCredentials::Plain {
username: username.clone(),
password: password.clone(),
}
}
InterBrokerCredentials::Scram {
mechanism,
username,
password,
} => crabka_client_core::SaslCredentials::Scram {
mechanism: *mechanism,
username: username.clone(),
password: password.clone(),
},
InterBrokerCredentials::Gssapi {
keytab_path,
client_principal,
service_name,
kdc_url,
} => crabka_client_core::SaslCredentials::Gssapi {
keytab_path: keytab_path.clone(),
client_principal: client_principal.clone(),
service_name: service_name.clone(),
kdc_url: kdc_url.clone(),
},
}
}
#[derive(Debug, Error)]
pub enum InterBrokerError {
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("tls: {0}")]
Tls(String),
#[error("sasl: {0}")]
Sasl(String),
#[error("config: {0}")]
Config(String),
#[error("codec: {0}")]
Codec(String),
}
pub trait DuplexStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send + ?Sized> DuplexStream for T {}
pub struct InterBrokerClient {
tls_connector: Option<TlsConnector>,
creds: Option<InterBrokerCredentials>,
}
impl InterBrokerClient {
#[must_use]
pub fn new(tls_connector: Option<TlsConnector>, creds: Option<InterBrokerCredentials>) -> Self {
Self {
tls_connector,
creds,
}
}
pub async fn connect(
&self,
host: &str,
port: u16,
listener_protocol: ListenerProtocol,
server_name: &str,
) -> Result<Box<dyn DuplexStream>, InterBrokerError> {
let tcp = TcpStream::connect((host, port)).await?;
let mut stream: Box<dyn DuplexStream> = if listener_protocol.requires_tls() {
let connector = self.tls_connector.clone().ok_or_else(|| {
InterBrokerError::Config("TLS listener without TlsConnector".into())
})?;
let sni =
tokio_rustls::rustls::pki_types::ServerName::try_from(server_name.to_string())
.map_err(|e| InterBrokerError::Tls(format!("invalid server name: {e}")))?;
let tls = connector
.connect(sni, tcp)
.await
.map_err(|e| InterBrokerError::Tls(e.to_string()))?;
Box::new(tls)
} else {
Box::new(tcp)
};
if listener_protocol.requires_sasl() {
let creds = self.creds.clone().ok_or_else(|| {
InterBrokerError::Config("SASL listener without inter_broker_credentials".into())
})?;
crabka_client_core::outbound_sasl(&mut *stream, &to_client_creds(&creds), server_name)
.await
.map_err(|e| InterBrokerError::Sasl(e.to_string()))?;
}
Ok(stream)
}
pub async fn connect_as_connection(
&self,
host: &str,
port: u16,
listener_protocol: ListenerProtocol,
server_name: &str,
options: crabka_client_core::ConnectionOptions,
) -> Result<crabka_client_core::Connection, InterBrokerError> {
let tcp = TcpStream::connect((host, port)).await?;
let mut stream: Box<dyn crabka_client_core::ClientDuplex> =
if listener_protocol.requires_tls() {
let connector = self.tls_connector.clone().ok_or_else(|| {
InterBrokerError::Config("TLS listener without TlsConnector".into())
})?;
let sni =
tokio_rustls::rustls::pki_types::ServerName::try_from(server_name.to_string())
.map_err(|e| InterBrokerError::Tls(format!("invalid server name: {e}")))?;
let tls = connector
.connect(sni, tcp)
.await
.map_err(|e| InterBrokerError::Tls(e.to_string()))?;
Box::new(tls)
} else {
Box::new(tcp)
};
if listener_protocol.requires_sasl() {
let creds = self.creds.clone().ok_or_else(|| {
InterBrokerError::Config("SASL listener without inter_broker_credentials".into())
})?;
crabka_client_core::outbound_sasl(&mut *stream, &to_client_creds(&creds), server_name)
.await
.map_err(|e| InterBrokerError::Sasl(e.to_string()))?;
}
crabka_client_core::Connection::from_stream(stream, options)
.await
.map_err(|e| InterBrokerError::Config(format!("Connection::from_stream: {e}")))
}
}
pub struct InterBrokerDialer {
client: Arc<InterBrokerClient>,
listener_protocol: ListenerProtocol,
server_name: String,
}
impl InterBrokerDialer {
#[must_use]
pub fn new(
client: Arc<InterBrokerClient>,
listener_protocol: ListenerProtocol,
server_name: String,
) -> Self {
Self {
client,
listener_protocol,
server_name,
}
}
}
#[async_trait::async_trait]
impl crabka_raft::OutboundDialer for InterBrokerDialer {
async fn dial(
&self,
_target: crabka_raft::NodeId,
addr: &str,
options: crabka_client_core::ConnectionOptions,
) -> Result<crabka_client_core::Connection, crabka_client_core::ClientError> {
let (host, port) = match addr.rsplit_once(':') {
Some((h, p)) => {
let port: u16 = p.parse().map_err(|e: std::num::ParseIntError| {
crabka_client_core::ClientError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("invalid raft peer port in {addr:?}: {e}"),
))
})?;
(h.to_string(), port)
}
None => {
return Err(crabka_client_core::ClientError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("raft peer address missing port: {addr:?}"),
)));
}
};
self.client
.connect_as_connection(
&host,
port,
self.listener_protocol,
&self.server_name,
options,
)
.await
.map_err(|e| match e {
InterBrokerError::Io(io) => crabka_client_core::ClientError::Io(io),
other => crabka_client_core::ClientError::Io(std::io::Error::other(format!(
"InterBrokerClient dial: {other}"
))),
})
}
}