gliclass/input/
tensors.rs1use ort::session::SessionInputs;
2use crate::util::result::Result;
3use super::{encoded::EncodedInput, text::Labels};
4
5
6const TENSOR_INPUT_IDS: &str = "input_ids";
7const TENSOR_ATTN_MASKS: &str = "attention_mask";
8
9
10pub struct InputTensors<'a> {
12 pub inputs: SessionInputs<'a, 'a>,
13 pub labels: Labels,
14}
15
16
17impl TryFrom<EncodedInput> for InputTensors<'_> {
18 type Error = crate::util::result::Error;
19
20 fn try_from(input: EncodedInput) -> Result<Self> {
21 Ok(Self {
22 labels: input.labels,
23 inputs: ort::inputs!{
24 TENSOR_INPUT_IDS => input.input_ids,
25 TENSOR_ATTN_MASKS => input.attention_masks,
26 }?.into(),
27 })
28 }
29}
30
31
32impl<'a> TryInto<(SessionInputs<'a, 'a>, Labels)> for InputTensors<'a> {
33 type Error = crate::util::result::Error;
34
35 fn try_into(self) -> Result<(SessionInputs<'a, 'a>, Labels)> {
36 Ok((self.inputs, self.labels))
37 }
38}