gliclass/input/
encoded.rs1use composable::Composable;
2use crate::tokenizer::Tokenizer;
3use super::{prompt::PromptInput, text::Labels};
4use ndarray::Array2;
5
6pub struct EncodedInput {
8 pub labels: Labels,
9 pub input_ids: Array2<i64>,
10 pub attention_masks: Array2<i64>,
11}
12
13
14pub struct PromptsToEncoded<'a> {
15 tokenizer: &'a Tokenizer,
16}
17
18impl<'a> PromptsToEncoded<'a> {
19 pub fn new(tokenizer: &'a Tokenizer) -> Self {
20 Self { tokenizer }
21 }
22}
23
24impl Composable<PromptInput, EncodedInput> for PromptsToEncoded<'_> {
26 fn apply(&self, input: PromptInput) -> Result<EncodedInput, Box<dyn std::error::Error + Send + Sync>> {
27 let (input_ids, attention_masks) = self.tokenizer.tokenize(input.prompts)?;
28 Ok(EncodedInput {
29 labels: input.labels,
30 input_ids,
31 attention_masks,
32 })
33 }
34}