use crate::stage::{Stage, StageCapabilities, StageCategory};
use nv_core::id::StageId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum ValidationMode {
#[default]
Off,
Warn,
Error,
}
pub struct StagePipeline {
stages: Vec<Box<dyn Stage>>,
}
impl StagePipeline {
#[must_use]
pub fn builder() -> StagePipelineBuilder {
StagePipelineBuilder { stages: Vec::new() }
}
#[must_use]
pub fn len(&self) -> usize {
self.stages.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.stages.is_empty()
}
#[must_use]
pub fn stage_ids(&self) -> Vec<StageId> {
self.stages.iter().map(|s| s.id()).collect()
}
#[must_use]
pub fn categories(&self) -> Vec<(StageId, StageCategory)> {
self.stages.iter().map(|s| (s.id(), s.category())).collect()
}
#[must_use]
pub fn into_stages(self) -> Vec<Box<dyn Stage>> {
self.stages
}
#[must_use]
pub fn validate(&self) -> Vec<ValidationWarning> {
validate_stages(&self.stages)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidationWarning {
UnsatisfiedDependency {
stage_id: StageId,
missing: &'static str,
},
DuplicateStageId {
stage_id: StageId,
},
}
#[must_use]
pub fn validate_stages(stages: &[Box<dyn Stage>]) -> Vec<ValidationWarning> {
let mut warnings = Vec::new();
let mut detections_available = false;
let mut tracks_available = false;
for stage in stages {
validate_one_stage(
&**stage,
&mut detections_available,
&mut tracks_available,
&mut warnings,
);
}
let mut seen = std::collections::HashSet::new();
for stage in stages {
let id = stage.id();
if !seen.insert(id) {
warnings.push(ValidationWarning::DuplicateStageId { stage_id: id });
}
}
warnings
}
#[must_use]
pub fn validate_pipeline_phased(
pre_batch: &[Box<dyn Stage>],
batch_caps: Option<&StageCapabilities>,
batch_id: Option<StageId>,
post_batch: &[Box<dyn Stage>],
) -> Vec<ValidationWarning> {
let mut warnings = Vec::new();
let mut detections_available = false;
let mut tracks_available = false;
for stage in pre_batch {
validate_one_stage(
&**stage,
&mut detections_available,
&mut tracks_available,
&mut warnings,
);
}
if let Some(caps) = batch_caps {
if let Some(id) = batch_id {
if caps.consumes_detections && !detections_available {
warnings.push(ValidationWarning::UnsatisfiedDependency {
stage_id: id,
missing: "detections",
});
}
if caps.consumes_tracks && !tracks_available {
warnings.push(ValidationWarning::UnsatisfiedDependency {
stage_id: id,
missing: "tracks",
});
}
}
if caps.produces_detections {
detections_available = true;
}
if caps.produces_tracks {
tracks_available = true;
}
}
for stage in post_batch {
validate_one_stage(
&**stage,
&mut detections_available,
&mut tracks_available,
&mut warnings,
);
}
let mut seen = std::collections::HashSet::new();
if let Some(id) = batch_id {
seen.insert(id);
}
for stage in pre_batch.iter().chain(post_batch.iter()) {
let id = stage.id();
if !seen.insert(id) {
warnings.push(ValidationWarning::DuplicateStageId { stage_id: id });
}
}
warnings
}
pub(crate) fn validate_one_stage(
stage: &dyn Stage,
detections_available: &mut bool,
tracks_available: &mut bool,
warnings: &mut Vec<ValidationWarning>,
) {
let caps = match stage.capabilities() {
Some(c) => c,
None => return,
};
let id = stage.id();
if caps.consumes_detections && !*detections_available {
warnings.push(ValidationWarning::UnsatisfiedDependency {
stage_id: id,
missing: "detections",
});
}
if caps.consumes_tracks && !*tracks_available {
warnings.push(ValidationWarning::UnsatisfiedDependency {
stage_id: id,
missing: "tracks",
});
}
if caps.produces_detections {
*detections_available = true;
}
if caps.produces_tracks {
*tracks_available = true;
}
}
pub struct StagePipelineBuilder {
stages: Vec<Box<dyn Stage>>,
}
impl StagePipelineBuilder {
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, stage: impl Stage) -> Self {
self.stages.push(Box::new(stage));
self
}
#[must_use]
pub fn add_boxed(mut self, stage: Box<dyn Stage>) -> Self {
self.stages.push(stage);
self
}
#[must_use]
pub fn build(self) -> StagePipeline {
StagePipeline {
stages: self.stages,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stage::StageCapabilities;
use crate::{StageContext, StageOutput};
use nv_core::error::StageError;
struct TestStage {
name: &'static str,
cat: StageCategory,
}
impl Stage for TestStage {
fn id(&self) -> StageId {
StageId(self.name)
}
fn process(&mut self, _ctx: &StageContext<'_>) -> Result<StageOutput, StageError> {
Ok(StageOutput::empty())
}
fn category(&self) -> StageCategory {
self.cat
}
}
#[test]
fn builder_preserves_order() {
let pipeline = StagePipeline::builder()
.add(TestStage {
name: "det",
cat: StageCategory::FrameAnalysis,
})
.add(TestStage {
name: "trk",
cat: StageCategory::Association,
})
.add(TestStage {
name: "temporal",
cat: StageCategory::TemporalAnalysis,
})
.add(TestStage {
name: "sink",
cat: StageCategory::Sink,
})
.build();
let ids: Vec<&str> = pipeline.stage_ids().iter().map(|s| s.as_str()).collect();
assert_eq!(ids, vec!["det", "trk", "temporal", "sink"]);
}
#[test]
fn categories_reported_correctly() {
let pipeline = StagePipeline::builder()
.add(TestStage {
name: "det",
cat: StageCategory::FrameAnalysis,
})
.add(TestStage {
name: "trk",
cat: StageCategory::Association,
})
.build();
let cats = pipeline.categories();
assert_eq!(cats[0].1, StageCategory::FrameAnalysis);
assert_eq!(cats[1].1, StageCategory::Association);
}
#[test]
fn into_stages_returns_owned_vec() {
let pipeline = StagePipeline::builder()
.add(TestStage {
name: "a",
cat: StageCategory::Custom,
})
.add(TestStage {
name: "b",
cat: StageCategory::Custom,
})
.build();
let stages = pipeline.into_stages();
assert_eq!(stages.len(), 2);
assert_eq!(stages[0].id(), StageId("a"));
assert_eq!(stages[1].id(), StageId("b"));
}
#[test]
fn empty_pipeline() {
let pipeline = StagePipeline::builder().build();
assert!(pipeline.is_empty());
assert_eq!(pipeline.len(), 0);
}
struct CapStage {
name: &'static str,
caps: Option<StageCapabilities>,
}
impl Stage for CapStage {
fn id(&self) -> StageId {
StageId(self.name)
}
fn process(&mut self, _ctx: &StageContext<'_>) -> Result<StageOutput, StageError> {
Ok(StageOutput::empty())
}
fn capabilities(&self) -> Option<StageCapabilities> {
self.caps
}
}
#[test]
fn validate_happy_path() {
let pipeline = StagePipeline::builder()
.add(CapStage {
name: "det",
caps: Some(StageCapabilities::new().produces_detections()),
})
.add(CapStage {
name: "trk",
caps: Some(
StageCapabilities::new()
.consumes_detections()
.produces_tracks(),
),
})
.build();
let warnings = pipeline.validate();
assert!(warnings.is_empty());
}
#[test]
fn validate_unsatisfied_detections() {
let pipeline = StagePipeline::builder()
.add(CapStage {
name: "trk",
caps: Some(StageCapabilities::new().consumes_detections()),
})
.add(CapStage {
name: "det",
caps: Some(StageCapabilities::new().produces_detections()),
})
.build();
let warnings = pipeline.validate();
assert_eq!(warnings.len(), 1);
assert_eq!(
warnings[0],
ValidationWarning::UnsatisfiedDependency {
stage_id: StageId("trk"),
missing: "detections",
}
);
}
#[test]
fn validate_unsatisfied_tracks() {
let pipeline = StagePipeline::builder()
.add(CapStage {
name: "temporal",
caps: Some(StageCapabilities::new().consumes_tracks()),
})
.build();
let warnings = pipeline.validate();
assert_eq!(warnings.len(), 1);
assert!(matches!(
&warnings[0],
ValidationWarning::UnsatisfiedDependency {
missing: "tracks",
..
}
));
}
#[test]
fn validate_skips_stages_without_capabilities() {
let pipeline = StagePipeline::builder()
.add(CapStage {
name: "noop",
caps: None,
})
.add(CapStage {
name: "det",
caps: Some(StageCapabilities::new().produces_detections()),
})
.build();
let warnings = pipeline.validate();
assert!(warnings.is_empty());
}
#[test]
fn validate_duplicate_stage_ids() {
let pipeline = StagePipeline::builder()
.add(CapStage {
name: "det",
caps: None,
})
.add(CapStage {
name: "det",
caps: None,
})
.build();
let warnings = pipeline.validate();
assert_eq!(warnings.len(), 1);
assert!(matches!(
&warnings[0],
ValidationWarning::DuplicateStageId { stage_id } if *stage_id == StageId("det")
));
}
#[test]
fn validate_stages_fn_matches_pipeline_validate() {
let stages: Vec<Box<dyn Stage>> = vec![
Box::new(CapStage {
name: "trk",
caps: Some(StageCapabilities::new().consumes_detections()),
}),
Box::new(CapStage {
name: "det",
caps: Some(StageCapabilities::new().produces_detections()),
}),
];
let warnings = validate_stages(&stages);
assert_eq!(warnings.len(), 1);
assert_eq!(
warnings[0],
ValidationWarning::UnsatisfiedDependency {
stage_id: StageId("trk"),
missing: "detections",
}
);
}
#[test]
fn validate_stages_fn_happy_path() {
let stages: Vec<Box<dyn Stage>> = vec![
Box::new(CapStage {
name: "det",
caps: Some(StageCapabilities::new().produces_detections()),
}),
Box::new(CapStage {
name: "trk",
caps: Some(
StageCapabilities::new()
.consumes_detections()
.produces_tracks(),
),
}),
];
let warnings = validate_stages(&stages);
assert!(warnings.is_empty());
}
#[test]
fn phased_happy_path() {
let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
name: "det",
caps: Some(StageCapabilities::new().produces_detections()),
})];
let batch_caps = StageCapabilities::new()
.consumes_detections()
.produces_tracks();
let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
name: "temporal",
caps: Some(StageCapabilities::new().consumes_tracks()),
})];
let warnings =
validate_pipeline_phased(&pre, Some(&batch_caps), Some(StageId("batch")), &post);
assert!(warnings.is_empty());
}
#[test]
fn phased_batch_unsatisfied_dependency() {
let pre: Vec<Box<dyn Stage>> = vec![];
let batch_caps = StageCapabilities::new().consumes_detections();
let warnings =
validate_pipeline_phased(&pre, Some(&batch_caps), Some(StageId("batch")), &[]);
assert_eq!(warnings.len(), 1);
assert_eq!(
warnings[0],
ValidationWarning::UnsatisfiedDependency {
stage_id: StageId("batch"),
missing: "detections",
}
);
}
#[test]
fn phased_post_batch_sees_batch_outputs() {
let pre: Vec<Box<dyn Stage>> = vec![];
let batch_caps = StageCapabilities::new().produces_detections();
let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
name: "trk",
caps: Some(StageCapabilities::new().consumes_detections()),
})];
let warnings =
validate_pipeline_phased(&pre, Some(&batch_caps), Some(StageId("batch")), &post);
assert!(warnings.is_empty());
}
#[test]
fn phased_duplicate_across_phases() {
let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
name: "dup",
caps: None,
})];
let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
name: "dup",
caps: None,
})];
let warnings = validate_pipeline_phased(&pre, None, None, &post);
assert_eq!(warnings.len(), 1);
assert!(matches!(
&warnings[0],
ValidationWarning::DuplicateStageId { stage_id } if *stage_id == StageId("dup")
));
}
#[test]
fn phased_duplicate_with_batch_id() {
let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
name: "batch",
caps: None,
})];
let warnings = validate_pipeline_phased(&pre, None, Some(StageId("batch")), &[]);
assert_eq!(warnings.len(), 1);
assert!(matches!(
&warnings[0],
ValidationWarning::DuplicateStageId { stage_id } if *stage_id == StageId("batch")
));
}
#[test]
fn phased_no_batch_processor() {
let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
name: "det",
caps: Some(StageCapabilities::new().produces_detections()),
})];
let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
name: "trk",
caps: Some(StageCapabilities::new().consumes_detections()),
})];
let warnings = validate_pipeline_phased(&pre, None, None, &post);
assert!(warnings.is_empty());
}
}