use crate::{device::Device, dtype::DType, error::TensorError, shape::Shape, tensor::Tensor};
use std::collections::HashMap;
pub type NodeId = u64;
pub type EdgeId = u64;
#[derive(Clone, Debug)]
pub struct GraphNode {
pub id: NodeId,
pub name: String,
pub op_type: NodeType,
pub device: Device,
pub inputs: Vec<EdgeId>,
pub outputs: Vec<EdgeId>,
pub attributes: HashMap<String, AttributeValue>,
}
#[derive(Clone, Debug, PartialEq)]
pub enum NodeType {
Operation(String), Variable { dtype: DType, shape: Shape },
Placeholder { dtype: DType, shape: Shape },
Constant,
}
#[derive(Clone, Debug)]
pub struct GraphEdge {
pub id: EdgeId,
pub from_node: NodeId,
pub to_node: NodeId,
pub from_output: usize, pub to_input: usize, pub dtype: DType,
pub shape: Shape,
pub is_control: bool, }
#[derive(Clone, Debug)]
pub enum AttributeValue {
String(String),
Int(i64),
Float(f64),
Bool(bool),
IntList(Vec<i64>),
FloatList(Vec<f64>),
Shape(Shape),
Tensor(Tensor<f32>), }
#[derive(Debug, Clone)]
pub struct Graph {
pub(crate) nodes: HashMap<NodeId, GraphNode>,
pub(crate) edges: HashMap<EdgeId, GraphEdge>,
pub(crate) next_node_id: NodeId,
pub(crate) next_edge_id: EdgeId,
pub(crate) name_to_node: HashMap<String, NodeId>,
pub(crate) topological_order: Option<Vec<NodeId>>,
pub(crate) version: u64,
}
impl Graph {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: HashMap::new(),
next_node_id: 0,
next_edge_id: 0,
name_to_node: HashMap::new(),
topological_order: None,
version: 0,
}
}
pub fn add_node(
&mut self,
name: String,
op_type: NodeType,
device: Device,
attributes: HashMap<String, AttributeValue>,
) -> Result<NodeId, TensorError> {
if self.name_to_node.contains_key(&name) {
return Err(TensorError::invalid_argument(format!(
"Node name '{name}' already exists"
)));
}
let node_id = self.next_node_id;
self.next_node_id += 1;
let node = GraphNode {
id: node_id,
name: name.clone(),
op_type,
device,
inputs: Vec::new(),
outputs: Vec::new(),
attributes,
};
self.nodes.insert(node_id, node);
self.name_to_node.insert(name, node_id);
self.topological_order = None; self.version += 1;
Ok(node_id)
}
#[allow(clippy::too_many_arguments)]
pub fn add_edge(
&mut self,
from_node: NodeId,
to_node: NodeId,
from_output: usize,
to_input: usize,
dtype: DType,
shape: Shape,
is_control: bool,
) -> Result<EdgeId, TensorError> {
if !self.nodes.contains_key(&from_node) {
return Err(TensorError::invalid_argument(format!(
"Source node {from_node} not found"
)));
}
if !self.nodes.contains_key(&to_node) {
return Err(TensorError::invalid_argument(format!(
"Destination node {to_node} not found"
)));
}
if from_node == to_node {
return Err(TensorError::invalid_argument(
"Self-loops are not allowed".to_string(),
));
}
let edge_id = self.next_edge_id;
self.next_edge_id += 1;
let edge = GraphEdge {
id: edge_id,
from_node,
to_node,
from_output,
to_input,
dtype,
shape,
is_control,
};
self.edges.insert(edge_id, edge);
self.nodes
.get_mut(&from_node)
.expect("Source node must exist after validation")
.outputs
.push(edge_id);
self.nodes
.get_mut(&to_node)
.expect("Destination node must exist after validation")
.inputs
.push(edge_id);
self.topological_order = None; self.version += 1;
Ok(edge_id)
}
pub fn get_node(&self, node_id: NodeId) -> Option<&GraphNode> {
self.nodes.get(&node_id)
}
pub fn get_node_by_name(&self, name: &str) -> Option<&GraphNode> {
self.name_to_node
.get(name)
.and_then(|&id| self.nodes.get(&id))
}
pub fn get_edge(&self, edge_id: EdgeId) -> Option<&GraphEdge> {
self.edges.get(&edge_id)
}
pub fn nodes(&self) -> impl Iterator<Item = &GraphNode> {
self.nodes.values()
}
pub fn edges(&self) -> impl Iterator<Item = &GraphEdge> {
self.edges.values()
}
pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut GraphNode> {
self.nodes.get_mut(&node_id)
}
pub fn get_edge_mut(&mut self, edge_id: EdgeId) -> Option<&mut GraphEdge> {
self.edges.get_mut(&edge_id)
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
}
impl Default for Graph {
fn default() -> Self {
Self::new()
}
}