1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
//! Definition of interfaces of neural networks.
use anyhow::Result;
use std::path::Path;
use tch::{nn, nn::VarStore, Tensor};

/// Base interface.
pub trait ModelBase {
    /// Trains the network given a loss.
    fn backward_step(&mut self, loss: &Tensor);

    /// Returns `var_store` as mutable reference.
    fn get_var_store_mut(&mut self) -> &mut nn::VarStore;

    /// Returns `var_store`.
    fn get_var_store(&self) -> &nn::VarStore;

    /// Save parameters of the neural network.
    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;

    /// Load parameters of the neural network.
    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
}

/// Neural networks with a single input and a single output.
pub trait Model1: ModelBase {
    /// The input of the neural network.
    type Input;
    /// The output of the neural network.
    type Output;

    /// Performs forward computation given an input.
    fn forward(&self, xs: &Self::Input) -> Self::Output;

    /// TODO: check places this method is used in code.
    fn in_shape(&self) -> &[usize];

    /// TODO: check places this method is used in code.
    fn out_dim(&self) -> usize;
}

/// Neural networks with double inputs and a single output.
pub trait Model2: ModelBase {
    /// An input of the neural network.
    type Input1;
    /// The other input of the neural network.
    type Input2;
    /// The output of the neural network.
    type Output;

    /// Performs forward computation given a pair of inputs.
    fn forward(&self, x1s: &Self::Input1, x2s: &Self::Input2) -> Self::Output;
}

/// Neural network model that can be initialized with [VarStore] and configuration.
///
/// The purpose of this trait is for modularity of neural network models.
/// Modules, which consists a neural network, should share [VarStore].
/// To do this, structs implementing this trait can be initialized with a given [VarStore].
/// This trait also provide the ability to clone with a given [VarStore].
/// The ability is useful when creating a target network, used in recent deep learning algorithms in common.
pub trait SubModel {
    /// Configuration from which [SubModel] is constructed.
    type Config;

    /// Input of the [SubModel].
    type Input;

    /// Output of the [SubModel].
    type Output;

    /// Builds [SubModel] with [VarStore] and [SubModel::Config].
    fn build(var_store: &VarStore, config: Self::Config) -> Self;

    /// Clones [SubModel] with [VarStore].
    fn clone_with_var_store(&self, var_store: &VarStore) -> Self;

    /// A generalized forward function.
    fn forward(&self, input: &Self::Input) -> Self::Output;
}

/// Neural network model that can be initialized with [VarStore] and configuration.
///
/// The difference from [SubModel] is that this trait takes two inputs.
pub trait SubModel2 {
    /// Configuration from which [SubModel2] is constructed.
    type Config;

    /// Input of the [SubModel2].
    type Input1;

    /// Input of the [SubModel2].
    type Input2;

    /// Output of the [SubModel2].
    type Output;

    /// Builds [SubModel2] with [VarStore] and [SubModel2::Config].
    fn build(var_store: &VarStore, config: Self::Config) -> Self;

    /// Clones [SubModel2] with [VarStore].
    fn clone_with_var_store(&self, var_store: &VarStore) -> Self;

    /// A generalized forward function.
    fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output;
}