use oxionnx::Session;
use oxionnx_core::graph::Dim;
fn encode_varint(mut val: u64) -> Vec<u8> {
let mut buf = Vec::new();
loop {
let byte = (val & 0x7F) as u8;
val >>= 7;
if val == 0 {
buf.push(byte);
break;
} else {
buf.push(byte | 0x80);
}
}
buf
}
fn encode_varint_field(field: u32, val: u64) -> Vec<u8> {
let tag = field << 3; let mut buf = encode_varint(tag as u64);
buf.extend(encode_varint(val));
buf
}
fn encode_bytes_field(field: u32, data: &[u8]) -> Vec<u8> {
let tag = (field << 3) | 2; let mut buf = encode_varint(tag as u64);
buf.extend(encode_varint(data.len() as u64));
buf.extend_from_slice(data);
buf
}
fn build_graph_with_input(
input_name: &str,
elem_type: u32,
dims: &[(Option<u64>, Option<&str>)],
) -> Vec<u8> {
let mut shape_bytes = Vec::new();
for (dim_val, dim_param) in dims {
let mut dim_msg = Vec::new();
if let Some(v) = dim_val {
dim_msg.extend(encode_varint_field(1, *v)); }
if let Some(p) = dim_param {
dim_msg.extend(encode_bytes_field(2, p.as_bytes())); }
shape_bytes.extend(encode_bytes_field(1, &dim_msg)); }
let mut tensor_type = encode_varint_field(1, elem_type as u64);
tensor_type.extend(encode_bytes_field(2, &shape_bytes));
let type_proto = encode_bytes_field(1, &tensor_type);
let mut vi_bytes = encode_bytes_field(1, input_name.as_bytes());
vi_bytes.extend(encode_bytes_field(2, &type_proto));
encode_bytes_field(11, &vi_bytes)
}
fn build_model(ir_version: u64, extra_fields: Vec<u8>, graph_bytes: Option<Vec<u8>>) -> Vec<u8> {
let mut model = encode_varint_field(1, ir_version); model.extend(extra_fields);
if let Some(g) = graph_bytes {
model.extend(encode_bytes_field(7, &g)); }
let opset = encode_varint_field(2, 11);
model.extend(encode_bytes_field(8, &opset));
model
}
#[test]
fn test_parse_producer_name() {
let mut extra = encode_bytes_field(2, b"test_framework");
extra.extend(encode_bytes_field(3, b"1.2.3"));
let model_bytes = build_model(8, extra, None);
let session = Session::from_bytes(&model_bytes).expect("should parse model");
let meta = session.metadata();
assert_eq!(meta.producer_name, "test_framework");
assert_eq!(meta.producer_version, "1.2.3");
assert_eq!(meta.ir_version, 8);
}
#[test]
fn test_dim_param_symbolic() {
let graph_bytes = build_graph_with_input(
"input_ids",
1, &[
(None, Some("batch_size")), (Some(768), None), ],
);
let model_bytes = build_model(7, Vec::new(), Some(graph_bytes));
let session = Session::from_bytes(&model_bytes).expect("should parse model with dim_param");
let infos = session.input_info();
assert!(!infos.is_empty(), "expected at least one input_info entry");
let info = infos
.iter()
.find(|i| i.name == "input_ids")
.expect("input_ids should be in input_infos");
assert_eq!(info.shape.len(), 2, "expected 2 dimensions");
let sym_shape = info.symbolic_shape();
assert_eq!(sym_shape.len(), 2);
assert_eq!(
sym_shape[0],
Dim::Symbol("batch_size".to_string()),
"first dim should be Dim::Symbol(batch_size)"
);
assert_eq!(
sym_shape[1],
Dim::Static(768),
"second dim should be Dim::Static(768)"
);
}
#[test]
fn test_metadata_props() {
let mut entry1 = encode_bytes_field(1, b"license");
entry1.extend(encode_bytes_field(2, b"Apache-2.0"));
let mut entry2 = encode_bytes_field(1, b"framework");
entry2.extend(encode_bytes_field(2, b"onnx-test"));
let mut extra = encode_bytes_field(14, &entry1);
extra.extend(encode_bytes_field(14, &entry2));
let model_bytes = build_model(7, extra, None);
let session =
Session::from_bytes(&model_bytes).expect("should parse model with metadata_props");
let meta = session.metadata();
assert_eq!(
meta.custom_metadata.get("license").map(String::as_str),
Some("Apache-2.0"),
"license metadata_prop should be Apache-2.0"
);
assert_eq!(
meta.custom_metadata.get("framework").map(String::as_str),
Some("onnx-test"),
"framework metadata_prop should be onnx-test"
);
}