border_tch_agent/model/
base.rs

1//! Definition of interfaces of neural networks.
2use anyhow::Result;
3use std::path::Path;
4use tch::{nn, nn::VarStore, Tensor};
5
6/// Base interface.
7pub trait ModelBase {
8    /// Trains the network given a loss.
9    fn backward_step(&mut self, loss: &Tensor);
10
11    /// Returns `var_store` as mutable reference.
12    fn get_var_store_mut(&mut self) -> &mut nn::VarStore;
13
14    /// Returns `var_store`.
15    fn get_var_store(&self) -> &nn::VarStore;
16
17    /// Save parameters of the neural network.
18    fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;
19
20    /// Load parameters of the neural network.
21    fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
22}
23
24/// Neural networks with a single input and a single output.
25#[allow(dead_code)]
26pub trait Model1: ModelBase {
27    /// The input of the neural network.
28    type Input;
29    /// The output of the neural network.
30    type Output;
31
32    /// Performs forward computation given an input.
33    fn forward(&self, xs: &Self::Input) -> Self::Output;
34
35    /// TODO: check places this method is used in code.
36    fn in_shape(&self) -> &[usize];
37
38    /// TODO: check places this method is used in code.
39    fn out_dim(&self) -> usize;
40}
41
42/// Neural networks with double inputs and a single output.
43#[allow(dead_code)]
44pub trait Model2: ModelBase {
45    /// An input of the neural network.
46    type Input1;
47    /// The other input of the neural network.
48    type Input2;
49    /// The output of the neural network.
50    type Output;
51
52    /// Performs forward computation given a pair of inputs.
53    fn forward(&self, x1s: &Self::Input1, x2s: &Self::Input2) -> Self::Output;
54}
55
56/// Neural network model that can be initialized with [`VarStore`] and configuration.
57///
58/// The purpose of this trait is for modularity of neural network models.
59/// Modules, which consists a neural network, should share [`VarStore`].
60/// To do this, structs implementing this trait can be initialized with a given [`VarStore`].
61/// This trait also provide the ability to clone with a given [`VarStore`].
62/// The ability is useful when creating a target network, used in recent deep learning algorithms in common.
63///
64/// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html
65pub trait SubModel {
66    /// Configuration from which [`SubModel`] is constructed.
67    type Config;
68
69    /// Input of the [`SubModel`].
70    type Input;
71
72    /// Output of the [`SubModel`].
73    type Output;
74
75    /// Builds [`SubModel`] with [`VarStore`] and [`SubModel::Config`].
76    fn build(var_store: &VarStore, config: Self::Config) -> Self;
77
78    /// Clones [`SubModel`] with [`VarStore`].
79    ///
80    /// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html
81    fn clone_with_var_store(&self, var_store: &VarStore) -> Self;
82
83    /// A generalized forward function.
84    fn forward(&self, input: &Self::Input) -> Self::Output;
85}
86
87/// Neural network model that can be initialized with [`VarStore`] and configuration.
88///
89/// The difference from [`SubModel`] is that this trait takes two inputs.
90///
91/// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html
92pub trait SubModel2 {
93    /// Configuration from which [`SubModel2`] is constructed.
94    type Config;
95
96    /// Input of the [`SubModel2`].
97    type Input1;
98
99    /// Input of the [`SubModel2`].
100    type Input2;
101
102    /// Output of the [`SubModel2`].
103    type Output;
104
105    /// Builds [`SubModel2`] with [VarStore] and [SubModel2::Config].
106    fn build(var_store: &VarStore, config: Self::Config) -> Self;
107
108    /// Clones [`SubModel2`] with [`VarStore`].
109    ///
110    /// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html
111    fn clone_with_var_store(&self, var_store: &VarStore) -> Self;
112
113    /// A generalized forward function.
114    fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output;
115}