border_candle_agent/mlp/
mlp3.rs

1use super::{mlp_forward, MlpConfig};
2use crate::model::SubModel1;
3use anyhow::Result;
4use candle_core::{Device, Tensor};
5use candle_nn::{init::Init, linear, Linear, VarBuilder};
6
7/// Returns vector of linear modules from [`MlpConfig`].
8fn create_linear_layers(prefix: &str, vs: VarBuilder, config: &MlpConfig) -> Result<Vec<Linear>> {
9    let mut in_out_pairs: Vec<(i64, i64)> = (0..config.units.len() - 1)
10        .map(|i| (config.units[i], config.units[i + 1]))
11        .collect();
12    in_out_pairs.insert(0, (config.in_dim, config.units[0]));
13    in_out_pairs.push((*config.units.last().unwrap(), config.out_dim));
14    let vs = vs.pp(prefix);
15
16    Ok(in_out_pairs
17        .iter()
18        .enumerate()
19        .map(|(i, &(in_dim, out_dim))| {
20            linear(in_dim as _, out_dim as _, vs.pp(format!("ln{}", i))).unwrap()
21        })
22        .collect())
23}
24
25/// A module with two heads.
26///
27/// The one is MLP, while the other is just a set of parameters.
28/// This is based on the code of a neural network used in AWAC of CORL:
29/// <https://github.com/tinkoff-ai/CORL/blob/6afec90484bbf47dee05fdf525e26a3ebe028e9b/algorithms/offline/awac.py>
30pub struct Mlp3 {
31    _config: MlpConfig,
32    device: Device,
33    layers: Vec<Linear>,
34    head2: Tensor,
35}
36
37impl SubModel1 for Mlp3 {
38    type Config = MlpConfig;
39    type Input = Tensor;
40    type Output = (Tensor, Tensor);
41
42    fn forward(&self, xs: &Self::Input) -> Self::Output {
43        let batch_size = xs.dims()[0];
44        let xs = xs.to_device(&self.device).unwrap();
45        let ys = mlp_forward(xs, &self.layers, &crate::Activation::None);
46        let zs = self.head2.repeat((batch_size, 1)).unwrap();
47        (ys, zs)
48    }
49
50    fn build(vs: VarBuilder, config: Self::Config) -> Self {
51        let device = vs.device().clone();
52        let head2 = vs
53            .get_with_hints((1, config.out_dim as usize), "head2", Init::Const(0.))
54            .unwrap()
55            .to_device(&device)
56            .unwrap();
57        let layers = create_linear_layers("mlp", vs, &config).unwrap();
58
59        Mlp3 {
60            _config: config,
61            device,
62            layers,
63            head2,
64        }
65    }
66}