pub struct Seq2SeqGenerationModel<'a> { /* private fields */ }Expand description
Onnx inference session wrapper for the Seq2Seq generation models.
Implementations§
source§impl<'a> Seq2SeqGenerationModel<'a>
impl<'a> Seq2SeqGenerationModel<'a>
pub fn new_from_memory( env: Arc<Environment>, model_bytes: &'a [u8], device: Device, optimization_level: GraphOptimizationLevel ) -> Result<Self, Error>
pub fn new_from_file( env: Arc<Environment>, model_path: PathBuf, device: Device, optimization_level: GraphOptimizationLevel ) -> Result<Self, Error>
pub fn get_token_type_support(&self) -> bool
pub fn get_decoder_token_type_support(&self) -> bool
sourcepub fn forward(
&self,
input_ids: Array2<u32>,
attention_mask: Option<Array2<u32>>,
decoder_input_ids: Array2<u32>,
decoder_attention_mask: Option<Array2<u32>>,
token_type_ids: Option<Array2<u32>>,
decoder_token_type_ids: Option<Array2<u32>>
) -> Result<Array3<f32>, Error>
pub fn forward( &self, input_ids: Array2<u32>, attention_mask: Option<Array2<u32>>, decoder_input_ids: Array2<u32>, decoder_attention_mask: Option<Array2<u32>>, token_type_ids: Option<Array2<u32>>, decoder_token_type_ids: Option<Array2<u32>> ) -> Result<Array3<f32>, Error>
Does inference. Returns logits and the past key values.