1use crate::tensor::Tensor;
4use crate::dtype::DType;
5use crate::error::{GhostError, Result};
6use std::collections::HashMap;
7use std::io::{Read, Write, BufReader, BufWriter};
8use std::fs::File;
9use std::path::Path;
10
11const MAGIC: &[u8; 8] = b"GHOSTFLW";
13const VERSION: u32 = 1;
15
16pub type StateDict = HashMap<String, Tensor>;
18
19pub fn save_state_dict<P: AsRef<Path>>(state_dict: &StateDict, path: P) -> Result<()> {
21 let file = File::create(path)
22 .map_err(|e| GhostError::InvalidOperation(format!("Failed to create file: {}", e)))?;
23 let mut writer = BufWriter::new(file);
24
25 writer.write_all(MAGIC)
27 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
28 writer.write_all(&VERSION.to_le_bytes())
29 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
30
31 let num_tensors = state_dict.len() as u32;
33 writer.write_all(&num_tensors.to_le_bytes())
34 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
35
36 for (name, tensor) in state_dict {
38 write_tensor(&mut writer, name, tensor)?;
39 }
40
41 writer.flush()
42 .map_err(|e| GhostError::InvalidOperation(format!("Flush error: {}", e)))?;
43
44 Ok(())
45}
46
47pub fn load_state_dict<P: AsRef<Path>>(path: P) -> Result<StateDict> {
49 let file = File::open(path)
50 .map_err(|e| GhostError::InvalidOperation(format!("Failed to open file: {}", e)))?;
51 let mut reader = BufReader::new(file);
52
53 let mut magic = [0u8; 8];
55 reader.read_exact(&mut magic)
56 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
57 if &magic != MAGIC {
58 return Err(GhostError::InvalidOperation("Invalid file format".into()));
59 }
60
61 let mut version_bytes = [0u8; 4];
62 reader.read_exact(&mut version_bytes)
63 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
64 let version = u32::from_le_bytes(version_bytes);
65 if version > VERSION {
66 return Err(GhostError::InvalidOperation(format!(
67 "Unsupported version: {} (max: {})", version, VERSION
68 )));
69 }
70
71 let mut num_bytes = [0u8; 4];
73 reader.read_exact(&mut num_bytes)
74 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
75 let num_tensors = u32::from_le_bytes(num_bytes) as usize;
76
77 let mut state_dict = HashMap::with_capacity(num_tensors);
79 for _ in 0..num_tensors {
80 let (name, tensor) = read_tensor(&mut reader)?;
81 state_dict.insert(name, tensor);
82 }
83
84 Ok(state_dict)
85}
86
87fn write_tensor<W: Write>(writer: &mut W, name: &str, tensor: &Tensor) -> Result<()> {
88 let name_bytes = name.as_bytes();
90 let name_len = name_bytes.len() as u32;
91 writer.write_all(&name_len.to_le_bytes())
92 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
93 writer.write_all(name_bytes)
94 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
95
96 let dtype_byte = dtype_to_byte(tensor.dtype());
98 writer.write_all(&[dtype_byte])
99 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
100
101 let dims = tensor.dims();
103 let ndim = dims.len() as u32;
104 writer.write_all(&ndim.to_le_bytes())
105 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
106 for &dim in dims {
107 writer.write_all(&(dim as u64).to_le_bytes())
108 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
109 }
110
111 let data = tensor.data_f32();
113 let data_bytes: Vec<u8> = data.iter()
114 .flat_map(|&f| f.to_le_bytes())
115 .collect();
116 writer.write_all(&data_bytes)
117 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
118
119 Ok(())
120}
121
122fn read_tensor<R: Read>(reader: &mut R) -> Result<(String, Tensor)> {
123 let mut name_len_bytes = [0u8; 4];
125 reader.read_exact(&mut name_len_bytes)
126 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
127 let name_len = u32::from_le_bytes(name_len_bytes) as usize;
128
129 let mut name_bytes = vec![0u8; name_len];
130 reader.read_exact(&mut name_bytes)
131 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
132 let name = String::from_utf8(name_bytes)
133 .map_err(|e| GhostError::InvalidOperation(format!("Invalid UTF-8: {}", e)))?;
134
135 let mut dtype_byte = [0u8; 1];
137 reader.read_exact(&mut dtype_byte)
138 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
139 let _dtype = byte_to_dtype(dtype_byte[0])?;
140
141 let mut ndim_bytes = [0u8; 4];
143 reader.read_exact(&mut ndim_bytes)
144 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
145 let ndim = u32::from_le_bytes(ndim_bytes) as usize;
146
147 let mut dims = Vec::with_capacity(ndim);
148 for _ in 0..ndim {
149 let mut dim_bytes = [0u8; 8];
150 reader.read_exact(&mut dim_bytes)
151 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
152 dims.push(u64::from_le_bytes(dim_bytes) as usize);
153 }
154
155 let numel: usize = dims.iter().product();
157 let mut data_bytes = vec![0u8; numel * 4]; reader.read_exact(&mut data_bytes)
159 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
160
161 let data: Vec<f32> = data_bytes
162 .chunks_exact(4)
163 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
164 .collect();
165
166 let tensor = Tensor::from_slice(&data, &dims)?;
167
168 Ok((name, tensor))
169}
170
171fn dtype_to_byte(dtype: DType) -> u8 {
172 match dtype {
173 DType::F16 => 0,
174 DType::BF16 => 1,
175 DType::F32 => 2,
176 DType::F64 => 3,
177 DType::I8 => 4,
178 DType::I16 => 5,
179 DType::I32 => 6,
180 DType::I64 => 7,
181 DType::U8 => 8,
182 DType::Bool => 9,
183 }
184}
185
186fn byte_to_dtype(byte: u8) -> Result<DType> {
187 match byte {
188 0 => Ok(DType::F16),
189 1 => Ok(DType::BF16),
190 2 => Ok(DType::F32),
191 3 => Ok(DType::F64),
192 4 => Ok(DType::I8),
193 5 => Ok(DType::I16),
194 6 => Ok(DType::I32),
195 7 => Ok(DType::I64),
196 8 => Ok(DType::U8),
197 9 => Ok(DType::Bool),
198 _ => Err(GhostError::InvalidOperation(format!("Unknown dtype: {}", byte))),
199 }
200}
201
202pub trait Serializable {
204 fn state_dict(&self) -> StateDict;
206
207 fn load_state_dict(&mut self, state_dict: &StateDict) -> Result<()>;
209
210 fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
212 save_state_dict(&self.state_dict(), path)
213 }
214
215 fn load<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
217 let state_dict = load_state_dict(path)?;
218 self.load_state_dict(&state_dict)
219 }
220}
221
222pub mod safetensors {
224 use super::*;
225
226 pub fn save<P: AsRef<Path>>(state_dict: &StateDict, path: P) -> Result<()> {
228 let file = File::create(path)
234 .map_err(|e| GhostError::InvalidOperation(format!("Failed to create file: {}", e)))?;
235 let mut writer = BufWriter::new(file);
236
237 let mut header = String::from("{");
239 let mut offset = 0usize;
240 let mut tensor_data: Vec<u8> = Vec::new();
241
242 for (i, (name, tensor)) in state_dict.iter().enumerate() {
243 if i > 0 {
244 header.push(',');
245 }
246
247 let data = tensor.data_f32();
248 let data_bytes: Vec<u8> = data.iter()
249 .flat_map(|&f| f.to_le_bytes())
250 .collect();
251 let data_len = data_bytes.len();
252
253 header.push_str(&format!(
255 "\"{}\":{{\"dtype\":\"F32\",\"shape\":{:?},\"data_offsets\":[{},{}]}}",
256 name,
257 tensor.dims(),
258 offset,
259 offset + data_len
260 ));
261
262 tensor_data.extend(data_bytes);
263 offset += data_len;
264 }
265 header.push('}');
266
267 let header_bytes = header.as_bytes();
269 let header_size = header_bytes.len() as u64;
270 writer.write_all(&header_size.to_le_bytes())
271 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
272
273 writer.write_all(header_bytes)
275 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
276
277 writer.write_all(&tensor_data)
279 .map_err(|e| GhostError::InvalidOperation(format!("Write error: {}", e)))?;
280
281 writer.flush()
282 .map_err(|e| GhostError::InvalidOperation(format!("Flush error: {}", e)))?;
283
284 Ok(())
285 }
286
287 pub fn load<P: AsRef<Path>>(path: P) -> Result<StateDict> {
289 let file = File::open(path)
290 .map_err(|e| GhostError::InvalidOperation(format!("Failed to open file: {}", e)))?;
291 let mut reader = BufReader::new(file);
292
293 let mut header_size_bytes = [0u8; 8];
295 reader.read_exact(&mut header_size_bytes)
296 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
297 let header_size = u64::from_le_bytes(header_size_bytes) as usize;
298
299 let mut header_bytes = vec![0u8; header_size];
301 reader.read_exact(&mut header_bytes)
302 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
303 let header = String::from_utf8(header_bytes)
304 .map_err(|e| GhostError::InvalidOperation(format!("Invalid UTF-8: {}", e)))?;
305
306 let mut tensor_data = Vec::new();
308 reader.read_to_end(&mut tensor_data)
309 .map_err(|e| GhostError::InvalidOperation(format!("Read error: {}", e)))?;
310
311 let state_dict = parse_safetensors_header(&header, &tensor_data)?;
313
314 Ok(state_dict)
315 }
316
317 fn parse_safetensors_header(header: &str, data: &[u8]) -> Result<StateDict> {
318 let mut state_dict = HashMap::new();
320
321 let content = header.trim();
323 let content = if content.starts_with('{') && content.ends_with('}') {
324 &content[1..content.len()-1]
325 } else {
326 content
327 };
328 let content = content.trim();
329
330 if content.is_empty() {
331 return Ok(state_dict);
332 }
333
334 let mut chars = content.chars().peekable();
336 let mut current_name = String::new();
337 let mut tensor_json = String::new();
338 let mut in_quotes = false;
339 let mut in_name = false;
340 let mut in_value = false;
341 let mut brace_depth = 0;
342
343 while let Some(ch) = chars.next() {
344 match ch {
345 '"' => {
346 if in_value {
347 tensor_json.push(ch);
349 in_quotes = !in_quotes;
350 } else {
351 in_quotes = !in_quotes;
353 if !in_value && !in_name && !in_quotes {
354 in_name = false;
356 } else if !in_value && !in_name && in_quotes {
357 in_name = true;
359 current_name.clear();
360 }
361 }
362 }
363 ':' if !in_quotes && !in_value => {
364 in_name = false;
366 in_value = true;
367 tensor_json.clear();
368 while let Some(&' ') = chars.peek() {
370 chars.next();
371 }
372 }
373 '{' if !in_quotes && in_value => {
374 brace_depth += 1;
375 tensor_json.push(ch);
376 }
377 '}' => {
378 if !in_quotes && in_value {
379 tensor_json.push(ch);
380 brace_depth -= 1;
381 if brace_depth == 0 {
382 if let Ok(tensor) = parse_tensor_entry(¤t_name, &tensor_json, data) {
384 state_dict.insert(current_name.clone(), tensor);
385 }
386 in_value = false;
387 current_name.clear();
388 tensor_json.clear();
389 }
390 }
391 }
392 ',' if !in_quotes && !in_value => {
393 continue;
395 }
396 _ => {
397 if in_name && in_quotes {
398 current_name.push(ch);
399 } else if in_value {
400 tensor_json.push(ch);
402 }
403 }
404 }
405 }
406
407 Ok(state_dict)
408 }
409
410 fn parse_tensor_entry(_name: &str, json: &str, data: &[u8]) -> Result<Tensor> {
411 let shape_start = json.find("\"shape\":").ok_or_else(||
416 GhostError::InvalidOperation("Missing shape".into()))? + 8;
417 let shape_end = json[shape_start..].find(']').ok_or_else(||
418 GhostError::InvalidOperation("Invalid shape".into()))? + shape_start + 1;
419 let shape_str = &json[shape_start..shape_end];
420
421 let shape: Vec<usize> = shape_str
423 .trim_start_matches('[')
424 .trim_end_matches(']')
425 .split(',')
426 .filter_map(|s| s.trim().parse().ok())
427 .collect();
428
429 let offsets_start = json.find("\"data_offsets\":").ok_or_else(||
431 GhostError::InvalidOperation("Missing offsets".into()))? + 15;
432 let offsets_end = json[offsets_start..].find(']').ok_or_else(||
433 GhostError::InvalidOperation("Invalid offsets".into()))? + offsets_start + 1;
434 let offsets_str = &json[offsets_start..offsets_end];
435
436 let offsets: Vec<usize> = offsets_str
437 .trim_start_matches('[')
438 .trim_end_matches(']')
439 .split(',')
440 .filter_map(|s| s.trim().parse().ok())
441 .collect();
442
443 if offsets.len() != 2 {
444 return Err(GhostError::InvalidOperation("Invalid offsets".into()));
445 }
446
447 let tensor_bytes = &data[offsets[0]..offsets[1]];
449 let tensor_data: Vec<f32> = tensor_bytes
450 .chunks_exact(4)
451 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
452 .collect();
453
454 Tensor::from_slice(&tensor_data, &shape)
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use std::fs;
462
463 #[test]
464 fn test_save_load_state_dict() {
465 let mut state_dict = HashMap::new();
466 state_dict.insert("weight".to_string(), Tensor::randn(&[3, 4]));
467 state_dict.insert("bias".to_string(), Tensor::zeros(&[4]));
468
469 let path = "test_model.gf";
470 save_state_dict(&state_dict, path).unwrap();
471
472 let loaded = load_state_dict(path).unwrap();
473
474 assert_eq!(loaded.len(), 2);
475 assert!(loaded.contains_key("weight"));
476 assert!(loaded.contains_key("bias"));
477
478 fs::remove_file(path).ok();
479 }
480
481 #[test]
482 fn test_safetensors_save_load() {
483 let mut state_dict = HashMap::new();
484 state_dict.insert("layer.weight".to_string(), Tensor::randn(&[2, 3]));
485
486 let path = "test_model.safetensors";
487 safetensors::save(&state_dict, path).unwrap();
488
489 let loaded = safetensors::load(path).unwrap();
490
491 assert!(loaded.contains_key("layer.weight"), "Loaded dict should contain layer.weight");
492 assert_eq!(loaded["layer.weight"].shape().dims(), &[2, 3]);
493
494 fs::remove_file(path).ok();
495 }
496}