use super::config::KafkaConfig;
use crate::transport::error::{TransportError, TransportResult};
use rdkafka::config::ClientConfig;
use rdkafka::producer::{BaseRecord, Producer, ThreadedProducer};
use rdkafka::util::Timeout;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ProducerProfile {
#[default]
HighThroughput,
ExactlyOnce,
LowLatency,
DevTest,
}
impl std::fmt::Display for ProducerProfile {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::HighThroughput => write!(f, "high_throughput"),
Self::ExactlyOnce => write!(f, "exactly_once"),
Self::LowLatency => write!(f, "low_latency"),
Self::DevTest => write!(f, "devtest"),
}
}
}
pub struct KafkaProducer {
producer: ThreadedProducer<ProducerContext>,
profile: ProducerProfile,
messages_sent: AtomicU64,
bytes_sent: AtomicU64,
errors: AtomicU64,
}
#[derive(Clone)]
pub struct ProducerContext {
}
impl rdkafka::ClientContext for ProducerContext {}
impl rdkafka::producer::ProducerContext for ProducerContext {
type DeliveryOpaque = ();
fn delivery(
&self,
_result: &rdkafka::producer::DeliveryResult<'_>,
_opaque: Self::DeliveryOpaque,
) {
}
}
impl KafkaProducer {
pub fn new(config: &KafkaConfig, profile: ProducerProfile) -> TransportResult<Self> {
let mut client_config = ClientConfig::new();
client_config.set("bootstrap.servers", config.brokers.join(","));
client_config.set("client.id", &config.client_id);
client_config.set("security.protocol", &config.security_protocol);
if let Some(ref mechanism) = config.sasl_mechanism {
client_config.set("sasl.mechanism", mechanism);
}
if let Some(ref username) = config.sasl_username {
client_config.set("sasl.username", username);
}
if let Some(ref password) = config.sasl_password {
client_config.set("sasl.password", password.expose());
}
if let Some(ref ca) = config.ssl_ca_location {
client_config.set("ssl.ca.location", ca);
}
if let Some(ref cert) = config.ssl_certificate_location {
client_config.set("ssl.certificate.location", cert);
}
if let Some(ref key) = config.ssl_key_location {
client_config.set("ssl.key.location", key);
}
if config.ssl_skip_verify {
client_config.set("enable.ssl.certificate.verification", "false");
}
let profile_settings = match profile {
ProducerProfile::HighThroughput => super::config::PRODUCER_HIGH_THROUGHPUT,
ProducerProfile::ExactlyOnce => super::config::PRODUCER_EXACTLY_ONCE,
ProducerProfile::LowLatency => super::config::PRODUCER_LOW_LATENCY,
ProducerProfile::DevTest => super::config::PRODUCER_DEVTEST,
};
for (key, value) in profile_settings {
client_config.set(*key, *value);
}
for (key, value) in &config.librdkafka_overrides {
client_config.set(key, value);
}
let context = ProducerContext {};
let producer: ThreadedProducer<ProducerContext> = client_config
.create_with_context(context)
.map_err(|e| TransportError::Connection(format!("Failed to create producer: {e}")))?;
Ok(Self {
producer,
profile,
messages_sent: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
errors: AtomicU64::new(0),
})
}
pub fn high_throughput(config: &KafkaConfig) -> TransportResult<Self> {
Self::new(config, ProducerProfile::HighThroughput)
}
pub fn exactly_once(config: &KafkaConfig) -> TransportResult<Self> {
Self::new(config, ProducerProfile::ExactlyOnce)
}
pub fn low_latency(config: &KafkaConfig) -> TransportResult<Self> {
Self::new(config, ProducerProfile::LowLatency)
}
pub fn send(&self, topic: &str, key: Option<&[u8]>, payload: &[u8]) -> TransportResult<()> {
let mut record = BaseRecord::to(topic).payload(payload);
if let Some(k) = key {
record = record.key(k);
}
match self.producer.send(record) {
Ok(()) => {
self.messages_sent.fetch_add(1, Ordering::Relaxed);
self.bytes_sent
.fetch_add(payload.len() as u64, Ordering::Relaxed);
Ok(())
}
Err((err, _)) => {
self.errors.fetch_add(1, Ordering::Relaxed);
let err_str = err.to_string();
if err_str.contains("queue full") || err_str.contains("Local: Queue full") {
Err(TransportError::Backpressure)
} else {
Err(TransportError::Send(err_str))
}
}
}
}
pub fn send_keyed(&self, topic: &str, key: &str, payload: &[u8]) -> TransportResult<()> {
self.send(topic, Some(key.as_bytes()), payload)
}
pub fn send_batch(&self, topic: &str, messages: &[(Option<&[u8]>, &[u8])]) -> usize {
let mut sent = 0;
for (key, payload) in messages {
let mut record = BaseRecord::to(topic).payload(*payload);
if let Some(k) = key {
record = record.key(*k);
}
if self.producer.send(record).is_ok() {
self.messages_sent.fetch_add(1, Ordering::Relaxed);
self.bytes_sent
.fetch_add(payload.len() as u64, Ordering::Relaxed);
sent += 1;
} else {
self.errors.fetch_add(1, Ordering::Relaxed);
break; }
}
sent
}
pub fn poll(&self, timeout: Duration) {
self.producer.poll(Timeout::After(timeout));
}
#[allow(clippy::cast_sign_loss)]
pub fn flush(&self, timeout: Duration) -> usize {
let _ = self.producer.flush(Timeout::After(timeout));
self.producer.in_flight_count().max(0) as usize
}
#[allow(clippy::cast_sign_loss)]
pub fn in_flight_count(&self) -> usize {
self.producer.in_flight_count().max(0) as usize
}
#[allow(clippy::cast_sign_loss)]
pub fn metrics(&self) -> ProducerMetrics {
ProducerMetrics {
messages_sent: self.messages_sent.load(Ordering::Relaxed),
bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
errors: self.errors.load(Ordering::Relaxed),
in_flight: self.producer.in_flight_count().max(0) as u64,
profile: self.profile,
}
}
}
#[derive(Debug, Clone)]
pub struct ProducerMetrics {
pub messages_sent: u64,
pub bytes_sent: u64,
pub errors: u64,
pub in_flight: u64,
pub profile: ProducerProfile,
}
impl std::fmt::Debug for KafkaProducer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KafkaProducer")
.field("profile", &self.profile)
.field("messages_sent", &self.messages_sent.load(Ordering::Relaxed))
.field("bytes_sent", &self.bytes_sent.load(Ordering::Relaxed))
.field("errors", &self.errors.load(Ordering::Relaxed))
.field("in_flight", &self.producer.in_flight_count())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_producer_profile_display() {
assert_eq!(
ProducerProfile::HighThroughput.to_string(),
"high_throughput"
);
assert_eq!(ProducerProfile::ExactlyOnce.to_string(), "exactly_once");
assert_eq!(ProducerProfile::LowLatency.to_string(), "low_latency");
assert_eq!(ProducerProfile::DevTest.to_string(), "devtest");
}
#[test]
fn test_producer_profile_default() {
assert_eq!(ProducerProfile::default(), ProducerProfile::HighThroughput);
}
}