ghostflow_nn/
onnx.rs

1//! ONNX export and import functionality
2//!
3//! This module provides functionality to export GhostFlow models to ONNX format
4//! and import ONNX models into GhostFlow.
5
6use ghostflow_core::{Result, Tensor, GhostError};
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::{Write, Read};
10
11/// ONNX data types
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum ONNXDataType {
14    Float32,
15    Float64,
16    Int32,
17    Int64,
18    Uint8,
19}
20
21/// ONNX tensor information
22#[derive(Debug, Clone)]
23pub struct ONNXTensor {
24    pub name: String,
25    pub dtype: ONNXDataType,
26    pub shape: Vec<i64>,
27    pub data: Vec<u8>,
28}
29
30/// ONNX node (operation)
31#[derive(Debug, Clone)]
32pub struct ONNXNode {
33    pub name: String,
34    pub op_type: String,
35    pub inputs: Vec<String>,
36    pub outputs: Vec<String>,
37    pub attributes: HashMap<String, ONNXAttribute>,
38}
39
40/// ONNX attribute value
41#[derive(Debug, Clone)]
42pub enum ONNXAttribute {
43    Int(i64),
44    Float(f32),
45    String(String),
46    Ints(Vec<i64>),
47    Floats(Vec<f32>),
48}
49
50/// ONNX graph representation
51#[derive(Debug, Clone)]
52pub struct ONNXGraph {
53    pub name: String,
54    pub nodes: Vec<ONNXNode>,
55    pub inputs: Vec<ONNXTensor>,
56    pub outputs: Vec<ONNXTensor>,
57    pub initializers: Vec<ONNXTensor>,
58}
59
60/// ONNX model
61#[derive(Debug, Clone)]
62pub struct ONNXModel {
63    pub ir_version: i64,
64    pub producer_name: String,
65    pub producer_version: String,
66    pub graph: ONNXGraph,
67}
68
69impl ONNXModel {
70    /// Create a new ONNX model
71    pub fn new(name: &str) -> Self {
72        Self {
73            ir_version: 8, // ONNX IR version 8
74            producer_name: "GhostFlow".to_string(),
75            producer_version: env!("CARGO_PKG_VERSION").to_string(),
76            graph: ONNXGraph {
77                name: name.to_string(),
78                nodes: Vec::new(),
79                inputs: Vec::new(),
80                outputs: Vec::new(),
81                initializers: Vec::new(),
82            },
83        }
84    }
85
86    /// Export model to ONNX file
87    pub fn save(&self, path: &str) -> Result<()> {
88        let serialized = self.serialize()?;
89        let mut file = File::create(path)
90            .map_err(|e| GhostError::IOError(format!("Failed to create file: {}", e)))?;
91        file.write_all(&serialized)
92            .map_err(|e| GhostError::IOError(format!("Failed to write file: {}", e)))?;
93        Ok(())
94    }
95
96    /// Load ONNX model from file
97    pub fn load(path: &str) -> Result<Self> {
98        let mut file = File::open(path)
99            .map_err(|e| GhostError::IOError(format!("Failed to open file: {}", e)))?;
100        let mut buffer = Vec::new();
101        file.read_to_end(&mut buffer)
102            .map_err(|e| GhostError::IOError(format!("Failed to read file: {}", e)))?;
103        Self::deserialize(&buffer)
104    }
105
106    /// Serialize to bytes (simplified protobuf-like format)
107    fn serialize(&self) -> Result<Vec<u8>> {
108        let mut buffer = Vec::new();
109        
110        // Magic number for ONNX
111        buffer.extend_from_slice(b"ONNX");
112        
113        // IR version
114        buffer.extend_from_slice(&self.ir_version.to_le_bytes());
115        
116        // Producer name length and data
117        let producer_bytes = self.producer_name.as_bytes();
118        buffer.extend_from_slice(&(producer_bytes.len() as u32).to_le_bytes());
119        buffer.extend_from_slice(producer_bytes);
120        
121        // Producer version length and data
122        let version_bytes = self.producer_version.as_bytes();
123        buffer.extend_from_slice(&(version_bytes.len() as u32).to_le_bytes());
124        buffer.extend_from_slice(version_bytes);
125        
126        // Graph name
127        let graph_name_bytes = self.graph.name.as_bytes();
128        buffer.extend_from_slice(&(graph_name_bytes.len() as u32).to_le_bytes());
129        buffer.extend_from_slice(graph_name_bytes);
130        
131        // Number of nodes
132        buffer.extend_from_slice(&(self.graph.nodes.len() as u32).to_le_bytes());
133        
134        // Serialize nodes
135        for node in &self.graph.nodes {
136            self.serialize_node(node, &mut buffer)?;
137        }
138        
139        // Number of initializers
140        buffer.extend_from_slice(&(self.graph.initializers.len() as u32).to_le_bytes());
141        
142        // Serialize initializers
143        for tensor in &self.graph.initializers {
144            self.serialize_tensor(tensor, &mut buffer)?;
145        }
146        
147        Ok(buffer)
148    }
149
150    /// Deserialize from bytes
151    fn deserialize(buffer: &[u8]) -> Result<Self> {
152        let mut offset = 0;
153        
154        // Check magic number
155        if &buffer[0..4] != b"ONNX" {
156            return Err(GhostError::InvalidFormat("Invalid ONNX magic number".to_string()));
157        }
158        offset += 4;
159        
160        // Read IR version
161        let ir_version = i64::from_le_bytes(buffer[offset..offset+8].try_into().unwrap());
162        offset += 8;
163        
164        // Read producer name
165        let name_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
166        offset += 4;
167        let producer_name = String::from_utf8(buffer[offset..offset+name_len].to_vec())
168            .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
169        offset += name_len;
170        
171        // Read producer version
172        let version_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
173        offset += 4;
174        let producer_version = String::from_utf8(buffer[offset..offset+version_len].to_vec())
175            .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
176        offset += version_len;
177        
178        // Read graph name
179        let graph_name_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
180        offset += 4;
181        let graph_name = String::from_utf8(buffer[offset..offset+graph_name_len].to_vec())
182            .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
183        offset += graph_name_len;
184        
185        // Read nodes
186        let num_nodes = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
187        offset += 4;
188        
189        let mut nodes = Vec::new();
190        for _ in 0..num_nodes {
191            let (node, new_offset) = Self::deserialize_node(buffer, offset)?;
192            nodes.push(node);
193            offset = new_offset;
194        }
195        
196        // Read initializers
197        let num_initializers = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
198        offset += 4;
199        
200        let mut initializers = Vec::new();
201        for _ in 0..num_initializers {
202            let (tensor, new_offset) = Self::deserialize_tensor(buffer, offset)?;
203            initializers.push(tensor);
204            offset = new_offset;
205        }
206        
207        Ok(Self {
208            ir_version,
209            producer_name,
210            producer_version,
211            graph: ONNXGraph {
212                name: graph_name,
213                nodes,
214                inputs: Vec::new(),
215                outputs: Vec::new(),
216                initializers,
217            },
218        })
219    }
220
221    fn serialize_node(&self, node: &ONNXNode, buffer: &mut Vec<u8>) -> Result<()> {
222        // Node name
223        let name_bytes = node.name.as_bytes();
224        buffer.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
225        buffer.extend_from_slice(name_bytes);
226        
227        // Op type
228        let op_bytes = node.op_type.as_bytes();
229        buffer.extend_from_slice(&(op_bytes.len() as u32).to_le_bytes());
230        buffer.extend_from_slice(op_bytes);
231        
232        // Inputs
233        buffer.extend_from_slice(&(node.inputs.len() as u32).to_le_bytes());
234        for input in &node.inputs {
235            let input_bytes = input.as_bytes();
236            buffer.extend_from_slice(&(input_bytes.len() as u32).to_le_bytes());
237            buffer.extend_from_slice(input_bytes);
238        }
239        
240        // Outputs
241        buffer.extend_from_slice(&(node.outputs.len() as u32).to_le_bytes());
242        for output in &node.outputs {
243            let output_bytes = output.as_bytes();
244            buffer.extend_from_slice(&(output_bytes.len() as u32).to_le_bytes());
245            buffer.extend_from_slice(output_bytes);
246        }
247        
248        Ok(())
249    }
250
251    fn deserialize_node(buffer: &[u8], mut offset: usize) -> Result<(ONNXNode, usize)> {
252        // Node name
253        let name_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
254        offset += 4;
255        let name = String::from_utf8(buffer[offset..offset+name_len].to_vec())
256            .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
257        offset += name_len;
258        
259        // Op type
260        let op_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
261        offset += 4;
262        let op_type = String::from_utf8(buffer[offset..offset+op_len].to_vec())
263            .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
264        offset += op_len;
265        
266        // Inputs
267        let num_inputs = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
268        offset += 4;
269        let mut inputs = Vec::new();
270        for _ in 0..num_inputs {
271            let input_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
272            offset += 4;
273            let input = String::from_utf8(buffer[offset..offset+input_len].to_vec())
274                .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
275            offset += input_len;
276            inputs.push(input);
277        }
278        
279        // Outputs
280        let num_outputs = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
281        offset += 4;
282        let mut outputs = Vec::new();
283        for _ in 0..num_outputs {
284            let output_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
285            offset += 4;
286            let output = String::from_utf8(buffer[offset..offset+output_len].to_vec())
287                .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
288            offset += output_len;
289            outputs.push(output);
290        }
291        
292        Ok((ONNXNode {
293            name,
294            op_type,
295            inputs,
296            outputs,
297            attributes: HashMap::new(),
298        }, offset))
299    }
300
301    fn serialize_tensor(&self, tensor: &ONNXTensor, buffer: &mut Vec<u8>) -> Result<()> {
302        // Tensor name
303        let name_bytes = tensor.name.as_bytes();
304        buffer.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
305        buffer.extend_from_slice(name_bytes);
306        
307        // Data type
308        buffer.push(tensor.dtype as u8);
309        
310        // Shape
311        buffer.extend_from_slice(&(tensor.shape.len() as u32).to_le_bytes());
312        for dim in &tensor.shape {
313            buffer.extend_from_slice(&dim.to_le_bytes());
314        }
315        
316        // Data
317        buffer.extend_from_slice(&(tensor.data.len() as u32).to_le_bytes());
318        buffer.extend_from_slice(&tensor.data);
319        
320        Ok(())
321    }
322
323    fn deserialize_tensor(buffer: &[u8], mut offset: usize) -> Result<(ONNXTensor, usize)> {
324        // Tensor name
325        let name_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
326        offset += 4;
327        let name = String::from_utf8(buffer[offset..offset+name_len].to_vec())
328            .map_err(|e| GhostError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
329        offset += name_len;
330        
331        // Data type
332        let dtype = match buffer[offset] {
333            0 => ONNXDataType::Float32,
334            1 => ONNXDataType::Float64,
335            2 => ONNXDataType::Int32,
336            3 => ONNXDataType::Int64,
337            4 => ONNXDataType::Uint8,
338            _ => return Err(GhostError::InvalidFormat("Unknown data type".to_string())),
339        };
340        offset += 1;
341        
342        // Shape
343        let shape_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
344        offset += 4;
345        let mut shape = Vec::new();
346        for _ in 0..shape_len {
347            let dim = i64::from_le_bytes(buffer[offset..offset+8].try_into().unwrap());
348            offset += 8;
349            shape.push(dim);
350        }
351        
352        // Data
353        let data_len = u32::from_le_bytes(buffer[offset..offset+4].try_into().unwrap()) as usize;
354        offset += 4;
355        let data = buffer[offset..offset+data_len].to_vec();
356        offset += data_len;
357        
358        Ok((ONNXTensor {
359            name,
360            dtype,
361            shape,
362            data,
363        }, offset))
364    }
365
366    /// Add a node to the graph
367    pub fn add_node(&mut self, node: ONNXNode) {
368        self.graph.nodes.push(node);
369    }
370
371    /// Add an initializer (weight tensor)
372    pub fn add_initializer(&mut self, tensor: ONNXTensor) {
373        self.graph.initializers.push(tensor);
374    }
375}
376
377/// Helper to convert GhostFlow tensor to ONNX tensor
378pub fn tensor_to_onnx(name: &str, tensor: &Tensor) -> ONNXTensor {
379    let shape: Vec<i64> = tensor.dims().iter().map(|&d| d as i64).collect();
380    let data = tensor.data_f32();
381    let bytes: Vec<u8> = data.iter()
382        .flat_map(|&f| f.to_le_bytes())
383        .collect();
384    
385    ONNXTensor {
386        name: name.to_string(),
387        dtype: ONNXDataType::Float32,
388        shape,
389        data: bytes,
390    }
391}
392
393/// Helper to convert ONNX tensor to GhostFlow tensor
394pub fn onnx_to_tensor(onnx_tensor: &ONNXTensor) -> Result<Tensor> {
395    if onnx_tensor.dtype != ONNXDataType::Float32 {
396        return Err(GhostError::InvalidFormat("Only Float32 supported".to_string()));
397    }
398    
399    let floats: Vec<f32> = onnx_tensor.data
400        .chunks_exact(4)
401        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
402        .collect();
403    
404    let shape: Vec<usize> = onnx_tensor.shape.iter().map(|&d| d as usize).collect();
405    Tensor::from_slice(&floats, &shape)
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_onnx_model_creation() {
414        let model = ONNXModel::new("test_model");
415        assert_eq!(model.graph.name, "test_model");
416        assert_eq!(model.producer_name, "GhostFlow");
417    }
418
419    #[test]
420    fn test_onnx_serialization() {
421        let mut model = ONNXModel::new("test");
422        
423        // Add a simple node
424        model.add_node(ONNXNode {
425            name: "linear1".to_string(),
426            op_type: "Gemm".to_string(),
427            inputs: vec!["input".to_string(), "weight".to_string()],
428            outputs: vec!["output".to_string()],
429            attributes: HashMap::new(),
430        });
431        
432        // Serialize and deserialize
433        let bytes = model.serialize().unwrap();
434        let loaded = ONNXModel::deserialize(&bytes).unwrap();
435        
436        assert_eq!(loaded.graph.name, "test");
437        assert_eq!(loaded.graph.nodes.len(), 1);
438        assert_eq!(loaded.graph.nodes[0].op_type, "Gemm");
439    }
440
441    #[test]
442    fn test_tensor_conversion() {
443        let tensor = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
444        let onnx_tensor = tensor_to_onnx("test", &tensor);
445        
446        assert_eq!(onnx_tensor.name, "test");
447        assert_eq!(onnx_tensor.shape, vec![2, 2]);
448        
449        let converted = onnx_to_tensor(&onnx_tensor).unwrap();
450        assert_eq!(converted.dims(), &[2, 2]);
451    }
452}
453