1use crate::tensor::DenseTensor;
4use crate::tensor::traits::TensorBase;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum GraphNodeType {
9 TokenEmbedding,
11 HiddenState,
13 AttentionOutput,
15 FFNOutput,
17}
18
19#[derive(Debug, Clone)]
21pub struct TokenEmbeddingNode {
22 pub token_id: usize,
24 pub position: usize,
26 pub embedding: DenseTensor,
28}
29
30impl TokenEmbeddingNode {
31 pub fn new(token_id: usize, position: usize, embedding: DenseTensor) -> Self {
33 Self {
34 token_id,
35 position,
36 embedding,
37 }
38 }
39
40 pub fn hidden_dim(&self) -> usize {
42 self.embedding.shape()[1]
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct HiddenStateNode {
49 pub layer: usize,
51 pub position: usize,
53 pub state: DenseTensor,
55}
56
57impl HiddenStateNode {
58 pub fn new(layer: usize, position: usize, state: DenseTensor) -> Self {
60 Self {
61 layer,
62 position,
63 state,
64 }
65 }
66
67 pub fn hidden_dim(&self) -> usize {
69 self.state.shape()[1]
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct AttentionOutputNode {
76 pub layer: usize,
78 pub head: usize,
80 pub query_pos: usize,
82 pub attended_positions: Vec<usize>,
84 pub weights: Vec<f64>,
86 pub output: DenseTensor,
88}
89
90impl AttentionOutputNode {
91 pub fn new(
93 layer: usize,
94 head: usize,
95 query_pos: usize,
96 attended_positions: Vec<usize>,
97 weights: Vec<f64>,
98 output: DenseTensor,
99 ) -> Self {
100 Self {
101 layer,
102 head,
103 query_pos,
104 attended_positions,
105 weights,
106 output,
107 }
108 }
109
110 pub fn head_dim(&self) -> usize {
112 self.output.shape()[1]
113 }
114
115 pub fn num_attended(&self) -> usize {
117 self.attended_positions.len()
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct FFNOutputNode {
124 pub layer: usize,
126 pub position: usize,
128 pub output: DenseTensor,
130}
131
132impl FFNOutputNode {
133 pub fn new(layer: usize, position: usize, output: DenseTensor) -> Self {
135 Self {
136 layer,
137 position,
138 output,
139 }
140 }
141
142 pub fn hidden_dim(&self) -> usize {
144 self.output.shape()[1]
145 }
146}
147
148#[derive(Debug, Clone)]
150pub struct GraphNode {
151 pub node_type: GraphNodeType,
153 pub id: usize,
155 pub layer: usize,
157 pub position: usize,
159 pub token_embedding: Option<TokenEmbeddingNode>,
161 pub hidden_state: Option<HiddenStateNode>,
163 pub attention_output: Option<AttentionOutputNode>,
165 pub ffn_output: Option<FFNOutputNode>,
167}
168
169impl GraphNode {
170 pub fn token_embedding(id: usize, token_id: usize, position: usize, embedding: DenseTensor) -> Self {
172 Self {
173 node_type: GraphNodeType::TokenEmbedding,
174 id,
175 layer: 0,
176 position,
177 token_embedding: Some(TokenEmbeddingNode::new(token_id, position, embedding)),
178 hidden_state: None,
179 attention_output: None,
180 ffn_output: None,
181 }
182 }
183
184 pub fn hidden_state(id: usize, layer: usize, position: usize, state: DenseTensor) -> Self {
186 Self {
187 node_type: GraphNodeType::HiddenState,
188 id,
189 layer,
190 position,
191 token_embedding: None,
192 hidden_state: Some(HiddenStateNode::new(layer, position, state)),
193 attention_output: None,
194 ffn_output: None,
195 }
196 }
197
198 pub fn attention_output(
200 id: usize,
201 layer: usize,
202 head: usize,
203 query_pos: usize,
204 attended_positions: Vec<usize>,
205 weights: Vec<f64>,
206 output: DenseTensor,
207 ) -> Self {
208 Self {
209 node_type: GraphNodeType::AttentionOutput,
210 id,
211 layer,
212 position: query_pos,
213 token_embedding: None,
214 hidden_state: None,
215 attention_output: Some(AttentionOutputNode::new(
216 layer,
217 head,
218 query_pos,
219 attended_positions,
220 weights,
221 output,
222 )),
223 ffn_output: None,
224 }
225 }
226
227 pub fn ffn_output(id: usize, layer: usize, position: usize, output: DenseTensor) -> Self {
229 Self {
230 node_type: GraphNodeType::FFNOutput,
231 id,
232 layer,
233 position,
234 token_embedding: None,
235 hidden_state: None,
236 attention_output: None,
237 ffn_output: Some(FFNOutputNode::new(layer, position, output)),
238 }
239 }
240
241 pub fn get_embedding(&self) -> Option<&TokenEmbeddingNode> {
243 self.token_embedding.as_ref()
244 }
245
246 pub fn get_hidden_state(&self) -> Option<&HiddenStateNode> {
248 self.hidden_state.as_ref()
249 }
250
251 pub fn get_attention_output(&self) -> Option<&AttentionOutputNode> {
253 self.attention_output.as_ref()
254 }
255
256 pub fn get_ffn_output(&self) -> Option<&FFNOutputNode> {
258 self.ffn_output.as_ref()
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_token_embedding_node() {
268 let embedding = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![1, 4]);
269 let node = GraphNode::token_embedding(0, 10, 0, embedding);
270
271 assert_eq!(node.node_type, GraphNodeType::TokenEmbedding);
272 assert_eq!(node.id, 0);
273 assert_eq!(node.layer, 0);
274 assert_eq!(node.position, 0);
275
276 let emb = node.get_embedding().unwrap();
277 assert_eq!(emb.token_id, 10);
278 assert_eq!(emb.position, 0);
279 assert_eq!(emb.hidden_dim(), 4);
280 }
281
282 #[test]
283 fn test_hidden_state_node() {
284 let state = DenseTensor::new(vec![0.1, 0.2, 0.3], vec![1, 3]);
285 let node = GraphNode::hidden_state(1, 5, 2, state);
286
287 assert_eq!(node.node_type, GraphNodeType::HiddenState);
288 assert_eq!(node.layer, 5);
289 assert_eq!(node.position, 2);
290
291 let hidden = node.get_hidden_state().unwrap();
292 assert_eq!(hidden.layer, 5);
293 assert_eq!(hidden.position, 2);
294 assert_eq!(hidden.hidden_dim(), 3);
295 }
296
297 #[test]
298 fn test_attention_output_node() {
299 let output = DenseTensor::new(vec![0.1, 0.2], vec![1, 2]);
300 let node = GraphNode::attention_output(
301 10,
302 3,
303 2,
304 5,
305 vec![3, 4, 5],
306 vec![0.3, 0.5, 0.2],
307 output,
308 );
309
310 assert_eq!(node.node_type, GraphNodeType::AttentionOutput);
311 assert_eq!(node.layer, 3);
312
313 let attn = node.get_attention_output().unwrap();
314 assert_eq!(attn.layer, 3);
315 assert_eq!(attn.head, 2);
316 assert_eq!(attn.query_pos, 5);
317 assert_eq!(attn.num_attended(), 3);
318 assert_eq!(attn.head_dim(), 2);
319 }
320
321 #[test]
322 fn test_ffn_output_node() {
323 let output = DenseTensor::new(vec![0.1, 0.2, 0.3], vec![1, 3]);
324 let node = GraphNode::ffn_output(20, 7, 4, output);
325
326 assert_eq!(node.node_type, GraphNodeType::FFNOutput);
327 assert_eq!(node.layer, 7);
328 assert_eq!(node.position, 4);
329
330 let ffn = node.get_ffn_output().unwrap();
331 assert_eq!(ffn.layer, 7);
332 assert_eq!(ffn.position, 4);
333 assert_eq!(ffn.hidden_dim(), 3);
334 }
335}