1use yscv_tensor::Tensor;
2
3use super::error::ModelError;
4use super::sequential::SequentialModel;
5
6pub struct InferencePipeline {
9 preprocess: Option<Box<dyn Fn(&Tensor) -> Result<Tensor, ModelError>>>,
10 model: SequentialModel,
11 postprocess: Option<Box<dyn Fn(&Tensor) -> Result<Tensor, ModelError>>>,
12}
13
14impl InferencePipeline {
15 pub fn new(model: SequentialModel) -> Self {
17 Self {
18 preprocess: None,
19 model,
20 postprocess: None,
21 }
22 }
23
24 pub fn with_preprocess<F>(mut self, f: F) -> Self
27 where
28 F: Fn(&Tensor) -> Result<Tensor, ModelError> + 'static,
29 {
30 self.preprocess = Some(Box::new(f));
31 self
32 }
33
34 pub fn with_postprocess<F>(mut self, f: F) -> Self
37 where
38 F: Fn(&Tensor) -> Result<Tensor, ModelError> + 'static,
39 {
40 self.postprocess = Some(Box::new(f));
41 self
42 }
43
44 pub fn run(&self, input: &Tensor) -> Result<Tensor, ModelError> {
47 let preprocessed = match &self.preprocess {
48 Some(f) => f(input)?,
49 None => input.clone(),
50 };
51 let output = self.model.forward_inference(&preprocessed)?;
52 match &self.postprocess {
53 Some(f) => f(&output),
54 None => Ok(output),
55 }
56 }
57
58 pub fn run_batch(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
60 inputs.iter().map(|input| self.run(input)).collect()
61 }
62}