use crate::{
error::{OnnxError, Result},
graph::{Graph, Node, TensorSpec},
model::{Model, ModelMetadata},
proto,
tensor::Tensor,
};
use prost::Message;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
pub fn from_model_proto(model_proto: &proto::ModelProto) -> Result<Model> {
let graph_proto = model_proto
.graph
.as_ref()
.ok_or_else(|| OnnxError::model_load_error("Model proto missing graph"))?;
let metadata = ModelMetadata {
name: graph_proto.name.clone().unwrap_or_default(),
version: model_proto
.model_version
.map(|v| v.to_string())
.unwrap_or_default(),
description: model_proto.doc_string.clone().unwrap_or_default(),
producer: model_proto.producer_name.clone().unwrap_or_default(),
onnx_version: format!("IR_VERSION_{}", model_proto.ir_version.unwrap_or(0)),
domain: model_proto.domain.clone().unwrap_or_default(),
};
let graph = from_graph_proto(graph_proto)?;
Ok(Model::with_metadata(metadata, graph))
}
pub fn from_graph_proto(graph_proto: &proto::GraphProto) -> Result<Graph> {
let mut nodes = Vec::new();
for node_proto in &graph_proto.node {
nodes.push(from_node_proto(node_proto)?);
}
let mut inputs = Vec::new();
for input_proto in &graph_proto.input {
inputs.push(from_value_info_proto(input_proto)?);
}
let mut outputs = Vec::new();
for output_proto in &graph_proto.output {
outputs.push(from_value_info_proto(output_proto)?);
}
let mut initializers = HashMap::new();
for tensor_proto in &graph_proto.initializer {
let name = tensor_proto.name.clone().unwrap_or_default();
let tensor = from_tensor_proto(tensor_proto)?;
initializers.insert(name, tensor);
}
Ok(Graph {
name: graph_proto.name.clone().unwrap_or_default(),
nodes,
inputs,
outputs,
initializers,
})
}
pub fn from_node_proto(node_proto: &proto::NodeProto) -> Result<Node> {
let inputs: Vec<String> = node_proto
.input
.iter()
.filter(|input| !input.is_empty())
.cloned()
.collect();
let mut node = Node::new(
node_proto.name.clone().unwrap_or_default(),
node_proto.op_type.clone().unwrap_or_default(),
inputs,
node_proto.output.clone(),
);
for attr_proto in &node_proto.attribute {
let name = attr_proto.name.clone().unwrap_or_default();
let value = from_attribute_proto(attr_proto)?;
node.add_attribute(name, value);
}
Ok(node)
}
pub fn from_value_info_proto(value_info_proto: &proto::ValueInfoProto) -> Result<TensorSpec> {
let name = value_info_proto.name.clone().unwrap_or_default();
let type_proto = value_info_proto
.r#type
.as_ref()
.ok_or_else(|| OnnxError::model_load_error("ValueInfo missing type information"))?;
let tensor_type = match &type_proto.value {
Some(proto::type_proto::Value::TensorType(tensor)) => tensor,
_ => {
return Err(OnnxError::model_load_error(
"Type proto missing tensor type",
))
}
};
let mut dimensions = Vec::new();
if let Some(shape_proto) = &tensor_type.shape {
for dim in &shape_proto.dim {
match &dim.value {
Some(proto::tensor_shape_proto::dimension::Value::DimValue(dim_value)) => {
dimensions.push(Some(*dim_value as usize));
}
Some(proto::tensor_shape_proto::dimension::Value::DimParam(_)) => {
dimensions.push(None);
}
None => {
dimensions.push(None);
}
}
}
}
let dtype = match proto::tensor_proto::DataType::try_from(tensor_type.elem_type.unwrap_or(0)) {
Ok(proto::tensor_proto::DataType::Float) => "float32",
Ok(proto::tensor_proto::DataType::Double) => "float64",
Ok(proto::tensor_proto::DataType::Int32) => "int32",
Ok(proto::tensor_proto::DataType::Int64) => "int64",
_ => {
return Err(OnnxError::unsupported_operation(format!(
"Unsupported data type: {}",
tensor_type.elem_type.unwrap_or(0)
)))
}
};
Ok(TensorSpec {
name,
dimensions,
dtype: dtype.to_string(),
})
}
pub fn from_tensor_proto(tensor_proto: &proto::TensorProto) -> Result<Tensor> {
let shape: Vec<usize> = tensor_proto.dims.iter().map(|&dim| dim as usize).collect();
let data_type = proto::tensor_proto::DataType::try_from(tensor_proto.data_type.unwrap_or(0))
.map_err(|_| {
OnnxError::unsupported_operation(format!(
"Unsupported tensor data type: {}",
tensor_proto.data_type.unwrap_or(0)
))
})?;
match data_type {
proto::tensor_proto::DataType::Float => {
let data = if !tensor_proto.float_data.is_empty() {
tensor_proto.float_data.clone()
} else if let Some(ref raw_data) = tensor_proto.raw_data {
if !raw_data.is_empty() {
if raw_data.len() % 4 != 0 {
return Err(OnnxError::model_load_error(
"Invalid raw data length for float32",
));
}
let mut floats = Vec::with_capacity(raw_data.len() / 4);
for chunk in raw_data.chunks_exact(4) {
let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
floats.push(f32::from_le_bytes(bytes));
}
floats
} else {
return Err(OnnxError::model_load_error("Tensor missing data"));
}
} else {
return Err(OnnxError::model_load_error("Tensor missing data"));
};
Tensor::from_shape_vec(&shape, data)
}
proto::tensor_proto::DataType::Int64 => {
let data = if !tensor_proto.int64_data.is_empty() {
tensor_proto.int64_data.iter().map(|&x| x as f32).collect()
} else if let Some(ref raw_data) = tensor_proto.raw_data {
if !raw_data.is_empty() {
if raw_data.len() % 8 != 0 {
return Err(OnnxError::model_load_error(
"Invalid raw data length for int64",
));
}
let mut floats = Vec::with_capacity(raw_data.len() / 8);
for chunk in raw_data.chunks_exact(8) {
let bytes = [
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
];
let int_val = i64::from_le_bytes(bytes);
floats.push(int_val as f32);
}
floats
} else {
return Err(OnnxError::model_load_error("Tensor missing data"));
}
} else {
return Err(OnnxError::model_load_error("Tensor missing data"));
};
Tensor::from_shape_vec(&shape, data)
}
proto::tensor_proto::DataType::Int32 => {
let data = if !tensor_proto.int32_data.is_empty() {
tensor_proto.int32_data.iter().map(|&x| x as f32).collect()
} else if let Some(ref raw_data) = tensor_proto.raw_data {
if !raw_data.is_empty() {
if raw_data.len() % 4 != 0 {
return Err(OnnxError::model_load_error(
"Invalid raw data length for int32",
));
}
let mut floats = Vec::with_capacity(raw_data.len() / 4);
for chunk in raw_data.chunks_exact(4) {
let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
let int_val = i32::from_le_bytes(bytes);
floats.push(int_val as f32);
}
floats
} else {
return Err(OnnxError::model_load_error("Tensor missing data"));
}
} else {
return Err(OnnxError::model_load_error("Tensor missing data"));
};
Tensor::from_shape_vec(&shape, data)
}
_ => Err(OnnxError::unsupported_operation(format!(
"Unsupported tensor data type: {data_type:?}"
))),
}
}
pub fn from_attribute_proto(attr_proto: &proto::AttributeProto) -> Result<String> {
if let Some(ref s) = attr_proto.s {
Ok(String::from_utf8_lossy(s).to_string())
} else if let Some(i) = attr_proto.i {
Ok(i.to_string())
} else if let Some(f) = attr_proto.f {
Ok(f.to_string())
} else if !attr_proto.ints.is_empty() {
Ok(format!("{:?}", attr_proto.ints))
} else if !attr_proto.floats.is_empty() {
Ok(format!("{:?}", attr_proto.floats))
} else {
Ok(String::new())
}
}
pub fn load_onnx_model<P: AsRef<Path>>(path: P) -> Result<Model> {
let path = path.as_ref();
let bytes = fs::read(path).map_err(|e| {
OnnxError::model_load_error(format!(
"Failed to read ONNX file '{}': {}",
path.display(),
e
))
})?;
let model_proto = proto::ModelProto::decode(&bytes[..]).map_err(|e| {
OnnxError::model_load_error(format!(
"Failed to parse ONNX file '{}': {}",
path.display(),
e
))
})?;
from_model_proto(&model_proto)
}
pub fn to_model_proto(model: &Model) -> Result<proto::ModelProto> {
let model_proto = proto::ModelProto {
ir_version: Some(7i64), producer_name: Some(model.metadata.producer.clone()),
producer_version: Some(model.metadata.version.clone()),
domain: Some(model.metadata.domain.clone()),
doc_string: Some(model.metadata.description.clone()),
graph: Some(to_graph_proto(&model.graph)?),
model_version: None,
opset_import: vec![],
metadata_props: vec![],
training_info: vec![],
functions: vec![],
configuration: vec![],
};
Ok(model_proto)
}
pub fn to_graph_proto(graph: &Graph) -> Result<proto::GraphProto> {
let mut nodes = Vec::new();
for node in &graph.nodes {
nodes.push(to_node_proto(node)?);
}
let mut inputs = Vec::new();
for input in &graph.inputs {
inputs.push(to_value_info_proto(input)?);
}
let mut outputs = Vec::new();
for output in &graph.outputs {
outputs.push(to_value_info_proto(output)?);
}
let mut initializers = Vec::new();
for (name, tensor) in &graph.initializers {
initializers.push(to_tensor_proto(name, tensor)?);
}
let graph_proto = proto::GraphProto {
node: nodes,
name: Some(graph.name.clone()),
initializer: initializers,
sparse_initializer: vec![],
doc_string: None,
input: inputs,
output: outputs,
value_info: vec![],
quantization_annotation: vec![],
metadata_props: vec![],
};
Ok(graph_proto)
}
pub fn to_node_proto(node: &Node) -> Result<proto::NodeProto> {
let node_proto = proto::NodeProto {
input: node.inputs.clone(),
output: node.outputs.clone(),
name: Some(node.name.clone()),
op_type: Some(node.op_type.clone()),
domain: None,
attribute: vec![], doc_string: None,
overload: None,
metadata_props: vec![],
device_configurations: vec![],
};
Ok(node_proto)
}
pub fn to_value_info_proto(spec: &TensorSpec) -> Result<proto::ValueInfoProto> {
let elem_type = match spec.dtype.as_str() {
"float32" => proto::tensor_proto::DataType::Float as i32,
"float64" => proto::tensor_proto::DataType::Double as i32,
"int32" => proto::tensor_proto::DataType::Int32 as i32,
"int64" => proto::tensor_proto::DataType::Int64 as i32,
_ => {
return Err(OnnxError::unsupported_operation(format!(
"Unsupported data type: {}",
spec.dtype
)))
}
};
let mut dims = Vec::new();
for dim_opt in &spec.dimensions {
match dim_opt {
Some(dim_size) => {
dims.push(proto::tensor_shape_proto::Dimension {
value: Some(proto::tensor_shape_proto::dimension::Value::DimValue(
*dim_size as i64,
)),
denotation: None,
});
}
None => {
dims.push(proto::tensor_shape_proto::Dimension {
value: Some(proto::tensor_shape_proto::dimension::Value::DimParam(
"dynamic".to_string(),
)),
denotation: None,
});
}
}
}
let tensor_type = proto::type_proto::Tensor {
elem_type: Some(elem_type),
shape: Some(proto::TensorShapeProto { dim: dims }),
};
let type_proto = proto::TypeProto {
denotation: None,
value: Some(proto::type_proto::Value::TensorType(tensor_type)),
};
let value_info = proto::ValueInfoProto {
name: Some(spec.name.clone()),
r#type: Some(type_proto),
doc_string: None,
metadata_props: vec![],
};
Ok(value_info)
}
pub fn to_tensor_proto(name: &str, tensor: &Tensor) -> Result<proto::TensorProto> {
let tensor_proto = proto::TensorProto {
dims: tensor.shape().iter().map(|&dim| dim as i64).collect(),
data_type: Some(proto::tensor_proto::DataType::Float as i32),
segment: None,
float_data: tensor.data().iter().cloned().collect(),
int32_data: vec![],
string_data: vec![],
int64_data: vec![],
name: Some(name.to_string()),
doc_string: None,
raw_data: None,
external_data: vec![],
data_location: None,
double_data: vec![],
uint64_data: vec![],
metadata_props: vec![],
};
Ok(tensor_proto)
}
pub fn save_onnx_model<P: AsRef<Path>>(model: &Model, path: P) -> Result<()> {
let path = path.as_ref();
let model_proto = to_model_proto(model)?;
let mut bytes = Vec::new();
model_proto
.encode(&mut bytes)
.map_err(|e| OnnxError::other(format!("Failed to serialize model: {e}")))?;
fs::write(path, bytes).map_err(|e| {
OnnxError::other(format!(
"Failed to write ONNX file '{}': {}",
path.display(),
e
))
})
}
#[cfg(test)]
mod converter_tests {
use super::*;
use crate::graph::TensorSpec;
use crate::{Graph, Model, Node};
#[test]
fn test_converter_round_trip() {
let mut graph = Graph::new("test_converter".to_string());
let input_spec = TensorSpec::new("input".to_string(), vec![Some(1), Some(3)]);
graph.add_input(input_spec);
let output_spec = TensorSpec::new("output".to_string(), vec![Some(1), Some(3)]);
graph.add_output(output_spec);
let node = Node::new(
"relu".to_string(),
"Relu".to_string(),
vec!["input".to_string()],
vec!["output".to_string()],
);
graph.add_node(node);
let original_model = Model::new(graph);
let proto = to_model_proto(&original_model).expect("Failed to convert to proto");
let converted_model = from_model_proto(&proto).expect("Failed to convert from proto");
assert_eq!(original_model.graph.name, converted_model.graph.name);
assert_eq!(
original_model.graph.nodes.len(),
converted_model.graph.nodes.len()
);
assert_eq!(
original_model.graph.inputs.len(),
converted_model.graph.inputs.len()
);
assert_eq!(
original_model.graph.outputs.len(),
converted_model.graph.outputs.len()
);
let orig_node = &original_model.graph.nodes[0];
let conv_node = &converted_model.graph.nodes[0];
assert_eq!(orig_node.name, conv_node.name);
assert_eq!(orig_node.op_type, conv_node.op_type);
assert_eq!(orig_node.inputs, conv_node.inputs);
assert_eq!(orig_node.outputs, conv_node.outputs);
}
#[test]
fn test_tensor_spec_conversion() {
let spec = TensorSpec::new("test".to_string(), vec![Some(2), Some(4)]);
let value_info = to_value_info_proto(&spec).expect("Failed to convert TensorSpec");
let converted_spec = from_value_info_proto(&value_info).expect("Failed to convert back");
assert_eq!(spec.name, converted_spec.name);
assert_eq!(spec.dimensions, converted_spec.dimensions);
assert_eq!(spec.dtype, converted_spec.dtype);
let dynamic_spec = TensorSpec::new("dynamic".to_string(), vec![None, Some(3), None]);
let dynamic_value_info =
to_value_info_proto(&dynamic_spec).expect("Failed to convert dynamic TensorSpec");
let converted_dynamic =
from_value_info_proto(&dynamic_value_info).expect("Failed to convert back dynamic");
assert_eq!(dynamic_spec.name, converted_dynamic.name);
assert_eq!(dynamic_spec.dimensions, converted_dynamic.dimensions);
assert_eq!(dynamic_spec.dtype, converted_dynamic.dtype);
}
#[test]
fn test_load_save_onnx_model() {
use std::env;
use std::fs;
let mut graph = Graph::new("test_save_load".to_string());
let input_spec = TensorSpec::new("x".to_string(), vec![Some(1), Some(2)]);
graph.add_input(input_spec);
let output_spec = TensorSpec::new("y".to_string(), vec![Some(1), Some(2)]);
graph.add_output(output_spec);
let node = Node::new(
"sigmoid".to_string(),
"Sigmoid".to_string(),
vec!["x".to_string()],
vec!["y".to_string()],
);
graph.add_node(node);
let model = Model::new(graph);
let temp_dir = env::temp_dir();
let test_path = temp_dir.join("test_converter.onnx");
save_onnx_model(&model, &test_path).expect("Failed to save ONNX model");
let loaded_model = load_onnx_model(&test_path).expect("Failed to load ONNX model");
assert_eq!(model.graph.name, loaded_model.graph.name);
assert_eq!(model.graph.nodes.len(), loaded_model.graph.nodes.len());
assert_eq!(model.graph.inputs.len(), loaded_model.graph.inputs.len());
assert_eq!(model.graph.outputs.len(), loaded_model.graph.outputs.len());
let _ = fs::remove_file(&test_path);
}
#[test]
fn test_invalid_tensor_type() {
let value_info = proto::ValueInfoProto {
name: Some("invalid".to_string()),
r#type: None,
doc_string: None,
metadata_props: vec![],
};
let result = from_value_info_proto(&value_info);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("missing type information"));
}
#[test]
fn test_protobuf_conversion_edge_cases() {
let graph_proto = proto::GraphProto {
node: vec![],
name: Some("empty_graph".to_string()),
initializer: vec![],
sparse_initializer: vec![],
doc_string: None,
input: vec![],
output: vec![],
value_info: vec![],
quantization_annotation: vec![],
metadata_props: vec![],
};
let result = from_graph_proto(&graph_proto).unwrap();
assert_eq!(result.name, "empty_graph");
assert!(result.nodes.is_empty());
assert!(result.inputs.is_empty());
assert!(result.outputs.is_empty());
assert!(result.initializers.is_empty());
}
#[test]
fn test_tensor_proto_edge_cases() {
let tensor_proto = proto::TensorProto {
dims: vec![],
data_type: Some(proto::tensor_proto::DataType::Float as i32),
segment: None,
float_data: vec![42.0],
int32_data: vec![],
string_data: vec![],
int64_data: vec![],
name: Some("scalar".to_string()),
doc_string: None,
raw_data: None,
external_data: vec![],
data_location: None,
double_data: vec![],
uint64_data: vec![],
metadata_props: vec![],
};
let result = from_tensor_proto(&tensor_proto).unwrap();
assert_eq!(result.shape(), &[] as &[usize]);
assert_eq!(result.data().as_slice().unwrap(), &[42.0]);
}
#[test]
fn test_unsupported_data_types() {
let tensor_proto = proto::TensorProto {
dims: vec![2],
data_type: Some(proto::tensor_proto::DataType::Double as i32),
segment: None,
float_data: vec![],
int32_data: vec![],
string_data: vec![],
int64_data: vec![],
name: Some("double_tensor".to_string()),
doc_string: None,
raw_data: None,
external_data: vec![],
data_location: None,
double_data: vec![1.0, 2.0],
uint64_data: vec![],
metadata_props: vec![],
};
let result = from_tensor_proto(&tensor_proto);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Double"));
}
#[test]
fn test_file_io_errors() {
let result = load_onnx_model("nonexistent_file.onnx");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Failed to read ONNX file"));
}
#[test]
fn test_round_trip_with_initializers() {
let mut graph = Graph::new("test_graph".to_string());
let input = TensorSpec::new("input".to_string(), vec![Some(2), Some(2)]);
let output = TensorSpec::new("output".to_string(), vec![Some(2), Some(2)]);
graph.add_input(input);
graph.add_output(output);
let node = Node::new(
"add_node".to_string(),
"Add".to_string(),
vec!["input".to_string(), "weights".to_string()],
vec!["output".to_string()],
);
graph.add_node(node);
let weights = Tensor::from_shape_vec(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
graph.add_initializer("weights".to_string(), weights);
let original_model = Model::new(graph);
let proto = to_model_proto(&original_model).unwrap();
let converted_model = from_model_proto(&proto).unwrap();
assert_eq!(
original_model.graph.initializers.len(),
converted_model.graph.initializers.len()
);
assert!(converted_model.graph.initializers.contains_key("weights"));
for (orig_node, conv_node) in original_model
.graph
.nodes
.iter()
.zip(converted_model.graph.nodes.iter())
{
assert_eq!(orig_node.name, conv_node.name);
assert_eq!(orig_node.op_type, conv_node.op_type);
assert_eq!(orig_node.inputs, conv_node.inputs);
assert_eq!(orig_node.outputs, conv_node.outputs);
assert_eq!(conv_node.attributes.len(), 0);
}
}
}