gliner/model/input/tensors/
token.rs1use ort::session::SessionInputs;
2use composable::Composable;
3use crate::util::result::Result;
4use super::super::encoded::EncodedInput;
5use super::super::super::pipeline::context::EntityContext;
6
7
8const TENSOR_INPUT_IDS: &str = "input_ids";
9const TENSOR_ATTENTION_MASK: &str = "attention_mask";
10const TENSOR_WORD_MASK: &str = "words_mask";
11const TENSOR_TEXT_LENGTHS: &str = "text_lengths";
12
13
14pub struct TokenTensors<'a> {
16 pub tensors: SessionInputs<'a, 'a>,
17 pub context: EntityContext,
18}
19
20impl TokenTensors<'_> {
21
22 pub fn from(encoded: EncodedInput) -> Result<Self> {
23 let inputs = ort::inputs!{
24 TENSOR_INPUT_IDS => encoded.input_ids,
25 TENSOR_ATTENTION_MASK => encoded.attention_masks,
26 TENSOR_WORD_MASK => encoded.word_masks,
27 TENSOR_TEXT_LENGTHS => encoded.text_lengths,
28 }?;
29 Ok(Self {
30 tensors: inputs.into(),
31 context: EntityContext {
32 texts: encoded.texts,
33 tokens: encoded.tokens,
34 entities: encoded.entities,
35 num_words: encoded.num_words
36 },
37 })
38 }
39
40 pub fn inputs() -> [&'static str; 4] {
41 [TENSOR_INPUT_IDS, TENSOR_ATTENTION_MASK, TENSOR_WORD_MASK, TENSOR_TEXT_LENGTHS]
42 }
43
44}
45
46
47#[derive(Default)]
49pub struct EncodedToTensors { }
50
51
52impl<'a> Composable<EncodedInput, TokenTensors<'a>> for EncodedToTensors {
53 fn apply(&self, input: EncodedInput) -> Result<TokenTensors<'a>> {
54 TokenTensors::from(input)
55 }
56}
57
58
59#[derive(Default)]
61pub struct TensorsToSessionInput { }
62
63
64impl<'a> Composable<TokenTensors<'a>, (SessionInputs<'a, 'a>, EntityContext)> for TensorsToSessionInput {
65 fn apply(&self, input: TokenTensors<'a>) -> Result<(SessionInputs<'a, 'a>, EntityContext)> {
66 Ok((input.tensors, input.context))
67 }
68}