extern crate self as svod_model;
use svod_macros::jit_wrapper;
use crate::gigaam::model::GigaAm;
jit_wrapper! {
RnntPredictorStepJit(GigaAm) {
prev_token: Tensor,
h_in: Tensor,
c_in: Tensor,
build(prev_token, h_in, c_in) {
let (rnnt_head, _) = model.head.expect_rnnt("RnntPredictorStepJit")?;
rnnt_head.predictor.forward_concat(prev_token, h_in, c_in)
}
}
}
jit_wrapper! {
RnntJointStepJit(GigaAm) {
enc_t: Tensor,
g: Tensor,
build(enc_t, g) {
let (rnnt_head, _) = model.head.expect_rnnt("RnntJointStepJit")?;
rnnt_head.joint.forward(enc_t, g)
}
}
}