border_candle_agent/mlp/
base.rs

1use super::{mlp_forward, MlpConfig};
2use crate::model::{SubModel1, SubModel2};
3use anyhow::Result;
4use candle_core::{Device, Tensor, D};
5use candle_nn::{linear, Linear, VarBuilder};
6
7/// Returns vector of linear modules from [`MlpConfig`].
8fn create_linear_layers(prefix: &str, vs: VarBuilder, config: &MlpConfig) -> Result<Vec<Linear>> {
9    let mut in_out_pairs: Vec<(i64, i64)> = (0..config.units.len() - 1)
10        .map(|i| (config.units[i], config.units[i + 1]))
11        .collect();
12    in_out_pairs.insert(0, (config.in_dim, config.units[0]));
13    in_out_pairs.push((*config.units.last().unwrap(), config.out_dim));
14    let vs = vs.pp(prefix);
15
16    Ok(in_out_pairs
17        .iter()
18        .enumerate()
19        .map(|(i, &(in_dim, out_dim))| {
20            linear(in_dim as _, out_dim as _, vs.pp(format!("ln{}", i))).unwrap()
21        })
22        .collect())
23}
24
25/// Multilayer perceptron with ReLU activation function.
26pub struct Mlp {
27    config: MlpConfig,
28    device: Device,
29    layers: Vec<Linear>,
30}
31
32fn _build(vs: VarBuilder, config: MlpConfig) -> Mlp {
33    let device = vs.device().clone();
34    let layers = create_linear_layers("mlp", vs, &config).unwrap();
35
36    Mlp {
37        config,
38        device,
39        layers,
40    }
41}
42
43impl SubModel1 for Mlp {
44    type Config = MlpConfig;
45    type Input = Tensor;
46    type Output = Tensor;
47
48    fn forward(&self, xs: &Self::Input) -> Tensor {
49        let xs = xs.to_device(&self.device).unwrap();
50        mlp_forward(xs, &self.layers, &self.config.activation_out)
51    }
52
53    fn build(vs: VarBuilder, config: Self::Config) -> Self {
54        _build(vs, config)
55    }
56}
57
58impl SubModel2 for Mlp {
59    type Config = MlpConfig;
60    type Input1 = Tensor;
61    type Input2 = Tensor;
62    type Output = Tensor;
63
64    fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output {
65        let input1: Tensor = input1.to_device(&self.device).unwrap();
66        let input2: Tensor = input2.to_device(&self.device).unwrap();
67
68        let input = Tensor::cat(&[input1, input2], D::Minus1)
69            .unwrap()
70            .to_device(&self.device)
71            .unwrap();
72        mlp_forward(input, &self.layers, &self.config.activation_out)
73    }
74
75    fn build(vs: VarBuilder, config: Self::Config) -> Self {
76        _build(vs, config)
77    }
78}