Skip to main content

god_graph/transformer/graph_transformer/
nodes.rs

1//! Node types for graph-structured Transformer
2
3use crate::tensor::DenseTensor;
4use crate::tensor::traits::TensorBase;
5
6/// Type of graph node
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum GraphNodeType {
9    /// Token embedding node
10    TokenEmbedding,
11    /// Hidden state node
12    HiddenState,
13    /// Attention output node
14    AttentionOutput,
15    /// FFN output node
16    FFNOutput,
17}
18
19/// Token embedding node data
20#[derive(Debug, Clone)]
21pub struct TokenEmbeddingNode {
22    /// Token ID
23    pub token_id: usize,
24    /// Position in sequence
25    pub position: usize,
26    /// Embedding vector [1, hidden_dim]
27    pub embedding: DenseTensor,
28}
29
30impl TokenEmbeddingNode {
31    /// Create a new token embedding node
32    pub fn new(token_id: usize, position: usize, embedding: DenseTensor) -> Self {
33        Self {
34            token_id,
35            position,
36            embedding,
37        }
38    }
39
40    /// Get the hidden dimension
41    pub fn hidden_dim(&self) -> usize {
42        self.embedding.shape()[1]
43    }
44}
45
46/// Hidden state node data
47#[derive(Debug, Clone)]
48pub struct HiddenStateNode {
49    /// Layer number
50    pub layer: usize,
51    /// Position in sequence
52    pub position: usize,
53    /// Hidden state vector [1, hidden_dim]
54    pub state: DenseTensor,
55}
56
57impl HiddenStateNode {
58    /// Create a new hidden state node
59    pub fn new(layer: usize, position: usize, state: DenseTensor) -> Self {
60        Self {
61            layer,
62            position,
63            state,
64        }
65    }
66
67    /// Get the hidden dimension
68    pub fn hidden_dim(&self) -> usize {
69        self.state.shape()[1]
70    }
71}
72
73/// Attention output node data
74#[derive(Debug, Clone)]
75pub struct AttentionOutputNode {
76    /// Layer number
77    pub layer: usize,
78    /// Attention head
79    pub head: usize,
80    /// Query position
81    pub query_pos: usize,
82    /// Attended positions
83    pub attended_positions: Vec<usize>,
84    /// Attention weights
85    pub weights: Vec<f64>,
86    /// Output vector [1, head_dim]
87    pub output: DenseTensor,
88}
89
90impl AttentionOutputNode {
91    /// Create a new attention output node
92    pub fn new(
93        layer: usize,
94        head: usize,
95        query_pos: usize,
96        attended_positions: Vec<usize>,
97        weights: Vec<f64>,
98        output: DenseTensor,
99    ) -> Self {
100        Self {
101            layer,
102            head,
103            query_pos,
104            attended_positions,
105            weights,
106            output,
107        }
108    }
109
110    /// Get the head dimension
111    pub fn head_dim(&self) -> usize {
112        self.output.shape()[1]
113    }
114
115    /// Get number of attended positions
116    pub fn num_attended(&self) -> usize {
117        self.attended_positions.len()
118    }
119}
120
121/// FFN output node data
122#[derive(Debug, Clone)]
123pub struct FFNOutputNode {
124    /// Layer number
125    pub layer: usize,
126    /// Position in sequence
127    pub position: usize,
128    /// FFN output vector [1, hidden_dim]
129    pub output: DenseTensor,
130}
131
132impl FFNOutputNode {
133    /// Create a new FFN output node
134    pub fn new(layer: usize, position: usize, output: DenseTensor) -> Self {
135        Self {
136            layer,
137            position,
138            output,
139        }
140    }
141
142    /// Get the hidden dimension
143    pub fn hidden_dim(&self) -> usize {
144        self.output.shape()[1]
145    }
146}
147
148/// Graph node wrapper
149#[derive(Debug, Clone)]
150pub struct GraphNode {
151    /// Node type
152    pub node_type: GraphNodeType,
153    /// Unique node ID
154    pub id: usize,
155    /// Layer number (for layer-specific nodes)
156    pub layer: usize,
157    /// Position in sequence
158    pub position: usize,
159    /// Optional token embedding data
160    pub token_embedding: Option<TokenEmbeddingNode>,
161    /// Optional hidden state data
162    pub hidden_state: Option<HiddenStateNode>,
163    /// Optional attention output data
164    pub attention_output: Option<AttentionOutputNode>,
165    /// Optional FFN output data
166    pub ffn_output: Option<FFNOutputNode>,
167}
168
169impl GraphNode {
170    /// Create a token embedding node
171    pub fn token_embedding(id: usize, token_id: usize, position: usize, embedding: DenseTensor) -> Self {
172        Self {
173            node_type: GraphNodeType::TokenEmbedding,
174            id,
175            layer: 0,
176            position,
177            token_embedding: Some(TokenEmbeddingNode::new(token_id, position, embedding)),
178            hidden_state: None,
179            attention_output: None,
180            ffn_output: None,
181        }
182    }
183
184    /// Create a hidden state node
185    pub fn hidden_state(id: usize, layer: usize, position: usize, state: DenseTensor) -> Self {
186        Self {
187            node_type: GraphNodeType::HiddenState,
188            id,
189            layer,
190            position,
191            token_embedding: None,
192            hidden_state: Some(HiddenStateNode::new(layer, position, state)),
193            attention_output: None,
194            ffn_output: None,
195        }
196    }
197
198    /// Create an attention output node
199    pub fn attention_output(
200        id: usize,
201        layer: usize,
202        head: usize,
203        query_pos: usize,
204        attended_positions: Vec<usize>,
205        weights: Vec<f64>,
206        output: DenseTensor,
207    ) -> Self {
208        Self {
209            node_type: GraphNodeType::AttentionOutput,
210            id,
211            layer,
212            position: query_pos,
213            token_embedding: None,
214            hidden_state: None,
215            attention_output: Some(AttentionOutputNode::new(
216                layer,
217                head,
218                query_pos,
219                attended_positions,
220                weights,
221                output,
222            )),
223            ffn_output: None,
224        }
225    }
226
227    /// Create a FFN output node
228    pub fn ffn_output(id: usize, layer: usize, position: usize, output: DenseTensor) -> Self {
229        Self {
230            node_type: GraphNodeType::FFNOutput,
231            id,
232            layer,
233            position,
234            token_embedding: None,
235            hidden_state: None,
236            attention_output: None,
237            ffn_output: Some(FFNOutputNode::new(layer, position, output)),
238        }
239    }
240
241    /// Get the embedding if this is a token embedding node
242    pub fn get_embedding(&self) -> Option<&TokenEmbeddingNode> {
243        self.token_embedding.as_ref()
244    }
245
246    /// Get the hidden state if this is a hidden state node
247    pub fn get_hidden_state(&self) -> Option<&HiddenStateNode> {
248        self.hidden_state.as_ref()
249    }
250
251    /// Get the attention output if this is an attention output node
252    pub fn get_attention_output(&self) -> Option<&AttentionOutputNode> {
253        self.attention_output.as_ref()
254    }
255
256    /// Get the FFN output if this is a FFN output node
257    pub fn get_ffn_output(&self) -> Option<&FFNOutputNode> {
258        self.ffn_output.as_ref()
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_token_embedding_node() {
268        let embedding = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![1, 4]);
269        let node = GraphNode::token_embedding(0, 10, 0, embedding);
270
271        assert_eq!(node.node_type, GraphNodeType::TokenEmbedding);
272        assert_eq!(node.id, 0);
273        assert_eq!(node.layer, 0);
274        assert_eq!(node.position, 0);
275
276        let emb = node.get_embedding().unwrap();
277        assert_eq!(emb.token_id, 10);
278        assert_eq!(emb.position, 0);
279        assert_eq!(emb.hidden_dim(), 4);
280    }
281
282    #[test]
283    fn test_hidden_state_node() {
284        let state = DenseTensor::new(vec![0.1, 0.2, 0.3], vec![1, 3]);
285        let node = GraphNode::hidden_state(1, 5, 2, state);
286
287        assert_eq!(node.node_type, GraphNodeType::HiddenState);
288        assert_eq!(node.layer, 5);
289        assert_eq!(node.position, 2);
290
291        let hidden = node.get_hidden_state().unwrap();
292        assert_eq!(hidden.layer, 5);
293        assert_eq!(hidden.position, 2);
294        assert_eq!(hidden.hidden_dim(), 3);
295    }
296
297    #[test]
298    fn test_attention_output_node() {
299        let output = DenseTensor::new(vec![0.1, 0.2], vec![1, 2]);
300        let node = GraphNode::attention_output(
301            10,
302            3,
303            2,
304            5,
305            vec![3, 4, 5],
306            vec![0.3, 0.5, 0.2],
307            output,
308        );
309
310        assert_eq!(node.node_type, GraphNodeType::AttentionOutput);
311        assert_eq!(node.layer, 3);
312
313        let attn = node.get_attention_output().unwrap();
314        assert_eq!(attn.layer, 3);
315        assert_eq!(attn.head, 2);
316        assert_eq!(attn.query_pos, 5);
317        assert_eq!(attn.num_attended(), 3);
318        assert_eq!(attn.head_dim(), 2);
319    }
320
321    #[test]
322    fn test_ffn_output_node() {
323        let output = DenseTensor::new(vec![0.1, 0.2, 0.3], vec![1, 3]);
324        let node = GraphNode::ffn_output(20, 7, 4, output);
325
326        assert_eq!(node.node_type, GraphNodeType::FFNOutput);
327        assert_eq!(node.layer, 7);
328        assert_eq!(node.position, 4);
329
330        let ffn = node.get_ffn_output().unwrap();
331        assert_eq!(ffn.layer, 7);
332        assert_eq!(ffn.position, 4);
333        assert_eq!(ffn.hidden_dim(), 3);
334    }
335}