gliner/model/input/tensors/
token.rs

1use ort::session::SessionInputs;
2use composable::Composable;
3use crate::util::result::Result;
4use super::super::encoded::EncodedInput;
5use super::super::super::pipeline::context::EntityContext;
6
7
8const TENSOR_INPUT_IDS: &str = "input_ids";
9const TENSOR_ATTENTION_MASK: &str = "attention_mask";
10const TENSOR_WORD_MASK: &str = "words_mask";
11const TENSOR_TEXT_LENGTHS: &str = "text_lengths";
12
13
14/// Ready-for-inference tensors (token mode)
15pub struct TokenTensors<'a> {
16    pub tensors: SessionInputs<'a, 'a>,
17    pub context: EntityContext,    
18}
19
20impl TokenTensors<'_> {
21
22    pub fn from(encoded: EncodedInput) -> Result<Self> {
23        let inputs = ort::inputs!{
24            TENSOR_INPUT_IDS => encoded.input_ids,
25            TENSOR_ATTENTION_MASK => encoded.attention_masks,
26            TENSOR_WORD_MASK => encoded.word_masks,
27            TENSOR_TEXT_LENGTHS => encoded.text_lengths,
28        }?;
29        Ok(Self {
30            tensors: inputs.into(),
31            context: EntityContext { 
32                texts: encoded.texts, 
33                tokens: encoded.tokens, 
34                entities: encoded.entities, 
35                num_words: encoded.num_words 
36            },            
37        })
38    }
39
40    pub fn inputs() -> [&'static str; 4] {
41        [TENSOR_INPUT_IDS, TENSOR_ATTENTION_MASK, TENSOR_WORD_MASK, TENSOR_TEXT_LENGTHS]
42    }
43
44}
45
46
47/// Composable: Encoded => TokenTensors
48#[derive(Default)]
49pub struct EncodedToTensors { }
50
51
52impl<'a> Composable<EncodedInput, TokenTensors<'a>> for EncodedToTensors {
53    fn apply(&self, input: EncodedInput) -> Result<TokenTensors<'a>> {
54        TokenTensors::from(input)
55    }
56}
57
58
59/// Composable: TokenTensors => (SessionInput, TensorsMeta) 
60#[derive(Default)]
61pub struct TensorsToSessionInput { }
62
63
64impl<'a> Composable<TokenTensors<'a>, (SessionInputs<'a, 'a>, EntityContext)> for TensorsToSessionInput {
65    fn apply(&self, input: TokenTensors<'a>) -> Result<(SessionInputs<'a, 'a>, EntityContext)> {
66        Ok((input.tensors, input.context))
67    }
68}