Skip to main content

kojin_core/
canvas.rs

1use serde::{Deserialize, Serialize};
2
3use crate::broker::Broker;
4use crate::error::TaskResult;
5use crate::result_backend::ResultBackend;
6use crate::signature::Signature;
7use crate::task_id::TaskId;
8
9/// A composable workflow description.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub enum Canvas {
12    /// A single task invocation.
13    Single(Signature),
14    /// Execute tasks sequentially; each task's result is passed to the next.
15    Chain(Vec<Canvas>),
16    /// Execute tasks in parallel.
17    Group(Vec<Canvas>),
18    /// A group with a callback that fires after all members complete.
19    Chord {
20        group: Vec<Canvas>,
21        callback: Box<Canvas>,
22    },
23}
24
25/// Handle returned after submitting a workflow.
26#[derive(Debug, Clone)]
27pub struct WorkflowHandle {
28    /// Workflow identifier (correlation ID).
29    pub id: String,
30    /// Task IDs of all submitted tasks.
31    pub task_ids: Vec<TaskId>,
32}
33
34impl Canvas {
35    /// Submit this workflow for execution.
36    ///
37    /// - **Single**: enqueues immediately.
38    /// - **Chain**: enqueues first task with remaining steps stored in headers.
39    /// - **Group**: enqueues all members in parallel with group metadata.
40    /// - **Chord**: like Group but attaches a chord callback.
41    pub async fn apply(
42        &self,
43        broker: &dyn Broker,
44        backend: &dyn ResultBackend,
45    ) -> TaskResult<WorkflowHandle> {
46        let correlation_id = uuid::Uuid::now_v7().to_string();
47        let mut task_ids = Vec::new();
48        self.apply_inner(broker, backend, &correlation_id, &mut task_ids)
49            .await?;
50        Ok(WorkflowHandle {
51            id: correlation_id,
52            task_ids,
53        })
54    }
55
56    async fn apply_inner(
57        &self,
58        broker: &dyn Broker,
59        backend: &dyn ResultBackend,
60        correlation_id: &str,
61        task_ids: &mut Vec<TaskId>,
62    ) -> TaskResult<()> {
63        match self {
64            Canvas::Single(sig) => {
65                let mut msg = sig.clone().into_message();
66                msg.correlation_id = Some(correlation_id.to_string());
67                task_ids.push(msg.id);
68                broker.enqueue(msg).await?;
69            }
70            Canvas::Chain(steps) => {
71                if steps.is_empty() {
72                    return Ok(());
73                }
74                // Flatten chain into signatures
75                let sigs = Self::flatten_chain(steps);
76                if sigs.is_empty() {
77                    return Ok(());
78                }
79                // Enqueue only the first; store remaining as chain_next header
80                let mut first_msg = sigs[0].clone().into_message();
81                first_msg.correlation_id = Some(correlation_id.to_string());
82                if sigs.len() > 1 {
83                    let remaining: Vec<Signature> = sigs[1..].to_vec();
84                    let remaining_json =
85                        serde_json::to_string(&remaining).expect("failed to serialize chain steps");
86                    first_msg
87                        .headers
88                        .insert("kojin.chain_next".to_string(), remaining_json);
89                }
90                task_ids.push(first_msg.id);
91                broker.enqueue(first_msg).await?;
92            }
93            Canvas::Group(members) => {
94                let group_id = uuid::Uuid::now_v7().to_string();
95                let sigs = Self::flatten_group(members);
96                let total = sigs.len() as u32;
97                backend.init_group(&group_id, total).await?;
98
99                for sig in &sigs {
100                    let mut msg = sig.clone().into_message();
101                    msg.correlation_id = Some(correlation_id.to_string());
102                    msg.group_id = Some(group_id.clone());
103                    msg.group_total = Some(total);
104                    task_ids.push(msg.id);
105                    broker.enqueue(msg).await?;
106                }
107            }
108            Canvas::Chord { group, callback } => {
109                let group_id = uuid::Uuid::now_v7().to_string();
110                let sigs = Self::flatten_group(group);
111                let total = sigs.len() as u32;
112                backend.init_group(&group_id, total).await?;
113
114                // Build the callback message
115                let callback_sigs = Self::flatten_chain(&[*callback.clone()]);
116                let callback_msg = if !callback_sigs.is_empty() {
117                    callback_sigs[0].clone().into_message()
118                } else {
119                    return Ok(());
120                };
121
122                for sig in &sigs {
123                    let mut msg = sig.clone().into_message();
124                    msg.correlation_id = Some(correlation_id.to_string());
125                    msg.group_id = Some(group_id.clone());
126                    msg.group_total = Some(total);
127                    msg.chord_callback = Some(Box::new(callback_msg.clone()));
128                    task_ids.push(msg.id);
129                    broker.enqueue(msg).await?;
130                }
131            }
132        }
133        Ok(())
134    }
135
136    /// Flatten a chain of canvases into a list of signatures.
137    fn flatten_chain(steps: &[Canvas]) -> Vec<Signature> {
138        let mut result = Vec::new();
139        for step in steps {
140            match step {
141                Canvas::Single(sig) => result.push(sig.clone()),
142                Canvas::Chain(inner) => result.extend(Self::flatten_chain(inner)),
143                _ => {
144                    // For nested group/chord in a chain, we'd need more complex handling.
145                    // For now, skip non-single items.
146                    tracing::warn!("Nested group/chord in chain is not yet supported, skipping");
147                }
148            }
149        }
150        result
151    }
152
153    /// Flatten a group of canvases into a list of signatures.
154    fn flatten_group(members: &[Canvas]) -> Vec<Signature> {
155        let mut result = Vec::new();
156        for member in members {
157            match member {
158                Canvas::Single(sig) => result.push(sig.clone()),
159                _ => {
160                    tracing::warn!("Nested canvas in group is not yet supported, skipping");
161                }
162            }
163        }
164        result
165    }
166}
167
168/// Create a chain of tasks.
169///
170/// ```ignore
171/// let workflow = chain![sig_a, sig_b, sig_c];
172/// ```
173#[macro_export]
174macro_rules! chain {
175    ($($sig:expr),+ $(,)?) => {
176        $crate::canvas::Canvas::Chain(vec![
177            $($crate::canvas::Canvas::Single($sig)),+
178        ])
179    };
180}
181
182/// Create a group of tasks.
183///
184/// ```ignore
185/// let workflow = group![sig_a, sig_b, sig_c];
186/// ```
187#[macro_export]
188macro_rules! group {
189    ($($sig:expr),+ $(,)?) => {
190        $crate::canvas::Canvas::Group(vec![
191            $($crate::canvas::Canvas::Single($sig)),+
192        ])
193    };
194}
195
196/// Create a chord: a group with a callback that fires when all members complete.
197pub fn chord(group_items: Vec<Signature>, callback: Signature) -> Canvas {
198    Canvas::Chord {
199        group: group_items.into_iter().map(Canvas::Single).collect(),
200        callback: Box::new(Canvas::Single(callback)),
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::memory_broker::MemoryBroker;
208    use crate::memory_result_backend::MemoryResultBackend;
209
210    fn sig(name: &str) -> Signature {
211        Signature::new(name, "default", serde_json::json!({}))
212    }
213
214    #[test]
215    fn chain_macro() {
216        let c = chain![sig("a"), sig("b"), sig("c")];
217        match c {
218            Canvas::Chain(steps) => assert_eq!(steps.len(), 3),
219            _ => panic!("expected Chain"),
220        }
221    }
222
223    #[test]
224    fn group_macro() {
225        let g = group![sig("a"), sig("b")];
226        match g {
227            Canvas::Group(members) => assert_eq!(members.len(), 2),
228            _ => panic!("expected Group"),
229        }
230    }
231
232    #[test]
233    fn chord_constructor() {
234        let c = chord(vec![sig("a"), sig("b")], sig("callback"));
235        match c {
236            Canvas::Chord { group, callback } => {
237                assert_eq!(group.len(), 2);
238                assert!(matches!(*callback, Canvas::Single(_)));
239            }
240            _ => panic!("expected Chord"),
241        }
242    }
243
244    #[tokio::test]
245    async fn apply_single() {
246        let broker = MemoryBroker::new();
247        let backend = MemoryResultBackend::new();
248        let canvas = Canvas::Single(sig("task_a"));
249
250        let handle = canvas.apply(&broker, &backend).await.unwrap();
251        assert_eq!(handle.task_ids.len(), 1);
252        assert_eq!(broker.queue_len("default").await.unwrap(), 1);
253    }
254
255    #[tokio::test]
256    async fn apply_chain() {
257        let broker = MemoryBroker::new();
258        let backend = MemoryResultBackend::new();
259        let canvas = chain![sig("a"), sig("b"), sig("c")];
260
261        let handle = canvas.apply(&broker, &backend).await.unwrap();
262        // Only first task enqueued
263        assert_eq!(handle.task_ids.len(), 1);
264        assert_eq!(broker.queue_len("default").await.unwrap(), 1);
265
266        // Verify chain_next header
267        let msg = broker
268            .dequeue(
269                &["default".to_string()],
270                std::time::Duration::from_millis(100),
271            )
272            .await
273            .unwrap()
274            .unwrap();
275        assert!(msg.headers.contains_key("kojin.chain_next"));
276        let remaining: Vec<Signature> =
277            serde_json::from_str(msg.headers.get("kojin.chain_next").unwrap()).unwrap();
278        assert_eq!(remaining.len(), 2);
279        assert_eq!(remaining[0].task_name, "b");
280        assert_eq!(remaining[1].task_name, "c");
281    }
282
283    #[tokio::test]
284    async fn apply_group() {
285        let broker = MemoryBroker::new();
286        let backend = MemoryResultBackend::new();
287        let canvas = group![sig("a"), sig("b"), sig("c")];
288
289        let handle = canvas.apply(&broker, &backend).await.unwrap();
290        assert_eq!(handle.task_ids.len(), 3);
291        assert_eq!(broker.queue_len("default").await.unwrap(), 3);
292    }
293
294    #[tokio::test]
295    async fn apply_chord() {
296        let broker = MemoryBroker::new();
297        let backend = MemoryResultBackend::new();
298        let canvas = chord(vec![sig("a"), sig("b")], sig("callback"));
299
300        let handle = canvas.apply(&broker, &backend).await.unwrap();
301        assert_eq!(handle.task_ids.len(), 2);
302        assert_eq!(broker.queue_len("default").await.unwrap(), 2);
303
304        // Each member should have chord_callback set
305        let msg = broker
306            .dequeue(
307                &["default".to_string()],
308                std::time::Duration::from_millis(100),
309            )
310            .await
311            .unwrap()
312            .unwrap();
313        assert!(msg.chord_callback.is_some());
314        assert_eq!(msg.chord_callback.unwrap().task_name, "callback");
315    }
316}