use std::sync::Arc;
use async_trait::async_trait;
use futures::{StreamExt};
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use log::error;
use tokio::sync::{mpsc, Mutex};
use tokio::task::{JoinSet};
use crate::tokio::step::{AsyncStep, DeciderCallback, StepResult};
use crate::tokio::step::parallel_step_builder::AsyncParallelStepBuilderTrait;
use crate::tokio::step::step_builder::AsyncStepBuilderTrait;
const DEFAULT_CHUNK_SIZE: usize = 1000;
const DEFAULT_WORKERS_SIZE: usize = 1;
type DynParamAsyncCallback<I, O> = dyn Send + Sync + Fn(I) -> BoxFuture<'static, O>;
type ProcessorCallback<I, O> = Box<DynParamAsyncCallback<I, O>>;
type ReaderCallback<I> = Box<dyn Send + Sync + Fn() -> BoxFuture<'static, BoxStream<'static, I>>>;
#[async_trait]
pub trait ComplexStepBuilderTrait<I: Sized, O: Sized> {
fn reader(self, reader: ReaderCallback<I>) -> Self;
fn processor(self, processor: ProcessorCallback<I, O>) -> Self;
fn writer(self, writer: Box<DynParamAsyncCallback<Vec<O>, ()>>) -> Self;
fn chunk_size(self, chunk_size: usize) -> Self;
}
#[async_trait]
impl<I: Sized + 'static, O: Sized + 'static> ComplexStepBuilderTrait<I, O> for AsyncComplexStepBuilder<I, O> {
fn reader(self, reader: ReaderCallback<I>) -> Self {
AsyncComplexStepBuilder {
reader: Some(reader),
..self
}
}
fn processor(self, processor: ProcessorCallback<I, O>) -> Self {
AsyncComplexStepBuilder {
processor: Some(processor),
..self
}
}
fn writer(self, writer: Box<DynParamAsyncCallback<Vec<O>, ()>>) -> Self {
AsyncComplexStepBuilder {
writer: Some(writer),
..self
}
}
fn chunk_size(self, chunk_size: usize) -> Self {
AsyncComplexStepBuilder {
chunk_size: Some(chunk_size),
..self
}
}
}
pub struct AsyncComplexStepBuilder<I: Sized, O: Sized> {
reader: Option<ReaderCallback<I>>,
processor: Option<ProcessorCallback<I, O>>,
writer: Option<Box<DynParamAsyncCallback<Vec<O>, ()>>>,
chunk_size: Option<usize>,
workers: usize,
step: AsyncStep,
}
impl<I: Sized + Send + 'static, O: Sized + Send + 'static> AsyncStepBuilderTrait for AsyncComplexStepBuilder<I, O>
where
Self: Sized,
{
fn decider(self, decider: DeciderCallback) -> Self {
AsyncComplexStepBuilder {
step: AsyncStep {
decider: Some(decider),
..self.step
},
..self
}
}
fn throw_tolerant(self) -> Self {
AsyncComplexStepBuilder {
step: AsyncStep {
throw_tolerant: Some(true),
..self.step
},
..self
}
}
#[inline]
fn get(name: String) -> Self {
AsyncComplexStepBuilder {
reader: None,
processor: None,
writer: None,
chunk_size: None,
workers: DEFAULT_WORKERS_SIZE,
step: AsyncStep {
name,
callback: None,
decider: None,
throw_tolerant: None,
},
}
}
fn validate(self) -> Self {
if self.step.name.is_empty() {
panic!("Name is required");
}
if self.reader.is_none() {
panic!("Reader is required");
}
if self.processor.is_none() {
panic!("Processor is required");
}
if self.writer.is_none() {
panic!("Writer is required");
}
return self;
}
fn build(self) -> AsyncStep {
let mut current_self = self.validate();
let reader = Arc::new(current_self.reader.unwrap());
let processor = Arc::new(current_self.processor.unwrap());
let writer = Arc::new(current_self.writer.unwrap());
let throw_tolerant = current_self.step.throw_tolerant.unwrap_or(false);
let step_name = Arc::new(current_self.step.name.clone());
current_self.step.callback = Some(Box::new(move || {
let reader = Box::pin(reader.clone());
let processor = processor.clone();
let writer = writer.clone();
let chunk_size = current_self.chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE);
let throw_tolerant = throw_tolerant.clone();
let step_name = step_name.clone();
return Box::pin(async move {
let reader = Arc::clone(&reader);
let processor = Arc::clone(&processor);
let writer = Arc::clone(&writer);
let throw_tolerant = throw_tolerant.clone();
let step_name = Arc::clone(&step_name);
let mut join_workers = JoinSet::new();
let mut channels = Vec::new();
let step_result: Arc<Mutex<StepResult>> = Arc::new(Mutex::new(Ok(())));
for _ in 0..current_self.workers {
let (sender, receiver) = mpsc::channel::<I>(16);
let processor = Arc::clone(&processor);
let writer = Arc::clone(&writer);
let mut receiver = receiver;
let throw_tolerant = throw_tolerant.clone();
let step_result = Arc::clone(&step_result);
let step_name = Arc::clone(&step_name);
join_workers.spawn(async move {
let step_result = Arc::clone(&step_result);
let mut vec: Vec<O> = Vec::new();
let step_name = Arc::clone(&step_name);
while let Some(data) = receiver.recv().await {
let output = tokio::spawn(processor(data)).await;
if let Err(err) = output {
let mut step_result = step_result.lock().await;
*step_result = Err(err);
if !throw_tolerant {
panic!("step {}: Error to processing data", step_name);
} else {
error!("step {}: Error to processing data", step_name);
continue;
}
}
let output = output.unwrap();
vec.push(output);
if vec.len() >= chunk_size {
let vec_to_write = std::mem::take(&mut vec);
let writer_result = tokio::spawn(writer(vec_to_write)).await;
vec.clear();
if let Err(err) = writer_result {
if !throw_tolerant {
let mut error = step_result.lock().await;
*error = Err(err);
panic!("step {}: Error to writing data", step_name);
} else {
error!("step {}: Error to writing data", step_name);
}
}
}
}
if !vec.is_empty() {
let vec_to_write = std::mem::take(&mut vec);
let writer_result = tokio::spawn(writer(vec_to_write)).await;
if let Err(err) = writer_result {
if !throw_tolerant {
let mut step_result = step_result.lock().await;
*step_result = Err(err);
panic!("step {}: Error to writing data", step_name);
} else {
error!("step {}: Error to writing data", step_name);
}
}
}
});
channels.push(sender);
}
let mut iterator = reader().await;
let mut current_channel: usize = 0;
while let Some(data) = iterator.next().await {
if !throw_tolerant {
let step_result = Arc::clone(&step_result);
let step_result = step_result.lock().await;
if step_result.is_err() {
join_workers.abort_all();
panic!("step {}: Error to processing data", step_name);
}
}
let sender = &mut channels[current_channel];
sender.send(data).await.unwrap();
if current_channel == current_self.workers - 1 {
current_channel = 0;
} else {
current_channel += 1;
}
}
drop(channels);
while let Some(task_result) = join_workers.join_next().await {
if let Err(err) = task_result {
if !throw_tolerant {
return Err(err);
}
join_workers.abort_all();
}
}
let step_result = Arc::try_unwrap(step_result).unwrap();
let step_result = step_result.into_inner();
return step_result;
});
}));
return current_self.step;
}
}
impl<I: Sized + Send + 'static + Sync, O: Sized + Send + 'static + Sync> AsyncParallelStepBuilderTrait for AsyncComplexStepBuilder<I, O>
where
Self: Sized,
{
fn workers(self, workers: usize) -> Self {
AsyncComplexStepBuilder {
workers,
..self
}
}
}
pub fn get<I: Sized + 'static + Send, O: Sized + 'static + Send + Clone + Sync>(name: String) -> AsyncComplexStepBuilder<I, O> {
AsyncComplexStepBuilder::get(name)
}