use log::debug;
use rdkafka::{producer::BaseProducer, ClientConfig};
use std::sync::mpsc::channel;
use std::thread;
use std::{any::Any, boxed::Box, sync::Arc};
use crate::prod_utils::validate_produce_input;
pub type SyncWorkerFunction<T> =
fn(&BaseProducer, &'static str, Vec<Arc<dyn Any + Sync + Send>>) -> T;
pub fn produce_in_parallel<T: 'static + Clone + Send>(
num_threads: u32,
topic: &'static str,
producer_config: &ClientConfig,
worker_function: SyncWorkerFunction<T>,
worker_function_args: Vec<Arc<dyn Any + Sync + Send>>,
) -> Result<Vec<T>, Box<dyn Any + Send>> {
if let Err(e) = validate_produce_input(num_threads, topic, producer_config) {
return Err(e);
}
let (tx, rx) = channel();
let producer_config = Arc::new(producer_config.clone());
let mut threads = Vec::new();
for _ in 0..num_threads {
let tx = tx.clone();
let producer_config = producer_config.clone();
let topic = topic.clone();
let worker_function_args = worker_function_args.clone();
threads.push(thread::spawn(move || {
let producer: BaseProducer = producer_config.create().expect("Producer creation error");
let result = worker_function(&producer, topic, worker_function_args);
tx.send(result).expect("Failed to send result");
}));
}
drop(tx);
let mut results = Vec::new();
for _ in threads {
let result = rx.recv().expect("Failed to receive result");
results.push(result);
}
Ok(results)
}
pub type SyncWorkerFunctionTypeErased =
fn(&BaseProducer, &'static str, Vec<Arc<dyn Any + Sync + Send>>) -> Box<dyn Any + Sync + Send>;
pub fn produce_in_parallel_type_erased(
n_threads: u32,
topic: &'static str,
producer_config: &ClientConfig,
worker_function: SyncWorkerFunctionTypeErased,
worker_function_args: Vec<Arc<dyn Any + Sync + Send>>,
) -> Result<Vec<Box<dyn Any + Send + Sync>>, Box<dyn Any + Send>> {
if let Err(e) = validate_produce_input(n_threads, topic, producer_config) {
return Err(e);
}
let mut thread_handles = Vec::new();
let (tx, rx) = channel();
for thread_id in 0..n_threads {
debug!("Starting thread {}", thread_id);
let tx = tx.clone();
match producer_config.create::<BaseProducer>() {
Ok(producer) => {
let worker_function = Arc::new(worker_function);
let function_args = worker_function_args.clone();
let handle = thread::spawn(move || {
match tx.send(worker_function(&producer, topic, function_args)) {
Ok(_) => {}
Err(e) => {
log::error!("Failed to send result from thread {}: {}", thread_id, e);
panic!("Failed to send result from thread {}: {}", thread_id, e);
}
}
});
thread_handles.push(handle);
}
Err(e) => return Err(Box::new(e)),
}
}
drop(tx);
let mut results = Vec::new();
log::debug!("Joining threads");
for handle in thread_handles {
match handle.join() {
Ok(_) => {
log::debug!("Thread joined");
let local_res = rx.recv().unwrap();
results.push(local_res);
}
Err(e) => return Err(e),
}
}
log::debug!("Threads joined successfully");
return Ok(results);
}
#[cfg(test)]
mod tests {
use super::*;
fn get_default_producer_config() -> ClientConfig {
let mut producer_config = ClientConfig::new();
producer_config.set("message.timeout.ms", "10");
producer_config.set("bootstrap.servers", "localhost:9092");
producer_config
}
fn test_sync_worker_function_returns_some(
_: &BaseProducer,
_: &'static str,
_: Vec<Arc<dyn Any + Sync + Send>>,
) -> Box<dyn Any + Sync + Send> {
Box::new(Some(1337))
}
fn test_sync_worker_function_returns_none(
_: &BaseProducer,
_: &'static str,
_: Vec<Arc<dyn Any + Sync + Send>>,
) -> Box<dyn Any + Sync + Send> {
Box::new(None::<()>)
}
#[test]
fn test_produce_in_parallel_returns_err() {
let producer_config = get_default_producer_config();
match produce_in_parallel_type_erased(
0,
"test-topic",
&producer_config,
test_sync_worker_function_returns_none,
Vec::new(),
) {
Ok(_) => assert!(false),
Err(_) => assert!(true),
}
match produce_in_parallel_type_erased(
2,
"",
&producer_config,
test_sync_worker_function_returns_none,
Vec::new(),
) {
Ok(_) => assert!(false),
Err(_) => assert!(true),
}
let mut producer_config = ClientConfig::new();
producer_config.set("message.timeout.ms", "10");
match produce_in_parallel_type_erased(
2,
"test-topic",
&producer_config,
test_sync_worker_function_returns_none,
Vec::new(),
) {
Ok(_) => assert!(false),
Err(_) => assert!(true),
}
let mut producer_config = ClientConfig::new();
producer_config.set("bootstrap.servers", "localhost:9092");
match produce_in_parallel_type_erased(
2,
"test-topic",
&producer_config,
test_sync_worker_function_returns_none,
Vec::new(),
) {
Ok(_) => assert!(false),
Err(_) => assert!(true),
}
}
#[test]
fn test_produce_in_parallel_returns_none() {
let producer_config = get_default_producer_config();
match produce_in_parallel_type_erased(
2,
"test-topic",
&producer_config,
test_sync_worker_function_returns_none,
Vec::new(),
) {
Ok(res) => {
for val in res {
let val = val.downcast::<Option<()>>().unwrap();
assert!(val.is_none());
}
}
Err(_) => assert!(false),
}
}
#[test]
fn test_produce_in_parallel_returns_correct_values() {
let producer_config = get_default_producer_config();
match produce_in_parallel_type_erased(
2,
"test-topic",
&producer_config,
test_sync_worker_function_returns_some,
Vec::new(),
) {
Ok(res) => {
for val in res {
let val = val.downcast::<Option<i32>>().unwrap();
assert!(val.is_some());
let val = val.unwrap();
assert_eq!(val, 1337);
}
}
Err(_) => assert!(false),
}
}
}