use std::io::Write;
use std::path::Path;
use crate::ir::{DType, Graph, Op};
use crate::weight_loader::ModelWeights;
#[derive(Debug, thiserror::Error)]
pub enum OnnxExportError {
#[error("I/O error writing ONNX file: {0}")]
Io(#[from] std::io::Error),
}
fn encode_varint(buf: &mut Vec<u8>, mut v: u64) {
loop {
let byte = (v & 0x7F) as u8;
v >>= 7;
if v == 0 {
buf.push(byte);
break;
}
buf.push(byte | 0x80);
}
}
fn encode_field_tag(buf: &mut Vec<u8>, field: u32, wire_type: u8) {
encode_varint(buf, ((field as u64) << 3) | wire_type as u64);
}
fn encode_string(buf: &mut Vec<u8>, field: u32, s: &str) {
encode_bytes(buf, field, s.as_bytes());
}
fn encode_int64(buf: &mut Vec<u8>, field: u32, v: i64) {
encode_field_tag(buf, field, 0);
encode_varint(buf, v as u64);
}
fn encode_bytes(buf: &mut Vec<u8>, field: u32, data: &[u8]) {
encode_field_tag(buf, field, 2);
encode_varint(buf, data.len() as u64);
buf.extend_from_slice(data);
}
fn encode_message(buf: &mut Vec<u8>, field: u32, data: &[u8]) {
encode_bytes(buf, field, data);
}
fn op_to_onnx(op: &Op) -> Option<(&'static str, &'static str)> {
match op {
Op::MatMul | Op::BatchMatMul => Some(("MatMul", "")),
Op::Add | Op::Residual => Some(("Add", "")),
Op::Mul => Some(("Mul", "")),
Op::Softmax => Some(("Softmax", "")),
Op::ReLU => Some(("Relu", "")),
Op::GeLU => Some(("Gelu", "")),
Op::SiLU => Some(("SiLU", "forge-llm")),
Op::RMSNorm { .. } => Some(("SimplifiedLayerNormalization", "com.microsoft")),
Op::LayerNorm { .. } => Some(("LayerNormalization", "")),
Op::RoPE { .. } => Some(("RotaryEmbedding", "com.microsoft")),
Op::Attention { .. } => Some(("Attention", "com.microsoft")),
Op::Embedding { .. } => Some(("Gather", "")),
Op::LogitsProjection { .. } => Some(("MatMul", "")),
Op::Reshape { .. } => Some(("Reshape", "")),
Op::Transpose { .. } => Some(("Transpose", "")),
Op::Contiguous => Some(("Contiguous", "forge-llm")),
Op::Cast { .. } => Some(("Cast", "")),
Op::LoadWeight { .. } | Op::Input { .. } => None,
}
}
const ONNX_FLOAT: i32 = 1; const ONNX_FLOAT16: i32 = 10; const ONNX_INT32: i32 = 6;
const ONNX_INT64: i32 = 7;
fn dtype_to_onnx(dtype: DType) -> i32 {
match dtype {
DType::F32 => ONNX_FLOAT,
DType::F16 | DType::BF16 => ONNX_FLOAT16,
DType::I32 => ONNX_INT32,
DType::I64 => ONNX_INT64,
DType::F8E4M3 | DType::F8E5M2 => ONNX_FLOAT16,
DType::Q8_0 | DType::Q4_0 | DType::Q4_1 | DType::Q2 | DType::NF4 => ONNX_FLOAT,
}
}
fn build_tensor_proto(name: &str, shape: &[usize], dtype: DType, data: Option<&[f32]>) -> Vec<u8> {
let mut buf = Vec::new();
for &dim in shape {
encode_int64(&mut buf, 1, dim as i64);
}
encode_field_tag(&mut buf, 2, 0);
encode_varint(&mut buf, dtype_to_onnx(dtype) as u64);
encode_string(&mut buf, 8, name);
if let Some(floats) = data {
let mut raw = Vec::with_capacity(floats.len() * 4);
for &f in floats {
raw.extend_from_slice(&f.to_le_bytes());
}
encode_bytes(&mut buf, 9, &raw);
}
buf
}
fn build_value_info(name: &str, dtype: DType) -> Vec<u8> {
let mut buf = Vec::new();
encode_string(&mut buf, 1, name);
let mut tensor_type_buf = Vec::new();
encode_field_tag(&mut tensor_type_buf, 1, 0); encode_varint(&mut tensor_type_buf, dtype_to_onnx(dtype) as u64);
let mut type_proto_buf = Vec::new();
encode_message(&mut type_proto_buf, 1, &tensor_type_buf);
encode_message(&mut buf, 2, &type_proto_buf);
buf
}
fn build_node_proto(
node_name: &str,
op_type: &str,
domain: &str,
inputs: &[String],
outputs: &[String],
) -> Vec<u8> {
let mut buf = Vec::new();
for inp in inputs {
encode_string(&mut buf, 1, inp);
}
for out in outputs {
encode_string(&mut buf, 2, out);
}
encode_string(&mut buf, 3, node_name);
encode_string(&mut buf, 4, op_type);
if !domain.is_empty() {
encode_string(&mut buf, 7, domain);
}
buf
}
fn build_graph_proto(graph: &Graph, weights: &ModelWeights) -> Vec<u8> {
let mut buf = Vec::new();
encode_string(&mut buf, 3, &graph.name);
for (weight_name, tensor_info) in &graph.weights {
let data = weights.get(weight_name);
let tp = build_tensor_proto(
weight_name,
&tensor_info.shape.0.to_vec(),
tensor_info.dtype,
data,
);
encode_message(&mut buf, 5, &tp);
}
for node in &graph.nodes {
if let Op::Input { name } = &node.op {
let vi = build_value_info(name, node.output.dtype);
encode_message(&mut buf, 11, &vi);
}
}
if let Some(last_node) = graph.nodes.last() {
let vi = build_value_info(&last_node.output.name, last_node.output.dtype);
encode_message(&mut buf, 12, &vi);
}
for node in &graph.nodes {
let Some((op_type, domain)) = op_to_onnx(&node.op) else {
continue;
};
let input_names: Vec<String> = node
.inputs
.iter()
.map(|&id| graph.nodes[id].output.name.clone())
.collect();
let output_names = vec![node.output.name.clone()];
let np = build_node_proto(
&node.output.name,
op_type,
domain,
&input_names,
&output_names,
);
encode_message(&mut buf, 1, &np);
}
buf
}
fn build_opset_import(domain: &str, version: i64) -> Vec<u8> {
let mut buf = Vec::new();
encode_string(&mut buf, 1, domain);
encode_int64(&mut buf, 2, version);
buf
}
pub fn export_onnx(
graph: &Graph,
weights: &ModelWeights,
output_path: &Path,
) -> Result<(), OnnxExportError> {
let model_bytes = build_model_proto(graph, weights);
let mut file = std::fs::File::create(output_path)?;
file.write_all(&model_bytes)?;
Ok(())
}
pub(crate) fn build_model_proto(graph: &Graph, weights: &ModelWeights) -> Vec<u8> {
let mut buf = Vec::new();
encode_int64(&mut buf, 1, 8);
let opset_std = build_opset_import("", 17);
encode_message(&mut buf, 8, &opset_std);
let opset_ms = build_opset_import("com.microsoft", 1);
encode_message(&mut buf, 8, &opset_ms);
let opset_forge = build_opset_import("forge-llm", 1);
encode_message(&mut buf, 8, &opset_forge);
encode_string(&mut buf, 2, "forge-llm");
encode_string(&mut buf, 3, env!("CARGO_PKG_VERSION"));
let graph_bytes = build_graph_proto(graph, weights);
encode_message(&mut buf, 7, &graph_bytes);
buf
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::graph_builder::build_graph;
use crate::ir::{Architecture, DType, HiddenActivation, ModelConfig};
fn tiny_config() -> ModelConfig {
ModelConfig {
architecture: Architecture::Llama,
hidden_size: 64,
intermediate_size: 128,
num_layers: 1,
num_attention_heads: 4,
num_kv_heads: 2,
head_dim: 16,
vocab_size: 256,
max_seq_len: 64,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
dtype: DType::F16,
sliding_window_size: None,
qkv_bias: false,
hidden_activation: HiddenActivation::SiLU,
}
}
fn empty_weights() -> ModelWeights {
ModelWeights {
tensors: HashMap::new(),
}
}
#[test]
fn export_creates_file() {
let dir = tempfile::tempdir().unwrap();
let config = tiny_config();
let graph = build_graph(&config).unwrap();
let weights = empty_weights();
let out = dir.path().join("model.onnx");
export_onnx(&graph, &weights, &out).unwrap();
assert!(out.exists(), "ONNX file should exist");
assert!(
out.metadata().unwrap().len() > 0,
"ONNX file should not be empty"
);
}
#[test]
fn export_writes_valid_onnx_header() {
let config = tiny_config();
let graph = build_graph(&config).unwrap();
let weights = empty_weights();
let bytes = build_model_proto(&graph, &weights);
assert!(!bytes.is_empty(), "serialized model should not be empty");
assert_eq!(bytes[0], 0x08, "first byte should be ir_version field tag");
}
#[test]
fn varint_encoding() {
let mut buf = Vec::new();
encode_varint(&mut buf, 0);
assert_eq!(buf, &[0x00]);
buf.clear();
encode_varint(&mut buf, 1);
assert_eq!(buf, &[0x01]);
buf.clear();
encode_varint(&mut buf, 127);
assert_eq!(buf, &[0x7F]);
buf.clear();
encode_varint(&mut buf, 128);
assert_eq!(buf, &[0x80, 0x01]);
buf.clear();
encode_varint(&mut buf, 300);
assert_eq!(buf, &[0xAC, 0x02]);
}
#[test]
fn string_field_encoding() {
let mut buf = Vec::new();
encode_string(&mut buf, 2, "hi");
assert_eq!(buf, &[0x12, 0x02, 0x68, 0x69]);
}
#[test]
fn export_with_weights() {
let dir = tempfile::tempdir().unwrap();
let config = tiny_config();
let graph = build_graph(&config).unwrap();
let mut tensors = HashMap::new();
for (name, info) in &graph.weights {
let numel = info.shape.0.iter().product::<usize>();
tensors.insert(name.clone(), vec![0.0f32; numel]);
}
let weights = ModelWeights { tensors };
let out = dir.path().join("model_with_weights.onnx");
export_onnx(&graph, &weights, &out).unwrap();
assert!(out.exists());
let weights_size = out.metadata().unwrap().len();
let empty_out = dir.path().join("model_empty.onnx");
export_onnx(&graph, &empty_weights(), &empty_out).unwrap();
let empty_size = empty_out.metadata().unwrap().len();
assert!(
weights_size > empty_size,
"file with weights ({weights_size} bytes) should be larger than without ({empty_size} bytes)"
);
}
#[test]
fn op_mapping_completeness() {
let op_variants_with_nodes: &[Op] = &[
Op::MatMul,
Op::BatchMatMul,
Op::Add,
Op::Mul,
Op::SiLU,
Op::GeLU,
Op::ReLU,
Op::RMSNorm { eps: 1e-5 },
Op::LayerNorm { eps: 1e-5 },
Op::RoPE {
max_seq_len: 64,
rope_theta: 10000.0,
head_dim: 16,
},
Op::Attention {
num_heads: 4,
num_kv_heads: 2,
head_dim: 16,
},
Op::Softmax,
Op::Embedding {
vocab_size: 256,
embed_dim: 64,
},
Op::LogitsProjection { vocab_size: 256 },
Op::Residual,
];
for op in op_variants_with_nodes {
assert!(op_to_onnx(op).is_some(), "op {op} should map to an ONNX op");
}
assert!(op_to_onnx(&Op::LoadWeight { name: "w".into() }).is_none());
assert!(op_to_onnx(&Op::Input { name: "x".into() }).is_none());
}
}