use bytes::Bytes;
use rustfs_kafka::client::RequiredAcks;
use rustfs_kafka::error::Result;
use rustfs_kafka::producer::{AsBytes, Producer, Record};
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, info};
use crate::AsyncKafkaClient;
enum ProducerCommand {
Send {
topic: String,
key: Bytes,
value: Bytes,
partition: i32,
response: oneshot::Sender<Result<()>>,
},
Flush {
response: oneshot::Sender<Result<()>>,
},
Shutdown,
}
pub struct AsyncProducer {
sender: mpsc::Sender<ProducerCommand>,
handle: Option<tokio::task::JoinHandle<()>>,
}
impl AsyncProducer {
pub async fn new(client: AsyncKafkaClient) -> Result<Self> {
let hosts = client.bootstrap_hosts().to_vec();
let client_id = client.client_id().to_owned();
let sync_producer = Producer::from_hosts(hosts)
.with_client_id(client_id)
.with_required_acks(RequiredAcks::One)
.create()?;
let (sender, mut receiver) = mpsc::channel::<ProducerCommand>(256);
let handle = tokio::spawn(async move {
let mut producer = sync_producer;
while let Some(cmd) = receiver.recv().await {
match cmd {
ProducerCommand::Send {
topic,
key,
value,
partition,
response,
} => {
let result = Self::do_send(&mut producer, &topic, &key, &value, partition);
let _ = response.send(result);
}
ProducerCommand::Flush { response } => {
let _ = response.send(Ok(()));
}
ProducerCommand::Shutdown => {
debug!("Async producer shutting down");
break;
}
}
}
info!("Async producer background task exited");
});
Ok(Self {
sender,
handle: Some(handle),
})
}
pub async fn from_hosts(hosts: Vec<String>) -> Result<Self> {
let client = AsyncKafkaClient::new(hosts).await?;
Self::new(client).await
}
pub async fn send<K, V>(&self, record: &Record<'_, K, V>) -> Result<()>
where
K: AsBytes,
V: AsBytes,
{
let (tx, rx) = oneshot::channel();
self.sender
.send(ProducerCommand::Send {
topic: record.topic.to_owned(),
key: Bytes::copy_from_slice(record.key.as_bytes()),
value: Bytes::copy_from_slice(record.value.as_bytes()),
partition: record.partition,
response: tx,
})
.await
.map_err(|_| {
rustfs_kafka::error::Error::Connection(
rustfs_kafka::error::ConnectionError::NoHostReachable,
)
})?;
rx.await.map_err(|_| {
rustfs_kafka::error::Error::Connection(
rustfs_kafka::error::ConnectionError::NoHostReachable,
)
})?
}
pub async fn flush(&self) -> Result<()> {
let (tx, rx) = oneshot::channel();
self.sender
.send(ProducerCommand::Flush { response: tx })
.await
.map_err(|_| {
rustfs_kafka::error::Error::Connection(
rustfs_kafka::error::ConnectionError::NoHostReachable,
)
})?;
rx.await.map_err(|_| {
rustfs_kafka::error::Error::Connection(
rustfs_kafka::error::ConnectionError::NoHostReachable,
)
})?
}
pub async fn close(mut self) -> Result<()> {
if let Some(handle) = self.handle.take() {
let _ = self.sender.send(ProducerCommand::Shutdown).await;
let _ = handle.await;
}
Ok(())
}
fn do_send(
producer: &mut Producer,
topic: &str,
key: &Bytes,
value: &Bytes,
partition: i32,
) -> Result<()> {
let record =
Record::from_key_value(topic, key.as_ref(), value.as_ref()).with_partition(partition);
producer.send(&record)
}
}
impl Drop for AsyncProducer {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use rustfs_kafka::error::{ConnectionError, Error};
use super::*;
#[tokio::test]
async fn from_hosts_fails_with_unreachable_hosts() {
let result = AsyncProducer::from_hosts(vec!["127.0.0.1:1".to_owned()]).await;
assert!(matches!(
result,
Err(Error::Connection(ConnectionError::NoHostReachable))
));
}
#[tokio::test]
async fn new_fails_with_unreachable_hosts() {
let client = AsyncKafkaClient::new(vec![]).await.unwrap();
let result = AsyncProducer::new(client).await;
assert!(result.is_err());
}
}