use std::collections::{HashMap, HashSet};
use crate::graph::Graph;
use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
use crate::node::NodeIndex;
use crate::tensor::DenseTensor;
use crate::tensor::traits::{TensorOps, TensorBase};
use super::nodes::{GraphNode, GraphNodeType};
use super::edges::{GraphEdge, GraphEdgeType, DataFlowOp, SkipType};
#[derive(Debug)]
pub struct GraphExecutor {
graph: Graph<GraphNode, GraphEdge>,
cache: HashMap<NodeIndex, DenseTensor>,
}
impl GraphExecutor {
pub fn new() -> Self {
Self {
graph: Graph::directed(),
cache: HashMap::new(),
}
}
pub fn add_node(&mut self, node: GraphNode) -> NodeIndex {
self.graph.add_node(node).unwrap_or(NodeIndex::invalid())
}
pub fn add_edge(&mut self, source: NodeIndex, target: NodeIndex, edge: GraphEdge) -> bool {
self.graph.add_edge(source, target, edge).is_ok()
}
pub fn num_nodes(&self) -> usize {
self.graph.node_count()
}
pub fn num_edges(&self) -> usize {
self.graph.edge_count()
}
pub fn topological_sort(&self) -> Vec<NodeIndex> {
let mut result = Vec::new();
let mut visited = HashSet::new();
fn visit(
node_idx: NodeIndex,
graph: &Graph<GraphNode, GraphEdge>,
visited: &mut HashSet<NodeIndex>,
result: &mut Vec<NodeIndex>,
) {
if visited.contains(&node_idx) {
return;
}
visited.insert(node_idx);
for neighbor in graph.neighbors(node_idx) {
visit(neighbor, graph, visited, result);
}
result.push(node_idx);
}
for node in self.graph.nodes() {
visit(node.index(), &self.graph, &mut visited, &mut result);
}
result.reverse();
result
}
pub fn forward(&mut self, input_ids: &[usize]) -> DenseTensor {
self.cache.clear();
let order = self.topological_sort();
for node_idx in order {
self.execute_node(node_idx, input_ids);
}
if let Some(last_node) = self.graph.nodes().last() {
if let Some(output) = self.cache.get(&last_node.index()) {
return output.clone();
}
}
DenseTensor::zeros(vec![1, 1])
}
fn execute_node(&mut self, node_idx: NodeIndex, input_ids: &[usize]) {
let node = if let Ok(node_ref) = self.graph.get_node(node_idx) {
node_ref.clone()
} else {
return;
};
let mut inputs: Vec<DenseTensor> = Vec::new();
let mut edge_messages: Vec<DenseTensor> = Vec::new();
let mut edge_weights: Vec<f64> = Vec::new();
for edge_ref in self.graph.edges() {
if edge_ref.target() == node_idx {
if let Some(source_tensor) = self.cache.get(&edge_ref.source()) {
inputs.push(source_tensor.clone());
if let Some(msg) = edge_ref.data().message() {
edge_messages.push(msg.clone());
}
if let Some(sa) = edge_ref.data().get_self_attention() {
edge_weights.push(sa.weight);
}
}
}
}
match node.node_type {
GraphNodeType::TokenEmbedding => {
if let Some(emb) = &node.token_embedding {
let position = emb.position;
if position < input_ids.len() {
let token_id = input_ids.get(position).copied().unwrap_or(0);
let hidden_dim = emb.embedding.shape()[1];
let emb_data: Vec<f64> = (0..hidden_dim)
.map(|i| {
let seed = (token_id * 1000 + i) as f64;
(seed.sin() * 1000.0).fract()
})
.collect();
let embedding = DenseTensor::new(emb_data, vec![1, hidden_dim]);
self.cache.insert(node_idx, embedding);
} else {
self.cache.insert(node_idx, emb.embedding.clone());
}
}
}
GraphNodeType::HiddenState => {
if let Some(state) = &node.hidden_state {
if inputs.is_empty() {
self.cache.insert(node_idx, state.state.clone());
} else {
let mut result = if edge_messages.is_empty() {
inputs[0].clone()
} else {
let qkv = &edge_messages[0];
if qkv.shape() == inputs[0].shape() {
inputs[0].add(qkv)
} else {
inputs[0].clone()
}
};
for (i, input) in inputs.iter().enumerate().skip(1) {
let tensor_to_add = if i < edge_messages.len() {
&edge_messages[i]
} else {
input
};
result = result.add(tensor_to_add);
}
self.cache.insert(node_idx, result);
}
}
}
GraphNodeType::AttentionOutput => {
if let Some(attn) = &node.attention_output {
if inputs.is_empty() {
self.cache.insert(node_idx, attn.output.clone());
} else {
let hidden_dim = attn.output.shape()[1];
let mut result = DenseTensor::zeros(vec![1, hidden_dim]);
for (i, input) in inputs.iter().enumerate() {
let weight = if i < edge_weights.len() {
edge_weights[i]
} else if i < attn.weights.len() {
attn.weights[i]
} else {
1.0 / inputs.len() as f64
};
let weighted = input.scale(weight);
result = result.add(&weighted);
}
self.cache.insert(node_idx, result);
}
}
}
GraphNodeType::FFNOutput => {
if let Some(ffn) = &node.ffn_output {
if inputs.is_empty() {
self.cache.insert(node_idx, ffn.output.clone());
} else {
let aggregated = if inputs.len() > 1 {
let mut result = inputs[0].clone();
for input in inputs.iter().skip(1) {
result = result.add(input);
}
result
} else {
inputs[0].clone()
};
self.cache.insert(node_idx, aggregated);
}
}
}
}
}
pub fn prune_weak_edges(&mut self, threshold: f64) -> usize {
let mut pruned_count = 0;
let edges_to_prune: Vec<_> = self.graph.edges()
.filter(|edge_ref| {
if let GraphEdgeType::SelfAttention = edge_ref.data().edge_type {
if let Some(sa) = &edge_ref.data().self_attention {
return sa.weight < threshold;
}
}
false
})
.map(|edge_ref| edge_ref.index())
.collect();
for edge_idx in edges_to_prune {
if self.graph.remove_edge(edge_idx).is_ok() {
pruned_count += 1;
}
}
pruned_count
}
pub fn to_dot(&self) -> String {
let mut dot = String::from("digraph Transformer {\n");
dot.push_str(" rankdir=TB;\n");
dot.push_str(" node [shape=box];\n\n");
for node in self.graph.nodes() {
let label = match node.data.node_type {
GraphNodeType::TokenEmbedding => format!("TokenEmbed[{}]", node.data.position),
GraphNodeType::HiddenState => format!("Hidden[L{}P{}]", node.data.layer, node.data.position),
GraphNodeType::AttentionOutput => format!("Attn[L{}H{}]", node.data.layer,
node.data.attention_output.as_ref().map(|a| a.head).unwrap_or(0)),
GraphNodeType::FFNOutput => format!("FFN[L{}P{}]", node.data.layer, node.data.position),
};
dot.push_str(&format!(" n{} [label=\"{}\"];\n", node.index().index(), label));
}
dot.push('\n');
for edge in self.graph.edges() {
let style = match edge.data().edge_type {
GraphEdgeType::SelfAttention => "style=solid, color=blue",
GraphEdgeType::DataFlow => "style=solid, color=green",
GraphEdgeType::Residual => "style=dashed, color=red",
};
dot.push_str(&format!(" n{} -> n{} [{}];\n",
edge.source().index(), edge.target().index(), style));
}
dot.push('}');
dot
}
pub fn clear(&mut self) {
self.graph = Graph::directed();
self.cache.clear();
}
}
impl Default for GraphExecutor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct GraphTransformer {
executor: GraphExecutor,
num_layers: usize,
num_heads: usize,
hidden_dim: usize,
}
impl GraphTransformer {
pub fn new(num_layers: usize, num_heads: usize, hidden_dim: usize) -> Self {
Self {
executor: GraphExecutor::new(),
num_layers,
num_heads,
hidden_dim,
}
}
pub fn build_graph(&mut self, input_ids: &[usize]) {
let seq_len = input_ids.len();
let head_dim = self.hidden_dim / self.num_heads;
let mut embedding_nodes = Vec::new();
for (i, &token_id) in input_ids.iter().enumerate() {
let embedding = DenseTensor::zeros(vec![1, self.hidden_dim]);
let node = GraphNode::token_embedding(i, token_id, i, embedding);
let node_idx = self.executor.add_node(node);
embedding_nodes.push(node_idx);
}
let mut prev_layer_nodes = embedding_nodes;
for layer in 0..self.num_layers {
let mut current_layer_nodes = Vec::new();
for pos in 0..seq_len {
let attended_positions: Vec<usize> = (0..seq_len).collect();
let weights = vec![1.0 / seq_len as f64; seq_len];
let output = DenseTensor::zeros(vec![1, self.hidden_dim]);
let attn_node = GraphNode::attention_output(
pos,
layer,
0,
pos,
attended_positions.clone(),
weights.clone(),
output,
);
let attn_node_idx = self.executor.add_node(attn_node);
current_layer_nodes.push(attn_node_idx);
for (src_pos, &src_node) in prev_layer_nodes.iter().enumerate() {
let weight = weights.get(src_pos).copied().unwrap_or(0.0);
let message = DenseTensor::zeros(vec![1, head_dim]);
let edge = GraphEdge::self_attention_with_message(
src_node.index(),
attn_node_idx.index(),
weight,
0,
layer,
message,
);
self.executor.add_edge(src_node, attn_node_idx, edge);
}
if let Some(&prev_node) = prev_layer_nodes.get(pos) {
let residual_tensor = DenseTensor::zeros(vec![1, self.hidden_dim]);
let residual_edge = GraphEdge::residual_with_tensor(
prev_node.index(),
attn_node_idx.index(),
layer,
SkipType::PreNorm,
residual_tensor,
);
self.executor.add_edge(prev_node, attn_node_idx, residual_edge);
}
}
let mut ffn_nodes = Vec::new();
for (pos, &attn_node) in current_layer_nodes.iter().enumerate() {
let output = DenseTensor::zeros(vec![1, self.hidden_dim]);
let ffn_node = GraphNode::ffn_output(pos, layer, pos, output);
let ffn_node_idx = self.executor.add_node(ffn_node);
ffn_nodes.push(ffn_node_idx);
let message = DenseTensor::zeros(vec![1, self.hidden_dim]);
let edge = GraphEdge::data_flow_with_message(
attn_node.index(),
ffn_node_idx.index(),
DataFlowOp::AttentionToOutput,
layer,
message,
);
self.executor.add_edge(attn_node, ffn_node_idx, edge);
let residual_tensor = DenseTensor::zeros(vec![1, self.hidden_dim]);
let residual_edge = GraphEdge::residual_with_tensor(
attn_node.index(),
ffn_node_idx.index(),
layer,
SkipType::PostNorm,
residual_tensor,
);
self.executor.add_edge(attn_node, ffn_node_idx, residual_edge);
}
prev_layer_nodes = ffn_nodes;
}
}
pub fn forward(&mut self, input_ids: &[usize]) -> DenseTensor {
self.executor.forward(input_ids)
}
pub fn num_nodes(&self) -> usize {
self.executor.num_nodes()
}
pub fn num_edges(&self) -> usize {
self.executor.num_edges()
}
pub fn prune_weak_edges(&mut self, threshold: f64) -> usize {
self.executor.prune_weak_edges(threshold)
}
pub fn to_dot(&self) -> String {
self.executor.to_dot()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_executor_creation() {
let executor = GraphExecutor::new();
assert_eq!(executor.num_nodes(), 0);
assert_eq!(executor.num_edges(), 0);
}
#[test]
fn test_graph_executor_add_node() {
let mut executor = GraphExecutor::new();
let embedding = DenseTensor::zeros(vec![1, 4]);
let node = GraphNode::token_embedding(0, 10, 0, embedding);
let node_idx = executor.add_node(node);
assert_eq!(executor.num_nodes(), 1);
assert!(node_idx.is_valid());
}
#[test]
fn test_graph_executor_add_edge() {
let mut executor = GraphExecutor::new();
let embedding1 = DenseTensor::zeros(vec![1, 4]);
let node1 = GraphNode::token_embedding(0, 10, 0, embedding1);
let node1_idx = executor.add_node(node1);
let embedding2 = DenseTensor::zeros(vec![1, 4]);
let node2 = GraphNode::token_embedding(1, 20, 1, embedding2);
let node2_idx = executor.add_node(node2);
let edge = GraphEdge::self_attention(node1_idx.index(), node2_idx.index(), 0.5, 0, 0);
let result = executor.add_edge(node1_idx, node2_idx, edge);
assert!(result);
assert_eq!(executor.num_edges(), 1);
}
#[test]
fn test_topological_sort() {
let mut executor = GraphExecutor::new();
let node_a = GraphNode::token_embedding(0, 1, 0, DenseTensor::zeros(vec![1, 4]));
let node_b = GraphNode::hidden_state(1, 0, 0, DenseTensor::zeros(vec![1, 4]));
let node_c = GraphNode::ffn_output(2, 0, 0, DenseTensor::zeros(vec![1, 4]));
let idx_a = executor.add_node(node_a);
let idx_b = executor.add_node(node_b);
let idx_c = executor.add_node(node_c);
executor.add_edge(idx_a, idx_b, GraphEdge::data_flow(idx_a.index(), idx_b.index(), DataFlowOp::InputToAttention, 0));
executor.add_edge(idx_b, idx_c, GraphEdge::data_flow(idx_b.index(), idx_c.index(), DataFlowOp::AttentionToOutput, 0));
let order = executor.topological_sort();
assert!(order.iter().position(|&x| x == idx_a).unwrap() < order.iter().position(|&x| x == idx_b).unwrap());
assert!(order.iter().position(|&x| x == idx_b).unwrap() < order.iter().position(|&x| x == idx_c).unwrap());
}
#[test]
fn test_graph_transformer_creation() {
let transformer = GraphTransformer::new(2, 4, 256);
assert_eq!(transformer.num_layers, 2);
assert_eq!(transformer.num_heads, 4);
assert_eq!(transformer.hidden_dim, 256);
}
#[test]
fn test_graph_transformer_build() {
let mut transformer = GraphTransformer::new(2, 4, 256);
let input_ids = vec![1, 2, 3, 4];
transformer.build_graph(&input_ids);
assert!(transformer.num_nodes() > 0);
assert!(transformer.num_edges() > 0);
}
#[test]
fn test_to_dot_export() {
let mut executor = GraphExecutor::new();
let node1 = GraphNode::token_embedding(0, 1, 0, DenseTensor::zeros(vec![1, 4]));
let node2 = GraphNode::hidden_state(1, 0, 0, DenseTensor::zeros(vec![1, 4]));
let idx1 = executor.add_node(node1);
let idx2 = executor.add_node(node2);
executor.add_edge(idx1, idx2, GraphEdge::data_flow(idx1.index(), idx2.index(), DataFlowOp::InputToAttention, 0));
let dot = executor.to_dot();
assert!(dot.contains("digraph Transformer"));
assert!(dot.contains("n0"));
assert!(dot.contains("n1"));
assert!(dot.contains("n0 -> n1"));
}
}