use crate::error::{NeuralError, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TensorDtype {
F32,
F64,
F16,
I32,
I64,
Bool,
}
impl TensorDtype {
pub fn onnx_str(&self) -> &str {
match self {
TensorDtype::F32 => "float",
TensorDtype::F64 => "double",
TensorDtype::F16 => "float16",
TensorDtype::I32 => "int32",
TensorDtype::I64 => "int64",
TensorDtype::Bool => "bool",
}
}
pub fn from_onnx_str(s: &str) -> Result<Self> {
match s {
"float" | "F32" | "f32" => Ok(TensorDtype::F32),
"double" | "F64" | "f64" => Ok(TensorDtype::F64),
"float16" | "F16" | "f16" => Ok(TensorDtype::F16),
"int32" | "I32" | "i32" => Ok(TensorDtype::I32),
"int64" | "I64" | "i64" => Ok(TensorDtype::I64),
"bool" | "Bool" => Ok(TensorDtype::Bool),
other => Err(NeuralError::DeserializationError(format!(
"Unknown tensor dtype: {other}"
))),
}
}
pub fn element_size(&self) -> usize {
match self {
TensorDtype::F32 => 4,
TensorDtype::F64 => 8,
TensorDtype::F16 => 2,
TensorDtype::I32 => 4,
TensorDtype::I64 => 8,
TensorDtype::Bool => 1,
}
}
}
impl Default for TensorDtype {
fn default() -> Self {
TensorDtype::F32
}
}
impl std::fmt::Display for TensorDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.onnx_str())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TensorShape {
pub dims: Vec<Option<i64>>,
pub dtype: TensorDtype,
}
impl TensorShape {
pub fn new(dims: Vec<i64>, dtype: TensorDtype) -> Self {
Self {
dims: dims.into_iter().map(Some).collect(),
dtype,
}
}
pub fn with_batch_dim(spatial_dims: Vec<i64>, dtype: TensorDtype) -> Self {
let mut dims = vec![None];
dims.extend(spatial_dims.into_iter().map(Some));
Self { dims, dtype }
}
pub fn num_elements(&self) -> Option<i64> {
let mut product = 1i64;
for d in &self.dims {
match d {
Some(v) => product = product.checked_mul(*v)?,
None => return None,
}
}
Some(product)
}
pub fn rank(&self) -> usize {
self.dims.len()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PaddingSpec {
pub top: usize,
pub bottom: usize,
pub left: usize,
pub right: usize,
}
impl PaddingSpec {
pub fn same(value: usize) -> Self {
Self {
top: value,
bottom: value,
left: value,
right: value,
}
}
pub fn zero() -> Self {
Self::same(0)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DenseAttrs {
pub in_features: usize,
pub out_features: usize,
pub use_bias: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Conv2DAttrs {
pub in_channels: usize,
pub out_channels: usize,
pub kernel_h: usize,
pub kernel_w: usize,
pub stride_h: usize,
pub stride_w: usize,
pub padding: PaddingSpec,
pub dilation: usize,
pub groups: usize,
pub use_bias: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BatchNormAttrs {
pub num_features: usize,
pub eps: f64,
pub momentum: f64,
pub affine: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LayerNormAttrs {
pub normalized_shape: Vec<usize>,
pub eps: f64,
pub elementwise_affine: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Pool2DAttrs {
pub kernel_h: usize,
pub kernel_w: usize,
pub stride_h: Option<usize>,
pub stride_w: Option<usize>,
pub padding: PaddingSpec,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DropoutAttrs {
pub p: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ReshapeAttrs {
pub shape: Vec<i64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AttentionAttrs {
pub embed_dim: usize,
pub num_heads: usize,
pub attn_dropout: f64,
pub use_bias: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingAttrs {
pub num_embeddings: usize,
pub embedding_dim: usize,
pub padding_idx: Option<i64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ActivationAttrs {
pub function: String,
pub alpha: Option<f64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AddAttrs {
pub learnable_scale: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "node_type", rename_all = "snake_case")]
pub enum GraphNode {
Input {
name: String,
shape: TensorShape,
},
Output {
name: String,
shape: TensorShape,
},
Dense {
name: String,
attrs: DenseAttrs,
output_shape: TensorShape,
},
Conv2D {
name: String,
attrs: Conv2DAttrs,
output_shape: TensorShape,
},
DepthwiseConv2D {
name: String,
attrs: Conv2DAttrs,
output_shape: TensorShape,
},
BatchNorm {
name: String,
attrs: BatchNormAttrs,
output_shape: TensorShape,
},
LayerNorm {
name: String,
attrs: LayerNormAttrs,
output_shape: TensorShape,
},
Activation {
name: String,
attrs: ActivationAttrs,
output_shape: TensorShape,
},
MaxPool2D {
name: String,
attrs: Pool2DAttrs,
output_shape: TensorShape,
},
AvgPool2D {
name: String,
attrs: Pool2DAttrs,
output_shape: TensorShape,
},
GlobalAvgPool {
name: String,
output_shape: TensorShape,
},
Dropout {
name: String,
attrs: DropoutAttrs,
output_shape: TensorShape,
},
Reshape {
name: String,
attrs: ReshapeAttrs,
output_shape: TensorShape,
},
Flatten {
name: String,
output_shape: TensorShape,
},
Add {
name: String,
attrs: AddAttrs,
output_shape: TensorShape,
},
Concat {
name: String,
axis: i64,
output_shape: TensorShape,
},
Attention {
name: String,
attrs: AttentionAttrs,
output_shape: TensorShape,
},
Embedding {
name: String,
attrs: EmbeddingAttrs,
output_shape: TensorShape,
},
Softmax {
name: String,
axis: i64,
output_shape: TensorShape,
},
Custom {
name: String,
op_type: String,
attributes: HashMap<String, serde_json::Value>,
output_shape: TensorShape,
},
}
impl GraphNode {
pub fn name(&self) -> &str {
match self {
GraphNode::Input { name, .. }
| GraphNode::Output { name, .. }
| GraphNode::Dense { name, .. }
| GraphNode::Conv2D { name, .. }
| GraphNode::DepthwiseConv2D { name, .. }
| GraphNode::BatchNorm { name, .. }
| GraphNode::LayerNorm { name, .. }
| GraphNode::Activation { name, .. }
| GraphNode::MaxPool2D { name, .. }
| GraphNode::AvgPool2D { name, .. }
| GraphNode::GlobalAvgPool { name, .. }
| GraphNode::Dropout { name, .. }
| GraphNode::Reshape { name, .. }
| GraphNode::Flatten { name, .. }
| GraphNode::Add { name, .. }
| GraphNode::Concat { name, .. }
| GraphNode::Attention { name, .. }
| GraphNode::Embedding { name, .. }
| GraphNode::Softmax { name, .. }
| GraphNode::Custom { name, .. } => name,
}
}
pub fn output_shape(&self) -> &TensorShape {
match self {
GraphNode::Input { shape, .. } | GraphNode::Output { shape, .. } => shape,
GraphNode::Dense { output_shape, .. }
| GraphNode::Conv2D { output_shape, .. }
| GraphNode::DepthwiseConv2D { output_shape, .. }
| GraphNode::BatchNorm { output_shape, .. }
| GraphNode::LayerNorm { output_shape, .. }
| GraphNode::Activation { output_shape, .. }
| GraphNode::MaxPool2D { output_shape, .. }
| GraphNode::AvgPool2D { output_shape, .. }
| GraphNode::GlobalAvgPool { output_shape, .. }
| GraphNode::Dropout { output_shape, .. }
| GraphNode::Reshape { output_shape, .. }
| GraphNode::Flatten { output_shape, .. }
| GraphNode::Add { output_shape, .. }
| GraphNode::Concat { output_shape, .. }
| GraphNode::Attention { output_shape, .. }
| GraphNode::Embedding { output_shape, .. }
| GraphNode::Softmax { output_shape, .. }
| GraphNode::Custom { output_shape, .. } => output_shape,
}
}
pub fn onnx_op_type(&self) -> &str {
match self {
GraphNode::Input { .. } => "Input",
GraphNode::Output { .. } => "Output",
GraphNode::Dense { .. } => "Gemm",
GraphNode::Conv2D { .. } => "Conv",
GraphNode::DepthwiseConv2D { .. } => "Conv",
GraphNode::BatchNorm { .. } => "BatchNormalization",
GraphNode::LayerNorm { .. } => "LayerNormalization",
GraphNode::Activation { attrs, .. } => match attrs.function.as_str() {
"relu" => "Relu",
"sigmoid" => "Sigmoid",
"tanh" => "Tanh",
"gelu" => "Gelu",
"leaky_relu" => "LeakyRelu",
"elu" => "Elu",
"swish" | "silu" => "Swish",
_ => "Activation",
},
GraphNode::MaxPool2D { .. } => "MaxPool",
GraphNode::AvgPool2D { .. } => "AveragePool",
GraphNode::GlobalAvgPool { .. } => "GlobalAveragePool",
GraphNode::Dropout { .. } => "Dropout",
GraphNode::Reshape { .. } => "Reshape",
GraphNode::Flatten { .. } => "Flatten",
GraphNode::Add { .. } => "Add",
GraphNode::Concat { .. } => "Concat",
GraphNode::Attention { .. } => "MultiHeadAttention",
GraphNode::Embedding { .. } => "Gather",
GraphNode::Softmax { .. } => "Softmax",
GraphNode::Custom { op_type, .. } => op_type,
}
}
pub fn num_weight_tensors(&self) -> usize {
match self {
GraphNode::Dense { attrs, .. } => {
if attrs.use_bias {
2
} else {
1
}
}
GraphNode::Conv2D { attrs, .. } | GraphNode::DepthwiseConv2D { attrs, .. } => {
if attrs.use_bias {
2
} else {
1
}
}
GraphNode::BatchNorm { attrs, .. } => {
if attrs.affine {
4
} else {
2
}
}
GraphNode::LayerNorm { attrs, .. } => {
if attrs.elementwise_affine {
2
} else {
0
}
}
GraphNode::Embedding { .. } | GraphNode::Attention { .. } => 1,
_ => 0,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct GraphEdge {
pub from_node: String,
pub from_slot: usize,
pub to_node: String,
pub to_slot: usize,
}
impl GraphEdge {
pub fn simple(from: impl Into<String>, to: impl Into<String>) -> Self {
Self {
from_node: from.into(),
from_slot: 0,
to_node: to.into(),
to_slot: 0,
}
}
pub fn with_slots(
from: impl Into<String>,
from_slot: usize,
to: impl Into<String>,
to_slot: usize,
) -> Self {
Self {
from_node: from.into(),
from_slot,
to_node: to.into(),
to_slot,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphMetadata {
pub architecture: String,
pub framework_version: String,
pub graph_format_version: String,
pub description: Option<String>,
pub extra: HashMap<String, String>,
}
impl GraphMetadata {
pub fn new(architecture: impl Into<String>) -> Self {
Self {
architecture: architecture.into(),
framework_version: env!("CARGO_PKG_VERSION").to_string(),
graph_format_version: "1.0".to_string(),
description: None,
extra: HashMap::new(),
}
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extra.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelGraph {
pub metadata: GraphMetadata,
nodes: Vec<GraphNode>,
edges: Vec<GraphEdge>,
}
impl ModelGraph {
pub fn new(metadata: GraphMetadata) -> Self {
Self {
metadata,
nodes: Vec::new(),
edges: Vec::new(),
}
}
pub fn add_node(&mut self, node: GraphNode) -> Result<()> {
let name = node.name().to_string();
if self.nodes.iter().any(|n| n.name() == name) {
return Err(NeuralError::InvalidArgument(format!(
"Node '{name}' already exists in the graph"
)));
}
self.nodes.push(node);
Ok(())
}
pub fn add_edge(&mut self, edge: GraphEdge) -> Result<()> {
let from_exists = self.nodes.iter().any(|n| n.name() == edge.from_node);
let to_exists = self.nodes.iter().any(|n| n.name() == edge.to_node);
if !from_exists {
return Err(NeuralError::InvalidArgument(format!(
"Source node '{}' not found in graph",
edge.from_node
)));
}
if !to_exists {
return Err(NeuralError::InvalidArgument(format!(
"Destination node '{}' not found in graph",
edge.to_node
)));
}
self.edges.push(edge);
Ok(())
}
pub fn nodes(&self) -> &[GraphNode] {
&self.nodes
}
pub fn edges(&self) -> &[GraphEdge] {
&self.edges
}
pub fn find_node(&self, name: &str) -> Option<&GraphNode> {
self.nodes.iter().find(|n| n.name() == name)
}
pub fn outgoing_edges(&self, node_name: &str) -> Vec<&GraphEdge> {
self.edges
.iter()
.filter(|e| e.from_node == node_name)
.collect()
}
pub fn incoming_edges(&self, node_name: &str) -> Vec<&GraphEdge> {
self.edges
.iter()
.filter(|e| e.to_node == node_name)
.collect()
}
pub fn total_parameters(&self) -> Option<i64> {
let mut total = 0i64;
for node in &self.nodes {
let w = node.num_weight_tensors();
if w == 0 {
continue;
}
let elems = node.output_shape().num_elements()?;
total = total.checked_add(elems * w as i64)?;
}
Some(total)
}
pub fn validate(&self) -> Result<()> {
let mut seen_names: HashSet<&str> = HashSet::new();
for node in &self.nodes {
if !seen_names.insert(node.name()) {
return Err(NeuralError::ValidationError(format!(
"Duplicate node name: '{}'",
node.name()
)));
}
}
for edge in &self.edges {
if !seen_names.contains(edge.from_node.as_str()) {
return Err(NeuralError::ValidationError(format!(
"Edge references unknown source node: '{}'",
edge.from_node
)));
}
if !seen_names.contains(edge.to_node.as_str()) {
return Err(NeuralError::ValidationError(format!(
"Edge references unknown destination node: '{}'",
edge.to_node
)));
}
}
let has_input = self
.nodes
.iter()
.any(|n| matches!(n, GraphNode::Input { .. }));
let has_output = self
.nodes
.iter()
.any(|n| matches!(n, GraphNode::Output { .. }));
if !has_input {
return Err(NeuralError::ValidationError(
"Graph must have at least one Input node".to_string(),
));
}
if !has_output {
return Err(NeuralError::ValidationError(
"Graph must have at least one Output node".to_string(),
));
}
Ok(())
}
pub fn export_to_json(&self) -> Result<String> {
serde_json::to_string_pretty(self)
.map_err(|e| NeuralError::SerializationError(e.to_string()))
}
pub fn import_from_json(json: &str) -> Result<Self> {
serde_json::from_str(json)
.map_err(|e| NeuralError::DeserializationError(e.to_string()))
}
pub fn save_to_file(&self, path: &Path) -> Result<()> {
let json = self.export_to_json()?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| NeuralError::IOError(e.to_string()))?;
}
fs::write(path, json.as_bytes()).map_err(|e| NeuralError::IOError(e.to_string()))
}
pub fn load_from_file(path: &Path) -> Result<Self> {
let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
let json =
std::str::from_utf8(&bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
Self::import_from_json(json)
}
pub fn export_onnx_like(&self) -> Result<String> {
let mut inputs = Vec::new();
let mut outputs = Vec::new();
let mut onnx_nodes: Vec<serde_json::Value> = Vec::new();
let mut value_info: Vec<serde_json::Value> = Vec::new();
for node in &self.nodes {
let shape_dims: Vec<serde_json::Value> = node
.output_shape()
.dims
.iter()
.map(|d| match d {
Some(v) => serde_json::json!({ "dim_value": v }),
None => serde_json::json!({ "dim_param": "batch_size" }),
})
.collect();
let type_proto = serde_json::json!({
"tensor_type": {
"elem_type": onnx_elem_type(&node.output_shape().dtype),
"shape": { "dim": shape_dims }
}
});
match node {
GraphNode::Input { name, .. } => {
inputs.push(serde_json::json!({
"name": name,
"type": type_proto
}));
}
GraphNode::Output { name, .. } => {
outputs.push(serde_json::json!({
"name": name,
"type": type_proto
}));
}
other => {
let input_names: Vec<String> = self
.incoming_edges(other.name())
.into_iter()
.map(|e| e.from_node.clone())
.collect();
let output_names = vec![other.name().to_string()];
let attrs = onnx_node_attributes(other);
onnx_nodes.push(serde_json::json!({
"op_type": other.onnx_op_type(),
"name": other.name(),
"input": input_names,
"output": output_names,
"attribute": attrs,
}));
value_info.push(serde_json::json!({
"name": other.name(),
"type": type_proto
}));
}
}
}
let doc_string = self
.metadata
.description
.clone()
.unwrap_or_else(|| format!("{} model graph", self.metadata.architecture));
let onnx_model = serde_json::json!({
"ir_version": 8,
"opset_import": [{ "domain": "", "version": 17 }],
"model_version": 1,
"producer_name": "scirs2-neural",
"producer_version": env!("CARGO_PKG_VERSION"),
"domain": "ai.onnx",
"doc_string": doc_string,
"graph": {
"name": self.metadata.architecture,
"doc_string": doc_string,
"input": inputs,
"output": outputs,
"node": onnx_nodes,
"value_info": value_info,
}
});
serde_json::to_string_pretty(&onnx_model)
.map_err(|e| NeuralError::SerializationError(e.to_string()))
}
}
fn onnx_elem_type(dtype: &TensorDtype) -> u32 {
match dtype {
TensorDtype::F32 => 1, TensorDtype::F64 => 11, TensorDtype::F16 => 10, TensorDtype::I32 => 6, TensorDtype::I64 => 7, TensorDtype::Bool => 9, }
}
fn onnx_node_attributes(node: &GraphNode) -> Vec<serde_json::Value> {
let mut attrs = Vec::new();
match node {
GraphNode::Conv2D { attrs: a, .. } | GraphNode::DepthwiseConv2D { attrs: a, .. } => {
attrs.push(serde_json::json!({
"name": "kernel_shape",
"type": "INTS",
"ints": [a.kernel_h, a.kernel_w]
}));
attrs.push(serde_json::json!({
"name": "strides",
"type": "INTS",
"ints": [a.stride_h, a.stride_w]
}));
attrs.push(serde_json::json!({
"name": "pads",
"type": "INTS",
"ints": [a.padding.top, a.padding.left, a.padding.bottom, a.padding.right]
}));
attrs.push(serde_json::json!({
"name": "dilations",
"type": "INTS",
"ints": [a.dilation, a.dilation]
}));
attrs.push(serde_json::json!({
"name": "group",
"type": "INT",
"i": a.groups
}));
}
GraphNode::MaxPool2D { attrs: a, .. } | GraphNode::AvgPool2D { attrs: a, .. } => {
attrs.push(serde_json::json!({
"name": "kernel_shape",
"type": "INTS",
"ints": [a.kernel_h, a.kernel_w]
}));
let sh = a.stride_h.unwrap_or(a.kernel_h);
let sw = a.stride_w.unwrap_or(a.kernel_w);
attrs.push(serde_json::json!({
"name": "strides",
"type": "INTS",
"ints": [sh, sw]
}));
attrs.push(serde_json::json!({
"name": "pads",
"type": "INTS",
"ints": [a.padding.top, a.padding.left, a.padding.bottom, a.padding.right]
}));
}
GraphNode::BatchNorm { attrs: a, .. } => {
attrs.push(serde_json::json!({
"name": "epsilon",
"type": "FLOAT",
"f": a.eps
}));
attrs.push(serde_json::json!({
"name": "momentum",
"type": "FLOAT",
"f": 1.0 - a.momentum
}));
}
GraphNode::Dropout { attrs: a, .. } => {
attrs.push(serde_json::json!({
"name": "seed",
"type": "INT",
"i": 0
}));
attrs.push(serde_json::json!({
"name": "ratio",
"type": "FLOAT",
"f": a.p
}));
}
GraphNode::Activation { attrs: a, .. } => {
if let Some(alpha) = a.alpha {
attrs.push(serde_json::json!({
"name": "alpha",
"type": "FLOAT",
"f": alpha
}));
}
}
GraphNode::Concat { axis, .. } => {
attrs.push(serde_json::json!({
"name": "axis",
"type": "INT",
"i": axis
}));
}
GraphNode::Softmax { axis, .. } => {
attrs.push(serde_json::json!({
"name": "axis",
"type": "INT",
"i": axis
}));
}
GraphNode::Reshape { attrs: a, .. } => {
attrs.push(serde_json::json!({
"name": "shape",
"type": "INTS",
"ints": a.shape
}));
}
_ => {}
}
attrs
}
pub struct ModelGraphBuilder {
graph: ModelGraph,
pending_chains: Vec<(String, String)>,
}
impl ModelGraphBuilder {
pub fn new(architecture: impl Into<String>) -> Self {
Self {
graph: ModelGraph::new(GraphMetadata::new(architecture)),
pending_chains: Vec::new(),
}
}
pub fn input(mut self, name: impl Into<String>, shape: TensorShape) -> Self {
let n = name.into();
let _ = self.graph.add_node(GraphNode::Input {
name: n,
shape,
});
self
}
pub fn output(mut self, name: impl Into<String>, shape: TensorShape) -> Self {
let n = name.into();
let _ = self.graph.add_node(GraphNode::Output {
name: n,
shape,
});
self
}
pub fn dense(
mut self,
name: impl Into<String>,
in_features: usize,
out_features: usize,
use_bias: bool,
output_shape: TensorShape,
) -> Self {
let n = name.into();
let _ = self.graph.add_node(GraphNode::Dense {
name: n,
attrs: DenseAttrs {
in_features,
out_features,
use_bias,
},
output_shape,
});
self
}
pub fn activation(
mut self,
name: impl Into<String>,
function: impl Into<String>,
alpha: Option<f64>,
output_shape: TensorShape,
) -> Self {
let n = name.into();
let _ = self.graph.add_node(GraphNode::Activation {
name: n,
attrs: ActivationAttrs {
function: function.into(),
alpha,
},
output_shape,
});
self
}
pub fn conv2d(
mut self,
name: impl Into<String>,
attrs: Conv2DAttrs,
output_shape: TensorShape,
) -> Self {
let n = name.into();
let _ = self.graph.add_node(GraphNode::Conv2D {
name: n,
attrs,
output_shape,
});
self
}
pub fn batch_norm(
mut self,
name: impl Into<String>,
attrs: BatchNormAttrs,
output_shape: TensorShape,
) -> Self {
let n = name.into();
let _ = self.graph.add_node(GraphNode::BatchNorm {
name: n,
attrs,
output_shape,
});
self
}
pub fn chain(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
self.pending_chains.push((from.into(), to.into()));
self
}
pub fn edge(mut self, edge: GraphEdge) -> Self {
let _ = self.graph.add_edge(edge);
self
}
pub fn build(mut self) -> Result<ModelGraph> {
let chains = std::mem::take(&mut self.pending_chains);
for (from, to) in chains {
self.graph.add_edge(GraphEdge::simple(from, to))?;
}
self.graph.validate()?;
Ok(self.graph)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mlp_graph() -> ModelGraph {
ModelGraphBuilder::new("MLP")
.input(
"x",
TensorShape::with_batch_dim(vec![784], TensorDtype::F32),
)
.dense(
"fc1",
784,
256,
true,
TensorShape::with_batch_dim(vec![256], TensorDtype::F32),
)
.activation(
"relu1",
"relu",
None,
TensorShape::with_batch_dim(vec![256], TensorDtype::F32),
)
.dense(
"fc2",
256,
10,
true,
TensorShape::with_batch_dim(vec![10], TensorDtype::F32),
)
.output(
"y",
TensorShape::with_batch_dim(vec![10], TensorDtype::F32),
)
.chain("x", "fc1")
.chain("fc1", "relu1")
.chain("relu1", "fc2")
.chain("fc2", "y")
.build()
.expect("build failed")
}
#[test]
fn test_model_graph_builder_basic() {
let graph = make_mlp_graph();
assert_eq!(graph.nodes().len(), 5);
assert_eq!(graph.edges().len(), 4);
}
#[test]
fn test_graph_json_roundtrip() {
let graph = make_mlp_graph();
let json = graph.export_to_json().expect("export failed");
assert!(!json.is_empty());
let restored = ModelGraph::import_from_json(&json).expect("import failed");
assert_eq!(restored.nodes().len(), graph.nodes().len());
assert_eq!(restored.edges().len(), graph.edges().len());
assert_eq!(restored.metadata.architecture, "MLP");
}
#[test]
fn test_graph_file_roundtrip() {
let graph = make_mlp_graph();
let dir = std::env::temp_dir().join("scirs2_model_graph_test");
let path = dir.join("mlp_graph.json");
graph.save_to_file(&path).expect("save failed");
let loaded = ModelGraph::load_from_file(&path).expect("load failed");
assert_eq!(loaded.nodes().len(), graph.nodes().len());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_graph_onnx_like_export() {
let graph = make_mlp_graph();
let onnx_json = graph.export_onnx_like().expect("onnx export failed");
let parsed: serde_json::Value =
serde_json::from_str(&onnx_json).expect("onnx json should be valid json");
assert_eq!(parsed["ir_version"], 8);
assert!(parsed["graph"]["node"].is_array());
assert!(parsed["graph"]["input"].is_array());
assert!(parsed["graph"]["output"].is_array());
}
#[test]
fn test_graph_validate_missing_input() {
let mut graph = ModelGraph::new(GraphMetadata::new("Test"));
graph
.add_node(GraphNode::Output {
name: "out".to_string(),
shape: TensorShape::new(vec![10], TensorDtype::F32),
})
.expect("add node");
assert!(graph.validate().is_err());
}
#[test]
fn test_graph_validate_missing_output() {
let mut graph = ModelGraph::new(GraphMetadata::new("Test"));
graph
.add_node(GraphNode::Input {
name: "in".to_string(),
shape: TensorShape::new(vec![10], TensorDtype::F32),
})
.expect("add node");
assert!(graph.validate().is_err());
}
#[test]
fn test_graph_validate_bad_edge() {
let mut graph = ModelGraph::new(GraphMetadata::new("Test"));
graph
.add_node(GraphNode::Input {
name: "in".to_string(),
shape: TensorShape::new(vec![10], TensorDtype::F32),
})
.expect("add node");
graph
.add_node(GraphNode::Output {
name: "out".to_string(),
shape: TensorShape::new(vec![10], TensorDtype::F32),
})
.expect("add node");
let result = graph.add_edge(GraphEdge::simple("nonexistent", "out"));
assert!(result.is_err());
}
#[test]
fn test_duplicate_node_rejected() {
let mut graph = ModelGraph::new(GraphMetadata::new("Test"));
graph
.add_node(GraphNode::Input {
name: "x".to_string(),
shape: TensorShape::new(vec![10], TensorDtype::F32),
})
.expect("first add");
let result = graph.add_node(GraphNode::Input {
name: "x".to_string(),
shape: TensorShape::new(vec![10], TensorDtype::F32),
});
assert!(result.is_err());
}
#[test]
fn test_tensor_shape_static_num_elements() {
let shape = TensorShape::new(vec![3, 4, 5], TensorDtype::F32);
assert_eq!(shape.num_elements(), Some(60));
}
#[test]
fn test_tensor_shape_dynamic_num_elements() {
let shape = TensorShape::with_batch_dim(vec![3, 4], TensorDtype::F32);
assert_eq!(shape.num_elements(), None);
}
#[test]
fn test_conv2d_onnx_attributes() {
let graph = ModelGraphBuilder::new("ConvNet")
.input(
"x",
TensorShape::with_batch_dim(vec![3, 224, 224], TensorDtype::F32),
)
.conv2d(
"conv1",
Conv2DAttrs {
in_channels: 3,
out_channels: 64,
kernel_h: 3,
kernel_w: 3,
stride_h: 1,
stride_w: 1,
padding: PaddingSpec::same(1),
dilation: 1,
groups: 1,
use_bias: false,
},
TensorShape::with_batch_dim(vec![64, 224, 224], TensorDtype::F32),
)
.output(
"y",
TensorShape::with_batch_dim(vec![64, 224, 224], TensorDtype::F32),
)
.chain("x", "conv1")
.chain("conv1", "y")
.build()
.expect("build");
let onnx = graph.export_onnx_like().expect("onnx export");
let val: serde_json::Value = serde_json::from_str(&onnx).expect("parse");
let node = &val["graph"]["node"][0];
assert_eq!(node["op_type"], "Conv");
let attrs: Vec<serde_json::Value> =
serde_json::from_value(node["attribute"].clone()).expect("parse attrs");
let kernel_attr = attrs
.iter()
.find(|a| a["name"] == "kernel_shape")
.expect("kernel_shape attr");
assert_eq!(kernel_attr["ints"], serde_json::json!([3, 3]));
}
#[test]
fn test_dtype_roundtrip() {
let dtypes = [
TensorDtype::F32,
TensorDtype::F64,
TensorDtype::F16,
TensorDtype::I32,
TensorDtype::I64,
TensorDtype::Bool,
];
for dt in &dtypes {
let s = dt.onnx_str();
let restored = TensorDtype::from_onnx_str(s).expect("roundtrip");
assert_eq!(dt, &restored);
}
}
#[test]
fn test_graph_node_name_accessor() {
let node = GraphNode::Dense {
name: "my_layer".to_string(),
attrs: DenseAttrs {
in_features: 10,
out_features: 5,
use_bias: true,
},
output_shape: TensorShape::new(vec![5], TensorDtype::F32),
};
assert_eq!(node.name(), "my_layer");
assert_eq!(node.onnx_op_type(), "Gemm");
assert_eq!(node.num_weight_tensors(), 2); }
}