numrs/llo/
model.rs

1//! Model operations for ONNX compatibility
2//! 
3//! This module defines operations for training, saving, and loading models
4//! compatible with ONNX format.
5
6use serde::{Serialize, Deserialize};
7use std::collections::HashMap;
8
9/// Model save/load operations
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11pub enum ModelKind {
12    /// Save model to ONNX format
13    SaveONNX,
14    /// Load model from ONNX format
15    LoadONNX,
16    /// Export model graph
17    ExportGraph,
18    /// Import model graph
19    ImportGraph,
20}
21
22/// Training-related operations
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
24pub enum TrainingKind {
25    /// Forward pass
26    Forward,
27    /// Backward pass (gradient computation)
28    Backward,
29    /// Update weights (gradient descent step)
30    UpdateWeights,
31    /// Compute loss
32    Loss,
33}
34
35/// Loss function types
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
37pub enum LossKind {
38    /// Mean Squared Error
39    MSE,
40    /// Cross Entropy
41    CrossEntropy,
42    /// Binary Cross Entropy
43    BinaryCrossEntropy,
44    /// L1 Loss (Mean Absolute Error)
45    L1,
46    /// Huber Loss
47    Huber,
48}
49
50/// Optimizer types
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub enum OptimizerKind {
53    /// Stochastic Gradient Descent
54    SGD { learning_rate: f32, momentum: f32 },
55    /// Adam optimizer
56    Adam { learning_rate: f32, beta1: f32, beta2: f32, epsilon: f32 },
57    /// RMSprop optimizer
58    RMSprop { learning_rate: f32, decay: f32, epsilon: f32 },
59    /// AdaGrad optimizer
60    AdaGrad { learning_rate: f32, epsilon: f32 },
61}
62
63/// ONNX Model metadata
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ModelMetadata {
66    /// Model name
67    pub name: String,
68    /// Model version
69    pub version: i64,
70    /// Model author
71    pub author: String,
72    /// Model description
73    pub description: String,
74    /// Model domain (reverse DNS)
75    pub domain: String,
76    /// ONNX opset version
77    pub opset_version: i64,
78    /// Producer name
79    pub producer_name: String,
80    /// Producer version
81    pub producer_version: String,
82    /// Additional metadata
83    pub metadata: HashMap<String, String>,
84}
85
86impl Default for ModelMetadata {
87    fn default() -> Self {
88        Self {
89            name: "numrs_model".to_string(),
90            version: 1,
91            author: "NumRs".to_string(),
92            description: "Model created with NumRs".to_string(),
93            domain: "ai.numrs".to_string(),
94            opset_version: 18, // ONNX 1.13+ uses opset 18
95            producer_name: "NumRs".to_string(),
96            producer_version: env!("CARGO_PKG_VERSION").to_string(),
97            metadata: HashMap::new(),
98        }
99    }
100}
101
102/// ONNX Node representation
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct OnnxNode {
105    /// Node name
106    pub name: String,
107    /// Operator type (Add, MatMul, Relu, etc.)
108    pub op_type: String,
109    /// Input names
110    pub inputs: Vec<String>,
111    /// Output names
112    pub outputs: Vec<String>,
113    /// Operator attributes
114    pub attributes: HashMap<String, OnnxAttribute>,
115}
116
117/// ONNX Attribute value
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub enum OnnxAttribute {
120    Float(f32),
121    Int(i64),
122    String(String),
123    Tensor(Vec<usize>, Vec<f32>), // shape, data
124    Floats(Vec<f32>),
125    Ints(Vec<i64>),
126    Strings(Vec<String>),
127}
128
129/// ONNX Tensor representation
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct OnnxTensor {
132    /// Tensor name
133    pub name: String,
134    /// Data type (1=FLOAT, 6=INT32, 7=INT64, 11=DOUBLE)
135    pub dtype: i32,
136    /// Shape
137    pub shape: Vec<usize>,
138    /// Raw data bytes
139    pub data: Vec<u8>,
140}
141
142/// ONNX Graph representation
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct OnnxGraph {
145    /// Graph name
146    pub name: String,
147    /// List of nodes
148    pub nodes: Vec<OnnxNode>,
149    /// Input tensors
150    pub inputs: Vec<OnnxTensor>,
151    /// Output tensor names
152    pub outputs: Vec<String>,
153    /// Initializers (constant tensors like weights)
154    pub initializers: Vec<OnnxTensor>,
155}
156
157/// ONNX Model representation
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct OnnxModel {
160    /// Model metadata
161    pub metadata: ModelMetadata,
162    /// Model graph
163    pub graph: OnnxGraph,
164}
165
166impl OnnxModel {
167    /// Create a new ONNX model
168    pub fn new(name: &str) -> Self {
169        let mut metadata = ModelMetadata::default();
170        metadata.name = name.to_string();
171        
172        Self {
173            metadata,
174            graph: OnnxGraph {
175                name: name.to_string(),
176                nodes: Vec::new(),
177                inputs: Vec::new(),
178                outputs: Vec::new(),
179                initializers: Vec::new(),
180            },
181        }
182    }
183    
184    /// Add a node to the graph
185    pub fn add_node(&mut self, node: OnnxNode) {
186        self.graph.nodes.push(node);
187    }
188    
189    /// Add an input tensor
190    pub fn add_input(&mut self, tensor: OnnxTensor) {
191        self.graph.inputs.push(tensor);
192    }
193    
194    /// Add an initializer (weights/constants)
195    pub fn add_initializer(&mut self, tensor: OnnxTensor) {
196        self.graph.initializers.push(tensor);
197    }
198    
199    /// Set output names
200    pub fn set_outputs(&mut self, outputs: Vec<String>) {
201        self.graph.outputs = outputs;
202    }
203}
204
205/// Training state for a model
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct TrainingState {
208    /// Current epoch
209    pub epoch: usize,
210    /// Current iteration
211    pub iteration: usize,
212    /// Current loss value
213    pub loss: f32,
214    /// Optimizer state
215    pub optimizer: OptimizerKind,
216    /// Parameter gradients
217    pub gradients: HashMap<String, Vec<f32>>,
218    /// Optimizer momentum/state variables
219    pub optimizer_state: HashMap<String, Vec<f32>>,
220}
221
222impl TrainingState {
223    /// Create new training state with SGD optimizer
224    pub fn new_sgd(learning_rate: f32) -> Self {
225        Self {
226            epoch: 0,
227            iteration: 0,
228            loss: 0.0,
229            optimizer: OptimizerKind::SGD { 
230                learning_rate, 
231                momentum: 0.9 
232            },
233            gradients: HashMap::new(),
234            optimizer_state: HashMap::new(),
235        }
236    }
237    
238    /// Create new training state with Adam optimizer
239    pub fn new_adam(learning_rate: f32) -> Self {
240        Self {
241            epoch: 0,
242            iteration: 0,
243            loss: 0.0,
244            optimizer: OptimizerKind::Adam {
245                learning_rate,
246                beta1: 0.9,
247                beta2: 0.999,
248                epsilon: 1e-8,
249            },
250            gradients: HashMap::new(),
251            optimizer_state: HashMap::new(),
252        }
253    }
254}