#![allow(unused_variables)]
use super::{ExportConfig, ExportFormat, ExportPrecision, ModelExporter};
use crate::traits::Model;
use anyhow::{anyhow, Result};
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
#[derive(Debug)]
pub struct ONNXModel {
pub graph: ONNXGraph,
pub ir_version: i64,
pub opset_imports: Vec<ONNXOpsetImport>,
pub producer_name: String,
pub producer_version: String,
pub model_version: i64,
}
#[derive(Debug)]
pub struct ONNXGraph {
pub nodes: Vec<ONNXNode>,
pub inputs: Vec<ONNXValueInfo>,
pub outputs: Vec<ONNXValueInfo>,
pub initializers: Vec<ONNXTensor>,
pub name: String,
}
#[derive(Debug)]
pub struct ONNXNode {
pub op_type: String,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub attributes: HashMap<String, ONNXAttribute>,
pub name: String,
}
#[derive(Debug)]
pub struct ONNXValueInfo {
pub name: String,
pub type_info: ONNXTypeInfo,
}
#[derive(Debug)]
pub struct ONNXTypeInfo {
pub tensor_type: ONNXTensorType,
}
#[derive(Debug)]
pub struct ONNXTensorType {
pub elem_type: ONNXDataType,
pub shape: ONNXTensorShape,
}
#[derive(Debug)]
pub struct ONNXTensorShape {
pub dims: Vec<ONNXDimension>,
}
#[derive(Debug)]
pub enum ONNXDimension {
Value(i64),
Parameter(String),
}
#[derive(Debug, Clone, Copy)]
pub enum ONNXDataType {
Float = 1,
UInt8 = 2,
Int8 = 3,
UInt16 = 4,
Int16 = 5,
Int32 = 6,
Int64 = 7,
String = 8,
Bool = 9,
Float16 = 10,
Double = 11,
UInt32 = 12,
UInt64 = 13,
Complex64 = 14,
Complex128 = 15,
BFloat16 = 16,
}
#[derive(Debug)]
pub struct ONNXTensor {
pub name: String,
pub data_type: ONNXDataType,
pub dims: Vec<i64>,
pub raw_data: Vec<u8>,
}
#[derive(Debug)]
pub struct ONNXOpsetImport {
pub domain: String,
pub version: i64,
}
#[derive(Debug)]
pub enum ONNXAttribute {
Int(i64),
Float(f32),
String(String),
Tensor(ONNXTensor),
Ints(Vec<i64>),
Floats(Vec<f32>),
Strings(Vec<String>),
}
#[derive(Debug)]
pub struct ONNXOperatorRegistry {
pub operators: HashMap<String, ONNXOperatorInfo>,
}
#[derive(Debug, Clone)]
pub struct ONNXOperatorInfo {
pub name: String,
pub opset_version: i64,
pub inputs: Vec<ONNXOperatorInput>,
pub outputs: Vec<ONNXOperatorOutput>,
pub attributes: Vec<ONNXOperatorAttribute>,
pub description: String,
}
#[derive(Debug, Clone)]
pub struct ONNXOperatorInput {
pub name: String,
pub types: Vec<ONNXDataType>,
pub optional: bool,
pub description: String,
}
#[derive(Debug, Clone)]
pub struct ONNXOperatorOutput {
pub name: String,
pub types: Vec<ONNXDataType>,
pub description: String,
}
#[derive(Debug, Clone)]
pub struct ONNXOperatorAttribute {
pub name: String,
pub required: bool,
pub attribute_type: String,
pub description: String,
}
#[derive(Clone)]
pub struct ONNXExporter {
opset_version: i64,
}
impl Default for ONNXExporter {
fn default() -> Self {
Self::new()
}
}
impl ONNXExporter {
pub fn new() -> Self {
Self { opset_version: 14 }
}
pub fn with_opset_version(mut self, version: i64) -> Self {
self.opset_version = version;
self
}
fn create_onnx_model<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<ONNXModel> {
let mut graph = ONNXGraph {
nodes: Vec::new(),
inputs: Vec::new(),
outputs: Vec::new(),
initializers: Vec::new(),
name: "trustformers_model".to_string(),
};
let batch_size = config.batch_size.unwrap_or(1);
let seq_len = config.sequence_length.unwrap_or(512);
let input_ids = ONNXValueInfo {
name: "input_ids".to_string(),
type_info: ONNXTypeInfo {
tensor_type: ONNXTensorType {
elem_type: ONNXDataType::Int64,
shape: ONNXTensorShape {
dims: vec![
ONNXDimension::Parameter("batch_size".to_string()),
ONNXDimension::Parameter("sequence_length".to_string()),
],
},
},
},
};
graph.inputs.push(input_ids);
let attention_mask = ONNXValueInfo {
name: "attention_mask".to_string(),
type_info: ONNXTypeInfo {
tensor_type: ONNXTensorType {
elem_type: ONNXDataType::Int64,
shape: ONNXTensorShape {
dims: vec![
ONNXDimension::Parameter("batch_size".to_string()),
ONNXDimension::Parameter("sequence_length".to_string()),
],
},
},
},
};
graph.inputs.push(attention_mask);
let output = ONNXValueInfo {
name: "logits".to_string(),
type_info: ONNXTypeInfo {
tensor_type: ONNXTensorType {
elem_type: match config.precision {
ExportPrecision::FP32 => ONNXDataType::Float,
ExportPrecision::FP16 => ONNXDataType::Float16,
ExportPrecision::INT8 => ONNXDataType::Int8,
ExportPrecision::INT4 => ONNXDataType::Int8, },
shape: ONNXTensorShape {
dims: vec![
ONNXDimension::Parameter("batch_size".to_string()),
ONNXDimension::Parameter("sequence_length".to_string()),
ONNXDimension::Value(50257), ],
},
},
},
};
graph.outputs.push(output);
self.convert_model_to_nodes(model, &mut graph, config)?;
let onnx_model = ONNXModel {
graph,
ir_version: 8,
opset_imports: vec![ONNXOpsetImport {
domain: "".to_string(),
version: config.opset_version.unwrap_or(self.opset_version),
}],
producer_name: "TrustformeRS".to_string(),
producer_version: "0.1.0".to_string(),
model_version: 1,
};
Ok(onnx_model)
}
fn convert_model_to_nodes<M: Model>(
&self,
model: &M,
graph: &mut ONNXGraph,
config: &ExportConfig,
) -> Result<()> {
let embedding_node = ONNXNode {
op_type: "Gather".to_string(),
inputs: vec!["embedding_weight".to_string(), "input_ids".to_string()],
outputs: vec!["embeddings".to_string()],
attributes: HashMap::new(),
name: "embedding".to_string(),
};
graph.nodes.push(embedding_node);
for layer_idx in 0..12 {
self.add_transformer_layer(graph, layer_idx, config)?;
}
let final_norm_node = ONNXNode {
op_type: "LayerNormalization".to_string(),
inputs: vec![
format!("layer_{}_output", 11),
"final_layer_norm_weight".to_string(),
"final_layer_norm_bias".to_string(),
],
outputs: vec!["final_hidden_states".to_string()],
attributes: HashMap::new(),
name: "final_layer_norm".to_string(),
};
graph.nodes.push(final_norm_node);
let lm_head_node = ONNXNode {
op_type: "MatMul".to_string(),
inputs: vec![
"final_hidden_states".to_string(),
"lm_head_weight".to_string(),
],
outputs: vec!["logits".to_string()],
attributes: HashMap::new(),
name: "lm_head".to_string(),
};
graph.nodes.push(lm_head_node);
Ok(())
}
fn add_transformer_layer(
&self,
graph: &mut ONNXGraph,
layer_idx: usize,
_config: &ExportConfig,
) -> Result<()> {
let layer_prefix = format!("layer_{}", layer_idx);
let input_name = if layer_idx == 0 {
"embeddings".to_string()
} else {
format!("layer_{}_output", layer_idx - 1)
};
let attention_node = ONNXNode {
op_type: "MultiHeadAttention".to_string(),
inputs: vec![
input_name.clone(),
input_name.clone(),
input_name.clone(),
format!("{}_attention_mask", layer_prefix),
],
outputs: vec![format!("{}_attention_output", layer_prefix)],
attributes: {
let mut attrs = HashMap::new();
attrs.insert("num_heads".to_string(), ONNXAttribute::Int(12));
attrs
},
name: format!("{}_attention", layer_prefix),
};
graph.nodes.push(attention_node);
let add_node = ONNXNode {
op_type: "Add".to_string(),
inputs: vec![
input_name.clone(),
format!("{}_attention_output", layer_prefix),
],
outputs: vec![format!("{}_attention_residual", layer_prefix)],
attributes: HashMap::new(),
name: format!("{}_attention_add", layer_prefix),
};
graph.nodes.push(add_node);
let norm_node = ONNXNode {
op_type: "LayerNormalization".to_string(),
inputs: vec![
format!("{}_attention_residual", layer_prefix),
format!("{}_norm_weight", layer_prefix),
format!("{}_norm_bias", layer_prefix),
],
outputs: vec![format!("{}_norm_output", layer_prefix)],
attributes: HashMap::new(),
name: format!("{}_norm", layer_prefix),
};
graph.nodes.push(norm_node);
let ff_node = ONNXNode {
op_type: "MatMul".to_string(),
inputs: vec![
format!("{}_norm_output", layer_prefix),
format!("{}_ff_weight", layer_prefix),
],
outputs: vec![format!("{}_ff_output", layer_prefix)],
attributes: HashMap::new(),
name: format!("{}_feedforward", layer_prefix),
};
graph.nodes.push(ff_node);
let final_add_node = ONNXNode {
op_type: "Add".to_string(),
inputs: vec![
format!("{}_norm_output", layer_prefix),
format!("{}_ff_output", layer_prefix),
],
outputs: vec![format!("{}_output", layer_prefix)],
attributes: HashMap::new(),
name: format!("{}_final_add", layer_prefix),
};
graph.nodes.push(final_add_node);
Ok(())
}
fn serialize_onnx_model(&self, model: &ONNXModel, output_path: &str) -> Result<()> {
let serialized = self.serialize_to_text(model)?;
let mut file = File::create(format!("{}.onnx", output_path))?;
file.write_all(serialized.as_bytes())?;
Ok(())
}
fn serialize_to_text(&self, model: &ONNXModel) -> Result<String> {
let mut output = String::new();
output.push_str(&format!("IR Version: {}\n", model.ir_version));
output.push_str(&format!(
"Producer: {} {}\n",
model.producer_name, model.producer_version
));
output.push_str(&format!("Model Version: {}\n", model.model_version));
output.push('\n');
output.push_str("Opset Imports:\n");
for opset in &model.opset_imports {
output.push_str(&format!(
" Domain: '{}', Version: {}\n",
opset.domain, opset.version
));
}
output.push('\n');
output.push_str("Graph:\n");
output.push_str(&format!(" Name: {}\n", model.graph.name));
output.push_str(" Inputs:\n");
for input in &model.graph.inputs {
output.push_str(&format!(" {}: {:?}\n", input.name, input.type_info));
}
output.push_str(" Outputs:\n");
for output_info in &model.graph.outputs {
output.push_str(&format!(
" {}: {:?}\n",
output_info.name, output_info.type_info
));
}
output.push_str(" Nodes:\n");
for node in &model.graph.nodes {
output.push_str(&format!(
" {} ({}): {} -> {}\n",
node.name,
node.op_type,
node.inputs.join(", "),
node.outputs.join(", ")
));
}
Ok(output)
}
}
impl Default for ONNXOperatorRegistry {
fn default() -> Self {
Self::new()
}
}
impl ONNXOperatorRegistry {
pub fn new() -> Self {
Self {
operators: HashMap::new(),
}
}
pub fn register_all_operators(&mut self) {
self.register_core_operators();
self.register_math_operators();
self.register_neural_network_operators();
self.register_tensor_operators();
self.register_control_flow_operators();
self.register_quantization_operators();
}
fn register_core_operators(&mut self) {
self.register_operator(ONNXOperatorInfo {
name: "Add".to_string(),
opset_version: 7,
inputs: vec![
ONNXOperatorInput {
name: "A".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
optional: false,
description: "First operand".to_string(),
},
ONNXOperatorInput {
name: "B".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
optional: false,
description: "Second operand".to_string(),
},
],
outputs: vec![ONNXOperatorOutput {
name: "C".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
description: "Result of addition".to_string(),
}],
attributes: vec![],
description: "Element-wise addition".to_string(),
});
self.register_operator(ONNXOperatorInfo {
name: "MatMul".to_string(),
opset_version: 1,
inputs: vec![
ONNXOperatorInput {
name: "A".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
ONNXDataType::BFloat16,
],
optional: false,
description: "N-dimensional matrix A".to_string(),
},
ONNXOperatorInput {
name: "B".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
ONNXDataType::BFloat16,
],
optional: false,
description: "N-dimensional matrix B".to_string(),
},
],
outputs: vec![ONNXOperatorOutput {
name: "Y".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
ONNXDataType::BFloat16,
],
description: "Matrix multiplication result".to_string(),
}],
attributes: vec![],
description: "Matrix multiplication".to_string(),
});
}
fn register_math_operators(&mut self) {
let activations = vec![
("Relu", "Rectified Linear Unit activation"),
("Sigmoid", "Sigmoid activation function"),
("Tanh", "Hyperbolic tangent activation"),
("Gelu", "Gaussian Error Linear Unit activation"),
("LeakyRelu", "Leaky ReLU activation"),
("Elu", "Exponential Linear Unit activation"),
("Selu", "Scaled Exponential Linear Unit activation"),
("Swish", "Swish activation function"),
];
for (name, desc) in activations {
self.register_operator(ONNXOperatorInfo {
name: name.to_string(),
opset_version: 6,
inputs: vec![ONNXOperatorInput {
name: "X".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: false,
description: "Input tensor".to_string(),
}],
outputs: vec![ONNXOperatorOutput {
name: "Y".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
description: "Output tensor".to_string(),
}],
attributes: if name == "LeakyRelu" {
vec![ONNXOperatorAttribute {
name: "alpha".to_string(),
attribute_type: "float".to_string(),
required: false,
description: "Coefficient of leakage".to_string(),
}]
} else {
vec![]
},
description: desc.to_string(),
});
}
let math_ops = vec![
"Abs",
"Acos",
"Asin",
"Atan",
"Ceil",
"Cos",
"Cosh",
"Exp",
"Floor",
"Log",
"Neg",
"Reciprocal",
"Round",
"Sign",
"Sin",
"Sinh",
"Sqrt",
"Tan",
"Erf",
];
for op in math_ops {
self.register_operator(ONNXOperatorInfo {
name: op.to_string(),
opset_version: 6,
inputs: vec![ONNXOperatorInput {
name: "input".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: false,
description: "Input tensor".to_string(),
}],
outputs: vec![ONNXOperatorOutput {
name: "output".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
description: "Output tensor".to_string(),
}],
attributes: vec![],
description: format!("{} mathematical function", op),
});
}
}
fn register_neural_network_operators(&mut self) {
self.register_operator(ONNXOperatorInfo {
name: "Conv".to_string(),
opset_version: 11,
inputs: vec![
ONNXOperatorInput {
name: "X".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: false,
description: "Input data tensor".to_string(),
},
ONNXOperatorInput {
name: "W".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: false,
description: "Weight tensor".to_string(),
},
ONNXOperatorInput {
name: "B".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: true,
description: "Optional bias tensor".to_string(),
},
],
outputs: vec![ONNXOperatorOutput {
name: "Y".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
description: "Output data tensor".to_string(),
}],
attributes: vec![
ONNXOperatorAttribute {
name: "strides".to_string(),
attribute_type: "ints".to_string(),
required: false,
description: "Stride along each spatial axis".to_string(),
},
ONNXOperatorAttribute {
name: "pads".to_string(),
attribute_type: "ints".to_string(),
required: false,
description: "Padding for the beginning and ending".to_string(),
},
ONNXOperatorAttribute {
name: "dilations".to_string(),
attribute_type: "ints".to_string(),
required: false,
description: "Dilation value along each spatial axis".to_string(),
},
ONNXOperatorAttribute {
name: "group".to_string(),
attribute_type: "int".to_string(),
required: false,
description:
"Number of groups input channels and output channels are divided into"
.to_string(),
},
],
description: "Convolution operator".to_string(),
});
self.register_operator(ONNXOperatorInfo {
name: "BatchNormalization".to_string(),
opset_version: 15,
inputs: vec![
ONNXOperatorInput {
name: "X".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: false,
description: "Input tensor".to_string(),
},
ONNXOperatorInput {
name: "scale".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: false,
description: "Scale tensor".to_string(),
},
ONNXOperatorInput {
name: "B".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: false,
description: "Bias tensor".to_string(),
},
ONNXOperatorInput {
name: "input_mean".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: false,
description: "Running mean".to_string(),
},
ONNXOperatorInput {
name: "input_var".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
optional: false,
description: "Running variance".to_string(),
},
],
outputs: vec![ONNXOperatorOutput {
name: "Y".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Float16,
],
description: "Output tensor".to_string(),
}],
attributes: vec![ONNXOperatorAttribute {
name: "epsilon".to_string(),
attribute_type: "float".to_string(),
required: false,
description: "Small constant to avoid division by zero".to_string(),
}],
description: "Batch normalization operator".to_string(),
});
}
fn register_tensor_operators(&mut self) {
let shape_ops = vec![
("Reshape", "Reshape the input tensor"),
("Transpose", "Transpose the input tensor"),
("Squeeze", "Remove single-dimensional entries"),
("Unsqueeze", "Insert single-dimensional entries"),
("Flatten", "Flatten the input tensor"),
];
for (name, desc) in shape_ops {
self.register_operator(ONNXOperatorInfo {
name: name.to_string(),
opset_version: 13,
inputs: if name == "Reshape" {
vec![
ONNXOperatorInput {
name: "data".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
optional: false,
description: "Input tensor".to_string(),
},
ONNXOperatorInput {
name: "shape".to_string(),
types: vec![ONNXDataType::Int64],
optional: false,
description: "New shape".to_string(),
},
]
} else {
vec![ONNXOperatorInput {
name: "data".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
optional: false,
description: "Input tensor".to_string(),
}]
},
outputs: vec![ONNXOperatorOutput {
name: "reshaped".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
description: "Reshaped tensor".to_string(),
}],
attributes: match name {
"Transpose" => vec![ONNXOperatorAttribute {
name: "perm".to_string(),
attribute_type: "ints".to_string(),
required: false,
description: "A list of integers. By default, reverse the dimensions"
.to_string(),
}],
"Flatten" => vec![ONNXOperatorAttribute {
name: "axis".to_string(),
attribute_type: "int".to_string(),
required: false,
description: "Indicate which axis to flatten".to_string(),
}],
_ => vec![],
},
description: desc.to_string(),
});
}
let reduce_ops = vec![
"ReduceSum",
"ReduceMean",
"ReduceMax",
"ReduceMin",
"ReduceProd",
"ReduceL1",
"ReduceL2",
"ReduceLogSum",
"ReduceLogSumExp",
"ReduceSumSquare",
];
for op in reduce_ops {
self.register_operator(ONNXOperatorInfo {
name: op.to_string(),
opset_version: 13,
inputs: vec![ONNXOperatorInput {
name: "data".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
optional: false,
description: "Input tensor".to_string(),
}],
outputs: vec![ONNXOperatorOutput {
name: "reduced".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
description: "Reduced tensor".to_string(),
}],
attributes: vec![
ONNXOperatorAttribute {
name: "axes".to_string(),
attribute_type: "ints".to_string(),
required: false,
description: "A list of integers, along which to reduce".to_string(),
},
ONNXOperatorAttribute {
name: "keepdims".to_string(),
attribute_type: "int".to_string(),
required: false,
description: "Keep the reduced dimension or not".to_string(),
},
],
description: format!("{} reduction operation", op),
});
}
}
fn register_control_flow_operators(&mut self) {
self.register_operator(ONNXOperatorInfo {
name: "If".to_string(),
opset_version: 11,
inputs: vec![ONNXOperatorInput {
name: "cond".to_string(),
types: vec![ONNXDataType::Bool],
optional: false,
description: "Condition tensor".to_string(),
}],
outputs: vec![ONNXOperatorOutput {
name: "outputs".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
description: "Output values".to_string(),
}],
attributes: vec![
ONNXOperatorAttribute {
name: "then_branch".to_string(),
attribute_type: "string".to_string(),
required: true,
description: "Graph to run if condition is true".to_string(),
},
ONNXOperatorAttribute {
name: "else_branch".to_string(),
attribute_type: "string".to_string(),
required: true,
description: "Graph to run if condition is false".to_string(),
},
],
description: "Conditional execution".to_string(),
});
self.register_operator(ONNXOperatorInfo {
name: "Loop".to_string(),
opset_version: 13,
inputs: vec![
ONNXOperatorInput {
name: "M".to_string(),
types: vec![ONNXDataType::Int64],
optional: true,
description: "Maximum trip count".to_string(),
},
ONNXOperatorInput {
name: "cond".to_string(),
types: vec![ONNXDataType::Bool],
optional: true,
description: "Loop termination condition".to_string(),
},
],
outputs: vec![ONNXOperatorOutput {
name: "v_final".to_string(),
types: vec![
ONNXDataType::Float,
ONNXDataType::Double,
ONNXDataType::Int32,
ONNXDataType::Int64,
],
description: "Final loop carried values".to_string(),
}],
attributes: vec![ONNXOperatorAttribute {
name: "body".to_string(),
attribute_type: "string".to_string(),
required: true,
description: "Graph to execute in the loop".to_string(),
}],
description: "Loop execution".to_string(),
});
}
fn register_quantization_operators(&mut self) {
self.register_operator(ONNXOperatorInfo {
name: "QuantizeLinear".to_string(),
opset_version: 13,
inputs: vec![
ONNXOperatorInput {
name: "x".to_string(),
types: vec![ONNXDataType::Float, ONNXDataType::Int32],
optional: false,
description: "Input tensor".to_string(),
},
ONNXOperatorInput {
name: "y_scale".to_string(),
types: vec![ONNXDataType::Float],
optional: false,
description: "Scale for doing quantization".to_string(),
},
ONNXOperatorInput {
name: "y_zero_point".to_string(),
types: vec![ONNXDataType::UInt8, ONNXDataType::Int8],
optional: true,
description: "Zero point for quantization".to_string(),
},
],
outputs: vec![ONNXOperatorOutput {
name: "y".to_string(),
types: vec![ONNXDataType::UInt8, ONNXDataType::Int8],
description: "Quantized output tensor".to_string(),
}],
attributes: vec![],
description: "Linear quantization operator".to_string(),
});
self.register_operator(ONNXOperatorInfo {
name: "DequantizeLinear".to_string(),
opset_version: 13,
inputs: vec![
ONNXOperatorInput {
name: "x".to_string(),
types: vec![ONNXDataType::UInt8, ONNXDataType::Int8],
optional: false,
description: "Quantized input tensor".to_string(),
},
ONNXOperatorInput {
name: "x_scale".to_string(),
types: vec![ONNXDataType::Float],
optional: false,
description: "Scale for doing dequantization".to_string(),
},
ONNXOperatorInput {
name: "x_zero_point".to_string(),
types: vec![ONNXDataType::UInt8, ONNXDataType::Int8],
optional: true,
description: "Zero point for dequantization".to_string(),
},
],
outputs: vec![ONNXOperatorOutput {
name: "y".to_string(),
types: vec![ONNXDataType::Float],
description: "Dequantized output tensor".to_string(),
}],
attributes: vec![],
description: "Linear dequantization operator".to_string(),
});
}
fn register_operator(&mut self, operator_info: ONNXOperatorInfo) {
self.operators.insert(operator_info.name.clone(), operator_info);
}
pub fn get_operator_names(&self) -> Vec<String> {
self.operators.keys().cloned().collect()
}
pub fn has_operator(&self, name: &str) -> bool {
self.operators.contains_key(name)
}
pub fn get_operator(&self, name: &str) -> Option<&ONNXOperatorInfo> {
self.operators.get(name)
}
pub fn get_operators_by_opset(&self, opset_version: i64) -> Vec<&ONNXOperatorInfo> {
self.operators.values().filter(|op| op.opset_version <= opset_version).collect()
}
pub fn validate_operator_usage(
&self,
op_name: &str,
inputs: &[String],
attributes: &HashMap<String, ONNXAttribute>,
) -> Result<()> {
let op_info = self
.get_operator(op_name)
.ok_or_else(|| anyhow!("Unsupported operator: {}", op_name))?;
let required_inputs = op_info.inputs.iter().filter(|input| !input.optional).count();
if inputs.len() < required_inputs {
return Err(anyhow!(
"Operator {} requires at least {} inputs, got {}",
op_name,
required_inputs,
inputs.len()
));
}
for attr in &op_info.attributes {
if attr.required && !attributes.contains_key(&attr.name) {
return Err(anyhow!(
"Operator {} requires attribute '{}'",
op_name,
attr.name
));
}
}
Ok(())
}
}
impl ModelExporter for ONNXExporter {
fn export<M: Model>(&self, model: &M, config: &ExportConfig) -> Result<()> {
if config.format != ExportFormat::ONNX {
return Err(anyhow!("ONNXExporter only supports ONNX format"));
}
let onnx_model = self.create_onnx_model(model, config)?;
self.serialize_onnx_model(&onnx_model, &config.output_path)?;
println!("Model exported to {}.onnx", config.output_path);
Ok(())
}
fn supported_formats(&self) -> Vec<ExportFormat> {
vec![ExportFormat::ONNX]
}
fn validate_model<M: Model>(&self, _model: &M, format: ExportFormat) -> Result<()> {
if format != ExportFormat::ONNX {
return Err(anyhow!("ONNXExporter only supports ONNX format"));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_onnx_exporter_creation() {
let exporter = ONNXExporter::new();
assert_eq!(exporter.opset_version, 14);
let exporter_v13 = exporter.with_opset_version(13);
assert_eq!(exporter_v13.opset_version, 13);
}
#[test]
fn test_onnx_data_types() {
assert_eq!(ONNXDataType::Float as i64, 1);
assert_eq!(ONNXDataType::Float16 as i64, 10);
assert_eq!(ONNXDataType::Int64 as i64, 7);
}
#[test]
fn test_supported_formats() {
let exporter = ONNXExporter::new();
let formats = exporter.supported_formats();
assert_eq!(formats.len(), 1);
assert_eq!(formats[0], ExportFormat::ONNX);
}
#[test]
fn test_onnx_dimension_types() {
let dim_value = ONNXDimension::Value(512);
let dim_param = ONNXDimension::Parameter("batch_size".to_string());
match dim_value {
ONNXDimension::Value(v) => assert_eq!(v, 512),
_ => panic!("Expected Value dimension"),
}
match dim_param {
ONNXDimension::Parameter(p) => assert_eq!(p, "batch_size"),
_ => panic!("Expected Parameter dimension"),
}
}
#[test]
fn test_onnx_attribute_types() {
let int_attr = ONNXAttribute::Int(42);
let float_attr = ONNXAttribute::Float(std::f32::consts::PI);
let string_attr = ONNXAttribute::String("test".to_string());
match int_attr {
ONNXAttribute::Int(v) => assert_eq!(v, 42),
_ => panic!("Expected Int attribute"),
}
match float_attr {
ONNXAttribute::Float(v) => assert!((v - std::f32::consts::PI).abs() < 1e-6),
_ => panic!("Expected Float attribute"),
}
match string_attr {
ONNXAttribute::String(s) => assert_eq!(s, "test"),
_ => panic!("Expected String attribute"),
}
}
}