use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use dataflow_rs::engine::error::DataflowError;
use dataflow_rs::engine::functions::AsyncFunctionHandler;
use dataflow_rs::engine::functions::config::FunctionConfig;
use dataflow_rs::engine::message::{Change, Message};
use datalogic_rs::DataLogic;
use serde::Deserialize;
use serde_json::Value;
use tokio::sync::RwLock;
use super::http_common::{get_nested, set_nested};
const ORION_META_PREFIX: &str = "_orion_";
const META_CALL_DEPTH: &str = "_orion_call_depth";
const META_CALL_CHAIN: &str = "_orion_call_chain";
#[derive(Debug, Deserialize)]
pub struct ChannelCallInput {
pub channel: String,
#[serde(default)]
pub channel_logic: Option<Value>,
#[serde(default)]
pub response_path: Option<String>,
#[serde(default)]
pub data: Option<Value>,
#[serde(default)]
pub data_logic: Option<Value>,
#[serde(default)]
pub timeout_ms: Option<u64>,
}
pub struct ChannelCallHandler {
pub engine: Arc<RwLock<Arc<dataflow_rs::Engine>>>,
pub max_call_depth: u32,
pub default_timeout_ms: u64,
}
#[async_trait]
impl AsyncFunctionHandler for ChannelCallHandler {
async fn execute(
&self,
message: &mut Message,
config: &FunctionConfig,
datalogic: Arc<DataLogic>,
) -> dataflow_rs::Result<(usize, Vec<Change>)> {
let input_value = match config {
FunctionConfig::Custom { input, .. } => input,
_ => {
return Err(DataflowError::Validation(
"Expected Custom config for channel_call".into(),
));
}
};
let input: ChannelCallInput =
serde_json::from_value(input_value.clone()).map_err(|e: serde_json::Error| {
DataflowError::Validation(format!("Invalid channel_call config: {e}"))
})?;
let target_channel = if let Some(ref logic) = input.channel_logic {
let context = message.get_context_arc();
let compiled = datalogic
.compile(logic)
.map_err(|e| DataflowError::LogicEvaluation(e.to_string()))?;
let result = datalogic
.evaluate(&compiled, context)
.map_err(|e| DataflowError::LogicEvaluation(e.to_string()))?;
result.as_str().map(|s| s.to_string()).ok_or_else(|| {
DataflowError::Validation("channel_logic must evaluate to a string".to_string())
})?
} else {
input.channel.clone()
};
if target_channel.is_empty() {
return Err(DataflowError::Validation(
"channel_call: target channel name must not be empty".into(),
));
}
let parent_depth = message
.metadata()
.get(META_CALL_DEPTH)
.and_then(|v| v.as_u64())
.unwrap_or(0);
let parent_chain: Vec<String> = message
.metadata()
.get(META_CALL_CHAIN)
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
if parent_depth >= self.max_call_depth as u64 {
return Err(DataflowError::Validation(format!(
"channel_call: max call depth {} exceeded (chain: {})",
self.max_call_depth,
format_chain(&parent_chain, &target_channel),
)));
}
if parent_chain.contains(&target_channel) {
return Err(DataflowError::Validation(format!(
"channel_call: cycle detected: {}",
format_chain(&parent_chain, &target_channel),
)));
}
let call_data = if let Some(ref logic) = input.data_logic {
let context = message.get_context_arc();
let compiled = datalogic
.compile(logic)
.map_err(|e| DataflowError::LogicEvaluation(e.to_string()))?;
datalogic
.evaluate(&compiled, context)
.map_err(|e| DataflowError::LogicEvaluation(e.to_string()))?
} else if let Some(ref data) = input.data {
data.clone()
} else {
(*message.payload).clone()
};
let mut child_message = Message::from_value(&call_data);
if let Some(parent_meta) = message.metadata().as_object()
&& let Some(child_meta) = child_message.metadata_mut().as_object_mut()
{
for (k, v) in parent_meta {
if k != "channel" {
child_meta.insert(k.clone(), v.clone());
}
}
}
let child_depth = parent_depth + 1;
let mut child_chain = parent_chain;
child_chain.push(target_channel.clone());
child_message.metadata_mut()[META_CALL_DEPTH] = serde_json::json!(child_depth);
child_message.metadata_mut()[META_CALL_CHAIN] = serde_json::json!(child_chain);
let engine = crate::engine::acquire_engine_read(&self.engine).await;
let timeout = Duration::from_millis(input.timeout_ms.unwrap_or(self.default_timeout_ms));
let process_result = tokio::time::timeout(
timeout,
engine.process_message_for_channel(&target_channel, &mut child_message),
)
.await;
match process_result {
Ok(inner) => inner.map_err(|e| {
DataflowError::function_execution(
format!("channel_call to '{}' failed: {}", target_channel, e),
None,
)
})?,
Err(_) => {
return Err(DataflowError::Timeout(format!(
"channel_call to '{}' timed out after {}ms",
target_channel,
timeout.as_millis()
)));
}
}
if let Some(meta) = child_message.metadata_mut().as_object_mut() {
meta.retain(|k, _| !k.starts_with(ORION_META_PREFIX));
}
let mut changes = Vec::new();
let result_data = child_message.data().clone();
if let Some(ref response_path) = input.response_path {
let old_value = get_nested(&message.context, response_path);
set_nested(&mut message.context, response_path, result_data.clone());
message.invalidate_context_cache();
changes.push(Change {
path: Arc::from(response_path.as_str()),
old_value: Arc::new(old_value),
new_value: Arc::new(result_data),
});
} else {
let old_value = message.data().clone();
*message.data_mut() = result_data.clone();
message.invalidate_context_cache();
changes.push(Change {
path: Arc::from("data"),
old_value: Arc::new(old_value),
new_value: Arc::new(result_data),
});
}
Ok((200, changes))
}
}
fn format_chain(chain: &[String], target: &str) -> String {
let mut parts: Vec<&str> = chain.iter().map(|s| s.as_str()).collect();
parts.push(target);
parts.join(" -> ")
}