1use ghostflow_core::{Result, Tensor, GhostError};
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::{Write, Read};
10
11#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum ONNXDataType {
14 Float32,
15 Float64,
16 Int32,
17 Int64,
18 Uint8,
19}
20
21#[derive(Debug, Clone)]
23pub struct ONNXTensor {
24 pub name: String,
25 pub dtype: ONNXDataType,
26 pub shape: Vec<i64>,
27 pub data: Vec<u8>,
28}
29
30#[derive(Debug, Clone)]
32pub struct ONNXNode {
33 pub name: String,
34 pub op_type: String,
35 pub inputs: Vec<String>,
36 pub outputs: Vec<String>,
37 pub attributes: HashMap<String, ONNXAttribute>,
38}
39
40#[derive(Debug, Clone)]
42pub enum ONNXAttribute {
43 Int(i64),
44 Float(f32),
45 String(String),
46 Ints(Vec<i64>),
47 Floats(Vec<f32>),
48}
49
50#[derive(Debug, Clone)]
52pub struct ONNXGraph {
53 pub name: String,
54 pub nodes: Vec<ONNXNode>,
55 pub inputs: Vec<ONNXTensor>,
56 pub outputs: Vec<ONNXTensor>,
57 pub initializers: Vec<ONNXTensor>,
58}
59
60#[derive(Debug, Clone)]
62pub struct ONNXModel {
63 pub ir_version: i64,
64 pub producer_name: String,
65 pub producer_version: String,
66 pub graph: ONNXGraph,
67}
68
69impl ONNXModel {
70 pub fn new(name: &str) -> Self {
72 Self {
73 ir_version: 8, producer_name: "GhostFlow".to_string(),
75 producer_version: env!("CARGO_PKG_VERSION").to_string(),
76 graph: ONNXGraph {
77 name: name.to_string(),
78 nodes: Vec::new(),
79 inputs: Vec::new(),
80 outputs: Vec::new(),
81 initializers: Vec::new(),
82 },
83 }
84 }
85
86 pub fn save(&self, path: &str) -> Result<()> {
88 let serialized = self.serialize()?;
89 let mut file = File::create(path)
90 .map_err(|e| GhostError::IOError(format!("Failed to create file: {}", e)))?;
91 file.write_all(&serialized)
92 .map_err(|e| GhostError::IOError(format!("Failed to write file: {}", e)))?;
93 Ok(())
94 }
95
96 pub fn load(path: &str) -> Result<Self> {
98 let mut file = File::open(path)
99 .map_err(|e| GhostError::IOError(format!("Failed to open file: {}", e)))?;
100 let mut buffer = Vec::new();
101 file.read_to_end(&mut buffer)
102 .map_err(|e| GhostError::IOError(format!("Failed to read file: {}", e)))?;
103 Self::deserialize(&buffer)
104 }
105
106 fn serialize(&self) -> Result<Vec<u8>> {
108 let mut buffer = Vec::new();
109
110 buffer.extend_from_slice(b"ONNX");
112
113 buffer.extend_from_slice(&self.ir_version.to_le_bytes());
115
116 let producer_bytes = self.producer_name.as_bytes();
118 buffer.extend_from_slice(&(producer_bytes.len() as u32).to_le_bytes());
119 buffer.extend_from_slice(producer_bytes);
120
121 let version_bytes = self.producer_version.as_bytes();
123 buffer.extend_from_slice(&(version_bytes.len() as u32).to_le_bytes());
124 buffer.extend_from_slice(version_bytes);
125
126 let graph_name_bytes = self.graph.name.as_bytes();
128 buffer.extend_from_slice(&(graph_name_bytes.len() as u32).to_le_bytes());
129 buffer.extend_from_slice(graph_name_bytes);
130
131 buffer.extend_from_slice(&(self.graph.nodes.len() as u32).to_le_bytes());
133
134 for node in &self.graph.nodes {
136 self.serialize_node(node, &mut buffer)?;
137 }
138
139 buffer.extend_from_slice(&(self.graph.initializers.len() as u32).to_le_bytes());
141
142 for tensor in &self.graph.initializers {
144 self.serialize_tensor(tensor, &mut buffer)?;
145 }
146
147 Ok(buffer)
148 }
149
150 fn deserialize(buffer: &[u8]) -> Result<Self> {
152 let mut offset = 0;
153
154 if &buffer[0..4] != b"ONNX" {
156 return Err(GhostError::InvalidFormat("Invalid ONNX magic number".to_string()));
157 }
158 offset += 4;
159
160 let ir_version = i64::from_le_bytes(buffer[offset..offset+8].try_into().unwrap());
162 offset += 8;
163
164 let name_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
166 offset += 4;
167 let producer_name = String::from_utf8(buffer[offset..offset+name_len].to_vec())
168 .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
169 offset += name_len;
170
171 let version_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
173 offset += 4;
174 let producer_version = String::from_utf8(buffer[offset..offset+version_len].to_vec())
175 .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
176 offset += version_len;
177
178 let graph_name_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
180 offset += 4;
181 let graph_name = String::from_utf8(buffer[offset..offset+graph_name_len].to_vec())
182 .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
183 offset += graph_name_len;
184
185 let num_nodes = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
187 offset += 4;
188
189 let mut nodes = Vec::new();
190 for _ in 0..num_nodes {
191 let (node, new_offset) = Self::deserialize_node(buffer, offset)?;
192 nodes.push(node);
193 offset = new_offset;
194 }
195
196 let num_initializers = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
198 offset += 4;
199
200 let mut initializers = Vec::new();
201 for _ in 0..num_initializers {
202 let (tensor, new_offset) = Self::deserialize_tensor(buffer, offset)?;
203 initializers.push(tensor);
204 offset = new_offset;
205 }
206
207 Ok(Self {
208 ir_version,
209 producer_name,
210 producer_version,
211 graph: ONNXGraph {
212 name: graph_name,
213 nodes,
214 inputs: Vec::new(),
215 outputs: Vec::new(),
216 initializers,
217 },
218 })
219 }
220
221 fn serialize_node(&self, node: &ONNXNode, buffer: &mut Vec<u8>) -> Result<()> {
222 let name_bytes = node.name.as_bytes();
224 buffer.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
225 buffer.extend_from_slice(name_bytes);
226
227 let op_bytes = node.op_type.as_bytes();
229 buffer.extend_from_slice(&(op_bytes.len() as u32).to_le_bytes());
230 buffer.extend_from_slice(op_bytes);
231
232 buffer.extend_from_slice(&(node.inputs.len() as u32).to_le_bytes());
234 for input in &node.inputs {
235 let input_bytes = input.as_bytes();
236 buffer.extend_from_slice(&(input_bytes.len() as u32).to_le_bytes());
237 buffer.extend_from_slice(input_bytes);
238 }
239
240 buffer.extend_from_slice(&(node.outputs.len() as u32).to_le_bytes());
242 for output in &node.outputs {
243 let output_bytes = output.as_bytes();
244 buffer.extend_from_slice(&(output_bytes.len() as u32).to_le_bytes());
245 buffer.extend_from_slice(output_bytes);
246 }
247
248 Ok(())
249 }
250
251 fn deserialize_node(buffer: &[u8], mut offset: usize) -> Result<(ONNXNode, usize)> {
252 let name_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
254 offset += 4;
255 let name = String::from_utf8(buffer[offset..offset+name_len].to_vec())
256 .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
257 offset += name_len;
258
259 let op_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
261 offset += 4;
262 let op_type = String::from_utf8(buffer[offset..offset+op_len].to_vec())
263 .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
264 offset += op_len;
265
266 let num_inputs = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
268 offset += 4;
269 let mut inputs = Vec::new();
270 for _ in 0..num_inputs {
271 let input_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
272 offset += 4;
273 let input = String::from_utf8(buffer[offset..offset+input_len].to_vec())
274 .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
275 offset += input_len;
276 inputs.push(input);
277 }
278
279 let num_outputs = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
281 offset += 4;
282 let mut outputs = Vec::new();
283 for _ in 0..num_outputs {
284 let output_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
285 offset += 4;
286 let output = String::from_utf8(buffer[offset..offset+output_len].to_vec())
287 .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
288 offset += output_len;
289 outputs.push(output);
290 }
291
292 Ok((ONNXNode {
293 name,
294 op_type,
295 inputs,
296 outputs,
297 attributes: HashMap::new(),
298 }, offset))
299 }
300
301 fn serialize_tensor(&self, tensor: &ONNXTensor, buffer: &mut Vec<u8>) -> Result<()> {
302 let name_bytes = tensor.name.as_bytes();
304 buffer.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
305 buffer.extend_from_slice(name_bytes);
306
307 buffer.push(tensor.dtype as u8);
309
310 buffer.extend_from_slice(&(tensor.shape.len() as u32).to_le_bytes());
312 for dim in &tensor.shape {
313 buffer.extend_from_slice(&dim.to_le_bytes());
314 }
315
316 buffer.extend_from_slice(&(tensor.data.len() as u32).to_le_bytes());
318 buffer.extend_from_slice(&tensor.data);
319
320 Ok(())
321 }
322
323 fn deserialize_tensor(buffer: &[u8], mut offset: usize) -> Result<(ONNXTensor, usize)> {
324 let name_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
326 offset += 4;
327 let name = String::from_utf8(buffer[offset..offset+name_len].to_vec())
328 .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
329 offset += name_len;
330
331 let dtype = match buffer[offset] {
333 0 => ONNXDataType::Float32,
334 1 => ONNXDataType::Float64,
335 2 => ONNXDataType::Int32,
336 3 => ONNXDataType::Int64,
337 4 => ONNXDataType::Uint8,
338 _ => return Err(GhostError::InvalidFormat("Unknown data type".to_string())),
339 };
340 offset += 1;
341
342 let shape_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
344 offset += 4;
345 let mut shape = Vec::new();
346 for _ in 0..shape_len {
347 let dim = i64::from_le_bytes(buffer[offset..offset+8].try_into().unwrap());
348 offset += 8;
349 shape.push(dim);
350 }
351
352 let data_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
354 offset += 4;
355 let data = buffer[offset..offset+data_len].to_vec();
356 offset += data_len;
357
358 Ok((ONNXTensor {
359 name,
360 dtype,
361 shape,
362 data,
363 }, offset))
364 }
365
366 pub fn add_node(&mut self, node: ONNXNode) {
368 self.graph.nodes.push(node);
369 }
370
371 pub fn add_initializer(&mut self, tensor: ONNXTensor) {
373 self.graph.initializers.push(tensor);
374 }
375}
376
377pub fn tensor_to_onnx(name: &str, tensor: &Tensor) -> ONNXTensor {
379 let shape: Vec<i64> = tensor.dims().iter().map(|&d| d as i64).collect();
380 let data = tensor.data_f32();
381 let bytes: Vec<u8> = data.iter()
382 .flat_map(|&f| f.to_le_bytes())
383 .collect();
384
385 ONNXTensor {
386 name: name.to_string(),
387 dtype: ONNXDataType::Float32,
388 shape,
389 data: bytes,
390 }
391}
392
393pub fn onnx_to_tensor(onnx_tensor: &ONNXTensor) -> Result<Tensor> {
395 if onnx_tensor.dtype != ONNXDataType::Float32 {
396 return Err(GhostError::InvalidFormat("Only Float32 supported".to_string()));
397 }
398
399 let floats: Vec<f32> = onnx_tensor.data
400 .chunks_exact(4)
401 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
402 .collect();
403
404 let shape: Vec<usize> = onnx_tensor.shape.iter().map(|&d| d as usize).collect();
405 Tensor::from_slice(&floats, &shape)
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_onnx_model_creation() {
414 let model = ONNXModel::new("test_model");
415 assert_eq!(model.graph.name, "test_model");
416 assert_eq!(model.producer_name, "GhostFlow");
417 }
418
419 #[test]
420 fn test_onnx_serialization() {
421 let mut model = ONNXModel::new("test");
422
423 model.add_node(ONNXNode {
425 name: "linear1".to_string(),
426 op_type: "Gemm".to_string(),
427 inputs: vec!["input".to_string(), "weight".to_string()],
428 outputs: vec!["output".to_string()],
429 attributes: HashMap::new(),
430 });
431
432 let bytes = model.serialize().unwrap();
434 let loaded = ONNXModel::deserialize(&bytes).unwrap();
435
436 assert_eq!(loaded.graph.name, "test");
437 assert_eq!(loaded.graph.nodes.len(), 1);
438 assert_eq!(loaded.graph.nodes[0].op_type, "Gemm");
439 }
440
441 #[test]
442 fn test_tensor_conversion() {
443 let tensor = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
444 let onnx_tensor = tensor_to_onnx("test", &tensor);
445
446 assert_eq!(onnx_tensor.name, "test");
447 assert_eq!(onnx_tensor.shape, vec![2, 2]);
448
449 let converted = onnx_to_tensor(&onnx_tensor).unwrap();
450 assert_eq!(converted.dims(), &[2, 2]);
451 }
452}
453