use nv_perception::Stage;
use crate::batch::BatchHandle;
#[derive(Debug, thiserror::Error)]
pub enum PipelineError {
#[error("only one batch point is allowed per pipeline")]
DuplicateBatchPoint,
}
pub(crate) type PipelineParts = (
Vec<Box<dyn Stage>>,
Option<BatchHandle>,
Vec<Box<dyn Stage>>,
);
pub struct FeedPipeline {
pub(crate) pre_batch: Vec<Box<dyn Stage>>,
pub(crate) batch: Option<BatchHandle>,
pub(crate) post_batch: Vec<Box<dyn Stage>>,
}
impl FeedPipeline {
#[must_use]
pub fn builder() -> FeedPipelineBuilder {
FeedPipelineBuilder {
pre_batch: Vec::new(),
batch: None,
post_batch: Vec::new(),
after_batch: false,
}
}
#[must_use]
pub fn stage_count(&self) -> usize {
self.pre_batch.len() + self.post_batch.len()
}
#[must_use]
pub fn has_batch(&self) -> bool {
self.batch.is_some()
}
pub(crate) fn into_parts(self) -> PipelineParts {
(self.pre_batch, self.batch, self.post_batch)
}
}
pub struct FeedPipelineBuilder {
pre_batch: Vec<Box<dyn Stage>>,
batch: Option<BatchHandle>,
post_batch: Vec<Box<dyn Stage>>,
after_batch: bool,
}
impl FeedPipelineBuilder {
#[must_use]
pub fn add_stage(mut self, stage: impl Stage) -> Self {
if self.after_batch {
self.post_batch.push(Box::new(stage));
} else {
self.pre_batch.push(Box::new(stage));
}
self
}
#[must_use]
pub fn add_boxed_stage(mut self, stage: Box<dyn Stage>) -> Self {
if self.after_batch {
self.post_batch.push(stage);
} else {
self.pre_batch.push(stage);
}
self
}
pub fn batch(mut self, handle: BatchHandle) -> Result<Self, PipelineError> {
if self.batch.is_some() {
return Err(PipelineError::DuplicateBatchPoint);
}
self.batch = Some(handle);
self.after_batch = true;
Ok(self)
}
#[must_use]
pub fn build(self) -> FeedPipeline {
FeedPipeline {
pre_batch: self.pre_batch,
batch: self.batch,
post_batch: self.post_batch,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use nv_core::error::StageError;
use nv_core::id::StageId;
use nv_perception::{StageContext, StageOutput};
struct DummyStage(&'static str);
impl Stage for DummyStage {
fn id(&self) -> StageId {
StageId(self.0)
}
fn process(&mut self, _ctx: &StageContext<'_>) -> Result<StageOutput, StageError> {
Ok(StageOutput::empty())
}
}
#[test]
fn pipeline_without_batch() {
let p = FeedPipeline::builder()
.add_stage(DummyStage("a"))
.add_stage(DummyStage("b"))
.build();
assert_eq!(p.stage_count(), 2);
assert!(!p.has_batch());
let (pre, batch, post) = p.into_parts();
assert_eq!(pre.len(), 2);
assert!(batch.is_none());
assert!(post.is_empty());
}
#[test]
fn pipeline_stage_count_with_batch() {
let p = FeedPipeline::builder()
.add_stage(DummyStage("pre1"))
.add_stage(DummyStage("pre2"))
.build();
let (pre, _, post) = p.into_parts();
assert_eq!(pre.len(), 2);
assert!(post.is_empty());
}
#[test]
fn double_batch_returns_error() {
use crate::batch::{BatchConfig, BatchCoordinator};
use nv_core::health::HealthEvent;
use nv_core::id::StageId;
use nv_perception::batch::{BatchEntry, BatchProcessor};
use std::time::Duration;
struct Noop;
impl BatchProcessor for Noop {
fn id(&self) -> StageId {
StageId("noop")
}
fn process(&mut self, _: &mut [BatchEntry]) -> Result<(), nv_core::error::StageError> {
Ok(())
}
}
let (health_tx, _) = tokio::sync::broadcast::channel::<HealthEvent>(4);
let coord = BatchCoordinator::start(
Box::new(Noop),
BatchConfig {
max_batch_size: 1,
max_latency: Duration::from_millis(10),
queue_capacity: None,
response_timeout: None,
max_in_flight_per_feed: 1,
startup_timeout: None,
},
health_tx,
)
.unwrap();
let handle = coord.handle();
let result = FeedPipeline::builder()
.batch(handle.clone())
.expect("first batch() should succeed")
.batch(handle);
assert!(result.is_err(), "duplicate batch point should return error");
}
}