Skip to main content

celers_canvas/
chain.rs

1use crate::{CanvasError, Signature};
2use celers_core::{Broker, SerializedTask};
3use serde::{Deserialize, Serialize};
4use uuid::Uuid;
5
6/// Chain: Sequential execution
7///
8/// task1(args1) -> task2(result1) -> task3(result2)
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10pub struct Chain {
11    /// Tasks in the chain
12    pub tasks: Vec<Signature>,
13}
14
15impl Chain {
16    pub fn new() -> Self {
17        Self { tasks: Vec::new() }
18    }
19
20    pub fn then(mut self, task: &str, args: Vec<serde_json::Value>) -> Self {
21        self.tasks
22            .push(Signature::new(task.to_string()).with_args(args));
23        self
24    }
25
26    pub fn then_signature(mut self, signature: Signature) -> Self {
27        self.tasks.push(signature);
28        self
29    }
30
31    /// Apply the chain by enqueuing the first task with links to subsequent tasks
32    pub async fn apply<B: Broker>(self, broker: &B) -> Result<Uuid, CanvasError> {
33        if self.tasks.is_empty() {
34            return Err(CanvasError::Invalid("Chain cannot be empty".to_string()));
35        }
36
37        // Build chain backwards: last task -> second-to-last -> ... -> first
38        let mut chain_iter = self.tasks.into_iter().rev();
39        let mut next_sig: Option<Signature> = None;
40
41        // Start from the last task (no link)
42        if let Some(last_task) = chain_iter.next() {
43            // Last task has no link
44            next_sig = Some(last_task);
45
46            // Link remaining tasks backwards
47            for mut task in chain_iter {
48                task.options.link = next_sig.map(Box::new);
49                next_sig = Some(task);
50            }
51        }
52
53        // Enqueue the first task (which is now in next_sig)
54        if let Some(first_sig) = next_sig {
55            let task_id = Self::enqueue_signature(broker, &first_sig).await?;
56            Ok(task_id)
57        } else {
58            Err(CanvasError::Invalid("Failed to build chain".to_string()))
59        }
60    }
61
62    async fn enqueue_signature<B: Broker>(
63        broker: &B,
64        sig: &Signature,
65    ) -> Result<Uuid, CanvasError> {
66        let args_json = serde_json::json!({
67            "args": sig.args,
68            "kwargs": sig.kwargs
69        });
70        let args_bytes = serde_json::to_vec(&args_json)
71            .map_err(|e| CanvasError::Serialization(e.to_string()))?;
72
73        let mut task = SerializedTask::new(sig.task.clone(), args_bytes);
74
75        if let Some(priority) = sig.options.priority {
76            task = task.with_priority(priority.into());
77        }
78
79        let task_id = task.metadata.id;
80        broker
81            .enqueue(task)
82            .await
83            .map_err(|e| CanvasError::Broker(e.to_string()))?;
84
85        Ok(task_id)
86    }
87}
88
89impl Default for Chain {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl Chain {
96    /// Check if chain is empty
97    pub fn is_empty(&self) -> bool {
98        self.tasks.is_empty()
99    }
100
101    /// Get number of tasks in chain
102    pub fn len(&self) -> usize {
103        self.tasks.len()
104    }
105
106    /// Get the first task in the chain
107    pub fn first(&self) -> Option<&Signature> {
108        self.tasks.first()
109    }
110
111    /// Get the last task in the chain
112    pub fn last(&self) -> Option<&Signature> {
113        self.tasks.last()
114    }
115
116    /// Get an iterator over the tasks
117    pub fn iter(&self) -> std::slice::Iter<'_, Signature> {
118        self.tasks.iter()
119    }
120
121    /// Get a mutable iterator over the tasks
122    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, Signature> {
123        self.tasks.iter_mut()
124    }
125
126    /// Get a task by index
127    pub fn get(&self, index: usize) -> Option<&Signature> {
128        self.tasks.get(index)
129    }
130
131    /// Get a mutable task by index
132    pub fn get_mut(&mut self, index: usize) -> Option<&mut Signature> {
133        self.tasks.get_mut(index)
134    }
135
136    /// Create a chain with pre-allocated capacity
137    pub fn with_capacity(capacity: usize) -> Self {
138        Self {
139            tasks: Vec::with_capacity(capacity),
140        }
141    }
142
143    /// Extend the chain with additional tasks
144    pub fn extend(mut self, tasks: impl IntoIterator<Item = Signature>) -> Self {
145        self.tasks.extend(tasks);
146        self
147    }
148
149    /// Reverse the order of tasks in the chain
150    pub fn reverse(mut self) -> Self {
151        self.tasks.reverse();
152        self
153    }
154
155    /// Retain only tasks that satisfy the predicate
156    pub fn retain<F>(mut self, f: F) -> Self
157    where
158        F: FnMut(&Signature) -> bool,
159    {
160        self.tasks.retain(f);
161        self
162    }
163
164    /// Apply the chain with a countdown (delay in seconds)
165    ///
166    /// The first task will be delayed by the countdown amount.
167    /// Subsequent tasks are linked and will execute after the previous completes.
168    ///
169    /// # Example
170    /// ```ignore
171    /// let chain = Chain::new()
172    ///     .then("task1", vec![])
173    ///     .then("task2", vec![]);
174    ///
175    /// // Start chain execution after 60 seconds
176    /// chain.apply_with_countdown(broker, 60).await?;
177    /// ```
178    pub async fn apply_with_countdown<B: Broker>(
179        mut self,
180        broker: &B,
181        countdown: u64,
182    ) -> Result<Uuid, CanvasError> {
183        if self.tasks.is_empty() {
184            return Err(CanvasError::Invalid("Chain cannot be empty".to_string()));
185        }
186
187        // Set countdown on the first task
188        if let Some(first) = self.tasks.first_mut() {
189            first.options.countdown = Some(countdown);
190        }
191
192        // Use regular apply to handle the chain
193        self.apply(broker).await
194    }
195
196    /// Apply the chain with an ETA (execution time as Unix timestamp)
197    ///
198    /// The first task will be scheduled for execution at the specified ETA.
199    /// Subsequent tasks are linked and will execute after the previous completes.
200    ///
201    /// # Example
202    /// ```ignore
203    /// use std::time::{SystemTime, UNIX_EPOCH, Duration};
204    ///
205    /// let chain = Chain::new()
206    ///     .then("task1", vec![])
207    ///     .then("task2", vec![]);
208    ///
209    /// // Schedule chain for 1 hour from now
210    /// let eta = SystemTime::now()
211    ///     .duration_since(UNIX_EPOCH).unwrap().as_secs() + 3600;
212    /// chain.apply_with_eta(broker, eta).await?;
213    /// ```
214    pub async fn apply_with_eta<B: Broker>(
215        mut self,
216        broker: &B,
217        eta: u64,
218    ) -> Result<Uuid, CanvasError> {
219        if self.tasks.is_empty() {
220            return Err(CanvasError::Invalid("Chain cannot be empty".to_string()));
221        }
222
223        // Calculate countdown from ETA
224        let now = std::time::SystemTime::now()
225            .duration_since(std::time::UNIX_EPOCH)
226            .unwrap_or_default()
227            .as_secs();
228
229        let countdown = eta.saturating_sub(now);
230
231        // Set countdown on the first task
232        if let Some(first) = self.tasks.first_mut() {
233            first.options.countdown = Some(countdown);
234        }
235
236        self.apply(broker).await
237    }
238
239    /// Set countdown on all tasks in the chain (staggered execution)
240    ///
241    /// Each task gets a progressively larger countdown.
242    ///
243    /// # Arguments
244    /// * `start` - Initial countdown for first task
245    /// * `step` - Additional delay added for each subsequent task
246    pub fn with_staggered_countdown(mut self, start: u64, step: u64) -> Self {
247        let mut countdown = start;
248        for task in &mut self.tasks {
249            task.options.countdown = Some(countdown);
250            countdown += step;
251        }
252        self
253    }
254
255    /// Append another chain to this chain
256    ///
257    /// # Example
258    /// ```
259    /// use celers_canvas::{Chain, Signature};
260    ///
261    /// let chain1 = Chain::new()
262    ///     .then("task1", vec![])
263    ///     .then("task2", vec![]);
264    ///
265    /// let chain2 = Chain::new()
266    ///     .then("task3", vec![])
267    ///     .then("task4", vec![]);
268    ///
269    /// let combined = chain1.append(chain2);
270    /// assert_eq!(combined.len(), 4);
271    /// ```
272    pub fn append(mut self, other: Chain) -> Self {
273        self.tasks.extend(other.tasks);
274        self
275    }
276
277    /// Prepend another chain to this chain
278    ///
279    /// # Example
280    /// ```
281    /// use celers_canvas::{Chain, Signature};
282    ///
283    /// let chain1 = Chain::new()
284    ///     .then("task1", vec![])
285    ///     .then("task2", vec![]);
286    ///
287    /// let chain2 = Chain::new()
288    ///     .then("task3", vec![])
289    ///     .then("task4", vec![]);
290    ///
291    /// let combined = chain1.prepend(chain2);
292    /// assert_eq!(combined.len(), 4);
293    /// assert_eq!(combined.first().unwrap().task, "task3");
294    /// ```
295    pub fn prepend(mut self, other: Chain) -> Self {
296        let mut new_tasks = other.tasks;
297        new_tasks.extend(self.tasks);
298        self.tasks = new_tasks;
299        self
300    }
301
302    /// Split chain at the specified index
303    ///
304    /// Returns a tuple of (before, after) chains.
305    /// The task at `index` will be the first task in the second chain.
306    ///
307    /// # Example
308    /// ```
309    /// use celers_canvas::Chain;
310    ///
311    /// let chain = Chain::new()
312    ///     .then("task1", vec![])
313    ///     .then("task2", vec![])
314    ///     .then("task3", vec![])
315    ///     .then("task4", vec![]);
316    ///
317    /// let (before, after) = chain.split_at(2);
318    /// assert_eq!(before.len(), 2);
319    /// assert_eq!(after.len(), 2);
320    /// ```
321    pub fn split_at(self, index: usize) -> (Chain, Chain) {
322        let (before, after) = self.tasks.split_at(index.min(self.tasks.len()));
323        (
324            Chain {
325                tasks: before.to_vec(),
326            },
327            Chain {
328                tasks: after.to_vec(),
329            },
330        )
331    }
332
333    /// Concatenate multiple chains into a single chain
334    ///
335    /// # Example
336    /// ```
337    /// use celers_canvas::Chain;
338    ///
339    /// let chains = vec![
340    ///     Chain::new().then("task1", vec![]),
341    ///     Chain::new().then("task2", vec![]),
342    ///     Chain::new().then("task3", vec![]),
343    /// ];
344    ///
345    /// let combined = Chain::concat(chains);
346    /// assert_eq!(combined.len(), 3);
347    /// ```
348    pub fn concat<I>(chains: I) -> Self
349    where
350        I: IntoIterator<Item = Chain>,
351    {
352        let mut result = Chain::new();
353        for chain in chains {
354            result.tasks.extend(chain.tasks);
355        }
356        result
357    }
358
359    /// Clone all tasks in the chain with a new task name prefix
360    ///
361    /// Useful for creating workflow variants.
362    ///
363    /// # Example
364    /// ```
365    /// use celers_canvas::Chain;
366    ///
367    /// let chain = Chain::new()
368    ///     .then("process", vec![])
369    ///     .then("validate", vec![]);
370    ///
371    /// let prefixed = chain.with_task_prefix("batch_");
372    /// assert_eq!(prefixed.first().unwrap().task, "batch_process");
373    /// ```
374    pub fn with_task_prefix(mut self, prefix: &str) -> Self {
375        for task in &mut self.tasks {
376            task.task = format!("{}{}", prefix, task.task);
377        }
378        self
379    }
380
381    /// Clone all tasks in the chain with a new task name suffix
382    ///
383    /// # Example
384    /// ```
385    /// use celers_canvas::Chain;
386    ///
387    /// let chain = Chain::new()
388    ///     .then("process", vec![])
389    ///     .then("validate", vec![]);
390    ///
391    /// let suffixed = chain.with_task_suffix("_v2");
392    /// assert_eq!(suffixed.first().unwrap().task, "process_v2");
393    /// ```
394    pub fn with_task_suffix(mut self, suffix: &str) -> Self {
395        for task in &mut self.tasks {
396            task.task = format!("{}{}", task.task, suffix);
397        }
398        self
399    }
400
401    /// Validate that all tasks in the chain have non-empty names
402    ///
403    /// Returns true if all tasks are valid, false otherwise.
404    ///
405    /// # Example
406    /// ```
407    /// use celers_canvas::Chain;
408    ///
409    /// let valid = Chain::new()
410    ///     .then("task1", vec![])
411    ///     .then("task2", vec![]);
412    /// assert!(valid.is_valid());
413    ///
414    /// let invalid = Chain { tasks: vec![] };
415    /// assert!(!invalid.is_valid());
416    /// ```
417    pub fn is_valid(&self) -> bool {
418        !self.tasks.is_empty() && self.tasks.iter().all(|t| !t.task.is_empty())
419    }
420
421    /// Count tasks that match a predicate
422    ///
423    /// # Example
424    /// ```
425    /// use celers_canvas::{Chain, Signature};
426    ///
427    /// let chain = Chain::new()
428    ///     .then_signature(Signature::new("high".to_string()).with_priority(9))
429    ///     .then_signature(Signature::new("low".to_string()).with_priority(1))
430    ///     .then_signature(Signature::new("urgent".to_string()).with_priority(9));
431    ///
432    /// let high_priority = chain.count_matching(|sig| sig.options.priority.unwrap_or(0) >= 9);
433    /// assert_eq!(high_priority, 2);
434    /// ```
435    pub fn count_matching<F>(&self, predicate: F) -> usize
436    where
437        F: Fn(&Signature) -> bool,
438    {
439        self.tasks.iter().filter(|t| predicate(t)).count()
440    }
441
442    /// Check if any task matches a predicate
443    ///
444    /// # Example
445    /// ```
446    /// use celers_canvas::Chain;
447    ///
448    /// let chain = Chain::new()
449    ///     .then("process", vec![])
450    ///     .then("validate", vec![]);
451    ///
452    /// assert!(chain.any(|sig| sig.task == "validate"));
453    /// assert!(!chain.any(|sig| sig.task == "missing"));
454    /// ```
455    pub fn any<F>(&self, predicate: F) -> bool
456    where
457        F: Fn(&Signature) -> bool,
458    {
459        self.tasks.iter().any(predicate)
460    }
461
462    /// Check if all tasks match a predicate
463    ///
464    /// # Example
465    /// ```
466    /// use celers_canvas::Chain;
467    ///
468    /// let chain = Chain::new()
469    ///     .then("process", vec![])
470    ///     .then("validate", vec![]);
471    ///
472    /// assert!(chain.all(|sig| !sig.task.is_empty()));
473    /// ```
474    pub fn all<F>(&self, predicate: F) -> bool
475    where
476        F: Fn(&Signature) -> bool,
477    {
478        self.tasks.iter().all(predicate)
479    }
480
481    /// Map over all tasks, transforming each signature
482    ///
483    /// # Example
484    /// ```
485    /// use celers_canvas::{Chain, Signature};
486    ///
487    /// let chain = Chain::new()
488    ///     .then("task1", vec![])
489    ///     .then("task2", vec![]);
490    ///
491    /// let modified = chain.map_tasks(|sig| {
492    ///     Signature::new(format!("modified_{}", sig.task))
493    /// });
494    ///
495    /// assert_eq!(modified.first().unwrap().task, "modified_task1");
496    /// ```
497    pub fn map_tasks<F>(mut self, f: F) -> Self
498    where
499        F: FnMut(Signature) -> Signature,
500    {
501        self.tasks = self.tasks.into_iter().map(f).collect();
502        self
503    }
504
505    /// Filter and map tasks in one operation
506    ///
507    /// # Example
508    /// ```
509    /// use celers_canvas::{Chain, Signature};
510    ///
511    /// let chain = Chain::new()
512    ///     .then_signature(Signature::new("high".to_string()).with_priority(9))
513    ///     .then_signature(Signature::new("low".to_string()).with_priority(1))
514    ///     .then_signature(Signature::new("urgent".to_string()).with_priority(9));
515    ///
516    /// let high_priority = chain.filter_map(|sig| {
517    ///     if sig.options.priority.unwrap_or(0) >= 9 {
518    ///         Some(sig)
519    ///     } else {
520    ///         None
521    ///     }
522    /// });
523    ///
524    /// assert_eq!(high_priority.len(), 2);
525    /// ```
526    pub fn filter_map<F>(mut self, f: F) -> Self
527    where
528        F: FnMut(Signature) -> Option<Signature>,
529    {
530        self.tasks = self.tasks.into_iter().filter_map(f).collect();
531        self
532    }
533
534    /// Take the first n tasks from the chain
535    ///
536    /// # Example
537    /// ```
538    /// use celers_canvas::Chain;
539    ///
540    /// let chain = Chain::new()
541    ///     .then("task1", vec![])
542    ///     .then("task2", vec![])
543    ///     .then("task3", vec![])
544    ///     .then("task4", vec![]);
545    ///
546    /// let first_two = chain.take(2);
547    /// assert_eq!(first_two.len(), 2);
548    /// ```
549    pub fn take(mut self, n: usize) -> Self {
550        self.tasks.truncate(n);
551        self
552    }
553
554    /// Skip the first n tasks from the chain
555    ///
556    /// # Example
557    /// ```
558    /// use celers_canvas::Chain;
559    ///
560    /// let chain = Chain::new()
561    ///     .then("task1", vec![])
562    ///     .then("task2", vec![])
563    ///     .then("task3", vec![])
564    ///     .then("task4", vec![]);
565    ///
566    /// let skipped = chain.skip(2);
567    /// assert_eq!(skipped.len(), 2);
568    /// assert_eq!(skipped.first().unwrap().task, "task3");
569    /// ```
570    pub fn skip(mut self, n: usize) -> Self {
571        self.tasks = self.tasks.into_iter().skip(n).collect();
572        self
573    }
574
575    /// Find the index of the first task with the given name
576    ///
577    /// # Example
578    /// ```
579    /// use celers_canvas::Chain;
580    ///
581    /// let chain = Chain::new()
582    ///     .then("task1", vec![])
583    ///     .then("task2", vec![])
584    ///     .then("task1", vec![]);
585    ///
586    /// assert_eq!(chain.find_task("task1"), Some(0));
587    /// assert_eq!(chain.find_task("task2"), Some(1));
588    /// assert_eq!(chain.find_task("task3"), None);
589    /// ```
590    pub fn find_task(&self, task_name: &str) -> Option<usize> {
591        self.tasks.iter().position(|t| t.task == task_name)
592    }
593
594    /// Find all indices of tasks with the given name
595    ///
596    /// # Example
597    /// ```
598    /// use celers_canvas::Chain;
599    ///
600    /// let chain = Chain::new()
601    ///     .then("task1", vec![])
602    ///     .then("task2", vec![])
603    ///     .then("task1", vec![]);
604    ///
605    /// assert_eq!(chain.find_all_tasks("task1"), vec![0, 2]);
606    /// assert_eq!(chain.find_all_tasks("task2"), vec![1]);
607    /// ```
608    pub fn find_all_tasks(&self, task_name: &str) -> Vec<usize> {
609        self.tasks
610            .iter()
611            .enumerate()
612            .filter(|(_, t)| t.task == task_name)
613            .map(|(i, _)| i)
614            .collect()
615    }
616
617    /// Check if the chain contains a task with the given name
618    ///
619    /// # Example
620    /// ```
621    /// use celers_canvas::Chain;
622    ///
623    /// let chain = Chain::new()
624    ///     .then("task1", vec![])
625    ///     .then("task2", vec![]);
626    ///
627    /// assert!(chain.contains_task("task1"));
628    /// assert!(!chain.contains_task("task3"));
629    /// ```
630    pub fn contains_task(&self, task_name: &str) -> bool {
631        self.tasks.iter().any(|t| t.task == task_name)
632    }
633
634    /// Get the total estimated duration in seconds based on task time limits
635    ///
636    /// This sums up all task time limits (or soft_time_limit if time_limit is not set).
637    /// Returns None if no tasks have time limits set.
638    ///
639    /// # Example
640    /// ```
641    /// use celers_canvas::{Chain, Signature};
642    ///
643    /// let chain = Chain::new()
644    ///     .then_signature(Signature::new("task1".to_string()).with_time_limit(10))
645    ///     .then_signature(Signature::new("task2".to_string()).with_time_limit(20));
646    ///
647    /// assert_eq!(chain.estimated_duration(), Some(30));
648    /// ```
649    pub fn estimated_duration(&self) -> Option<u64> {
650        let mut total = 0u64;
651        let mut found_any = false;
652
653        for task in &self.tasks {
654            if let Some(limit) = task.options.time_limit.or(task.options.soft_time_limit) {
655                total += limit;
656                found_any = true;
657            }
658        }
659
660        if found_any {
661            Some(total)
662        } else {
663            None
664        }
665    }
666
667    /// Get a summary of all task names in the chain
668    ///
669    /// # Example
670    /// ```
671    /// use celers_canvas::Chain;
672    ///
673    /// let chain = Chain::new()
674    ///     .then("fetch", vec![])
675    ///     .then("process", vec![])
676    ///     .then("save", vec![]);
677    ///
678    /// assert_eq!(chain.task_names(), vec!["fetch", "process", "save"]);
679    /// ```
680    pub fn task_names(&self) -> Vec<&str> {
681        self.tasks.iter().map(|t| t.task.as_str()).collect()
682    }
683
684    /// Get all unique task names in the chain
685    ///
686    /// # Example
687    /// ```
688    /// use celers_canvas::Chain;
689    ///
690    /// let chain = Chain::new()
691    ///     .then("task1", vec![])
692    ///     .then("task2", vec![])
693    ///     .then("task1", vec![]);
694    ///
695    /// let unique = chain.unique_task_names();
696    /// assert_eq!(unique.len(), 2);
697    /// assert!(unique.contains(&"task1"));
698    /// assert!(unique.contains(&"task2"));
699    /// ```
700    pub fn unique_task_names(&self) -> std::collections::HashSet<&str> {
701        self.tasks.iter().map(|t| t.task.as_str()).collect()
702    }
703
704    /// Clone the chain with a transformation applied to each task
705    ///
706    /// # Example
707    /// ```
708    /// use celers_canvas::Chain;
709    ///
710    /// let chain = Chain::new()
711    ///     .then("task1", vec![])
712    ///     .then("task2", vec![]);
713    ///
714    /// let prioritized = chain.clone_with_transform(|sig| {
715    ///     sig.clone().with_priority(5)
716    /// });
717    ///
718    /// assert!(prioritized.tasks.iter().all(|t| t.options.priority == Some(5)));
719    /// ```
720    pub fn clone_with_transform<F>(&self, mut transform: F) -> Self
721    where
722        F: FnMut(&Signature) -> Signature,
723    {
724        Self {
725            tasks: self.tasks.iter().map(&mut transform).collect(),
726        }
727    }
728}
729
730impl std::fmt::Display for Chain {
731    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
732        write!(f, "Chain[{} tasks]", self.tasks.len())?;
733        if !self.tasks.is_empty() {
734            write!(
735                f,
736                " {} -> ... -> {}",
737                self.tasks.first().unwrap().task,
738                self.tasks.last().unwrap().task
739            )?;
740        }
741        Ok(())
742    }
743}
744
745impl IntoIterator for Chain {
746    type Item = Signature;
747    type IntoIter = std::vec::IntoIter<Signature>;
748
749    fn into_iter(self) -> Self::IntoIter {
750        self.tasks.into_iter()
751    }
752}
753
754impl<'a> IntoIterator for &'a Chain {
755    type Item = &'a Signature;
756    type IntoIter = std::slice::Iter<'a, Signature>;
757
758    fn into_iter(self) -> Self::IntoIter {
759        self.tasks.iter()
760    }
761}
762
763impl From<Vec<Signature>> for Chain {
764    fn from(tasks: Vec<Signature>) -> Self {
765        Self { tasks }
766    }
767}
768
769impl FromIterator<Signature> for Chain {
770    fn from_iter<T: IntoIterator<Item = Signature>>(iter: T) -> Self {
771        Self {
772            tasks: iter.into_iter().collect(),
773        }
774    }
775}