1use crate::error::{TokenizerError, TokenizerResult};
27use crate::{ReconstructionMetrics, TrainingConfig};
28use chrono::{DateTime, Utc};
29use scirs2_core::ndarray::Array2;
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::fmt;
33use std::fs::File;
34use std::io::{Read, Write};
35use std::path::Path;
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39pub struct ModelVersion {
40 pub major: u32,
42 pub minor: u32,
44 pub patch: u32,
46}
47
48impl ModelVersion {
49 pub fn new(major: u32, minor: u32, patch: u32) -> Self {
51 Self {
52 major,
53 minor,
54 patch,
55 }
56 }
57
58 pub fn parse(s: &str) -> TokenizerResult<Self> {
60 let parts: Vec<&str> = s.split('.').collect();
61 if parts.len() != 3 {
62 return Err(TokenizerError::InvalidConfig(format!(
63 "Invalid version string: {}",
64 s
65 )));
66 }
67
68 let major = parts[0]
69 .parse()
70 .map_err(|_| TokenizerError::InvalidConfig("Invalid major version".into()))?;
71 let minor = parts[1]
72 .parse()
73 .map_err(|_| TokenizerError::InvalidConfig("Invalid minor version".into()))?;
74 let patch = parts[2]
75 .parse()
76 .map_err(|_| TokenizerError::InvalidConfig("Invalid patch version".into()))?;
77
78 Ok(Self::new(major, minor, patch))
79 }
80
81 pub fn is_compatible_with(&self, other: &ModelVersion) -> bool {
83 self.major == other.major
85 }
86}
87
88impl fmt::Display for ModelVersion {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ModelMetadata {
97 pub version: ModelVersion,
99 pub model_type: String,
101 pub created_at: DateTime<Utc>,
103 pub modified_at: DateTime<Utc>,
105 pub input_dim: usize,
107 pub embed_dim: usize,
109 pub training_config: Option<TrainingConfig>,
111 pub metrics: Option<ReconstructionMetrics>,
113 pub custom: HashMap<String, String>,
115}
116
117impl ModelMetadata {
118 pub fn new(
120 version: ModelVersion,
121 model_type: String,
122 input_dim: usize,
123 embed_dim: usize,
124 ) -> Self {
125 let now = Utc::now();
126 Self {
127 version,
128 model_type,
129 created_at: now,
130 modified_at: now,
131 input_dim,
132 embed_dim,
133 training_config: None,
134 metrics: None,
135 custom: HashMap::new(),
136 }
137 }
138
139 pub fn touch(&mut self) {
141 self.modified_at = Utc::now();
142 }
143
144 pub fn add_custom(&mut self, key: String, value: String) {
146 self.custom.insert(key, value);
147 }
148}
149
150#[derive(Debug)]
152pub struct ModelCheckpoint {
153 pub metadata: ModelMetadata,
155 pub weights: HashMap<String, Vec<f32>>,
157 pub shapes: HashMap<String, Vec<usize>>,
159}
160
161impl ModelCheckpoint {
162 pub fn new(metadata: ModelMetadata) -> Self {
164 Self {
165 metadata,
166 weights: HashMap::new(),
167 shapes: HashMap::new(),
168 }
169 }
170
171 pub fn add_weight(&mut self, name: String, data: Vec<f32>, shape: Vec<usize>) {
173 self.weights.insert(name.clone(), data);
174 self.shapes.insert(name, shape);
175 }
176
177 pub fn add_array2(&mut self, name: String, array: &Array2<f32>) {
179 let shape = array.shape();
180 let data: Vec<f32> = array.iter().copied().collect();
181 self.add_weight(name, data, vec![shape[0], shape[1]]);
182 }
183
184 pub fn get_weight(&self, name: &str) -> Option<(&[f32], &[usize])> {
186 self.weights
187 .get(name)
188 .and_then(|w| self.shapes.get(name).map(|s| (w.as_slice(), s.as_slice())))
189 }
190
191 pub fn get_array2(&self, name: &str) -> TokenizerResult<Array2<f32>> {
193 let (data, shape) = self
194 .get_weight(name)
195 .ok_or_else(|| TokenizerError::InvalidConfig(format!("Weight '{}' not found", name)))?;
196
197 if shape.len() != 2 {
198 return Err(TokenizerError::InvalidConfig(format!(
199 "Expected 2D array for '{}', got {}D",
200 name,
201 shape.len()
202 )));
203 }
204
205 let mut array = Array2::zeros((shape[0], shape[1]));
206 for (i, &val) in data.iter().enumerate() {
207 let row = i / shape[1];
208 let col = i % shape[1];
209 array[[row, col]] = val;
210 }
211
212 Ok(array)
213 }
214
215 pub fn save<P: AsRef<Path>>(&self, path: P) -> TokenizerResult<()> {
217 let path = path.as_ref();
218
219 let mut tensors = Vec::new();
221 for (name, data) in &self.weights {
222 let shape = self.shapes.get(name).ok_or_else(|| {
223 TokenizerError::InternalError(format!("Missing shape for weight '{}'", name))
224 })?;
225
226 let data_bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
228
229 tensors.push((name.clone(), shape.clone(), data_bytes));
230 }
231
232 let metadata_json = serde_json::to_string(&self.metadata)
234 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
235
236 let mut data_map: HashMap<String, (Vec<usize>, Vec<u8>)> = HashMap::new();
238 for (name, shape, data) in tensors {
239 data_map.insert(name, (shape, data));
240 }
241
242 let mut file = File::create(path)
244 .map_err(|e| TokenizerError::InternalError(format!("Failed to create file: {}", e)))?;
245
246 let metadata_bytes = metadata_json.as_bytes();
248 let metadata_len = metadata_bytes.len() as u32;
249
250 file.write_all(&metadata_len.to_le_bytes())
251 .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
252 file.write_all(metadata_bytes)
253 .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
254
255 for (name, (shape, data)) in data_map {
257 let name_bytes = name.as_bytes();
259 file.write_all(&(name_bytes.len() as u32).to_le_bytes())
260 .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
261 file.write_all(name_bytes)
262 .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
263
264 file.write_all(&(shape.len() as u32).to_le_bytes())
265 .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
266 for &dim in &shape {
267 file.write_all(&(dim as u32).to_le_bytes()).map_err(|e| {
268 TokenizerError::InternalError(format!("Failed to write: {}", e))
269 })?;
270 }
271
272 file.write_all(&(data.len() as u32).to_le_bytes())
273 .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
274 file.write_all(&data)
275 .map_err(|e| TokenizerError::InternalError(format!("Failed to write: {}", e)))?;
276 }
277
278 Ok(())
279 }
280
281 pub fn load<P: AsRef<Path>>(path: P) -> TokenizerResult<Self> {
283 let path = path.as_ref();
284 let mut file = File::open(path)
285 .map_err(|e| TokenizerError::InternalError(format!("Failed to open file: {}", e)))?;
286
287 let mut len_buf = [0u8; 4];
289 file.read_exact(&mut len_buf)
290 .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
291 let metadata_len = u32::from_le_bytes(len_buf) as usize;
292
293 let mut metadata_buf = vec![0u8; metadata_len];
295 file.read_exact(&mut metadata_buf)
296 .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
297 let metadata: ModelMetadata = serde_json::from_slice(&metadata_buf).map_err(|e| {
298 TokenizerError::InternalError(format!("Failed to parse metadata: {}", e))
299 })?;
300
301 let mut checkpoint = ModelCheckpoint::new(metadata);
302
303 loop {
305 let mut name_len_buf = [0u8; 4];
307 match file.read_exact(&mut name_len_buf) {
308 Ok(_) => {}
309 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
310 Err(e) => {
311 return Err(TokenizerError::InternalError(format!(
312 "Failed to read: {}",
313 e
314 )))
315 }
316 }
317 let name_len = u32::from_le_bytes(name_len_buf) as usize;
318
319 let mut name_buf = vec![0u8; name_len];
321 file.read_exact(&mut name_buf)
322 .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
323 let name = String::from_utf8(name_buf)
324 .map_err(|e| TokenizerError::InternalError(format!("Invalid UTF-8: {}", e)))?;
325
326 let mut shape_len_buf = [0u8; 4];
328 file.read_exact(&mut shape_len_buf)
329 .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
330 let shape_len = u32::from_le_bytes(shape_len_buf) as usize;
331
332 let mut shape = Vec::with_capacity(shape_len);
333 for _ in 0..shape_len {
334 let mut dim_buf = [0u8; 4];
335 file.read_exact(&mut dim_buf)
336 .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
337 shape.push(u32::from_le_bytes(dim_buf) as usize);
338 }
339
340 let mut data_len_buf = [0u8; 4];
342 file.read_exact(&mut data_len_buf)
343 .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
344 let data_len = u32::from_le_bytes(data_len_buf) as usize;
345
346 let mut data_bytes = vec![0u8; data_len];
347 file.read_exact(&mut data_bytes)
348 .map_err(|e| TokenizerError::InternalError(format!("Failed to read: {}", e)))?;
349
350 let data: Vec<f32> = data_bytes
352 .chunks_exact(4)
353 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
354 .collect();
355
356 checkpoint.add_weight(name, data, shape);
357 }
358
359 Ok(checkpoint)
360 }
361}
362
363pub fn save_config<P: AsRef<Path>>(config: &TrainingConfig, path: P) -> TokenizerResult<()> {
365 let json = serde_json::to_string_pretty(config)
366 .map_err(|e| TokenizerError::InternalError(e.to_string()))?;
367
368 std::fs::write(path, json)
369 .map_err(|e| TokenizerError::InternalError(format!("Failed to write config: {}", e)))?;
370
371 Ok(())
372}
373
374pub fn load_config<P: AsRef<Path>>(path: P) -> TokenizerResult<TrainingConfig> {
376 let json = std::fs::read_to_string(path)
377 .map_err(|e| TokenizerError::InternalError(format!("Failed to read config: {}", e)))?;
378
379 let config = serde_json::from_str(&json)
380 .map_err(|e| TokenizerError::InternalError(format!("Failed to parse config: {}", e)))?;
381
382 Ok(config)
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use std::env;
389
390 #[test]
391 fn test_model_version() {
392 let v1 = ModelVersion::new(1, 2, 3);
393 assert_eq!(v1.to_string(), "1.2.3");
394
395 let v2 = ModelVersion::parse("1.2.3").unwrap();
396 assert_eq!(v1, v2);
397
398 assert!(v1.is_compatible_with(&v2));
399
400 let v3 = ModelVersion::new(2, 0, 0);
401 assert!(!v1.is_compatible_with(&v3));
402 }
403
404 #[test]
405 fn test_model_metadata() {
406 let version = ModelVersion::new(1, 0, 0);
407 let mut metadata = ModelMetadata::new(version, "TestModel".to_string(), 8, 16);
408
409 metadata.add_custom("author".to_string(), "Test User".to_string());
410 assert_eq!(metadata.custom.get("author").unwrap(), "Test User");
411
412 let before = metadata.modified_at;
413 std::thread::sleep(std::time::Duration::from_millis(10));
414 metadata.touch();
415 assert!(metadata.modified_at > before);
416 }
417
418 #[test]
419 fn test_checkpoint_creation() {
420 let version = ModelVersion::new(1, 0, 0);
421 let metadata = ModelMetadata::new(version, "TestModel".to_string(), 8, 16);
422 let mut checkpoint = ModelCheckpoint::new(metadata);
423
424 let data = vec![1.0, 2.0, 3.0, 4.0];
426 let shape = vec![2, 2];
427 checkpoint.add_weight("test_weight".to_string(), data.clone(), shape.clone());
428
429 let (retrieved_data, retrieved_shape) = checkpoint.get_weight("test_weight").unwrap();
431 assert_eq!(retrieved_data, &data[..]);
432 assert_eq!(retrieved_shape, &shape[..]);
433 }
434
435 #[test]
436 fn test_checkpoint_array2() {
437 let version = ModelVersion::new(1, 0, 0);
438 let metadata = ModelMetadata::new(version, "TestModel".to_string(), 8, 16);
439 let mut checkpoint = ModelCheckpoint::new(metadata);
440
441 let mut array = Array2::zeros((2, 3));
443 array[[0, 0]] = 1.0;
444 array[[0, 1]] = 2.0;
445 array[[0, 2]] = 3.0;
446 array[[1, 0]] = 4.0;
447 array[[1, 1]] = 5.0;
448 array[[1, 2]] = 6.0;
449
450 checkpoint.add_array2("matrix".to_string(), &array);
451
452 let retrieved = checkpoint.get_array2("matrix").unwrap();
454 assert_eq!(retrieved.shape(), &[2, 3]);
455 assert_eq!(retrieved[[0, 0]], 1.0);
456 assert_eq!(retrieved[[1, 2]], 6.0);
457 }
458
459 #[test]
460 fn test_checkpoint_save_load() {
461 let temp_dir = env::temp_dir();
462 let checkpoint_path = temp_dir.join("test_checkpoint.safetensors");
463
464 let version = ModelVersion::new(1, 0, 0);
466 let mut metadata = ModelMetadata::new(version, "TestModel".to_string(), 4, 8);
467 metadata.add_custom("test".to_string(), "value".to_string());
468
469 let mut checkpoint = ModelCheckpoint::new(metadata);
470
471 let mut encoder = Array2::zeros((4, 8));
472 for i in 0..4 {
473 for j in 0..8 {
474 encoder[[i, j]] = (i * 8 + j) as f32;
475 }
476 }
477 checkpoint.add_array2("encoder".to_string(), &encoder);
478
479 checkpoint.save(&checkpoint_path).unwrap();
480
481 let loaded = ModelCheckpoint::load(&checkpoint_path).unwrap();
483 assert_eq!(loaded.metadata.model_type, "TestModel");
484 assert_eq!(loaded.metadata.input_dim, 4);
485 assert_eq!(loaded.metadata.embed_dim, 8);
486 assert_eq!(loaded.metadata.custom.get("test").unwrap(), "value");
487
488 let loaded_encoder = loaded.get_array2("encoder").unwrap();
489 assert_eq!(loaded_encoder.shape(), &[4, 8]);
490 assert_eq!(loaded_encoder[[0, 0]], 0.0);
491 assert_eq!(loaded_encoder[[3, 7]], 31.0);
492
493 std::fs::remove_file(&checkpoint_path).ok();
495 }
496
497 #[test]
498 fn test_save_load_config() {
499 let temp_dir = env::temp_dir();
500 let config_path = temp_dir.join("test_config.json");
501
502 let config = TrainingConfig {
503 learning_rate: 0.001,
504 num_epochs: 50,
505 batch_size: 16,
506 ..Default::default()
507 };
508
509 save_config(&config, &config_path).unwrap();
510 let loaded_config = load_config(&config_path).unwrap();
511
512 assert_eq!(loaded_config.learning_rate, 0.001);
513 assert_eq!(loaded_config.num_epochs, 50);
514 assert_eq!(loaded_config.batch_size, 16);
515
516 std::fs::remove_file(&config_path).ok();
518 }
519}