#![allow(dead_code)]
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use kovan_map::HashMap;
use parking_lot::Mutex;
use super::client::{RespClient, StreamClient};
use super::config::{Compression, ProducerConfig};
use super::error::StreamError;
use super::interceptor::ProducerInterceptor;
use super::types::{
Acks, ConsumerGroupMetadata, OffsetAndMetadata, Record, RecordMetadata, TopicPartition,
};
pub struct Producer {
config: ProducerConfig,
client: Arc<dyn StreamClient>,
interceptors: Vec<Arc<dyn ProducerInterceptor + Send + Sync>>,
tx_active: AtomicBool,
tx_buffer: Mutex<Vec<(String, Record)>>,
records_sent: AtomicU64,
bytes_sent: AtomicU64,
}
#[derive(Default)]
pub struct ProducerBuilder {
config: ProducerConfig,
client: Option<Arc<dyn StreamClient>>,
interceptors: Vec<Arc<dyn ProducerInterceptor + Send + Sync>>,
}
impl ProducerBuilder {
pub fn bootstrap_servers(mut self, servers: impl Into<String>) -> Self {
self.config.bootstrap_servers = servers.into();
self
}
pub fn acks(mut self, acks: Acks) -> Self {
self.config.acks = acks;
self
}
pub fn idempotence(mut self, enabled: bool) -> Self {
self.config.idempotence = enabled;
self
}
pub fn transactional_id(mut self, id: impl Into<String>) -> Self {
self.config.transactional_id = Some(id.into());
self
}
pub fn linger_ms(mut self, ms: u64) -> Self {
self.config.linger = std::time::Duration::from_millis(ms);
self
}
pub fn batch_size(mut self, bytes: usize) -> Self {
self.config.batch_size = bytes;
self
}
pub fn max_in_flight(mut self, n: u32) -> Self {
self.config.max_in_flight = n;
self
}
pub fn request_timeout_ms(mut self, ms: u64) -> Self {
self.config.request_timeout = std::time::Duration::from_millis(ms);
self
}
pub fn compression(mut self, compression: Compression) -> Self {
self.config.compression = compression;
self
}
pub fn retries(mut self, retries: u32) -> Self {
self.config.retries = retries;
self
}
pub fn client(mut self, client: Arc<dyn StreamClient>) -> Self {
self.client = Some(client);
self
}
pub fn interceptor(mut self, interceptor: Arc<dyn ProducerInterceptor + Send + Sync>) -> Self {
self.interceptors.push(interceptor);
self
}
pub async fn build(self) -> Result<Producer, StreamError> {
let client = self
.client
.unwrap_or_else(|| Arc::new(RespClient::new(&self.config.bootstrap_servers)));
Ok(Producer {
config: self.config,
client,
interceptors: self.interceptors,
tx_active: AtomicBool::new(false),
tx_buffer: Mutex::new(Vec::new()),
records_sent: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
})
}
}
impl Producer {
pub fn builder() -> ProducerBuilder {
ProducerBuilder::default()
}
pub fn send(
&self,
record: Record,
) -> Pin<Box<dyn Future<Output = Result<RecordMetadata, StreamError>> + Send + '_>> {
Box::pin(async move {
let record = self.apply_interceptors(record);
let value_len = record.value.len();
if self.tx_active.load(Ordering::SeqCst) {
let mut buffer = self.tx_buffer.lock();
buffer.push((record.topic.clone(), record));
return Ok(RecordMetadata {
topic: buffer.last().map(|(t, _)| t.clone()).unwrap_or_default(),
partition: 0,
offset: 0,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
});
}
let metadata = self.client.xadd(
&record.topic,
record.key.as_deref(),
&record.value,
&record.headers,
)?;
self.records_sent.fetch_add(1, Ordering::Relaxed);
self.bytes_sent
.fetch_add(value_len as u64, Ordering::Relaxed);
for interceptor in &self.interceptors {
interceptor.on_acknowledgement(&metadata, None);
}
Ok(metadata)
})
}
pub fn send_with_callback<F>(&self, record: Record, callback: F)
where
F: FnOnce(Result<RecordMetadata, StreamError>) + Send + 'static,
{
let record = self.apply_interceptors(record);
let value_len = record.value.len();
let result = self.client.xadd(
&record.topic,
record.key.as_deref(),
&record.value,
&record.headers,
);
if result.is_ok() {
self.records_sent.fetch_add(1, Ordering::Relaxed);
self.bytes_sent
.fetch_add(value_len as u64, Ordering::Relaxed);
}
callback(result);
}
pub fn send_sync(&self, record: Record) -> Result<RecordMetadata, StreamError> {
let record = self.apply_interceptors(record);
let value_len = record.value.len();
let metadata = self.client.xadd(
&record.topic,
record.key.as_deref(),
&record.value,
&record.headers,
)?;
self.records_sent.fetch_add(1, Ordering::Relaxed);
self.bytes_sent
.fetch_add(value_len as u64, Ordering::Relaxed);
Ok(metadata)
}
pub async fn flush(&self) -> Result<(), StreamError> {
Ok(())
}
pub async fn init_transactions(&self) -> Result<(), StreamError> {
if self.config.transactional_id.is_none() {
return Err(StreamError::Config(
"transactional_id must be set before calling init_transactions".into(),
));
}
Ok(())
}
pub fn begin_transaction(&self) -> Result<(), StreamError> {
if self.config.transactional_id.is_none() {
return Err(StreamError::Config(
"transactional_id must be set before calling begin_transaction".into(),
));
}
if self.tx_active.swap(true, Ordering::SeqCst) {
return Err(StreamError::Transaction(
"transaction already in progress".into(),
));
}
self.tx_buffer.lock().clear();
Ok(())
}
pub async fn commit_transaction(&self) -> Result<(), StreamError> {
if !self.tx_active.load(Ordering::SeqCst) {
return Err(StreamError::Transaction("no active transaction".into()));
}
let records: Vec<(String, Record)> = {
let mut buffer = self.tx_buffer.lock();
std::mem::take(&mut *buffer)
};
for (topic, record) in records {
self.client.xadd(
&topic,
record.key.as_deref(),
&record.value,
&record.headers,
)?;
}
self.tx_active.store(false, Ordering::SeqCst);
Ok(())
}
pub async fn abort_transaction(&self) -> Result<(), StreamError> {
if !self.tx_active.load(Ordering::SeqCst) {
return Err(StreamError::Transaction("no active transaction".into()));
}
self.tx_buffer.lock().clear();
self.tx_active.store(false, Ordering::SeqCst);
Ok(())
}
pub async fn send_offsets_to_transaction(
&self,
offsets: HashMap<TopicPartition, OffsetAndMetadata>,
group_metadata: ConsumerGroupMetadata,
) -> Result<(), StreamError> {
if !self.tx_active.load(Ordering::SeqCst) {
return Err(StreamError::Transaction("no active transaction".into()));
}
for (tp, offset_meta) in offsets.iter() {
self.client
.commit_offset(&group_metadata.group_id, &tp.topic, offset_meta.offset)?;
}
Ok(())
}
pub async fn close(&self) -> Result<(), StreamError> {
if self.tx_active.load(Ordering::SeqCst) {
self.abort_transaction().await?;
}
self.flush().await?;
Ok(())
}
pub fn config(&self) -> &ProducerConfig {
&self.config
}
pub fn records_sent(&self) -> u64 {
self.records_sent.load(Ordering::Relaxed)
}
pub fn bytes_sent(&self) -> u64 {
self.bytes_sent.load(Ordering::Relaxed)
}
fn apply_interceptors(&self, mut record: Record) -> Record {
for interceptor in &self.interceptors {
record = interceptor.on_send(record);
}
record
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_producer_builder() {
let builder = Producer::builder()
.bootstrap_servers("localhost:6379")
.acks(Acks::All)
.idempotence(true)
.transactional_id("tx-1")
.linger_ms(5)
.batch_size(16384)
.compression(Compression::Lz4)
.retries(3);
assert_eq!(builder.config.bootstrap_servers, "localhost:6379");
assert_eq!(builder.config.acks, Acks::All);
assert!(builder.config.idempotence);
}
#[test]
fn test_record_builder() {
let record = Record::new("my-topic")
.key(b"key1".to_vec())
.value(b"value1".to_vec())
.header("h1", b"v1".to_vec())
.timestamp(1234567890);
assert_eq!(record.topic, "my-topic");
assert_eq!(record.key, Some(b"key1".to_vec()));
assert_eq!(record.value, b"value1");
assert_eq!(record.headers.len(), 1);
}
}