use std::collections::HashMap;
use prost::Message;
use crate::nn::Module;
use crate::onnx::proto::onnx as pb;
use crate::onnx::tracer::{self, TracedAttribute, TracedGraph, TracedNode, ValueInfo};
use crate::tensor::{DType, Tensor};
#[derive(Debug)]
pub enum OnnxError {
Io(std::io::Error),
UnsupportedDType(DType),
ShapeError(String),
}
impl From<std::io::Error> for OnnxError {
fn from(e: std::io::Error) -> Self {
OnnxError::Io(e)
}
}
impl std::fmt::Display for OnnxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OnnxError::Io(e) => write!(f, "IO error: {}", e),
OnnxError::UnsupportedDType(dt) => write!(f, "Unsupported dtype: {:?}", dt),
OnnxError::ShapeError(msg) => write!(f, "Shape error: {}", msg),
}
}
}
impl std::error::Error for OnnxError {}
const ONNX_FLOAT: i32 = 1;
const ONNX_FLOAT16: i32 = 10;
fn dtype_to_onnx(dtype: DType) -> i32 {
match dtype {
DType::F32 => ONNX_FLOAT,
DType::F16 => ONNX_FLOAT16,
DType::Q8 { .. } => ONNX_FLOAT,
}
}
fn tensor_to_proto(name: &str, tensor: &Tensor, shape_override: Option<&[usize]>) -> pb::TensorProto {
let shape = shape_override.unwrap_or(tensor.shape());
match tensor.dtype() {
DType::F16 => {
#[cfg(feature = "gpu")]
let raw_bytes = tensor.storage.download_raw_bytes();
#[cfg(not(feature = "gpu"))]
let raw_bytes = {
let guard = tensor.storage.data();
bytemuck::cast_slice(&*guard).to_vec()
};
pb::TensorProto {
dims: shape.iter().map(|&d| d as i64).collect(),
data_type: ONNX_FLOAT16,
name: name.to_string(),
raw_data: raw_bytes,
..Default::default()
}
}
_ => {
let guard = tensor.storage.data();
let raw_bytes: Vec<u8> = bytemuck::cast_slice(&*guard).to_vec();
drop(guard);
pb::TensorProto {
dims: shape.iter().map(|&d| d as i64).collect(),
data_type: ONNX_FLOAT,
name: name.to_string(),
raw_data: raw_bytes,
..Default::default()
}
}
}
}
fn value_info_to_proto(name: &str, info: &ValueInfo) -> pb::ValueInfoProto {
let elem_type = dtype_to_onnx(info.dtype);
let shape = pb::TensorShapeProto {
dim: info
.shape
.iter()
.map(|&d| pb::tensor_shape_proto::Dimension {
denotation: String::new(),
value: Some(pb::tensor_shape_proto::dimension::Value::DimValue(d as i64)),
})
.collect(),
};
pb::ValueInfoProto {
name: name.to_string(),
r#type: Some(pb::TypeProto {
denotation: String::new(),
value: Some(pb::type_proto::Value::TensorType(pb::type_proto::Tensor {
elem_type,
shape: Some(shape),
})),
}),
doc_string: String::new(),
}
}
fn attr_to_proto(attr: &TracedAttribute) -> pb::AttributeProto {
match attr {
TracedAttribute::Int(name, val) => pb::AttributeProto {
name: name.clone(),
r#type: 2, i: *val,
..Default::default()
},
TracedAttribute::Float(name, val) => pb::AttributeProto {
name: name.clone(),
r#type: 1, f: *val,
..Default::default()
},
TracedAttribute::Ints(name, vals) => pb::AttributeProto {
name: name.clone(),
r#type: 7, ints: vals.clone(),
..Default::default()
},
TracedAttribute::Floats(name, vals) => pb::AttributeProto {
name: name.clone(),
r#type: 6, floats: vals.clone(),
..Default::default()
},
TracedAttribute::String(name, val) => pb::AttributeProto {
name: name.clone(),
r#type: 3, s: val.clone(),
..Default::default()
},
}
}
fn node_to_proto(node: &TracedNode, idx: usize) -> pb::NodeProto {
pb::NodeProto {
input: node.inputs.clone(),
output: node.outputs.clone(),
name: format!("{}_{}", node.op_type, idx),
op_type: node.op_type.clone(),
domain: String::new(),
attribute: node.attributes.iter().map(attr_to_proto).collect(),
doc_string: String::new(),
}
}
fn graph_to_proto(graph: &TracedGraph, value_shapes: &HashMap<String, ValueInfo>) -> pb::GraphProto {
let nodes: Vec<pb::NodeProto> = graph
.nodes
.iter()
.enumerate()
.map(|(i, n)| node_to_proto(n, i))
.collect();
let inputs: Vec<pb::ValueInfoProto> = graph
.inputs
.iter()
.map(|(name, info)| value_info_to_proto(name, info))
.collect();
let outputs: Vec<pb::ValueInfoProto> = graph
.outputs
.iter()
.map(|(name, info)| value_info_to_proto(name, info))
.collect();
let initializers: Vec<pb::TensorProto> = graph
.initializers
.iter()
.map(|(name, tensor)| {
let shape_override = value_shapes.get(name).map(|vi| vi.shape.as_slice());
tensor_to_proto(name, tensor, shape_override)
})
.collect();
pb::GraphProto {
node: nodes,
name: "rumus_graph".to_string(),
initializer: initializers,
doc_string: String::new(),
input: inputs,
output: outputs,
}
}
fn build_model_proto(graph: &TracedGraph, value_shapes: &HashMap<String, ValueInfo>, opset: u32) -> pb::ModelProto {
pb::ModelProto {
ir_version: 9,
opset_import: vec![pb::OperatorSetIdProto {
domain: String::new(),
version: opset as i64,
}],
producer_name: "rumus".to_string(),
producer_version: env!("CARGO_PKG_VERSION").to_string(),
domain: String::new(),
model_version: 1,
doc_string: String::new(),
graph: Some(graph_to_proto(graph, value_shapes)),
}
}
pub fn export_onnx<M, F>(
model: &M,
input_specs: &[(&str, Vec<usize>, DType)],
path: &str,
forward_fn: F,
) -> Result<(), OnnxError>
where
M: Module,
F: FnOnce(&[Tensor]) -> Tensor,
{
export_onnx_with_opset(model, input_specs, path, forward_fn, 17)
}
pub fn export_onnx_with_opset<M, F>(
model: &M,
input_specs: &[(&str, Vec<usize>, DType)],
path: &str,
forward_fn: F,
opset: u32,
) -> Result<(), OnnxError>
where
M: Module,
F: FnOnce(&[Tensor]) -> Tensor,
{
let state_dict = model.state_dict("");
let graph = tracer::trace(&state_dict, input_specs, forward_fn);
let mut value_shapes = HashMap::new();
for (name, tensor) in &graph.initializers {
value_shapes.insert(
name.clone(),
ValueInfo {
shape: tensor.shape().to_vec(),
dtype: tensor.dtype(),
},
);
}
let model_proto = build_model_proto(&graph, &value_shapes, opset);
let bytes = model_proto.encode_to_vec();
std::fs::write(path, bytes)?;
Ok(())
}