1use ghostflow_core::tensor::Tensor;
6use std::collections::HashMap;
7use std::fs::File;
8use std::io::{Read, Write, BufReader, BufWriter};
9use std::path::Path;
10
11#[derive(Clone, Debug)]
13pub struct ModelCheckpoint {
14 pub parameters: HashMap<String, Tensor>,
16 pub metadata: ModelMetadata,
18 pub optimizer_state: Option<HashMap<String, Vec<f32>>>,
20}
21
22#[derive(Clone, Debug)]
24pub struct ModelMetadata {
25 pub name: String,
27 pub version: String,
29 pub framework_version: String,
31 pub epoch: usize,
33 pub loss: f32,
35 pub extra: HashMap<String, String>,
37}
38
39impl Default for ModelMetadata {
40 fn default() -> Self {
41 Self {
42 name: "ghostflow_model".to_string(),
43 version: "1.0.0".to_string(),
44 framework_version: env!("CARGO_PKG_VERSION").to_string(),
45 epoch: 0,
46 loss: 0.0,
47 extra: HashMap::new(),
48 }
49 }
50}
51
52impl ModelCheckpoint {
53 pub fn new(parameters: HashMap<String, Tensor>) -> Self {
55 Self {
56 parameters,
57 metadata: ModelMetadata::default(),
58 optimizer_state: None,
59 }
60 }
61
62 pub fn with_metadata(mut self, metadata: ModelMetadata) -> Self {
64 self.metadata = metadata;
65 self
66 }
67
68 pub fn with_optimizer_state(mut self, state: HashMap<String, Vec<f32>>) -> Self {
70 self.optimizer_state = Some(state);
71 self
72 }
73
74 pub fn save<P: AsRef<Path>>(&self, path: P) -> std::io::Result<()> {
76 let file = File::create(path)?;
77 let mut writer = BufWriter::new(file);
78
79 writer.write_all(b"GFCP")?; writer.write_all(&[0, 4, 0])?; self.write_metadata(&mut writer)?;
87
88 self.write_parameters(&mut writer)?;
90
91 if let Some(ref state) = self.optimizer_state {
93 writer.write_all(&[1])?; self.write_optimizer_state(&mut writer, state)?;
95 } else {
96 writer.write_all(&[0])?; }
98
99 writer.flush()?;
100 Ok(())
101 }
102
103 pub fn load<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
105 let file = File::open(path)?;
106 let mut reader = BufReader::new(file);
107
108 let mut magic = [0u8; 4];
110 reader.read_exact(&mut magic)?;
111 if &magic != b"GFCP" {
112 return Err(std::io::Error::new(
113 std::io::ErrorKind::InvalidData,
114 "Invalid checkpoint file format",
115 ));
116 }
117
118 let mut version = [0u8; 3];
120 reader.read_exact(&mut version)?;
121
122 let metadata = Self::read_metadata(&mut reader)?;
124
125 let parameters = Self::read_parameters(&mut reader)?;
127
128 let mut has_optimizer = [0u8; 1];
130 reader.read_exact(&mut has_optimizer)?;
131 let optimizer_state = if has_optimizer[0] == 1 {
132 Some(Self::read_optimizer_state(&mut reader)?)
133 } else {
134 None
135 };
136
137 Ok(Self {
138 parameters,
139 metadata,
140 optimizer_state,
141 })
142 }
143
144 fn write_metadata<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
145 self.write_string(writer, &self.metadata.name)?;
147 self.write_string(writer, &self.metadata.version)?;
149 self.write_string(writer, &self.metadata.framework_version)?;
151 writer.write_all(&self.metadata.epoch.to_le_bytes())?;
153 writer.write_all(&self.metadata.loss.to_le_bytes())?;
155 writer.write_all(&(self.metadata.extra.len() as u32).to_le_bytes())?;
157 for (key, value) in &self.metadata.extra {
158 self.write_string(writer, key)?;
159 self.write_string(writer, value)?;
160 }
161 Ok(())
162 }
163
164 fn read_metadata<R: Read>(reader: &mut R) -> std::io::Result<ModelMetadata> {
165 let name = Self::read_string(reader)?;
166 let version = Self::read_string(reader)?;
167 let framework_version = Self::read_string(reader)?;
168
169 let mut epoch_bytes = [0u8; 8];
170 reader.read_exact(&mut epoch_bytes)?;
171 let epoch = usize::from_le_bytes(epoch_bytes);
172
173 let mut loss_bytes = [0u8; 4];
174 reader.read_exact(&mut loss_bytes)?;
175 let loss = f32::from_le_bytes(loss_bytes);
176
177 let mut count_bytes = [0u8; 4];
178 reader.read_exact(&mut count_bytes)?;
179 let count = u32::from_le_bytes(count_bytes) as usize;
180
181 let mut extra = HashMap::new();
182 for _ in 0..count {
183 let key = Self::read_string(reader)?;
184 let value = Self::read_string(reader)?;
185 extra.insert(key, value);
186 }
187
188 Ok(ModelMetadata {
189 name,
190 version,
191 framework_version,
192 epoch,
193 loss,
194 extra,
195 })
196 }
197
198 fn write_parameters<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
199 writer.write_all(&(self.parameters.len() as u32).to_le_bytes())?;
201
202 for (name, tensor) in &self.parameters {
203 self.write_string(writer, name)?;
205
206 let shape = tensor.shape().dims();
208 writer.write_all(&(shape.len() as u32).to_le_bytes())?;
209 for &dim in shape {
210 writer.write_all(&(dim as u64).to_le_bytes())?;
211 }
212
213 let data = tensor.storage().as_slice::<f32>();
215 writer.write_all(&(data.len() as u64).to_le_bytes())?;
216 for &value in data.iter() {
217 writer.write_all(&value.to_le_bytes())?;
218 }
219 }
220
221 Ok(())
222 }
223
224 fn read_parameters<R: Read>(reader: &mut R) -> std::io::Result<HashMap<String, Tensor>> {
225 let mut count_bytes = [0u8; 4];
226 reader.read_exact(&mut count_bytes)?;
227 let count = u32::from_le_bytes(count_bytes) as usize;
228
229 let mut parameters = HashMap::new();
230
231 for _ in 0..count {
232 let name = Self::read_string(reader)?;
234
235 let mut shape_len_bytes = [0u8; 4];
237 reader.read_exact(&mut shape_len_bytes)?;
238 let shape_len = u32::from_le_bytes(shape_len_bytes) as usize;
239
240 let mut shape = Vec::with_capacity(shape_len);
241 for _ in 0..shape_len {
242 let mut dim_bytes = [0u8; 8];
243 reader.read_exact(&mut dim_bytes)?;
244 shape.push(u64::from_le_bytes(dim_bytes) as usize);
245 }
246
247 let mut data_len_bytes = [0u8; 8];
249 reader.read_exact(&mut data_len_bytes)?;
250 let data_len = u64::from_le_bytes(data_len_bytes) as usize;
251
252 let mut data = Vec::with_capacity(data_len);
253 for _ in 0..data_len {
254 let mut value_bytes = [0u8; 4];
255 reader.read_exact(&mut value_bytes)?;
256 data.push(f32::from_le_bytes(value_bytes));
257 }
258
259 let tensor = Tensor::from_slice(&data, &shape)
260 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
261
262 parameters.insert(name, tensor);
263 }
264
265 Ok(parameters)
266 }
267
268 fn write_optimizer_state<W: Write>(
269 &self,
270 writer: &mut W,
271 state: &HashMap<String, Vec<f32>>,
272 ) -> std::io::Result<()> {
273 writer.write_all(&(state.len() as u32).to_le_bytes())?;
274
275 for (name, values) in state {
276 self.write_string(writer, name)?;
277 writer.write_all(&(values.len() as u64).to_le_bytes())?;
278 for &value in values {
279 writer.write_all(&value.to_le_bytes())?;
280 }
281 }
282
283 Ok(())
284 }
285
286 fn read_optimizer_state<R: Read>(reader: &mut R) -> std::io::Result<HashMap<String, Vec<f32>>> {
287 let mut count_bytes = [0u8; 4];
288 reader.read_exact(&mut count_bytes)?;
289 let count = u32::from_le_bytes(count_bytes) as usize;
290
291 let mut state = HashMap::new();
292
293 for _ in 0..count {
294 let name = Self::read_string(reader)?;
295
296 let mut len_bytes = [0u8; 8];
297 reader.read_exact(&mut len_bytes)?;
298 let len = u64::from_le_bytes(len_bytes) as usize;
299
300 let mut values = Vec::with_capacity(len);
301 for _ in 0..len {
302 let mut value_bytes = [0u8; 4];
303 reader.read_exact(&mut value_bytes)?;
304 values.push(f32::from_le_bytes(value_bytes));
305 }
306
307 state.insert(name, values);
308 }
309
310 Ok(state)
311 }
312
313 fn write_string<W: Write>(&self, writer: &mut W, s: &str) -> std::io::Result<()> {
314 let bytes = s.as_bytes();
315 writer.write_all(&(bytes.len() as u32).to_le_bytes())?;
316 writer.write_all(bytes)?;
317 Ok(())
318 }
319
320 fn read_string<R: Read>(reader: &mut R) -> std::io::Result<String> {
321 let mut len_bytes = [0u8; 4];
322 reader.read_exact(&mut len_bytes)?;
323 let len = u32::from_le_bytes(len_bytes) as usize;
324
325 let mut bytes = vec![0u8; len];
326 reader.read_exact(&mut bytes)?;
327
328 String::from_utf8(bytes)
329 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
330 }
331}
332
333pub fn save_model<P: AsRef<Path>>(
335 path: P,
336 parameters: HashMap<String, Tensor>,
337) -> std::io::Result<()> {
338 let checkpoint = ModelCheckpoint::new(parameters);
339 checkpoint.save(path)
340}
341
342pub fn load_model<P: AsRef<Path>>(path: P) -> std::io::Result<HashMap<String, Tensor>> {
343 let checkpoint = ModelCheckpoint::load(path)?;
344 Ok(checkpoint.parameters)
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use std::fs;
351
352 #[test]
353 fn test_save_load_checkpoint() {
354 let mut parameters = HashMap::new();
355 parameters.insert(
356 "weight".to_string(),
357 Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
358 );
359 parameters.insert(
360 "bias".to_string(),
361 Tensor::from_slice(&[0.5f32, 0.6], &[2]).unwrap(),
362 );
363
364 let checkpoint = ModelCheckpoint::new(parameters.clone());
365
366 let path = "test_checkpoint.gfcp";
367 checkpoint.save(path).unwrap();
368
369 let loaded = ModelCheckpoint::load(path).unwrap();
370
371 assert_eq!(loaded.parameters.len(), 2);
372 assert!(loaded.parameters.contains_key("weight"));
373 assert!(loaded.parameters.contains_key("bias"));
374
375 fs::remove_file(path).ok();
377 }
378
379 #[test]
380 fn test_checkpoint_with_metadata() {
381 let mut parameters = HashMap::new();
382 parameters.insert(
383 "layer1".to_string(),
384 Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap(),
385 );
386
387 let mut metadata = ModelMetadata::default();
388 metadata.name = "test_model".to_string();
389 metadata.epoch = 10;
390 metadata.loss = 0.123;
391
392 let checkpoint = ModelCheckpoint::new(parameters)
393 .with_metadata(metadata);
394
395 let path = "test_metadata.gfcp";
396 checkpoint.save(path).unwrap();
397
398 let loaded = ModelCheckpoint::load(path).unwrap();
399
400 assert_eq!(loaded.metadata.name, "test_model");
401 assert_eq!(loaded.metadata.epoch, 10);
402 assert!((loaded.metadata.loss - 0.123).abs() < 0.001);
403
404 fs::remove_file(path).ok();
406 }
407
408 #[test]
409 fn test_checkpoint_with_optimizer_state() {
410 let mut parameters = HashMap::new();
411 parameters.insert(
412 "weight".to_string(),
413 Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap(),
414 );
415
416 let mut optimizer_state = HashMap::new();
417 optimizer_state.insert("momentum".to_string(), vec![0.1f32, 0.2]);
418
419 let checkpoint = ModelCheckpoint::new(parameters)
420 .with_optimizer_state(optimizer_state);
421
422 let path = "test_optimizer.gfcp";
423 checkpoint.save(path).unwrap();
424
425 let loaded = ModelCheckpoint::load(path).unwrap();
426
427 assert!(loaded.optimizer_state.is_some());
428 let state = loaded.optimizer_state.unwrap();
429 assert!(state.contains_key("momentum"));
430
431 fs::remove_file(path).ok();
433 }
434
435 #[test]
436 fn test_simple_save_load() {
437 let mut parameters = HashMap::new();
438 parameters.insert(
439 "test".to_string(),
440 Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap(),
441 );
442
443 let path = "test_simple.gfcp";
444 save_model(path, parameters.clone()).unwrap();
445
446 let loaded = load_model(path).unwrap();
447
448 assert_eq!(loaded.len(), 1);
449 assert!(loaded.contains_key("test"));
450
451 fs::remove_file(path).ok();
453 }
454}