1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
//! Definition of interfaces of neural networks.
use anyhow::Result;
use std::path::Path;
use tch::{nn, nn::VarStore, Tensor};
/// Base interface.
pub trait ModelBase {
/// Trains the network given a loss.
fn backward_step(&mut self, loss: &Tensor);
/// Returns `var_store` as mutable reference.
fn get_var_store_mut(&mut self) -> &mut nn::VarStore;
/// Returns `var_store`.
fn get_var_store(&self) -> &nn::VarStore;
/// Save parameters of the neural network.
fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;
/// Load parameters of the neural network.
fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
}
/// Neural networks with a single input and a single output.
pub trait Model1: ModelBase {
/// The input of the neural network.
type Input;
/// The output of the neural network.
type Output;
/// Performs forward computation given an input.
fn forward(&self, xs: &Self::Input) -> Self::Output;
/// TODO: check places this method is used in code.
fn in_shape(&self) -> &[usize];
/// TODO: check places this method is used in code.
fn out_dim(&self) -> usize;
}
/// Neural networks with double inputs and a single output.
pub trait Model2: ModelBase {
/// An input of the neural network.
type Input1;
/// The other input of the neural network.
type Input2;
/// The output of the neural network.
type Output;
/// Performs forward computation given a pair of inputs.
fn forward(&self, x1s: &Self::Input1, x2s: &Self::Input2) -> Self::Output;
}
/// Neural network model that can be initialized with [VarStore] and configuration.
///
/// The purpose of this trait is for modularity of neural network models.
/// Modules, which consists a neural network, should share [VarStore].
/// To do this, structs implementing this trait can be initialized with a given [VarStore].
/// This trait also provide the ability to clone with a given [VarStore].
/// The ability is useful when creating a target network, used in recent deep learning algorithms in common.
pub trait SubModel {
/// Configuration from which [SubModel] is constructed.
type Config;
/// Input of the [SubModel].
type Input;
/// Output of the [SubModel].
type Output;
/// Builds [SubModel] with [VarStore] and [SubModel::Config].
fn build(var_store: &VarStore, config: Self::Config) -> Self;
/// Clones [SubModel] with [VarStore].
fn clone_with_var_store(&self, var_store: &VarStore) -> Self;
/// A generalized forward function.
fn forward(&self, input: &Self::Input) -> Self::Output;
}
/// Neural network model that can be initialized with [VarStore] and configuration.
///
/// The difference from [SubModel] is that this trait takes two inputs.
pub trait SubModel2 {
/// Configuration from which [SubModel2] is constructed.
type Config;
/// Input of the [SubModel2].
type Input1;
/// Input of the [SubModel2].
type Input2;
/// Output of the [SubModel2].
type Output;
/// Builds [SubModel2] with [VarStore] and [SubModel2::Config].
fn build(var_store: &VarStore, config: Self::Config) -> Self;
/// Clones [SubModel2] with [VarStore].
fn clone_with_var_store(&self, var_store: &VarStore) -> Self;
/// A generalized forward function.
fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output;
}