gliclass/
pipeline.rs

1//! Orp pipeline for GLiClass classification
2 
3use std::path::Path;
4use orp::pipeline::{Pipeline, PreProcessor, PostProcessor};
5use crate::params::Parameters;
6use crate::tokenizer::Tokenizer;
7use crate::util::result::Result;
8
9/// Pipeline for GLiClass classification, to be executed using [`orp`](https://github.com/fbilhaut/orp).
10pub struct ClassificationPipeline {
11    tokenizer: Tokenizer,
12}
13
14
15impl ClassificationPipeline {
16    pub fn new<P: AsRef<Path>>(tokenizer_path: P, _params: &Parameters) -> Result<Self> {
17        Ok(Self { 
18            tokenizer: Tokenizer::new(tokenizer_path, None)?
19        })
20    }
21}
22
23
24impl<'a> Pipeline<'a> for ClassificationPipeline {
25    type Input = super::input::text::TextInput;
26    type Output = super::output::classes::Classes;
27    type Context = super::input::text::Labels;
28    type Parameters = Parameters;
29
30    fn pre_processor(&self, params: &Self::Parameters) -> impl PreProcessor<'a, Self::Input, Self::Context> {
31        composable::composed![
32            super::input::prompt::InputToPrompt::with_params(params),
33            super::input::encoded::PromptsToEncoded::new(&self.tokenizer),
34            super::input::tensors::InputTensors::try_from,
35            super::input::tensors::InputTensors::try_into
36        ]
37    }
38
39    fn post_processor(&self, _params: &Self::Parameters) -> impl PostProcessor<'a, Self::Output, Self::Context> {
40        composable::composed![
41            super::output::tensors::OutputTensors::try_from,
42            |tensors| super::output::classes::Classes::try_from(tensors)
43        ]
44    }
45}