deltaflow/
pipeline.rs

1//! Pipeline builder and executor.
2
3use async_trait::async_trait;
4use serde::Serialize;
5use std::sync::Arc;
6use thiserror::Error;
7
8use crate::recorder::{NoopRecorder, Recorder, RunId, RunStatus, StepStatus};
9use crate::retry::RetryPolicy;
10use crate::step::{Step, StepError};
11
12/// Type alias for fork predicate functions.
13type ForkPredicate<O> = Arc<dyn Fn(&O) -> bool + Send + Sync>;
14
15/// Type alias for dynamic spawn generator functions.
16type SpawnGenerator<O> = Arc<dyn Fn(&O) -> Vec<serde_json::Value> + Send + Sync>;
17
18/// A rule for spawning work after pipeline completion.
19#[derive(Clone)]
20pub enum SpawnRule<O> {
21    /// Conditional fork: spawn to target if predicate returns true.
22    Fork {
23        target: &'static str,
24        predicate: ForkPredicate<O>,
25        description: String,
26    },
27    /// Static fan-out: always spawn to these targets.
28    FanOut { targets: Vec<&'static str> },
29    /// Dynamic spawn: generate tasks from output.
30    Dynamic {
31        target: &'static str,
32        generator: SpawnGenerator<O>,
33    },
34}
35
36/// Serializable representation of a pipeline's structure for visualization.
37#[derive(Debug, Clone, Serialize)]
38pub struct PipelineGraph {
39    pub name: String,
40    pub steps: Vec<StepNode>,
41    pub forks: Vec<ForkNode>,
42    pub fan_outs: Vec<FanOutNode>,
43    pub dynamic_spawns: Vec<DynamicSpawnNode>,
44}
45
46/// A step in the pipeline graph.
47#[derive(Debug, Clone, Serialize)]
48pub struct StepNode {
49    pub name: String,
50    pub index: usize,
51}
52
53/// A conditional fork declaration.
54#[derive(Debug, Clone, Serialize)]
55pub struct ForkNode {
56    pub target_pipeline: String,
57    pub condition: String,
58}
59
60/// A static fan-out declaration.
61#[derive(Debug, Clone, Serialize)]
62pub struct FanOutNode {
63    pub targets: Vec<String>,
64}
65
66/// A dynamic spawn declaration.
67#[derive(Debug, Clone, Serialize)]
68pub struct DynamicSpawnNode {
69    pub target_pipeline: String,
70}
71
72/// Error returned by pipeline execution.
73#[derive(Error, Debug)]
74pub enum PipelineError {
75    /// A step failed permanently.
76    #[error("step '{step}' failed: {source}")]
77    StepFailed {
78        step: &'static str,
79        #[source]
80        source: anyhow::Error,
81    },
82
83    /// A step exhausted all retries.
84    #[error("step '{step}' exhausted {attempts} retries: {source}")]
85    RetriesExhausted {
86        step: &'static str,
87        attempts: u32,
88        #[source]
89        source: anyhow::Error,
90    },
91
92    /// Recording failed.
93    #[error("recorder error: {0}")]
94    RecorderError(#[from] anyhow::Error),
95}
96
97/// Trait for types that can provide an entity ID for recording.
98pub trait HasEntityId {
99    /// Returns the entity identifier for this input.
100    fn entity_id(&self) -> String;
101}
102
103// Blanket impl for String
104impl HasEntityId for String {
105    fn entity_id(&self) -> String {
106        self.clone()
107    }
108}
109
110// Blanket impl for &str
111impl HasEntityId for &str {
112    fn entity_id(&self) -> String {
113        self.to_string()
114    }
115}
116
117/// Internal trait for boxed step execution.
118#[doc(hidden)]
119#[async_trait]
120pub trait BoxedStep<I, O>: Send + Sync {
121    fn name(&self) -> &'static str;
122    async fn execute(&self, input: I) -> Result<O, StepError>;
123}
124
125/// Wrapper to make any Step into a BoxedStep.
126#[doc(hidden)]
127pub struct StepWrapper<S>(pub S);
128
129#[async_trait]
130impl<S> BoxedStep<S::Input, S::Output> for StepWrapper<S>
131where
132    S: Step,
133{
134    fn name(&self) -> &'static str {
135        self.0.name()
136    }
137
138    async fn execute(&self, input: S::Input) -> Result<S::Output, StepError> {
139        self.0.execute(input).await
140    }
141}
142
143/// A chain of steps that transforms I -> O.
144#[doc(hidden)]
145#[async_trait]
146pub trait StepChain<I, O>: Send + Sync {
147    async fn run(
148        &self,
149        input: I,
150        run_id: RunId,
151        recorder: &dyn Recorder,
152        retry_policy: &RetryPolicy,
153        start_index: u32,
154    ) -> Result<O, PipelineError>;
155
156    /// Returns the number of steps in this chain.
157    fn step_count(&self) -> u32;
158
159    /// Collect step names in order.
160    fn collect_step_names(&self, names: &mut Vec<&'static str>);
161}
162
163/// Terminal chain - identity transform.
164#[doc(hidden)]
165pub struct Identity;
166
167#[async_trait]
168impl<T: Send + 'static> StepChain<T, T> for Identity {
169    async fn run(
170        &self,
171        input: T,
172        _run_id: RunId,
173        _recorder: &dyn Recorder,
174        _retry_policy: &RetryPolicy,
175        _start_index: u32,
176    ) -> Result<T, PipelineError> {
177        Ok(input)
178    }
179
180    fn step_count(&self) -> u32 {
181        0
182    }
183
184    fn collect_step_names(&self, _names: &mut Vec<&'static str>) {}
185}
186
187/// Chain that runs a step then continues with the rest.
188#[doc(hidden)]
189pub struct ChainedStep<S, Next, I, M, O>
190where
191    S: BoxedStep<I, M>,
192    Next: StepChain<M, O>,
193{
194    pub step: S,
195    pub next: Next,
196    pub _phantom: std::marker::PhantomData<(I, M, O)>,
197}
198
199#[async_trait]
200impl<S, Next, I, M, O> StepChain<I, O> for ChainedStep<S, Next, I, M, O>
201where
202    I: Send + Sync + Clone + 'static,
203    M: Send + Sync + 'static,
204    O: Send + Sync + 'static,
205    S: BoxedStep<I, M> + Send + Sync,
206    Next: StepChain<M, O> + Send + Sync,
207{
208    async fn run(
209        &self,
210        input: I,
211        run_id: RunId,
212        recorder: &dyn Recorder,
213        retry_policy: &RetryPolicy,
214        start_index: u32,
215    ) -> Result<O, PipelineError> {
216        let step_name = self.step.name();
217        let step_id = recorder.start_step(run_id, step_name, start_index).await?;
218
219        // Execute with retry
220        let mut attempt = 0u32;
221        let output = loop {
222            attempt += 1;
223            match self.step.execute(input.clone()).await {
224                Ok(output) => break output,
225                Err(StepError::Permanent(e)) => {
226                    recorder
227                        .complete_step(
228                            step_id,
229                            StepStatus::Failed {
230                                error: e.to_string(),
231                                attempt,
232                            },
233                        )
234                        .await?;
235                    return Err(PipelineError::StepFailed {
236                        step: step_name,
237                        source: e,
238                    });
239                }
240                Err(StepError::Retryable(e)) => {
241                    if let Some(delay) = retry_policy.delay_for_attempt(attempt) {
242                        tokio::time::sleep(delay).await;
243                    } else {
244                        recorder
245                            .complete_step(
246                                step_id,
247                                StepStatus::Failed {
248                                    error: e.to_string(),
249                                    attempt,
250                                },
251                            )
252                            .await?;
253                        return Err(PipelineError::RetriesExhausted {
254                            step: step_name,
255                            attempts: attempt,
256                            source: e,
257                        });
258                    }
259                }
260            }
261        };
262
263        recorder
264            .complete_step(step_id, StepStatus::Completed)
265            .await?;
266
267        // Continue with next steps
268        self.next
269            .run(output, run_id, recorder, retry_policy, start_index + 1)
270            .await
271    }
272
273    fn step_count(&self) -> u32 {
274        1 + self.next.step_count()
275    }
276
277    fn collect_step_names(&self, names: &mut Vec<&'static str>) {
278        names.push(self.step.name());
279        self.next.collect_step_names(names);
280    }
281}
282
283/// Builder for constructing pipelines.
284pub struct Pipeline<I, O, Chain>
285where
286    Chain: StepChain<I, O>,
287{
288    name: &'static str,
289    chain: Chain,
290    retry_policy: RetryPolicy,
291    recorder: Arc<dyn Recorder>,
292    spawn_rules: Vec<SpawnRule<O>>,
293    _phantom: std::marker::PhantomData<(I, O)>,
294}
295
296impl Pipeline<(), (), Identity> {
297    /// Create a new pipeline builder with the given name.
298    pub fn new(name: &'static str) -> Self {
299        Self {
300            name,
301            chain: Identity,
302            retry_policy: RetryPolicy::default(),
303            recorder: Arc::new(NoopRecorder),
304            spawn_rules: Vec::new(),
305            _phantom: std::marker::PhantomData,
306        }
307    }
308}
309
310impl<O, Chain> Pipeline<(), O, Chain>
311where
312    Chain: StepChain<(), O> + Send + Sync + 'static,
313    O: Send + 'static,
314{
315    /// Add the first step to the pipeline.
316    #[allow(clippy::type_complexity)]
317    pub fn start_with<S>(
318        self,
319        step: S,
320    ) -> Pipeline<
321        S::Input,
322        S::Output,
323        ChainedStep<StepWrapper<S>, Identity, S::Input, S::Output, S::Output>,
324    >
325    where
326        S: Step + 'static,
327    {
328        Pipeline {
329            name: self.name,
330            chain: ChainedStep {
331                step: StepWrapper(step),
332                next: Identity,
333                _phantom: std::marker::PhantomData,
334            },
335            retry_policy: self.retry_policy,
336            recorder: self.recorder,
337            spawn_rules: Vec::new(),
338            _phantom: std::marker::PhantomData,
339        }
340    }
341}
342
343impl<I, O, Chain> Pipeline<I, O, Chain>
344where
345    I: Send + Sync + Clone + 'static,
346    O: Send + Sync + Clone + 'static,
347    Chain: StepChain<I, O> + Send + Sync + 'static,
348{
349    /// Add a step to the pipeline.
350    pub fn then<S>(self, step: S) -> Pipeline<I, S::Output, impl StepChain<I, S::Output>>
351    where
352        S: Step<Input = O> + 'static,
353    {
354        Pipeline {
355            name: self.name,
356            chain: ThenChain {
357                first: self.chain,
358                step: StepWrapper(step),
359                _phantom: std::marker::PhantomData,
360            },
361            retry_policy: self.retry_policy,
362            recorder: self.recorder,
363            spawn_rules: Vec::new(),
364            _phantom: std::marker::PhantomData,
365        }
366    }
367
368    /// Set the retry policy for this pipeline.
369    pub fn with_retry(mut self, policy: RetryPolicy) -> Self {
370        self.retry_policy = policy;
371        self
372    }
373
374    /// Set the recorder for this pipeline.
375    pub fn with_recorder<R: Recorder + 'static>(mut self, recorder: R) -> Self {
376        self.recorder = Arc::new(recorder);
377        self
378    }
379
380    /// Declare follow-up tasks to spawn on successful completion.
381    ///
382    /// The generator function receives the pipeline output and returns
383    /// a list of inputs to enqueue to the target pipeline.
384    pub fn spawn_from<T, F>(mut self, target: &'static str, f: F) -> Self
385    where
386        T: Serialize + 'static,
387        F: Fn(&O) -> Vec<T> + Send + Sync + 'static,
388    {
389        self.spawn_rules.push(SpawnRule::Dynamic {
390            target,
391            generator: Arc::new(move |output| {
392                f(output)
393                    .into_iter()
394                    .filter_map(|item| serde_json::to_value(item).ok())
395                    .collect()
396            }),
397        });
398        self
399    }
400
401    /// Keep old name as alias for backward compatibility.
402    #[deprecated(since = "0.4.0", note = "Use spawn_from instead")]
403    pub fn spawns<T, F>(self, target: &'static str, f: F) -> Self
404    where
405        T: Serialize + 'static,
406        F: Fn(&O) -> Vec<T> + Send + Sync + 'static,
407    {
408        self.spawn_from(target, f)
409    }
410
411    /// Conditionally fork to a target pipeline when predicate returns true.
412    ///
413    /// The output is serialized and sent to the target pipeline.
414    /// Multiple forks can match - they are not mutually exclusive.
415    pub fn fork_when<F>(mut self, predicate: F, target: &'static str) -> Self
416    where
417        F: Fn(&O) -> bool + Send + Sync + 'static,
418    {
419        self.spawn_rules.push(SpawnRule::Fork {
420            target,
421            predicate: Arc::new(predicate),
422            description: format!("fork to {}", target),
423        });
424        self
425    }
426
427    /// Conditionally fork with a custom description for visualization.
428    pub fn fork_when_desc<F>(
429        mut self,
430        predicate: F,
431        target: &'static str,
432        description: &str,
433    ) -> Self
434    where
435        F: Fn(&O) -> bool + Send + Sync + 'static,
436    {
437        self.spawn_rules.push(SpawnRule::Fork {
438            target,
439            predicate: Arc::new(predicate),
440            description: description.to_string(),
441        });
442        self
443    }
444
445    /// Fan out to multiple target pipelines unconditionally.
446    ///
447    /// The output is serialized and sent to ALL specified targets.
448    pub fn fan_out(mut self, targets: &[&'static str]) -> Self {
449        self.spawn_rules.push(SpawnRule::FanOut {
450            targets: targets.to_vec(),
451        });
452        self
453    }
454
455    /// Build the pipeline, ready for execution.
456    pub fn build(self) -> BuiltPipeline<I, O, Chain> {
457        BuiltPipeline {
458            name: self.name,
459            chain: self.chain,
460            retry_policy: self.retry_policy,
461            recorder: self.recorder,
462            spawn_rules: self.spawn_rules,
463            _phantom: std::marker::PhantomData,
464        }
465    }
466}
467
468/// Chain that runs first chain then a step.
469#[doc(hidden)]
470pub struct ThenChain<First, S, I, M, O>
471where
472    First: StepChain<I, M>,
473    S: BoxedStep<M, O>,
474{
475    pub first: First,
476    pub step: S,
477    pub _phantom: std::marker::PhantomData<(I, M, O)>,
478}
479
480#[async_trait]
481impl<First, S, I, M, O> StepChain<I, O> for ThenChain<First, S, I, M, O>
482where
483    I: Send + Sync + Clone + 'static,
484    M: Send + Sync + Clone + 'static,
485    O: Send + Sync + 'static,
486    First: StepChain<I, M> + Send + Sync,
487    S: BoxedStep<M, O> + Send + Sync,
488{
489    async fn run(
490        &self,
491        input: I,
492        run_id: RunId,
493        recorder: &dyn Recorder,
494        retry_policy: &RetryPolicy,
495        start_index: u32,
496    ) -> Result<O, PipelineError> {
497        // Run first chain
498        let mid = self
499            .first
500            .run(input, run_id, recorder, retry_policy, start_index)
501            .await?;
502
503        let next_index = start_index + self.first.step_count();
504
505        let step_name = self.step.name();
506        let step_id = recorder.start_step(run_id, step_name, next_index).await?;
507
508        // Execute with retry
509        let mut attempt = 0u32;
510        let output = loop {
511            attempt += 1;
512            match self.step.execute(mid.clone()).await {
513                Ok(output) => break output,
514                Err(StepError::Permanent(e)) => {
515                    recorder
516                        .complete_step(
517                            step_id,
518                            StepStatus::Failed {
519                                error: e.to_string(),
520                                attempt,
521                            },
522                        )
523                        .await?;
524                    return Err(PipelineError::StepFailed {
525                        step: step_name,
526                        source: e,
527                    });
528                }
529                Err(StepError::Retryable(e)) => {
530                    if let Some(delay) = retry_policy.delay_for_attempt(attempt) {
531                        tokio::time::sleep(delay).await;
532                    } else {
533                        recorder
534                            .complete_step(
535                                step_id,
536                                StepStatus::Failed {
537                                    error: e.to_string(),
538                                    attempt,
539                                },
540                            )
541                            .await?;
542                        return Err(PipelineError::RetriesExhausted {
543                            step: step_name,
544                            attempts: attempt,
545                            source: e,
546                        });
547                    }
548                }
549            }
550        };
551
552        recorder
553            .complete_step(step_id, StepStatus::Completed)
554            .await?;
555        Ok(output)
556    }
557
558    fn step_count(&self) -> u32 {
559        self.first.step_count() + 1
560    }
561
562    fn collect_step_names(&self, names: &mut Vec<&'static str>) {
563        self.first.collect_step_names(names);
564        names.push(self.step.name());
565    }
566}
567
568/// A built pipeline ready for execution.
569pub struct BuiltPipeline<I, O, Chain>
570where
571    Chain: StepChain<I, O>,
572{
573    name: &'static str,
574    chain: Chain,
575    retry_policy: RetryPolicy,
576    recorder: Arc<dyn Recorder>,
577    pub(crate) spawn_rules: Vec<SpawnRule<O>>,
578    _phantom: std::marker::PhantomData<(I, O)>,
579}
580
581impl<I, O, Chain> BuiltPipeline<I, O, Chain>
582where
583    I: Send + Clone + HasEntityId + 'static,
584    O: Send + Serialize + 'static,
585    Chain: StepChain<I, O> + Send + Sync,
586{
587    /// Execute the pipeline with the given input.
588    pub async fn run(&self, input: I) -> Result<O, PipelineError> {
589        let entity_id = input.entity_id();
590        let run_id = self.recorder.start_run(self.name, &entity_id).await?;
591
592        match self
593            .chain
594            .run(input, run_id, self.recorder.as_ref(), &self.retry_policy, 0)
595            .await
596        {
597            Ok(output) => {
598                self.recorder
599                    .complete_run(run_id, RunStatus::Completed)
600                    .await?;
601                Ok(output)
602            }
603            Err(e) => {
604                self.recorder
605                    .complete_run(
606                        run_id,
607                        RunStatus::Failed {
608                            error: e.to_string(),
609                        },
610                    )
611                    .await?;
612                Err(e)
613            }
614        }
615    }
616
617    /// Get the pipeline name.
618    pub fn name(&self) -> &'static str {
619        self.name
620    }
621
622    /// Get spawned tasks for the given output.
623    pub fn get_spawned(&self, output: &O) -> Vec<(&'static str, serde_json::Value)> {
624        let mut spawned = Vec::new();
625
626        for rule in &self.spawn_rules {
627            match rule {
628                SpawnRule::Fork {
629                    target, predicate, ..
630                } => {
631                    if predicate(output) {
632                        if let Ok(value) = serde_json::to_value(output) {
633                            spawned.push((*target, value));
634                        }
635                    }
636                }
637                SpawnRule::FanOut { targets } => {
638                    if let Ok(value) = serde_json::to_value(output) {
639                        for target in targets {
640                            spawned.push((*target, value.clone()));
641                        }
642                    }
643                }
644                SpawnRule::Dynamic { target, generator } => {
645                    for input in generator(output) {
646                        spawned.push((*target, input));
647                    }
648                }
649            }
650        }
651
652        spawned
653    }
654
655    /// Export the pipeline structure as a graph for visualization.
656    pub fn to_graph(&self) -> PipelineGraph {
657        let mut step_names = Vec::new();
658        self.chain.collect_step_names(&mut step_names);
659
660        let steps: Vec<StepNode> = step_names
661            .into_iter()
662            .enumerate()
663            .map(|(index, name)| StepNode {
664                name: name.to_string(),
665                index,
666            })
667            .collect();
668
669        let mut forks = Vec::new();
670        let mut fan_outs = Vec::new();
671        let mut dynamic_spawns = Vec::new();
672
673        for rule in &self.spawn_rules {
674            match rule {
675                SpawnRule::Fork {
676                    target,
677                    description,
678                    ..
679                } => {
680                    forks.push(ForkNode {
681                        target_pipeline: target.to_string(),
682                        condition: description.clone(),
683                    });
684                }
685                SpawnRule::FanOut { targets } => {
686                    fan_outs.push(FanOutNode {
687                        targets: targets.iter().map(|s| s.to_string()).collect(),
688                    });
689                }
690                SpawnRule::Dynamic { target, .. } => {
691                    dynamic_spawns.push(DynamicSpawnNode {
692                        target_pipeline: target.to_string(),
693                    });
694                }
695            }
696        }
697
698        PipelineGraph {
699            name: self.name.to_string(),
700            steps,
701            forks,
702            fan_outs,
703            dynamic_spawns,
704        }
705    }
706}