use std::io::Write;
use tempfile::NamedTempFile;
pub fn pb_varint(val: u64) -> Vec<u8> {
let mut buf = Vec::new();
let mut v = val;
loop {
let byte = (v & 0x7F) as u8;
v >>= 7;
if v == 0 {
buf.push(byte);
break;
} else {
buf.push(byte | 0x80);
}
}
buf
}
pub fn pb_tag(field: u32, wire_type: u32) -> Vec<u8> {
pb_varint(((field as u64) << 3) | (wire_type as u64))
}
pub fn pb_length_delimited(field: u32, data: &[u8]) -> Vec<u8> {
let mut out = pb_tag(field, 2);
out.extend(pb_varint(data.len() as u64));
out.extend_from_slice(data);
out
}
pub fn pb_varint_field(field: u32, val: u64) -> Vec<u8> {
let mut out = pb_tag(field, 0);
out.extend(pb_varint(val));
out
}
pub fn build_tensor_proto(name: &str, data_type: u64, dims: &[i64], raw_data: &[u8]) -> Vec<u8> {
let mut proto = Vec::new();
let mut packed_dims = Vec::new();
for &d in dims {
packed_dims.extend(pb_varint(d as u64));
}
proto.extend(pb_length_delimited(1, &packed_dims));
proto.extend(pb_varint_field(2, data_type));
proto.extend(pb_length_delimited(8, name.as_bytes()));
proto.extend(pb_length_delimited(9, raw_data));
proto
}
pub fn build_tensor_proto_float_data(name: &str, dims: &[i64], float_data: &[f32]) -> Vec<u8> {
let mut proto = Vec::new();
let mut packed_dims = Vec::new();
for &d in dims {
packed_dims.extend(pb_varint(d as u64));
}
proto.extend(pb_length_delimited(1, &packed_dims));
proto.extend(pb_varint_field(2, 1));
proto.extend(pb_length_delimited(8, name.as_bytes()));
let mut packed_floats = Vec::new();
for &f in float_data {
packed_floats.extend_from_slice(&f.to_le_bytes());
}
proto.extend(pb_length_delimited(4, &packed_floats));
proto
}
pub fn build_onnx_file(tensors: Vec<Vec<u8>>) -> NamedTempFile {
let mut graph = Vec::new();
for tensor in &tensors {
graph.extend(pb_length_delimited(5, tensor));
}
let mut model = Vec::new();
model.extend(pb_length_delimited(7, &graph));
let mut file = NamedTempFile::with_suffix(".onnx").unwrap();
std::io::Write::write_all(&mut file, &model).unwrap();
file.flush().unwrap();
file
}