use serde::{Deserialize, Serialize};
use crate::broker::Broker;
use crate::error::TaskResult;
use crate::result_backend::ResultBackend;
use crate::signature::Signature;
use crate::task_id::TaskId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Canvas {
Single(Signature),
Chain(Vec<Canvas>),
Group(Vec<Canvas>),
Chord {
group: Vec<Canvas>,
callback: Box<Canvas>,
},
}
#[derive(Debug, Clone)]
pub struct WorkflowHandle {
pub id: String,
pub task_ids: Vec<TaskId>,
}
impl Canvas {
pub async fn apply(
&self,
broker: &dyn Broker,
backend: &dyn ResultBackend,
) -> TaskResult<WorkflowHandle> {
let correlation_id = uuid::Uuid::now_v7().to_string();
let mut task_ids = Vec::new();
self.apply_inner(broker, backend, &correlation_id, &mut task_ids)
.await?;
Ok(WorkflowHandle {
id: correlation_id,
task_ids,
})
}
async fn apply_inner(
&self,
broker: &dyn Broker,
backend: &dyn ResultBackend,
correlation_id: &str,
task_ids: &mut Vec<TaskId>,
) -> TaskResult<()> {
match self {
Canvas::Single(sig) => {
let mut msg = sig.clone().into_message();
msg.correlation_id = Some(correlation_id.to_string());
task_ids.push(msg.id);
broker.enqueue(msg).await?;
}
Canvas::Chain(steps) => {
if steps.is_empty() {
return Ok(());
}
let sigs = Self::flatten_chain(steps);
if sigs.is_empty() {
return Ok(());
}
let mut first_msg = sigs[0].clone().into_message();
first_msg.correlation_id = Some(correlation_id.to_string());
if sigs.len() > 1 {
let remaining: Vec<Signature> = sigs[1..].to_vec();
let remaining_json =
serde_json::to_string(&remaining).expect("failed to serialize chain steps");
first_msg
.headers
.insert("kojin.chain_next".to_string(), remaining_json);
}
task_ids.push(first_msg.id);
broker.enqueue(first_msg).await?;
}
Canvas::Group(members) => {
let group_id = uuid::Uuid::now_v7().to_string();
let sigs = Self::flatten_group(members);
let total = sigs.len() as u32;
backend.init_group(&group_id, total).await?;
for sig in &sigs {
let mut msg = sig.clone().into_message();
msg.correlation_id = Some(correlation_id.to_string());
msg.group_id = Some(group_id.clone());
msg.group_total = Some(total);
task_ids.push(msg.id);
broker.enqueue(msg).await?;
}
}
Canvas::Chord { group, callback } => {
let group_id = uuid::Uuid::now_v7().to_string();
let sigs = Self::flatten_group(group);
let total = sigs.len() as u32;
backend.init_group(&group_id, total).await?;
let callback_sigs = Self::flatten_chain(&[*callback.clone()]);
let callback_msg = if !callback_sigs.is_empty() {
callback_sigs[0].clone().into_message()
} else {
return Ok(());
};
for sig in &sigs {
let mut msg = sig.clone().into_message();
msg.correlation_id = Some(correlation_id.to_string());
msg.group_id = Some(group_id.clone());
msg.group_total = Some(total);
msg.chord_callback = Some(Box::new(callback_msg.clone()));
task_ids.push(msg.id);
broker.enqueue(msg).await?;
}
}
}
Ok(())
}
fn flatten_chain(steps: &[Canvas]) -> Vec<Signature> {
let mut result = Vec::new();
for step in steps {
match step {
Canvas::Single(sig) => result.push(sig.clone()),
Canvas::Chain(inner) => result.extend(Self::flatten_chain(inner)),
_ => {
tracing::warn!("Nested group/chord in chain is not yet supported, skipping");
}
}
}
result
}
fn flatten_group(members: &[Canvas]) -> Vec<Signature> {
let mut result = Vec::new();
for member in members {
match member {
Canvas::Single(sig) => result.push(sig.clone()),
_ => {
tracing::warn!("Nested canvas in group is not yet supported, skipping");
}
}
}
result
}
}
#[macro_export]
macro_rules! chain {
($($sig:expr),+ $(,)?) => {
$crate::canvas::Canvas::Chain(vec![
$($crate::canvas::Canvas::Single($sig)),+
])
};
}
#[macro_export]
macro_rules! group {
($($sig:expr),+ $(,)?) => {
$crate::canvas::Canvas::Group(vec![
$($crate::canvas::Canvas::Single($sig)),+
])
};
}
pub fn chord(group_items: Vec<Signature>, callback: Signature) -> Canvas {
Canvas::Chord {
group: group_items.into_iter().map(Canvas::Single).collect(),
callback: Box::new(Canvas::Single(callback)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory_broker::MemoryBroker;
use crate::memory_result_backend::MemoryResultBackend;
fn sig(name: &str) -> Signature {
Signature::new(name, "default", serde_json::json!({}))
}
#[test]
fn chain_macro() {
let c = chain![sig("a"), sig("b"), sig("c")];
match c {
Canvas::Chain(steps) => assert_eq!(steps.len(), 3),
_ => panic!("expected Chain"),
}
}
#[test]
fn group_macro() {
let g = group![sig("a"), sig("b")];
match g {
Canvas::Group(members) => assert_eq!(members.len(), 2),
_ => panic!("expected Group"),
}
}
#[test]
fn chord_constructor() {
let c = chord(vec![sig("a"), sig("b")], sig("callback"));
match c {
Canvas::Chord { group, callback } => {
assert_eq!(group.len(), 2);
assert!(matches!(*callback, Canvas::Single(_)));
}
_ => panic!("expected Chord"),
}
}
#[tokio::test]
async fn apply_single() {
let broker = MemoryBroker::new();
let backend = MemoryResultBackend::new();
let canvas = Canvas::Single(sig("task_a"));
let handle = canvas.apply(&broker, &backend).await.unwrap();
assert_eq!(handle.task_ids.len(), 1);
assert_eq!(broker.queue_len("default").await.unwrap(), 1);
}
#[tokio::test]
async fn apply_chain() {
let broker = MemoryBroker::new();
let backend = MemoryResultBackend::new();
let canvas = chain![sig("a"), sig("b"), sig("c")];
let handle = canvas.apply(&broker, &backend).await.unwrap();
assert_eq!(handle.task_ids.len(), 1);
assert_eq!(broker.queue_len("default").await.unwrap(), 1);
let msg = broker
.dequeue(
&["default".to_string()],
std::time::Duration::from_millis(100),
)
.await
.unwrap()
.unwrap();
assert!(msg.headers.contains_key("kojin.chain_next"));
let remaining: Vec<Signature> =
serde_json::from_str(msg.headers.get("kojin.chain_next").unwrap()).unwrap();
assert_eq!(remaining.len(), 2);
assert_eq!(remaining[0].task_name, "b");
assert_eq!(remaining[1].task_name, "c");
}
#[tokio::test]
async fn apply_group() {
let broker = MemoryBroker::new();
let backend = MemoryResultBackend::new();
let canvas = group![sig("a"), sig("b"), sig("c")];
let handle = canvas.apply(&broker, &backend).await.unwrap();
assert_eq!(handle.task_ids.len(), 3);
assert_eq!(broker.queue_len("default").await.unwrap(), 3);
}
#[tokio::test]
async fn apply_chord() {
let broker = MemoryBroker::new();
let backend = MemoryResultBackend::new();
let canvas = chord(vec![sig("a"), sig("b")], sig("callback"));
let handle = canvas.apply(&broker, &backend).await.unwrap();
assert_eq!(handle.task_ids.len(), 2);
assert_eq!(broker.queue_len("default").await.unwrap(), 2);
let msg = broker
.dequeue(
&["default".to_string()],
std::time::Duration::from_millis(100),
)
.await
.unwrap()
.unwrap();
assert!(msg.chord_callback.is_some());
assert_eq!(msg.chord_callback.unwrap().task_name, "callback");
}
}