use std::borrow::Borrow;
use tch::nn::{Init, Path};
use tch::Tensor;
#[derive(Clone, Copy, Debug)]
pub struct PairwiseBilinearConfig {
pub in_features: i64,
pub out_features: i64,
pub initializer_range: f64,
}
#[derive(Debug)]
pub struct PairwiseBilinear {
weight: Tensor,
}
impl PairwiseBilinear {
pub fn new<'a>(vs: impl Borrow<Path<'a>>, config: &PairwiseBilinearConfig) -> Self {
assert!(
config.in_features > 0,
"in_features should be > 0, was: {}",
config.in_features,
);
assert!(
config.out_features > 0,
"out_features should be > 0, was: {}",
config.out_features,
);
let vs = vs.borrow();
let weight = vs.var(
"weight",
&[config.in_features, config.out_features, config.in_features],
Init::Randn {
mean: 0.,
stdev: config.initializer_range,
},
);
PairwiseBilinear { weight }
}
pub fn forward(&self, input1: &Tensor, input2: &Tensor) -> Tensor {
assert_eq!(
input1.size(),
input2.size(),
"Inputs to Bilinear must have the same shape: {:?} {:?}",
input1.size(),
input2.size()
);
assert_eq!(
input1.dim(),
3,
"Shape should have 3 dimensions, has: {}",
input1.dim()
);
let intermediate = Tensor::einsum("blh,hfh->blfh", &[input1, &self.weight]);
Tensor::einsum("blh,bmfh->blmf", &[input2, &intermediate])
}
}
#[cfg(test)]
mod tests {
use tch::nn::VarStore;
use tch::{Device, Kind, Tensor};
use crate::model::pairwise_bilinear::{PairwiseBilinear, PairwiseBilinearConfig};
#[test]
fn bilinear_correct_shapes() {
let input1 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
let input2 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
let vs = VarStore::new(Device::Cpu);
let bilinear = PairwiseBilinear::new(
vs.root(),
&PairwiseBilinearConfig {
in_features: 200,
out_features: 5,
initializer_range: 0.02,
},
);
assert_eq!(bilinear.forward(&input1, &input2).size(), &[64, 10, 10, 5]);
}
}