pub fn build_onnx_mlp_bytes(layer_specs: &[(usize, usize, Vec<f32>, Vec<f32>)]) -> Vec<u8> {
if layer_specs.is_empty() {
return Vec::new();
}
let n = layer_specs.len();
let in_dim = layer_specs[0].0;
let out_dim = layer_specs[n - 1].1;
let mut initializers: Vec<Vec<u8>> = Vec::new();
let mut nodes: Vec<Vec<u8>> = Vec::new();
for (idx, (layer_in, layer_out, weights, biases)) in layer_specs.iter().enumerate() {
let w_name = format!("W{idx}");
let b_name = format!("b{idx}");
initializers.push(build_tensor_proto(
&w_name,
&[*layer_in as i64, *layer_out as i64],
weights,
));
initializers.push(build_tensor_proto(&b_name, &[*layer_out as i64], biases));
let gemm_input = if idx == 0 {
"obs".to_string()
} else {
format!("relu{}", idx - 1)
};
let gemm_output = format!("gemm{idx}");
nodes.push(build_gemm_node(
&format!("Gemm_{idx}"),
&gemm_input,
&w_name,
&b_name,
&gemm_output,
));
if idx < n - 1 {
nodes.push(build_relu_node(
&format!("Relu_{idx}"),
&gemm_output,
&format!("relu{idx}"),
));
}
}
let final_output = format!("gemm{}", n - 1);
let input_info = build_value_info("obs", in_dim);
let output_info = build_value_info(&final_output, out_dim);
let graph = build_graph_proto("mlp", &nodes, &initializers, &[input_info], &[output_info]);
build_model_proto(graph)
}
fn varint(mut val: u64) -> Vec<u8> {
let mut out = Vec::new();
loop {
if val < 0x80 {
out.push(val as u8);
break;
}
out.push((val as u8 & 0x7F) | 0x80);
val >>= 7;
}
out
}
fn field_varint(field: u32, val: u64) -> Vec<u8> {
let mut out = varint(((field as u64) << 3) | 0);
out.extend(varint(val));
out
}
fn field_bytes(field: u32, data: &[u8]) -> Vec<u8> {
let mut out = varint(((field as u64) << 3) | 2);
out.extend(varint(data.len() as u64));
out.extend_from_slice(data);
out
}
fn field_str(field: u32, s: &str) -> Vec<u8> {
field_bytes(field, s.as_bytes())
}
fn field_fixed32(field: u32, val: f32) -> Vec<u8> {
let mut out = varint(((field as u64) << 3) | 5);
out.extend_from_slice(&val.to_le_bytes());
out
}
fn field_msg(field: u32, msg: &[u8]) -> Vec<u8> {
field_bytes(field, msg)
}
fn build_tensor_proto(name: &str, dims: &[i64], data: &[f32]) -> Vec<u8> {
let mut msg = Vec::new();
for &d in dims {
msg.extend(field_varint(1, d as u64)); }
msg.extend(field_varint(2, 1)); msg.extend(field_str(8, name)); let raw: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
msg.extend(field_bytes(9, &raw)); msg
}
fn build_attribute_float(name: &str, val: f32) -> Vec<u8> {
let mut msg = Vec::new();
msg.extend(field_str(1, name)); msg.extend(field_varint(20, 1)); msg.extend(field_fixed32(4, val)); msg
}
fn build_attribute_int(name: &str, val: i64) -> Vec<u8> {
let mut msg = Vec::new();
msg.extend(field_str(1, name)); msg.extend(field_varint(20, 2)); msg.extend(field_varint(4, val as u64)); msg
}
fn build_gemm_node(
name: &str,
input: &str,
weight: &str,
bias: &str,
output: &str,
) -> Vec<u8> {
let mut msg = Vec::new();
msg.extend(field_str(1, input)); msg.extend(field_str(1, weight)); msg.extend(field_str(1, bias)); msg.extend(field_str(2, output)); msg.extend(field_str(3, name)); msg.extend(field_str(4, "Gemm")); msg.extend(field_msg(6, &build_attribute_float("alpha", 1.0)));
msg.extend(field_msg(6, &build_attribute_float("beta", 1.0)));
msg.extend(field_msg(6, &build_attribute_int("transB", 0)));
msg
}
fn build_relu_node(name: &str, input: &str, output: &str) -> Vec<u8> {
let mut msg = Vec::new();
msg.extend(field_str(1, input));
msg.extend(field_str(2, output));
msg.extend(field_str(3, name));
msg.extend(field_str(4, "Relu"));
msg
}
fn build_dim(dim_value: Option<i64>, dim_param: Option<&str>) -> Vec<u8> {
let mut msg = Vec::new();
if let Some(v) = dim_value {
msg.extend(field_varint(1, v as u64)); }
if let Some(p) = dim_param {
msg.extend(field_str(2, p)); }
msg
}
fn build_type_proto_float_tensor(feature_dim: usize) -> Vec<u8> {
let mut shape_msg = Vec::new();
shape_msg.extend(field_msg(1, &build_dim(None, Some("batch_size"))));
shape_msg.extend(field_msg(1, &build_dim(Some(feature_dim as i64), None)));
let mut tensor_msg = Vec::new();
tensor_msg.extend(field_varint(1, 1)); tensor_msg.extend(field_msg(2, &shape_msg));
let mut type_msg = Vec::new();
type_msg.extend(field_msg(1, &tensor_msg));
type_msg
}
fn build_value_info(name: &str, feature_dim: usize) -> Vec<u8> {
let mut msg = Vec::new();
msg.extend(field_str(1, name)); msg.extend(field_msg(2, &build_type_proto_float_tensor(feature_dim))); msg
}
fn build_graph_proto(
name: &str,
nodes: &[Vec<u8>],
initializers: &[Vec<u8>],
inputs: &[Vec<u8>],
outputs: &[Vec<u8>],
) -> Vec<u8> {
let mut msg = Vec::new();
for node in nodes {
msg.extend(field_msg(1, node)); }
msg.extend(field_str(2, name)); for init in initializers {
msg.extend(field_msg(5, init)); }
for input in inputs {
msg.extend(field_msg(11, input)); }
for output in outputs {
msg.extend(field_msg(12, output)); }
msg
}
fn build_opset_import(domain: &str, version: i64) -> Vec<u8> {
let mut msg = Vec::new();
msg.extend(field_str(1, domain)); msg.extend(field_varint(2, version as u64)); msg
}
fn build_model_proto(graph: Vec<u8>) -> Vec<u8> {
let mut msg = Vec::new();
msg.extend(field_varint(1, 7)); msg.extend(field_msg(8, &build_opset_import("", 17))); msg.extend(field_msg(7, &graph)); msg
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn varint_single_byte() {
assert_eq!(varint(0), vec![0]);
assert_eq!(varint(127), vec![127]);
}
#[test]
fn varint_multibyte() {
assert_eq!(varint(128), vec![0x80, 0x01]);
}
#[test]
fn build_onnx_mlp_bytes_nonempty_for_single_layer() {
let weights = vec![1.0f32, 0.0, 0.0, 1.0]; let biases = vec![0.0f32, 0.0];
let specs = vec![(2usize, 2usize, weights, biases)];
let bytes = build_onnx_mlp_bytes(&specs);
assert!(!bytes.is_empty(), "ONNX bytes should not be empty");
}
#[test]
fn build_onnx_mlp_bytes_empty_for_no_layers() {
let bytes = build_onnx_mlp_bytes(&[]);
assert!(bytes.is_empty());
}
#[test]
fn build_onnx_mlp_bytes_two_layer_mlp() {
let w1 = vec![0.1f32; 4 * 8]; let b1 = vec![0.0f32; 8];
let w2 = vec![0.2f32; 8 * 2]; let b2 = vec![0.0f32; 2];
let specs = vec![(4, 8, w1, b1), (8, 2, w2, b2)];
let bytes = build_onnx_mlp_bytes(&specs);
assert!(bytes.len() > 100, "expected a non-trivial ONNX blob");
}
}