1use serde::{Deserialize, Serialize};
2
3use crate::broker::Broker;
4use crate::error::TaskResult;
5use crate::result_backend::ResultBackend;
6use crate::signature::Signature;
7use crate::task_id::TaskId;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub enum Canvas {
12 Single(Signature),
14 Chain(Vec<Canvas>),
16 Group(Vec<Canvas>),
18 Chord {
20 group: Vec<Canvas>,
21 callback: Box<Canvas>,
22 },
23}
24
25#[derive(Debug, Clone)]
27pub struct WorkflowHandle {
28 pub id: String,
30 pub task_ids: Vec<TaskId>,
32}
33
34impl Canvas {
35 pub async fn apply(
42 &self,
43 broker: &dyn Broker,
44 backend: &dyn ResultBackend,
45 ) -> TaskResult<WorkflowHandle> {
46 let correlation_id = uuid::Uuid::now_v7().to_string();
47 let mut task_ids = Vec::new();
48 self.apply_inner(broker, backend, &correlation_id, &mut task_ids)
49 .await?;
50 Ok(WorkflowHandle {
51 id: correlation_id,
52 task_ids,
53 })
54 }
55
56 async fn apply_inner(
57 &self,
58 broker: &dyn Broker,
59 backend: &dyn ResultBackend,
60 correlation_id: &str,
61 task_ids: &mut Vec<TaskId>,
62 ) -> TaskResult<()> {
63 match self {
64 Canvas::Single(sig) => {
65 let mut msg = sig.clone().into_message();
66 msg.correlation_id = Some(correlation_id.to_string());
67 task_ids.push(msg.id);
68 broker.enqueue(msg).await?;
69 }
70 Canvas::Chain(steps) => {
71 if steps.is_empty() {
72 return Ok(());
73 }
74 let sigs = Self::flatten_chain(steps);
76 if sigs.is_empty() {
77 return Ok(());
78 }
79 let mut first_msg = sigs[0].clone().into_message();
81 first_msg.correlation_id = Some(correlation_id.to_string());
82 if sigs.len() > 1 {
83 let remaining: Vec<Signature> = sigs[1..].to_vec();
84 let remaining_json =
85 serde_json::to_string(&remaining).expect("failed to serialize chain steps");
86 first_msg
87 .headers
88 .insert("kojin.chain_next".to_string(), remaining_json);
89 }
90 task_ids.push(first_msg.id);
91 broker.enqueue(first_msg).await?;
92 }
93 Canvas::Group(members) => {
94 let group_id = uuid::Uuid::now_v7().to_string();
95 let sigs = Self::flatten_group(members);
96 let total = sigs.len() as u32;
97 backend.init_group(&group_id, total).await?;
98
99 for sig in &sigs {
100 let mut msg = sig.clone().into_message();
101 msg.correlation_id = Some(correlation_id.to_string());
102 msg.group_id = Some(group_id.clone());
103 msg.group_total = Some(total);
104 task_ids.push(msg.id);
105 broker.enqueue(msg).await?;
106 }
107 }
108 Canvas::Chord { group, callback } => {
109 let group_id = uuid::Uuid::now_v7().to_string();
110 let sigs = Self::flatten_group(group);
111 let total = sigs.len() as u32;
112 backend.init_group(&group_id, total).await?;
113
114 let callback_sigs = Self::flatten_chain(&[*callback.clone()]);
116 let callback_msg = if !callback_sigs.is_empty() {
117 callback_sigs[0].clone().into_message()
118 } else {
119 return Ok(());
120 };
121
122 for sig in &sigs {
123 let mut msg = sig.clone().into_message();
124 msg.correlation_id = Some(correlation_id.to_string());
125 msg.group_id = Some(group_id.clone());
126 msg.group_total = Some(total);
127 msg.chord_callback = Some(Box::new(callback_msg.clone()));
128 task_ids.push(msg.id);
129 broker.enqueue(msg).await?;
130 }
131 }
132 }
133 Ok(())
134 }
135
136 fn flatten_chain(steps: &[Canvas]) -> Vec<Signature> {
138 let mut result = Vec::new();
139 for step in steps {
140 match step {
141 Canvas::Single(sig) => result.push(sig.clone()),
142 Canvas::Chain(inner) => result.extend(Self::flatten_chain(inner)),
143 _ => {
144 tracing::warn!("Nested group/chord in chain is not yet supported, skipping");
147 }
148 }
149 }
150 result
151 }
152
153 fn flatten_group(members: &[Canvas]) -> Vec<Signature> {
155 let mut result = Vec::new();
156 for member in members {
157 match member {
158 Canvas::Single(sig) => result.push(sig.clone()),
159 _ => {
160 tracing::warn!("Nested canvas in group is not yet supported, skipping");
161 }
162 }
163 }
164 result
165 }
166}
167
168#[macro_export]
174macro_rules! chain {
175 ($($sig:expr),+ $(,)?) => {
176 $crate::canvas::Canvas::Chain(vec![
177 $($crate::canvas::Canvas::Single($sig)),+
178 ])
179 };
180}
181
182#[macro_export]
188macro_rules! group {
189 ($($sig:expr),+ $(,)?) => {
190 $crate::canvas::Canvas::Group(vec![
191 $($crate::canvas::Canvas::Single($sig)),+
192 ])
193 };
194}
195
196pub fn chord(group_items: Vec<Signature>, callback: Signature) -> Canvas {
198 Canvas::Chord {
199 group: group_items.into_iter().map(Canvas::Single).collect(),
200 callback: Box::new(Canvas::Single(callback)),
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::memory_broker::MemoryBroker;
208 use crate::memory_result_backend::MemoryResultBackend;
209
210 fn sig(name: &str) -> Signature {
211 Signature::new(name, "default", serde_json::json!({}))
212 }
213
214 #[test]
215 fn chain_macro() {
216 let c = chain![sig("a"), sig("b"), sig("c")];
217 match c {
218 Canvas::Chain(steps) => assert_eq!(steps.len(), 3),
219 _ => panic!("expected Chain"),
220 }
221 }
222
223 #[test]
224 fn group_macro() {
225 let g = group![sig("a"), sig("b")];
226 match g {
227 Canvas::Group(members) => assert_eq!(members.len(), 2),
228 _ => panic!("expected Group"),
229 }
230 }
231
232 #[test]
233 fn chord_constructor() {
234 let c = chord(vec![sig("a"), sig("b")], sig("callback"));
235 match c {
236 Canvas::Chord { group, callback } => {
237 assert_eq!(group.len(), 2);
238 assert!(matches!(*callback, Canvas::Single(_)));
239 }
240 _ => panic!("expected Chord"),
241 }
242 }
243
244 #[tokio::test]
245 async fn apply_single() {
246 let broker = MemoryBroker::new();
247 let backend = MemoryResultBackend::new();
248 let canvas = Canvas::Single(sig("task_a"));
249
250 let handle = canvas.apply(&broker, &backend).await.unwrap();
251 assert_eq!(handle.task_ids.len(), 1);
252 assert_eq!(broker.queue_len("default").await.unwrap(), 1);
253 }
254
255 #[tokio::test]
256 async fn apply_chain() {
257 let broker = MemoryBroker::new();
258 let backend = MemoryResultBackend::new();
259 let canvas = chain![sig("a"), sig("b"), sig("c")];
260
261 let handle = canvas.apply(&broker, &backend).await.unwrap();
262 assert_eq!(handle.task_ids.len(), 1);
264 assert_eq!(broker.queue_len("default").await.unwrap(), 1);
265
266 let msg = broker
268 .dequeue(
269 &["default".to_string()],
270 std::time::Duration::from_millis(100),
271 )
272 .await
273 .unwrap()
274 .unwrap();
275 assert!(msg.headers.contains_key("kojin.chain_next"));
276 let remaining: Vec<Signature> =
277 serde_json::from_str(msg.headers.get("kojin.chain_next").unwrap()).unwrap();
278 assert_eq!(remaining.len(), 2);
279 assert_eq!(remaining[0].task_name, "b");
280 assert_eq!(remaining[1].task_name, "c");
281 }
282
283 #[tokio::test]
284 async fn apply_group() {
285 let broker = MemoryBroker::new();
286 let backend = MemoryResultBackend::new();
287 let canvas = group![sig("a"), sig("b"), sig("c")];
288
289 let handle = canvas.apply(&broker, &backend).await.unwrap();
290 assert_eq!(handle.task_ids.len(), 3);
291 assert_eq!(broker.queue_len("default").await.unwrap(), 3);
292 }
293
294 #[tokio::test]
295 async fn apply_chord() {
296 let broker = MemoryBroker::new();
297 let backend = MemoryResultBackend::new();
298 let canvas = chord(vec![sig("a"), sig("b")], sig("callback"));
299
300 let handle = canvas.apply(&broker, &backend).await.unwrap();
301 assert_eq!(handle.task_ids.len(), 2);
302 assert_eq!(broker.queue_len("default").await.unwrap(), 2);
303
304 let msg = broker
306 .dequeue(
307 &["default".to_string()],
308 std::time::Duration::from_millis(100),
309 )
310 .await
311 .unwrap()
312 .unwrap();
313 assert!(msg.chord_callback.is_some());
314 assert_eq!(msg.chord_callback.unwrap().task_name, "callback");
315 }
316}