use crate::{ModelError, ModelResult};
use std::collections::HashMap;
pub mod proto {
pub fn encode_varint(value: u64) -> Vec<u8> {
let mut buf = Vec::with_capacity(10);
let mut v = value;
loop {
if v < 0x80 {
buf.push(v as u8);
break;
}
buf.push((v as u8 & 0x7F) | 0x80);
v >>= 7;
}
buf
}
pub fn field_tag(field: u32, wire_type: u32) -> Vec<u8> {
encode_varint(((field as u64) << 3) | wire_type as u64)
}
pub fn encode_string(field: u32, s: &str) -> Vec<u8> {
let bytes = s.as_bytes();
let mut out = field_tag(field, 2);
out.extend(encode_varint(bytes.len() as u64));
out.extend_from_slice(bytes);
out
}
pub fn encode_bytes(field: u32, b: &[u8]) -> Vec<u8> {
let mut out = field_tag(field, 2);
out.extend(encode_varint(b.len() as u64));
out.extend_from_slice(b);
out
}
pub fn encode_i32(field: u32, v: i32) -> Vec<u8> {
if v == 0 {
return vec![];
}
let mut out = field_tag(field, 0);
out.extend(encode_varint(v as i64 as u64));
out
}
pub fn encode_i64(field: u32, v: i64) -> Vec<u8> {
if v == 0 {
return vec![];
}
let mut out = field_tag(field, 0);
out.extend(encode_varint(v as u64));
out
}
pub fn encode_f32(field: u32, v: f32) -> Vec<u8> {
let mut out = field_tag(field, 5);
out.extend_from_slice(&v.to_le_bytes());
out
}
pub fn encode_submessage(field: u32, msg: &[u8]) -> Vec<u8> {
if msg.is_empty() {
return vec![];
}
let mut out = field_tag(field, 2);
out.extend(encode_varint(msg.len() as u64));
out.extend_from_slice(msg);
out
}
pub fn encode_float_slice_as_raw(field: u32, floats: &[f32]) -> Vec<u8> {
let bytes: Vec<u8> = floats.iter().flat_map(|f| f.to_le_bytes()).collect();
encode_bytes(field, &bytes)
}
pub fn encode_packed_i64(field: u32, vals: &[i64]) -> Vec<u8> {
if vals.is_empty() {
return vec![];
}
let mut inner = Vec::new();
for &v in vals {
inner.extend(encode_varint(v as u64));
}
encode_bytes(field, &inner)
}
pub fn encode_packed_f32(field: u32, vals: &[f32]) -> Vec<u8> {
if vals.is_empty() {
return vec![];
}
let mut inner = Vec::with_capacity(vals.len() * 4);
for &v in vals {
inner.extend_from_slice(&v.to_le_bytes());
}
encode_bytes(field, &inner)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(i32)]
pub enum OnnxDataType {
Float = 1,
Int8 = 3,
Int32 = 6,
Int64 = 7,
Float16 = 10,
Double = 11,
}
impl OnnxDataType {
pub fn as_i32(self) -> i32 {
self as i32
}
}
#[derive(Debug, Clone)]
pub struct OnnxTensor {
pub name: String,
pub dims: Vec<i64>,
pub data_type: OnnxDataType,
pub float_data: Vec<f32>,
}
impl OnnxTensor {
fn to_proto_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend(proto::encode_packed_i64(1, &self.dims));
buf.extend(proto::encode_i32(2, self.data_type.as_i32()));
buf.extend(proto::encode_string(8, &self.name));
if !self.float_data.is_empty() {
buf.extend(proto::encode_float_slice_as_raw(9, &self.float_data));
}
buf
}
}
#[derive(Debug, Clone)]
pub struct OnnxValueInfo {
pub name: String,
pub data_type: OnnxDataType,
pub shape: Vec<Option<i64>>,
}
impl OnnxValueInfo {
fn to_proto_bytes(&self) -> Vec<u8> {
let mut shape_buf = Vec::new();
for dim_opt in &self.shape {
let dim_bytes = match dim_opt {
Some(v) => {
proto::encode_i64(1, *v)
}
None => {
proto::encode_string(2, "?")
}
};
shape_buf.extend(proto::encode_submessage(1, &dim_bytes));
}
let mut tensor_type_buf = Vec::new();
tensor_type_buf.extend(proto::encode_i32(1, self.data_type.as_i32()));
if !shape_buf.is_empty() {
tensor_type_buf.extend(proto::encode_submessage(2, &shape_buf));
}
let type_proto_buf = proto::encode_submessage(1, &tensor_type_buf);
let mut buf = Vec::new();
buf.extend(proto::encode_string(1, &self.name));
buf.extend(proto::encode_submessage(2, &type_proto_buf));
buf
}
}
#[repr(i32)]
enum AttributeType {
Int = 1,
Float = 4,
String = 3,
Ints = 7,
Floats = 6,
}
#[derive(Debug, Clone)]
pub enum OnnxAttribute {
Int(String, i64),
Float(String, f32),
String(String, Vec<u8>),
Ints(String, Vec<i64>),
Floats(String, Vec<f32>),
}
impl OnnxAttribute {
fn to_proto_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
match self {
OnnxAttribute::Int(name, v) => {
buf.extend(proto::encode_string(1, name));
buf.extend(proto::encode_i64(3, *v));
buf.extend(proto::encode_i32(20, AttributeType::Int as i32));
}
OnnxAttribute::Float(name, v) => {
buf.extend(proto::encode_string(1, name));
buf.extend(proto::encode_f32(4, *v));
buf.extend(proto::encode_i32(20, AttributeType::Float as i32));
}
OnnxAttribute::String(name, v) => {
buf.extend(proto::encode_string(1, name));
buf.extend(proto::encode_bytes(8, v));
buf.extend(proto::encode_i32(20, AttributeType::String as i32));
}
OnnxAttribute::Ints(name, vals) => {
buf.extend(proto::encode_string(1, name));
buf.extend(proto::encode_packed_i64(7, vals));
buf.extend(proto::encode_i32(20, AttributeType::Ints as i32));
}
OnnxAttribute::Floats(name, vals) => {
buf.extend(proto::encode_string(1, name));
buf.extend(proto::encode_packed_f32(6, vals));
buf.extend(proto::encode_i32(20, AttributeType::Floats as i32));
}
}
buf
}
}
#[derive(Debug, Clone)]
pub struct OnnxNode {
pub op_type: String,
pub name: String,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub attributes: Vec<OnnxAttribute>,
}
impl OnnxNode {
fn to_proto_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
for inp in &self.inputs {
buf.extend(proto::encode_string(1, inp));
}
for out in &self.outputs {
buf.extend(proto::encode_string(2, out));
}
buf.extend(proto::encode_string(3, &self.name));
buf.extend(proto::encode_string(4, &self.op_type));
for attr in &self.attributes {
let attr_bytes = attr.to_proto_bytes();
buf.extend(proto::encode_submessage(5, &attr_bytes));
}
buf
}
}
#[derive(Debug, Clone)]
pub struct OnnxGraph {
pub name: String,
pub nodes: Vec<OnnxNode>,
pub inputs: Vec<OnnxValueInfo>,
pub outputs: Vec<OnnxValueInfo>,
pub initializers: Vec<OnnxTensor>,
}
impl OnnxGraph {
fn to_proto_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
for node in &self.nodes {
let node_bytes = node.to_proto_bytes();
buf.extend(proto::encode_submessage(1, &node_bytes));
}
buf.extend(proto::encode_string(2, &self.name));
for init in &self.initializers {
let init_bytes = init.to_proto_bytes();
buf.extend(proto::encode_submessage(6, &init_bytes));
}
for inp in &self.inputs {
let inp_bytes = inp.to_proto_bytes();
buf.extend(proto::encode_submessage(11, &inp_bytes));
}
for out in &self.outputs {
let out_bytes = out.to_proto_bytes();
buf.extend(proto::encode_submessage(12, &out_bytes));
}
buf
}
}
#[derive(Debug, Clone)]
pub struct OnnxModel {
pub ir_version: i64,
pub opset_version: i64,
pub domain: String,
pub graph: OnnxGraph,
pub doc_string: String,
}
impl OnnxModel {
pub fn new(graph: OnnxGraph) -> Self {
Self {
ir_version: 8,
opset_version: 17,
domain: String::new(),
graph,
doc_string: "Generated by Kizzasi".to_string(),
}
}
pub fn to_bytes(&self) -> ModelResult<Vec<u8>> {
let mut buf = Vec::new();
buf.extend(proto::encode_i64(1, self.ir_version));
let graph_bytes = self.graph.to_proto_bytes();
buf.extend(proto::encode_submessage(7, &graph_bytes));
let mut opset_buf = Vec::new();
opset_buf.extend(proto::encode_string(1, &self.domain));
opset_buf.extend(proto::encode_i64(2, self.opset_version));
buf.extend(proto::encode_submessage(8, &opset_buf));
if !self.doc_string.is_empty() {
buf.extend(proto::encode_string(12, &self.doc_string));
}
Ok(buf)
}
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> ModelResult<()> {
let bytes = self.to_bytes()?;
std::fs::write(path, bytes).map_err(ModelError::IoError)
}
}
pub fn export_weights_to_onnx(
weights: &HashMap<String, Vec<f32>>,
shapes: &HashMap<String, Vec<usize>>,
model_name: &str,
) -> ModelResult<OnnxModel> {
let mut initializers = Vec::with_capacity(weights.len());
let mut names: Vec<&String> = weights.keys().collect();
names.sort();
for name in names {
let data = &weights[name];
let shape = shapes.get(name).ok_or_else(|| {
ModelError::invalid_config(format!(
"export_weights_to_onnx: shape missing for tensor '{name}'"
))
})?;
let volume: usize = shape.iter().product();
if volume != data.len() {
return Err(ModelError::InvalidConfig {
message: format!(
"export_weights_to_onnx: tensor '{name}' shape {:?} has volume {volume} \
but data length is {}",
shape,
data.len()
),
});
}
let dims: Vec<i64> = shape.iter().map(|&d| d as i64).collect();
initializers.push(OnnxTensor {
name: name.clone(),
dims,
data_type: OnnxDataType::Float,
float_data: data.clone(),
});
}
let graph = OnnxGraph {
name: model_name.to_string(),
nodes: vec![],
inputs: vec![],
outputs: vec![],
initializers,
};
Ok(OnnxModel::new(graph))
}
pub fn export_linear_layer(
weight: &[f32],
weight_shape: &[usize],
bias: Option<&[f32]>,
input_name: &str,
output_name: &str,
layer_name: &str,
) -> ModelResult<OnnxGraph> {
if weight_shape.len() != 2 {
return Err(ModelError::invalid_config(format!(
"export_linear_layer: weight_shape must have exactly 2 elements, \
got {}",
weight_shape.len()
)));
}
let out_features = weight_shape[0];
let in_features = weight_shape[1];
let expected_len = out_features * in_features;
if weight.len() != expected_len {
return Err(ModelError::invalid_config(format!(
"export_linear_layer: weight length {} does not match shape {:?} (volume {})",
weight.len(),
weight_shape,
expected_len
)));
}
let mut weight_transposed = vec![0.0f32; expected_len];
for o in 0..out_features {
for i in 0..in_features {
weight_transposed[i * out_features + o] = weight[o * in_features + i];
}
}
let weight_name = format!("{layer_name}.weight");
let matmul_out_name = format!("{layer_name}.matmul_out");
let weight_tensor = OnnxTensor {
name: weight_name.clone(),
dims: vec![in_features as i64, out_features as i64],
data_type: OnnxDataType::Float,
float_data: weight_transposed,
};
let mut nodes = Vec::new();
let mut initializers = vec![weight_tensor];
let matmul_output = if bias.is_some() {
matmul_out_name.clone()
} else {
output_name.to_string()
};
let matmul_node = OnnxNode {
op_type: "MatMul".to_string(),
name: format!("{layer_name}/MatMul"),
inputs: vec![input_name.to_string(), weight_name],
outputs: vec![matmul_output],
attributes: vec![],
};
nodes.push(matmul_node);
if let Some(bias_data) = bias {
if bias_data.len() != out_features {
return Err(ModelError::invalid_config(format!(
"export_linear_layer: bias length {} does not match out_features {}",
bias_data.len(),
out_features
)));
}
let bias_name = format!("{layer_name}.bias");
let bias_tensor = OnnxTensor {
name: bias_name.clone(),
dims: vec![out_features as i64],
data_type: OnnxDataType::Float,
float_data: bias_data.to_vec(),
};
initializers.push(bias_tensor);
let add_node = OnnxNode {
op_type: "Add".to_string(),
name: format!("{layer_name}/Add"),
inputs: vec![matmul_out_name, bias_name],
outputs: vec![output_name.to_string()],
attributes: vec![],
};
nodes.push(add_node);
}
let graph_input = OnnxValueInfo {
name: input_name.to_string(),
data_type: OnnxDataType::Float,
shape: vec![None, Some(in_features as i64)],
};
let graph_output = OnnxValueInfo {
name: output_name.to_string(),
data_type: OnnxDataType::Float,
shape: vec![None, Some(out_features as i64)],
};
Ok(OnnxGraph {
name: layer_name.to_string(),
nodes,
inputs: vec![graph_input],
outputs: vec![graph_output],
initializers,
})
}
#[derive(Debug, Default)]
pub struct OnnxGraphBuilder {
name: String,
nodes: Vec<OnnxNode>,
inputs: Vec<OnnxValueInfo>,
outputs: Vec<OnnxValueInfo>,
initializers: Vec<OnnxTensor>,
}
impl OnnxGraphBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
..Default::default()
}
}
pub fn add_node(mut self, node: OnnxNode) -> Self {
self.nodes.push(node);
self
}
pub fn add_input(mut self, vi: OnnxValueInfo) -> Self {
self.inputs.push(vi);
self
}
pub fn add_output(mut self, vi: OnnxValueInfo) -> Self {
self.outputs.push(vi);
self
}
pub fn add_initializer(mut self, tensor: OnnxTensor) -> Self {
self.initializers.push(tensor);
self
}
pub fn build(self) -> OnnxGraph {
OnnxGraph {
name: self.name,
nodes: self.nodes,
inputs: self.inputs,
outputs: self.outputs,
initializers: self.initializers,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_proto_varint_encoding() {
use super::proto::encode_varint;
assert_eq!(encode_varint(0), vec![0]);
assert_eq!(encode_varint(1), vec![1]);
assert_eq!(encode_varint(127), vec![127]);
assert_eq!(encode_varint(128), vec![0x80, 0x01]);
assert_eq!(encode_varint(300), vec![0xAC, 0x02]);
}
#[test]
fn test_onnx_model_to_bytes_nonempty() {
let graph = OnnxGraph {
name: "test".to_string(),
nodes: vec![],
inputs: vec![],
outputs: vec![],
initializers: vec![OnnxTensor {
name: "w".to_string(),
dims: vec![2, 3],
data_type: OnnxDataType::Float,
float_data: vec![1.0f32; 6],
}],
};
let model = OnnxModel::new(graph);
let bytes = model.to_bytes().expect("serialization must succeed");
assert!(!bytes.is_empty());
}
#[test]
fn test_onnx_save_and_file_size() {
let graph = OnnxGraph {
name: "weight_export".to_string(),
nodes: vec![],
inputs: vec![],
outputs: vec![],
initializers: vec![OnnxTensor {
name: "embed.weight".to_string(),
dims: vec![4, 8],
data_type: OnnxDataType::Float,
float_data: (0..32).map(|i| i as f32 * 0.01).collect(),
}],
};
let model = OnnxModel::new(graph);
let path = std::env::temp_dir().join("test_kizzasi_export.onnx");
model.save(&path).expect("save must succeed");
let metadata = std::fs::metadata(&path).expect("file must exist after save");
assert!(metadata.len() > 10, "exported file must be non-trivial");
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_export_weights_to_onnx() {
let mut weights = HashMap::new();
weights.insert("layer.weight".to_string(), vec![1.0f32; 12]);
let mut shapes = HashMap::new();
shapes.insert("layer.weight".to_string(), vec![3, 4]);
let model =
export_weights_to_onnx(&weights, &shapes, "test_model").expect("export must succeed");
assert_eq!(model.graph.initializers.len(), 1);
assert_eq!(model.graph.initializers[0].dims, vec![3i64, 4]);
}
#[test]
fn test_export_linear_layer() {
let weight = vec![1.0f32; 6]; let bias = vec![0.1f32; 2];
let graph =
export_linear_layer(&weight, &[2, 3], Some(&bias), "input", "output", "linear0")
.expect("export_linear_layer must succeed");
assert!(!graph.nodes.is_empty(), "must have at least one node");
assert!(
!graph.initializers.is_empty(),
"must have at least one initializer"
);
assert_eq!(graph.nodes.len(), 2);
assert_eq!(graph.nodes[0].op_type, "MatMul");
assert_eq!(graph.nodes[1].op_type, "Add");
}
#[test]
fn test_export_linear_layer_no_bias() {
let weight = vec![0.5f32; 12]; let graph = export_linear_layer(&weight, &[3, 4], None, "x", "y", "fc")
.expect("export must succeed");
assert_eq!(graph.nodes.len(), 1);
assert_eq!(graph.nodes[0].op_type, "MatMul");
assert_eq!(graph.initializers.len(), 1);
}
#[test]
fn test_onnx_tensor_raw_data() {
let floats = vec![1.0f32, 2.0, 3.0];
let raw: Vec<u8> = floats.iter().flat_map(|f| f.to_le_bytes()).collect();
assert_eq!(raw.len(), 12);
let recovered: Vec<f32> = raw
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
assert_eq!(recovered, floats);
}
#[test]
fn test_export_weights_shape_mismatch() {
let mut weights = HashMap::new();
weights.insert("w".to_string(), vec![1.0f32; 6]);
let mut shapes = HashMap::new();
shapes.insert("w".to_string(), vec![3, 3]);
let result = export_weights_to_onnx(&weights, &shapes, "bad");
assert!(result.is_err(), "must fail on volume mismatch");
}
#[test]
fn test_export_weights_missing_shape() {
let mut weights = HashMap::new();
weights.insert("w".to_string(), vec![1.0f32; 6]);
let shapes: HashMap<String, Vec<usize>> = HashMap::new();
let result = export_weights_to_onnx(&weights, &shapes, "bad");
assert!(result.is_err(), "must fail when shape is absent");
}
#[test]
fn test_onnx_graph_builder() {
let graph = OnnxGraphBuilder::new("built_graph")
.add_initializer(OnnxTensor {
name: "param".to_string(),
dims: vec![4, 4],
data_type: OnnxDataType::Float,
float_data: vec![0.0f32; 16],
})
.add_input(OnnxValueInfo {
name: "x".to_string(),
data_type: OnnxDataType::Float,
shape: vec![None, Some(4)],
})
.add_output(OnnxValueInfo {
name: "y".to_string(),
data_type: OnnxDataType::Float,
shape: vec![None, Some(4)],
})
.build();
assert_eq!(graph.name, "built_graph");
assert_eq!(graph.initializers.len(), 1);
assert_eq!(graph.inputs.len(), 1);
assert_eq!(graph.outputs.len(), 1);
}
#[test]
fn test_attribute_int_encoding() {
let attr = OnnxAttribute::Int("axis".to_string(), 1);
let bytes = attr.to_proto_bytes();
assert!(!bytes.is_empty());
}
#[test]
fn test_attribute_floats_encoding() {
let attr = OnnxAttribute::Floats("scales".to_string(), vec![1.0, 2.0, 3.0]);
let bytes = attr.to_proto_bytes();
assert!(!bytes.is_empty());
}
#[test]
fn test_multi_tensor_export_round_trip() {
let mut weights = HashMap::new();
let mut shapes = HashMap::new();
for layer in 0..4 {
let key = format!("layer{layer}.weight");
weights.insert(key.clone(), vec![0.1f32; 8]);
shapes.insert(key, vec![2, 4]);
let bkey = format!("layer{layer}.bias");
weights.insert(bkey.clone(), vec![0.0f32; 2]);
shapes.insert(bkey, vec![2]);
}
let model =
export_weights_to_onnx(&weights, &shapes, "multi_layer").expect("export must succeed");
assert_eq!(model.graph.initializers.len(), 8);
let bytes = model.to_bytes().expect("to_bytes must succeed");
assert!(!bytes.is_empty());
let path = std::env::temp_dir().join("test_kizzasi_multi_layer.onnx");
model.save(&path).expect("save must succeed");
let on_disk = std::fs::read(&path).expect("must read back file");
assert_eq!(on_disk, bytes);
let _ = std::fs::remove_file(&path);
}
}