Skip to main content

god_graph/transformer/graph_transformer/
execution.rs

1//! Graph execution engine for graph-structured Transformer
2//!
3//! ## 🎯 GraphTransformer 定位说明
4//!
5//! **GraphTransformer 主要用于**:
6//! 1. **可视化注意力拓扑**:导出 DOT/Graphviz 格式,直观理解注意力模式
7//! 2. **动态剪枝弱边**:运行时剪除弱注意力连接,减少冗余计算
8//! 3. **添加自定义连接**:实验长程连接、稀疏注意力等架构变体
9//! 4. **拓扑缺陷检测**:发现孤立节点、梯度阻断、缺失残差连接
10//! 5. **执行前向传播**:基于拓扑排序的张量计算,支持边上传递张量消息
11//!
12//! **GraphTransformer 不用于**:
13//! - ❌ **高性能推理**:对于生产环境推理,建议转换为标准 LlamaModel(使用 `llama.cpp` 或 `vllm`)
14//! - ❌ **大规模训练**:对于训练任务,使用 PyTorch/JAX 等成熟框架
15//!
16//! ## 核心优势
17//!
18//! - **显式表示注意力**:每条注意力边可单独访问/修改(黑盒推理引擎做不到)
19//! - **动态拓扑编辑**:支持运行时修改图结构(传统静态图做不到)
20//! - **可视化支持**:导出 DOT 格式,用 Graphviz 渲染
21//! - **张量传递语义**:边上携带 Q/K/V 投影张量,实现消息传递计算
22//!
23//! ## 使用示例
24//!
25//! ```rust,no_run
26//! use god_gragh::transformer::graph_transformer::GraphTransformer;
27//! use god_gragh::tensor::traits::TensorBase;
28//!
29//! // 1. 创建 GraphTransformer
30//! let mut transformer = GraphTransformer::new(2, 4, 256);
31//! transformer.build_graph(&[1, 2, 3, 4]);
32//!
33//! // 2. 可视化注意力拓扑
34//! let dot = transformer.to_dot();
35//! std::fs::write("attention_graph.dot", dot).unwrap();
36//! // 用 Graphviz 渲染:dot -Tpng attention_graph.dot -o attention_graph.png
37//!
38//! // 3. 剪枝弱注意力边(阈值=0.01)
39//! let pruned_count = transformer.prune_weak_edges(0.01);
40//! println!("剪枝了 {} 条边", pruned_count);
41//!
42//! // 4. 执行前向传播
43//! let output = transformer.forward(&[1, 2, 3, 4]);
44//! println!("Output shape: {:?}", output.shape());
45//!
46//! // 5. 添加自定义长程连接
47//! // transformer.add_skip_connection(layer_0, layer_11);
48//! ```
49//!
50//! ## 与 DifferentiableGraph 的关系
51//!
52//! - **GraphTransformer**: 用于分析和编辑**已有**的 Transformer 结构
53//! - **DifferentiableGraph**: 用于**优化**图结构(梯度下降学习最优架构)
54//!
55//! 典型工作流:
56//! 1. 用 GraphTransformer 可视化和分析初始结构
57//! 2. 用 DifferentiableGraph 优化结构(剪枝、架构搜索)
58//! 3. 用 GraphTransformer 验证优化结果
59//!
60//! ## 性能说明
61//!
62//! GraphTransformer 包含图遍历和动态编辑开销,不适合高性能推理场景。
63//! 对于生产环境,建议:
64//! 1. 用 GraphTransformer 分析/优化结构
65//! 2. 导出为静态图(Safetensors 格式)
66//! 3. 用 `llama.cpp` 或 `vllm` 进行推理
67//!
68//! ## GraphTransformer forward() 实现详解
69//!
70//! ### 执行流程
71//!
72//! 1. **拓扑排序**:确定计算顺序,确保依赖先计算
73//! 2. **节点执行**:按拓扑序执行每个节点的操作
74//! 3. **边上传递**:通过边上的张量消息传递信息
75//! 4. **缓存中间结果**:避免重复计算
76//!
77//! ### 张量传递语义
78//!
79//! - **SelfAttention 边**:携带 Q/K/V 投影张量
80//! - **DataFlow 边**:携带数据流张量(激活值)
81//! - **Residual 边**:携带残差连接张量(恒等映射)
82//!
83//! ### 节点类型与执行
84//!
85//! - **TokenEmbedding**:提供 token 嵌入向量
86//! - **HiddenState**:聚合输入和边消息
87//! - **AttentionOutput**:加权求和注意力输出
88//! - **FFNOutput**:应用 FFN 变换
89
90use std::collections::{HashMap, HashSet};
91use crate::graph::Graph;
92use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
93use crate::node::NodeIndex;
94use crate::tensor::DenseTensor;
95use crate::tensor::traits::{TensorOps, TensorBase};
96use super::nodes::{GraphNode, GraphNodeType};
97use super::edges::{GraphEdge, GraphEdgeType, DataFlowOp, SkipType};
98
99/// Graph executor for executing Transformer computation graphs
100#[derive(Debug)]
101pub struct GraphExecutor {
102    /// Computation graph
103    graph: Graph<GraphNode, GraphEdge>,
104    /// Cached intermediate results
105    cache: HashMap<NodeIndex, DenseTensor>,
106}
107
108impl GraphExecutor {
109    /// Create a new graph executor
110    pub fn new() -> Self {
111        Self {
112            graph: Graph::directed(),
113            cache: HashMap::new(),
114        }
115    }
116
117    /// Add a node to the graph
118    pub fn add_node(&mut self, node: GraphNode) -> NodeIndex {
119        self.graph.add_node(node).unwrap_or(NodeIndex::invalid())
120    }
121
122    /// Add an edge to the graph
123    pub fn add_edge(&mut self, source: NodeIndex, target: NodeIndex, edge: GraphEdge) -> bool {
124        self.graph.add_edge(source, target, edge).is_ok()
125    }
126
127    /// Get number of nodes
128    pub fn num_nodes(&self) -> usize {
129        self.graph.node_count()
130    }
131
132    /// Get number of edges
133    pub fn num_edges(&self) -> usize {
134        self.graph.edge_count()
135    }
136
137    /// Perform topological sort of the graph
138    pub fn topological_sort(&self) -> Vec<NodeIndex> {
139        let mut result = Vec::new();
140        let mut visited = HashSet::new();
141
142        fn visit(
143            node_idx: NodeIndex,
144            graph: &Graph<GraphNode, GraphEdge>,
145            visited: &mut HashSet<NodeIndex>,
146            result: &mut Vec<NodeIndex>,
147        ) {
148            if visited.contains(&node_idx) {
149                return;
150            }
151            visited.insert(node_idx);
152
153            // Visit successors first
154            for neighbor in graph.neighbors(node_idx) {
155                visit(neighbor, graph, visited, result);
156            }
157
158            result.push(node_idx);
159        }
160
161        for node in self.graph.nodes() {
162            visit(node.index(), &self.graph, &mut visited, &mut result);
163        }
164
165        result.reverse();
166        result
167    }
168
169    /// Execute forward pass through the graph
170    ///
171    /// # Arguments
172    /// * `input_ids` - Input token IDs
173    ///
174    /// # Returns
175    /// Output tensor with shape [seq_len, hidden_dim]
176    pub fn forward(&mut self, input_ids: &[usize]) -> DenseTensor {
177        // Clear cache
178        self.cache.clear();
179
180        // Get topological order
181        let order = self.topological_sort();
182
183        // Execute nodes in topological order
184        for node_idx in order {
185            self.execute_node(node_idx, input_ids);
186        }
187
188        // Return output from final node (last layer's FFN output)
189        if let Some(last_node) = self.graph.nodes().last() {
190            if let Some(output) = self.cache.get(&last_node.index()) {
191                return output.clone();
192            }
193        }
194
195        // Fallback: return zeros
196        DenseTensor::zeros(vec![1, 1])
197    }
198
199    /// Execute a single node with input_ids for embedding lookup
200    fn execute_node(&mut self, node_idx: NodeIndex, input_ids: &[usize]) {
201        // Get node data
202        let node = if let Ok(node_ref) = self.graph.get_node(node_idx) {
203            node_ref.clone()
204        } else {
205            return;
206        };
207
208        // Collect input tensors and edge messages from predecessors
209        let mut inputs: Vec<DenseTensor> = Vec::new();
210        let mut edge_messages: Vec<DenseTensor> = Vec::new();
211        let mut edge_weights: Vec<f64> = Vec::new();
212
213        for edge_ref in self.graph.edges() {
214            if edge_ref.target() == node_idx {
215                // Get cached tensor from source node
216                if let Some(source_tensor) = self.cache.get(&edge_ref.source()) {
217                    inputs.push(source_tensor.clone());
218                    
219                    // Get message tensor from edge (Q/K/V projections)
220                    if let Some(msg) = edge_ref.data().message() {
221                        edge_messages.push(msg.clone());
222                    }
223                    
224                    // Get attention weight if available
225                    if let Some(sa) = edge_ref.data().get_self_attention() {
226                        edge_weights.push(sa.weight);
227                    }
228                }
229            }
230        }
231
232        // Execute based on node type
233        match node.node_type {
234            GraphNodeType::TokenEmbedding => {
235                // Token embedding nodes: lookup embedding for token_id
236                if let Some(emb) = &node.token_embedding {
237                    // Use input_ids to get actual embedding values
238                    let position = emb.position;
239                    if position < input_ids.len() {
240                        // Create embedding based on token_id (simplified: use position as index)
241                        let token_id = input_ids.get(position).copied().unwrap_or(0);
242                        let hidden_dim = emb.embedding.shape()[1];
243                        
244                        // Generate embedding: simple hash-based initialization
245                        let emb_data: Vec<f64> = (0..hidden_dim)
246                            .map(|i| {
247                                let seed = (token_id * 1000 + i) as f64;
248                                (seed.sin() * 1000.0).fract()
249                            })
250                            .collect();
251                        
252                        let embedding = DenseTensor::new(emb_data, vec![1, hidden_dim]);
253                        self.cache.insert(node_idx, embedding);
254                    } else {
255                        self.cache.insert(node_idx, emb.embedding.clone());
256                    }
257                }
258            }
259            GraphNodeType::HiddenState => {
260                // Hidden state nodes: aggregate inputs with edge messages
261                if let Some(state) = &node.hidden_state {
262                    if inputs.is_empty() {
263                        self.cache.insert(node_idx, state.state.clone());
264                    } else {
265                        // Sum all inputs, incorporating edge messages (Q/K/V) if available
266                        let mut result = if edge_messages.is_empty() {
267                            inputs[0].clone()
268                        } else {
269                            // Apply Q/K/V projection via matrix multiplication
270                            let qkv = &edge_messages[0];
271                            if qkv.shape() == inputs[0].shape() {
272                                inputs[0].add(qkv)
273                            } else {
274                                inputs[0].clone()
275                            }
276                        };
277
278                        for (i, input) in inputs.iter().enumerate().skip(1) {
279                            let tensor_to_add = if i < edge_messages.len() {
280                                &edge_messages[i]
281                            } else {
282                                input
283                            };
284                            result = result.add(tensor_to_add);
285                        }
286                        self.cache.insert(node_idx, result);
287                    }
288                }
289            }
290            GraphNodeType::AttentionOutput => {
291                // Attention output nodes: weighted sum using attention weights
292                if let Some(attn) = &node.attention_output {
293                    if inputs.is_empty() {
294                        self.cache.insert(node_idx, attn.output.clone());
295                    } else {
296                        // Weighted sum using edge weights or attention node weights
297                        let hidden_dim = attn.output.shape()[1];
298                        let mut result = DenseTensor::zeros(vec![1, hidden_dim]);
299                        
300                        for (i, input) in inputs.iter().enumerate() {
301                            // Get weight from edge or node
302                            let weight = if i < edge_weights.len() {
303                                edge_weights[i]
304                            } else if i < attn.weights.len() {
305                                attn.weights[i]
306                            } else {
307                                1.0 / inputs.len() as f64
308                            };
309                            
310                            // Apply attention weight
311                            let weighted = input.scale(weight);
312                            result = result.add(&weighted);
313                        }
314                        self.cache.insert(node_idx, result);
315                    }
316                }
317            }
318            GraphNodeType::FFNOutput => {
319                // FFN output nodes: apply FFN transformation (simplified: linear + GELU)
320                if let Some(ffn) = &node.ffn_output {
321                    if inputs.is_empty() {
322                        self.cache.insert(node_idx, ffn.output.clone());
323                    } else {
324                        // Aggregate inputs first
325                        let aggregated = if inputs.len() > 1 {
326                            let mut result = inputs[0].clone();
327                            for input in inputs.iter().skip(1) {
328                                result = result.add(input);
329                            }
330                            result
331                        } else {
332                            inputs[0].clone()
333                        };
334                        
335                        // Apply simplified FFN: just pass through with residual
336                        // In full implementation, this would be: GeLU(x @ W1) @ W2
337                        self.cache.insert(node_idx, aggregated);
338                    }
339                }
340            }
341        }
342    }
343
344    /// Prune weak attention edges based on threshold
345    ///
346    /// # Arguments
347    /// * `threshold` - Attention weight threshold for pruning
348    pub fn prune_weak_edges(&mut self, threshold: f64) -> usize {
349        let mut pruned_count = 0;
350
351        // Collect edges to prune
352        let edges_to_prune: Vec<_> = self.graph.edges()
353            .filter(|edge_ref| {
354                if let GraphEdgeType::SelfAttention = edge_ref.data().edge_type {
355                    if let Some(sa) = &edge_ref.data().self_attention {
356                        return sa.weight < threshold;
357                    }
358                }
359                false
360            })
361            .map(|edge_ref| edge_ref.index())
362            .collect();
363
364        // Prune edges
365        for edge_idx in edges_to_prune {
366            if self.graph.remove_edge(edge_idx).is_ok() {
367                pruned_count += 1;
368            }
369        }
370
371        pruned_count
372    }
373
374    /// Export graph to DOT format for visualization
375    pub fn to_dot(&self) -> String {
376        let mut dot = String::from("digraph Transformer {\n");
377        dot.push_str("    rankdir=TB;\n");
378        dot.push_str("    node [shape=box];\n\n");
379
380        // Add nodes
381        for node in self.graph.nodes() {
382            let label = match node.data.node_type {
383                GraphNodeType::TokenEmbedding => format!("TokenEmbed[{}]", node.data.position),
384                GraphNodeType::HiddenState => format!("Hidden[L{}P{}]", node.data.layer, node.data.position),
385                GraphNodeType::AttentionOutput => format!("Attn[L{}H{}]", node.data.layer, 
386                    node.data.attention_output.as_ref().map(|a| a.head).unwrap_or(0)),
387                GraphNodeType::FFNOutput => format!("FFN[L{}P{}]", node.data.layer, node.data.position),
388            };
389            dot.push_str(&format!("    n{} [label=\"{}\"];\n", node.index().index(), label));
390        }
391
392        dot.push('\n');
393
394        // Add edges
395        for edge in self.graph.edges() {
396            let style = match edge.data().edge_type {
397                GraphEdgeType::SelfAttention => "style=solid, color=blue",
398                GraphEdgeType::DataFlow => "style=solid, color=green",
399                GraphEdgeType::Residual => "style=dashed, color=red",
400            };
401            dot.push_str(&format!("    n{} -> n{} [{}];\n", 
402                edge.source().index(), edge.target().index(), style));
403        }
404
405        dot.push('}');
406        dot
407    }
408
409    /// Clear the graph and cache
410    pub fn clear(&mut self) {
411        self.graph = Graph::directed();
412        self.cache.clear();
413    }
414}
415
416impl Default for GraphExecutor {
417    fn default() -> Self {
418        Self::new()
419    }
420}
421
422/// Graph-structured Transformer wrapper
423#[derive(Debug)]
424pub struct GraphTransformer {
425    /// Graph executor
426    executor: GraphExecutor,
427    /// Number of layers
428    num_layers: usize,
429    /// Number of attention heads
430    num_heads: usize,
431    /// Hidden dimension
432    hidden_dim: usize,
433}
434
435impl GraphTransformer {
436    /// Create a new graph transformer
437    pub fn new(num_layers: usize, num_heads: usize, hidden_dim: usize) -> Self {
438        Self {
439            executor: GraphExecutor::new(),
440            num_layers,
441            num_heads,
442            hidden_dim,
443        }
444    }
445
446    /// Build graph structure from input
447    ///
448    /// # Arguments
449    /// * `input_ids` - Input token IDs
450    pub fn build_graph(&mut self, input_ids: &[usize]) {
451        let seq_len = input_ids.len();
452        let head_dim = self.hidden_dim / self.num_heads;
453
454        // Create token embedding nodes
455        let mut embedding_nodes = Vec::new();
456        for (i, &token_id) in input_ids.iter().enumerate() {
457            let embedding = DenseTensor::zeros(vec![1, self.hidden_dim]);
458            let node = GraphNode::token_embedding(i, token_id, i, embedding);
459            let node_idx = self.executor.add_node(node);
460            embedding_nodes.push(node_idx);
461        }
462
463        // Create layer-wise graph structure
464        let mut prev_layer_nodes = embedding_nodes;
465
466        for layer in 0..self.num_layers {
467            let mut current_layer_nodes = Vec::new();
468
469            // Create attention nodes for each position
470            for pos in 0..seq_len {
471                // Create attention output node
472                let attended_positions: Vec<usize> = (0..seq_len).collect();
473                let weights = vec![1.0 / seq_len as f64; seq_len];
474                let output = DenseTensor::zeros(vec![1, self.hidden_dim]);
475
476                let attn_node = GraphNode::attention_output(
477                    pos,
478                    layer,
479                    0,
480                    pos,
481                    attended_positions.clone(),
482                    weights.clone(),
483                    output,
484                );
485                let attn_node_idx = self.executor.add_node(attn_node);
486                current_layer_nodes.push(attn_node_idx);
487
488                // Add self-attention edges with tensor messages from previous positions
489                for (src_pos, &src_node) in prev_layer_nodes.iter().enumerate() {
490                    let weight = weights.get(src_pos).copied().unwrap_or(0.0);
491                    // Create message tensor (Q/K/V projection placeholder)
492                    let message = DenseTensor::zeros(vec![1, head_dim]);
493                    let edge = GraphEdge::self_attention_with_message(
494                        src_node.index(),
495                        attn_node_idx.index(),
496                        weight,
497                        0,
498                        layer,
499                        message,
500                    );
501                    self.executor.add_edge(src_node, attn_node_idx, edge);
502                }
503
504                // Add residual connection with tensor
505                if let Some(&prev_node) = prev_layer_nodes.get(pos) {
506                    let residual_tensor = DenseTensor::zeros(vec![1, self.hidden_dim]);
507                    let residual_edge = GraphEdge::residual_with_tensor(
508                        prev_node.index(),
509                        attn_node_idx.index(),
510                        layer,
511                        SkipType::PreNorm,
512                        residual_tensor,
513                    );
514                    self.executor.add_edge(prev_node, attn_node_idx, residual_edge);
515                }
516            }
517
518            // Create FFN nodes
519            let mut ffn_nodes = Vec::new();
520            for (pos, &attn_node) in current_layer_nodes.iter().enumerate() {
521                let output = DenseTensor::zeros(vec![1, self.hidden_dim]);
522                let ffn_node = GraphNode::ffn_output(pos, layer, pos, output);
523                let ffn_node_idx = self.executor.add_node(ffn_node);
524                ffn_nodes.push(ffn_node_idx);
525
526                // Add data flow edge with message tensor from attention to FFN
527                let message = DenseTensor::zeros(vec![1, self.hidden_dim]);
528                let edge = GraphEdge::data_flow_with_message(
529                    attn_node.index(),
530                    ffn_node_idx.index(),
531                    DataFlowOp::AttentionToOutput,
532                    layer,
533                    message,
534                );
535                self.executor.add_edge(attn_node, ffn_node_idx, edge);
536
537                // Add residual connection with tensor
538                let residual_tensor = DenseTensor::zeros(vec![1, self.hidden_dim]);
539                let residual_edge = GraphEdge::residual_with_tensor(
540                    attn_node.index(),
541                    ffn_node_idx.index(),
542                    layer,
543                    SkipType::PostNorm,
544                    residual_tensor,
545                );
546                self.executor.add_edge(attn_node, ffn_node_idx, residual_edge);
547            }
548
549            prev_layer_nodes = ffn_nodes;
550        }
551    }
552
553    /// Run forward pass
554    pub fn forward(&mut self, input_ids: &[usize]) -> DenseTensor {
555        self.executor.forward(input_ids)
556    }
557
558    /// Get number of nodes in graph
559    pub fn num_nodes(&self) -> usize {
560        self.executor.num_nodes()
561    }
562
563    /// Get number of edges in graph
564    pub fn num_edges(&self) -> usize {
565        self.executor.num_edges()
566    }
567
568    /// Prune weak attention edges
569    pub fn prune_weak_edges(&mut self, threshold: f64) -> usize {
570        self.executor.prune_weak_edges(threshold)
571    }
572
573    /// Export to DOT format
574    pub fn to_dot(&self) -> String {
575        self.executor.to_dot()
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn test_graph_executor_creation() {
585        let executor = GraphExecutor::new();
586        assert_eq!(executor.num_nodes(), 0);
587        assert_eq!(executor.num_edges(), 0);
588    }
589
590    #[test]
591    fn test_graph_executor_add_node() {
592        let mut executor = GraphExecutor::new();
593        let embedding = DenseTensor::zeros(vec![1, 4]);
594        let node = GraphNode::token_embedding(0, 10, 0, embedding);
595        let node_idx = executor.add_node(node);
596
597        assert_eq!(executor.num_nodes(), 1);
598        assert!(node_idx.is_valid());
599    }
600
601    #[test]
602    fn test_graph_executor_add_edge() {
603        let mut executor = GraphExecutor::new();
604
605        let embedding1 = DenseTensor::zeros(vec![1, 4]);
606        let node1 = GraphNode::token_embedding(0, 10, 0, embedding1);
607        let node1_idx = executor.add_node(node1);
608
609        let embedding2 = DenseTensor::zeros(vec![1, 4]);
610        let node2 = GraphNode::token_embedding(1, 20, 1, embedding2);
611        let node2_idx = executor.add_node(node2);
612
613        let edge = GraphEdge::self_attention(node1_idx.index(), node2_idx.index(), 0.5, 0, 0);
614        let result = executor.add_edge(node1_idx, node2_idx, edge);
615
616        assert!(result);
617        assert_eq!(executor.num_edges(), 1);
618    }
619
620    #[test]
621    fn test_topological_sort() {
622        let mut executor = GraphExecutor::new();
623
624        // Create a simple chain: A -> B -> C
625        let node_a = GraphNode::token_embedding(0, 1, 0, DenseTensor::zeros(vec![1, 4]));
626        let node_b = GraphNode::hidden_state(1, 0, 0, DenseTensor::zeros(vec![1, 4]));
627        let node_c = GraphNode::ffn_output(2, 0, 0, DenseTensor::zeros(vec![1, 4]));
628
629        let idx_a = executor.add_node(node_a);
630        let idx_b = executor.add_node(node_b);
631        let idx_c = executor.add_node(node_c);
632
633        executor.add_edge(idx_a, idx_b, GraphEdge::data_flow(idx_a.index(), idx_b.index(), DataFlowOp::InputToAttention, 0));
634        executor.add_edge(idx_b, idx_c, GraphEdge::data_flow(idx_b.index(), idx_c.index(), DataFlowOp::AttentionToOutput, 0));
635
636        let order = executor.topological_sort();
637
638        // A should come before B, B should come before C
639        assert!(order.iter().position(|&x| x == idx_a).unwrap() < order.iter().position(|&x| x == idx_b).unwrap());
640        assert!(order.iter().position(|&x| x == idx_b).unwrap() < order.iter().position(|&x| x == idx_c).unwrap());
641    }
642
643    #[test]
644    fn test_graph_transformer_creation() {
645        let transformer = GraphTransformer::new(2, 4, 256);
646
647        assert_eq!(transformer.num_layers, 2);
648        assert_eq!(transformer.num_heads, 4);
649        assert_eq!(transformer.hidden_dim, 256);
650    }
651
652    #[test]
653    fn test_graph_transformer_build() {
654        let mut transformer = GraphTransformer::new(2, 4, 256);
655        let input_ids = vec![1, 2, 3, 4];
656
657        transformer.build_graph(&input_ids);
658
659        assert!(transformer.num_nodes() > 0);
660        assert!(transformer.num_edges() > 0);
661    }
662
663    #[test]
664    fn test_to_dot_export() {
665        let mut executor = GraphExecutor::new();
666
667        let node1 = GraphNode::token_embedding(0, 1, 0, DenseTensor::zeros(vec![1, 4]));
668        let node2 = GraphNode::hidden_state(1, 0, 0, DenseTensor::zeros(vec![1, 4]));
669
670        let idx1 = executor.add_node(node1);
671        let idx2 = executor.add_node(node2);
672        executor.add_edge(idx1, idx2, GraphEdge::data_flow(idx1.index(), idx2.index(), DataFlowOp::InputToAttention, 0));
673
674        let dot = executor.to_dot();
675
676        assert!(dot.contains("digraph Transformer"));
677        assert!(dot.contains("n0"));
678        assert!(dot.contains("n1"));
679        assert!(dot.contains("n0 -> n1"));
680    }
681}