numrs/ops/
model.rs

1//! Model operations for saving, loading and inference
2//!
3//! This module provides high-level operations for:
4//! - Saving models in ONNX format
5//! - Loading models from ONNX format
6//! - Running inference on loaded models
7//! - Training support (forward/backward passes)
8
9use crate::array::{Array, DTypeValue};
10pub use crate::llo::{OnnxAttribute, OnnxModel, OnnxNode, OnnxTensor, TrainingState};
11use crate::ops::{add, div, matmul, mul, relu, sigmoid, softmax, sub, transpose};
12use anyhow::{anyhow, Result};
13use std::collections::HashMap;
14
15/// Serialize model to JSON string
16pub fn serialize_onnx(model: &OnnxModel) -> Result<String> {
17    Ok(serde_json::to_string_pretty(model)?)
18}
19
20/// Save model to ONNX format
21///
22/// # Arguments
23/// * `model` - The ONNX model to save
24/// * `path` - File path to save the model
25///
26/// # Example
27/// ```no_run
28/// use numrs::ops::model::{save_onnx, OnnxModel};
29///
30/// let model = OnnxModel::new("my_model");
31/// save_onnx(&model, "model.onnx")?;
32/// # Ok::<(), anyhow::Error>(())
33/// ```
34pub fn save_onnx(model: &OnnxModel, path: &str) -> Result<()> {
35    // Serialize to JSON (or could use protobuf for full ONNX compatibility)
36    let json = serialize_onnx(model)?;
37    std::fs::write(path, json)?;
38    Ok(())
39}
40
41/// Deserialize model from JSON string
42pub fn deserialize_onnx(json: &str) -> Result<OnnxModel> {
43    let model = serde_json::from_str(json)?;
44    Ok(model)
45}
46
47/// Load model from ONNX format
48///
49/// # Arguments
50/// * `path` - File path to load the model from
51///
52/// # Returns
53/// The loaded ONNX model
54///
55/// # Example
56/// ```no_run
57/// use numrs::ops::model::load_onnx;
58///
59/// let model = load_onnx("model.onnx")?;
60/// # Ok::<(), anyhow::Error>(())
61/// ```
62pub fn load_onnx(path: &str) -> Result<OnnxModel> {
63    let json = std::fs::read_to_string(path)?;
64    deserialize_onnx(&json)
65}
66
67/// Create a linear layer node
68///
69/// # Arguments
70/// * `name` - Node name
71/// * `input` - Input tensor name
72/// * `weights` - Weights tensor name
73/// * `bias` - Bias tensor name (optional)
74/// * `output` - Output tensor name
75pub fn create_linear_node(
76    name: &str,
77    input: &str,
78    weights: &str,
79    bias: Option<&str>,
80    output: &str,
81) -> OnnxNode {
82    let mut inputs = vec![input.to_string(), weights.to_string()];
83    if let Some(b) = bias {
84        inputs.push(b.to_string());
85    }
86
87    OnnxNode {
88        name: name.to_string(),
89        op_type: "Gemm".to_string(), // ONNX Gemm = General Matrix Multiply
90        inputs,
91        outputs: vec![output.to_string()],
92        attributes: {
93            let mut attrs = HashMap::new();
94            attrs.insert("alpha".to_string(), OnnxAttribute::Float(1.0));
95            attrs.insert("beta".to_string(), OnnxAttribute::Float(1.0));
96            attrs.insert("transA".to_string(), OnnxAttribute::Int(0));
97            attrs.insert("transB".to_string(), OnnxAttribute::Int(1)); // Transpose weights
98            attrs
99        },
100    }
101}
102
103/// Create a ReLU activation node
104pub fn create_relu_node(name: &str, input: &str, output: &str) -> OnnxNode {
105    OnnxNode {
106        name: name.to_string(),
107        op_type: "Relu".to_string(),
108        inputs: vec![input.to_string()],
109        outputs: vec![output.to_string()],
110        attributes: HashMap::new(),
111    }
112}
113
114/// Create a Softmax activation node
115pub fn create_softmax_node(name: &str, input: &str, output: &str, axis: i64) -> OnnxNode {
116    let mut attrs = HashMap::new();
117    attrs.insert("axis".to_string(), OnnxAttribute::Int(axis));
118
119    OnnxNode {
120        name: name.to_string(),
121        op_type: "Softmax".to_string(),
122        inputs: vec![input.to_string()],
123        outputs: vec![output.to_string()],
124        attributes: attrs,
125    }
126}
127
128/// Create a MatMul node
129pub fn create_matmul_node(name: &str, input_a: &str, input_b: &str, output: &str) -> OnnxNode {
130    OnnxNode {
131        name: name.to_string(),
132        op_type: "MatMul".to_string(),
133        inputs: vec![input_a.to_string(), input_b.to_string()],
134        outputs: vec![output.to_string()],
135        attributes: HashMap::new(),
136    }
137}
138
139/// Create an Add node
140pub fn create_add_node(name: &str, input_a: &str, input_b: &str, output: &str) -> OnnxNode {
141    OnnxNode {
142        name: name.to_string(),
143        op_type: "Add".to_string(),
144        inputs: vec![input_a.to_string(), input_b.to_string()],
145        outputs: vec![output.to_string()],
146        attributes: HashMap::new(),
147    }
148}
149
150/// Create a tensor from an Array
151pub fn array_to_onnx_tensor<T: DTypeValue>(name: &str, array: &Array<T>) -> Result<OnnxTensor> {
152    // Convert array data to bytes
153    let data = &array.data;
154    let bytes: Vec<u8> = unsafe {
155        std::slice::from_raw_parts(
156            data.as_ptr() as *const u8,
157            data.len() * std::mem::size_of::<T>(),
158        )
159        .to_vec()
160    };
161
162    // Determine ONNX dtype (1=FLOAT, 6=INT32, 7=INT64, 11=DOUBLE)
163    let dtype = match std::any::type_name::<T>() {
164        "f32" => 1,  // FLOAT
165        "f64" => 11, // DOUBLE
166        "i32" => 6,  // INT32
167        "i64" => 7,  // INT64
168        _ => return Err(anyhow!("Unsupported data type for ONNX")),
169    };
170
171    Ok(OnnxTensor {
172        name: name.to_string(),
173        dtype,
174        shape: array.shape().to_vec(),
175        data: bytes,
176    })
177}
178
179/// Create a simple feedforward neural network model
180///
181/// # Arguments
182/// * `name` - Model name
183/// * `input_size` - Input feature size
184/// * `hidden_size` - Hidden layer size
185/// * `output_size` - Output size
186/// * `weights` - Layer weights [w1, b1, w2, b2]
187///
188/// # Example
189/// ```no_run
190/// use numrs::ops::model::create_mlp;
191/// use numrs::Array;
192///
193/// let w1 = Array::new(vec![784, 128], vec![0.0; 784 * 128]);
194/// let b1 = Array::new(vec![128], vec![0.0; 128]);
195/// let w2 = Array::new(vec![128, 10], vec![0.0; 128 * 10]);
196/// let b2 = Array::new(vec![10], vec![0.0; 10]);
197///
198/// let model = create_mlp("mnist_classifier", 784, 128, 10,
199///                        vec![&w1, &b1, &w2, &b2])?;
200/// # Ok::<(), anyhow::Error>(())
201/// ```
202pub fn create_mlp(
203    name: &str,
204    input_size: usize,
205    _hidden_size: usize,
206    _output_size: usize,
207    weights: Vec<&Array>,
208) -> Result<OnnxModel> {
209    if weights.len() != 4 {
210        return Err(anyhow!("Expected 4 weight arrays [w1, b1, w2, b2]"));
211    }
212
213    let mut model = OnnxModel::new(name);
214
215    // Add input tensor
216    let input = OnnxTensor {
217        name: "input".to_string(),
218        dtype: 1,                   // FLOAT
219        shape: vec![1, input_size], // Batch size 1
220        data: Vec::new(),
221    };
222    model.add_input(input);
223
224    // Add weights as initializers
225    model.add_initializer(array_to_onnx_tensor("w1", weights[0])?);
226    model.add_initializer(array_to_onnx_tensor("b1", weights[1])?);
227    model.add_initializer(array_to_onnx_tensor("w2", weights[2])?);
228    model.add_initializer(array_to_onnx_tensor("b2", weights[3])?);
229
230    // Layer 1: Linear + ReLU
231    model.add_node(create_linear_node(
232        "fc1",
233        "input",
234        "w1",
235        Some("b1"),
236        "hidden",
237    ));
238    model.add_node(create_relu_node("relu1", "hidden", "hidden_act"));
239
240    // Layer 2: Linear + Softmax
241    model.add_node(create_linear_node(
242        "fc2",
243        "hidden_act",
244        "w2",
245        Some("b2"),
246        "logits",
247    ));
248    model.add_node(create_softmax_node("softmax", "logits", "output", 1));
249
250    // Set output
251    model.set_outputs(vec!["output".to_string()]);
252
253    Ok(model)
254}
255
256/// Run inference on a model
257///
258/// # Arguments
259/// * `model` - The ONNX model
260/// * `inputs` - Input tensors as a map of name -> Array
261///
262/// # Returns
263/// Output tensors as a map of name -> Array
264///
265/// # Note
266/// This is a simplified inference engine. For production use, consider using
267/// a full ONNX runtime like onnxruntime-rs.
268/// Run inference on a model
269///
270/// # Arguments
271/// * `model` - The ONNX model
272/// * `inputs` - Input tensors as a map of name -> Array
273///
274/// # Returns
275/// Output tensors as a map of name -> Array
276pub fn infer(model: &OnnxModel, inputs: HashMap<String, Array>) -> Result<HashMap<String, Array>> {
277    // 1. Initialize value store with inputs and initializers
278    let mut values: HashMap<String, Array> = inputs;
279
280    // Load initializers (weights) into values
281    for init in &model.graph.initializers {
282        // Convert OnnxTensor bytes back to Array
283        // Assuming f32 for now as per export limitation, but handling INT64 shape tensors via cast
284        if init.dtype == 1 {
285            // FLOAT (f32)
286            let data_f32: Vec<f32> = init
287                .data
288                .chunks_exact(4)
289                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
290                .collect();
291            let array = Array::new(init.shape.clone(), data_f32);
292            values.insert(init.name.clone(), array);
293        } else if init.dtype == 7 {
294            // INT64 (cast to f32 for compatibility with Array<f32> map)
295            let data_f32: Vec<f32> = init
296                .data
297                .chunks_exact(8)
298                .map(|b| {
299                    i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
300                })
301                .collect();
302            let array = Array::new(init.shape.clone(), data_f32);
303            values.insert(init.name.clone(), array);
304        } else {
305            return Err(anyhow!("Unsupported initializer dtype: {}", init.dtype));
306        }
307    }
308
309    // 2. Execute nodes sequentially (assuming topological sort in export)
310    for node in &model.graph.nodes {
311        // Prepare inputs
312        let mut node_inputs = Vec::new();
313        for input_name in &node.inputs {
314            let val = values.get(input_name).ok_or_else(|| {
315                anyhow!("Missing input '{}' for node '{}'", input_name, node.name)
316            })?;
317            node_inputs.push(val);
318        }
319
320        let output_name = &node.outputs[0]; // Assuming single output for now
321
322        let output_val = match node.op_type.as_str() {
323            "Add" => add(&node_inputs[0], &node_inputs[1])?,
324            "Sub" => sub(&node_inputs[0], &node_inputs[1])?,
325            "Mul" => mul(&node_inputs[0], &node_inputs[1])?,
326            "Div" => div(&node_inputs[0], &node_inputs[1])?,
327            "MatMul" => matmul(&node_inputs[0], &node_inputs[1])?,
328            "Relu" => relu(&node_inputs[0])?,
329            "Sigmoid" => sigmoid(&node_inputs[0])?,
330            "Softmax" => {
331                // Check axis attribute
332                let axis = node
333                    .attributes
334                    .get("axis")
335                    .and_then(|a| match a {
336                        OnnxAttribute::Int(i) => Some(*i as usize),
337                        _ => None,
338                    })
339                    .unwrap_or(1); // Default axis 1
340                softmax(&node_inputs[0], Some(axis))?
341            }
342            "Gemm" => {
343                // Y = alpha * A' * B' + beta * C
344                let a = node_inputs[0];
345                let b = node_inputs[1];
346                let c = node_inputs.get(2); // Optional bias
347
348                let alpha = node
349                    .attributes
350                    .get("alpha")
351                    .and_then(|v| match v {
352                        OnnxAttribute::Float(f) => Some(*f),
353                        _ => None,
354                    })
355                    .unwrap_or(1.0);
356                let beta = node
357                    .attributes
358                    .get("beta")
359                    .and_then(|v| match v {
360                        OnnxAttribute::Float(f) => Some(*f),
361                        _ => None,
362                    })
363                    .unwrap_or(1.0);
364
365                let trans_a = node
366                    .attributes
367                    .get("transA")
368                    .and_then(|v| match v {
369                        OnnxAttribute::Int(i) => Some(*i == 1),
370                        _ => None,
371                    })
372                    .unwrap_or(false);
373                let trans_b = node
374                    .attributes
375                    .get("transB")
376                    .and_then(|v| match v {
377                        OnnxAttribute::Int(i) => Some(*i == 1),
378                        _ => None,
379                    })
380                    .unwrap_or(false);
381
382                // Transpose if needed
383                let a_processed = if trans_a {
384                    transpose(a, None)?
385                } else {
386                    a.clone()
387                };
388                let b_processed = if trans_b {
389                    transpose(b, None)?
390                } else {
391                    b.clone()
392                };
393
394                // MatMul
395                let mut mul_res = matmul(&a_processed, &b_processed)?;
396
397                // Alpha scaling
398                if alpha != 1.0 {
399                    let scalar = Array::new(vec![1], vec![alpha]);
400                    mul_res = mul(&mul_res, &scalar)?;
401                }
402
403                // Add bias (Beta * C)
404                if let Some(bias) = c {
405                    if beta != 1.0 {
406                        let scalar = Array::new(vec![1], vec![beta]);
407                        let bias_scaled = mul(bias, &scalar)?;
408                        add(&mul_res, &bias_scaled)?
409                    } else {
410                        add(&mul_res, bias)?
411                    }
412                } else {
413                    mul_res
414                }
415            }
416            "Transpose" => transpose(node_inputs[0], None)?,
417            "Conv" => {
418                // Inputs: X, W, [B]
419                let x = node_inputs[0];
420                let w = node_inputs[1];
421                let b = node_inputs.get(2).map(|&v| v);
422
423                // Attributes
424                let padding = node
425                    .attributes
426                    .get("pads")
427                    .and_then(|a| match a {
428                        OnnxAttribute::Ints(v) => Some(v[0] as usize),
429                        _ => None,
430                    })
431                    .unwrap_or(0);
432
433                let stride = node
434                    .attributes
435                    .get("strides")
436                    .and_then(|a| match a {
437                        OnnxAttribute::Ints(v) => Some(v[0] as usize),
438                        _ => None,
439                    })
440                    .unwrap_or(1);
441
442                crate::ops::conv::conv1d(x, w, b, stride, padding)?
443            }
444            "Reshape" => {
445                // Inputs: Data, Shape (Tensor)
446                let data = node_inputs[0];
447                let shape_tensor = node_inputs[1];
448
449                // Extract shape from tensor data (assuming int64 or int32)
450                let shape_vec: Vec<isize> = shape_tensor.data.iter().map(|&v| v as isize).collect();
451
452                crate::ops::shape::reshape(data, &shape_vec)?
453            }
454            "Flatten" => {
455                // Inputs: Data
456                // Attribute: axis (default 1)
457                let axis = node
458                    .attributes
459                    .get("axis")
460                    .and_then(|a| match a {
461                        OnnxAttribute::Int(i) => Some(*i as usize),
462                        _ => None,
463                    })
464                    .unwrap_or(1);
465
466                // Flatten from axis to end
467                // Note: numrs ops::flatten might differ, using reshape approach for safety if flattened dims are contiguous
468                // But simpler: just use ops::flatten (assuming it exists matching Tensor::flatten)
469                // Or implementing via reshape: [batch, -1] usually
470
471                // Let's use ops::flatten if available, checking imports
472                // For now, manual reshape for safety: flatten [d0...axis-1, axis...end]
473                // Actually my ops export usually maps Flatten -> Reshape in PyTorch legacy, but here explicit Flatten op.
474                // ops::flatten(a, start_dim, end_dim)
475                crate::ops::shape::flatten(node_inputs[0], axis, node_inputs[0].shape().len() - 1)?
476            }
477            "BatchNormalization" => {
478                // Inputs: X, scale, B, mean, var
479                let x = node_inputs[0];
480                let scale = node_inputs[1];
481                let b = node_inputs[2];
482                let mean = node_inputs[3];
483                let var = node_inputs[4];
484
485                let epsilon = node
486                    .attributes
487                    .get("epsilon")
488                    .and_then(|a| match a {
489                        OnnxAttribute::Float(f) => Some(*f),
490                        _ => None,
491                    })
492                    .unwrap_or(1e-5);
493
494                let momentum = node
495                    .attributes
496                    .get("momentum")
497                    .and_then(|a| match a {
498                        OnnxAttribute::Float(f) => Some(*f),
499                        _ => None,
500                    })
501                    .unwrap_or(0.9);
502
503                // Clone running stats because batch_norm expects &mut even if training=false
504                let mut mean_clone = mean.clone();
505                let mut var_clone = var.clone();
506
507                crate::ops::batchnorm::batch_norm(
508                    x,
509                    &mut mean_clone,
510                    &mut var_clone,
511                    scale,
512                    b,
513                    false,
514                    momentum,
515                    epsilon,
516                )?
517            }
518            "Dropout" => {
519                // Identity for inference
520                node_inputs[0].clone()
521            }
522            _ => return Err(anyhow!("Unsupported op type: {}", node.op_type)),
523        };
524
525        values.insert(output_name.clone(), output_val);
526    }
527
528    // 3. Collect outputs
529    let mut results = HashMap::new();
530    for out_name in &model.graph.outputs {
531        if let Some(val) = values.get(out_name) {
532            results.insert(out_name.clone(), val.clone());
533        } else {
534            return Err(anyhow!(
535                "Model output '{}' not found after execution",
536                out_name
537            ));
538        }
539    }
540
541    Ok(results)
542}
543
544/// Save training checkpoint
545///
546/// # Arguments
547/// * `model` - The model to save
548/// * `training_state` - Current training state (epoch, loss, optimizer state)
549/// * `path` - Path to save checkpoint
550pub fn save_checkpoint(
551    model: &OnnxModel,
552    training_state: &TrainingState,
553    path: &str,
554) -> Result<()> {
555    #[derive(serde::Serialize)]
556    struct Checkpoint<'a> {
557        model: &'a OnnxModel,
558        training_state: &'a TrainingState,
559    }
560
561    let checkpoint = Checkpoint {
562        model,
563        training_state,
564    };
565    let json = serde_json::to_string_pretty(&checkpoint)?;
566    std::fs::write(path, json)?;
567    Ok(())
568}
569
570/// Load training checkpoint
571///
572/// # Arguments
573/// * `path` - Path to load checkpoint from
574///
575/// # Returns
576/// Tuple of (model, training_state)
577pub fn load_checkpoint(path: &str) -> Result<(OnnxModel, TrainingState)> {
578    #[derive(serde::Deserialize)]
579    struct Checkpoint {
580        model: OnnxModel,
581        training_state: TrainingState,
582    }
583
584    let json = std::fs::read_to_string(path)?;
585    let checkpoint: Checkpoint = serde_json::from_str(&json)?;
586    Ok((checkpoint.model, checkpoint.training_state))
587}