1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
//! Definition of interfaces of neural networks.
use anyhow::Result;
use std::path::Path;
use tch::{nn, nn::VarStore, Tensor};

/// Base interface.
pub trait ModelBase {
    /// Trains the network given a loss.
    fn backward_step(&mut self, loss: &Tensor);

    /// Returns `var_store` as mutable reference.
    fn get_var_store_mut(&mut self) -> &mut nn::VarStore;

    /// Returns `var_store`.
    fn get_var_store(&self) -> &nn::VarStore;

    /// Save parameters of the neural network.
    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;

    /// Load parameters of the neural network.
    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
}

/// Neural networks with a single input and a single output.
pub trait Model1: ModelBase {
    /// The input of the neural network.
    type Input;
    /// The output of the neural network.
    type Output;

    /// Performs forward computation given an input.
    fn forward(&self, xs: &Self::Input) -> Self::Output;

    /// TODO: check places this method is used in code.
    fn in_shape(&self) -> &[usize];

    /// TODO: check places this method is used in code.
    fn out_dim(&self) -> usize;
}

/// Neural networks with double inputs and a single output.
pub trait Model2: ModelBase {
    /// An input of the neural network.
    type Input1;
    /// The other input of the neural network.
    type Input2;
    /// The output of the neural network.
    type Output;

    /// Performs forward computation given a pair of inputs.
    fn forward(&self, x1s: &Self::Input1, x2s: &Self::Input2) -> Self::Output;
}

/// Neural network model that can be initialized with [VarStore] and configuration.
///
/// The purpose of this trait is for modularity of neural network models.
/// Modules, which consists a neural network, should share [VarStore].
/// To do this, structs implementing this trait can be initialized with a given [VarStore].
/// This trait also provide the ability to clone with a given [VarStore].
/// The ability is useful when creating a target network, used in recent deep learning algorithms in common.
pub trait SubModel {
    /// Configuration from which [SubModel] is constructed.
    type Config;

    /// Input of the [SubModel].
    type Input;

    /// Output of the [SubModel].
    type Output;

    /// Builds [SubModel] with [VarStore] and [SubModel::Config].
    fn build(var_store: &VarStore, config: Self::Config) -> Self;

    /// Clones [SubModel] with [VarStore].
    fn clone_with_var_store(&self, var_store: &VarStore) -> Self;

    /// A generalized forward function.
    fn forward(&self, input: &Self::Input) -> Self::Output;
}

/// Neural network model that can be initialized with [VarStore] and configuration.
///
/// The difference from [SubModel] is that this trait takes two inputs.
pub trait SubModel2 {
    /// Configuration from which [SubModel2] is constructed.
    type Config;

    /// Input of the [SubModel2].
    type Input1;

    /// Input of the [SubModel2].
    type Input2;

    /// Output of the [SubModel2].
    type Output;

    /// Builds [SubModel2] with [VarStore] and [SubModel2::Config].
    fn build(var_store: &VarStore, config: Self::Config) -> Self;

    /// Clones [SubModel2] with [VarStore].
    fn clone_with_var_store(&self, var_store: &VarStore) -> Self;

    /// A generalized forward function.
    fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output;
}