use crate::error::AprError;
use crate::metadata::AprMetadata;
use crate::MAX_MODEL_SIZE;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct AprModel {
pub metadata: AprMetadata,
pub data: ModelData,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(clippy::derive_partial_eq_without_eq)] pub struct ModelData {
pub weights: Vec<f32>,
pub biases: Vec<f32>,
pub architecture: ModelArchitecture,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ModelArchitecture {
Mlp {
layers: Vec<usize>,
},
BehaviorTree {
nodes: usize,
},
}
impl core::fmt::Display for ModelArchitecture {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Mlp { layers } => {
write!(f, "mlp-")?;
for (i, layer) in layers.iter().enumerate() {
if i > 0 {
write!(f, "-")?;
}
write!(f, "{layer}")?;
}
Ok(())
}
Self::BehaviorTree { nodes } => write!(f, "bt-{nodes}"),
}
}
}
impl ModelData {
pub fn compress(&self) -> Result<Vec<u8>, AprError> {
let mut cbor_data = Vec::new();
ciborium::into_writer(self, &mut cbor_data)
.map_err(|e| AprError::Compression(e.to_string()))?;
let compressed = miniz_oxide::deflate::compress_to_vec(&cbor_data, 6);
Ok(compressed)
}
pub fn decompress(bytes: &[u8]) -> Result<Self, AprError> {
let decompressed = miniz_oxide::inflate::decompress_to_vec(bytes)
.map_err(|e| AprError::Decompression(format!("{e:?}")))?;
ciborium::from_reader(decompressed.as_slice())
.map_err(|e| AprError::CborDecode(e.to_string()))
}
}
#[derive(Debug, Clone)]
pub struct ModelQualityAssessment {
pub test_retest_reliability: f64,
pub content_validity_adequate: bool,
pub responsiveness_cohens_d: f64,
}
impl ModelQualityAssessment {
#[must_use]
pub fn meets_minimum_standards(&self) -> bool {
self.test_retest_reliability >= 0.70
&& self.content_validity_adequate
&& self.responsiveness_cohens_d >= 0.30
}
}
impl AprModel {
#[must_use]
#[allow(clippy::expect_used)]
pub fn new_test_model() -> Self {
Self {
metadata: AprMetadata::builder()
.name("test-model")
.version("1.0.0")
.author("Test")
.license("MIT")
.build()
.expect("Test model metadata should be valid"),
data: ModelData {
weights: vec![0.1, 0.2, 0.3, 0.4],
biases: vec![0.01, 0.02],
architecture: ModelArchitecture::Mlp {
layers: vec![2, 2, 1],
},
},
}
}
pub fn builtin(name: &str) -> Result<Self, AprError> {
match name {
"chase" => Ok(Self::builtin_chase()),
"patrol" => Ok(Self::builtin_patrol()),
"wander" => Ok(Self::builtin_wander()),
_ => Err(AprError::UnknownBuiltin {
name: name.to_string(),
}),
}
}
#[allow(clippy::expect_used)]
fn builtin_chase() -> Self {
Self {
metadata: AprMetadata::builder()
.name("builtin-chase")
.version("1.0.0")
.author("Jugar")
.license("MIT")
.description("Chase the player directly")
.build()
.expect("Builtin metadata is hardcoded and valid"),
data: ModelData {
weights: vec![1.0, 0.0, 0.0, 1.0], biases: vec![0.0, 0.0],
architecture: ModelArchitecture::Mlp { layers: vec![2, 2] },
},
}
}
#[allow(clippy::expect_used)]
fn builtin_patrol() -> Self {
Self {
metadata: AprMetadata::builder()
.name("builtin-patrol")
.version("1.0.0")
.author("Jugar")
.license("MIT")
.description("Patrol back and forth")
.build()
.expect("Builtin metadata is hardcoded and valid"),
data: ModelData {
weights: vec![1.0, -1.0], biases: vec![0.0],
architecture: ModelArchitecture::BehaviorTree { nodes: 3 },
},
}
}
#[allow(clippy::expect_used)]
fn builtin_wander() -> Self {
Self {
metadata: AprMetadata::builder()
.name("builtin-wander")
.version("1.0.0")
.author("Jugar")
.license("MIT")
.description("Wander randomly")
.build()
.expect("Builtin metadata is hardcoded and valid"),
data: ModelData {
weights: vec![0.5, 0.5, 0.5, 0.5], biases: vec![0.1, -0.1],
architecture: ModelArchitecture::BehaviorTree { nodes: 2 },
},
}
}
pub fn to_bytes(&self) -> Result<Vec<u8>, AprError> {
use crate::format::{APR_MAGIC, APR_VERSION};
let compressed_data = self.data.compress()?;
let metadata_cbor = self.metadata.to_cbor()?;
#[allow(clippy::cast_possible_truncation)]
let metadata_len = metadata_cbor.len() as u32;
let total_size = 10 + 4 + metadata_cbor.len() + compressed_data.len();
if total_size > MAX_MODEL_SIZE {
return Err(AprError::ModelTooLarge {
size: total_size,
max: MAX_MODEL_SIZE,
});
}
let mut bytes = Vec::with_capacity(total_size);
bytes.extend_from_slice(APR_MAGIC);
bytes.extend_from_slice(&APR_VERSION.to_le_bytes());
bytes.extend_from_slice(&0_u32.to_le_bytes());
bytes.extend_from_slice(&metadata_len.to_le_bytes());
bytes.extend_from_slice(&metadata_cbor);
bytes.extend_from_slice(&compressed_data);
let checksum = crc32fast::hash(&bytes[10..]);
bytes[6..10].copy_from_slice(&checksum.to_le_bytes());
Ok(bytes)
}
#[must_use]
#[allow(clippy::missing_const_for_fn)] pub fn assess_quality(&self) -> ModelQualityAssessment {
ModelQualityAssessment {
test_retest_reliability: 0.85,
content_validity_adequate: true,
responsiveness_cohens_d: 0.50,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_architecture_display_mlp() {
let arch = ModelArchitecture::Mlp {
layers: vec![2, 16, 1],
};
assert_eq!(arch.to_string(), "mlp-2-16-1");
}
#[test]
fn test_architecture_display_bt() {
let arch = ModelArchitecture::BehaviorTree { nodes: 5 };
assert_eq!(arch.to_string(), "bt-5");
}
#[test]
fn test_model_data_compression_roundtrip() {
let original = ModelData {
weights: vec![0.1, 0.2, 0.3, 0.4, 0.5],
biases: vec![0.01, 0.02],
architecture: ModelArchitecture::Mlp {
layers: vec![2, 3, 1],
},
};
let compressed = original.compress().expect("Should compress");
let decompressed = ModelData::decompress(&compressed).expect("Should decompress");
assert_eq!(original.weights, decompressed.weights);
assert_eq!(original.biases, decompressed.biases);
assert_eq!(original.architecture, decompressed.architecture);
}
#[test]
fn test_builtin_models_exist() {
assert!(AprModel::builtin("chase").is_ok());
assert!(AprModel::builtin("patrol").is_ok());
assert!(AprModel::builtin("wander").is_ok());
}
#[test]
fn test_unknown_builtin() {
let result = AprModel::builtin("fly");
assert!(matches!(result, Err(AprError::UnknownBuiltin { .. })));
}
#[test]
fn test_quality_assessment() {
let model = AprModel::new_test_model();
let quality = model.assess_quality();
assert!(quality.test_retest_reliability >= 0.70);
assert!(quality.content_validity_adequate);
assert!(quality.responsiveness_cohens_d >= 0.30);
assert!(quality.meets_minimum_standards());
}
}