1use std::collections::HashMap;
2use std::sync::Arc;
3
4use ndarray::{concatenate, s, Array2, ArrayView2, Axis};
5
6use crate::config::RogueNetConfig;
7use crate::fun::{gelu, softmax};
8use crate::layer_norm::LayerNorm;
9use crate::linear::Linear;
10use crate::msgpack::TensorDict;
11use crate::relpos_encoding::{RelposEncoding, RelposIndices};
12use crate::state::State;
13
14#[derive(Debug, Clone)]
15pub struct Transformer {
16 relpos_encoding: Option<Arc<RelposEncoding>>,
17 blocks: Vec<TransformerBlock>,
18}
19
20impl Transformer {
21 pub fn forward(
22 &self,
23 mut x: Array2<f32>,
24 entities: &HashMap<String, Array2<f32>>,
25 ) -> Array2<f32> {
26 let relpos_indices = self
27 .relpos_encoding
28 .as_ref()
29 .map(|rp| rp.relpos_indices(entities));
30 log::debug!("relpos_indices: {:?}", relpos_indices);
31
32 for block in &self.blocks {
33 x = block.forward(x, &relpos_indices);
34 }
35 x
36 }
37
38 pub fn new(state_dict: &TensorDict, config: &RogueNetConfig, state: &State) -> Self {
39 let dict = state_dict.as_dict();
40
41 let relpos_encoding = config.relpos_encoding.clone().map(|config| {
42 Arc::new(RelposEncoding::new(
43 &dict["relpos_encoding"],
44 &config,
45 &state.obs_space,
46 ))
47 });
48
49 let mut blocks = Vec::new();
50 for value in dict["blocks"].as_dict().values() {
51 let block = TransformerBlock::new(value, config.n_head, &relpos_encoding);
52 blocks.push(block);
53 }
54
55 Transformer {
56 blocks,
57 relpos_encoding,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct TransformerBlock {
64 ln1: LayerNorm,
65 attention: MultiHeadAttention,
66 ln2: LayerNorm,
67 mlp: Mlp,
68}
69
70impl TransformerBlock {
71 pub fn forward(&self, x: Array2<f32>, relpos_indices: &Option<RelposIndices>) -> Array2<f32> {
72 let x0 = x.view();
73 let x = self.ln1.forward(x.view());
74 let x = self.attention.forward(x.view(), relpos_indices);
75 let x = x + x0;
76 log::debug!("ATTN + RESIDUAL {:?}", x);
77 let x1 = x.view();
78 let x = self.ln2.forward(x.view());
79 let x = self.mlp.forward(x);
80 log::debug!("MLP {:?}", x);
81 let x = x + x1;
82 log::debug!("MLP + RESIDUAL {:?}", x);
83 x
84 }
85
86 fn new(
87 state_dict: &TensorDict,
88 n_head: u32,
89 relpos_encoding: &Option<Arc<RelposEncoding>>,
90 ) -> Self {
91 let dict = state_dict.as_dict();
92 let ln1 = LayerNorm::from(&dict["ln1"]);
93 let mlp = Mlp::from(&dict["mlp"]);
94 let ln2 = LayerNorm::from(&dict["ln2"]);
95 let attention = MultiHeadAttention::new(&dict["attn"], n_head, relpos_encoding.clone());
96
97 TransformerBlock {
98 ln1,
99 mlp,
100 ln2,
101 attention,
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
107pub struct MultiHeadAttention {
108 n_head: u32,
109 relpos_encoding: Option<Arc<RelposEncoding>>,
110 key: Linear,
111 value: Linear,
112 query: Linear,
113 proj: Linear,
114}
115
116impl MultiHeadAttention {
117 pub fn forward(
118 &self,
119 x: ArrayView2<f32>,
120 relpos_indices: &Option<RelposIndices>,
121 ) -> Array2<f32> {
122 let (_, c) = x.dim();
123 let d_head = c / self.n_head as usize;
124 let k = self.key.forward(x);
125 let q = self.query.forward(x);
126 let v = self.value.forward(x);
127 let scale = 1.0 / (d_head as f32).sqrt();
128 let mut ys = vec![];
129 for head in 0..self.n_head as usize {
130 let slice = s![.., head * d_head..(head + 1) * d_head];
131 let q = q.slice(slice);
132 let k = k.slice(slice);
133 let mut logits = q.dot(&k.t());
134 logits.mapv_inplace(|x| x * scale);
135 if let Some(re) = &self.relpos_encoding {
136 let relattn_logits = &re.relattn_logits(relpos_indices.as_ref().unwrap(), q.view());
137 logits += relattn_logits;
138 }
139 let attn = softmax(&logits);
140 let v = v.slice(slice);
141 let mut y = attn.dot(&v);
142 if let Some(re) = &self.relpos_encoding {
143 let relpos_values = &re.relpos_values(relpos_indices.as_ref().unwrap(), &attn, x);
144 log::debug!("RELPOS VALUES {:?}", relpos_values);
145 y += relpos_values;
146 }
147 ys.push(y);
148 }
149 let y = concatenate(Axis(1), &ys.iter().map(|x| x.view()).collect::<Vec<_>>()).unwrap();
150 self.proj.forward(y.view())
151 }
152 fn new(
153 state_dict: &TensorDict,
154 n_head: u32,
155 relpos_encoding: Option<Arc<RelposEncoding>>,
156 ) -> Self {
157 let dict = state_dict.as_dict();
158 let key = Linear::from(&dict["key"]);
159 let value = Linear::from(&dict["value"]);
160 let query = Linear::from(&dict["query"]);
161 let proj = Linear::from(&dict["proj"]);
162
163 MultiHeadAttention {
164 relpos_encoding,
165 n_head,
166 key,
167 value,
168 query,
169 proj,
170 }
171 }
172}
173
174#[derive(Debug, Clone)]
175pub struct Mlp {
176 layer1: Linear,
177 layer2: Linear,
178}
179
180impl Mlp {
181 pub fn forward(&self, x: Array2<f32>) -> Array2<f32> {
182 let x = self.layer1.forward(x.view());
183 let x = gelu(x.view());
184 self.layer2.forward(x.view())
185 }
186}
187
188impl<'a> From<&'a TensorDict> for Mlp {
189 fn from(state_dict: &TensorDict) -> Self {
190 let dict = state_dict.as_dict();
191 let layer1 = Linear::from(&dict["0"]);
192 let layer2 = Linear::from(&dict["2"]);
193
194 Mlp { layer1, layer2 }
195 }
196}