border_tch_agent/model/base.rs
1//! Definition of interfaces of neural networks.
2use anyhow::Result;
3use std::path::Path;
4use tch::{nn, nn::VarStore, Tensor};
5
6/// Base interface.
7pub trait ModelBase {
8 /// Trains the network given a loss.
9 fn backward_step(&mut self, loss: &Tensor);
10
11 /// Returns `var_store` as mutable reference.
12 fn get_var_store_mut(&mut self) -> &mut nn::VarStore;
13
14 /// Returns `var_store`.
15 fn get_var_store(&self) -> &nn::VarStore;
16
17 /// Save parameters of the neural network.
18 fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;
19
20 /// Load parameters of the neural network.
21 fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
22}
23
24/// Neural networks with a single input and a single output.
25#[allow(dead_code)]
26pub trait Model1: ModelBase {
27 /// The input of the neural network.
28 type Input;
29 /// The output of the neural network.
30 type Output;
31
32 /// Performs forward computation given an input.
33 fn forward(&self, xs: &Self::Input) -> Self::Output;
34
35 /// TODO: check places this method is used in code.
36 fn in_shape(&self) -> &[usize];
37
38 /// TODO: check places this method is used in code.
39 fn out_dim(&self) -> usize;
40}
41
42/// Neural networks with double inputs and a single output.
43#[allow(dead_code)]
44pub trait Model2: ModelBase {
45 /// An input of the neural network.
46 type Input1;
47 /// The other input of the neural network.
48 type Input2;
49 /// The output of the neural network.
50 type Output;
51
52 /// Performs forward computation given a pair of inputs.
53 fn forward(&self, x1s: &Self::Input1, x2s: &Self::Input2) -> Self::Output;
54}
55
56/// Neural network model that can be initialized with [`VarStore`] and configuration.
57///
58/// The purpose of this trait is for modularity of neural network models.
59/// Modules, which consists a neural network, should share [`VarStore`].
60/// To do this, structs implementing this trait can be initialized with a given [`VarStore`].
61/// This trait also provide the ability to clone with a given [`VarStore`].
62/// The ability is useful when creating a target network, used in recent deep learning algorithms in common.
63///
64/// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html
65pub trait SubModel {
66 /// Configuration from which [`SubModel`] is constructed.
67 type Config;
68
69 /// Input of the [`SubModel`].
70 type Input;
71
72 /// Output of the [`SubModel`].
73 type Output;
74
75 /// Builds [`SubModel`] with [`VarStore`] and [`SubModel::Config`].
76 fn build(var_store: &VarStore, config: Self::Config) -> Self;
77
78 /// Clones [`SubModel`] with [`VarStore`].
79 ///
80 /// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html
81 fn clone_with_var_store(&self, var_store: &VarStore) -> Self;
82
83 /// A generalized forward function.
84 fn forward(&self, input: &Self::Input) -> Self::Output;
85}
86
87/// Neural network model that can be initialized with [`VarStore`] and configuration.
88///
89/// The difference from [`SubModel`] is that this trait takes two inputs.
90///
91/// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html
92pub trait SubModel2 {
93 /// Configuration from which [`SubModel2`] is constructed.
94 type Config;
95
96 /// Input of the [`SubModel2`].
97 type Input1;
98
99 /// Input of the [`SubModel2`].
100 type Input2;
101
102 /// Output of the [`SubModel2`].
103 type Output;
104
105 /// Builds [`SubModel2`] with [VarStore] and [SubModel2::Config].
106 fn build(var_store: &VarStore, config: Self::Config) -> Self;
107
108 /// Clones [`SubModel2`] with [`VarStore`].
109 ///
110 /// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html
111 fn clone_with_var_store(&self, var_store: &VarStore) -> Self;
112
113 /// A generalized forward function.
114 fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output;
115}