pub struct Seq2SeqDecoderModelWithPKVs<'a> { /* private fields */ }Expand description
Onnx inference session wrapper for the conditional generation models.
Validates inputs and outputs of the model and provides a convenient interface to the model.
Implementations§
source§impl<'a> Seq2SeqDecoderModelWithPKVs<'a>
impl<'a> Seq2SeqDecoderModelWithPKVs<'a>
pub fn new_from_memory( env: Arc<Environment>, model_bytes: &'a [u8], model_with_pkvs_bytes: &'a [u8], device: Device, optimization_level: GraphOptimizationLevel ) -> Result<Self, Error>
pub fn new_from_file( env: Arc<Environment>, model_path: PathBuf, model_with_pkvs_path: PathBuf, device: Device, optimization_level: GraphOptimizationLevel ) -> Result<Self, Error>
sourcepub fn forward(
&self,
input_ids: Array2<u32>,
encoder_last_hidden_state: Array3<f32>,
encoder_attention_mask: Option<Array2<u32>>,
past_key_values: Option<HashMap<String, ArrayD<f32>>>
) -> Result<(Array3<f32>, HashMap<String, ArrayD<f32>>), Error>
pub fn forward( &self, input_ids: Array2<u32>, encoder_last_hidden_state: Array3<f32>, encoder_attention_mask: Option<Array2<u32>>, past_key_values: Option<HashMap<String, ArrayD<f32>>> ) -> Result<(Array3<f32>, HashMap<String, ArrayD<f32>>), Error>
Does inference. Returns logits and the past key values.