1use std::path::Path;
4use orp::pipeline::{Pipeline, PreProcessor, PostProcessor};
5use crate::params::Parameters;
6use crate::tokenizer::Tokenizer;
7use crate::util::result::Result;
8
9pub struct ClassificationPipeline {
11 tokenizer: Tokenizer,
12}
13
14
15impl ClassificationPipeline {
16 pub fn new<P: AsRef<Path>>(tokenizer_path: P, _params: &Parameters) -> Result<Self> {
17 Ok(Self {
18 tokenizer: Tokenizer::new(tokenizer_path, None)?
19 })
20 }
21}
22
23
24impl<'a> Pipeline<'a> for ClassificationPipeline {
25 type Input = super::input::text::TextInput;
26 type Output = super::output::classes::Classes;
27 type Context = super::input::text::Labels;
28 type Parameters = Parameters;
29
30 fn pre_processor(&self, params: &Self::Parameters) -> impl PreProcessor<'a, Self::Input, Self::Context> {
31 composable::composed![
32 super::input::prompt::InputToPrompt::with_params(params),
33 super::input::encoded::PromptsToEncoded::new(&self.tokenizer),
34 super::input::tensors::InputTensors::try_from,
35 super::input::tensors::InputTensors::try_into
36 ]
37 }
38
39 fn post_processor(&self, _params: &Self::Parameters) -> impl PostProcessor<'a, Self::Output, Self::Context> {
40 composable::composed![
41 super::output::tensors::OutputTensors::try_from,
42 |tensors| super::output::classes::Classes::try_from(tensors)
43 ]
44 }
45}