use crate::node::{InputStreams, Node, NodeExecutionError, OutputStreams};
use crate::nodes::common::BaseNode;
use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_stream::{StreamExt, wrappers::ReceiverStream};
#[async_trait]
pub trait TryFunction: Any + Send + Sync {
async fn apply(
&self,
value: Arc<dyn Any + Send + Sync>,
) -> Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>>;
}
#[async_trait]
pub trait CatchFunction: Any + Send + Sync {
async fn apply(
&self,
error: Arc<dyn Any + Send + Sync>,
) -> Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>>;
}
pub type TryConfig = Arc<dyn TryFunction>;
pub type CatchConfig = Arc<dyn CatchFunction>;
struct TryFunctionWrapper<F> {
function: F,
}
#[async_trait]
impl<F, Fut> TryFunction for TryFunctionWrapper<F>
where
F: Fn(Arc<dyn Any + Send + Sync>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>>>
+ Send,
{
async fn apply(
&self,
value: Arc<dyn Any + Send + Sync>,
) -> Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>> {
(self.function)(value).await
}
}
struct CatchFunctionWrapper<F> {
function: F,
}
#[async_trait]
impl<F, Fut> CatchFunction for CatchFunctionWrapper<F>
where
F: Fn(Arc<dyn Any + Send + Sync>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>>>
+ Send,
{
async fn apply(
&self,
error: Arc<dyn Any + Send + Sync>,
) -> Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>> {
(self.function)(error).await
}
}
pub fn try_config<F, Fut>(function: F) -> TryConfig
where
F: Fn(Arc<dyn Any + Send + Sync>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>>>
+ Send
+ 'static,
{
Arc::new(TryFunctionWrapper {
function: move |v| {
Box::pin((function)(v))
as Pin<
Box<
dyn std::future::Future<
Output = Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>>,
> + Send,
>,
>
},
})
}
pub fn catch_config<F, Fut>(function: F) -> CatchConfig
where
F: Fn(Arc<dyn Any + Send + Sync>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>>>
+ Send
+ 'static,
{
Arc::new(CatchFunctionWrapper {
function: move |e| {
Box::pin((function)(e))
as Pin<
Box<
dyn std::future::Future<
Output = Result<Arc<dyn Any + Send + Sync>, Arc<dyn Any + Send + Sync>>,
> + Send,
>,
>
},
})
}
#[derive(Debug, PartialEq)]
enum InputPort {
In,
Try,
Catch,
}
pub struct TryCatchNode {
pub(crate) base: BaseNode,
current_try_config: Arc<Mutex<Option<TryConfig>>>,
current_catch_config: Arc<Mutex<Option<CatchConfig>>>,
}
impl TryCatchNode {
pub fn new(name: String) -> Self {
Self {
base: BaseNode::new(
name,
vec![
"configuration".to_string(),
"in".to_string(),
"try".to_string(),
"catch".to_string(),
],
vec!["out".to_string(), "error".to_string()],
),
current_try_config: Arc::new(Mutex::new(None)),
current_catch_config: Arc::new(Mutex::new(None)),
}
}
pub fn has_try_config(&self) -> bool {
self
.current_try_config
.try_lock()
.map(|g| g.is_some())
.unwrap_or(false)
}
pub fn has_catch_config(&self) -> bool {
self
.current_catch_config
.try_lock()
.map(|g| g.is_some())
.unwrap_or(false)
}
}
#[async_trait]
#[allow(clippy::type_complexity)]
impl Node for TryCatchNode {
fn name(&self) -> &str {
self.base.name()
}
fn set_name(&mut self, name: &str) {
self.base.set_name(name);
}
fn input_port_names(&self) -> &[String] {
self.base.input_port_names()
}
fn output_port_names(&self) -> &[String] {
self.base.output_port_names()
}
fn has_input_port(&self, name: &str) -> bool {
self.base.has_input_port(name)
}
fn has_output_port(&self, name: &str) -> bool {
self.base.has_output_port(name)
}
fn execute(
&self,
mut inputs: InputStreams,
) -> Pin<
Box<dyn std::future::Future<Output = Result<OutputStreams, NodeExecutionError>> + Send + '_>,
> {
let try_config_state = Arc::clone(&self.current_try_config);
let catch_config_state = Arc::clone(&self.current_catch_config);
Box::pin(async move {
let _config_stream = inputs.remove("configuration");
let in_stream = inputs.remove("in").ok_or("Missing 'in' input")?;
let try_stream = inputs.remove("try").ok_or("Missing 'try' input")?;
let catch_stream = inputs.remove("catch").ok_or("Missing 'catch' input")?;
let in_stream = in_stream.map(|item| (InputPort::In, item));
let try_stream = try_stream.map(|item| (InputPort::Try, item));
let catch_stream = catch_stream.map(|item| (InputPort::Catch, item));
let (out_tx, out_rx) = tokio::sync::mpsc::channel(10);
let (error_tx, error_rx) = tokio::sync::mpsc::channel(10);
let out_tx_clone = out_tx.clone();
let error_tx_clone = error_tx.clone();
tokio::spawn(async move {
let mut in_stream = in_stream;
let mut try_stream = try_stream;
let mut catch_stream = catch_stream;
let mut current_try_config: Option<TryConfig> = None;
let mut current_catch_config: Option<CatchConfig> = None;
loop {
tokio::select! {
try_result = try_stream.next() => {
if let Some((_, item)) = try_result {
if let Ok(arc_arc_fn) = item.clone().downcast::<Arc<Arc<dyn TryFunction>>>() {
current_try_config = Some(Arc::clone(&**arc_arc_fn));
let mut config = try_config_state.lock().await;
*config = Some(Arc::clone(&**arc_arc_fn));
} else if let Ok(arc_function) = item.clone().downcast::<Arc<dyn TryFunction>>() {
current_try_config = Some(Arc::clone(&*arc_function));
let mut config = try_config_state.lock().await;
*config = Some(Arc::clone(&*arc_function));
}
}
}
catch_result = catch_stream.next() => {
if let Some((_, item)) = catch_result {
if let Ok(arc_arc_fn) = item.clone().downcast::<Arc<Arc<dyn CatchFunction>>>() {
current_catch_config = Some(Arc::clone(&**arc_arc_fn));
let mut config = catch_config_state.lock().await;
*config = Some(Arc::clone(&**arc_arc_fn));
} else if let Ok(arc_function) = item.clone().downcast::<Arc<dyn CatchFunction>>() {
current_catch_config = Some(Arc::clone(&*arc_function));
let mut config = catch_config_state.lock().await;
*config = Some(Arc::clone(&*arc_function));
}
}
}
in_result = in_stream.next() => {
match in_result {
Some((_, item)) => {
if let Some(try_fn) = ¤t_try_config {
match try_fn.apply(item.clone()).await {
Ok(result) => {
let _ = out_tx_clone.send(result).await;
}
Err(error) => {
if let Some(catch_fn) = ¤t_catch_config {
match catch_fn.apply(error).await {
Ok(result) => {
let _ = out_tx_clone.send(result).await;
}
Err(catch_error) => {
let _ = error_tx_clone.send(catch_error).await;
}
}
} else {
let _ = error_tx_clone.send(error).await;
}
}
}
} else {
let error_msg =
Arc::new("No try function configured".to_string()) as Arc<dyn Any + Send + Sync>;
let _ = error_tx_clone.send(error_msg).await;
}
}
None => {
if try_stream.next().await.is_none() && catch_stream.next().await.is_none() {
break; }
}
}
}
}
}
});
let mut outputs = HashMap::new();
outputs.insert(
"out".to_string(),
Box::pin(ReceiverStream::new(out_rx))
as Pin<Box<dyn tokio_stream::Stream<Item = Arc<dyn Any + Send + Sync>> + Send>>,
);
outputs.insert(
"error".to_string(),
Box::pin(ReceiverStream::new(error_rx))
as Pin<Box<dyn tokio_stream::Stream<Item = Arc<dyn Any + Send + Sync>> + Send>>,
);
Ok(outputs)
})
}
}