use crate::step::Step;
use crate::Context;
use crate::DataResult;
use async_channel::{Receiver, Sender};
use async_trait::async_trait;
use smol::stream::StreamExt;
use serde::Deserialize;
use serde_json::Value;
use std::io;
use uuid::Uuid;
#[derive(Debug, Deserialize, Clone)]
#[serde(default, deny_unknown_fields)]
pub struct Generator {
#[serde(alias = "alias")]
pub name: String,
#[serde(alias = "data")]
pub data_type: String,
#[serde(alias = "batch")]
#[serde(alias = "size")]
pub record_limit: usize,
#[serde(skip)]
pub receiver: Option<Receiver<Context>>,
#[serde(skip)]
pub sender: Option<Sender<Context>>,
}
impl Default for Generator {
fn default() -> Self {
let uuid = Uuid::new_v4();
Generator {
name: uuid.simple().to_string(),
data_type: DataResult::OK.to_string(),
record_limit: 1,
receiver: None,
sender: None,
}
}
}
#[async_trait]
impl Step for Generator {
fn set_receiver(&mut self, receiver: Receiver<Context>) {
self.receiver = Some(receiver);
}
fn receiver(&self) -> Option<&Receiver<Context>> {
self.receiver.as_ref()
}
fn set_sender(&mut self, sender: Sender<Context>) {
self.sender = Some(sender);
}
fn sender(&self) -> Option<&Sender<Context>> {
self.sender.as_ref()
}
#[instrument(name = "generator::exec",
skip(self),
fields(name=self.name,
data_type=self.data_type,
record_limit=self.record_limit,
))]
async fn exec(&self) -> io::Result<()> {
info!("Start generating data...");
let mut receiver_stream = self.receive().await;
let mut has_data_been_received = false;
let record_limit = self.record_limit;
while let Some(context_received) = receiver_stream.next().await {
if !has_data_been_received {
has_data_been_received = true;
}
if !context_received.input().is_type(self.data_type.as_ref()) {
trace!("Handles only this data type");
self.send(&context_received).await;
continue;
}
for _ in 0..record_limit {
let mut context = context_received.clone();
context.insert_step_result(self.name(), context.input());
self.send(&context).await;
}
}
if !has_data_been_received {
for _ in 0..record_limit {
let context = Context::new(self.name(), DataResult::Ok(Value::Null));
self.send(&context).await;
}
}
trace!(
"Stops generating data and sending context in the channel"
);
Ok(())
}
fn name(&self) -> String {
self.name.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use macro_rules_attribute::apply;
use smol_macros::test;
use serde_json::Value;
use std::io::{Error, ErrorKind};
use std::thread;
#[apply(test!)]
async fn exec_with_different_data_result_type() {
let mut step = Generator::default();
let (sender_input, receiver_input) = async_channel::unbounded();
let (sender_output, receiver_output) = async_channel::unbounded();
let data = serde_json::from_str(r#"{"field_1":"value_1"}"#).unwrap();
let error = Error::new(ErrorKind::InvalidData, "My error");
let context = Context::new("before".to_string(), DataResult::Err((data, error)));
let expected_context = context.clone();
thread::spawn(move || {
sender_input.try_send(context).unwrap();
});
step.receiver = Some(receiver_input);
step.sender = Some(sender_output);
step.exec().await.unwrap();
assert_eq!(expected_context, receiver_output.recv().await.unwrap());
}
#[apply(test!)]
async fn exec_with_same_data_result_type() {
let mut step = Generator::default();
let (sender_input, receiver_input) = async_channel::unbounded();
let (sender_output, receiver_output) = async_channel::unbounded();
let data: Value = serde_json::from_str(r#"{"field_1":"value_1"}"#).unwrap();
let context = Context::new("before".to_string(), DataResult::Ok(data.clone()));
let mut expected_context = context.clone();
expected_context.insert_step_result("my_step".to_string(), DataResult::Ok(data));
thread::spawn(move || {
sender_input.try_send(context).unwrap();
});
step.receiver = Some(receiver_input);
step.sender = Some(sender_output);
step.name = "my_step".to_string();
step.exec().await.unwrap();
assert_eq!(expected_context, receiver_output.recv().await.unwrap());
}
}