use super::super::{
BuildModule, Forward, Module, ModuleExtras, SeqIterative, SeqPacked, SeqSerial,
};
use crate::torch::initializers::Initializer;
use crate::torch::packed::PackedTensor;
use crate::torch::serialize::TensorDef;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use std::iter::{self, Chain, Once};
use std::option;
use tch::{Device, Tensor};
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct LinearConfig {
kernel_init: Initializer,
bias_init: Option<Initializer>,
}
impl Default for LinearConfig {
fn default() -> Self {
Self {
kernel_init: Initializer::default(),
bias_init: Some(Initializer::default()),
}
}
}
impl BuildModule for LinearConfig {
type Module = Linear;
fn build_module(&self, in_dim: usize, out_dim: usize, device: Device) -> Self::Module {
Linear::new(in_dim, out_dim, device, self)
}
}
#[serde_as]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct Linear {
#[serde_as(as = "TensorDef")]
kernel: Tensor,
#[serde_as(as = "Option<TensorDef>")]
bias: Option<Tensor>,
}
impl Linear {
#[must_use]
pub fn new(in_dim: usize, out_dim: usize, device: Device, config: &LinearConfig) -> Self {
let fan_in = in_dim + 1;
Self {
kernel: config
.kernel_init
.tensor(&[out_dim, in_dim])
.device(device)
.fan_in(fan_in)
.build(),
bias: config
.bias_init
.map(|b| b.tensor(&[out_dim]).device(device).fan_in(fan_in).build()),
}
}
}
impl Module for Linear {
fn shallow_clone(&self) -> Self
where
Self: Sized,
{
Self {
kernel: self.kernel.shallow_clone(),
bias: self.bias.as_ref().map(Tensor::shallow_clone),
}
}
fn clone_to_device(&self, device: Device) -> Self
where
Self: Sized,
{
Self {
kernel: self.kernel.to_device(device),
bias: self.bias.as_ref().map(|b| b.to_device(device)),
}
}
#[inline]
fn variables(&self) -> Box<dyn Iterator<Item = &Tensor> + '_> {
Box::new(ModuleExtras::variables(self))
}
#[inline]
fn trainable_variables(&self) -> Box<dyn Iterator<Item = &Tensor> + '_> {
Box::new(ModuleExtras::trainable_variables(self))
}
}
impl<'a> ModuleExtras<'a> for Linear {
type Variables = Chain<Once<&'a Tensor>, option::Iter<'a, Tensor>>;
type TrainableVariables = Self::Variables;
#[inline]
fn variables(&'a self) -> Self::Variables {
iter::once(&self.kernel).chain(self.bias.iter())
}
#[inline]
fn trainable_variables(&'a self) -> Self::TrainableVariables {
ModuleExtras::variables(self)
}
}
impl Forward for Linear {
#[inline]
fn forward(&self, input: &Tensor) -> Tensor {
input.linear(&self.kernel, self.bias.as_ref())
}
}
impl SeqSerial for Linear {
#[inline]
fn seq_serial(&self, inputs: &Tensor, _seq_lengths: &[usize]) -> Tensor {
self.forward(inputs)
}
}
impl SeqPacked for Linear {
#[inline]
fn seq_packed(&self, inputs: &PackedTensor) -> PackedTensor {
inputs.batch_map_ref(|tensor| self.forward(tensor))
}
}
impl SeqIterative for Linear {
type State = ();
#[inline]
fn initial_state(&self) -> Self::State {}
#[inline]
fn step(&self, _: &mut Self::State, input: &Tensor) -> Tensor {
self.forward(input)
}
}
#[cfg(test)]
#[allow(
clippy::needless_pass_by_value,
clippy::used_underscore_binding,
clippy::no_effect_underscore_binding
)]
mod tests {
use super::super::super::testing::{
self, RunForward, RunIterStep, RunModule, RunSeqPacked, RunSeqSerial,
};
use super::*;
use rstest::{fixture, rstest};
use tch::{kind::Kind, Device};
#[fixture]
fn default_module() -> (Linear, usize, usize) {
let in_dim = 3;
let out_dim = 2;
let config = LinearConfig::default();
let module = config.build_module(in_dim, out_dim, Device::Cpu);
(module, in_dim, out_dim)
}
#[fixture]
fn module_no_bias() -> (Linear, usize, usize) {
let in_dim = 3;
let out_dim = 2;
let config = LinearConfig {
bias_init: None,
..LinearConfig::default()
};
let module = config.build_module(in_dim, out_dim, Device::Cpu);
(module, in_dim, out_dim)
}
#[rstest]
fn forward_batch(default_module: (Linear, usize, usize)) {
let (module, in_dim, out_dim) = default_module;
testing::check_forward(&module, in_dim, out_dim, &[4], Kind::Float);
}
#[rstest]
fn seq_serial(default_module: (Linear, usize, usize)) {
let (module, in_dim, out_dim) = default_module;
testing::check_seq_serial(&module, in_dim, out_dim);
}
#[rstest]
fn seq_packed(default_module: (Linear, usize, usize)) {
let (module, in_dim, out_dim) = default_module;
testing::check_seq_packed(&module, in_dim, out_dim);
}
#[rstest]
fn seq_step(default_module: (Linear, usize, usize)) {
let (module, in_dim, out_dim) = default_module;
testing::check_step(&module, in_dim, out_dim);
}
#[rstest]
fn seq_consistent(default_module: (Linear, usize, usize)) {
let (module, in_dim, out_dim) = default_module;
testing::check_seq_packed_matches_iter_steps(&module, in_dim, out_dim);
}
#[rstest]
#[case::forward(RunForward)]
#[case::seq_serial(RunSeqSerial)]
#[case::seq_packed(RunSeqPacked)]
#[case::iter_step(RunIterStep)]
fn gradient_descent<R: RunModule<Linear>>(#[case] _runner: R) {
testing::check_config_gradient_descent::<R, _>(&LinearConfig::default());
}
#[rstest]
#[case::forward(RunForward)]
#[case::seq_serial(RunSeqSerial)]
#[case::seq_packed(RunSeqPacked)]
#[case::iter_step(RunIterStep)]
fn clone_to_new_device<R: RunModule<Linear>>(#[case] _runner: R) {
testing::check_config_clone_to_new_device::<R, _>(&LinearConfig::default());
}
#[test]
fn clone_to_same_device() {
testing::check_config_clone_to_same_device::<RunForward, _>(&LinearConfig::default());
}
#[rstest]
#[case::forward(RunForward)]
#[case::seq_serial(RunSeqSerial)]
#[case::seq_packed(RunSeqPacked)]
#[case::iter_step(RunIterStep)]
fn ser_de_matches<R: RunModule<Linear>>(
#[case] _runner: R,
default_module: (Linear, usize, usize),
) {
let (module, in_dim, _) = default_module;
testing::check_ser_de_matches::<R, _>(&module, in_dim);
}
#[rstest]
fn variables_count_default(default_module: (Linear, usize, usize)) {
let (module, _, _) = default_module;
assert_eq!(Module::variables(&module).count(), 2);
}
#[rstest]
fn variables_count_no_bias(module_no_bias: (Linear, usize, usize)) {
let (module, _, _) = module_no_bias;
assert_eq!(Module::variables(&module).count(), 1);
}
#[rstest]
fn trainable_variables_count_default(default_module: (Linear, usize, usize)) {
let (module, _, _) = default_module;
assert_eq!(Module::trainable_variables(&module).count(), 2);
}
#[rstest]
fn trainable_variables_count_no_bias(module_no_bias: (Linear, usize, usize)) {
let (module, _, _) = module_no_bias;
assert_eq!(Module::trainable_variables(&module).count(), 1);
}
}