1use crate::error::AprError;
6use crate::metadata::AprMetadata;
7use crate::model::{AprModel, ModelData};
8use crate::{MAX_MODEL_SIZE, MIN_SUPPORTED_VERSION};
9
10pub const APR_MAGIC: &[u8; 4] = b"APNR";
12
13pub const APR_VERSION: u16 = 1;
15
16const HEADER_SIZE: usize = 10;
18
19#[derive(Debug)]
21pub struct AprFile {
22 pub version: u16,
24 pub model: AprModel,
26}
27
28impl AprFile {
29 #[must_use]
31 pub fn has_magic(bytes: &[u8]) -> bool {
32 bytes.len() >= 4 && &bytes[0..4] == APR_MAGIC
33 }
34
35 pub fn from_bytes(bytes: &[u8]) -> Result<Self, AprError> {
46 if bytes.len() < HEADER_SIZE {
48 return Err(AprError::FileTooSmall { size: bytes.len() });
49 }
50
51 if !Self::has_magic(bytes) {
53 return Err(AprError::invalid_magic(bytes));
54 }
55
56 let version = u16::from_le_bytes([bytes[4], bytes[5]]);
58 if version < MIN_SUPPORTED_VERSION {
59 return Err(AprError::UnsupportedVersion {
60 version,
61 min_supported: MIN_SUPPORTED_VERSION,
62 });
63 }
64
65 let stored_checksum = u32::from_le_bytes([bytes[6], bytes[7], bytes[8], bytes[9]]);
67 let computed_checksum = crc32fast::hash(&bytes[HEADER_SIZE..]);
68 if stored_checksum != computed_checksum {
69 return Err(AprError::ChecksumMismatch {
70 expected: stored_checksum,
71 computed: computed_checksum,
72 });
73 }
74
75 if bytes.len() > MAX_MODEL_SIZE {
77 return Err(AprError::ModelTooLarge {
78 size: bytes.len(),
79 max: MAX_MODEL_SIZE,
80 });
81 }
82
83 if bytes.len() < HEADER_SIZE + 4 {
85 return Err(AprError::FileTooSmall { size: bytes.len() });
86 }
87 let metadata_len = u32::from_le_bytes([
88 bytes[HEADER_SIZE],
89 bytes[HEADER_SIZE + 1],
90 bytes[HEADER_SIZE + 2],
91 bytes[HEADER_SIZE + 3],
92 ]) as usize;
93
94 let metadata_start = HEADER_SIZE + 4;
96 let metadata_end = metadata_start + metadata_len;
97 if metadata_end > bytes.len() {
98 return Err(AprError::CborDecode(
99 "Metadata length exceeds file size".to_string(),
100 ));
101 }
102
103 let metadata = AprMetadata::from_cbor(&bytes[metadata_start..metadata_end])?;
105
106 let data_start = metadata_end;
108 let data = ModelData::decompress(&bytes[data_start..])?;
109
110 Ok(Self {
111 version,
112 model: AprModel { metadata, data },
113 })
114 }
115}
116
117#[cfg(test)]
118#[allow(clippy::unwrap_used, clippy::expect_used)]
119mod tests {
120 use super::*;
121 use crate::model::ModelArchitecture;
122
123 #[test]
124 fn test_has_magic_valid() {
125 let valid = b"APNRsomedata";
126 assert!(AprFile::has_magic(valid));
127 }
128
129 #[test]
130 fn test_has_magic_invalid() {
131 let invalid = b"WRONGdata";
132 assert!(!AprFile::has_magic(invalid));
133 }
134
135 #[test]
136 fn test_has_magic_too_short() {
137 let short = b"APR";
138 assert!(!AprFile::has_magic(short));
139 }
140
141 #[test]
142 fn test_file_too_small() {
143 let tiny = b"APNR";
144 let result = AprFile::from_bytes(tiny);
145 assert!(matches!(result, Err(AprError::FileTooSmall { .. })));
146 }
147
148 #[test]
149 fn test_invalid_magic() {
150 let bad = b"WRONG_____";
151 let result = AprFile::from_bytes(bad);
152 assert!(matches!(result, Err(AprError::InvalidMagic { .. })));
153 }
154
155 #[test]
156 fn test_version_zero_rejected() {
157 let mut bytes = Vec::new();
158 bytes.extend_from_slice(APR_MAGIC);
159 bytes.extend_from_slice(&0_u16.to_le_bytes());
160 bytes.extend_from_slice(&0_u32.to_le_bytes());
161
162 let result = AprFile::from_bytes(&bytes);
163 assert!(matches!(result, Err(AprError::UnsupportedVersion { .. })));
164 }
165
166 #[test]
167 fn test_full_roundtrip() {
168 let model = AprModel {
169 metadata: AprMetadata::builder()
170 .name("roundtrip-test")
171 .version("1.0.0")
172 .author("Test")
173 .license("MIT")
174 .build()
175 .expect("metadata"),
176 data: ModelData {
177 weights: vec![1.0, 2.0, 3.0],
178 biases: vec![0.1],
179 architecture: ModelArchitecture::Mlp {
180 layers: vec![1, 2, 1],
181 },
182 },
183 };
184
185 let bytes = model.to_bytes().expect("serialize");
186 let loaded = AprFile::from_bytes(&bytes).expect("deserialize");
187
188 assert_eq!(loaded.model.metadata.name, "roundtrip-test");
189 assert_eq!(loaded.model.data.weights, vec![1.0, 2.0, 3.0]);
190 }
191
192 #[test]
193 fn test_checksum_corruption_detected() {
194 let model = AprModel::new_test_model();
195 let mut bytes = model.to_bytes().expect("serialize");
196
197 if bytes.len() > 20 {
199 bytes[20] ^= 0xFF;
200 }
201
202 let result = AprFile::from_bytes(&bytes);
203 assert!(matches!(result, Err(AprError::ChecksumMismatch { .. })));
204 }
205}