use std::{
any::Any,
future::Future,
pin::Pin,
sync::{mpsc::channel, Arc},
};
use crate::prod_utils::validate_produce_input;
use rdkafka::{producer::FutureProducer, ClientConfig};
pub type DynFut<'lt, T> = Pin<Box<dyn 'lt + Send + Future<Output = T>>>;
pub type AsyncWorkerFunction<T> = fn(FutureProducer, &'static str, Vec<Arc<dyn Any + Send + Sync>>) -> DynFut<'static, T>;
#[derive(Clone, Copy)]
pub struct AsyncFunctionWrapper<T> {
inner: AsyncWorkerFunction<T>,
}
impl<T> AsyncFunctionWrapper<T> {
pub fn new(
inner: AsyncWorkerFunction<T>,
) -> Self {
Self { inner }
}
pub fn run(
&self,
producer: FutureProducer,
topic: &'static str,
args: Vec<Arc<dyn Any + Send + Sync>>,
) -> DynFut<'static, T> {
(self.inner)(producer, topic, args)
}
}
pub async fn produce_in_parallel<T: 'static + Clone + Send>(
num_threads: u32,
topic: &'static str,
producer_config: &ClientConfig,
func_wrapper: AsyncFunctionWrapper<T>,
args: Vec<Arc<dyn Any + Send + Sync>>,
) -> Result<Vec<T>, Box<dyn Any + Send>> {
if let Err(e) = validate_produce_input(num_threads, topic, producer_config) {
return Err(e);
}
let mut thread_handles = Vec::new();
let (tx, rx) = channel();
for thread_id in 0..num_threads {
log::debug!("Spawning thread {}", thread_id);
let tx = tx.clone();
match producer_config.create::<FutureProducer>() {
Ok(producer) => {
let topic = topic.clone();
let worker_function = Arc::new(func_wrapper.clone());
let worker_function_args = args.clone();
let handle = tokio::spawn(async move {
let local_res = worker_function
.run(producer, topic, worker_function_args)
.await;
match tx.send(local_res) {
Ok(_) => log::debug!("Sent result from thread {}", thread_id),
Err(e) => {
log::error!("Error sending result from thread {}: {}", thread_id, e);
panic!("Error sending result from thread {}: {}", thread_id, e);
}
}
});
thread_handles.push(handle);
}
Err(e) => {
log::error!("Error creating producer: {}", e);
return Err(Box::new(e));
}
}
}
drop(tx);
let mut results = Vec::new();
for handle in thread_handles {
match handle.await {
Ok(_) => {
log::debug!("Thread joined");
match rx.recv() {
Ok(local_res) => results.push(local_res),
Err(e) => log::error!("Error receiving result from thread: {}", e),
}
}
Err(e) => log::error!("Error joining thread: {}", e),
};
}
Ok(results)
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use rdkafka::producer::FutureRecord;
use super::*;
use crate::prod_utils;
macro_rules! aw {
($e:expr) => {
tokio_test::block_on($e)
};
}
async fn test_async_fn_returns_i32(
_: FutureProducer,
_: &'static str,
_: Vec<Arc<dyn Any + Send + Sync>>,
) -> i32 {
1337
}
#[test]
fn test_produce_in_parallel_returns_correct_values() {
let producer_config = prod_utils::get_default_producer_config("localhost:9092", "10");
let topic = "test-topic";
let args = Vec::new();
let num_threads = 4;
let func_wrapper = AsyncFunctionWrapper::new(|future_producer, topic, args| {
Box::pin(test_async_fn_returns_i32(future_producer, topic, args))
});
let t = aw!(produce_in_parallel(
num_threads,
topic,
&producer_config,
func_wrapper,
args
));
assert!(t.is_ok());
let results = t.unwrap();
assert_eq!(results.len(), num_threads as usize);
for res in results {
assert_eq!(res, 1337);
}
}
#[test]
fn test_produce_in_parallel_returns_error_on_invalid_num_threads() {
let producer_config = prod_utils::get_default_producer_config("localhost:9092", "10");
let topic = "test-topic";
let args = Vec::new();
let num_threads = 0;
let func_wrapper = AsyncFunctionWrapper::new(|future_producer, topic, args| {
Box::pin(test_async_fn_returns_i32(future_producer, topic, args))
});
let t = aw!(produce_in_parallel(
num_threads,
topic,
&producer_config,
func_wrapper,
args
));
assert!(t.is_err());
}
#[test]
fn test_produce_in_parallel_returns_error_on_invalid_topic() {
let producer_config = prod_utils::get_default_producer_config("localhost:9092", "10");
let topic = "";
let args = Vec::new();
let num_threads = 4;
let func_wrapper = AsyncFunctionWrapper::new(|future_producer, topic, args| {
Box::pin(test_async_fn_returns_i32(future_producer, topic, args))
});
let t = aw!(produce_in_parallel(
num_threads,
topic,
&producer_config,
func_wrapper,
args
));
assert!(t.is_err());
}
#[test]
fn test_produce_in_parallel_returns_error_on_invalid_producer_config() {
let producer_config = ClientConfig::new();
let topic = "test-topic";
let args = Vec::new();
let num_threads = 4;
let func_wrapper = AsyncFunctionWrapper::new(|future_producer, topic, args| {
Box::pin(test_async_fn_returns_i32(future_producer, topic, args))
});
let t = aw!(produce_in_parallel(
num_threads,
topic,
&producer_config,
func_wrapper,
args
));
assert!(t.is_err());
}
async fn checks_message_response(
producer: FutureProducer,
topic: &'static str,
_: Vec<Arc<dyn Any + Sync + Send>>,
) -> Result<(i32, i64), (rdkafka::error::KafkaError, rdkafka::message::OwnedMessage)> {
let i = 0;
let delivery_status = producer
.send(
FutureRecord::to(topic)
.payload(&format!("Message {}", i))
.key(&format!("Key {}", i)),
Duration::from_secs(5),
)
.await;
if let Err(ref e) = delivery_status {
log::error!("Error sending message: {:?}", e);
}
delivery_status
}
#[test]
fn test_produce_in_parallel_sends_messages() {
pretty_env_logger::init();
let producer_config = prod_utils::get_default_producer_config("localhost:9092", "5000");
let topic = "test-topic";
let args = Vec::new();
let num_threads = 4;
let func_wrapper = AsyncFunctionWrapper::new(|future_producer, topic, args| {
Box::pin(checks_message_response(future_producer, topic, args))
});
let t = aw!(produce_in_parallel(
num_threads,
topic,
&producer_config,
func_wrapper,
args
));
assert!(t.is_ok());
let results = t.unwrap();
assert_eq!(results.len(), num_threads as usize);
for res in results {
assert!(res.is_ok());
}
}
}