1use async_trait::async_trait;
4use serde::Serialize;
5use std::sync::Arc;
6use thiserror::Error;
7
8use crate::recorder::{NoopRecorder, Recorder, RunId, RunStatus, StepStatus};
9use crate::retry::RetryPolicy;
10use crate::step::{Step, StepError};
11
12type ForkPredicate<O> = Arc<dyn Fn(&O) -> bool + Send + Sync>;
14
15type SpawnGenerator<O> = Arc<dyn Fn(&O) -> Vec<serde_json::Value> + Send + Sync>;
17
18#[derive(Clone)]
20pub enum SpawnRule<O> {
21 Fork {
23 target: &'static str,
24 predicate: ForkPredicate<O>,
25 description: String,
26 },
27 FanOut { targets: Vec<&'static str> },
29 Dynamic {
31 target: &'static str,
32 generator: SpawnGenerator<O>,
33 },
34}
35
36#[derive(Debug, Clone, Serialize)]
38pub struct PipelineGraph {
39 pub name: String,
40 pub steps: Vec<StepNode>,
41 pub forks: Vec<ForkNode>,
42 pub fan_outs: Vec<FanOutNode>,
43 pub dynamic_spawns: Vec<DynamicSpawnNode>,
44}
45
46#[derive(Debug, Clone, Serialize)]
48pub struct StepNode {
49 pub name: String,
50 pub index: usize,
51}
52
53#[derive(Debug, Clone, Serialize)]
55pub struct ForkNode {
56 pub target_pipeline: String,
57 pub condition: String,
58}
59
60#[derive(Debug, Clone, Serialize)]
62pub struct FanOutNode {
63 pub targets: Vec<String>,
64}
65
66#[derive(Debug, Clone, Serialize)]
68pub struct DynamicSpawnNode {
69 pub target_pipeline: String,
70}
71
72#[derive(Error, Debug)]
74pub enum PipelineError {
75 #[error("step '{step}' failed: {source}")]
77 StepFailed {
78 step: &'static str,
79 #[source]
80 source: anyhow::Error,
81 },
82
83 #[error("step '{step}' exhausted {attempts} retries: {source}")]
85 RetriesExhausted {
86 step: &'static str,
87 attempts: u32,
88 #[source]
89 source: anyhow::Error,
90 },
91
92 #[error("recorder error: {0}")]
94 RecorderError(#[from] anyhow::Error),
95}
96
97pub trait HasEntityId {
99 fn entity_id(&self) -> String;
101}
102
103impl HasEntityId for String {
105 fn entity_id(&self) -> String {
106 self.clone()
107 }
108}
109
110impl HasEntityId for &str {
112 fn entity_id(&self) -> String {
113 self.to_string()
114 }
115}
116
117#[doc(hidden)]
119#[async_trait]
120pub trait BoxedStep<I, O>: Send + Sync {
121 fn name(&self) -> &'static str;
122 async fn execute(&self, input: I) -> Result<O, StepError>;
123}
124
125#[doc(hidden)]
127pub struct StepWrapper<S>(pub S);
128
129#[async_trait]
130impl<S> BoxedStep<S::Input, S::Output> for StepWrapper<S>
131where
132 S: Step,
133{
134 fn name(&self) -> &'static str {
135 self.0.name()
136 }
137
138 async fn execute(&self, input: S::Input) -> Result<S::Output, StepError> {
139 self.0.execute(input).await
140 }
141}
142
143#[doc(hidden)]
145#[async_trait]
146pub trait StepChain<I, O>: Send + Sync {
147 async fn run(
148 &self,
149 input: I,
150 run_id: RunId,
151 recorder: &dyn Recorder,
152 retry_policy: &RetryPolicy,
153 start_index: u32,
154 ) -> Result<O, PipelineError>;
155
156 fn step_count(&self) -> u32;
158
159 fn collect_step_names(&self, names: &mut Vec<&'static str>);
161}
162
163#[doc(hidden)]
165pub struct Identity;
166
167#[async_trait]
168impl<T: Send + 'static> StepChain<T, T> for Identity {
169 async fn run(
170 &self,
171 input: T,
172 _run_id: RunId,
173 _recorder: &dyn Recorder,
174 _retry_policy: &RetryPolicy,
175 _start_index: u32,
176 ) -> Result<T, PipelineError> {
177 Ok(input)
178 }
179
180 fn step_count(&self) -> u32 {
181 0
182 }
183
184 fn collect_step_names(&self, _names: &mut Vec<&'static str>) {}
185}
186
187#[doc(hidden)]
189pub struct ChainedStep<S, Next, I, M, O>
190where
191 S: BoxedStep<I, M>,
192 Next: StepChain<M, O>,
193{
194 pub step: S,
195 pub next: Next,
196 pub _phantom: std::marker::PhantomData<(I, M, O)>,
197}
198
199#[async_trait]
200impl<S, Next, I, M, O> StepChain<I, O> for ChainedStep<S, Next, I, M, O>
201where
202 I: Send + Sync + Clone + 'static,
203 M: Send + Sync + 'static,
204 O: Send + Sync + 'static,
205 S: BoxedStep<I, M> + Send + Sync,
206 Next: StepChain<M, O> + Send + Sync,
207{
208 async fn run(
209 &self,
210 input: I,
211 run_id: RunId,
212 recorder: &dyn Recorder,
213 retry_policy: &RetryPolicy,
214 start_index: u32,
215 ) -> Result<O, PipelineError> {
216 let step_name = self.step.name();
217 let step_id = recorder.start_step(run_id, step_name, start_index).await?;
218
219 let mut attempt = 0u32;
221 let output = loop {
222 attempt += 1;
223 match self.step.execute(input.clone()).await {
224 Ok(output) => break output,
225 Err(StepError::Permanent(e)) => {
226 recorder
227 .complete_step(
228 step_id,
229 StepStatus::Failed {
230 error: e.to_string(),
231 attempt,
232 },
233 )
234 .await?;
235 return Err(PipelineError::StepFailed {
236 step: step_name,
237 source: e,
238 });
239 }
240 Err(StepError::Retryable(e)) => {
241 if let Some(delay) = retry_policy.delay_for_attempt(attempt) {
242 tokio::time::sleep(delay).await;
243 } else {
244 recorder
245 .complete_step(
246 step_id,
247 StepStatus::Failed {
248 error: e.to_string(),
249 attempt,
250 },
251 )
252 .await?;
253 return Err(PipelineError::RetriesExhausted {
254 step: step_name,
255 attempts: attempt,
256 source: e,
257 });
258 }
259 }
260 }
261 };
262
263 recorder
264 .complete_step(step_id, StepStatus::Completed)
265 .await?;
266
267 self.next
269 .run(output, run_id, recorder, retry_policy, start_index + 1)
270 .await
271 }
272
273 fn step_count(&self) -> u32 {
274 1 + self.next.step_count()
275 }
276
277 fn collect_step_names(&self, names: &mut Vec<&'static str>) {
278 names.push(self.step.name());
279 self.next.collect_step_names(names);
280 }
281}
282
283pub struct Pipeline<I, O, Chain>
285where
286 Chain: StepChain<I, O>,
287{
288 name: &'static str,
289 chain: Chain,
290 retry_policy: RetryPolicy,
291 recorder: Arc<dyn Recorder>,
292 spawn_rules: Vec<SpawnRule<O>>,
293 _phantom: std::marker::PhantomData<(I, O)>,
294}
295
296impl Pipeline<(), (), Identity> {
297 pub fn new(name: &'static str) -> Self {
299 Self {
300 name,
301 chain: Identity,
302 retry_policy: RetryPolicy::default(),
303 recorder: Arc::new(NoopRecorder),
304 spawn_rules: Vec::new(),
305 _phantom: std::marker::PhantomData,
306 }
307 }
308}
309
310impl<O, Chain> Pipeline<(), O, Chain>
311where
312 Chain: StepChain<(), O> + Send + Sync + 'static,
313 O: Send + 'static,
314{
315 #[allow(clippy::type_complexity)]
317 pub fn start_with<S>(
318 self,
319 step: S,
320 ) -> Pipeline<
321 S::Input,
322 S::Output,
323 ChainedStep<StepWrapper<S>, Identity, S::Input, S::Output, S::Output>,
324 >
325 where
326 S: Step + 'static,
327 {
328 Pipeline {
329 name: self.name,
330 chain: ChainedStep {
331 step: StepWrapper(step),
332 next: Identity,
333 _phantom: std::marker::PhantomData,
334 },
335 retry_policy: self.retry_policy,
336 recorder: self.recorder,
337 spawn_rules: Vec::new(),
338 _phantom: std::marker::PhantomData,
339 }
340 }
341}
342
343impl<I, O, Chain> Pipeline<I, O, Chain>
344where
345 I: Send + Sync + Clone + 'static,
346 O: Send + Sync + Clone + 'static,
347 Chain: StepChain<I, O> + Send + Sync + 'static,
348{
349 pub fn then<S>(self, step: S) -> Pipeline<I, S::Output, impl StepChain<I, S::Output>>
351 where
352 S: Step<Input = O> + 'static,
353 {
354 Pipeline {
355 name: self.name,
356 chain: ThenChain {
357 first: self.chain,
358 step: StepWrapper(step),
359 _phantom: std::marker::PhantomData,
360 },
361 retry_policy: self.retry_policy,
362 recorder: self.recorder,
363 spawn_rules: Vec::new(),
364 _phantom: std::marker::PhantomData,
365 }
366 }
367
368 pub fn with_retry(mut self, policy: RetryPolicy) -> Self {
370 self.retry_policy = policy;
371 self
372 }
373
374 pub fn with_recorder<R: Recorder + 'static>(mut self, recorder: R) -> Self {
376 self.recorder = Arc::new(recorder);
377 self
378 }
379
380 pub fn spawn_from<T, F>(mut self, target: &'static str, f: F) -> Self
385 where
386 T: Serialize + 'static,
387 F: Fn(&O) -> Vec<T> + Send + Sync + 'static,
388 {
389 self.spawn_rules.push(SpawnRule::Dynamic {
390 target,
391 generator: Arc::new(move |output| {
392 f(output)
393 .into_iter()
394 .filter_map(|item| serde_json::to_value(item).ok())
395 .collect()
396 }),
397 });
398 self
399 }
400
401 #[deprecated(since = "0.4.0", note = "Use spawn_from instead")]
403 pub fn spawns<T, F>(self, target: &'static str, f: F) -> Self
404 where
405 T: Serialize + 'static,
406 F: Fn(&O) -> Vec<T> + Send + Sync + 'static,
407 {
408 self.spawn_from(target, f)
409 }
410
411 pub fn fork_when<F>(mut self, predicate: F, target: &'static str) -> Self
416 where
417 F: Fn(&O) -> bool + Send + Sync + 'static,
418 {
419 self.spawn_rules.push(SpawnRule::Fork {
420 target,
421 predicate: Arc::new(predicate),
422 description: format!("fork to {}", target),
423 });
424 self
425 }
426
427 pub fn fork_when_desc<F>(
429 mut self,
430 predicate: F,
431 target: &'static str,
432 description: &str,
433 ) -> Self
434 where
435 F: Fn(&O) -> bool + Send + Sync + 'static,
436 {
437 self.spawn_rules.push(SpawnRule::Fork {
438 target,
439 predicate: Arc::new(predicate),
440 description: description.to_string(),
441 });
442 self
443 }
444
445 pub fn fan_out(mut self, targets: &[&'static str]) -> Self {
449 self.spawn_rules.push(SpawnRule::FanOut {
450 targets: targets.to_vec(),
451 });
452 self
453 }
454
455 pub fn build(self) -> BuiltPipeline<I, O, Chain> {
457 BuiltPipeline {
458 name: self.name,
459 chain: self.chain,
460 retry_policy: self.retry_policy,
461 recorder: self.recorder,
462 spawn_rules: self.spawn_rules,
463 _phantom: std::marker::PhantomData,
464 }
465 }
466}
467
468#[doc(hidden)]
470pub struct ThenChain<First, S, I, M, O>
471where
472 First: StepChain<I, M>,
473 S: BoxedStep<M, O>,
474{
475 pub first: First,
476 pub step: S,
477 pub _phantom: std::marker::PhantomData<(I, M, O)>,
478}
479
480#[async_trait]
481impl<First, S, I, M, O> StepChain<I, O> for ThenChain<First, S, I, M, O>
482where
483 I: Send + Sync + Clone + 'static,
484 M: Send + Sync + Clone + 'static,
485 O: Send + Sync + 'static,
486 First: StepChain<I, M> + Send + Sync,
487 S: BoxedStep<M, O> + Send + Sync,
488{
489 async fn run(
490 &self,
491 input: I,
492 run_id: RunId,
493 recorder: &dyn Recorder,
494 retry_policy: &RetryPolicy,
495 start_index: u32,
496 ) -> Result<O, PipelineError> {
497 let mid = self
499 .first
500 .run(input, run_id, recorder, retry_policy, start_index)
501 .await?;
502
503 let next_index = start_index + self.first.step_count();
504
505 let step_name = self.step.name();
506 let step_id = recorder.start_step(run_id, step_name, next_index).await?;
507
508 let mut attempt = 0u32;
510 let output = loop {
511 attempt += 1;
512 match self.step.execute(mid.clone()).await {
513 Ok(output) => break output,
514 Err(StepError::Permanent(e)) => {
515 recorder
516 .complete_step(
517 step_id,
518 StepStatus::Failed {
519 error: e.to_string(),
520 attempt,
521 },
522 )
523 .await?;
524 return Err(PipelineError::StepFailed {
525 step: step_name,
526 source: e,
527 });
528 }
529 Err(StepError::Retryable(e)) => {
530 if let Some(delay) = retry_policy.delay_for_attempt(attempt) {
531 tokio::time::sleep(delay).await;
532 } else {
533 recorder
534 .complete_step(
535 step_id,
536 StepStatus::Failed {
537 error: e.to_string(),
538 attempt,
539 },
540 )
541 .await?;
542 return Err(PipelineError::RetriesExhausted {
543 step: step_name,
544 attempts: attempt,
545 source: e,
546 });
547 }
548 }
549 }
550 };
551
552 recorder
553 .complete_step(step_id, StepStatus::Completed)
554 .await?;
555 Ok(output)
556 }
557
558 fn step_count(&self) -> u32 {
559 self.first.step_count() + 1
560 }
561
562 fn collect_step_names(&self, names: &mut Vec<&'static str>) {
563 self.first.collect_step_names(names);
564 names.push(self.step.name());
565 }
566}
567
568pub struct BuiltPipeline<I, O, Chain>
570where
571 Chain: StepChain<I, O>,
572{
573 name: &'static str,
574 chain: Chain,
575 retry_policy: RetryPolicy,
576 recorder: Arc<dyn Recorder>,
577 pub(crate) spawn_rules: Vec<SpawnRule<O>>,
578 _phantom: std::marker::PhantomData<(I, O)>,
579}
580
581impl<I, O, Chain> BuiltPipeline<I, O, Chain>
582where
583 I: Send + Clone + HasEntityId + 'static,
584 O: Send + Serialize + 'static,
585 Chain: StepChain<I, O> + Send + Sync,
586{
587 pub async fn run(&self, input: I) -> Result<O, PipelineError> {
589 let entity_id = input.entity_id();
590 let run_id = self.recorder.start_run(self.name, &entity_id).await?;
591
592 match self
593 .chain
594 .run(input, run_id, self.recorder.as_ref(), &self.retry_policy, 0)
595 .await
596 {
597 Ok(output) => {
598 self.recorder
599 .complete_run(run_id, RunStatus::Completed)
600 .await?;
601 Ok(output)
602 }
603 Err(e) => {
604 self.recorder
605 .complete_run(
606 run_id,
607 RunStatus::Failed {
608 error: e.to_string(),
609 },
610 )
611 .await?;
612 Err(e)
613 }
614 }
615 }
616
617 pub fn name(&self) -> &'static str {
619 self.name
620 }
621
622 pub fn get_spawned(&self, output: &O) -> Vec<(&'static str, serde_json::Value)> {
624 let mut spawned = Vec::new();
625
626 for rule in &self.spawn_rules {
627 match rule {
628 SpawnRule::Fork {
629 target, predicate, ..
630 } => {
631 if predicate(output) {
632 if let Ok(value) = serde_json::to_value(output) {
633 spawned.push((*target, value));
634 }
635 }
636 }
637 SpawnRule::FanOut { targets } => {
638 if let Ok(value) = serde_json::to_value(output) {
639 for target in targets {
640 spawned.push((*target, value.clone()));
641 }
642 }
643 }
644 SpawnRule::Dynamic { target, generator } => {
645 for input in generator(output) {
646 spawned.push((*target, input));
647 }
648 }
649 }
650 }
651
652 spawned
653 }
654
655 pub fn to_graph(&self) -> PipelineGraph {
657 let mut step_names = Vec::new();
658 self.chain.collect_step_names(&mut step_names);
659
660 let steps: Vec<StepNode> = step_names
661 .into_iter()
662 .enumerate()
663 .map(|(index, name)| StepNode {
664 name: name.to_string(),
665 index,
666 })
667 .collect();
668
669 let mut forks = Vec::new();
670 let mut fan_outs = Vec::new();
671 let mut dynamic_spawns = Vec::new();
672
673 for rule in &self.spawn_rules {
674 match rule {
675 SpawnRule::Fork {
676 target,
677 description,
678 ..
679 } => {
680 forks.push(ForkNode {
681 target_pipeline: target.to_string(),
682 condition: description.clone(),
683 });
684 }
685 SpawnRule::FanOut { targets } => {
686 fan_outs.push(FanOutNode {
687 targets: targets.iter().map(|s| s.to_string()).collect(),
688 });
689 }
690 SpawnRule::Dynamic { target, .. } => {
691 dynamic_spawns.push(DynamicSpawnNode {
692 target_pipeline: target.to_string(),
693 });
694 }
695 }
696 }
697
698 PipelineGraph {
699 name: self.name.to_string(),
700 steps,
701 forks,
702 fan_outs,
703 dynamic_spawns,
704 }
705 }
706}