Skip to main content

nv_perception/
pipeline.rs

1//! Stage pipeline composition — ordered collections with optional validation.
2//!
3//! [`StagePipeline`] provides a builder for composing stages into an
4//! ordered pipeline. The resulting pipeline can be passed directly to
5//! `FeedConfigBuilder::pipeline` or destructured into a `Vec<Box<dyn Stage>>`.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use nv_perception::StagePipeline;
11//! # use nv_perception::{Stage, StageContext, StageOutput};
12//! # use nv_core::{StageId, StageError};
13//! # struct MyDetector;
14//! # impl Stage for MyDetector {
15//! #     fn id(&self) -> StageId { StageId("det") }
16//! #     fn process(&mut self, _: &StageContext<'_>) -> Result<StageOutput, StageError> {
17//! #         Ok(StageOutput::empty())
18//! #     }
19//! # }
20//! # struct MyTracker;
21//! # impl Stage for MyTracker {
22//! #     fn id(&self) -> StageId { StageId("trk") }
23//! #     fn process(&mut self, _: &StageContext<'_>) -> Result<StageOutput, StageError> {
24//! #         Ok(StageOutput::empty())
25//! #     }
26//! # }
27//!
28//! let pipeline = StagePipeline::builder()
29//!     .add(MyDetector)
30//!     .add(MyTracker)
31//!     .build();
32//!
33//! assert_eq!(pipeline.len(), 2);
34//! let stages = pipeline.into_stages();
35//! ```
36
37use crate::stage::{Stage, StageCapabilities, StageCategory};
38use nv_core::id::StageId;
39
40/// Controls whether [`StagePipeline::validate`] / [`validate_stages`]
41/// warnings are ignored, logged, or promoted to hard errors.
42///
43/// Used by `FeedConfigBuilder` to wire validation into
44/// the normal build path without requiring callers to invoke
45/// `validate()` manually.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
47pub enum ValidationMode {
48    /// Validation is skipped entirely (default).
49    #[default]
50    Off,
51    /// Validation runs; warnings are returned but do not prevent
52    /// pipeline construction.
53    Warn,
54    /// Validation runs; any warning is promoted to a hard error.
55    Error,
56}
57
58/// An ordered, validated collection of perception stages.
59///
60/// Built via [`StagePipeline::builder()`]. The pipeline can be inspected
61/// for stage IDs and categories before being consumed as a `Vec<Box<dyn Stage>>`.
62pub struct StagePipeline {
63    stages: Vec<Box<dyn Stage>>,
64}
65
66impl StagePipeline {
67    /// Create a new builder.
68    #[must_use]
69    pub fn builder() -> StagePipelineBuilder {
70        StagePipelineBuilder { stages: Vec::new() }
71    }
72
73    /// Number of stages in the pipeline.
74    #[must_use]
75    pub fn len(&self) -> usize {
76        self.stages.len()
77    }
78
79    /// Whether the pipeline is empty.
80    #[must_use]
81    pub fn is_empty(&self) -> bool {
82        self.stages.is_empty()
83    }
84
85    /// Get the stage IDs in execution order.
86    #[must_use]
87    pub fn stage_ids(&self) -> Vec<StageId> {
88        self.stages.iter().map(|s| s.id()).collect()
89    }
90
91    /// Get `(StageId, StageCategory)` pairs in execution order.
92    #[must_use]
93    pub fn categories(&self) -> Vec<(StageId, StageCategory)> {
94        self.stages.iter().map(|s| (s.id(), s.category())).collect()
95    }
96
97    /// Consume the pipeline and return the ordered stage list.
98    ///
99    /// Suitable for passing to `FeedConfigBuilder::stages` or
100    /// `FeedConfigBuilder::pipeline`.
101    #[must_use]
102    pub fn into_stages(self) -> Vec<Box<dyn Stage>> {
103        self.stages
104    }
105
106    /// Validate stage ordering based on declared [`StageCapabilities`].
107    ///
108    /// Returns a (possibly empty) list of warnings. Stages that return
109    /// `None` from [`Stage::capabilities()`] are silently skipped.
110    ///
111    /// Warnings are advisory — the pipeline will still execute regardless.
112    /// This allows pipeline builders to catch common ordering mistakes
113    /// (e.g., placing a tracker before a detector) at construction time.
114    #[must_use]
115    pub fn validate(&self) -> Vec<ValidationWarning> {
116        validate_stages(&self.stages)
117    }
118}
119
120/// Advisory warning from [`StagePipeline::validate()`].
121///
122/// These do **not** prevent pipeline execution. They flag likely
123/// composition mistakes that the builder may want to address.
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub enum ValidationWarning {
126    /// A stage declares that it consumes an artifact type that no
127    /// earlier stage produces.
128    UnsatisfiedDependency {
129        /// The stage with the unsatisfied dependency.
130        stage_id: StageId,
131        /// Human-readable name of the missing artifact type.
132        missing: &'static str,
133    },
134    /// Two or more stages share the same [`StageId`].
135    DuplicateStageId {
136        /// The duplicated stage ID.
137        stage_id: StageId,
138    },
139}
140
141/// Validate an ordered stage slice and return advisory warnings.
142///
143/// This is the same logic as [`StagePipeline::validate`] but operates
144/// on a borrowed slice, making it usable by `FeedConfigBuilder`
145/// without requiring a `StagePipeline`.
146#[must_use]
147pub fn validate_stages(stages: &[Box<dyn Stage>]) -> Vec<ValidationWarning> {
148    let mut warnings = Vec::new();
149    let mut detections_available = false;
150    let mut tracks_available = false;
151
152    for stage in stages {
153        validate_one_stage(
154            &**stage,
155            &mut detections_available,
156            &mut tracks_available,
157            &mut warnings,
158        );
159    }
160
161    // Check for duplicate stage IDs.
162    let mut seen = std::collections::HashSet::new();
163    for stage in stages {
164        let id = stage.id();
165        if !seen.insert(id) {
166            warnings.push(ValidationWarning::DuplicateStageId { stage_id: id });
167        }
168    }
169
170    warnings
171}
172
173/// Validate a batch-aware pipeline: pre-batch stages → optional batch
174/// processor → post-batch stages.
175///
176/// Availability state (which artifact types have been produced) flows
177/// through all three phases in order.  The batch processor's
178/// `consumes_*` requirements are validated against pre-batch
179/// availability, and its `produces_*` fields update availability for
180/// the post-batch stages.  Duplicate stage IDs across all phases
181/// (including the batch processor) are also reported.
182///
183/// This is the recommended entry-point for validating a pipeline that
184/// includes a cross-feed batch processor.  For simple linear pipelines
185/// without a batch processor, [`validate_stages`] is sufficient.
186#[must_use]
187pub fn validate_pipeline_phased(
188    pre_batch: &[Box<dyn Stage>],
189    batch_caps: Option<&StageCapabilities>,
190    batch_id: Option<StageId>,
191    post_batch: &[Box<dyn Stage>],
192) -> Vec<ValidationWarning> {
193    let mut warnings = Vec::new();
194    let mut detections_available = false;
195    let mut tracks_available = false;
196
197    // Phase 1: pre-batch stages.
198    for stage in pre_batch {
199        validate_one_stage(
200            &**stage,
201            &mut detections_available,
202            &mut tracks_available,
203            &mut warnings,
204        );
205    }
206
207    // Phase 2: batch processor capabilities (if declared).
208    if let Some(caps) = batch_caps {
209        if let Some(id) = batch_id {
210            if caps.consumes_detections && !detections_available {
211                warnings.push(ValidationWarning::UnsatisfiedDependency {
212                    stage_id: id,
213                    missing: "detections",
214                });
215            }
216            if caps.consumes_tracks && !tracks_available {
217                warnings.push(ValidationWarning::UnsatisfiedDependency {
218                    stage_id: id,
219                    missing: "tracks",
220                });
221            }
222        }
223
224        if caps.produces_detections {
225            detections_available = true;
226        }
227        if caps.produces_tracks {
228            tracks_available = true;
229        }
230    }
231
232    // Phase 3: post-batch stages.
233    for stage in post_batch {
234        validate_one_stage(
235            &**stage,
236            &mut detections_available,
237            &mut tracks_available,
238            &mut warnings,
239        );
240    }
241
242    // Duplicate ID check across all per-feed stages + batch processor.
243    let mut seen = std::collections::HashSet::new();
244    if let Some(id) = batch_id {
245        seen.insert(id);
246    }
247    for stage in pre_batch.iter().chain(post_batch.iter()) {
248        let id = stage.id();
249        if !seen.insert(id) {
250            warnings.push(ValidationWarning::DuplicateStageId { stage_id: id });
251        }
252    }
253
254    warnings
255}
256
257/// Check a single stage's capabilities against current availability and
258/// update availability from its outputs.
259pub(crate) fn validate_one_stage(
260    stage: &dyn Stage,
261    detections_available: &mut bool,
262    tracks_available: &mut bool,
263    warnings: &mut Vec<ValidationWarning>,
264) {
265    let caps = match stage.capabilities() {
266        Some(c) => c,
267        None => return,
268    };
269    let id = stage.id();
270
271    if caps.consumes_detections && !*detections_available {
272        warnings.push(ValidationWarning::UnsatisfiedDependency {
273            stage_id: id,
274            missing: "detections",
275        });
276    }
277    if caps.consumes_tracks && !*tracks_available {
278        warnings.push(ValidationWarning::UnsatisfiedDependency {
279            stage_id: id,
280            missing: "tracks",
281        });
282    }
283    if caps.produces_detections {
284        *detections_available = true;
285    }
286    if caps.produces_tracks {
287        *tracks_available = true;
288    }
289}
290
291/// Builder for [`StagePipeline`].
292pub struct StagePipelineBuilder {
293    stages: Vec<Box<dyn Stage>>,
294}
295
296impl StagePipelineBuilder {
297    /// Append a stage to the pipeline.
298    #[must_use]
299    #[allow(clippy::should_implement_trait)]
300    pub fn add(mut self, stage: impl Stage) -> Self {
301        self.stages.push(Box::new(stage));
302        self
303    }
304
305    /// Append a boxed stage to the pipeline.
306    #[must_use]
307    pub fn add_boxed(mut self, stage: Box<dyn Stage>) -> Self {
308        self.stages.push(stage);
309        self
310    }
311
312    /// Build the pipeline.
313    #[must_use]
314    pub fn build(self) -> StagePipeline {
315        StagePipeline {
316            stages: self.stages,
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::stage::StageCapabilities;
325    use crate::{StageContext, StageOutput};
326    use nv_core::error::StageError;
327
328    struct TestStage {
329        name: &'static str,
330        cat: StageCategory,
331    }
332
333    impl Stage for TestStage {
334        fn id(&self) -> StageId {
335            StageId(self.name)
336        }
337        fn process(&mut self, _ctx: &StageContext<'_>) -> Result<StageOutput, StageError> {
338            Ok(StageOutput::empty())
339        }
340        fn category(&self) -> StageCategory {
341            self.cat
342        }
343    }
344
345    #[test]
346    fn builder_preserves_order() {
347        let pipeline = StagePipeline::builder()
348            .add(TestStage {
349                name: "det",
350                cat: StageCategory::FrameAnalysis,
351            })
352            .add(TestStage {
353                name: "trk",
354                cat: StageCategory::Association,
355            })
356            .add(TestStage {
357                name: "temporal",
358                cat: StageCategory::TemporalAnalysis,
359            })
360            .add(TestStage {
361                name: "sink",
362                cat: StageCategory::Sink,
363            })
364            .build();
365
366        let ids: Vec<&str> = pipeline.stage_ids().iter().map(|s| s.as_str()).collect();
367        assert_eq!(ids, vec!["det", "trk", "temporal", "sink"]);
368    }
369
370    #[test]
371    fn categories_reported_correctly() {
372        let pipeline = StagePipeline::builder()
373            .add(TestStage {
374                name: "det",
375                cat: StageCategory::FrameAnalysis,
376            })
377            .add(TestStage {
378                name: "trk",
379                cat: StageCategory::Association,
380            })
381            .build();
382
383        let cats = pipeline.categories();
384        assert_eq!(cats[0].1, StageCategory::FrameAnalysis);
385        assert_eq!(cats[1].1, StageCategory::Association);
386    }
387
388    #[test]
389    fn into_stages_returns_owned_vec() {
390        let pipeline = StagePipeline::builder()
391            .add(TestStage {
392                name: "a",
393                cat: StageCategory::Custom,
394            })
395            .add(TestStage {
396                name: "b",
397                cat: StageCategory::Custom,
398            })
399            .build();
400
401        let stages = pipeline.into_stages();
402        assert_eq!(stages.len(), 2);
403        assert_eq!(stages[0].id(), StageId("a"));
404        assert_eq!(stages[1].id(), StageId("b"));
405    }
406
407    #[test]
408    fn empty_pipeline() {
409        let pipeline = StagePipeline::builder().build();
410        assert!(pipeline.is_empty());
411        assert_eq!(pipeline.len(), 0);
412    }
413
414    /// A test stage with configurable capabilities.
415    struct CapStage {
416        name: &'static str,
417        caps: Option<StageCapabilities>,
418    }
419
420    impl Stage for CapStage {
421        fn id(&self) -> StageId {
422            StageId(self.name)
423        }
424        fn process(&mut self, _ctx: &StageContext<'_>) -> Result<StageOutput, StageError> {
425            Ok(StageOutput::empty())
426        }
427        fn capabilities(&self) -> Option<StageCapabilities> {
428            self.caps
429        }
430    }
431
432    #[test]
433    fn validate_happy_path() {
434        let pipeline = StagePipeline::builder()
435            .add(CapStage {
436                name: "det",
437                caps: Some(StageCapabilities::new().produces_detections()),
438            })
439            .add(CapStage {
440                name: "trk",
441                caps: Some(
442                    StageCapabilities::new()
443                        .consumes_detections()
444                        .produces_tracks(),
445                ),
446            })
447            .build();
448
449        let warnings = pipeline.validate();
450        assert!(warnings.is_empty());
451    }
452
453    #[test]
454    fn validate_unsatisfied_detections() {
455        let pipeline = StagePipeline::builder()
456            .add(CapStage {
457                name: "trk",
458                caps: Some(StageCapabilities::new().consumes_detections()),
459            })
460            .add(CapStage {
461                name: "det",
462                caps: Some(StageCapabilities::new().produces_detections()),
463            })
464            .build();
465
466        let warnings = pipeline.validate();
467        assert_eq!(warnings.len(), 1);
468        assert_eq!(
469            warnings[0],
470            ValidationWarning::UnsatisfiedDependency {
471                stage_id: StageId("trk"),
472                missing: "detections",
473            }
474        );
475    }
476
477    #[test]
478    fn validate_unsatisfied_tracks() {
479        let pipeline = StagePipeline::builder()
480            .add(CapStage {
481                name: "temporal",
482                caps: Some(StageCapabilities::new().consumes_tracks()),
483            })
484            .build();
485
486        let warnings = pipeline.validate();
487        assert_eq!(warnings.len(), 1);
488        assert!(matches!(
489            &warnings[0],
490            ValidationWarning::UnsatisfiedDependency {
491                missing: "tracks",
492                ..
493            }
494        ));
495    }
496
497    #[test]
498    fn validate_skips_stages_without_capabilities() {
499        let pipeline = StagePipeline::builder()
500            .add(CapStage {
501                name: "noop",
502                caps: None,
503            })
504            .add(CapStage {
505                name: "det",
506                caps: Some(StageCapabilities::new().produces_detections()),
507            })
508            .build();
509
510        let warnings = pipeline.validate();
511        assert!(warnings.is_empty());
512    }
513
514    #[test]
515    fn validate_duplicate_stage_ids() {
516        let pipeline = StagePipeline::builder()
517            .add(CapStage {
518                name: "det",
519                caps: None,
520            })
521            .add(CapStage {
522                name: "det",
523                caps: None,
524            })
525            .build();
526
527        let warnings = pipeline.validate();
528        assert_eq!(warnings.len(), 1);
529        assert!(matches!(
530            &warnings[0],
531            ValidationWarning::DuplicateStageId { stage_id } if *stage_id == StageId("det")
532        ));
533    }
534
535    #[test]
536    fn validate_stages_fn_matches_pipeline_validate() {
537        let stages: Vec<Box<dyn Stage>> = vec![
538            Box::new(CapStage {
539                name: "trk",
540                caps: Some(StageCapabilities::new().consumes_detections()),
541            }),
542            Box::new(CapStage {
543                name: "det",
544                caps: Some(StageCapabilities::new().produces_detections()),
545            }),
546        ];
547
548        let warnings = validate_stages(&stages);
549        assert_eq!(warnings.len(), 1);
550        assert_eq!(
551            warnings[0],
552            ValidationWarning::UnsatisfiedDependency {
553                stage_id: StageId("trk"),
554                missing: "detections",
555            }
556        );
557    }
558
559    #[test]
560    fn validate_stages_fn_happy_path() {
561        let stages: Vec<Box<dyn Stage>> = vec![
562            Box::new(CapStage {
563                name: "det",
564                caps: Some(StageCapabilities::new().produces_detections()),
565            }),
566            Box::new(CapStage {
567                name: "trk",
568                caps: Some(
569                    StageCapabilities::new()
570                        .consumes_detections()
571                        .produces_tracks(),
572                ),
573            }),
574        ];
575
576        let warnings = validate_stages(&stages);
577        assert!(warnings.is_empty());
578    }
579
580    // --- validate_pipeline_phased tests ---
581
582    #[test]
583    fn phased_happy_path() {
584        let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
585            name: "det",
586            caps: Some(StageCapabilities::new().produces_detections()),
587        })];
588        let batch_caps = StageCapabilities::new()
589            .consumes_detections()
590            .produces_tracks();
591        let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
592            name: "temporal",
593            caps: Some(StageCapabilities::new().consumes_tracks()),
594        })];
595
596        let warnings =
597            validate_pipeline_phased(&pre, Some(&batch_caps), Some(StageId("batch")), &post);
598        assert!(warnings.is_empty());
599    }
600
601    #[test]
602    fn phased_batch_unsatisfied_dependency() {
603        let pre: Vec<Box<dyn Stage>> = vec![];
604        let batch_caps = StageCapabilities::new().consumes_detections();
605
606        let warnings =
607            validate_pipeline_phased(&pre, Some(&batch_caps), Some(StageId("batch")), &[]);
608        assert_eq!(warnings.len(), 1);
609        assert_eq!(
610            warnings[0],
611            ValidationWarning::UnsatisfiedDependency {
612                stage_id: StageId("batch"),
613                missing: "detections",
614            }
615        );
616    }
617
618    #[test]
619    fn phased_post_batch_sees_batch_outputs() {
620        let pre: Vec<Box<dyn Stage>> = vec![];
621        let batch_caps = StageCapabilities::new().produces_detections();
622        let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
623            name: "trk",
624            caps: Some(StageCapabilities::new().consumes_detections()),
625        })];
626
627        let warnings =
628            validate_pipeline_phased(&pre, Some(&batch_caps), Some(StageId("batch")), &post);
629        assert!(warnings.is_empty());
630    }
631
632    #[test]
633    fn phased_duplicate_across_phases() {
634        let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
635            name: "dup",
636            caps: None,
637        })];
638        let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
639            name: "dup",
640            caps: None,
641        })];
642
643        let warnings = validate_pipeline_phased(&pre, None, None, &post);
644        assert_eq!(warnings.len(), 1);
645        assert!(matches!(
646            &warnings[0],
647            ValidationWarning::DuplicateStageId { stage_id } if *stage_id == StageId("dup")
648        ));
649    }
650
651    #[test]
652    fn phased_duplicate_with_batch_id() {
653        let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
654            name: "batch",
655            caps: None,
656        })];
657
658        let warnings = validate_pipeline_phased(&pre, None, Some(StageId("batch")), &[]);
659        assert_eq!(warnings.len(), 1);
660        assert!(matches!(
661            &warnings[0],
662            ValidationWarning::DuplicateStageId { stage_id } if *stage_id == StageId("batch")
663        ));
664    }
665
666    #[test]
667    fn phased_no_batch_processor() {
668        let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
669            name: "det",
670            caps: Some(StageCapabilities::new().produces_detections()),
671        })];
672        let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
673            name: "trk",
674            caps: Some(StageCapabilities::new().consumes_detections()),
675        })];
676
677        let warnings = validate_pipeline_phased(&pre, None, None, &post);
678        assert!(warnings.is_empty());
679    }
680}