use std::fmt;
use std::io::{Read, Write};
use std::net::{Shutdown, TcpStream, ToSocketAddrs};
use std::time::Duration;
use tracing::debug;
use crate::error::Result;
#[cfg(feature = "security")]
use crate::tls::{RustlsConnector, TlsConfig, TlsStream};
#[cfg(feature = "security")]
pub struct SecurityConfig {
pub(crate) tls_config: TlsConfig,
}
#[cfg(feature = "security")]
impl SecurityConfig {
#[must_use]
pub fn new() -> Self {
SecurityConfig {
tls_config: TlsConfig::new(),
}
}
#[must_use]
pub fn from_tls_config(tls_config: TlsConfig) -> SecurityConfig {
SecurityConfig { tls_config }
}
#[must_use]
pub fn with_hostname_verification(mut self, verify_hostname: bool) -> SecurityConfig {
self.tls_config.verify_hostname = verify_hostname;
self
}
#[must_use]
pub fn with_ca_cert(mut self, path: String) -> SecurityConfig {
self.tls_config.ca_cert_path = Some(path);
self
}
#[must_use]
pub fn with_client_cert(mut self, cert_path: String, key_path: String) -> SecurityConfig {
self.tls_config.client_cert_path = Some(cert_path);
self.tls_config.client_key_path = Some(key_path);
self
}
}
#[cfg(feature = "security")]
impl Default for SecurityConfig {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "security")]
impl fmt::Debug for SecurityConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SecurityConfig {{ verify_hostname: {} }}",
self.tls_config.verify_hostname
)
}
}
#[cfg(not(feature = "security"))]
pub(crate) type KafkaStream = TcpStream;
#[cfg(feature = "security")]
pub(crate) enum KafkaStream {
Plain(TcpStream),
Tls(Box<dyn TlsStream>),
}
pub(crate) trait StreamOps {
fn is_secured(&self) -> bool;
fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()>;
fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()>;
fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()>;
}
#[cfg(not(feature = "security"))]
impl StreamOps for KafkaStream {
fn is_secured(&self) -> bool {
false
}
fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
TcpStream::set_read_timeout(self, dur)
}
fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
TcpStream::set_write_timeout(self, dur)
}
fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()> {
TcpStream::shutdown(self, how)
}
}
#[cfg(feature = "security")]
impl StreamOps for KafkaStream {
fn is_secured(&self) -> bool {
match self {
KafkaStream::Plain(_) => false,
KafkaStream::Tls(_) => true,
}
}
fn set_read_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
match self {
KafkaStream::Plain(s) => s.set_read_timeout(dur),
KafkaStream::Tls(s) => s.set_read_timeout(dur),
}
}
fn set_write_timeout(&mut self, dur: Option<Duration>) -> std::io::Result<()> {
match self {
KafkaStream::Plain(s) => s.set_write_timeout(dur),
KafkaStream::Tls(s) => s.set_write_timeout(dur),
}
}
fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()> {
match self {
KafkaStream::Plain(s) => s.shutdown(how),
KafkaStream::Tls(s) => s.shutdown(),
}
}
}
#[cfg(feature = "security")]
impl Read for KafkaStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
KafkaStream::Plain(s) => s.read(buf),
KafkaStream::Tls(s) => s.read(buf),
}
}
}
#[cfg(feature = "security")]
impl Write for KafkaStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
KafkaStream::Plain(s) => s.write(buf),
KafkaStream::Tls(s) => s.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
KafkaStream::Plain(s) => s.flush(),
KafkaStream::Tls(s) => s.flush(),
}
}
}
pub struct KafkaConnection {
id: u32,
host: String,
stream: KafkaStream,
state: ConnectionState,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ConnectionState {
Connected,
Terminated,
}
impl fmt::Debug for KafkaConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"KafkaConnection {{ id: {}, secured: {}, state: {:?}, host: \"{}\" }}",
self.id,
self.stream.is_secured(),
self.state,
self.host
)
}
}
fn configure_tcp_socket(socket: &socket2::Socket) -> std::io::Result<()> {
use socket2::TcpKeepalive;
let keepalive = TcpKeepalive::new()
.with_time(Duration::from_secs(10))
.with_interval(Duration::from_secs(20));
socket.set_tcp_keepalive(&keepalive)?;
socket.set_tcp_nodelay(true)?;
Ok(())
}
impl KafkaConnection {
pub fn send(&mut self, msg: &[u8]) -> Result<usize> {
self.stream.write(msg).map_err(|e| {
self.state = ConnectionState::Terminated;
From::from(e)
})
}
pub(crate) fn is_terminated(&self) -> bool {
self.state == ConnectionState::Terminated
}
pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
self.stream.read_exact(buf).map_err(|e| {
self.state = ConnectionState::Terminated;
From::from(e)
})
}
pub fn read_exact_alloc(&mut self, size: u64) -> Result<bytes::Bytes> {
let len = usize::try_from(size).expect("response size exceeds usize");
let mut buf = bytes::BytesMut::with_capacity(len);
buf.resize(len, 0);
self.read_exact(&mut buf)?;
Ok(buf.freeze())
}
pub(crate) fn shutdown(&mut self) -> Result<()> {
self.state = ConnectionState::Terminated;
let r = StreamOps::shutdown(&mut self.stream, Shutdown::Both);
debug!("Shut down: {:?} => {:?}", self, r);
r.map_err(From::from)
}
fn from_stream(
mut stream: KafkaStream,
id: u32,
host: &str,
rw_timeout: Option<Duration>,
) -> Result<KafkaConnection> {
StreamOps::set_read_timeout(&mut stream, rw_timeout)?;
StreamOps::set_write_timeout(&mut stream, rw_timeout)?;
Ok(KafkaConnection {
id,
host: host.to_owned(),
stream,
state: ConnectionState::Connected,
})
}
fn new_tcp_stream(host: &str) -> std::io::Result<TcpStream> {
let mut last_err: Option<std::io::Error> = None;
for addr in host.to_socket_addrs()? {
let domain = match addr {
std::net::SocketAddr::V4(_) => socket2::Domain::IPV4,
std::net::SocketAddr::V6(_) => socket2::Domain::IPV6,
};
let socket =
socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
match socket.connect(&socket2::SockAddr::from(addr)) {
Ok(()) => {
configure_tcp_socket(&socket)?;
return Ok(socket.into());
}
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
format!("unable to resolve broker address: {host}"),
)
}))
}
#[cfg(not(feature = "security"))]
pub(crate) fn new(
id: u32,
host: &str,
rw_timeout: Option<Duration>,
) -> Result<KafkaConnection> {
KafkaConnection::from_stream(Self::new_tcp_stream(host)?, id, host, rw_timeout)
}
#[cfg(feature = "security")]
pub(crate) fn new(
id: u32,
host: &str,
rw_timeout: Option<Duration>,
tls_config: Option<&TlsConfig>,
) -> Result<KafkaConnection> {
let tcp_stream = Self::new_tcp_stream(host)?;
let stream = match tls_config {
Some(config) => {
let domain = match host.rfind(':') {
None => host,
Some(i) => &host[..i],
};
let connector = RustlsConnector::new(config)?;
let tls_stream = connector.connect(domain, tcp_stream)?;
KafkaStream::Tls(tls_stream)
}
None => KafkaStream::Plain(tcp_stream),
};
KafkaConnection::from_stream(stream, id, host, rw_timeout)
}
}