use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::any::Any;
use async_trait::async_trait;
use futures::future;
use serde_json::Value;
use log::warn;
use crate::base::{BaseNode, Node, SharedState, Action};
use crate::flow::{Flow, BatchFlow};
use crate::async_node::AsyncNodeTrait;
use crate::error::{Error, Result};
#[derive(Clone)]
pub struct AsyncFlow {
flow: Flow,
base: BaseNode,
}
impl AsyncFlow {
pub fn new(start: Arc<dyn Node>) -> Self {
Self {
flow: Flow::new(start),
base: BaseNode::new(),
}
}
fn is_async(&self, node: &Arc<dyn Node>) -> bool {
let type_id = node.type_id();
let async_node_ids = [
std::any::TypeId::of::<dyn AsyncNodeTrait>(),
];
async_node_ids.contains(&type_id)
}
pub async fn _orch_async(&self, shared: &mut SharedState, params: Option<HashMap<String, Value>>) -> Result<()> {
let mut curr = self.flow.start.clone();
let params = params.unwrap_or_else(|| {
self.base.params().read().unwrap().clone()
});
curr.set_params(params);
while let Some(node) = curr.clone().into() {
let action = if self.is_async(&node) {
Err(Error::InvalidOperation("Dynamic dispatch for async nodes not implemented".into()))?
} else {
node._run(shared)?
};
curr = match self.flow.get_next_node(node, action) {
Some(next) => next,
None => break,
};
}
Ok(())
}
}
impl Node for AsyncFlow {
fn params(&self) -> Arc<RwLock<HashMap<String, Value>>> {
self.base.params()
}
fn successors(&self) -> Arc<RwLock<HashMap<String, Arc<dyn Node>>>> {
self.base.successors()
}
fn set_params(&self, params: HashMap<String, Value>) {
let params_lock = self.params();
let mut p = params_lock.write().unwrap();
*p = params;
}
fn add_successor(&self, node: Arc<dyn Node>, action: &str) -> Result<Arc<dyn Node>> {
let successors_lock = self.successors();
let mut successors = successors_lock.write().unwrap();
if successors.contains_key(action) {
warn!("Overwriting successor for action '{}'", action);
}
successors.insert(action.to_string(), node.clone());
Ok(node)
}
fn prep(&self, _shared: &mut SharedState) -> Result<Value> {
Err(Error::InvalidOperation("Use prep_async".into()))
}
fn exec(&self, _prep_res: Value) -> Result<Value> {
Err(Error::InvalidOperation("AsyncFlow can't exec".into()))
}
fn post(&self, _shared: &mut SharedState, _prep_res: Value, _exec_res: Value) -> Result<Action> {
Err(Error::InvalidOperation("Use post_async".into()))
}
fn _run(&self, _shared: &mut SharedState) -> Result<Action> {
Err(Error::InvalidOperation("Use run_async".into()))
}
}
#[async_trait]
impl AsyncNodeTrait for AsyncFlow {
async fn _exec_async(&self, _prep_res: Value) -> Result<Value> {
Err(Error::InvalidOperation("AsyncFlow can't exec".into()))
}
async fn _run_async(&self, shared: &mut SharedState) -> Result<Action> {
let prep_res = self.prep_async(shared).await?;
self._orch_async(shared, None).await?;
self.post_async(shared, prep_res, Value::Null).await
}
}
#[derive(Clone)]
pub struct AsyncBatchFlow {
flow: AsyncFlow,
batch_flow: BatchFlow,
}
impl AsyncBatchFlow {
pub fn new(start: Arc<dyn Node>) -> Self {
Self {
flow: AsyncFlow::new(start.clone()),
batch_flow: BatchFlow::new(start),
}
}
}
impl Node for AsyncBatchFlow {
fn params(&self) -> Arc<RwLock<HashMap<String, Value>>> {
self.flow.params()
}
fn successors(&self) -> Arc<RwLock<HashMap<String, Arc<dyn Node>>>> {
self.flow.successors()
}
fn set_params(&self, params: HashMap<String, Value>) {
self.flow.set_params(params);
}
fn add_successor(&self, node: Arc<dyn Node>, action: &str) -> Result<Arc<dyn Node>> {
self.flow.add_successor(node, action)
}
fn prep(&self, _shared: &mut SharedState) -> Result<Value> {
Err(Error::InvalidOperation("Use prep_async".into()))
}
fn exec(&self, _prep_res: Value) -> Result<Value> {
Err(Error::InvalidOperation("AsyncBatchFlow can't exec".into()))
}
fn post(&self, _shared: &mut SharedState, _prep_res: Value, _exec_res: Value) -> Result<Action> {
Err(Error::InvalidOperation("Use post_async".into()))
}
fn _run(&self, _shared: &mut SharedState) -> Result<Action> {
Err(Error::InvalidOperation("Use run_async".into()))
}
}
#[async_trait]
impl AsyncNodeTrait for AsyncBatchFlow {
async fn _exec_async(&self, _prep_res: Value) -> Result<Value> {
Err(Error::InvalidOperation("AsyncBatchFlow can't exec".into()))
}
async fn _run_async(&self, shared: &mut SharedState) -> Result<Action> {
let prep_res = self.prep_async(shared).await?;
let batch_params = match &prep_res {
Value::Array(items) => items
.iter()
.map(|v| {
if let Value::Object(map) = v {
let map: HashMap<String, Value> = map
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Ok(map)
} else {
Err(Error::NodeExecution("AsyncBatchFlow prep should return array of objects".into()))
}
})
.collect::<Result<Vec<_>>>()?,
Value::Null => vec![],
_ => return Err(Error::NodeExecution("AsyncBatchFlow prep should return array or null".into())),
};
let flow_params = self.flow.params().read().unwrap().clone();
for mut bp in batch_params {
for (k, v) in flow_params.clone() {
bp.entry(k).or_insert(v);
}
self.flow._orch_async(shared, Some(bp)).await?;
}
self.post_async(shared, prep_res, Value::Null).await
}
}
#[derive(Clone)]
pub struct AsyncParallelBatchFlow {
batch_flow: AsyncBatchFlow,
}
impl AsyncParallelBatchFlow {
pub fn new(start: Arc<dyn Node>) -> Self {
Self {
batch_flow: AsyncBatchFlow::new(start),
}
}
}
impl Node for AsyncParallelBatchFlow {
fn params(&self) -> Arc<RwLock<HashMap<String, Value>>> {
self.batch_flow.params()
}
fn successors(&self) -> Arc<RwLock<HashMap<String, Arc<dyn Node>>>> {
self.batch_flow.successors()
}
fn set_params(&self, params: HashMap<String, Value>) {
self.batch_flow.set_params(params);
}
fn add_successor(&self, node: Arc<dyn Node>, action: &str) -> Result<Arc<dyn Node>> {
self.batch_flow.add_successor(node, action)
}
fn prep(&self, _shared: &mut SharedState) -> Result<Value> {
Err(Error::InvalidOperation("Use prep_async".into()))
}
fn exec(&self, _prep_res: Value) -> Result<Value> {
Err(Error::InvalidOperation("AsyncParallelBatchFlow can't exec".into()))
}
fn post(&self, _shared: &mut SharedState, _prep_res: Value, _exec_res: Value) -> Result<Action> {
Err(Error::InvalidOperation("Use post_async".into()))
}
fn _run(&self, _shared: &mut SharedState) -> Result<Action> {
Err(Error::InvalidOperation("Use run_async".into()))
}
}
#[async_trait]
impl AsyncNodeTrait for AsyncParallelBatchFlow {
async fn prep_async(&self, shared: &mut SharedState) -> Result<Value> {
self.batch_flow.prep_async(shared).await
}
async fn post_async(&self, shared: &mut SharedState, prep_res: Value, exec_res: Value) -> Result<Action> {
self.batch_flow.post_async(shared, prep_res, exec_res).await
}
async fn _exec_async(&self, _prep_res: Value) -> Result<Value> {
Err(Error::InvalidOperation("AsyncParallelBatchFlow can't exec".into()))
}
async fn _run_async(&self, shared: &mut SharedState) -> Result<Action> {
let prep_res = self.prep_async(shared).await?;
let batch_params = match &prep_res {
Value::Array(items) => items
.iter()
.map(|v| {
if let Value::Object(map) = v {
let map: HashMap<String, Value> = map
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Ok(map)
} else {
Err(Error::NodeExecution("AsyncParallelBatchFlow prep should return array of objects".into()))
}
})
.collect::<Result<Vec<_>>>()?,
Value::Null => vec![],
_ => return Err(Error::NodeExecution("AsyncParallelBatchFlow prep should return array or null".into())),
};
if batch_params.is_empty() {
return self.post_async(shared, prep_res, Value::Null).await;
}
let flow_params = self.batch_flow.params().read().unwrap().clone();
let futures = batch_params
.into_iter()
.map(|mut bp| {
let flow = self.batch_flow.flow.clone();
let mut shared_clone = shared.clone();
let flow_params = flow_params.clone();
for (k, v) in flow_params {
bp.entry(k).or_insert(v);
}
async move { flow._orch_async(&mut shared_clone, Some(bp)).await }
})
.collect::<Vec<_>>();
let results = future::join_all(futures).await;
for result in results {
result?;
}
self.post_async(shared, prep_res, Value::Null).await
}
}