use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use channel_plugin::message::ChannelMessage;
use handlebars::Handlebars;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
#[async_trait]
#[typetag::serde]
pub trait FlowRouter: Send + Sync {
async fn resolve(&self, msg: &ChannelMessage) -> Vec<(String, String)>;
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename = "channel")]
pub struct ChannelFlowRouter {
pub map: HashMap<String, Vec<String>>,
}
impl ChannelFlowRouter {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn add_mapping(&mut self, channel: &str, flow_name: &str) {
self.map
.entry(channel.into())
.or_default()
.push(flow_name.into());
}
}
#[async_trait]
#[typetag::serde]
impl FlowRouter for ChannelFlowRouter {
async fn resolve(&self, msg: &ChannelMessage) -> Vec<(String, String)> {
self.map
.get(&msg.channel)
.into_iter()
.flat_map(|flows| {
flows
.iter()
.map(move |flow| (flow.clone(), msg.channel.clone()))
})
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename = "script")]
pub struct ScriptFlowRouter {
pub templates: HashMap<String, String>,
#[serde(skip)]
#[schemars(skip)]
#[schemars(default = "default_registry")]
registry: Arc<Handlebars<'static>>,
}
fn default_registry() -> Arc<Handlebars<'static>> {
Arc::new(Handlebars::new())
}
impl ScriptFlowRouter {
pub fn new() -> Self {
Self {
templates: HashMap::new(),
registry: default_registry(),
}
}
pub fn set_template(&mut self, channel: &str, template: &str) {
self.templates.insert(channel.into(), template.into());
}
fn to_context(msg: &ChannelMessage) -> serde_json::Value {
json!(msg)
}
}
#[async_trait]
#[typetag::serde]
impl FlowRouter for ScriptFlowRouter {
async fn resolve(&self, msg: &ChannelMessage) -> Vec<(String, String)> {
let tmpl = match self.templates.get(&msg.channel) {
Some(t) => t,
None => return Vec::new(),
};
let ctx = Self::to_context(msg);
match self.registry.render_template(tmpl, &ctx) {
Ok(rendered) => {
rendered
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(|flow| (flow.to_string(), msg.channel.clone()))
.collect()
}
Err(e) => {
tracing::error!("Flow template error for channel {}: {:?}", msg.channel, e);
Vec::new()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use channel_plugin::message::ChannelMessage;
fn make_msg(channel: &str) -> ChannelMessage {
ChannelMessage {
channel: channel.into(),
..ChannelMessage::default()
}
}
#[tokio::test]
async fn channel_router_empty() {
let router = ChannelFlowRouter::new();
let out = router.resolve(&make_msg("foo")).await;
assert!(out.is_empty());
}
#[tokio::test]
async fn channel_router_mapped() {
let mut router = ChannelFlowRouter::new();
router.add_mapping("email", "flow1");
router.add_mapping("email", "flow2");
let out = router.resolve(&make_msg("email")).await;
assert_eq!(
out,
vec![
("flow1".to_string(), "email".to_string()),
("flow2".to_string(), "email".to_string()),
]
);
}
#[tokio::test]
async fn script_router_empty() {
let router = ScriptFlowRouter::new();
let out = router.resolve(&make_msg("sms")).await;
assert!(out.is_empty());
}
#[tokio::test]
async fn script_router_map_list() {
let mut router = ScriptFlowRouter::new();
router.set_template("sms", "a,b , c");
let out = router.resolve(&make_msg("sms")).await;
assert_eq!(
out,
vec![
("a".into(), "sms".into()),
("b".into(), "sms".into()),
("c".into(), "sms".into()),
]
);
}
#[tokio::test]
async fn script_router_bad_template() {
let mut router = ScriptFlowRouter::new();
router.set_template("sms", "{{#if}}"); let out = router.resolve(&make_msg("sms")).await;
assert!(out.is_empty());
}
}