use snafu::ResultExt;
use svod_dtype::DType;
use svod_tensor::Tensor;
use crate::init::fan_in_uniform;
use crate::state::{self, HasStateDict, StateDict};
use crate::{load_state_field, state_field};
use crate::gigaam::Result;
use crate::gigaam::error::TensorSnafu;
#[derive(Clone)]
pub struct RnntJoint {
pub enc_w: Tensor,
pub enc_b: Tensor,
pub pred_w: Tensor,
pub pred_b: Tensor,
pub out_w: Tensor,
pub out_b: Tensor,
}
impl RnntJoint {
pub fn empty(enc_hidden: usize, pred_hidden: usize, joint_hidden: usize, num_classes: usize) -> Self {
Self {
enc_w: fan_in_uniform(&[joint_hidden, enc_hidden], enc_hidden, DType::Float32),
enc_b: fan_in_uniform(&[joint_hidden], enc_hidden, DType::Float32),
pred_w: fan_in_uniform(&[joint_hidden, pred_hidden], pred_hidden, DType::Float32),
pred_b: fan_in_uniform(&[joint_hidden], pred_hidden, DType::Float32),
out_w: fan_in_uniform(&[num_classes, joint_hidden], joint_hidden, DType::Float32),
out_b: fan_in_uniform(&[num_classes], joint_hidden, DType::Float32),
}
}
pub fn forward(&self, enc_t: &Tensor, g: &Tensor) -> Result<Tensor> {
let enc_proj = enc_t.linear().weight(&self.enc_w).bias(&self.enc_b).call().context(TensorSnafu)?;
let pred_proj = g.linear().weight(&self.pred_w).bias(&self.pred_b).call().context(TensorSnafu)?;
let summed = enc_proj.try_add(&pred_proj).context(TensorSnafu)?;
let activated = summed.relu().context(TensorSnafu)?;
let logits = activated.linear().weight(&self.out_w).bias(&self.out_b).call().context(TensorSnafu)?;
logits.log_softmax(-1isize).context(TensorSnafu)
}
}
impl HasStateDict for RnntJoint {
fn state_dict(&self, prefix: &str) -> StateDict {
let mut sd = StateDict::new();
state_field!(sd, prefix, self, [enc_w, enc_b, pred_w, pred_b, out_w, out_b]);
sd
}
fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
load_state_field!(self, sd, prefix, [enc_w, enc_b, pred_w, pred_b, out_w, out_b]);
Ok(())
}
}