use super::flow_router::{ChannelFlowRouter, ScriptFlowRouter};
use super::manager::{ChannelManager, IncomingHandler};
use crate::flow::manager::{Flow, FlowManager, NodeKind};
use crate::flow::session::SessionStore;
use crate::message::Message;
use crate::node::{ChannelOrigin, NodeContext, NodeErr, NodeError, NodeOut, NodeType, Routing};
use async_trait::async_trait;
use channel_plugin::message::{ChannelMessage, MessageContent, MessageDirection};
use dashmap::DashMap;
use schemars::Schema;
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::{error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename = "channel")]
pub struct ChannelNode {
pub channel_name: String,
pub flow_name: String,
pub node_id: String,
pub poll_messages: bool,
pub send_messages: bool,
pub remote: bool,
#[serde(rename = "router")]
pub router_config: FlowRouterConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename = "router")]
pub enum FlowRouterConfig {
#[serde(rename = "channel")]
Channel(ChannelFlowRouter),
#[serde(rename = "script")]
Script(ScriptFlowRouter),
}
pub async fn handle_message(
flow_name: &str,
node_id: &str,
remote: bool,
msg: &ChannelMessage,
fm: &Arc<FlowManager>,
) {
let session_id = msg
.session_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let input = Message::new(
&msg.id,
serde_json::to_value(msg.content.clone()).unwrap(),
session_id,
);
let channel_origin = ChannelOrigin::new(
msg.channel.clone(),
msg.reply_to_id.clone(),
msg.thread_id.clone(),
msg.from.clone(),
remote,
);
if let Some(report) = fm
.process_message(flow_name, node_id, input, Some(channel_origin))
.await
{
let payload_json = serde_json::to_string(&report).expect("cannot serialize report");
tracing::event!(
target: "request", tracing::Level::INFO,
flow = %flow_name,
node = %node_id,
report = %payload_json,
"flow run completed"
);
}
}
impl ChannelNode {
pub async fn handle_message(&self, remote: bool, msg: &ChannelMessage, fm: &Arc<FlowManager>) {
handle_message(&self.flow_name, &self.node_id, remote, msg, fm).await;
}
}
#[derive(Clone)]
pub struct ChannelsRegistry {
map: Arc<DashMap<String, Vec<ChannelNode>>>,
flow_manager: Arc<FlowManager>,
channel_manager: Arc<ChannelManager>,
}
impl ChannelsRegistry {
pub async fn new(
flow_manager: Arc<FlowManager>,
channel_manager: Arc<ChannelManager>,
) -> Arc<Self> {
let me = Arc::new(Self {
map: Arc::new(DashMap::new()),
flow_manager: flow_manager.clone(),
channel_manager,
});
let registry = me.clone();
flow_manager
.subscribe_flow_added(Arc::new(move |flow_id: &str, flow: &Flow| {
let mut incoming_targets = std::collections::HashSet::new();
for (_from, tos) in flow.connections() {
for to in tos {
incoming_targets.insert(to.clone());
}
}
for (node_name, cfg) in flow.nodes().iter() {
if let NodeKind::Channel { cfg } = &cfg.kind {
if !incoming_targets.contains(node_name.as_str()) {
registry.register(ChannelNode {
remote: cfg.channel_remote,
channel_name: cfg.channel_name.clone(),
flow_name: flow_id.to_string(),
node_id: node_name.clone(),
poll_messages: cfg.channel_in.clone(),
send_messages: cfg.channel_out.clone(),
router_config: FlowRouterConfig::Channel(ChannelFlowRouter::new()),
});
}
}
}
}))
.await;
me
}
pub fn find_if_node_in_flow(&self, flow: &str, node: &str) -> bool {
if let Some(flow) = self.flow_manager.flows().get(flow) {
flow.nodes().contains_key(node)
} else {
false
}
}
pub fn subscribe(&self) {}
pub fn register(&self, node: ChannelNode) {
self.map
.entry(node.channel_name.clone())
.or_default()
.push(node);
}
}
#[async_trait]
impl IncomingHandler for ChannelsRegistry {
async fn handle_incoming(&self, mut msg: ChannelMessage, session_store: SessionStore) {
if let Some(nodes) = self.map.get(&msg.channel) {
if nodes.is_empty() {
error!(
channel = %msg.channel,
"received message but channel has no nodes configured"
);
} else {
if let Some(channel_session_id) = msg.session_id.clone() {
let session_id = session_store
.get_or_create_channel(&channel_session_id)
.await;
msg.session_id = Some(session_id.clone());
let state = session_store.get_or_create(&session_id).await;
let session_flows = state.flows().unwrap_or_default();
let session_nodes = state.nodes().unwrap_or_default();
if !session_flows.is_empty() && !session_nodes.is_empty() {
let mut routed = false;
for flow in session_flows.iter() {
for node in session_nodes.iter() {
if self.find_if_node_in_flow(flow, node) {
if let Some(channel) = self.channel_manager.channel(&msg.channel) {
handle_message(flow, node, channel.remote(), &msg, &self.flow_manager).await;
routed = true;
}
}
}
}
if !routed {
info!(
"No matching node found for session flows/nodes: {:?} / {:?}",
session_flows, session_nodes
);
for node in nodes.iter().cloned() {
node.handle_message(node.remote, &msg, &self.flow_manager).await;
}
}
} else {
info!(
"No flows/nodes recorded in session state. Broadcasting to all the starting nodes for {}",
msg.channel
);
for node in nodes.iter().cloned() {
node.handle_message(node.remote, &msg, &self.flow_manager).await;
}
}
} else {
error!(
channel = %msg.channel,
"received message but no session included"
);
}
}
} else {
error!(
channel = %msg.channel,
"received message but no flows bound for this channel"
);
}
}
}
#[async_trait]
#[typetag::serde]
impl NodeType for ChannelNode {
fn type_name(&self) -> String {
self.channel_name.clone()
}
fn schema(&self) -> Schema {
schema_for!(ChannelNode)
}
#[tracing::instrument(name = "channel_node_process", skip(self, ctx))]
async fn process(&self, input: Message, ctx: &mut NodeContext) -> Result<NodeOut, NodeErr> {
let mut plugin = ctx
.channel_manager()
.channel(&self.channel_name)
.ok_or_else(|| {
NodeErr::fail(NodeError::Internal(format!(
"no such channel: {}",
self.channel_name
)))
})?;
let send_result = if let Ok(mut cm) =
serde_json::from_value::<ChannelMessage>(input.payload().clone())
{
cm.channel = self.channel_name.clone();
cm.direction = MessageDirection::Outgoing;
if cm.to.is_empty() {
if let Some(channel_origin) = ctx.channel_origin() {
cm.to = vec![channel_origin.participant()];
} else {
let error = format!(
"No to field was specified so don't know where to send the message to in channel {} with session id {:?}",
cm.channel,
input.session_id()
);
error!(error);
return Err(NodeErr::fail(NodeError::InvalidInput(error)));
}
}
plugin.send_message(cm).await
} else {
let text = input.payload().to_string();
let to = if let Some(channel_origin) = ctx.channel_origin() {
vec![channel_origin.participant()]
} else {
let error = format!(
"No to field was specified so don't know where to send the message to in channel {} with session id {:?}",
plugin.name(),
input.session_id()
);
error!(error);
return Err(NodeErr::fail(NodeError::InvalidInput(error)));
};
let cm = ChannelMessage {
to: to.clone(),
channel: self.channel_name.clone(),
session_id: Some(input.session_id().clone()),
direction: MessageDirection::Outgoing,
content: vec![MessageContent::Text { text: text }],
..Default::default()
};
plugin.send_message(cm).await
};
if let Err(e) = send_result {
warn!(error = ?e, "failed to send to channel {}", self.channel_name);
}
Ok(NodeOut::with_routing(input, Routing::FollowGraph))
}
fn clone_box(&self) -> Box<dyn NodeType> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::channel::manager::ChannelManager;
use crate::config::{ConfigManager, MapConfigManager};
use crate::flow::session::InMemorySessionStore;
use crate::logger::{LogConfig, Logger};
use crate::process::manager::ProcessManager;
use crate::secret::SecretsManager;
use crate::{
executor::Executor, flow::manager::FlowManager, logger::OpenTelemetryLogger,
secret::TestSecretsManager,
};
use channel_plugin::message::{ChannelMessage, MessageDirection};
#[tokio::test]
async fn test_registry_dispatches_safely() {
let store = InMemorySessionStore::new(10);
let secrets = SecretsManager(TestSecretsManager::new());
let logger = Logger(Box::new(OpenTelemetryLogger::new()));
let exec = Executor::new(secrets.clone(), logger);
let config = ConfigManager(MapConfigManager::new());
let cm = ChannelManager::new(config, secrets.clone(), "123".to_string(), store.clone(), LogConfig::default())
.await
.expect("could not create channel manager");
let pm = ProcessManager::dummy();
let fm = FlowManager::new(store.clone(), exec, cm.clone(), pm.clone(), secrets);
let reg = ChannelsRegistry::new(fm, cm).await;
let mut msg = ChannelMessage::default();
msg.channel = "foo".into();
msg.direction = MessageDirection::Incoming;
reg.handle_incoming(msg.clone(), store.clone()).await;
let node = ChannelNode {
channel_name: "foo".into(),
flow_name: "flow_x".into(),
node_id: "node_id".into(),
remote: false,
poll_messages: true,
send_messages: false,
router_config: FlowRouterConfig::Channel(ChannelFlowRouter::default()),
};
reg.register(node);
reg.handle_incoming(msg, store).await;
}
}