aprender_tsp/
model.rs

1//! TSP model persistence in .apr format.
2//!
3//! Toyota Way Principle: *Standardized Work* - Consistent .apr format enables
4//! reproducible results across environments.
5
6use crate::error::{TspError, TspResult};
7use crate::solver::TspAlgorithm;
8use std::io::{Read, Write};
9use std::path::Path;
10
11/// Magic number for TSP .apr files
12const MAGIC: &[u8; 4] = b"APR\x00";
13
14/// Version number
15const VERSION: u32 = 1;
16
17/// Model type identifier for TSP
18const MODEL_TYPE_TSP: u32 = 0x54_53_50_00; // "TSP\x00"
19
20/// Algorithm-specific parameters
21#[derive(Debug, Clone)]
22pub enum TspParams {
23    /// ACO parameters
24    Aco {
25        alpha: f64,
26        beta: f64,
27        rho: f64,
28        q0: f64,
29        num_ants: usize,
30    },
31    /// Tabu Search parameters
32    Tabu { tenure: usize, max_neighbors: usize },
33    /// Genetic Algorithm parameters
34    Ga {
35        population_size: usize,
36        crossover_rate: f64,
37        mutation_rate: f64,
38    },
39    /// Hybrid parameters
40    Hybrid {
41        ga_fraction: f64,
42        tabu_fraction: f64,
43        aco_fraction: f64,
44    },
45}
46
47impl Default for TspParams {
48    fn default() -> Self {
49        Self::Aco {
50            alpha: 1.0,
51            beta: 2.5,
52            rho: 0.1,
53            q0: 0.9,
54            num_ants: 20,
55        }
56    }
57}
58
59/// Training metadata
60#[derive(Debug, Clone)]
61pub struct TspModelMetadata {
62    /// Number of instances used for training
63    pub trained_instances: u32,
64    /// Average instance size
65    pub avg_instance_size: u32,
66    /// Best known gap achieved during training
67    pub best_known_gap: f64,
68    /// Training time in seconds
69    pub training_time_secs: f64,
70}
71
72impl Default for TspModelMetadata {
73    fn default() -> Self {
74        Self {
75            trained_instances: 0,
76            avg_instance_size: 0,
77            best_known_gap: 0.0,
78            training_time_secs: 0.0,
79        }
80    }
81}
82
83/// TSP model persisted in .apr format
84#[derive(Debug, Clone)]
85pub struct TspModel {
86    /// Solver algorithm
87    pub algorithm: TspAlgorithm,
88    /// Learned parameters (algorithm-specific)
89    pub params: TspParams,
90    /// Training metadata
91    pub metadata: TspModelMetadata,
92}
93
94impl TspModel {
95    /// Create a new TSP model with default ACO parameters
96    pub fn new(algorithm: TspAlgorithm) -> Self {
97        let params = match algorithm {
98            TspAlgorithm::Aco => TspParams::Aco {
99                alpha: 1.0,
100                beta: 2.5,
101                rho: 0.1,
102                q0: 0.9,
103                num_ants: 20,
104            },
105            TspAlgorithm::Tabu => TspParams::Tabu {
106                tenure: 20,
107                max_neighbors: 100,
108            },
109            TspAlgorithm::Ga => TspParams::Ga {
110                population_size: 50,
111                crossover_rate: 0.9,
112                mutation_rate: 0.1,
113            },
114            TspAlgorithm::Hybrid => TspParams::Hybrid {
115                ga_fraction: 0.4,
116                tabu_fraction: 0.3,
117                aco_fraction: 0.3,
118            },
119        };
120
121        Self {
122            algorithm,
123            params,
124            metadata: TspModelMetadata::default(),
125        }
126    }
127
128    /// Set parameters
129    pub fn with_params(mut self, params: TspParams) -> Self {
130        self.params = params;
131        self
132    }
133
134    /// Set metadata
135    pub fn with_metadata(mut self, metadata: TspModelMetadata) -> Self {
136        self.metadata = metadata;
137        self
138    }
139
140    /// Save model to .apr file
141    pub fn save(&self, path: &Path) -> TspResult<()> {
142        let mut file = std::fs::File::create(path)?;
143
144        // Serialize payload first to compute checksum
145        let payload = self.serialize_payload();
146
147        // Write header
148        file.write_all(MAGIC)?;
149        file.write_all(&VERSION.to_le_bytes())?;
150        file.write_all(&MODEL_TYPE_TSP.to_le_bytes())?;
151
152        // Compute and write checksum
153        let checksum = crc32fast::hash(&payload);
154        file.write_all(&checksum.to_le_bytes())?;
155
156        // Write payload
157        file.write_all(&payload)?;
158
159        Ok(())
160    }
161
162    /// Load model from .apr file
163    pub fn load(path: &Path) -> TspResult<Self> {
164        let mut file = std::fs::File::open(path)?;
165        let mut data = Vec::new();
166        file.read_to_end(&mut data)?;
167
168        Self::from_bytes(&data, path)
169    }
170
171    /// Load from bytes
172    fn from_bytes(data: &[u8], path: &Path) -> TspResult<Self> {
173        // Minimum size: magic(4) + version(4) + type(4) + checksum(4) + min_payload
174        if data.len() < 16 {
175            return Err(TspError::InvalidFormat {
176                message: "File too small".into(),
177                hint: "Ensure this is a valid .apr file".into(),
178            });
179        }
180
181        // Verify magic
182        if &data[0..4] != MAGIC {
183            return Err(TspError::InvalidFormat {
184                message: "Not an .apr file".into(),
185                hint: format!("Expected magic 'APR\\x00', got {:?}", &data[0..4]),
186            });
187        }
188
189        // Verify version
190        let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
191        if version != VERSION {
192            return Err(TspError::InvalidFormat {
193                message: format!("Unsupported version: {version}"),
194                hint: format!("This tool supports version {VERSION}"),
195            });
196        }
197
198        // Verify model type
199        let model_type = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
200        if model_type != MODEL_TYPE_TSP {
201            return Err(TspError::InvalidFormat {
202                message: "Not a TSP model".into(),
203                hint: "This file contains a different model type".into(),
204            });
205        }
206
207        // Verify checksum
208        let stored_checksum = u32::from_le_bytes([data[12], data[13], data[14], data[15]]);
209        let payload = &data[16..];
210        let computed_checksum = crc32fast::hash(payload);
211
212        if stored_checksum != computed_checksum {
213            return Err(TspError::ChecksumMismatch {
214                expected: stored_checksum,
215                computed: computed_checksum,
216            });
217        }
218
219        Self::deserialize_payload(payload, path)
220    }
221
222    /// Serialize payload (without header)
223    fn serialize_payload(&self) -> Vec<u8> {
224        let mut payload = Vec::new();
225
226        // Algorithm type (1 byte)
227        let algo_byte = match self.algorithm {
228            TspAlgorithm::Aco => 0u8,
229            TspAlgorithm::Tabu => 1u8,
230            TspAlgorithm::Ga => 2u8,
231            TspAlgorithm::Hybrid => 3u8,
232        };
233        payload.push(algo_byte);
234
235        // Metadata
236        payload.extend_from_slice(&self.metadata.trained_instances.to_le_bytes());
237        payload.extend_from_slice(&self.metadata.avg_instance_size.to_le_bytes());
238        payload.extend_from_slice(&self.metadata.best_known_gap.to_le_bytes());
239        payload.extend_from_slice(&self.metadata.training_time_secs.to_le_bytes());
240
241        // Algorithm-specific parameters
242        match &self.params {
243            TspParams::Aco {
244                alpha,
245                beta,
246                rho,
247                q0,
248                num_ants,
249            } => {
250                payload.extend_from_slice(&alpha.to_le_bytes());
251                payload.extend_from_slice(&beta.to_le_bytes());
252                payload.extend_from_slice(&rho.to_le_bytes());
253                payload.extend_from_slice(&q0.to_le_bytes());
254                payload.extend_from_slice(&(*num_ants as u32).to_le_bytes());
255            }
256            TspParams::Tabu {
257                tenure,
258                max_neighbors,
259            } => {
260                payload.extend_from_slice(&(*tenure as u32).to_le_bytes());
261                payload.extend_from_slice(&(*max_neighbors as u32).to_le_bytes());
262            }
263            TspParams::Ga {
264                population_size,
265                crossover_rate,
266                mutation_rate,
267            } => {
268                payload.extend_from_slice(&(*population_size as u32).to_le_bytes());
269                payload.extend_from_slice(&crossover_rate.to_le_bytes());
270                payload.extend_from_slice(&mutation_rate.to_le_bytes());
271            }
272            TspParams::Hybrid {
273                ga_fraction,
274                tabu_fraction,
275                aco_fraction,
276            } => {
277                payload.extend_from_slice(&ga_fraction.to_le_bytes());
278                payload.extend_from_slice(&tabu_fraction.to_le_bytes());
279                payload.extend_from_slice(&aco_fraction.to_le_bytes());
280            }
281        }
282
283        payload
284    }
285
286    /// Deserialize payload
287    #[allow(clippy::too_many_lines)]
288    fn deserialize_payload(payload: &[u8], path: &Path) -> TspResult<Self> {
289        if payload.is_empty() {
290            return Err(TspError::ParseError {
291                file: path.to_path_buf(),
292                line: None,
293                cause: "Empty payload".into(),
294            });
295        }
296
297        let algo_byte = payload[0];
298        let algorithm = match algo_byte {
299            0 => TspAlgorithm::Aco,
300            1 => TspAlgorithm::Tabu,
301            2 => TspAlgorithm::Ga,
302            3 => TspAlgorithm::Hybrid,
303            _ => {
304                return Err(TspError::InvalidFormat {
305                    message: format!("Unknown algorithm type: {algo_byte}"),
306                    hint: "Supported: aco (0), tabu (1), ga (2), hybrid (3)".into(),
307                });
308            }
309        };
310
311        // Minimum payload size check
312        let min_size = 1 + 4 + 4 + 8 + 8; // algo + metadata
313        if payload.len() < min_size {
314            return Err(TspError::ParseError {
315                file: path.to_path_buf(),
316                line: None,
317                cause: "Payload too small for metadata".into(),
318            });
319        }
320
321        let mut offset = 1;
322
323        // Metadata
324        let trained_instances = u32::from_le_bytes([
325            payload[offset],
326            payload[offset + 1],
327            payload[offset + 2],
328            payload[offset + 3],
329        ]);
330        offset += 4;
331
332        let avg_instance_size = u32::from_le_bytes([
333            payload[offset],
334            payload[offset + 1],
335            payload[offset + 2],
336            payload[offset + 3],
337        ]);
338        offset += 4;
339
340        let best_known_gap = f64::from_le_bytes([
341            payload[offset],
342            payload[offset + 1],
343            payload[offset + 2],
344            payload[offset + 3],
345            payload[offset + 4],
346            payload[offset + 5],
347            payload[offset + 6],
348            payload[offset + 7],
349        ]);
350        offset += 8;
351
352        let training_time_secs = f64::from_le_bytes([
353            payload[offset],
354            payload[offset + 1],
355            payload[offset + 2],
356            payload[offset + 3],
357            payload[offset + 4],
358            payload[offset + 5],
359            payload[offset + 6],
360            payload[offset + 7],
361        ]);
362        offset += 8;
363
364        let metadata = TspModelMetadata {
365            trained_instances,
366            avg_instance_size,
367            best_known_gap,
368            training_time_secs,
369        };
370
371        // Algorithm-specific parameters
372        let params =
373            match algorithm {
374                TspAlgorithm::Aco => {
375                    let alpha =
376                        f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
377                            |_| TspError::ParseError {
378                                file: path.to_path_buf(),
379                                line: None,
380                                cause: "Failed to read alpha".into(),
381                            },
382                        )?);
383                    offset += 8;
384
385                    let beta = f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
386                        |_| TspError::ParseError {
387                            file: path.to_path_buf(),
388                            line: None,
389                            cause: "Failed to read beta".into(),
390                        },
391                    )?);
392                    offset += 8;
393
394                    let rho = f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
395                        |_| TspError::ParseError {
396                            file: path.to_path_buf(),
397                            line: None,
398                            cause: "Failed to read rho".into(),
399                        },
400                    )?);
401                    offset += 8;
402
403                    let q0 = f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
404                        |_| TspError::ParseError {
405                            file: path.to_path_buf(),
406                            line: None,
407                            cause: "Failed to read q0".into(),
408                        },
409                    )?);
410                    offset += 8;
411
412                    let num_ants =
413                        u32::from_le_bytes(payload[offset..offset + 4].try_into().map_err(
414                            |_| TspError::ParseError {
415                                file: path.to_path_buf(),
416                                line: None,
417                                cause: "Failed to read num_ants".into(),
418                            },
419                        )?) as usize;
420
421                    TspParams::Aco {
422                        alpha,
423                        beta,
424                        rho,
425                        q0,
426                        num_ants,
427                    }
428                }
429                TspAlgorithm::Tabu => {
430                    let tenure =
431                        u32::from_le_bytes(payload[offset..offset + 4].try_into().map_err(
432                            |_| TspError::ParseError {
433                                file: path.to_path_buf(),
434                                line: None,
435                                cause: "Failed to read tenure".into(),
436                            },
437                        )?) as usize;
438                    offset += 4;
439
440                    let max_neighbors =
441                        u32::from_le_bytes(payload[offset..offset + 4].try_into().map_err(
442                            |_| TspError::ParseError {
443                                file: path.to_path_buf(),
444                                line: None,
445                                cause: "Failed to read max_neighbors".into(),
446                            },
447                        )?) as usize;
448
449                    TspParams::Tabu {
450                        tenure,
451                        max_neighbors,
452                    }
453                }
454                TspAlgorithm::Ga => {
455                    let population_size =
456                        u32::from_le_bytes(payload[offset..offset + 4].try_into().map_err(
457                            |_| TspError::ParseError {
458                                file: path.to_path_buf(),
459                                line: None,
460                                cause: "Failed to read population_size".into(),
461                            },
462                        )?) as usize;
463                    offset += 4;
464
465                    let crossover_rate =
466                        f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
467                            |_| TspError::ParseError {
468                                file: path.to_path_buf(),
469                                line: None,
470                                cause: "Failed to read crossover_rate".into(),
471                            },
472                        )?);
473                    offset += 8;
474
475                    let mutation_rate =
476                        f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
477                            |_| TspError::ParseError {
478                                file: path.to_path_buf(),
479                                line: None,
480                                cause: "Failed to read mutation_rate".into(),
481                            },
482                        )?);
483
484                    TspParams::Ga {
485                        population_size,
486                        crossover_rate,
487                        mutation_rate,
488                    }
489                }
490                TspAlgorithm::Hybrid => {
491                    let ga_fraction =
492                        f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
493                            |_| TspError::ParseError {
494                                file: path.to_path_buf(),
495                                line: None,
496                                cause: "Failed to read ga_fraction".into(),
497                            },
498                        )?);
499                    offset += 8;
500
501                    let tabu_fraction =
502                        f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
503                            |_| TspError::ParseError {
504                                file: path.to_path_buf(),
505                                line: None,
506                                cause: "Failed to read tabu_fraction".into(),
507                            },
508                        )?);
509                    offset += 8;
510
511                    let aco_fraction =
512                        f64::from_le_bytes(payload[offset..offset + 8].try_into().map_err(
513                            |_| TspError::ParseError {
514                                file: path.to_path_buf(),
515                                line: None,
516                                cause: "Failed to read aco_fraction".into(),
517                            },
518                        )?);
519
520                    TspParams::Hybrid {
521                        ga_fraction,
522                        tabu_fraction,
523                        aco_fraction,
524                    }
525                }
526            };
527
528        Ok(Self {
529            algorithm,
530            params,
531            metadata,
532        })
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539    use tempfile::TempDir;
540
541    #[test]
542    fn test_model_new_aco() {
543        let model = TspModel::new(TspAlgorithm::Aco);
544        assert_eq!(model.algorithm, TspAlgorithm::Aco);
545        assert!(matches!(model.params, TspParams::Aco { .. }));
546    }
547
548    #[test]
549    fn test_model_new_tabu() {
550        let model = TspModel::new(TspAlgorithm::Tabu);
551        assert_eq!(model.algorithm, TspAlgorithm::Tabu);
552        assert!(matches!(model.params, TspParams::Tabu { .. }));
553    }
554
555    #[test]
556    fn test_model_new_ga() {
557        let model = TspModel::new(TspAlgorithm::Ga);
558        assert_eq!(model.algorithm, TspAlgorithm::Ga);
559        assert!(matches!(model.params, TspParams::Ga { .. }));
560    }
561
562    #[test]
563    fn test_model_new_hybrid() {
564        let model = TspModel::new(TspAlgorithm::Hybrid);
565        assert_eq!(model.algorithm, TspAlgorithm::Hybrid);
566        assert!(matches!(model.params, TspParams::Hybrid { .. }));
567    }
568
569    #[test]
570    fn test_model_save_load_aco() {
571        let temp_dir = TempDir::new().unwrap();
572        let path = temp_dir.path().join("test.apr");
573
574        let model = TspModel::new(TspAlgorithm::Aco).with_params(TspParams::Aco {
575            alpha: 2.0,
576            beta: 3.5,
577            rho: 0.2,
578            q0: 0.85,
579            num_ants: 30,
580        });
581
582        model.save(&path).expect("should save");
583        let loaded = TspModel::load(&path).expect("should load");
584
585        assert_eq!(loaded.algorithm, TspAlgorithm::Aco);
586        if let TspParams::Aco {
587            alpha,
588            beta,
589            rho,
590            q0,
591            num_ants,
592        } = loaded.params
593        {
594            assert!((alpha - 2.0).abs() < 1e-10);
595            assert!((beta - 3.5).abs() < 1e-10);
596            assert!((rho - 0.2).abs() < 1e-10);
597            assert!((q0 - 0.85).abs() < 1e-10);
598            assert_eq!(num_ants, 30);
599        } else {
600            panic!("Expected ACO params");
601        }
602    }
603
604    #[test]
605    fn test_model_save_load_tabu() {
606        let temp_dir = TempDir::new().unwrap();
607        let path = temp_dir.path().join("test.apr");
608
609        let model = TspModel::new(TspAlgorithm::Tabu).with_params(TspParams::Tabu {
610            tenure: 25,
611            max_neighbors: 150,
612        });
613
614        model.save(&path).expect("should save");
615        let loaded = TspModel::load(&path).expect("should load");
616
617        assert_eq!(loaded.algorithm, TspAlgorithm::Tabu);
618        if let TspParams::Tabu {
619            tenure,
620            max_neighbors,
621        } = loaded.params
622        {
623            assert_eq!(tenure, 25);
624            assert_eq!(max_neighbors, 150);
625        } else {
626            panic!("Expected Tabu params");
627        }
628    }
629
630    #[test]
631    fn test_model_save_load_ga() {
632        let temp_dir = TempDir::new().unwrap();
633        let path = temp_dir.path().join("test.apr");
634
635        let model = TspModel::new(TspAlgorithm::Ga).with_params(TspParams::Ga {
636            population_size: 100,
637            crossover_rate: 0.85,
638            mutation_rate: 0.15,
639        });
640
641        model.save(&path).expect("should save");
642        let loaded = TspModel::load(&path).expect("should load");
643
644        assert_eq!(loaded.algorithm, TspAlgorithm::Ga);
645        if let TspParams::Ga {
646            population_size,
647            crossover_rate,
648            mutation_rate,
649        } = loaded.params
650        {
651            assert_eq!(population_size, 100);
652            assert!((crossover_rate - 0.85).abs() < 1e-10);
653            assert!((mutation_rate - 0.15).abs() < 1e-10);
654        } else {
655            panic!("Expected GA params");
656        }
657    }
658
659    #[test]
660    fn test_model_save_load_hybrid() {
661        let temp_dir = TempDir::new().unwrap();
662        let path = temp_dir.path().join("test.apr");
663
664        let model = TspModel::new(TspAlgorithm::Hybrid).with_params(TspParams::Hybrid {
665            ga_fraction: 0.5,
666            tabu_fraction: 0.25,
667            aco_fraction: 0.25,
668        });
669
670        model.save(&path).expect("should save");
671        let loaded = TspModel::load(&path).expect("should load");
672
673        assert_eq!(loaded.algorithm, TspAlgorithm::Hybrid);
674        if let TspParams::Hybrid {
675            ga_fraction,
676            tabu_fraction,
677            aco_fraction,
678        } = loaded.params
679        {
680            assert!((ga_fraction - 0.5).abs() < 1e-10);
681            assert!((tabu_fraction - 0.25).abs() < 1e-10);
682            assert!((aco_fraction - 0.25).abs() < 1e-10);
683        } else {
684            panic!("Expected Hybrid params");
685        }
686    }
687
688    #[test]
689    fn test_model_metadata_roundtrip() {
690        let temp_dir = TempDir::new().unwrap();
691        let path = temp_dir.path().join("test.apr");
692
693        let metadata = TspModelMetadata {
694            trained_instances: 10,
695            avg_instance_size: 52,
696            best_known_gap: 0.03,
697            training_time_secs: 2.5,
698        };
699
700        let model = TspModel::new(TspAlgorithm::Aco).with_metadata(metadata);
701        model.save(&path).expect("should save");
702        let loaded = TspModel::load(&path).expect("should load");
703
704        assert_eq!(loaded.metadata.trained_instances, 10);
705        assert_eq!(loaded.metadata.avg_instance_size, 52);
706        assert!((loaded.metadata.best_known_gap - 0.03).abs() < 1e-10);
707        assert!((loaded.metadata.training_time_secs - 2.5).abs() < 1e-10);
708    }
709
710    #[test]
711    fn test_model_invalid_magic() {
712        let temp_dir = TempDir::new().unwrap();
713        let path = temp_dir.path().join("bad.apr");
714
715        // Write invalid magic (must be at least 16 bytes to pass size check)
716        let mut data = vec![0u8; 20];
717        data[0..4].copy_from_slice(b"BAD\x00");
718        std::fs::write(&path, &data).unwrap();
719
720        let result = TspModel::load(&path);
721        assert!(result.is_err());
722        let err = result.unwrap_err().to_string();
723        assert!(err.contains("Not an .apr file"), "Unexpected error: {err}");
724    }
725
726    #[test]
727    fn test_model_invalid_checksum() {
728        let temp_dir = TempDir::new().unwrap();
729        let path = temp_dir.path().join("corrupt.apr");
730
731        // Create valid model
732        let model = TspModel::new(TspAlgorithm::Aco);
733        model.save(&path).expect("should save");
734
735        // Corrupt the checksum
736        let mut data = std::fs::read(&path).unwrap();
737        data[12] ^= 0xFF; // Flip bits in checksum
738        std::fs::write(&path, &data).unwrap();
739
740        let result = TspModel::load(&path);
741        assert!(result.is_err());
742        let err = result.unwrap_err().to_string();
743        assert!(err.contains("checksum mismatch"));
744    }
745
746    #[test]
747    fn test_model_file_too_small() {
748        let temp_dir = TempDir::new().unwrap();
749        let path = temp_dir.path().join("small.apr");
750
751        // Write too-small file
752        std::fs::write(&path, b"APR\x00").unwrap();
753
754        let result = TspModel::load(&path);
755        assert!(result.is_err());
756        let err = result.unwrap_err().to_string();
757        assert!(err.contains("too small"));
758    }
759
760    #[test]
761    fn test_model_unsupported_version() {
762        let temp_dir = TempDir::new().unwrap();
763        let path = temp_dir.path().join("future.apr");
764
765        // Write header with future version
766        let mut data = Vec::new();
767        data.extend_from_slice(MAGIC);
768        data.extend_from_slice(&99u32.to_le_bytes()); // Future version
769        data.extend_from_slice(&MODEL_TYPE_TSP.to_le_bytes());
770        data.extend_from_slice(&0u32.to_le_bytes()); // Fake checksum
771        std::fs::write(&path, &data).unwrap();
772
773        let result = TspModel::load(&path);
774        assert!(result.is_err());
775        let err = result.unwrap_err().to_string();
776        assert!(err.contains("version"));
777    }
778}