border_candle_agent/mlp/
mlp3.rs1use super::{mlp_forward, MlpConfig};
2use crate::model::SubModel1;
3use anyhow::Result;
4use candle_core::{Device, Tensor};
5use candle_nn::{init::Init, linear, Linear, VarBuilder};
6
7fn create_linear_layers(prefix: &str, vs: VarBuilder, config: &MlpConfig) -> Result<Vec<Linear>> {
9 let mut in_out_pairs: Vec<(i64, i64)> = (0..config.units.len() - 1)
10 .map(|i| (config.units[i], config.units[i + 1]))
11 .collect();
12 in_out_pairs.insert(0, (config.in_dim, config.units[0]));
13 in_out_pairs.push((*config.units.last().unwrap(), config.out_dim));
14 let vs = vs.pp(prefix);
15
16 Ok(in_out_pairs
17 .iter()
18 .enumerate()
19 .map(|(i, &(in_dim, out_dim))| {
20 linear(in_dim as _, out_dim as _, vs.pp(format!("ln{}", i))).unwrap()
21 })
22 .collect())
23}
24
25pub struct Mlp3 {
31 _config: MlpConfig,
32 device: Device,
33 layers: Vec<Linear>,
34 head2: Tensor,
35}
36
37impl SubModel1 for Mlp3 {
38 type Config = MlpConfig;
39 type Input = Tensor;
40 type Output = (Tensor, Tensor);
41
42 fn forward(&self, xs: &Self::Input) -> Self::Output {
43 let batch_size = xs.dims()[0];
44 let xs = xs.to_device(&self.device).unwrap();
45 let ys = mlp_forward(xs, &self.layers, &crate::Activation::None);
46 let zs = self.head2.repeat((batch_size, 1)).unwrap();
47 (ys, zs)
48 }
49
50 fn build(vs: VarBuilder, config: Self::Config) -> Self {
51 let device = vs.device().clone();
52 let head2 = vs
53 .get_with_hints((1, config.out_dim as usize), "head2", Init::Const(0.))
54 .unwrap()
55 .to_device(&device)
56 .unwrap();
57 let layers = create_linear_layers("mlp", vs, &config).unwrap();
58
59 Mlp3 {
60 _config: config,
61 device,
62 layers,
63 head2,
64 }
65 }
66}