use ndarray::ArrayD;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
pub type NodeId = usize;
pub type AsgId = usize;
pub type Shape = Vec<usize>;
pub type AsgResult<T> = std::result::Result<T, AsgError>;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum AsgError {
#[error(
"Node with ID {0} not found in graph. \
Verify that the node was added to the graph using add_node() before use."
)]
NodeNotFound(NodeId),
#[error(
"Input with name '{0}' not found in graph. \
Ensure that an Input node was created with this name."
)]
InputNotFound(String),
#[error("Invalid graph structure: {0}")]
InvalidGraph(String),
#[error(
"Cyclic dependency detected in graph. \
ASG must be a directed acyclic graph (DAG)."
)]
CyclicDependency,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum DType {
F32,
I64,
Bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Value {
Tensor(ArrayD<f32>),
Integer(i64),
Float(f32),
Boolean(bool),
Text(String),
Unit,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Node {
pub id: NodeId,
pub name: Option<String>,
pub node_type: NodeType,
pub shape: Option<Shape>,
pub dtype: Option<DType>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum NodeType {
Input {
name: String,
},
Parameter {
name: String,
},
Literal(Value),
External {
name: String,
source_asg_id: AsgId,
source_node_id: NodeId,
},
Add(NodeId, NodeId),
Subtract(NodeId, NodeId),
Multiply(NodeId, NodeId),
Divide(NodeId, NodeId),
MatrixMultiply(NodeId, NodeId),
GreaterThan(NodeId, NodeId),
ReLU(NodeId),
Sigmoid(NodeId),
Log(NodeId),
Sqrt(NodeId),
Exp(NodeId),
Abs(NodeId),
Neg(NodeId),
Power(NodeId, NodeId), Softmax(NodeId),
Tanh(NodeId),
LeakyReLU(NodeId, f32), GELU(NodeId),
SiLU(NodeId), ELU(NodeId, f32), Softplus(NodeId, f32), Clamp(NodeId, f32, f32),
Sum(NodeId), Mean(NodeId), Variance(NodeId),
Embedding {
indices: NodeId,
weight: NodeId,
},
EmbeddingGrad {
grad_output: NodeId,
indices: NodeId,
num_embeddings: usize,
},
Reshape(NodeId, NodeId), Transpose(NodeId, usize, usize), Slice {
input: NodeId,
axis: usize,
start: usize,
end: usize,
},
Concat {
inputs: Vec<NodeId>,
axis: usize,
},
SliceBackward {
grad_output: NodeId,
axis: usize,
start: usize,
full_size: usize,
},
DropoutMask {
shape_provider: NodeId,
p: f32,
},
MeanAxis {
input: NodeId,
axis: usize,
keepdims: bool,
},
VarianceAxis {
input: NodeId,
axis: usize,
keepdims: bool,
},
BatchNorm {
input: NodeId,
gamma: NodeId,
beta: NodeId,
eps: f32,
channel_axis: usize,
},
BatchNormBackward {
grad_output: NodeId,
input: NodeId,
gamma: NodeId,
eps: f32,
channel_axis: usize,
},
BatchNormGradGamma {
grad_output: NodeId,
input: NodeId,
eps: f32,
channel_axis: usize,
},
BatchNormGradBeta {
grad_output: NodeId,
channel_axis: usize,
},
Broadcast(NodeId, NodeId),
ReduceSumTo(NodeId, NodeId),
Conv2d {
input: NodeId,
weight: NodeId,
bias: Option<NodeId>,
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
},
ConvTranspose2d {
input: NodeId,
weight: NodeId,
bias: Option<NodeId>,
stride: (usize, usize),
padding: (usize, usize),
output_padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
},
MaxPool2d {
input: NodeId,
kernel_size: (usize, usize),
stride: (usize, usize),
},
MaxUnpool2d {
input: NodeId,
original_input: NodeId,
kernel_size: (usize, usize),
stride: (usize, usize),
},
AvgPool2d {
input: NodeId,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
},
AvgUnpool2d {
input: NodeId,
original_input: NodeId,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
},
AdaptiveAvgPool2d {
input: NodeId,
output_size: (usize, usize),
},
LayerNorm {
input: NodeId,
gamma: NodeId,
beta: NodeId,
eps: f32,
},
LayerNormBackward {
grad_output: NodeId,
input: NodeId,
gamma: NodeId,
eps: f32,
},
LayerNormGradGamma {
grad_output: NodeId,
input: NodeId,
eps: f32,
},
LayerNormGradBeta {
grad_output: NodeId,
},
Conv2dBackwardInput {
grad_output: NodeId,
weight: NodeId,
input_shape: (usize, usize, usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
},
Conv2dBackwardWeight {
grad_output: NodeId,
input: NodeId,
weight_shape: (usize, usize, usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
},
If {
condition: NodeId,
then_asg: AsgId,
else_asg: AsgId,
},
ForLoop {
iterable: NodeId, loop_body_asg: AsgId, },
FunctionDefinition {
name: String,
body_asg: AsgId,
},
FunctionCall {
function_id: NodeId, args: Vec<NodeId>,
},
Print(NodeId),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Asg {
pub id: AsgId,
pub name: Option<String>,
pub nodes: HashMap<NodeId, Node>,
pub inputs: Vec<NodeId>,
pub outputs: Vec<NodeId>,
}
impl Asg {
pub fn new(id: AsgId, name: Option<String>) -> Self {
Self {
id,
name,
nodes: HashMap::new(),
inputs: Vec::new(),
outputs: Vec::new(),
}
}
pub fn add_node(&mut self, name: Option<String>, node_type: NodeType) -> NodeId {
let new_id = self.nodes.len();
let node = Node {
id: new_id,
name,
node_type,
shape: None,
dtype: None,
};
self.nodes.insert(new_id, node);
new_id
}
pub fn set_inputs(&mut self, inputs: Vec<NodeId>) {
self.inputs = inputs;
}
pub fn set_outputs(&mut self, outputs: Vec<NodeId>) {
self.outputs = outputs;
}
pub fn set_output(&mut self, output: NodeId) {
self.outputs = vec![output];
}
pub fn get_node(&self, id: NodeId) -> AsgResult<&Node> {
self.nodes.get(&id).ok_or(AsgError::NodeNotFound(id))
}
pub fn get_node_mut(&mut self, id: NodeId) -> AsgResult<&mut Node> {
self.nodes.get_mut(&id).ok_or(AsgError::NodeNotFound(id))
}
}