border_candle_agent/model.rs
1//! Interface of neural networks used in RL agents.
2// use anyhow::Result;
3// use candle_core::Tensor;
4use candle_nn::VarBuilder;
5// use std::path::Path;
6// use tch::{nn, nn::VarStore, Tensor};
7
8/// Neural network model not owing its [`VarMap`] internally.
9///
10/// [`VarMap`]: https://docs.rs/candle-nn/0.4.1/candle_nn/var_map/struct.VarMap.html
11pub trait SubModel1 {
12 /// Configuration from which [`SubModel1`] is constructed.
13 type Config;
14
15 /// Input of the [`SubModel1`].
16 type Input;
17
18 /// Output of the [`SubModel1`].
19 type Output;
20
21 /// Builds [`SubModel1`] with [`VarBuilder`] and [`SubModel1::Config`].
22 ///
23 /// [`VarBuilder`]: https://docs.rs/candle-nn/0.4.1/candle_nn/var_builder/type.VarBuilder.html
24 fn build(vb: VarBuilder, config: Self::Config) -> Self;
25
26 /// A generalized forward function.
27 fn forward(&self, input: &Self::Input) -> Self::Output;
28}
29
30/// Neural network model not owing its [`VarMap`] internally.
31///
32/// The difference from [`SubModel1`] is that this trait takes two inputs.
33///
34/// [`VarMap`]: https://docs.rs/candle-nn/0.4.1/candle_nn/var_map/struct.VarMap.html
35pub trait SubModel2 {
36 /// Configuration from which [`SubModel2`] is constructed.
37 type Config;
38
39 /// Input of the [`SubModel2`].
40 type Input1;
41
42 /// Input of the [`SubModel2`].
43 type Input2;
44
45 /// Output of the [`SubModel2`].
46 type Output;
47
48 /// Builds [`SubModel2`].
49 fn build(vb: VarBuilder, config: Self::Config) -> Self;
50
51 /// A generalized forward function.
52 fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output;
53}
54
55// /// Base interface of a neural nrtwork model owing its [`VarMap`].
56// ///
57// /// [`VarMap`]: candle_nn::VarMap
58// pub trait ModelBase {
59// /// Trains the network given a loss.
60// fn backward_step(&mut self, loss: &Tensor);
61
62// /// Returns `var_store` as mutable reference.
63// fn get_var_store_mut(&mut self) -> &mut nn::VarStore;
64
65// /// Returns `var_store`.
66// fn get_var_store(&self) -> &nn::VarStore;
67
68// /// Save parameters of the neural network.
69// fn save<T: AsRef<Path>>(&self, path: T) -> Result<()>;
70
71// /// Load parameters of the neural network.
72// fn load<T: AsRef<Path>>(&mut self, path: T) -> Result<()>;
73// }
74
75// /// Neural networks with a single input and a single output.
76// pub trait Model1: ModelBase {
77// /// The input of the neural network.
78// type Input;
79// /// The output of the neural network.
80// type Output;
81
82// /// Performs forward computation given an input.
83// fn forward(&self, xs: &Self::Input) -> Self::Output;
84
85// // /// TODO: check places this method is used in code.
86// // fn in_shape(&self) -> &[usize];
87
88// // /// TODO: check places this method is used in code.
89// // fn out_dim(&self) -> usize;
90// }
91
92// /// Neural networks with double inputs and a single output.
93// pub trait Model2: ModelBase {
94// /// An input of the neural network.
95// type Input1;
96// /// The other input of the neural network.
97// type Input2;
98// /// The output of the neural network.
99// type Output;
100
101// /// Performs forward computation given a pair of inputs.
102// fn forward(&self, x1s: &Self::Input1, x2s: &Self::Input2) -> Self::Output;
103// }
104
105// /// Neural network model that can be initialized with [VarStore] and configuration.
106// ///
107// /// The purpose of this trait is for modularity of neural network models.
108// /// Modules, which consists a neural network, should share [VarStore].
109// /// To do this, structs implementing this trait can be initialized with a given [VarStore].
110// /// This trait also provide the ability to clone with a given [VarStore].
111// /// The ability is useful when creating a target network, used in recent deep learning algorithms in common.
112// pub trait SubModel {
113// /// Configuration from which [SubModel] is constructed.
114// type Config;
115
116// /// Input of the [SubModel].
117// type Input;
118
119// /// Output of the [SubModel].
120// type Output;
121
122// /// Builds [SubModel] with [VarStore] and [SubModel::Config].
123// fn build(var_store: &VarStore, config: Self::Config) -> Self;
124
125// /// Clones [SubModel] with [VarStore].
126// fn clone_with_var_store(&self, var_store: &VarStore) -> Self;
127
128// /// A generalized forward function.
129// fn forward(&self, input: &Self::Input) -> Self::Output;
130// }
131
132// /// Neural network model that can be initialized with [VarStore] and configuration.
133// ///
134// /// The difference from [SubModel] is that this trait takes two inputs.
135// pub trait SubModel2 {
136// /// Configuration from which [SubModel2] is constructed.
137// type Config;
138
139// /// Input of the [SubModel2].
140// type Input1;
141
142// /// Input of the [SubModel2].
143// type Input2;
144
145// /// Output of the [SubModel2].
146// type Output;
147
148// /// Builds [SubModel2] with [VarStore] and [SubModel2::Config].
149// fn build(var_store: &VarStore, config: Self::Config) -> Self;
150
151// /// Clones [SubModel2] with [VarStore].
152// fn clone_with_var_store(&self, var_store: &VarStore) -> Self;
153
154// /// A generalized forward function.
155// fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output;
156// }