gte/commons/input/
tensors.rs

1use ort::session::SessionInputs;
2use crate::util::result::Result;
3use crate::commons::input::encoded::EncodedInput;
4
5const TENSOR_INPUT_IDS: &str = "input_ids";
6const TENSOR_ATTN_MASKS: &str = "attention_mask";
7
8/// Input tensors, ready for inferences
9pub struct InputTensors<'a> {
10    pub inputs: SessionInputs<'a, 'a>,    
11}
12
13
14impl TryFrom<EncodedInput> for InputTensors<'_> {
15    type Error = crate::util::result::Error;
16
17    fn try_from(input: EncodedInput) -> Result<Self> {
18        Ok(Self {
19            inputs: ort::inputs!{
20                TENSOR_INPUT_IDS => input.input_ids,
21                TENSOR_ATTN_MASKS => input.attn_masks,    
22            }?.into()
23        })
24    }
25}
26
27
28impl<'a> TryInto<(SessionInputs<'a, 'a>, ())> for InputTensors<'a> {
29    type Error = crate::util::result::Error;
30
31    fn try_into(self) -> Result<(SessionInputs<'a, 'a>, ())> {
32        Ok((self.inputs, ()))
33    }    
34}