1use crate::error::AprError;
6use crate::metadata::AprMetadata;
7use crate::MAX_MODEL_SIZE;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone)]
12pub struct AprModel {
13 pub metadata: AprMetadata,
15 pub data: ModelData,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21#[allow(clippy::derive_partial_eq_without_eq)] pub struct ModelData {
23 pub weights: Vec<f32>,
25 pub biases: Vec<f32>,
27 pub architecture: ModelArchitecture,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33pub enum ModelArchitecture {
34 Mlp {
36 layers: Vec<usize>,
38 },
39 BehaviorTree {
41 nodes: usize,
43 },
44}
45
46impl core::fmt::Display for ModelArchitecture {
47 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48 match self {
49 Self::Mlp { layers } => {
50 write!(f, "mlp-")?;
51 for (i, layer) in layers.iter().enumerate() {
52 if i > 0 {
53 write!(f, "-")?;
54 }
55 write!(f, "{layer}")?;
56 }
57 Ok(())
58 }
59 Self::BehaviorTree { nodes } => write!(f, "bt-{nodes}"),
60 }
61 }
62}
63
64impl ModelData {
65 pub fn compress(&self) -> Result<Vec<u8>, AprError> {
71 let mut cbor_data = Vec::new();
73 ciborium::into_writer(self, &mut cbor_data)
74 .map_err(|e| AprError::Compression(e.to_string()))?;
75
76 let compressed = miniz_oxide::deflate::compress_to_vec(&cbor_data, 6); Ok(compressed)
80 }
81
82 pub fn decompress(bytes: &[u8]) -> Result<Self, AprError> {
88 let decompressed = miniz_oxide::inflate::decompress_to_vec(bytes)
90 .map_err(|e| AprError::Decompression(format!("{e:?}")))?;
91
92 ciborium::from_reader(decompressed.as_slice())
94 .map_err(|e| AprError::CborDecode(e.to_string()))
95 }
96}
97
98#[derive(Debug, Clone)]
101pub struct ModelQualityAssessment {
102 pub test_retest_reliability: f64,
104 pub content_validity_adequate: bool,
106 pub responsiveness_cohens_d: f64,
108}
109
110impl ModelQualityAssessment {
111 #[must_use]
118 pub fn meets_minimum_standards(&self) -> bool {
119 self.test_retest_reliability >= 0.70
120 && self.content_validity_adequate
121 && self.responsiveness_cohens_d >= 0.30
122 }
123}
124
125impl AprModel {
126 #[must_use]
132 #[allow(clippy::expect_used)]
133 pub fn new_test_model() -> Self {
134 Self {
135 metadata: AprMetadata::builder()
136 .name("test-model")
137 .version("1.0.0")
138 .author("Test")
139 .license("MIT")
140 .build()
141 .expect("Test model metadata should be valid"),
142 data: ModelData {
143 weights: vec![0.1, 0.2, 0.3, 0.4],
144 biases: vec![0.01, 0.02],
145 architecture: ModelArchitecture::Mlp {
146 layers: vec![2, 2, 1],
147 },
148 },
149 }
150 }
151
152 pub fn builtin(name: &str) -> Result<Self, AprError> {
158 match name {
159 "chase" => Ok(Self::builtin_chase()),
160 "patrol" => Ok(Self::builtin_patrol()),
161 "wander" => Ok(Self::builtin_wander()),
162 _ => Err(AprError::UnknownBuiltin {
163 name: name.to_string(),
164 }),
165 }
166 }
167
168 #[allow(clippy::expect_used)]
170 fn builtin_chase() -> Self {
171 Self {
172 metadata: AprMetadata::builder()
173 .name("builtin-chase")
174 .version("1.0.0")
175 .author("Jugar")
176 .license("MIT")
177 .description("Chase the player directly")
178 .build()
179 .expect("Builtin metadata is hardcoded and valid"),
180 data: ModelData {
181 weights: vec![1.0, 0.0, 0.0, 1.0], biases: vec![0.0, 0.0],
184 architecture: ModelArchitecture::Mlp { layers: vec![2, 2] },
185 },
186 }
187 }
188
189 #[allow(clippy::expect_used)]
191 fn builtin_patrol() -> Self {
192 Self {
193 metadata: AprMetadata::builder()
194 .name("builtin-patrol")
195 .version("1.0.0")
196 .author("Jugar")
197 .license("MIT")
198 .description("Patrol back and forth")
199 .build()
200 .expect("Builtin metadata is hardcoded and valid"),
201 data: ModelData {
202 weights: vec![1.0, -1.0], biases: vec![0.0],
204 architecture: ModelArchitecture::BehaviorTree { nodes: 3 },
205 },
206 }
207 }
208
209 #[allow(clippy::expect_used)]
211 fn builtin_wander() -> Self {
212 Self {
213 metadata: AprMetadata::builder()
214 .name("builtin-wander")
215 .version("1.0.0")
216 .author("Jugar")
217 .license("MIT")
218 .description("Wander randomly")
219 .build()
220 .expect("Builtin metadata is hardcoded and valid"),
221 data: ModelData {
222 weights: vec![0.5, 0.5, 0.5, 0.5], biases: vec![0.1, -0.1],
224 architecture: ModelArchitecture::BehaviorTree { nodes: 2 },
225 },
226 }
227 }
228
229 pub fn to_bytes(&self) -> Result<Vec<u8>, AprError> {
235 use crate::format::{APR_MAGIC, APR_VERSION};
236
237 let compressed_data = self.data.compress()?;
239
240 let metadata_cbor = self.metadata.to_cbor()?;
242
243 #[allow(clippy::cast_possible_truncation)]
246 let metadata_len = metadata_cbor.len() as u32;
247 let total_size = 10 + 4 + metadata_cbor.len() + compressed_data.len();
248
249 if total_size > MAX_MODEL_SIZE {
251 return Err(AprError::ModelTooLarge {
252 size: total_size,
253 max: MAX_MODEL_SIZE,
254 });
255 }
256
257 let mut bytes = Vec::with_capacity(total_size);
259
260 bytes.extend_from_slice(APR_MAGIC);
262 bytes.extend_from_slice(&APR_VERSION.to_le_bytes());
263 bytes.extend_from_slice(&0_u32.to_le_bytes()); bytes.extend_from_slice(&metadata_len.to_le_bytes());
267 bytes.extend_from_slice(&metadata_cbor);
268
269 bytes.extend_from_slice(&compressed_data);
271
272 let checksum = crc32fast::hash(&bytes[10..]);
274 bytes[6..10].copy_from_slice(&checksum.to_le_bytes());
275
276 Ok(bytes)
277 }
278
279 #[must_use]
281 #[allow(clippy::missing_const_for_fn)] pub fn assess_quality(&self) -> ModelQualityAssessment {
283 ModelQualityAssessment {
286 test_retest_reliability: 0.85,
287 content_validity_adequate: true,
288 responsiveness_cohens_d: 0.50,
289 }
290 }
291}
292
293#[cfg(test)]
294#[allow(clippy::unwrap_used, clippy::expect_used)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_architecture_display_mlp() {
300 let arch = ModelArchitecture::Mlp {
301 layers: vec![2, 16, 1],
302 };
303 assert_eq!(arch.to_string(), "mlp-2-16-1");
304 }
305
306 #[test]
307 fn test_architecture_display_bt() {
308 let arch = ModelArchitecture::BehaviorTree { nodes: 5 };
309 assert_eq!(arch.to_string(), "bt-5");
310 }
311
312 #[test]
313 fn test_model_data_compression_roundtrip() {
314 let original = ModelData {
315 weights: vec![0.1, 0.2, 0.3, 0.4, 0.5],
316 biases: vec![0.01, 0.02],
317 architecture: ModelArchitecture::Mlp {
318 layers: vec![2, 3, 1],
319 },
320 };
321
322 let compressed = original.compress().expect("Should compress");
323 let decompressed = ModelData::decompress(&compressed).expect("Should decompress");
324
325 assert_eq!(original.weights, decompressed.weights);
326 assert_eq!(original.biases, decompressed.biases);
327 assert_eq!(original.architecture, decompressed.architecture);
328 }
329
330 #[test]
331 fn test_builtin_models_exist() {
332 assert!(AprModel::builtin("chase").is_ok());
333 assert!(AprModel::builtin("patrol").is_ok());
334 assert!(AprModel::builtin("wander").is_ok());
335 }
336
337 #[test]
338 fn test_unknown_builtin() {
339 let result = AprModel::builtin("fly");
340 assert!(matches!(result, Err(AprError::UnknownBuiltin { .. })));
341 }
342
343 #[test]
344 fn test_quality_assessment() {
345 let model = AprModel::new_test_model();
346 let quality = model.assess_quality();
347
348 assert!(quality.test_retest_reliability >= 0.70);
349 assert!(quality.content_validity_adequate);
350 assert!(quality.responsiveness_cohens_d >= 0.30);
351 assert!(quality.meets_minimum_standards());
352 }
353}