border_candle_agent/mlp/
mlp2.rs1use super::{mlp_forward, MlpConfig};
2use crate::model::SubModel1;
3use anyhow::Result;
4use candle_core::{Device, Module, Tensor};
5use candle_nn::{linear, Linear, VarBuilder};
6
7fn create_linear_layers(prefix: &str, vs: VarBuilder, config: &MlpConfig) -> Result<Vec<Linear>> {
9 let mut in_out_pairs: Vec<(i64, i64)> = (0..config.units.len() - 1)
10 .map(|i| (config.units[i], config.units[i + 1]))
11 .collect();
12 in_out_pairs.insert(0, (config.in_dim, config.units[0]));
13 let vs = vs.pp(prefix);
14
15 Ok(in_out_pairs
16 .iter()
17 .enumerate()
18 .map(|(i, &(in_dim, out_dim))| {
19 linear(in_dim as _, out_dim as _, vs.pp(format!("ln{}", i))).unwrap()
20 })
21 .collect())
22}
23
24pub struct Mlp2 {
26 _config: MlpConfig,
27 device: Device,
28 head1: Linear,
29 head2: Linear,
30 layers: Vec<Linear>,
31}
32
33impl SubModel1 for Mlp2 {
34 type Config = MlpConfig;
35 type Input = Tensor;
36 type Output = (Tensor, Tensor);
37
38 fn forward(&self, xs: &Self::Input) -> Self::Output {
39 let xs = xs.to_device(&self.device).unwrap();
40 let xs = mlp_forward(xs, &self.layers, &crate::Activation::ReLU);
41 let mean = self.head1.forward(&xs).unwrap();
42 let std = self.head2.forward(&xs).unwrap().exp().unwrap();
43 (mean, std)
44 }
45
46 fn build(vs: VarBuilder, config: Self::Config) -> Self {
47 let device = vs.device().clone();
48 let layers = create_linear_layers("mlp", vs.clone(), &config).unwrap();
49 let (head1, head2) = {
50 let in_dim = *config.units.last().unwrap();
51 let out_dim = config.out_dim;
52 let head1 = linear(in_dim as _, out_dim as _, vs.pp(format!("mean"))).unwrap();
53 let head2 = linear(in_dim as _, out_dim as _, vs.pp(format!("std"))).unwrap();
54 (head1, head2)
55 };
56
57 Self {
58 _config: config,
59 device,
60 head1,
61 head2,
62 layers,
63 }
64 }
65}