Skip to main content

yscv_model/
pipeline.rs

1use yscv_tensor::Tensor;
2
3use super::error::ModelError;
4use super::sequential::SequentialModel;
5
6/// Builder-style inference pipeline that wraps a [`SequentialModel`] with
7/// optional pre- and post-processing closures.
8pub struct InferencePipeline {
9    preprocess: Option<Box<dyn Fn(&Tensor) -> Result<Tensor, ModelError>>>,
10    model: SequentialModel,
11    postprocess: Option<Box<dyn Fn(&Tensor) -> Result<Tensor, ModelError>>>,
12}
13
14impl InferencePipeline {
15    /// Create a new pipeline wrapping the given model with no pre/post processing.
16    pub fn new(model: SequentialModel) -> Self {
17        Self {
18            preprocess: None,
19            model,
20            postprocess: None,
21        }
22    }
23
24    /// Attach a preprocessing closure that transforms the input tensor before
25    /// it is fed to the model.
26    pub fn with_preprocess<F>(mut self, f: F) -> Self
27    where
28        F: Fn(&Tensor) -> Result<Tensor, ModelError> + 'static,
29    {
30        self.preprocess = Some(Box::new(f));
31        self
32    }
33
34    /// Attach a postprocessing closure that transforms the model output tensor
35    /// before it is returned to the caller.
36    pub fn with_postprocess<F>(mut self, f: F) -> Self
37    where
38        F: Fn(&Tensor) -> Result<Tensor, ModelError> + 'static,
39    {
40        self.postprocess = Some(Box::new(f));
41        self
42    }
43
44    /// Run the full pipeline on a single input tensor:
45    /// preprocess (if set) -> model forward inference -> postprocess (if set).
46    pub fn run(&self, input: &Tensor) -> Result<Tensor, ModelError> {
47        let preprocessed = match &self.preprocess {
48            Some(f) => f(input)?,
49            None => input.clone(),
50        };
51        let output = self.model.forward_inference(&preprocessed)?;
52        match &self.postprocess {
53            Some(f) => f(&output),
54            None => Ok(output),
55        }
56    }
57
58    /// Run the pipeline on a batch of input tensors, collecting results.
59    pub fn run_batch(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
60        inputs.iter().map(|input| self.run(input)).collect()
61    }
62}