use crate::error::IoError;
#[derive(Debug, Clone)]
pub enum WireType {
Varint(u64),
LengthDelimited(Vec<u8>),
Fixed32(u32),
Fixed64(u64),
}
pub fn decode_varint(data: &[u8], pos: &mut usize) -> Result<u64, IoError> {
let mut result: u64 = 0;
let mut shift = 0u32;
loop {
if *pos >= data.len() {
return Err(IoError::ParseError(
"protobuf: unexpected end of data in varint".to_string(),
));
}
let byte = data[*pos];
*pos += 1;
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
return Ok(result);
}
shift += 7;
if shift >= 64 {
return Err(IoError::ParseError(
"protobuf: varint overflow (>64 bits)".to_string(),
));
}
}
}
pub fn decode_field(data: &[u8], pos: &mut usize) -> Result<(u32, WireType), IoError> {
let tag = decode_varint(data, pos)?;
let field_num = (tag >> 3) as u32;
let wire_type = tag & 0x7;
let value = match wire_type {
0 => {
let v = decode_varint(data, pos)?;
WireType::Varint(v)
}
1 => {
if *pos + 8 > data.len() {
return Err(IoError::ParseError(
"protobuf: unexpected end of data in fixed64".to_string(),
));
}
let bytes: [u8; 8] = data[*pos..*pos + 8]
.try_into()
.map_err(|_| IoError::ParseError("protobuf: fixed64 slice error".to_string()))?;
*pos += 8;
WireType::Fixed64(u64::from_le_bytes(bytes))
}
2 => {
let len = decode_varint(data, pos)? as usize;
if *pos + len > data.len() {
return Err(IoError::ParseError(format!(
"protobuf: length-delimited field needs {} bytes but only {} remain",
len,
data.len() - *pos
)));
}
let bytes = data[*pos..*pos + len].to_vec();
*pos += len;
WireType::LengthDelimited(bytes)
}
5 => {
if *pos + 4 > data.len() {
return Err(IoError::ParseError(
"protobuf: unexpected end of data in fixed32".to_string(),
));
}
let bytes: [u8; 4] = data[*pos..*pos + 4]
.try_into()
.map_err(|_| IoError::ParseError("protobuf: fixed32 slice error".to_string()))?;
*pos += 4;
WireType::Fixed32(u32::from_le_bytes(bytes))
}
wt => {
return Err(IoError::ParseError(format!(
"protobuf: unsupported wire type {wt} for field {field_num}"
)));
}
};
Ok((field_num, value))
}
fn parse_message(data: &[u8]) -> Result<Vec<(u32, WireType)>, IoError> {
let mut fields = Vec::new();
let mut pos = 0;
while pos < data.len() {
let (field_num, value) = decode_field(data, &mut pos)?;
fields.push((field_num, value));
}
Ok(fields)
}
fn wire_to_string(wt: &WireType) -> Result<String, IoError> {
match wt {
WireType::LengthDelimited(bytes) => String::from_utf8(bytes.clone())
.map_err(|e| IoError::ParseError(format!("protobuf: invalid UTF-8 string: {e}"))),
other => Err(IoError::ParseError(format!(
"protobuf: expected length-delimited for string, got {:?}",
std::mem::discriminant(other)
))),
}
}
fn wire_to_u64(wt: &WireType) -> Result<u64, IoError> {
match wt {
WireType::Varint(v) => Ok(*v),
WireType::Fixed64(v) => Ok(*v),
other => Err(IoError::ParseError(format!(
"protobuf: expected varint for integer, got {:?}",
std::mem::discriminant(other)
))),
}
}
#[derive(Debug, Clone, Default)]
pub struct OnnxTensorProto {
pub dims: Vec<i64>,
pub data_type: i32,
pub float_data: Vec<f32>,
pub int64_data: Vec<i64>,
pub name: String,
pub raw_data: Vec<u8>,
}
impl OnnxTensorProto {
fn parse(data: &[u8]) -> Result<Self, IoError> {
let fields = parse_message(data)?;
let mut tp = OnnxTensorProto::default();
for (field_num, value) in fields {
match field_num {
1 => tp.dims.push(wire_to_u64(&value)? as i64),
2 => tp.data_type = wire_to_u64(&value)? as i32,
4 => match &value {
WireType::Fixed32(bits) => tp.float_data.push(f32::from_bits(*bits)),
WireType::LengthDelimited(bytes) => {
for chunk in bytes.chunks(4) {
if chunk.len() == 4 {
let arr: [u8; 4] = chunk.try_into().map_err(|_| {
IoError::ParseError("packed float chunk error".to_string())
})?;
tp.float_data.push(f32::from_le_bytes(arr));
}
}
}
_ => {}
},
7 => match &value {
WireType::Varint(v) => tp.int64_data.push(*v as i64),
WireType::LengthDelimited(bytes) => {
let mut pos = 0;
while pos < bytes.len() {
let v = decode_varint(bytes, &mut pos)?;
tp.int64_data.push(v as i64);
}
}
_ => {}
},
9 => tp.name = wire_to_string(&value)?,
12 => {
if let WireType::LengthDelimited(bytes) = value {
tp.raw_data = bytes;
}
}
_ => {}
}
}
Ok(tp)
}
}
#[derive(Debug, Clone, Default)]
pub struct OnnxAttributeProto {
pub name: String,
pub attribute_type: i32,
pub f: f32,
pub i: i64,
pub s: Vec<u8>,
}
impl OnnxAttributeProto {
fn parse(data: &[u8]) -> Result<Self, IoError> {
let fields = parse_message(data)?;
let mut attr = OnnxAttributeProto::default();
for (field_num, value) in fields {
match field_num {
1 => attr.name = wire_to_string(&value)?,
3 => attr.i = wire_to_u64(&value)? as i64,
4 => {
if let WireType::Fixed32(bits) = value {
attr.f = f32::from_bits(bits);
}
}
5 => {
if let WireType::LengthDelimited(bytes) = value {
attr.s = bytes;
}
}
20 => attr.attribute_type = wire_to_u64(&value)? as i32,
_ => {}
}
}
Ok(attr)
}
}
#[derive(Debug, Clone, Default)]
pub struct OnnxNodeProto {
pub input: Vec<String>,
pub output: Vec<String>,
pub name: String,
pub op_type: String,
pub attribute: Vec<OnnxAttributeProto>,
pub domain: String,
}
impl OnnxNodeProto {
fn parse(data: &[u8]) -> Result<Self, IoError> {
let fields = parse_message(data)?;
let mut node = OnnxNodeProto::default();
for (field_num, value) in fields {
match field_num {
1 => node.input.push(wire_to_string(&value)?),
2 => node.output.push(wire_to_string(&value)?),
3 => node.name = wire_to_string(&value)?,
4 => node.op_type = wire_to_string(&value)?,
5 => {
if let WireType::LengthDelimited(bytes) = value {
node.attribute.push(OnnxAttributeProto::parse(&bytes)?);
}
}
7 => node.domain = wire_to_string(&value)?,
_ => {}
}
}
Ok(node)
}
}
#[derive(Debug, Clone, Default)]
pub struct OnnxValueInfoProto {
pub name: String,
}
impl OnnxValueInfoProto {
fn parse(data: &[u8]) -> Result<Self, IoError> {
let fields = parse_message(data)?;
let mut vi = OnnxValueInfoProto::default();
for (field_num, value) in &fields {
if *field_num == 1 {
vi.name = wire_to_string(value)?;
}
}
Ok(vi)
}
}
#[derive(Debug, Clone, Default)]
pub struct OnnxOperatorSetIdProto {
pub domain: String,
pub version: u64,
}
impl OnnxOperatorSetIdProto {
fn parse(data: &[u8]) -> Result<Self, IoError> {
let fields = parse_message(data)?;
let mut op = OnnxOperatorSetIdProto::default();
for (field_num, value) in fields {
match field_num {
1 => op.domain = wire_to_string(&value)?,
2 => op.version = wire_to_u64(&value)?,
_ => {}
}
}
Ok(op)
}
}
#[derive(Debug, Clone, Default)]
pub struct OnnxGraphProto {
pub node: Vec<OnnxNodeProto>,
pub name: String,
pub initializer: Vec<OnnxTensorProto>,
pub input: Vec<OnnxValueInfoProto>,
pub output: Vec<OnnxValueInfoProto>,
}
impl OnnxGraphProto {
fn parse(data: &[u8]) -> Result<Self, IoError> {
let fields = parse_message(data)?;
let mut graph = OnnxGraphProto::default();
for (field_num, value) in fields {
match field_num {
1 => {
if let WireType::LengthDelimited(bytes) = value {
graph.node.push(OnnxNodeProto::parse(&bytes)?);
}
}
2 => graph.name = wire_to_string(&value)?,
5 => {
if let WireType::LengthDelimited(bytes) = value {
graph.initializer.push(OnnxTensorProto::parse(&bytes)?);
}
}
11 => {
if let WireType::LengthDelimited(bytes) = value {
graph.input.push(OnnxValueInfoProto::parse(&bytes)?);
}
}
12 => {
if let WireType::LengthDelimited(bytes) = value {
graph.output.push(OnnxValueInfoProto::parse(&bytes)?);
}
}
_ => {}
}
}
Ok(graph)
}
}
#[derive(Debug, Clone, Default)]
pub struct OnnxModelProto {
pub ir_version: u64,
pub opset_import: Vec<OnnxOperatorSetIdProto>,
pub domain: String,
pub model_version: u64,
pub doc_string: String,
pub graph: OnnxGraphProto,
}
impl OnnxModelProto {
pub fn parse(data: &[u8]) -> Result<OnnxModelProto, IoError> {
let fields = parse_message(data)?;
let mut model = OnnxModelProto::default();
for (field_num, value) in fields {
match field_num {
1 => model.ir_version = wire_to_u64(&value)?,
2 => model.domain = wire_to_string(&value)?,
5 => model.model_version = wire_to_u64(&value)?,
6 => model.doc_string = wire_to_string(&value)?,
7 => {
if let WireType::LengthDelimited(bytes) = value {
model.graph = OnnxGraphProto::parse(&bytes)?;
}
}
8 => {
if let WireType::LengthDelimited(bytes) = value {
model
.opset_import
.push(OnnxOperatorSetIdProto::parse(&bytes)?);
}
}
_ => {}
}
}
Ok(model)
}
pub fn to_summary(&self) -> OnnxModelSummary {
let n_nodes = self.graph.node.len();
let n_initializers = self.graph.initializer.len();
let opset_version = self
.opset_import
.iter()
.filter(|op| op.domain.is_empty())
.map(|op| op.version)
.next()
.unwrap_or(0);
let mut op_types: Vec<String> = self.graph.node.iter().map(|n| n.op_type.clone()).collect();
op_types.sort();
op_types.dedup();
let input_names: Vec<String> = self.graph.input.iter().map(|v| v.name.clone()).collect();
let output_names: Vec<String> = self.graph.output.iter().map(|v| v.name.clone()).collect();
OnnxModelSummary {
n_nodes,
n_initializers,
opset_version,
op_types,
input_names,
output_names,
}
}
}
#[derive(Debug, Clone)]
pub struct OnnxModelSummary {
pub n_nodes: usize,
pub n_initializers: usize,
pub opset_version: u64,
pub op_types: Vec<String>,
pub input_names: Vec<String>,
pub output_names: Vec<String>,
}
pub fn encode_varint(mut value: u64) -> Vec<u8> {
let mut out = Vec::new();
loop {
let byte = (value & 0x7F) as u8;
value >>= 7;
if value == 0 {
out.push(byte);
break;
} else {
out.push(byte | 0x80);
}
}
out
}
pub fn write_field_tag(field_num: u32, wire_type: u8) -> Vec<u8> {
encode_varint(((field_num as u64) << 3) | (wire_type as u64))
}
fn write_length_delimited(field_num: u32, data: &[u8]) -> Vec<u8> {
let mut out = write_field_tag(field_num, 2);
out.extend(encode_varint(data.len() as u64));
out.extend_from_slice(data);
out
}
fn write_string_field(field_num: u32, s: &str) -> Vec<u8> {
write_length_delimited(field_num, s.as_bytes())
}
fn write_varint_field(field_num: u32, value: u64) -> Vec<u8> {
let mut out = write_field_tag(field_num, 0);
out.extend(encode_varint(value));
out
}
fn encode_node(node: &OnnxNodeProto) -> Vec<u8> {
let mut out = Vec::new();
for inp in &node.input {
out.extend(write_string_field(1, inp));
}
for outp in &node.output {
out.extend(write_string_field(2, outp));
}
if !node.name.is_empty() {
out.extend(write_string_field(3, &node.name));
}
out.extend(write_string_field(4, &node.op_type));
out
}
fn encode_value_info(name: &str) -> Vec<u8> {
write_string_field(1, name)
}
pub fn create_minimal_onnx(
nodes: &[OnnxNodeProto],
inputs: &[String],
outputs: &[String],
) -> Vec<u8> {
let mut graph_bytes = Vec::new();
for node in nodes {
let nb = encode_node(node);
graph_bytes.extend(write_length_delimited(1, &nb));
}
graph_bytes.extend(write_string_field(2, "main_graph"));
for inp in inputs {
let vi = encode_value_info(inp);
graph_bytes.extend(write_length_delimited(11, &vi));
}
for out in outputs {
let vi = encode_value_info(out);
graph_bytes.extend(write_length_delimited(12, &vi));
}
let mut opset_bytes = Vec::new();
opset_bytes.extend(write_string_field(1, "")); opset_bytes.extend(write_varint_field(2, 17));
let mut model_bytes = Vec::new();
model_bytes.extend(write_varint_field(1, 7)); model_bytes.extend(write_length_delimited(7, &graph_bytes)); model_bytes.extend(write_length_delimited(8, &opset_bytes));
model_bytes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decode_varint() {
let data = [0x96, 0x01]; let mut pos = 0;
let v = decode_varint(&data, &mut pos).expect("decode varint 150");
assert_eq!(v, 150);
assert_eq!(pos, 2);
}
#[test]
fn test_decode_varint_single_byte() {
let data = [0x05];
let mut pos = 0;
let v = decode_varint(&data, &mut pos).expect("single byte varint");
assert_eq!(v, 5);
}
#[test]
fn test_encode_decode_varint_roundtrip() {
for val in [
0u64,
1,
127,
128,
255,
300,
16383,
16384,
1_000_000,
u32::MAX as u64,
] {
let encoded = encode_varint(val);
let mut pos = 0;
let decoded = decode_varint(&encoded, &mut pos).expect("roundtrip");
assert_eq!(decoded, val, "varint roundtrip failed for {val}");
}
}
#[test]
fn test_create_and_parse_minimal_onnx() {
let node = OnnxNodeProto {
input: vec!["x".to_string()],
output: vec!["y".to_string()],
op_type: "Relu".to_string(),
name: "relu0".to_string(),
..Default::default()
};
let bytes = create_minimal_onnx(&[node], &["x".to_string()], &["y".to_string()]);
let model = OnnxModelProto::parse(&bytes).expect("parse minimal onnx");
assert_eq!(model.ir_version, 7);
assert_eq!(model.graph.node.len(), 1);
assert_eq!(model.graph.node[0].op_type, "Relu");
assert_eq!(model.graph.input[0].name, "x");
assert_eq!(model.graph.output[0].name, "y");
}
#[test]
fn test_opset_import_parsing() {
let bytes = create_minimal_onnx(&[], &[], &[]);
let model = OnnxModelProto::parse(&bytes).expect("parse opset test");
assert!(!model.opset_import.is_empty());
let default_opset = model
.opset_import
.iter()
.find(|op| op.domain.is_empty())
.expect("default opset");
assert_eq!(default_opset.version, 17);
}
#[test]
fn test_model_summary() {
let nodes = vec![
OnnxNodeProto {
input: vec!["x".to_string()],
output: vec!["h".to_string()],
op_type: "Gemm".to_string(),
name: "gemm0".to_string(),
..Default::default()
},
OnnxNodeProto {
input: vec!["h".to_string()],
output: vec!["y".to_string()],
op_type: "Relu".to_string(),
name: "relu1".to_string(),
..Default::default()
},
];
let bytes = create_minimal_onnx(&nodes, &["x".to_string()], &["y".to_string()]);
let model = OnnxModelProto::parse(&bytes).expect("parse");
let summary = model.to_summary();
assert_eq!(summary.n_nodes, 2);
assert_eq!(summary.opset_version, 17);
assert!(summary.op_types.contains(&"Gemm".to_string()));
assert!(summary.op_types.contains(&"Relu".to_string()));
assert_eq!(summary.input_names, vec!["x"]);
assert_eq!(summary.output_names, vec!["y"]);
}
#[test]
fn test_decode_field_length_delimited() {
let data = [0x12u8, 0x03, 0x61, 0x62, 0x63];
let mut pos = 0;
let (field_num, wt) = decode_field(&data, &mut pos).expect("decode ld field");
assert_eq!(field_num, 2);
match wt {
WireType::LengthDelimited(bytes) => assert_eq!(bytes, vec![0x61, 0x62, 0x63]),
_ => panic!("expected LengthDelimited"),
}
assert_eq!(pos, 5);
}
#[test]
fn test_write_field_tag() {
let tag = write_field_tag(1, 0);
assert_eq!(tag, vec![0x08]); let tag2 = write_field_tag(2, 2);
assert_eq!(tag2, vec![0x12]); }
#[test]
fn test_empty_model() {
let bytes = create_minimal_onnx(&[], &[], &[]);
let model = OnnxModelProto::parse(&bytes).expect("parse empty model");
let summary = model.to_summary();
assert_eq!(summary.n_nodes, 0);
assert_eq!(summary.n_initializers, 0);
assert!(summary.input_names.is_empty());
assert!(summary.output_names.is_empty());
}
}