border_candle_agent/mlp/
base.rs1use super::{mlp_forward, MlpConfig};
2use crate::model::{SubModel1, SubModel2};
3use anyhow::Result;
4use candle_core::{Device, Tensor, D};
5use candle_nn::{linear, Linear, VarBuilder};
6
7fn create_linear_layers(prefix: &str, vs: VarBuilder, config: &MlpConfig) -> Result<Vec<Linear>> {
9 let mut in_out_pairs: Vec<(i64, i64)> = (0..config.units.len() - 1)
10 .map(|i| (config.units[i], config.units[i + 1]))
11 .collect();
12 in_out_pairs.insert(0, (config.in_dim, config.units[0]));
13 in_out_pairs.push((*config.units.last().unwrap(), config.out_dim));
14 let vs = vs.pp(prefix);
15
16 Ok(in_out_pairs
17 .iter()
18 .enumerate()
19 .map(|(i, &(in_dim, out_dim))| {
20 linear(in_dim as _, out_dim as _, vs.pp(format!("ln{}", i))).unwrap()
21 })
22 .collect())
23}
24
25pub struct Mlp {
27 config: MlpConfig,
28 device: Device,
29 layers: Vec<Linear>,
30}
31
32fn _build(vs: VarBuilder, config: MlpConfig) -> Mlp {
33 let device = vs.device().clone();
34 let layers = create_linear_layers("mlp", vs, &config).unwrap();
35
36 Mlp {
37 config,
38 device,
39 layers,
40 }
41}
42
43impl SubModel1 for Mlp {
44 type Config = MlpConfig;
45 type Input = Tensor;
46 type Output = Tensor;
47
48 fn forward(&self, xs: &Self::Input) -> Tensor {
49 let xs = xs.to_device(&self.device).unwrap();
50 mlp_forward(xs, &self.layers, &self.config.activation_out)
51 }
52
53 fn build(vs: VarBuilder, config: Self::Config) -> Self {
54 _build(vs, config)
55 }
56}
57
58impl SubModel2 for Mlp {
59 type Config = MlpConfig;
60 type Input1 = Tensor;
61 type Input2 = Tensor;
62 type Output = Tensor;
63
64 fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output {
65 let input1: Tensor = input1.to_device(&self.device).unwrap();
66 let input2: Tensor = input2.to_device(&self.device).unwrap();
67
68 let input = Tensor::cat(&[input1, input2], D::Minus1)
69 .unwrap()
70 .to_device(&self.device)
71 .unwrap();
72 mlp_forward(input, &self.layers, &self.config.activation_out)
73 }
74
75 fn build(vs: VarBuilder, config: Self::Config) -> Self {
76 _build(vs, config)
77 }
78}