1use std::collections::{HashMap, HashSet};
91use crate::graph::Graph;
92use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
93use crate::node::NodeIndex;
94use crate::tensor::DenseTensor;
95use crate::tensor::traits::{TensorOps, TensorBase};
96use super::nodes::{GraphNode, GraphNodeType};
97use super::edges::{GraphEdge, GraphEdgeType, DataFlowOp, SkipType};
98
99#[derive(Debug)]
101pub struct GraphExecutor {
102 graph: Graph<GraphNode, GraphEdge>,
104 cache: HashMap<NodeIndex, DenseTensor>,
106}
107
108impl GraphExecutor {
109 pub fn new() -> Self {
111 Self {
112 graph: Graph::directed(),
113 cache: HashMap::new(),
114 }
115 }
116
117 pub fn add_node(&mut self, node: GraphNode) -> NodeIndex {
119 self.graph.add_node(node).unwrap_or(NodeIndex::invalid())
120 }
121
122 pub fn add_edge(&mut self, source: NodeIndex, target: NodeIndex, edge: GraphEdge) -> bool {
124 self.graph.add_edge(source, target, edge).is_ok()
125 }
126
127 pub fn num_nodes(&self) -> usize {
129 self.graph.node_count()
130 }
131
132 pub fn num_edges(&self) -> usize {
134 self.graph.edge_count()
135 }
136
137 pub fn topological_sort(&self) -> Vec<NodeIndex> {
139 let mut result = Vec::new();
140 let mut visited = HashSet::new();
141
142 fn visit(
143 node_idx: NodeIndex,
144 graph: &Graph<GraphNode, GraphEdge>,
145 visited: &mut HashSet<NodeIndex>,
146 result: &mut Vec<NodeIndex>,
147 ) {
148 if visited.contains(&node_idx) {
149 return;
150 }
151 visited.insert(node_idx);
152
153 for neighbor in graph.neighbors(node_idx) {
155 visit(neighbor, graph, visited, result);
156 }
157
158 result.push(node_idx);
159 }
160
161 for node in self.graph.nodes() {
162 visit(node.index(), &self.graph, &mut visited, &mut result);
163 }
164
165 result.reverse();
166 result
167 }
168
169 pub fn forward(&mut self, input_ids: &[usize]) -> DenseTensor {
177 self.cache.clear();
179
180 let order = self.topological_sort();
182
183 for node_idx in order {
185 self.execute_node(node_idx, input_ids);
186 }
187
188 if let Some(last_node) = self.graph.nodes().last() {
190 if let Some(output) = self.cache.get(&last_node.index()) {
191 return output.clone();
192 }
193 }
194
195 DenseTensor::zeros(vec![1, 1])
197 }
198
199 fn execute_node(&mut self, node_idx: NodeIndex, input_ids: &[usize]) {
201 let node = if let Ok(node_ref) = self.graph.get_node(node_idx) {
203 node_ref.clone()
204 } else {
205 return;
206 };
207
208 let mut inputs: Vec<DenseTensor> = Vec::new();
210 let mut edge_messages: Vec<DenseTensor> = Vec::new();
211 let mut edge_weights: Vec<f64> = Vec::new();
212
213 for edge_ref in self.graph.edges() {
214 if edge_ref.target() == node_idx {
215 if let Some(source_tensor) = self.cache.get(&edge_ref.source()) {
217 inputs.push(source_tensor.clone());
218
219 if let Some(msg) = edge_ref.data().message() {
221 edge_messages.push(msg.clone());
222 }
223
224 if let Some(sa) = edge_ref.data().get_self_attention() {
226 edge_weights.push(sa.weight);
227 }
228 }
229 }
230 }
231
232 match node.node_type {
234 GraphNodeType::TokenEmbedding => {
235 if let Some(emb) = &node.token_embedding {
237 let position = emb.position;
239 if position < input_ids.len() {
240 let token_id = input_ids.get(position).copied().unwrap_or(0);
242 let hidden_dim = emb.embedding.shape()[1];
243
244 let emb_data: Vec<f64> = (0..hidden_dim)
246 .map(|i| {
247 let seed = (token_id * 1000 + i) as f64;
248 (seed.sin() * 1000.0).fract()
249 })
250 .collect();
251
252 let embedding = DenseTensor::new(emb_data, vec![1, hidden_dim]);
253 self.cache.insert(node_idx, embedding);
254 } else {
255 self.cache.insert(node_idx, emb.embedding.clone());
256 }
257 }
258 }
259 GraphNodeType::HiddenState => {
260 if let Some(state) = &node.hidden_state {
262 if inputs.is_empty() {
263 self.cache.insert(node_idx, state.state.clone());
264 } else {
265 let mut result = if edge_messages.is_empty() {
267 inputs[0].clone()
268 } else {
269 let qkv = &edge_messages[0];
271 if qkv.shape() == inputs[0].shape() {
272 inputs[0].add(qkv)
273 } else {
274 inputs[0].clone()
275 }
276 };
277
278 for (i, input) in inputs.iter().enumerate().skip(1) {
279 let tensor_to_add = if i < edge_messages.len() {
280 &edge_messages[i]
281 } else {
282 input
283 };
284 result = result.add(tensor_to_add);
285 }
286 self.cache.insert(node_idx, result);
287 }
288 }
289 }
290 GraphNodeType::AttentionOutput => {
291 if let Some(attn) = &node.attention_output {
293 if inputs.is_empty() {
294 self.cache.insert(node_idx, attn.output.clone());
295 } else {
296 let hidden_dim = attn.output.shape()[1];
298 let mut result = DenseTensor::zeros(vec![1, hidden_dim]);
299
300 for (i, input) in inputs.iter().enumerate() {
301 let weight = if i < edge_weights.len() {
303 edge_weights[i]
304 } else if i < attn.weights.len() {
305 attn.weights[i]
306 } else {
307 1.0 / inputs.len() as f64
308 };
309
310 let weighted = input.scale(weight);
312 result = result.add(&weighted);
313 }
314 self.cache.insert(node_idx, result);
315 }
316 }
317 }
318 GraphNodeType::FFNOutput => {
319 if let Some(ffn) = &node.ffn_output {
321 if inputs.is_empty() {
322 self.cache.insert(node_idx, ffn.output.clone());
323 } else {
324 let aggregated = if inputs.len() > 1 {
326 let mut result = inputs[0].clone();
327 for input in inputs.iter().skip(1) {
328 result = result.add(input);
329 }
330 result
331 } else {
332 inputs[0].clone()
333 };
334
335 self.cache.insert(node_idx, aggregated);
338 }
339 }
340 }
341 }
342 }
343
344 pub fn prune_weak_edges(&mut self, threshold: f64) -> usize {
349 let mut pruned_count = 0;
350
351 let edges_to_prune: Vec<_> = self.graph.edges()
353 .filter(|edge_ref| {
354 if let GraphEdgeType::SelfAttention = edge_ref.data().edge_type {
355 if let Some(sa) = &edge_ref.data().self_attention {
356 return sa.weight < threshold;
357 }
358 }
359 false
360 })
361 .map(|edge_ref| edge_ref.index())
362 .collect();
363
364 for edge_idx in edges_to_prune {
366 if self.graph.remove_edge(edge_idx).is_ok() {
367 pruned_count += 1;
368 }
369 }
370
371 pruned_count
372 }
373
374 pub fn to_dot(&self) -> String {
376 let mut dot = String::from("digraph Transformer {\n");
377 dot.push_str(" rankdir=TB;\n");
378 dot.push_str(" node [shape=box];\n\n");
379
380 for node in self.graph.nodes() {
382 let label = match node.data.node_type {
383 GraphNodeType::TokenEmbedding => format!("TokenEmbed[{}]", node.data.position),
384 GraphNodeType::HiddenState => format!("Hidden[L{}P{}]", node.data.layer, node.data.position),
385 GraphNodeType::AttentionOutput => format!("Attn[L{}H{}]", node.data.layer,
386 node.data.attention_output.as_ref().map(|a| a.head).unwrap_or(0)),
387 GraphNodeType::FFNOutput => format!("FFN[L{}P{}]", node.data.layer, node.data.position),
388 };
389 dot.push_str(&format!(" n{} [label=\"{}\"];\n", node.index().index(), label));
390 }
391
392 dot.push('\n');
393
394 for edge in self.graph.edges() {
396 let style = match edge.data().edge_type {
397 GraphEdgeType::SelfAttention => "style=solid, color=blue",
398 GraphEdgeType::DataFlow => "style=solid, color=green",
399 GraphEdgeType::Residual => "style=dashed, color=red",
400 };
401 dot.push_str(&format!(" n{} -> n{} [{}];\n",
402 edge.source().index(), edge.target().index(), style));
403 }
404
405 dot.push('}');
406 dot
407 }
408
409 pub fn clear(&mut self) {
411 self.graph = Graph::directed();
412 self.cache.clear();
413 }
414}
415
416impl Default for GraphExecutor {
417 fn default() -> Self {
418 Self::new()
419 }
420}
421
422#[derive(Debug)]
424pub struct GraphTransformer {
425 executor: GraphExecutor,
427 num_layers: usize,
429 num_heads: usize,
431 hidden_dim: usize,
433}
434
435impl GraphTransformer {
436 pub fn new(num_layers: usize, num_heads: usize, hidden_dim: usize) -> Self {
438 Self {
439 executor: GraphExecutor::new(),
440 num_layers,
441 num_heads,
442 hidden_dim,
443 }
444 }
445
446 pub fn build_graph(&mut self, input_ids: &[usize]) {
451 let seq_len = input_ids.len();
452 let head_dim = self.hidden_dim / self.num_heads;
453
454 let mut embedding_nodes = Vec::new();
456 for (i, &token_id) in input_ids.iter().enumerate() {
457 let embedding = DenseTensor::zeros(vec![1, self.hidden_dim]);
458 let node = GraphNode::token_embedding(i, token_id, i, embedding);
459 let node_idx = self.executor.add_node(node);
460 embedding_nodes.push(node_idx);
461 }
462
463 let mut prev_layer_nodes = embedding_nodes;
465
466 for layer in 0..self.num_layers {
467 let mut current_layer_nodes = Vec::new();
468
469 for pos in 0..seq_len {
471 let attended_positions: Vec<usize> = (0..seq_len).collect();
473 let weights = vec![1.0 / seq_len as f64; seq_len];
474 let output = DenseTensor::zeros(vec![1, self.hidden_dim]);
475
476 let attn_node = GraphNode::attention_output(
477 pos,
478 layer,
479 0,
480 pos,
481 attended_positions.clone(),
482 weights.clone(),
483 output,
484 );
485 let attn_node_idx = self.executor.add_node(attn_node);
486 current_layer_nodes.push(attn_node_idx);
487
488 for (src_pos, &src_node) in prev_layer_nodes.iter().enumerate() {
490 let weight = weights.get(src_pos).copied().unwrap_or(0.0);
491 let message = DenseTensor::zeros(vec![1, head_dim]);
493 let edge = GraphEdge::self_attention_with_message(
494 src_node.index(),
495 attn_node_idx.index(),
496 weight,
497 0,
498 layer,
499 message,
500 );
501 self.executor.add_edge(src_node, attn_node_idx, edge);
502 }
503
504 if let Some(&prev_node) = prev_layer_nodes.get(pos) {
506 let residual_tensor = DenseTensor::zeros(vec![1, self.hidden_dim]);
507 let residual_edge = GraphEdge::residual_with_tensor(
508 prev_node.index(),
509 attn_node_idx.index(),
510 layer,
511 SkipType::PreNorm,
512 residual_tensor,
513 );
514 self.executor.add_edge(prev_node, attn_node_idx, residual_edge);
515 }
516 }
517
518 let mut ffn_nodes = Vec::new();
520 for (pos, &attn_node) in current_layer_nodes.iter().enumerate() {
521 let output = DenseTensor::zeros(vec![1, self.hidden_dim]);
522 let ffn_node = GraphNode::ffn_output(pos, layer, pos, output);
523 let ffn_node_idx = self.executor.add_node(ffn_node);
524 ffn_nodes.push(ffn_node_idx);
525
526 let message = DenseTensor::zeros(vec![1, self.hidden_dim]);
528 let edge = GraphEdge::data_flow_with_message(
529 attn_node.index(),
530 ffn_node_idx.index(),
531 DataFlowOp::AttentionToOutput,
532 layer,
533 message,
534 );
535 self.executor.add_edge(attn_node, ffn_node_idx, edge);
536
537 let residual_tensor = DenseTensor::zeros(vec![1, self.hidden_dim]);
539 let residual_edge = GraphEdge::residual_with_tensor(
540 attn_node.index(),
541 ffn_node_idx.index(),
542 layer,
543 SkipType::PostNorm,
544 residual_tensor,
545 );
546 self.executor.add_edge(attn_node, ffn_node_idx, residual_edge);
547 }
548
549 prev_layer_nodes = ffn_nodes;
550 }
551 }
552
553 pub fn forward(&mut self, input_ids: &[usize]) -> DenseTensor {
555 self.executor.forward(input_ids)
556 }
557
558 pub fn num_nodes(&self) -> usize {
560 self.executor.num_nodes()
561 }
562
563 pub fn num_edges(&self) -> usize {
565 self.executor.num_edges()
566 }
567
568 pub fn prune_weak_edges(&mut self, threshold: f64) -> usize {
570 self.executor.prune_weak_edges(threshold)
571 }
572
573 pub fn to_dot(&self) -> String {
575 self.executor.to_dot()
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn test_graph_executor_creation() {
585 let executor = GraphExecutor::new();
586 assert_eq!(executor.num_nodes(), 0);
587 assert_eq!(executor.num_edges(), 0);
588 }
589
590 #[test]
591 fn test_graph_executor_add_node() {
592 let mut executor = GraphExecutor::new();
593 let embedding = DenseTensor::zeros(vec![1, 4]);
594 let node = GraphNode::token_embedding(0, 10, 0, embedding);
595 let node_idx = executor.add_node(node);
596
597 assert_eq!(executor.num_nodes(), 1);
598 assert!(node_idx.is_valid());
599 }
600
601 #[test]
602 fn test_graph_executor_add_edge() {
603 let mut executor = GraphExecutor::new();
604
605 let embedding1 = DenseTensor::zeros(vec![1, 4]);
606 let node1 = GraphNode::token_embedding(0, 10, 0, embedding1);
607 let node1_idx = executor.add_node(node1);
608
609 let embedding2 = DenseTensor::zeros(vec![1, 4]);
610 let node2 = GraphNode::token_embedding(1, 20, 1, embedding2);
611 let node2_idx = executor.add_node(node2);
612
613 let edge = GraphEdge::self_attention(node1_idx.index(), node2_idx.index(), 0.5, 0, 0);
614 let result = executor.add_edge(node1_idx, node2_idx, edge);
615
616 assert!(result);
617 assert_eq!(executor.num_edges(), 1);
618 }
619
620 #[test]
621 fn test_topological_sort() {
622 let mut executor = GraphExecutor::new();
623
624 let node_a = GraphNode::token_embedding(0, 1, 0, DenseTensor::zeros(vec![1, 4]));
626 let node_b = GraphNode::hidden_state(1, 0, 0, DenseTensor::zeros(vec![1, 4]));
627 let node_c = GraphNode::ffn_output(2, 0, 0, DenseTensor::zeros(vec![1, 4]));
628
629 let idx_a = executor.add_node(node_a);
630 let idx_b = executor.add_node(node_b);
631 let idx_c = executor.add_node(node_c);
632
633 executor.add_edge(idx_a, idx_b, GraphEdge::data_flow(idx_a.index(), idx_b.index(), DataFlowOp::InputToAttention, 0));
634 executor.add_edge(idx_b, idx_c, GraphEdge::data_flow(idx_b.index(), idx_c.index(), DataFlowOp::AttentionToOutput, 0));
635
636 let order = executor.topological_sort();
637
638 assert!(order.iter().position(|&x| x == idx_a).unwrap() < order.iter().position(|&x| x == idx_b).unwrap());
640 assert!(order.iter().position(|&x| x == idx_b).unwrap() < order.iter().position(|&x| x == idx_c).unwrap());
641 }
642
643 #[test]
644 fn test_graph_transformer_creation() {
645 let transformer = GraphTransformer::new(2, 4, 256);
646
647 assert_eq!(transformer.num_layers, 2);
648 assert_eq!(transformer.num_heads, 4);
649 assert_eq!(transformer.hidden_dim, 256);
650 }
651
652 #[test]
653 fn test_graph_transformer_build() {
654 let mut transformer = GraphTransformer::new(2, 4, 256);
655 let input_ids = vec![1, 2, 3, 4];
656
657 transformer.build_graph(&input_ids);
658
659 assert!(transformer.num_nodes() > 0);
660 assert!(transformer.num_edges() > 0);
661 }
662
663 #[test]
664 fn test_to_dot_export() {
665 let mut executor = GraphExecutor::new();
666
667 let node1 = GraphNode::token_embedding(0, 1, 0, DenseTensor::zeros(vec![1, 4]));
668 let node2 = GraphNode::hidden_state(1, 0, 0, DenseTensor::zeros(vec![1, 4]));
669
670 let idx1 = executor.add_node(node1);
671 let idx2 = executor.add_node(node2);
672 executor.add_edge(idx1, idx2, GraphEdge::data_flow(idx1.index(), idx2.index(), DataFlowOp::InputToAttention, 0));
673
674 let dot = executor.to_dot();
675
676 assert!(dot.contains("digraph Transformer"));
677 assert!(dot.contains("n0"));
678 assert!(dot.contains("n1"));
679 assert!(dot.contains("n0 -> n1"));
680 }
681}