gliclass/input/
tensors.rs

1use ort::session::SessionInputs;
2use crate::util::result::Result;
3use super::{encoded::EncodedInput, text::Labels};
4
5
6const TENSOR_INPUT_IDS: &str = "input_ids";
7const TENSOR_ATTN_MASKS: &str = "attention_mask";
8
9
10/// Input tensors, ready for inferences
11pub struct InputTensors<'a> {
12    pub inputs: SessionInputs<'a, 'a>,
13    pub labels: Labels,
14}
15
16
17impl TryFrom<EncodedInput> for InputTensors<'_> {
18    type Error = crate::util::result::Error;
19
20    fn try_from(input: EncodedInput) -> Result<Self> {
21        Ok(Self {
22            labels: input.labels,
23            inputs: ort::inputs!{
24                TENSOR_INPUT_IDS => input.input_ids,
25                TENSOR_ATTN_MASKS => input.attention_masks,    
26            }?.into(),            
27        })
28    }
29}
30
31
32impl<'a> TryInto<(SessionInputs<'a, 'a>, Labels)> for InputTensors<'a> {
33    type Error = crate::util::result::Error;
34
35    fn try_into(self) -> Result<(SessionInputs<'a, 'a>, Labels)> {
36        Ok((self.inputs, self.labels))
37    }    
38}