use crate::{AiError, Result};
use glam::Vec2;
use jugar_apr::{AprModel, ModelArchitecture, ModelData};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct AiComponent {
pub model_id: String,
pub state: BehaviorState,
pub difficulty: u8,
}
impl AiComponent {
#[must_use]
pub fn new(model_id: impl Into<String>) -> Self {
Self {
model_id: model_id.into(),
state: BehaviorState::default(),
difficulty: 5,
}
}
#[must_use]
pub const fn with_difficulty(mut self, difficulty: u8) -> Self {
self.difficulty = if difficulty < 1 {
1
} else if difficulty > 10 {
10
} else {
difficulty
};
self
}
}
#[derive(Debug, Clone, Default)]
pub struct BehaviorState {
pub direction: Vec2,
pub state_time: f32,
pub waypoint_index: usize,
pub internal_state: f32,
}
#[derive(Debug, Clone, Default)]
pub struct AiInputs {
pub position: Vec2,
pub target_position: Vec2,
pub distance_to_target: f32,
pub direction_to_target: Vec2,
pub dt: f32,
}
impl AiInputs {
#[must_use]
pub fn from_positions(position: Vec2, target: Vec2, dt: f32) -> Self {
let delta = target - position;
let distance = delta.length();
let direction = if distance > 0.001 {
delta / distance
} else {
Vec2::ZERO
};
Self {
position,
target_position: target,
distance_to_target: distance,
direction_to_target: direction,
dt,
}
}
#[must_use]
pub fn to_vector(&self) -> Vec<f32> {
vec![
self.direction_to_target.x,
self.direction_to_target.y,
self.distance_to_target / 100.0, self.dt,
]
}
}
#[derive(Debug, Clone, Default)]
pub struct AiOutputs {
pub movement: Vec2,
pub speed: f32,
pub action: bool,
}
impl AiOutputs {
#[must_use]
pub fn from_raw(values: &[f32]) -> Self {
let movement = if values.len() >= 2 {
Vec2::new(values[0], values[1]).normalize_or_zero()
} else {
Vec2::ZERO
};
let speed = if values.len() >= 3 {
values[2].clamp(0.0, 1.0)
} else {
1.0
};
let action = values.len() >= 4 && values[3] > 0.5;
Self {
movement,
speed,
action,
}
}
}
#[derive(Debug, Default)]
pub struct AiSystem {
models: HashMap<String, LoadedModel>,
}
#[derive(Debug, Clone)]
struct LoadedModel {
model: AprModel,
layer_weights: Vec<LayerWeights>,
}
#[derive(Debug, Clone)]
struct LayerWeights {
weights: Vec<f32>,
biases: Vec<f32>,
input_size: usize,
output_size: usize,
}
impl AiSystem {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn load_model_from_file(&mut self, id: &str, path: &str) -> Result<()> {
let bytes = std::fs::read(path).map_err(|e| AiError::PreconditionsNotMet(e.to_string()))?;
let apr_file = jugar_apr::AprFile::from_bytes(&bytes)
.map_err(|e| AiError::PreconditionsNotMet(e.to_string()))?;
self.register_model(id, apr_file.model)
}
pub fn load_builtin(&mut self, id: &str, builtin_name: &str) -> Result<()> {
let model = AprModel::builtin(builtin_name)
.map_err(|e| AiError::PreconditionsNotMet(e.to_string()))?;
self.register_model(id, model)
}
pub fn register_model(&mut self, id: &str, model: AprModel) -> Result<()> {
let layer_weights = Self::prepare_weights(&model.data)?;
let loaded = LoadedModel {
model,
layer_weights,
};
let _ = self.models.insert(id.to_string(), loaded);
Ok(())
}
fn prepare_weights(data: &ModelData) -> Result<Vec<LayerWeights>> {
match &data.architecture {
ModelArchitecture::Mlp { layers } => {
if layers.len() < 2 {
return Err(AiError::PreconditionsNotMet(
"MLP needs at least 2 layers".to_string(),
));
}
let mut result = Vec::new();
let mut weight_offset = 0;
let mut bias_offset = 0;
for i in 0..layers.len() - 1 {
let input_size = layers[i];
let output_size = layers[i + 1];
let weight_count = input_size * output_size;
let weights = if weight_offset + weight_count <= data.weights.len() {
data.weights[weight_offset..weight_offset + weight_count].to_vec()
} else {
vec![0.1; weight_count]
};
let biases = if bias_offset + output_size <= data.biases.len() {
data.biases[bias_offset..bias_offset + output_size].to_vec()
} else {
vec![0.0; output_size]
};
result.push(LayerWeights {
weights,
biases,
input_size,
output_size,
});
weight_offset += weight_count;
bias_offset += output_size;
}
Ok(result)
}
ModelArchitecture::BehaviorTree { .. } => {
Ok(Vec::new())
}
}
}
pub fn infer(&self, model_id: &str, inputs: &AiInputs) -> Result<AiOutputs> {
let loaded = self
.models
.get(model_id)
.ok_or_else(|| AiError::PreconditionsNotMet(format!("Model not found: {model_id}")))?;
match &loaded.model.data.architecture {
ModelArchitecture::Mlp { .. } => {
let raw_outputs =
Self::run_mlp_inference(&loaded.layer_weights, &inputs.to_vector());
Ok(AiOutputs::from_raw(&raw_outputs))
}
ModelArchitecture::BehaviorTree { .. } => {
Self::run_behavior_inference(&loaded.model.metadata.name, inputs)
}
}
}
fn run_mlp_inference(layers: &[LayerWeights], input: &[f32]) -> Vec<f32> {
let mut current = input.to_vec();
for layer in layers {
let mut output = vec![0.0; layer.output_size];
for (i, out) in output.iter_mut().enumerate() {
let mut sum = layer.biases.get(i).copied().unwrap_or(0.0);
for (j, &inp) in current.iter().enumerate() {
let weight_idx = i * layer.input_size + j;
let weight = layer.weights.get(weight_idx).copied().unwrap_or(0.0);
sum += weight * inp;
}
*out = sum.max(0.0);
}
current = output;
}
current.iter().map(|&x| x.tanh()).collect()
}
fn run_behavior_inference(behavior_name: &str, inputs: &AiInputs) -> Result<AiOutputs> {
match behavior_name {
"builtin-chase" => Ok(AiOutputs {
movement: inputs.direction_to_target,
speed: 1.0,
action: inputs.distance_to_target < 50.0,
}),
"builtin-patrol" => {
let phase = (inputs.position.x / 100.0).sin();
Ok(AiOutputs {
movement: Vec2::new(phase.signum(), 0.0),
speed: 0.5,
action: false,
})
}
"builtin-wander" => {
#[allow(clippy::suboptimal_flops)]
let angle = (inputs.position.x * 0.1 + inputs.position.y * 0.07).sin()
* core::f32::consts::PI;
Ok(AiOutputs {
movement: Vec2::new(angle.cos(), angle.sin()),
speed: 0.3,
action: false,
})
}
_ => Err(AiError::PreconditionsNotMet(format!(
"Unknown behavior: {behavior_name}"
))),
}
}
#[must_use]
pub fn has_model(&self, id: &str) -> bool {
self.models.contains_key(id)
}
#[must_use]
pub fn model_count(&self) -> usize {
self.models.len()
}
pub fn unload_model(&mut self, id: &str) -> bool {
self.models.remove(id).is_some()
}
}
#[derive(Debug, Default)]
pub struct YamlAiBridge {
custom_models: HashMap<String, String>,
}
impl YamlAiBridge {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register_custom(&mut self, yaml_key: &str, path: &str) {
let _ = self
.custom_models
.insert(yaml_key.to_string(), path.to_string());
}
pub fn resolve(&self, yaml_key: &str, system: &mut AiSystem) -> Result<String> {
if let Some(builtin) = yaml_key.strip_prefix("builtin:") {
let id = format!("builtin-{builtin}");
if !system.has_model(&id) {
system.load_builtin(&id, builtin)?;
}
return Ok(id);
}
if std::path::Path::new(yaml_key)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("apr"))
{
let id = yaml_key.replace(['/', '\\', '.'], "_");
if !system.has_model(&id) {
system.load_model_from_file(&id, yaml_key)?;
}
return Ok(id);
}
if let Some(path) = self.custom_models.get(yaml_key) {
let id = yaml_key.to_string();
if !system.has_model(&id) {
system.load_model_from_file(&id, path)?;
}
return Ok(id);
}
if matches!(yaml_key, "chase" | "patrol" | "wander") {
let id = format!("builtin-{yaml_key}");
if !system.has_model(&id) {
system.load_builtin(&id, yaml_key)?;
}
return Ok(id);
}
Err(AiError::PreconditionsNotMet(format!(
"Unknown AI behavior: {yaml_key}"
)))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
mod ai_component_tests {
use super::*;
#[test]
fn test_ai_component_new() {
let component = AiComponent::new("builtin:chase");
assert_eq!(component.model_id, "builtin:chase");
assert_eq!(component.difficulty, 5);
}
#[test]
fn test_ai_component_with_difficulty() {
let component = AiComponent::new("chase").with_difficulty(8);
assert_eq!(component.difficulty, 8);
}
#[test]
fn test_ai_component_difficulty_clamped() {
let low = AiComponent::new("chase").with_difficulty(0);
assert_eq!(low.difficulty, 1);
let high = AiComponent::new("chase").with_difficulty(100);
assert_eq!(high.difficulty, 10);
}
}
mod ai_inputs_tests {
use super::*;
#[test]
fn test_from_positions() {
let inputs =
AiInputs::from_positions(Vec2::new(0.0, 0.0), Vec2::new(100.0, 0.0), 0.016);
assert!((inputs.distance_to_target - 100.0).abs() < 0.01);
assert!((inputs.direction_to_target.x - 1.0).abs() < 0.01);
assert!(inputs.direction_to_target.y.abs() < 0.01);
}
#[test]
fn test_from_positions_same_point() {
let inputs =
AiInputs::from_positions(Vec2::new(50.0, 50.0), Vec2::new(50.0, 50.0), 0.016);
assert!(inputs.distance_to_target < 0.001);
assert_eq!(inputs.direction_to_target, Vec2::ZERO);
}
#[test]
fn test_to_vector() {
let inputs = AiInputs::from_positions(Vec2::ZERO, Vec2::new(100.0, 0.0), 0.016);
let vec = inputs.to_vector();
assert_eq!(vec.len(), 4);
assert!((vec[0] - 1.0).abs() < 0.01); assert!(vec[1].abs() < 0.01); assert!((vec[2] - 1.0).abs() < 0.01); }
}
mod ai_outputs_tests {
use super::*;
#[test]
fn test_from_raw() {
let outputs = AiOutputs::from_raw(&[0.5, 0.5, 0.8, 0.9]);
assert!(outputs.movement.length() > 0.0);
assert!((outputs.speed - 0.8).abs() < 0.01);
assert!(outputs.action);
}
#[test]
fn test_from_raw_empty() {
let outputs = AiOutputs::from_raw(&[]);
assert_eq!(outputs.movement, Vec2::ZERO);
assert!((outputs.speed - 1.0).abs() < 0.01);
assert!(!outputs.action);
}
#[test]
fn test_from_raw_speed_clamped() {
let outputs = AiOutputs::from_raw(&[0.0, 0.0, 2.0]);
assert!((outputs.speed - 1.0).abs() < 0.01);
let outputs2 = AiOutputs::from_raw(&[0.0, 0.0, -1.0]);
assert!(outputs2.speed.abs() < 0.01);
}
}
mod ai_system_tests {
use super::*;
#[test]
fn test_new_system() {
let system = AiSystem::new();
assert_eq!(system.model_count(), 0);
}
#[test]
fn test_load_builtin_chase() {
let mut system = AiSystem::new();
system.load_builtin("chase", "chase").unwrap();
assert!(system.has_model("chase"));
assert_eq!(system.model_count(), 1);
}
#[test]
fn test_load_builtin_patrol() {
let mut system = AiSystem::new();
system.load_builtin("patrol", "patrol").unwrap();
assert!(system.has_model("patrol"));
}
#[test]
fn test_load_builtin_wander() {
let mut system = AiSystem::new();
system.load_builtin("wander", "wander").unwrap();
assert!(system.has_model("wander"));
}
#[test]
fn test_load_unknown_builtin() {
let mut system = AiSystem::new();
let result = system.load_builtin("unknown", "unknown");
assert!(result.is_err());
}
#[test]
fn test_register_model() {
let mut system = AiSystem::new();
let model = AprModel::new_test_model();
system.register_model("test", model).unwrap();
assert!(system.has_model("test"));
}
#[test]
fn test_unload_model() {
let mut system = AiSystem::new();
system.load_builtin("chase", "chase").unwrap();
assert!(system.unload_model("chase"));
assert!(!system.has_model("chase"));
}
#[test]
fn test_infer_chase() {
let mut system = AiSystem::new();
system.load_builtin("chase", "chase").unwrap();
let inputs =
AiInputs::from_positions(Vec2::new(0.0, 0.0), Vec2::new(100.0, 0.0), 0.016);
let outputs = system.infer("chase", &inputs).unwrap();
assert!(outputs.movement.x > 0.0);
assert!((outputs.speed - 1.0).abs() < 0.01);
}
#[test]
fn test_infer_patrol() {
let mut system = AiSystem::new();
system.load_builtin("patrol", "patrol").unwrap();
let inputs = AiInputs::from_positions(Vec2::new(50.0, 0.0), Vec2::new(0.0, 0.0), 0.016);
let outputs = system.infer("patrol", &inputs).unwrap();
assert!(outputs.movement.length() > 0.0);
assert!((outputs.speed - 0.5).abs() < 0.01);
}
#[test]
fn test_infer_wander() {
let mut system = AiSystem::new();
system.load_builtin("wander", "wander").unwrap();
let inputs =
AiInputs::from_positions(Vec2::new(25.0, 75.0), Vec2::new(0.0, 0.0), 0.016);
let outputs = system.infer("wander", &inputs).unwrap();
assert!(outputs.movement.length() > 0.0);
assert!((outputs.speed - 0.3).abs() < 0.01);
}
#[test]
fn test_infer_mlp_model() {
let mut system = AiSystem::new();
let model = AprModel::new_test_model();
system.register_model("mlp", model).unwrap();
let inputs =
AiInputs::from_positions(Vec2::new(0.0, 0.0), Vec2::new(50.0, 50.0), 0.016);
let outputs = system.infer("mlp", &inputs).unwrap();
assert!(outputs.movement.length() >= 0.0);
}
#[test]
fn test_infer_unknown_model() {
let system = AiSystem::new();
let inputs = AiInputs::default();
let result = system.infer("nonexistent", &inputs);
assert!(result.is_err());
}
}
mod yaml_bridge_tests {
use super::*;
#[test]
fn test_resolve_builtin_prefix() {
let bridge = YamlAiBridge::new();
let mut system = AiSystem::new();
let id = bridge.resolve("builtin:chase", &mut system).unwrap();
assert_eq!(id, "builtin-chase");
assert!(system.has_model("builtin-chase"));
}
#[test]
fn test_resolve_simple_builtin() {
let bridge = YamlAiBridge::new();
let mut system = AiSystem::new();
let id = bridge.resolve("patrol", &mut system).unwrap();
assert_eq!(id, "builtin-patrol");
assert!(system.has_model("builtin-patrol"));
}
#[test]
fn test_resolve_all_builtins() {
let bridge = YamlAiBridge::new();
let mut system = AiSystem::new();
let _ = bridge.resolve("chase", &mut system).unwrap();
let _ = bridge.resolve("patrol", &mut system).unwrap();
let _ = bridge.resolve("wander", &mut system).unwrap();
assert_eq!(system.model_count(), 3);
}
#[test]
fn test_resolve_unknown() {
let bridge = YamlAiBridge::new();
let mut system = AiSystem::new();
let result = bridge.resolve("unknown_behavior", &mut system);
assert!(result.is_err());
}
#[test]
fn test_resolve_caches_model() {
let bridge = YamlAiBridge::new();
let mut system = AiSystem::new();
let _ = bridge.resolve("builtin:chase", &mut system).unwrap();
let _ = bridge.resolve("builtin:chase", &mut system).unwrap();
assert_eq!(system.model_count(), 1);
}
#[test]
fn test_register_custom() {
let mut bridge = YamlAiBridge::new();
bridge.register_custom("smart-ghost", "models/ghost.apr");
assert!(!bridge.custom_models.is_empty());
}
}
mod mlp_inference_tests {
use super::*;
#[test]
fn test_simple_mlp() {
let layers = vec![LayerWeights {
weights: vec![1.0, 0.0, 0.0, 1.0], biases: vec![0.0, 0.0],
input_size: 2,
output_size: 2,
}];
let input = vec![0.5, -0.5];
let output = AiSystem::run_mlp_inference(&layers, &input);
assert!(output[0] > 0.0);
assert!(output[1].abs() < 0.01);
}
#[test]
fn test_multi_layer_mlp() {
let layers = vec![
LayerWeights {
weights: vec![0.5, 0.5, 0.5, 0.5],
biases: vec![0.0, 0.0],
input_size: 2,
output_size: 2,
},
LayerWeights {
weights: vec![1.0, 1.0],
biases: vec![0.0],
input_size: 2,
output_size: 1,
},
];
let input = vec![1.0, 1.0];
let output = AiSystem::run_mlp_inference(&layers, &input);
assert_eq!(output.len(), 1);
assert!(output[0] > 0.0);
}
}
}