rogue_net/
transformer.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use ndarray::{concatenate, s, Array2, ArrayView2, Axis};
5
6use crate::config::RogueNetConfig;
7use crate::fun::{gelu, softmax};
8use crate::layer_norm::LayerNorm;
9use crate::linear::Linear;
10use crate::msgpack::TensorDict;
11use crate::relpos_encoding::{RelposEncoding, RelposIndices};
12use crate::state::State;
13
14#[derive(Debug, Clone)]
15pub struct Transformer {
16    relpos_encoding: Option<Arc<RelposEncoding>>,
17    blocks: Vec<TransformerBlock>,
18}
19
20impl Transformer {
21    pub fn forward(
22        &self,
23        mut x: Array2<f32>,
24        entities: &HashMap<String, Array2<f32>>,
25    ) -> Array2<f32> {
26        let relpos_indices = self
27            .relpos_encoding
28            .as_ref()
29            .map(|rp| rp.relpos_indices(entities));
30        log::debug!("relpos_indices: {:?}", relpos_indices);
31
32        for block in &self.blocks {
33            x = block.forward(x, &relpos_indices);
34        }
35        x
36    }
37
38    pub fn new(state_dict: &TensorDict, config: &RogueNetConfig, state: &State) -> Self {
39        let dict = state_dict.as_dict();
40
41        let relpos_encoding = config.relpos_encoding.clone().map(|config| {
42            Arc::new(RelposEncoding::new(
43                &dict["relpos_encoding"],
44                &config,
45                &state.obs_space,
46            ))
47        });
48
49        let mut blocks = Vec::new();
50        for value in dict["blocks"].as_dict().values() {
51            let block = TransformerBlock::new(value, config.n_head, &relpos_encoding);
52            blocks.push(block);
53        }
54
55        Transformer {
56            blocks,
57            relpos_encoding,
58        }
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct TransformerBlock {
64    ln1: LayerNorm,
65    attention: MultiHeadAttention,
66    ln2: LayerNorm,
67    mlp: Mlp,
68}
69
70impl TransformerBlock {
71    pub fn forward(&self, x: Array2<f32>, relpos_indices: &Option<RelposIndices>) -> Array2<f32> {
72        let x0 = x.view();
73        let x = self.ln1.forward(x.view());
74        let x = self.attention.forward(x.view(), relpos_indices);
75        let x = x + x0;
76        log::debug!("ATTN + RESIDUAL {:?}", x);
77        let x1 = x.view();
78        let x = self.ln2.forward(x.view());
79        let x = self.mlp.forward(x);
80        log::debug!("MLP {:?}", x);
81        let x = x + x1;
82        log::debug!("MLP + RESIDUAL {:?}", x);
83        x
84    }
85
86    fn new(
87        state_dict: &TensorDict,
88        n_head: u32,
89        relpos_encoding: &Option<Arc<RelposEncoding>>,
90    ) -> Self {
91        let dict = state_dict.as_dict();
92        let ln1 = LayerNorm::from(&dict["ln1"]);
93        let mlp = Mlp::from(&dict["mlp"]);
94        let ln2 = LayerNorm::from(&dict["ln2"]);
95        let attention = MultiHeadAttention::new(&dict["attn"], n_head, relpos_encoding.clone());
96
97        TransformerBlock {
98            ln1,
99            mlp,
100            ln2,
101            attention,
102        }
103    }
104}
105
106#[derive(Debug, Clone)]
107pub struct MultiHeadAttention {
108    n_head: u32,
109    relpos_encoding: Option<Arc<RelposEncoding>>,
110    key: Linear,
111    value: Linear,
112    query: Linear,
113    proj: Linear,
114}
115
116impl MultiHeadAttention {
117    pub fn forward(
118        &self,
119        x: ArrayView2<f32>,
120        relpos_indices: &Option<RelposIndices>,
121    ) -> Array2<f32> {
122        let (_, c) = x.dim();
123        let d_head = c / self.n_head as usize;
124        let k = self.key.forward(x);
125        let q = self.query.forward(x);
126        let v = self.value.forward(x);
127        let scale = 1.0 / (d_head as f32).sqrt();
128        let mut ys = vec![];
129        for head in 0..self.n_head as usize {
130            let slice = s![.., head * d_head..(head + 1) * d_head];
131            let q = q.slice(slice);
132            let k = k.slice(slice);
133            let mut logits = q.dot(&k.t());
134            logits.mapv_inplace(|x| x * scale);
135            if let Some(re) = &self.relpos_encoding {
136                let relattn_logits = &re.relattn_logits(relpos_indices.as_ref().unwrap(), q.view());
137                logits += relattn_logits;
138            }
139            let attn = softmax(&logits);
140            let v = v.slice(slice);
141            let mut y = attn.dot(&v);
142            if let Some(re) = &self.relpos_encoding {
143                let relpos_values = &re.relpos_values(relpos_indices.as_ref().unwrap(), &attn, x);
144                log::debug!("RELPOS VALUES {:?}", relpos_values);
145                y += relpos_values;
146            }
147            ys.push(y);
148        }
149        let y = concatenate(Axis(1), &ys.iter().map(|x| x.view()).collect::<Vec<_>>()).unwrap();
150        self.proj.forward(y.view())
151    }
152    fn new(
153        state_dict: &TensorDict,
154        n_head: u32,
155        relpos_encoding: Option<Arc<RelposEncoding>>,
156    ) -> Self {
157        let dict = state_dict.as_dict();
158        let key = Linear::from(&dict["key"]);
159        let value = Linear::from(&dict["value"]);
160        let query = Linear::from(&dict["query"]);
161        let proj = Linear::from(&dict["proj"]);
162
163        MultiHeadAttention {
164            relpos_encoding,
165            n_head,
166            key,
167            value,
168            query,
169            proj,
170        }
171    }
172}
173
174#[derive(Debug, Clone)]
175pub struct Mlp {
176    layer1: Linear,
177    layer2: Linear,
178}
179
180impl Mlp {
181    pub fn forward(&self, x: Array2<f32>) -> Array2<f32> {
182        let x = self.layer1.forward(x.view());
183        let x = gelu(x.view());
184        self.layer2.forward(x.view())
185    }
186}
187
188impl<'a> From<&'a TensorDict> for Mlp {
189    fn from(state_dict: &TensorDict) -> Self {
190        let dict = state_dict.as_dict();
191        let layer1 = Linear::from(&dict["0"]);
192        let layer2 = Linear::from(&dict["2"]);
193
194        Mlp { layer1, layer2 }
195    }
196}