1use crate::array::{Array, DTypeValue};
10pub use crate::llo::{OnnxAttribute, OnnxModel, OnnxNode, OnnxTensor, TrainingState};
11use crate::ops::{add, div, matmul, mul, relu, sigmoid, softmax, sub, transpose};
12use anyhow::{anyhow, Result};
13use std::collections::HashMap;
14
15pub fn serialize_onnx(model: &OnnxModel) -> Result<String> {
17 Ok(serde_json::to_string_pretty(model)?)
18}
19
20pub fn save_onnx(model: &OnnxModel, path: &str) -> Result<()> {
35 let json = serialize_onnx(model)?;
37 std::fs::write(path, json)?;
38 Ok(())
39}
40
41pub fn deserialize_onnx(json: &str) -> Result<OnnxModel> {
43 let model = serde_json::from_str(json)?;
44 Ok(model)
45}
46
47pub fn load_onnx(path: &str) -> Result<OnnxModel> {
63 let json = std::fs::read_to_string(path)?;
64 deserialize_onnx(&json)
65}
66
67pub fn create_linear_node(
76 name: &str,
77 input: &str,
78 weights: &str,
79 bias: Option<&str>,
80 output: &str,
81) -> OnnxNode {
82 let mut inputs = vec![input.to_string(), weights.to_string()];
83 if let Some(b) = bias {
84 inputs.push(b.to_string());
85 }
86
87 OnnxNode {
88 name: name.to_string(),
89 op_type: "Gemm".to_string(), inputs,
91 outputs: vec![output.to_string()],
92 attributes: {
93 let mut attrs = HashMap::new();
94 attrs.insert("alpha".to_string(), OnnxAttribute::Float(1.0));
95 attrs.insert("beta".to_string(), OnnxAttribute::Float(1.0));
96 attrs.insert("transA".to_string(), OnnxAttribute::Int(0));
97 attrs.insert("transB".to_string(), OnnxAttribute::Int(1)); attrs
99 },
100 }
101}
102
103pub fn create_relu_node(name: &str, input: &str, output: &str) -> OnnxNode {
105 OnnxNode {
106 name: name.to_string(),
107 op_type: "Relu".to_string(),
108 inputs: vec![input.to_string()],
109 outputs: vec![output.to_string()],
110 attributes: HashMap::new(),
111 }
112}
113
114pub fn create_softmax_node(name: &str, input: &str, output: &str, axis: i64) -> OnnxNode {
116 let mut attrs = HashMap::new();
117 attrs.insert("axis".to_string(), OnnxAttribute::Int(axis));
118
119 OnnxNode {
120 name: name.to_string(),
121 op_type: "Softmax".to_string(),
122 inputs: vec![input.to_string()],
123 outputs: vec![output.to_string()],
124 attributes: attrs,
125 }
126}
127
128pub fn create_matmul_node(name: &str, input_a: &str, input_b: &str, output: &str) -> OnnxNode {
130 OnnxNode {
131 name: name.to_string(),
132 op_type: "MatMul".to_string(),
133 inputs: vec![input_a.to_string(), input_b.to_string()],
134 outputs: vec![output.to_string()],
135 attributes: HashMap::new(),
136 }
137}
138
139pub fn create_add_node(name: &str, input_a: &str, input_b: &str, output: &str) -> OnnxNode {
141 OnnxNode {
142 name: name.to_string(),
143 op_type: "Add".to_string(),
144 inputs: vec![input_a.to_string(), input_b.to_string()],
145 outputs: vec![output.to_string()],
146 attributes: HashMap::new(),
147 }
148}
149
150pub fn array_to_onnx_tensor<T: DTypeValue>(name: &str, array: &Array<T>) -> Result<OnnxTensor> {
152 let data = &array.data;
154 let bytes: Vec<u8> = unsafe {
155 std::slice::from_raw_parts(
156 data.as_ptr() as *const u8,
157 data.len() * std::mem::size_of::<T>(),
158 )
159 .to_vec()
160 };
161
162 let dtype = match std::any::type_name::<T>() {
164 "f32" => 1, "f64" => 11, "i32" => 6, "i64" => 7, _ => return Err(anyhow!("Unsupported data type for ONNX")),
169 };
170
171 Ok(OnnxTensor {
172 name: name.to_string(),
173 dtype,
174 shape: array.shape().to_vec(),
175 data: bytes,
176 })
177}
178
179pub fn create_mlp(
203 name: &str,
204 input_size: usize,
205 _hidden_size: usize,
206 _output_size: usize,
207 weights: Vec<&Array>,
208) -> Result<OnnxModel> {
209 if weights.len() != 4 {
210 return Err(anyhow!("Expected 4 weight arrays [w1, b1, w2, b2]"));
211 }
212
213 let mut model = OnnxModel::new(name);
214
215 let input = OnnxTensor {
217 name: "input".to_string(),
218 dtype: 1, shape: vec![1, input_size], data: Vec::new(),
221 };
222 model.add_input(input);
223
224 model.add_initializer(array_to_onnx_tensor("w1", weights[0])?);
226 model.add_initializer(array_to_onnx_tensor("b1", weights[1])?);
227 model.add_initializer(array_to_onnx_tensor("w2", weights[2])?);
228 model.add_initializer(array_to_onnx_tensor("b2", weights[3])?);
229
230 model.add_node(create_linear_node(
232 "fc1",
233 "input",
234 "w1",
235 Some("b1"),
236 "hidden",
237 ));
238 model.add_node(create_relu_node("relu1", "hidden", "hidden_act"));
239
240 model.add_node(create_linear_node(
242 "fc2",
243 "hidden_act",
244 "w2",
245 Some("b2"),
246 "logits",
247 ));
248 model.add_node(create_softmax_node("softmax", "logits", "output", 1));
249
250 model.set_outputs(vec!["output".to_string()]);
252
253 Ok(model)
254}
255
256pub fn infer(model: &OnnxModel, inputs: HashMap<String, Array>) -> Result<HashMap<String, Array>> {
277 let mut values: HashMap<String, Array> = inputs;
279
280 for init in &model.graph.initializers {
282 if init.dtype == 1 {
285 let data_f32: Vec<f32> = init
287 .data
288 .chunks_exact(4)
289 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
290 .collect();
291 let array = Array::new(init.shape.clone(), data_f32);
292 values.insert(init.name.clone(), array);
293 } else if init.dtype == 7 {
294 let data_f32: Vec<f32> = init
296 .data
297 .chunks_exact(8)
298 .map(|b| {
299 i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
300 })
301 .collect();
302 let array = Array::new(init.shape.clone(), data_f32);
303 values.insert(init.name.clone(), array);
304 } else {
305 return Err(anyhow!("Unsupported initializer dtype: {}", init.dtype));
306 }
307 }
308
309 for node in &model.graph.nodes {
311 let mut node_inputs = Vec::new();
313 for input_name in &node.inputs {
314 let val = values.get(input_name).ok_or_else(|| {
315 anyhow!("Missing input '{}' for node '{}'", input_name, node.name)
316 })?;
317 node_inputs.push(val);
318 }
319
320 let output_name = &node.outputs[0]; let output_val = match node.op_type.as_str() {
323 "Add" => add(&node_inputs[0], &node_inputs[1])?,
324 "Sub" => sub(&node_inputs[0], &node_inputs[1])?,
325 "Mul" => mul(&node_inputs[0], &node_inputs[1])?,
326 "Div" => div(&node_inputs[0], &node_inputs[1])?,
327 "MatMul" => matmul(&node_inputs[0], &node_inputs[1])?,
328 "Relu" => relu(&node_inputs[0])?,
329 "Sigmoid" => sigmoid(&node_inputs[0])?,
330 "Softmax" => {
331 let axis = node
333 .attributes
334 .get("axis")
335 .and_then(|a| match a {
336 OnnxAttribute::Int(i) => Some(*i as usize),
337 _ => None,
338 })
339 .unwrap_or(1); softmax(&node_inputs[0], Some(axis))?
341 }
342 "Gemm" => {
343 let a = node_inputs[0];
345 let b = node_inputs[1];
346 let c = node_inputs.get(2); let alpha = node
349 .attributes
350 .get("alpha")
351 .and_then(|v| match v {
352 OnnxAttribute::Float(f) => Some(*f),
353 _ => None,
354 })
355 .unwrap_or(1.0);
356 let beta = node
357 .attributes
358 .get("beta")
359 .and_then(|v| match v {
360 OnnxAttribute::Float(f) => Some(*f),
361 _ => None,
362 })
363 .unwrap_or(1.0);
364
365 let trans_a = node
366 .attributes
367 .get("transA")
368 .and_then(|v| match v {
369 OnnxAttribute::Int(i) => Some(*i == 1),
370 _ => None,
371 })
372 .unwrap_or(false);
373 let trans_b = node
374 .attributes
375 .get("transB")
376 .and_then(|v| match v {
377 OnnxAttribute::Int(i) => Some(*i == 1),
378 _ => None,
379 })
380 .unwrap_or(false);
381
382 let a_processed = if trans_a {
384 transpose(a, None)?
385 } else {
386 a.clone()
387 };
388 let b_processed = if trans_b {
389 transpose(b, None)?
390 } else {
391 b.clone()
392 };
393
394 let mut mul_res = matmul(&a_processed, &b_processed)?;
396
397 if alpha != 1.0 {
399 let scalar = Array::new(vec![1], vec![alpha]);
400 mul_res = mul(&mul_res, &scalar)?;
401 }
402
403 if let Some(bias) = c {
405 if beta != 1.0 {
406 let scalar = Array::new(vec![1], vec![beta]);
407 let bias_scaled = mul(bias, &scalar)?;
408 add(&mul_res, &bias_scaled)?
409 } else {
410 add(&mul_res, bias)?
411 }
412 } else {
413 mul_res
414 }
415 }
416 "Transpose" => transpose(node_inputs[0], None)?,
417 "Conv" => {
418 let x = node_inputs[0];
420 let w = node_inputs[1];
421 let b = node_inputs.get(2).map(|&v| v);
422
423 let padding = node
425 .attributes
426 .get("pads")
427 .and_then(|a| match a {
428 OnnxAttribute::Ints(v) => Some(v[0] as usize),
429 _ => None,
430 })
431 .unwrap_or(0);
432
433 let stride = node
434 .attributes
435 .get("strides")
436 .and_then(|a| match a {
437 OnnxAttribute::Ints(v) => Some(v[0] as usize),
438 _ => None,
439 })
440 .unwrap_or(1);
441
442 crate::ops::conv::conv1d(x, w, b, stride, padding)?
443 }
444 "Reshape" => {
445 let data = node_inputs[0];
447 let shape_tensor = node_inputs[1];
448
449 let shape_vec: Vec<isize> = shape_tensor.data.iter().map(|&v| v as isize).collect();
451
452 crate::ops::shape::reshape(data, &shape_vec)?
453 }
454 "Flatten" => {
455 let axis = node
458 .attributes
459 .get("axis")
460 .and_then(|a| match a {
461 OnnxAttribute::Int(i) => Some(*i as usize),
462 _ => None,
463 })
464 .unwrap_or(1);
465
466 crate::ops::shape::flatten(node_inputs[0], axis, node_inputs[0].shape().len() - 1)?
476 }
477 "BatchNormalization" => {
478 let x = node_inputs[0];
480 let scale = node_inputs[1];
481 let b = node_inputs[2];
482 let mean = node_inputs[3];
483 let var = node_inputs[4];
484
485 let epsilon = node
486 .attributes
487 .get("epsilon")
488 .and_then(|a| match a {
489 OnnxAttribute::Float(f) => Some(*f),
490 _ => None,
491 })
492 .unwrap_or(1e-5);
493
494 let momentum = node
495 .attributes
496 .get("momentum")
497 .and_then(|a| match a {
498 OnnxAttribute::Float(f) => Some(*f),
499 _ => None,
500 })
501 .unwrap_or(0.9);
502
503 let mut mean_clone = mean.clone();
505 let mut var_clone = var.clone();
506
507 crate::ops::batchnorm::batch_norm(
508 x,
509 &mut mean_clone,
510 &mut var_clone,
511 scale,
512 b,
513 false,
514 momentum,
515 epsilon,
516 )?
517 }
518 "Dropout" => {
519 node_inputs[0].clone()
521 }
522 _ => return Err(anyhow!("Unsupported op type: {}", node.op_type)),
523 };
524
525 values.insert(output_name.clone(), output_val);
526 }
527
528 let mut results = HashMap::new();
530 for out_name in &model.graph.outputs {
531 if let Some(val) = values.get(out_name) {
532 results.insert(out_name.clone(), val.clone());
533 } else {
534 return Err(anyhow!(
535 "Model output '{}' not found after execution",
536 out_name
537 ));
538 }
539 }
540
541 Ok(results)
542}
543
544pub fn save_checkpoint(
551 model: &OnnxModel,
552 training_state: &TrainingState,
553 path: &str,
554) -> Result<()> {
555 #[derive(serde::Serialize)]
556 struct Checkpoint<'a> {
557 model: &'a OnnxModel,
558 training_state: &'a TrainingState,
559 }
560
561 let checkpoint = Checkpoint {
562 model,
563 training_state,
564 };
565 let json = serde_json::to_string_pretty(&checkpoint)?;
566 std::fs::write(path, json)?;
567 Ok(())
568}
569
570pub fn load_checkpoint(path: &str) -> Result<(OnnxModel, TrainingState)> {
578 #[derive(serde::Deserialize)]
579 struct Checkpoint {
580 model: OnnxModel,
581 training_state: TrainingState,
582 }
583
584 let json = std::fs::read_to_string(path)?;
585 let checkpoint: Checkpoint = serde_json::from_str(&json)?;
586 Ok((checkpoint.model, checkpoint.training_state))
587}