1use crate::stage::{Stage, StageCapabilities, StageCategory};
38use nv_core::id::StageId;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
47pub enum ValidationMode {
48 #[default]
50 Off,
51 Warn,
54 Error,
56}
57
58pub struct StagePipeline {
63 stages: Vec<Box<dyn Stage>>,
64}
65
66impl StagePipeline {
67 #[must_use]
69 pub fn builder() -> StagePipelineBuilder {
70 StagePipelineBuilder { stages: Vec::new() }
71 }
72
73 #[must_use]
75 pub fn len(&self) -> usize {
76 self.stages.len()
77 }
78
79 #[must_use]
81 pub fn is_empty(&self) -> bool {
82 self.stages.is_empty()
83 }
84
85 #[must_use]
87 pub fn stage_ids(&self) -> Vec<StageId> {
88 self.stages.iter().map(|s| s.id()).collect()
89 }
90
91 #[must_use]
93 pub fn categories(&self) -> Vec<(StageId, StageCategory)> {
94 self.stages.iter().map(|s| (s.id(), s.category())).collect()
95 }
96
97 #[must_use]
102 pub fn into_stages(self) -> Vec<Box<dyn Stage>> {
103 self.stages
104 }
105
106 #[must_use]
115 pub fn validate(&self) -> Vec<ValidationWarning> {
116 validate_stages(&self.stages)
117 }
118}
119
120#[derive(Debug, Clone, PartialEq, Eq)]
125pub enum ValidationWarning {
126 UnsatisfiedDependency {
129 stage_id: StageId,
131 missing: &'static str,
133 },
134 DuplicateStageId {
136 stage_id: StageId,
138 },
139}
140
141#[must_use]
147pub fn validate_stages(stages: &[Box<dyn Stage>]) -> Vec<ValidationWarning> {
148 let mut warnings = Vec::new();
149 let mut detections_available = false;
150 let mut tracks_available = false;
151
152 for stage in stages {
153 validate_one_stage(
154 &**stage,
155 &mut detections_available,
156 &mut tracks_available,
157 &mut warnings,
158 );
159 }
160
161 let mut seen = std::collections::HashSet::new();
163 for stage in stages {
164 let id = stage.id();
165 if !seen.insert(id) {
166 warnings.push(ValidationWarning::DuplicateStageId { stage_id: id });
167 }
168 }
169
170 warnings
171}
172
173#[must_use]
187pub fn validate_pipeline_phased(
188 pre_batch: &[Box<dyn Stage>],
189 batch_caps: Option<&StageCapabilities>,
190 batch_id: Option<StageId>,
191 post_batch: &[Box<dyn Stage>],
192) -> Vec<ValidationWarning> {
193 let mut warnings = Vec::new();
194 let mut detections_available = false;
195 let mut tracks_available = false;
196
197 for stage in pre_batch {
199 validate_one_stage(
200 &**stage,
201 &mut detections_available,
202 &mut tracks_available,
203 &mut warnings,
204 );
205 }
206
207 if let Some(caps) = batch_caps {
209 if let Some(id) = batch_id {
210 if caps.consumes_detections && !detections_available {
211 warnings.push(ValidationWarning::UnsatisfiedDependency {
212 stage_id: id,
213 missing: "detections",
214 });
215 }
216 if caps.consumes_tracks && !tracks_available {
217 warnings.push(ValidationWarning::UnsatisfiedDependency {
218 stage_id: id,
219 missing: "tracks",
220 });
221 }
222 }
223
224 if caps.produces_detections {
225 detections_available = true;
226 }
227 if caps.produces_tracks {
228 tracks_available = true;
229 }
230 }
231
232 for stage in post_batch {
234 validate_one_stage(
235 &**stage,
236 &mut detections_available,
237 &mut tracks_available,
238 &mut warnings,
239 );
240 }
241
242 let mut seen = std::collections::HashSet::new();
244 if let Some(id) = batch_id {
245 seen.insert(id);
246 }
247 for stage in pre_batch.iter().chain(post_batch.iter()) {
248 let id = stage.id();
249 if !seen.insert(id) {
250 warnings.push(ValidationWarning::DuplicateStageId { stage_id: id });
251 }
252 }
253
254 warnings
255}
256
257pub(crate) fn validate_one_stage(
260 stage: &dyn Stage,
261 detections_available: &mut bool,
262 tracks_available: &mut bool,
263 warnings: &mut Vec<ValidationWarning>,
264) {
265 let caps = match stage.capabilities() {
266 Some(c) => c,
267 None => return,
268 };
269 let id = stage.id();
270
271 if caps.consumes_detections && !*detections_available {
272 warnings.push(ValidationWarning::UnsatisfiedDependency {
273 stage_id: id,
274 missing: "detections",
275 });
276 }
277 if caps.consumes_tracks && !*tracks_available {
278 warnings.push(ValidationWarning::UnsatisfiedDependency {
279 stage_id: id,
280 missing: "tracks",
281 });
282 }
283 if caps.produces_detections {
284 *detections_available = true;
285 }
286 if caps.produces_tracks {
287 *tracks_available = true;
288 }
289}
290
291pub struct StagePipelineBuilder {
293 stages: Vec<Box<dyn Stage>>,
294}
295
296impl StagePipelineBuilder {
297 #[must_use]
299 #[allow(clippy::should_implement_trait)]
300 pub fn add(mut self, stage: impl Stage) -> Self {
301 self.stages.push(Box::new(stage));
302 self
303 }
304
305 #[must_use]
307 pub fn add_boxed(mut self, stage: Box<dyn Stage>) -> Self {
308 self.stages.push(stage);
309 self
310 }
311
312 #[must_use]
314 pub fn build(self) -> StagePipeline {
315 StagePipeline {
316 stages: self.stages,
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::stage::StageCapabilities;
325 use crate::{StageContext, StageOutput};
326 use nv_core::error::StageError;
327
328 struct TestStage {
329 name: &'static str,
330 cat: StageCategory,
331 }
332
333 impl Stage for TestStage {
334 fn id(&self) -> StageId {
335 StageId(self.name)
336 }
337 fn process(&mut self, _ctx: &StageContext<'_>) -> Result<StageOutput, StageError> {
338 Ok(StageOutput::empty())
339 }
340 fn category(&self) -> StageCategory {
341 self.cat
342 }
343 }
344
345 #[test]
346 fn builder_preserves_order() {
347 let pipeline = StagePipeline::builder()
348 .add(TestStage {
349 name: "det",
350 cat: StageCategory::FrameAnalysis,
351 })
352 .add(TestStage {
353 name: "trk",
354 cat: StageCategory::Association,
355 })
356 .add(TestStage {
357 name: "temporal",
358 cat: StageCategory::TemporalAnalysis,
359 })
360 .add(TestStage {
361 name: "sink",
362 cat: StageCategory::Sink,
363 })
364 .build();
365
366 let ids: Vec<&str> = pipeline.stage_ids().iter().map(|s| s.as_str()).collect();
367 assert_eq!(ids, vec!["det", "trk", "temporal", "sink"]);
368 }
369
370 #[test]
371 fn categories_reported_correctly() {
372 let pipeline = StagePipeline::builder()
373 .add(TestStage {
374 name: "det",
375 cat: StageCategory::FrameAnalysis,
376 })
377 .add(TestStage {
378 name: "trk",
379 cat: StageCategory::Association,
380 })
381 .build();
382
383 let cats = pipeline.categories();
384 assert_eq!(cats[0].1, StageCategory::FrameAnalysis);
385 assert_eq!(cats[1].1, StageCategory::Association);
386 }
387
388 #[test]
389 fn into_stages_returns_owned_vec() {
390 let pipeline = StagePipeline::builder()
391 .add(TestStage {
392 name: "a",
393 cat: StageCategory::Custom,
394 })
395 .add(TestStage {
396 name: "b",
397 cat: StageCategory::Custom,
398 })
399 .build();
400
401 let stages = pipeline.into_stages();
402 assert_eq!(stages.len(), 2);
403 assert_eq!(stages[0].id(), StageId("a"));
404 assert_eq!(stages[1].id(), StageId("b"));
405 }
406
407 #[test]
408 fn empty_pipeline() {
409 let pipeline = StagePipeline::builder().build();
410 assert!(pipeline.is_empty());
411 assert_eq!(pipeline.len(), 0);
412 }
413
414 struct CapStage {
416 name: &'static str,
417 caps: Option<StageCapabilities>,
418 }
419
420 impl Stage for CapStage {
421 fn id(&self) -> StageId {
422 StageId(self.name)
423 }
424 fn process(&mut self, _ctx: &StageContext<'_>) -> Result<StageOutput, StageError> {
425 Ok(StageOutput::empty())
426 }
427 fn capabilities(&self) -> Option<StageCapabilities> {
428 self.caps
429 }
430 }
431
432 #[test]
433 fn validate_happy_path() {
434 let pipeline = StagePipeline::builder()
435 .add(CapStage {
436 name: "det",
437 caps: Some(StageCapabilities::new().produces_detections()),
438 })
439 .add(CapStage {
440 name: "trk",
441 caps: Some(
442 StageCapabilities::new()
443 .consumes_detections()
444 .produces_tracks(),
445 ),
446 })
447 .build();
448
449 let warnings = pipeline.validate();
450 assert!(warnings.is_empty());
451 }
452
453 #[test]
454 fn validate_unsatisfied_detections() {
455 let pipeline = StagePipeline::builder()
456 .add(CapStage {
457 name: "trk",
458 caps: Some(StageCapabilities::new().consumes_detections()),
459 })
460 .add(CapStage {
461 name: "det",
462 caps: Some(StageCapabilities::new().produces_detections()),
463 })
464 .build();
465
466 let warnings = pipeline.validate();
467 assert_eq!(warnings.len(), 1);
468 assert_eq!(
469 warnings[0],
470 ValidationWarning::UnsatisfiedDependency {
471 stage_id: StageId("trk"),
472 missing: "detections",
473 }
474 );
475 }
476
477 #[test]
478 fn validate_unsatisfied_tracks() {
479 let pipeline = StagePipeline::builder()
480 .add(CapStage {
481 name: "temporal",
482 caps: Some(StageCapabilities::new().consumes_tracks()),
483 })
484 .build();
485
486 let warnings = pipeline.validate();
487 assert_eq!(warnings.len(), 1);
488 assert!(matches!(
489 &warnings[0],
490 ValidationWarning::UnsatisfiedDependency {
491 missing: "tracks",
492 ..
493 }
494 ));
495 }
496
497 #[test]
498 fn validate_skips_stages_without_capabilities() {
499 let pipeline = StagePipeline::builder()
500 .add(CapStage {
501 name: "noop",
502 caps: None,
503 })
504 .add(CapStage {
505 name: "det",
506 caps: Some(StageCapabilities::new().produces_detections()),
507 })
508 .build();
509
510 let warnings = pipeline.validate();
511 assert!(warnings.is_empty());
512 }
513
514 #[test]
515 fn validate_duplicate_stage_ids() {
516 let pipeline = StagePipeline::builder()
517 .add(CapStage {
518 name: "det",
519 caps: None,
520 })
521 .add(CapStage {
522 name: "det",
523 caps: None,
524 })
525 .build();
526
527 let warnings = pipeline.validate();
528 assert_eq!(warnings.len(), 1);
529 assert!(matches!(
530 &warnings[0],
531 ValidationWarning::DuplicateStageId { stage_id } if *stage_id == StageId("det")
532 ));
533 }
534
535 #[test]
536 fn validate_stages_fn_matches_pipeline_validate() {
537 let stages: Vec<Box<dyn Stage>> = vec![
538 Box::new(CapStage {
539 name: "trk",
540 caps: Some(StageCapabilities::new().consumes_detections()),
541 }),
542 Box::new(CapStage {
543 name: "det",
544 caps: Some(StageCapabilities::new().produces_detections()),
545 }),
546 ];
547
548 let warnings = validate_stages(&stages);
549 assert_eq!(warnings.len(), 1);
550 assert_eq!(
551 warnings[0],
552 ValidationWarning::UnsatisfiedDependency {
553 stage_id: StageId("trk"),
554 missing: "detections",
555 }
556 );
557 }
558
559 #[test]
560 fn validate_stages_fn_happy_path() {
561 let stages: Vec<Box<dyn Stage>> = vec![
562 Box::new(CapStage {
563 name: "det",
564 caps: Some(StageCapabilities::new().produces_detections()),
565 }),
566 Box::new(CapStage {
567 name: "trk",
568 caps: Some(
569 StageCapabilities::new()
570 .consumes_detections()
571 .produces_tracks(),
572 ),
573 }),
574 ];
575
576 let warnings = validate_stages(&stages);
577 assert!(warnings.is_empty());
578 }
579
580 #[test]
583 fn phased_happy_path() {
584 let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
585 name: "det",
586 caps: Some(StageCapabilities::new().produces_detections()),
587 })];
588 let batch_caps = StageCapabilities::new()
589 .consumes_detections()
590 .produces_tracks();
591 let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
592 name: "temporal",
593 caps: Some(StageCapabilities::new().consumes_tracks()),
594 })];
595
596 let warnings =
597 validate_pipeline_phased(&pre, Some(&batch_caps), Some(StageId("batch")), &post);
598 assert!(warnings.is_empty());
599 }
600
601 #[test]
602 fn phased_batch_unsatisfied_dependency() {
603 let pre: Vec<Box<dyn Stage>> = vec![];
604 let batch_caps = StageCapabilities::new().consumes_detections();
605
606 let warnings =
607 validate_pipeline_phased(&pre, Some(&batch_caps), Some(StageId("batch")), &[]);
608 assert_eq!(warnings.len(), 1);
609 assert_eq!(
610 warnings[0],
611 ValidationWarning::UnsatisfiedDependency {
612 stage_id: StageId("batch"),
613 missing: "detections",
614 }
615 );
616 }
617
618 #[test]
619 fn phased_post_batch_sees_batch_outputs() {
620 let pre: Vec<Box<dyn Stage>> = vec![];
621 let batch_caps = StageCapabilities::new().produces_detections();
622 let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
623 name: "trk",
624 caps: Some(StageCapabilities::new().consumes_detections()),
625 })];
626
627 let warnings =
628 validate_pipeline_phased(&pre, Some(&batch_caps), Some(StageId("batch")), &post);
629 assert!(warnings.is_empty());
630 }
631
632 #[test]
633 fn phased_duplicate_across_phases() {
634 let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
635 name: "dup",
636 caps: None,
637 })];
638 let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
639 name: "dup",
640 caps: None,
641 })];
642
643 let warnings = validate_pipeline_phased(&pre, None, None, &post);
644 assert_eq!(warnings.len(), 1);
645 assert!(matches!(
646 &warnings[0],
647 ValidationWarning::DuplicateStageId { stage_id } if *stage_id == StageId("dup")
648 ));
649 }
650
651 #[test]
652 fn phased_duplicate_with_batch_id() {
653 let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
654 name: "batch",
655 caps: None,
656 })];
657
658 let warnings = validate_pipeline_phased(&pre, None, Some(StageId("batch")), &[]);
659 assert_eq!(warnings.len(), 1);
660 assert!(matches!(
661 &warnings[0],
662 ValidationWarning::DuplicateStageId { stage_id } if *stage_id == StageId("batch")
663 ));
664 }
665
666 #[test]
667 fn phased_no_batch_processor() {
668 let pre: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
669 name: "det",
670 caps: Some(StageCapabilities::new().produces_detections()),
671 })];
672 let post: Vec<Box<dyn Stage>> = vec![Box::new(CapStage {
673 name: "trk",
674 caps: Some(StageCapabilities::new().consumes_detections()),
675 })];
676
677 let warnings = validate_pipeline_phased(&pre, None, None, &post);
678 assert!(warnings.is_empty());
679 }
680}