gte/commons/input/
tensors.rs1use ort::session::SessionInputs;
2use crate::util::result::Result;
3use crate::commons::input::encoded::EncodedInput;
4
5const TENSOR_INPUT_IDS: &str = "input_ids";
6const TENSOR_ATTN_MASKS: &str = "attention_mask";
7
8pub struct InputTensors<'a> {
10 pub inputs: SessionInputs<'a, 'a>,
11}
12
13
14impl TryFrom<EncodedInput> for InputTensors<'_> {
15 type Error = crate::util::result::Error;
16
17 fn try_from(input: EncodedInput) -> Result<Self> {
18 Ok(Self {
19 inputs: ort::inputs!{
20 TENSOR_INPUT_IDS => input.input_ids,
21 TENSOR_ATTN_MASKS => input.attn_masks,
22 }?.into()
23 })
24 }
25}
26
27
28impl<'a> TryInto<(SessionInputs<'a, 'a>, ())> for InputTensors<'a> {
29 type Error = crate::util::result::Error;
30
31 fn try_into(self) -> Result<(SessionInputs<'a, 'a>, ())> {
32 Ok((self.inputs, ()))
33 }
34}