gliclass/input/
encoded.rs

1use composable::Composable;
2use crate::tokenizer::Tokenizer;
3use super::{prompt::PromptInput, text::Labels};
4use ndarray::Array2;
5
6/// Encoded sequences
7pub struct EncodedInput {
8    pub labels: Labels,    
9    pub input_ids: Array2<i64>,
10    pub attention_masks: Array2<i64>,    
11}
12
13
14pub struct PromptsToEncoded<'a> {
15    tokenizer: &'a Tokenizer,
16}
17
18impl<'a> PromptsToEncoded<'a> {
19    pub fn new(tokenizer: &'a Tokenizer) -> Self {
20        Self { tokenizer }
21    }
22}
23
24/// Transformation from prompts to encoded sequences
25impl Composable<PromptInput, EncodedInput> for PromptsToEncoded<'_> {
26    fn apply(&self, input: PromptInput) -> Result<EncodedInput, Box<dyn std::error::Error + Send + Sync>> {
27        let (input_ids, attention_masks) = self.tokenizer.tokenize(input.prompts)?;
28        Ok(EncodedInput {
29            labels: input.labels,
30            input_ids,
31            attention_masks,
32        })
33    }
34}