Skip to main content

jugar_apr/
format.rs

1//! APR file format handling.
2//!
3//! Per spec Section 4.1: File structure with magic, version, checksum.
4
5use crate::error::AprError;
6use crate::metadata::AprMetadata;
7use crate::model::{AprModel, ModelData};
8use crate::{MAX_MODEL_SIZE, MIN_SUPPORTED_VERSION};
9
10/// Magic number for APR files
11pub const APR_MAGIC: &[u8; 4] = b"APNR";
12
13/// Current APR version
14pub const APR_VERSION: u16 = 1;
15
16/// Minimum header size (magic + version + checksum)
17const HEADER_SIZE: usize = 10;
18
19/// Parsed APR file
20#[derive(Debug)]
21pub struct AprFile {
22    /// File version
23    pub version: u16,
24    /// Loaded model
25    pub model: AprModel,
26}
27
28impl AprFile {
29    /// Check if bytes start with APR magic number
30    #[must_use]
31    pub fn has_magic(bytes: &[u8]) -> bool {
32        bytes.len() >= 4 && &bytes[0..4] == APR_MAGIC
33    }
34
35    /// Parse an APR file from bytes
36    ///
37    /// # Errors
38    ///
39    /// Returns error if:
40    /// - File too small
41    /// - Invalid magic number
42    /// - Unsupported version
43    /// - Checksum mismatch
44    /// - Invalid metadata or model data
45    pub fn from_bytes(bytes: &[u8]) -> Result<Self, AprError> {
46        // Check minimum size
47        if bytes.len() < HEADER_SIZE {
48            return Err(AprError::FileTooSmall { size: bytes.len() });
49        }
50
51        // Check magic
52        if !Self::has_magic(bytes) {
53            return Err(AprError::invalid_magic(bytes));
54        }
55
56        // Read version
57        let version = u16::from_le_bytes([bytes[4], bytes[5]]);
58        if version < MIN_SUPPORTED_VERSION {
59            return Err(AprError::UnsupportedVersion {
60                version,
61                min_supported: MIN_SUPPORTED_VERSION,
62            });
63        }
64
65        // Read and verify checksum
66        let stored_checksum = u32::from_le_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]);
67        let computed_checksum = crc32fast::hash(&bytes[HEADER_SIZE..]);
68        if stored_checksum != computed_checksum {
69            return Err(AprError::ChecksumMismatch {
70                expected: stored_checksum,
71                computed: computed_checksum,
72            });
73        }
74
75        // Check size limit
76        if bytes.len() > MAX_MODEL_SIZE {
77            return Err(AprError::ModelTooLarge {
78                size: bytes.len(),
79                max: MAX_MODEL_SIZE,
80            });
81        }
82
83        // Read metadata length (4 bytes after header)
84        if bytes.len() < HEADER_SIZE + 4 {
85            return Err(AprError::FileTooSmall { size: bytes.len() });
86        }
87        let metadata_len = u32::from_le_bytes([
88            bytes[HEADER_SIZE],
89            bytes[HEADER_SIZE + 1],
90            bytes[HEADER_SIZE + 2],
91            bytes[HEADER_SIZE + 3],
92        ]) as usize;
93
94        // Validate metadata bounds
95        let metadata_start = HEADER_SIZE + 4;
96        let metadata_end = metadata_start + metadata_len;
97        if metadata_end > bytes.len() {
98            return Err(AprError::CborDecode(
99                "Metadata length exceeds file size".to_string(),
100            ));
101        }
102
103        // Parse metadata
104        let metadata = AprMetadata::from_cbor(&bytes[metadata_start..metadata_end])?;
105
106        // Parse compressed model data
107        let data_start = metadata_end;
108        let data = ModelData::decompress(&bytes[data_start..])?;
109
110        Ok(Self {
111            version,
112            model: AprModel { metadata, data },
113        })
114    }
115}
116
117#[cfg(test)]
118#[allow(clippy::unwrap_used, clippy::expect_used)]
119mod tests {
120    use super::*;
121    use crate::model::ModelArchitecture;
122
123    #[test]
124    fn test_has_magic_valid() {
125        let valid = b"APNRsomedata";
126        assert!(AprFile::has_magic(valid));
127    }
128
129    #[test]
130    fn test_has_magic_invalid() {
131        let invalid = b"WRONGdata";
132        assert!(!AprFile::has_magic(invalid));
133    }
134
135    #[test]
136    fn test_has_magic_too_short() {
137        let short = b"APR";
138        assert!(!AprFile::has_magic(short));
139    }
140
141    #[test]
142    fn test_file_too_small() {
143        let tiny = b"APNR";
144        let result = AprFile::from_bytes(tiny);
145        assert!(matches!(result, Err(AprError::FileTooSmall { .. })));
146    }
147
148    #[test]
149    fn test_invalid_magic() {
150        let bad = b"WRONG_____";
151        let result = AprFile::from_bytes(bad);
152        assert!(matches!(result, Err(AprError::InvalidMagic { .. })));
153    }
154
155    #[test]
156    fn test_version_zero_rejected() {
157        let mut bytes = Vec::new();
158        bytes.extend_from_slice(APR_MAGIC);
159        bytes.extend_from_slice(&0_u16.to_le_bytes());
160        bytes.extend_from_slice(&0_u32.to_le_bytes());
161
162        let result = AprFile::from_bytes(&bytes);
163        assert!(matches!(result, Err(AprError::UnsupportedVersion { .. })));
164    }
165
166    #[test]
167    fn test_full_roundtrip() {
168        let model = AprModel {
169            metadata: AprMetadata::builder()
170                .name("roundtrip-test")
171                .version("1.0.0")
172                .author("Test")
173                .license("MIT")
174                .build()
175                .expect("metadata"),
176            data: ModelData {
177                weights: vec![1.0, 2.0, 3.0],
178                biases: vec![0.1],
179                architecture: ModelArchitecture::Mlp {
180                    layers: vec![1, 2, 1],
181                },
182            },
183        };
184
185        let bytes = model.to_bytes().expect("serialize");
186        let loaded = AprFile::from_bytes(&bytes).expect("deserialize");
187
188        assert_eq!(loaded.model.metadata.name, "roundtrip-test");
189        assert_eq!(loaded.model.data.weights, vec![1.0, 2.0, 3.0]);
190    }
191
192    #[test]
193    fn test_checksum_corruption_detected() {
194        let model = AprModel::new_test_model();
195        let mut bytes = model.to_bytes().expect("serialize");
196
197        // Corrupt data after header
198        if bytes.len() > 20 {
199            bytes[20] ^= 0xFF;
200        }
201
202        let result = AprFile::from_bytes(&bytes);
203        assert!(matches!(result, Err(AprError::ChecksumMismatch { .. })));
204    }
205}