border_candle_agent/mlp/
mlp2.rs

1use super::{mlp_forward, MlpConfig};
2use crate::model::SubModel1;
3use anyhow::Result;
4use candle_core::{Device, Module, Tensor};
5use candle_nn::{linear, Linear, VarBuilder};
6
7/// Returns vector of linear modules from [`MlpConfig`].
8fn create_linear_layers(prefix: &str, vs: VarBuilder, config: &MlpConfig) -> Result<Vec<Linear>> {
9    let mut in_out_pairs: Vec<(i64, i64)> = (0..config.units.len() - 1)
10        .map(|i| (config.units[i], config.units[i + 1]))
11        .collect();
12    in_out_pairs.insert(0, (config.in_dim, config.units[0]));
13    let vs = vs.pp(prefix);
14
15    Ok(in_out_pairs
16        .iter()
17        .enumerate()
18        .map(|(i, &(in_dim, out_dim))| {
19            linear(in_dim as _, out_dim as _, vs.pp(format!("ln{}", i))).unwrap()
20        })
21        .collect())
22}
23
24/// Multilayer perceptron that outputs two tensors of the same size.
25pub struct Mlp2 {
26    _config: MlpConfig,
27    device: Device,
28    head1: Linear,
29    head2: Linear,
30    layers: Vec<Linear>,
31}
32
33impl SubModel1 for Mlp2 {
34    type Config = MlpConfig;
35    type Input = Tensor;
36    type Output = (Tensor, Tensor);
37
38    fn forward(&self, xs: &Self::Input) -> Self::Output {
39        let xs = xs.to_device(&self.device).unwrap();
40        let xs = mlp_forward(xs, &self.layers, &crate::Activation::ReLU);
41        let mean = self.head1.forward(&xs).unwrap();
42        let std = self.head2.forward(&xs).unwrap().exp().unwrap();
43        (mean, std)
44    }
45
46    fn build(vs: VarBuilder, config: Self::Config) -> Self {
47        let device = vs.device().clone();
48        let layers = create_linear_layers("mlp", vs.clone(), &config).unwrap();
49        let (head1, head2) = {
50            let in_dim = *config.units.last().unwrap();
51            let out_dim = config.out_dim;
52            let head1 = linear(in_dim as _, out_dim as _, vs.pp(format!("mean"))).unwrap();
53            let head2 = linear(in_dim as _, out_dim as _, vs.pp(format!("std"))).unwrap();
54            (head1, head2)
55        };
56
57        Self {
58            _config: config,
59            device,
60            head1,
61            head2,
62            layers,
63        }
64    }
65}