ghostflow_core/
serialize.rs

1//! Model serialization and deserialization
2
3use crate::tensor::Tensor;
4use crate::dtype::DType;
5use crate::error::{GhostError, Result};
6use std::collections::HashMap;
7use std::io::{Read, Write, BufReader, BufWriter};
8use std::fs::File;
9use std::path::Path;
10
11/// Magic number for GhostFlow model files
12const MAGIC: &[u8; 8] = b"GHOSTFLW";
13/// Current format version
14const VERSION: u32 = 1;
15
16/// State dictionary - maps parameter names to tensors
17pub type StateDict = HashMap<String, Tensor>;
18
19/// Save a state dictionary to a file
20pub fn save_state_dict<P: AsRef<Path>>(state_dict: &StateDict, path: P) -> Result<()> {
21    let file = File::create(path)
22        .map_err(|e| GhostError::InvalidOperation(format!("Failed to create file: {}", e)))?;
23    let mut writer = BufWriter::new(file);
24    
25    // Write header
26    writer.write_all(MAGIC)
27        .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
28    writer.write_all(&VERSION.to_le_bytes())
29        .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
30    
31    // Write number of tensors
32    let num_tensors = state_dict.len() as u32;
33    writer.write_all(&num_tensors.to_le_bytes())
34        .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
35    
36    // Write each tensor
37    for (name, tensor) in state_dict {
38        write_tensor(&mut writer, name, tensor)?;
39    }
40    
41    writer.flush()
42        .map_err(|e| GhostError::InvalidOperation(format!("Flush error: {}", e)))?;
43    
44    Ok(())
45}
46
47/// Load a state dictionary from a file
48pub fn load_state_dict<P: AsRef<Path>>(path: P) -> Result<StateDict> {
49    let file = File::open(path)
50        .map_err(|e| GhostError::InvalidOperation(format!("Failed to open file: {}", e)))?;
51    let mut reader = BufReader::new(file);
52    
53    // Read and verify header
54    let mut magic = [0u8; 8];
55    reader.read_exact(&mut magic)
56        .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
57    if &magic != MAGIC {
58        return Err(GhostError::InvalidOperation("Invalid file format".into()));
59    }
60    
61    let mut version_bytes = [0u8; 4];
62    reader.read_exact(&mut version_bytes)
63        .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
64    let version = u32::from_le_bytes(version_bytes);
65    if version > VERSION {
66        return Err(GhostError::InvalidOperation(format!(
67            "Unsupported version: {} (max: {})", version, VERSION
68        )));
69    }
70    
71    // Read number of tensors
72    let mut num_bytes = [0u8; 4];
73    reader.read_exact(&mut num_bytes)
74        .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
75    let num_tensors = u32::from_le_bytes(num_bytes) as usize;
76    
77    // Read tensors
78    let mut state_dict = HashMap::with_capacity(num_tensors);
79    for _ in 0..num_tensors {
80        let (name, tensor) = read_tensor(&mut reader)?;
81        state_dict.insert(name, tensor);
82    }
83    
84    Ok(state_dict)
85}
86
87fn write_tensor<W: Write>(writer: &mut W, name: &str, tensor: &Tensor) -> Result<()> {
88    // Write name length and name
89    let name_bytes = name.as_bytes();
90    let name_len = name_bytes.len() as u32;
91    writer.write_all(&name_len.to_le_bytes())
92        .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
93    writer.write_all(name_bytes)
94        .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
95    
96    // Write dtype
97    let dtype_byte = dtype_to_byte(tensor.dtype());
98    writer.write_all(&[dtype_byte])
99        .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
100    
101    // Write shape
102    let dims = tensor.dims();
103    let ndim = dims.len() as u32;
104    writer.write_all(&ndim.to_le_bytes())
105        .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
106    for &dim in dims {
107        writer.write_all(&(dim as u64).to_le_bytes())
108            .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
109    }
110    
111    // Write data
112    let data = tensor.data_f32();
113    let data_bytes: Vec<u8> = data.iter()
114        .flat_map(|&f| f.to_le_bytes())
115        .collect();
116    writer.write_all(&data_bytes)
117        .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
118    
119    Ok(())
120}
121
122fn read_tensor<R: Read>(reader: &mut R) -> Result<(String, Tensor)> {
123    // Read name
124    let mut name_len_bytes = [0u8; 4];
125    reader.read_exact(&mut name_len_bytes)
126        .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
127    let name_len = u32::from_le_bytes(name_len_bytes) as usize;
128    
129    let mut name_bytes = vec![0u8; name_len];
130    reader.read_exact(&mut name_bytes)
131        .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
132    let name = String::from_utf8(name_bytes)
133        .map_err(|e| GhostError::InvalidOperation(format!("Invalid UTF-8: {}", e)))?;
134    
135    // Read dtype
136    let mut dtype_byte = [0u8; 1];
137    reader.read_exact(&mut dtype_byte)
138        .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
139    let _dtype = byte_to_dtype(dtype_byte[0])?;
140    
141    // Read shape
142    let mut ndim_bytes = [0u8; 4];
143    reader.read_exact(&mut ndim_bytes)
144        .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
145    let ndim = u32::from_le_bytes(ndim_bytes) as usize;
146    
147    let mut dims = Vec::with_capacity(ndim);
148    for _ in 0..ndim {
149        let mut dim_bytes = [0u8; 8];
150        reader.read_exact(&mut dim_bytes)
151            .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
152        dims.push(u64::from_le_bytes(dim_bytes) as usize);
153    }
154    
155    // Read data
156    let numel: usize = dims.iter().product();
157    let mut data_bytes = vec![0u8; numel * 4]; // f32 = 4 bytes
158    reader.read_exact(&mut data_bytes)
159        .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
160    
161    let data: Vec<f32> = data_bytes
162        .chunks_exact(4)
163        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
164        .collect();
165    
166    let tensor = Tensor::from_slice(&data, &dims)?;
167    
168    Ok((name, tensor))
169}
170
171fn dtype_to_byte(dtype: DType) -> u8 {
172    match dtype {
173        DType::F16 => 0,
174        DType::BF16 => 1,
175        DType::F32 => 2,
176        DType::F64 => 3,
177        DType::I8 => 4,
178        DType::I16 => 5,
179        DType::I32 => 6,
180        DType::I64 => 7,
181        DType::U8 => 8,
182        DType::Bool => 9,
183    }
184}
185
186fn byte_to_dtype(byte: u8) -> Result<DType> {
187    match byte {
188        0 => Ok(DType::F16),
189        1 => Ok(DType::BF16),
190        2 => Ok(DType::F32),
191        3 => Ok(DType::F64),
192        4 => Ok(DType::I8),
193        5 => Ok(DType::I16),
194        6 => Ok(DType::I32),
195        7 => Ok(DType::I64),
196        8 => Ok(DType::U8),
197        9 => Ok(DType::Bool),
198        _ => Err(GhostError::InvalidOperation(format!("Unknown dtype: {}", byte))),
199    }
200}
201
202/// Trait for models that can be serialized
203pub trait Serializable {
204    /// Get state dictionary
205    fn state_dict(&self) -> StateDict;
206    
207    /// Load state dictionary
208    fn load_state_dict(&mut self, state_dict: &StateDict) -> Result<()>;
209    
210    /// Save model to file
211    fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
212        save_state_dict(&self.state_dict(), path)
213    }
214    
215    /// Load model from file
216    fn load<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
217        let state_dict = load_state_dict(path)?;
218        self.load_state_dict(&state_dict)
219    }
220}
221
222/// SafeTensors format support (compatible with HuggingFace)
223pub mod safetensors {
224    use super::*;
225    
226    /// Save in SafeTensors format
227    pub fn save<P: AsRef<Path>>(state_dict: &StateDict, path: P) -> Result<()> {
228        // SafeTensors format:
229        // - 8 bytes: header size (little endian)
230        // - header_size bytes: JSON header
231        // - tensor data
232        
233        let file = File::create(path)
234            .map_err(|e| GhostError::InvalidOperation(format!("Failed to create file: {}", e)))?;
235        let mut writer = BufWriter::new(file);
236        
237        // Build header
238        let mut header = String::from("{");
239        let mut offset = 0usize;
240        let mut tensor_data: Vec<u8> = Vec::new();
241        
242        for (i, (name, tensor)) in state_dict.iter().enumerate() {
243            if i > 0 {
244                header.push(',');
245            }
246            
247            let data = tensor.data_f32();
248            let data_bytes: Vec<u8> = data.iter()
249                .flat_map(|&f| f.to_le_bytes())
250                .collect();
251            let data_len = data_bytes.len();
252            
253            // Add to header
254            header.push_str(&format!(
255                "\"{}\":{{\"dtype\":\"F32\",\"shape\":{:?},\"data_offsets\":[{},{}]}}",
256                name,
257                tensor.dims(),
258                offset,
259                offset + data_len
260            ));
261            
262            tensor_data.extend(data_bytes);
263            offset += data_len;
264        }
265        header.push('}');
266        
267        // Write header size
268        let header_bytes = header.as_bytes();
269        let header_size = header_bytes.len() as u64;
270        writer.write_all(&header_size.to_le_bytes())
271            .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
272        
273        // Write header
274        writer.write_all(header_bytes)
275            .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
276        
277        // Write tensor data
278        writer.write_all(&tensor_data)
279            .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
280        
281        writer.flush()
282            .map_err(|e| GhostError::InvalidOperation(format!("Flush error: {}", e)))?;
283        
284        Ok(())
285    }
286    
287    /// Load from SafeTensors format
288    pub fn load<P: AsRef<Path>>(path: P) -> Result<StateDict> {
289        let file = File::open(path)
290            .map_err(|e| GhostError::InvalidOperation(format!("Failed to open file: {}", e)))?;
291        let mut reader = BufReader::new(file);
292        
293        // Read header size
294        let mut header_size_bytes = [0u8; 8];
295        reader.read_exact(&mut header_size_bytes)
296            .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
297        let header_size = u64::from_le_bytes(header_size_bytes) as usize;
298        
299        // Read header
300        let mut header_bytes = vec![0u8; header_size];
301        reader.read_exact(&mut header_bytes)
302            .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
303        let header = String::from_utf8(header_bytes)
304            .map_err(|e| GhostError::InvalidOperation(format!("Invalid UTF-8: {}", e)))?;
305        
306        // Read all tensor data
307        let mut tensor_data = Vec::new();
308        reader.read_to_end(&mut tensor_data)
309            .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
310        
311        // Parse header (simplified JSON parsing)
312        let state_dict = parse_safetensors_header(&header, &tensor_data)?;
313        
314        Ok(state_dict)
315    }
316    
317    fn parse_safetensors_header(header: &str, data: &[u8]) -> Result<StateDict> {
318        // Very simplified JSON parsing - in production, use serde_json
319        let mut state_dict = HashMap::new();
320        
321        // Remove only the outermost braces
322        let content = header.trim();
323        let content = if content.starts_with('{') && content.ends_with('}') {
324            &content[1..content.len()-1]
325        } else {
326            content
327        };
328        let content = content.trim();
329        
330        if content.is_empty() {
331            return Ok(state_dict);
332        }
333        
334        // Parse: "name":{"dtype":"F32","shape":[2,3],"data_offsets":[0,24]}
335        let mut chars = content.chars().peekable();
336        let mut current_name = String::new();
337        let mut tensor_json = String::new();
338        let mut in_quotes = false;
339        let mut in_name = false;
340        let mut in_value = false;
341        let mut brace_depth = 0;
342        
343        while let Some(ch) = chars.next() {
344            match ch {
345                '"' => {
346                    if in_value {
347                        // Inside value, just add the quote
348                        tensor_json.push(ch);
349                        in_quotes = !in_quotes;
350                    } else {
351                        // Outside value, toggle quotes for name parsing
352                        in_quotes = !in_quotes;
353                        if !in_value && !in_name && !in_quotes {
354                            // Just closed the name
355                            in_name = false;
356                        } else if !in_value && !in_name && in_quotes {
357                            // Starting a name
358                            in_name = true;
359                            current_name.clear();
360                        }
361                    }
362                }
363                ':' if !in_quotes && !in_value => {
364                    // After name, before value
365                    in_name = false;
366                    in_value = true;
367                    tensor_json.clear();
368                    // Skip whitespace
369                    while let Some(&' ') = chars.peek() {
370                        chars.next();
371                    }
372                }
373                '{' if !in_quotes && in_value => {
374                    brace_depth += 1;
375                    tensor_json.push(ch);
376                }
377                '}' => {
378                    if !in_quotes && in_value {
379                        tensor_json.push(ch);
380                        brace_depth -= 1;
381                        if brace_depth == 0 {
382                            // End of tensor value
383                            if let Ok(tensor) = parse_tensor_entry(&current_name, &tensor_json, data) {
384                                state_dict.insert(current_name.clone(), tensor);
385                            }
386                            in_value = false;
387                            current_name.clear();
388                            tensor_json.clear();
389                        }
390                    }
391                }
392                ',' if !in_quotes && !in_value => {
393                    // Between entries
394                    continue;
395                }
396                _ => {
397                    if in_name && in_quotes {
398                        current_name.push(ch);
399                    } else if in_value {
400                        // Include everything in the value, including quotes
401                        tensor_json.push(ch);
402                    }
403                }
404            }
405        }
406        
407        Ok(state_dict)
408    }
409    
410    fn parse_tensor_entry(_name: &str, json: &str, data: &[u8]) -> Result<Tensor> {
411        // Extract shape and offsets from JSON (simplified)
412        // Format: {"dtype":"F32","shape":[2,3],"data_offsets":[0,24]}
413        
414        // Find shape
415        let shape_start = json.find("\"shape\":").ok_or_else(|| 
416            GhostError::InvalidOperation("Missing shape".into()))? + 8;
417        let shape_end = json[shape_start..].find(']').ok_or_else(||
418            GhostError::InvalidOperation("Invalid shape".into()))? + shape_start + 1;
419        let shape_str = &json[shape_start..shape_end];
420        
421        // Parse shape array
422        let shape: Vec<usize> = shape_str
423            .trim_start_matches('[')
424            .trim_end_matches(']')
425            .split(',')
426            .filter_map(|s| s.trim().parse().ok())
427            .collect();
428        
429        // Find data offsets
430        let offsets_start = json.find("\"data_offsets\":").ok_or_else(||
431            GhostError::InvalidOperation("Missing offsets".into()))? + 15;
432        let offsets_end = json[offsets_start..].find(']').ok_or_else(||
433            GhostError::InvalidOperation("Invalid offsets".into()))? + offsets_start + 1;
434        let offsets_str = &json[offsets_start..offsets_end];
435        
436        let offsets: Vec<usize> = offsets_str
437            .trim_start_matches('[')
438            .trim_end_matches(']')
439            .split(',')
440            .filter_map(|s| s.trim().parse().ok())
441            .collect();
442        
443        if offsets.len() != 2 {
444            return Err(GhostError::InvalidOperation("Invalid offsets".into()));
445        }
446        
447        // Extract tensor data
448        let tensor_bytes = &data[offsets[0]..offsets[1]];
449        let tensor_data: Vec<f32> = tensor_bytes
450            .chunks_exact(4)
451            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
452            .collect();
453        
454        Tensor::from_slice(&tensor_data, &shape)
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use std::fs;
462
463    #[test]
464    fn test_save_load_state_dict() {
465        let mut state_dict = HashMap::new();
466        state_dict.insert("weight".to_string(), Tensor::randn(&[3, 4]));
467        state_dict.insert("bias".to_string(), Tensor::zeros(&[4]));
468        
469        let path = "test_model.gf";
470        save_state_dict(&state_dict, path).unwrap();
471        
472        let loaded = load_state_dict(path).unwrap();
473        
474        assert_eq!(loaded.len(), 2);
475        assert!(loaded.contains_key("weight"));
476        assert!(loaded.contains_key("bias"));
477        
478        fs::remove_file(path).ok();
479    }
480
481    #[test]
482    fn test_safetensors_save_load() {
483        let mut state_dict = HashMap::new();
484        state_dict.insert("layer.weight".to_string(), Tensor::randn(&[2, 3]));
485        
486        let path = "test_model.safetensors";
487        safetensors::save(&state_dict, path).unwrap();
488        
489        let loaded = safetensors::load(path).unwrap();
490        
491        assert!(loaded.contains_key("layer.weight"), "Loaded dict should contain layer.weight");
492        assert_eq!(loaded["layer.weight"].shape().dims(), &[2, 3]);
493        
494        fs::remove_file(path).ok();
495    }
496}