1use crate::error::{TspError, TspResult};
7use crate::solver::TspAlgorithm;
8use std::io::{Read, Write};
9use std::path::Path;
10
11const MAGIC: &[u8; 4] = b"APR\x00";
13
14const VERSION: u32 = 1;
16
17const MODEL_TYPE_TSP: u32 = 0x54_53_50_00; #[derive(Debug, Clone)]
22pub enum TspParams {
23 Aco {
25 alpha: f64,
26 beta: f64,
27 rho: f64,
28 q0: f64,
29 num_ants: usize,
30 },
31 Tabu { tenure: usize, max_neighbors: usize },
33 Ga {
35 population_size: usize,
36 crossover_rate: f64,
37 mutation_rate: f64,
38 },
39 Hybrid {
41 ga_fraction: f64,
42 tabu_fraction: f64,
43 aco_fraction: f64,
44 },
45}
46
47impl Default for TspParams {
48 fn default() -> Self {
49 Self::Aco {
50 alpha: 1.0,
51 beta: 2.5,
52 rho: 0.1,
53 q0: 0.9,
54 num_ants: 20,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct TspModelMetadata {
62 pub trained_instances: u32,
64 pub avg_instance_size: u32,
66 pub best_known_gap: f64,
68 pub training_time_secs: f64,
70}
71
72impl Default for TspModelMetadata {
73 fn default() -> Self {
74 Self {
75 trained_instances: 0,
76 avg_instance_size: 0,
77 best_known_gap: 0.0,
78 training_time_secs: 0.0,
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct TspModel {
86 pub algorithm: TspAlgorithm,
88 pub params: TspParams,
90 pub metadata: TspModelMetadata,
92}
93
94impl TspModel {
95 pub fn new(algorithm: TspAlgorithm) -> Self {
97 let params = match algorithm {
98 TspAlgorithm::Aco => TspParams::Aco {
99 alpha: 1.0,
100 beta: 2.5,
101 rho: 0.1,
102 q0: 0.9,
103 num_ants: 20,
104 },
105 TspAlgorithm::Tabu => TspParams::Tabu {
106 tenure: 20,
107 max_neighbors: 100,
108 },
109 TspAlgorithm::Ga => TspParams::Ga {
110 population_size: 50,
111 crossover_rate: 0.9,
112 mutation_rate: 0.1,
113 },
114 TspAlgorithm::Hybrid => TspParams::Hybrid {
115 ga_fraction: 0.4,
116 tabu_fraction: 0.3,
117 aco_fraction: 0.3,
118 },
119 };
120
121 Self {
122 algorithm,
123 params,
124 metadata: TspModelMetadata::default(),
125 }
126 }
127
128 pub fn with_params(mut self, params: TspParams) -> Self {
130 self.params = params;
131 self
132 }
133
134 pub fn with_metadata(mut self, metadata: TspModelMetadata) -> Self {
136 self.metadata = metadata;
137 self
138 }
139
140 pub fn save(&self, path: &Path) -> TspResult<()> {
142 let mut file = std::fs::File::create(path)?;
143
144 let payload = self.serialize_payload();
146
147 file.write_all(MAGIC)?;
149 file.write_all(&VERSION.to_le_bytes())?;
150 file.write_all(&MODEL_TYPE_TSP.to_le_bytes())?;
151
152 let checksum = crc32fast::hash(&payload);
154 file.write_all(&checksum.to_le_bytes())?;
155
156 file.write_all(&payload)?;
158
159 Ok(())
160 }
161
162 pub fn load(path: &Path) -> TspResult<Self> {
164 let mut file = std::fs::File::open(path)?;
165 let mut data = Vec::new();
166 file.read_to_end(&mut data)?;
167
168 Self::from_bytes(&data, path)
169 }
170
171 fn from_bytes(data: &[u8], path: &Path) -> TspResult<Self> {
173 if data.len() < 16 {
175 return Err(TspError::InvalidFormat {
176 message: "File too small".into(),
177 hint: "Ensure this is a valid .apr file".into(),
178 });
179 }
180
181 if &data[0..4] != MAGIC {
183 return Err(TspError::InvalidFormat {
184 message: "Not an .apr file".into(),
185 hint: format!("Expected magic 'APR\\x00', got {:?}", &data[0..4]),
186 });
187 }
188
189 let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
191 if version != VERSION {
192 return Err(TspError::InvalidFormat {
193 message: format!("Unsupported version: {version}"),
194 hint: format!("This tool supports version {VERSION}"),
195 });
196 }
197
198 let model_type = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
200 if model_type != MODEL_TYPE_TSP {
201 return Err(TspError::InvalidFormat {
202 message: "Not a TSP model".into(),
203 hint: "This file contains a different model type".into(),
204 });
205 }
206
207 let stored_checksum = u32::from_le_bytes([data[12], data[13], data[14], data[15]]);
209 let payload = &data[16..];
210 let computed_checksum = crc32fast::hash(payload);
211
212 if stored_checksum != computed_checksum {
213 return Err(TspError::ChecksumMismatch {
214 expected: stored_checksum,
215 computed: computed_checksum,
216 });
217 }
218
219 Self::deserialize_payload(payload, path)
220 }
221
222 fn serialize_payload(&self) -> Vec<u8> {
224 let mut payload = Vec::new();
225
226 let algo_byte = match self.algorithm {
228 TspAlgorithm::Aco => 0u8,
229 TspAlgorithm::Tabu => 1u8,
230 TspAlgorithm::Ga => 2u8,
231 TspAlgorithm::Hybrid => 3u8,
232 };
233 payload.push(algo_byte);
234
235 payload.extend_from_slice(&self.metadata.trained_instances.to_le_bytes());
237 payload.extend_from_slice(&self.metadata.avg_instance_size.to_le_bytes());
238 payload.extend_from_slice(&self.metadata.best_known_gap.to_le_bytes());
239 payload.extend_from_slice(&self.metadata.training_time_secs.to_le_bytes());
240
241 match &self.params {
243 TspParams::Aco {
244 alpha,
245 beta,
246 rho,
247 q0,
248 num_ants,
249 } => {
250 payload.extend_from_slice(&alpha.to_le_bytes());
251 payload.extend_from_slice(&beta.to_le_bytes());
252 payload.extend_from_slice(&rho.to_le_bytes());
253 payload.extend_from_slice(&q0.to_le_bytes());
254 payload.extend_from_slice(&(*num_ants as u32).to_le_bytes());
255 }
256 TspParams::Tabu {
257 tenure,
258 max_neighbors,
259 } => {
260 payload.extend_from_slice(&(*tenure as u32).to_le_bytes());
261 payload.extend_from_slice(&(*max_neighbors as u32).to_le_bytes());
262 }
263 TspParams::Ga {
264 population_size,
265 crossover_rate,
266 mutation_rate,
267 } => {
268 payload.extend_from_slice(&(*population_size as u32).to_le_bytes());
269 payload.extend_from_slice(&crossover_rate.to_le_bytes());
270 payload.extend_from_slice(&mutation_rate.to_le_bytes());
271 }
272 TspParams::Hybrid {
273 ga_fraction,
274 tabu_fraction,
275 aco_fraction,
276 } => {
277 payload.extend_from_slice(&ga_fraction.to_le_bytes());
278 payload.extend_from_slice(&tabu_fraction.to_le_bytes());
279 payload.extend_from_slice(&aco_fraction.to_le_bytes());
280 }
281 }
282
283 payload
284 }
285
286 #[allow(clippy::too_many_lines)]
288 fn deserialize_payload(payload: &[u8], path: &Path) -> TspResult<Self> {
289 if payload.is_empty() {
290 return Err(TspError::ParseError {
291 file: path.to_path_buf(),
292 line: None,
293 cause: "Empty payload".into(),
294 });
295 }
296
297 let algo_byte = payload[0];
298 let algorithm = match algo_byte {
299 0 => TspAlgorithm::Aco,
300 1 => TspAlgorithm::Tabu,
301 2 => TspAlgorithm::Ga,
302 3 => TspAlgorithm::Hybrid,
303 _ => {
304 return Err(TspError::InvalidFormat {
305 message: format!("Unknown algorithm type: {algo_byte}"),
306 hint: "Supported: aco (0), tabu (1), ga (2), hybrid (3)".into(),
307 });
308 }
309 };
310
311 let min_size = 1 + 4 + 4 + 8 + 8; if payload.len() < min_size {
314 return Err(TspError::ParseError {
315 file: path.to_path_buf(),
316 line: None,
317 cause: "Payload too small for metadata".into(),
318 });
319 }
320
321 let mut offset = 1;
322
323 let trained_instances = u32::from_le_bytes([
325 payload[offset],
326 payload[offset + 1],
327 payload[offset + 2],
328 payload[offset + 3],
329 ]);
330 offset += 4;
331
332 let avg_instance_size = u32::from_le_bytes([
333 payload[offset],
334 payload[offset + 1],
335 payload[offset + 2],
336 payload[offset + 3],
337 ]);
338 offset += 4;
339
340 let best_known_gap = f64::from_le_bytes([
341 payload[offset],
342 payload[offset + 1],
343 payload[offset + 2],
344 payload[offset + 3],
345 payload[offset + 4],
346 payload[offset + 5],
347 payload[offset + 6],
348 payload[offset + 7],
349 ]);
350 offset += 8;
351
352 let training_time_secs = f64::from_le_bytes([
353 payload[offset],
354 payload[offset + 1],
355 payload[offset + 2],
356 payload[offset + 3],
357 payload[offset + 4],
358 payload[offset + 5],
359 payload[offset + 6],
360 payload[offset + 7],
361 ]);
362 offset += 8;
363
364 let metadata = TspModelMetadata {
365 trained_instances,
366 avg_instance_size,
367 best_known_gap,
368 training_time_secs,
369 };
370
371 let params =
373 match algorithm {
374 TspAlgorithm::Aco => {
375 let alpha =
376 f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
377 |_| TspError::ParseError {
378 file: path.to_path_buf(),
379 line: None,
380 cause: "Failed to read alpha".into(),
381 },
382 )?);
383 offset += 8;
384
385 let beta = f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
386 |_| TspError::ParseError {
387 file: path.to_path_buf(),
388 line: None,
389 cause: "Failed to read beta".into(),
390 },
391 )?);
392 offset += 8;
393
394 let rho = f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
395 |_| TspError::ParseError {
396 file: path.to_path_buf(),
397 line: None,
398 cause: "Failed to read rho".into(),
399 },
400 )?);
401 offset += 8;
402
403 let q0 = f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
404 |_| TspError::ParseError {
405 file: path.to_path_buf(),
406 line: None,
407 cause: "Failed to read q0".into(),
408 },
409 )?);
410 offset += 8;
411
412 let num_ants =
413 u32::from_le_bytes(payload[offset..offset + 4].try_into().map_err(
414 |_| TspError::ParseError {
415 file: path.to_path_buf(),
416 line: None,
417 cause: "Failed to read num_ants".into(),
418 },
419 )?) as usize;
420
421 TspParams::Aco {
422 alpha,
423 beta,
424 rho,
425 q0,
426 num_ants,
427 }
428 }
429 TspAlgorithm::Tabu => {
430 let tenure =
431 u32::from_le_bytes(payload[offset..offset + 4].try_into().map_err(
432 |_| TspError::ParseError {
433 file: path.to_path_buf(),
434 line: None,
435 cause: "Failed to read tenure".into(),
436 },
437 )?) as usize;
438 offset += 4;
439
440 let max_neighbors =
441 u32::from_le_bytes(payload[offset..offset + 4].try_into().map_err(
442 |_| TspError::ParseError {
443 file: path.to_path_buf(),
444 line: None,
445 cause: "Failed to read max_neighbors".into(),
446 },
447 )?) as usize;
448
449 TspParams::Tabu {
450 tenure,
451 max_neighbors,
452 }
453 }
454 TspAlgorithm::Ga => {
455 let population_size =
456 u32::from_le_bytes(payload[offset..offset + 4].try_into().map_err(
457 |_| TspError::ParseError {
458 file: path.to_path_buf(),
459 line: None,
460 cause: "Failed to read population_size".into(),
461 },
462 )?) as usize;
463 offset += 4;
464
465 let crossover_rate =
466 f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
467 |_| TspError::ParseError {
468 file: path.to_path_buf(),
469 line: None,
470 cause: "Failed to read crossover_rate".into(),
471 },
472 )?);
473 offset += 8;
474
475 let mutation_rate =
476 f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
477 |_| TspError::ParseError {
478 file: path.to_path_buf(),
479 line: None,
480 cause: "Failed to read mutation_rate".into(),
481 },
482 )?);
483
484 TspParams::Ga {
485 population_size,
486 crossover_rate,
487 mutation_rate,
488 }
489 }
490 TspAlgorithm::Hybrid => {
491 let ga_fraction =
492 f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
493 |_| TspError::ParseError {
494 file: path.to_path_buf(),
495 line: None,
496 cause: "Failed to read ga_fraction".into(),
497 },
498 )?);
499 offset += 8;
500
501 let tabu_fraction =
502 f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
503 |_| TspError::ParseError {
504 file: path.to_path_buf(),
505 line: None,
506 cause: "Failed to read tabu_fraction".into(),
507 },
508 )?);
509 offset += 8;
510
511 let aco_fraction =
512 f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
513 |_| TspError::ParseError {
514 file: path.to_path_buf(),
515 line: None,
516 cause: "Failed to read aco_fraction".into(),
517 },
518 )?);
519
520 TspParams::Hybrid {
521 ga_fraction,
522 tabu_fraction,
523 aco_fraction,
524 }
525 }
526 };
527
528 Ok(Self {
529 algorithm,
530 params,
531 metadata,
532 })
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use tempfile::TempDir;
540
541 #[test]
542 fn test_model_new_aco() {
543 let model = TspModel::new(TspAlgorithm::Aco);
544 assert_eq!(model.algorithm, TspAlgorithm::Aco);
545 assert!(matches!(model.params, TspParams::Aco { .. }));
546 }
547
548 #[test]
549 fn test_model_new_tabu() {
550 let model = TspModel::new(TspAlgorithm::Tabu);
551 assert_eq!(model.algorithm, TspAlgorithm::Tabu);
552 assert!(matches!(model.params, TspParams::Tabu { .. }));
553 }
554
555 #[test]
556 fn test_model_new_ga() {
557 let model = TspModel::new(TspAlgorithm::Ga);
558 assert_eq!(model.algorithm, TspAlgorithm::Ga);
559 assert!(matches!(model.params, TspParams::Ga { .. }));
560 }
561
562 #[test]
563 fn test_model_new_hybrid() {
564 let model = TspModel::new(TspAlgorithm::Hybrid);
565 assert_eq!(model.algorithm, TspAlgorithm::Hybrid);
566 assert!(matches!(model.params, TspParams::Hybrid { .. }));
567 }
568
569 #[test]
570 fn test_model_save_load_aco() {
571 let temp_dir = TempDir::new().unwrap();
572 let path = temp_dir.path().join("test.apr");
573
574 let model = TspModel::new(TspAlgorithm::Aco).with_params(TspParams::Aco {
575 alpha: 2.0,
576 beta: 3.5,
577 rho: 0.2,
578 q0: 0.85,
579 num_ants: 30,
580 });
581
582 model.save(&path).expect("should save");
583 let loaded = TspModel::load(&path).expect("should load");
584
585 assert_eq!(loaded.algorithm, TspAlgorithm::Aco);
586 if let TspParams::Aco {
587 alpha,
588 beta,
589 rho,
590 q0,
591 num_ants,
592 } = loaded.params
593 {
594 assert!((alpha - 2.0).abs() < 1e-10);
595 assert!((beta - 3.5).abs() < 1e-10);
596 assert!((rho - 0.2).abs() < 1e-10);
597 assert!((q0 - 0.85).abs() < 1e-10);
598 assert_eq!(num_ants, 30);
599 } else {
600 panic!("Expected ACO params");
601 }
602 }
603
604 #[test]
605 fn test_model_save_load_tabu() {
606 let temp_dir = TempDir::new().unwrap();
607 let path = temp_dir.path().join("test.apr");
608
609 let model = TspModel::new(TspAlgorithm::Tabu).with_params(TspParams::Tabu {
610 tenure: 25,
611 max_neighbors: 150,
612 });
613
614 model.save(&path).expect("should save");
615 let loaded = TspModel::load(&path).expect("should load");
616
617 assert_eq!(loaded.algorithm, TspAlgorithm::Tabu);
618 if let TspParams::Tabu {
619 tenure,
620 max_neighbors,
621 } = loaded.params
622 {
623 assert_eq!(tenure, 25);
624 assert_eq!(max_neighbors, 150);
625 } else {
626 panic!("Expected Tabu params");
627 }
628 }
629
630 #[test]
631 fn test_model_save_load_ga() {
632 let temp_dir = TempDir::new().unwrap();
633 let path = temp_dir.path().join("test.apr");
634
635 let model = TspModel::new(TspAlgorithm::Ga).with_params(TspParams::Ga {
636 population_size: 100,
637 crossover_rate: 0.85,
638 mutation_rate: 0.15,
639 });
640
641 model.save(&path).expect("should save");
642 let loaded = TspModel::load(&path).expect("should load");
643
644 assert_eq!(loaded.algorithm, TspAlgorithm::Ga);
645 if let TspParams::Ga {
646 population_size,
647 crossover_rate,
648 mutation_rate,
649 } = loaded.params
650 {
651 assert_eq!(population_size, 100);
652 assert!((crossover_rate - 0.85).abs() < 1e-10);
653 assert!((mutation_rate - 0.15).abs() < 1e-10);
654 } else {
655 panic!("Expected GA params");
656 }
657 }
658
659 #[test]
660 fn test_model_save_load_hybrid() {
661 let temp_dir = TempDir::new().unwrap();
662 let path = temp_dir.path().join("test.apr");
663
664 let model = TspModel::new(TspAlgorithm::Hybrid).with_params(TspParams::Hybrid {
665 ga_fraction: 0.5,
666 tabu_fraction: 0.25,
667 aco_fraction: 0.25,
668 });
669
670 model.save(&path).expect("should save");
671 let loaded = TspModel::load(&path).expect("should load");
672
673 assert_eq!(loaded.algorithm, TspAlgorithm::Hybrid);
674 if let TspParams::Hybrid {
675 ga_fraction,
676 tabu_fraction,
677 aco_fraction,
678 } = loaded.params
679 {
680 assert!((ga_fraction - 0.5).abs() < 1e-10);
681 assert!((tabu_fraction - 0.25).abs() < 1e-10);
682 assert!((aco_fraction - 0.25).abs() < 1e-10);
683 } else {
684 panic!("Expected Hybrid params");
685 }
686 }
687
688 #[test]
689 fn test_model_metadata_roundtrip() {
690 let temp_dir = TempDir::new().unwrap();
691 let path = temp_dir.path().join("test.apr");
692
693 let metadata = TspModelMetadata {
694 trained_instances: 10,
695 avg_instance_size: 52,
696 best_known_gap: 0.03,
697 training_time_secs: 2.5,
698 };
699
700 let model = TspModel::new(TspAlgorithm::Aco).with_metadata(metadata);
701 model.save(&path).expect("should save");
702 let loaded = TspModel::load(&path).expect("should load");
703
704 assert_eq!(loaded.metadata.trained_instances, 10);
705 assert_eq!(loaded.metadata.avg_instance_size, 52);
706 assert!((loaded.metadata.best_known_gap - 0.03).abs() < 1e-10);
707 assert!((loaded.metadata.training_time_secs - 2.5).abs() < 1e-10);
708 }
709
710 #[test]
711 fn test_model_invalid_magic() {
712 let temp_dir = TempDir::new().unwrap();
713 let path = temp_dir.path().join("bad.apr");
714
715 let mut data = vec![0u8; 20];
717 data[0..4].copy_from_slice(b"BAD\x00");
718 std::fs::write(&path, &data).unwrap();
719
720 let result = TspModel::load(&path);
721 assert!(result.is_err());
722 let err = result.unwrap_err().to_string();
723 assert!(err.contains("Not an .apr file"), "Unexpected error: {err}");
724 }
725
726 #[test]
727 fn test_model_invalid_checksum() {
728 let temp_dir = TempDir::new().unwrap();
729 let path = temp_dir.path().join("corrupt.apr");
730
731 let model = TspModel::new(TspAlgorithm::Aco);
733 model.save(&path).expect("should save");
734
735 let mut data = std::fs::read(&path).unwrap();
737 data[12] ^= 0xFF; std::fs::write(&path, &data).unwrap();
739
740 let result = TspModel::load(&path);
741 assert!(result.is_err());
742 let err = result.unwrap_err().to_string();
743 assert!(err.contains("checksum mismatch"));
744 }
745
746 #[test]
747 fn test_model_file_too_small() {
748 let temp_dir = TempDir::new().unwrap();
749 let path = temp_dir.path().join("small.apr");
750
751 std::fs::write(&path, b"APR\x00").unwrap();
753
754 let result = TspModel::load(&path);
755 assert!(result.is_err());
756 let err = result.unwrap_err().to_string();
757 assert!(err.contains("too small"));
758 }
759
760 #[test]
761 fn test_model_unsupported_version() {
762 let temp_dir = TempDir::new().unwrap();
763 let path = temp_dir.path().join("future.apr");
764
765 let mut data = Vec::new();
767 data.extend_from_slice(MAGIC);
768 data.extend_from_slice(&99u32.to_le_bytes()); data.extend_from_slice(&MODEL_TYPE_TSP.to_le_bytes());
770 data.extend_from_slice(&0u32.to_le_bytes()); std::fs::write(&path, &data).unwrap();
772
773 let result = TspModel::load(&path);
774 assert!(result.is_err());
775 let err = result.unwrap_err().to_string();
776 assert!(err.contains("version"));
777 }
778}