1use serde::{Serialize, Deserialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11pub enum ModelKind {
12 SaveONNX,
14 LoadONNX,
16 ExportGraph,
18 ImportGraph,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
24pub enum TrainingKind {
25 Forward,
27 Backward,
29 UpdateWeights,
31 Loss,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
37pub enum LossKind {
38 MSE,
40 CrossEntropy,
42 BinaryCrossEntropy,
44 L1,
46 Huber,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub enum OptimizerKind {
53 SGD { learning_rate: f32, momentum: f32 },
55 Adam { learning_rate: f32, beta1: f32, beta2: f32, epsilon: f32 },
57 RMSprop { learning_rate: f32, decay: f32, epsilon: f32 },
59 AdaGrad { learning_rate: f32, epsilon: f32 },
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ModelMetadata {
66 pub name: String,
68 pub version: i64,
70 pub author: String,
72 pub description: String,
74 pub domain: String,
76 pub opset_version: i64,
78 pub producer_name: String,
80 pub producer_version: String,
82 pub metadata: HashMap<String, String>,
84}
85
86impl Default for ModelMetadata {
87 fn default() -> Self {
88 Self {
89 name: "numrs_model".to_string(),
90 version: 1,
91 author: "NumRs".to_string(),
92 description: "Model created with NumRs".to_string(),
93 domain: "ai.numrs".to_string(),
94 opset_version: 18, producer_name: "NumRs".to_string(),
96 producer_version: env!("CARGO_PKG_VERSION").to_string(),
97 metadata: HashMap::new(),
98 }
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct OnnxNode {
105 pub name: String,
107 pub op_type: String,
109 pub inputs: Vec<String>,
111 pub outputs: Vec<String>,
113 pub attributes: HashMap<String, OnnxAttribute>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub enum OnnxAttribute {
120 Float(f32),
121 Int(i64),
122 String(String),
123 Tensor(Vec<usize>, Vec<f32>), Floats(Vec<f32>),
125 Ints(Vec<i64>),
126 Strings(Vec<String>),
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct OnnxTensor {
132 pub name: String,
134 pub dtype: i32,
136 pub shape: Vec<usize>,
138 pub data: Vec<u8>,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct OnnxGraph {
145 pub name: String,
147 pub nodes: Vec<OnnxNode>,
149 pub inputs: Vec<OnnxTensor>,
151 pub outputs: Vec<String>,
153 pub initializers: Vec<OnnxTensor>,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct OnnxModel {
160 pub metadata: ModelMetadata,
162 pub graph: OnnxGraph,
164}
165
166impl OnnxModel {
167 pub fn new(name: &str) -> Self {
169 let mut metadata = ModelMetadata::default();
170 metadata.name = name.to_string();
171
172 Self {
173 metadata,
174 graph: OnnxGraph {
175 name: name.to_string(),
176 nodes: Vec::new(),
177 inputs: Vec::new(),
178 outputs: Vec::new(),
179 initializers: Vec::new(),
180 },
181 }
182 }
183
184 pub fn add_node(&mut self, node: OnnxNode) {
186 self.graph.nodes.push(node);
187 }
188
189 pub fn add_input(&mut self, tensor: OnnxTensor) {
191 self.graph.inputs.push(tensor);
192 }
193
194 pub fn add_initializer(&mut self, tensor: OnnxTensor) {
196 self.graph.initializers.push(tensor);
197 }
198
199 pub fn set_outputs(&mut self, outputs: Vec<String>) {
201 self.graph.outputs = outputs;
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct TrainingState {
208 pub epoch: usize,
210 pub iteration: usize,
212 pub loss: f32,
214 pub optimizer: OptimizerKind,
216 pub gradients: HashMap<String, Vec<f32>>,
218 pub optimizer_state: HashMap<String, Vec<f32>>,
220}
221
222impl TrainingState {
223 pub fn new_sgd(learning_rate: f32) -> Self {
225 Self {
226 epoch: 0,
227 iteration: 0,
228 loss: 0.0,
229 optimizer: OptimizerKind::SGD {
230 learning_rate,
231 momentum: 0.9
232 },
233 gradients: HashMap::new(),
234 optimizer_state: HashMap::new(),
235 }
236 }
237
238 pub fn new_adam(learning_rate: f32) -> Self {
240 Self {
241 epoch: 0,
242 iteration: 0,
243 loss: 0.0,
244 optimizer: OptimizerKind::Adam {
245 learning_rate,
246 beta1: 0.9,
247 beta2: 0.999,
248 epsilon: 1e-8,
249 },
250 gradients: HashMap::new(),
251 optimizer_state: HashMap::new(),
252 }
253 }
254}